diff --git a/cogs/other.py b/cogs/other.py index 6b6836a..3899a22 100644 --- a/cogs/other.py +++ b/cogs/other.py @@ -64,6 +64,26 @@ except Exception as _pyttsx3_err: VOICES = [] +class OllamaStreamReader: + def __init__(self, response: httpx.Response): + self.response = response + self.stream = response.aiter_bytes(1) + self._buffer = b"" + + async def __aiter__(self): + return self + + async def __anext__(self) -> dict[str, str | int | bool]: + if self.response.is_stream_consumed: + raise StopAsyncIteration + self._buffer = b"" + while not self._buffer.endswith(b"}\n"): + async for char in self.stream: + self._buffer += char + + return json.loads(self._buffer.decode("utf-8", "replace")) + + def format_autocomplete(ctx: discord.AutocompleteContext): url = ctx.options.get("url", os.urandom(6).hex()) self: "OtherCog" = ctx.bot.cogs["OtherCog"] # type: ignore @@ -1021,9 +1041,7 @@ class OtherCog(commands.Cog): try: extracted_info = await asyncio.to_thread(downloader.extract_info, url, download=False) except yt_dlp.utils.DownloadError: - title = chosen_format = chosen_format_id = final_extension = format_note = "error" - resolution = vcodec = acodec = "error" - fps = 0 + title = "error" thumbnail_url = webpage_url = discord.Embed.Empty else: title = extracted_info.get("title", url) @@ -1771,6 +1789,83 @@ class OtherCog(commands.Cog): % (content, output_location.name) ) + @commands.command(hidden=True) + @commands.is_owner() + @commands.max_concurrency(1, wait=True) + async def ollama(self, ctx: commands.Context, *, query: str): + """:3""" + if query.startswith("model:"): + model, query = query.split(" ", 1) + model = model[6:].casefold() + try: + _name, _tag = model.split(":", 1) + except ValueError: + model += ":latest" + else: + model = "orca-mini" + + msg = await ctx.reply(f"Preparing {model!r} ") + async with httpx.AsyncClient(base_url="http://localhost:11434/api") as client: + # get models + try: + response = await client.post("/show", json={"name": model}) + except httpx.TransportError as e: + return await msg.edit(content="Failed to connect to Ollama: `%s`" % e) + if response.status_code == 404: + await msg.edit(content="Downloading model %r, please wait.") + async with ctx.channel.typing(): + async with client.stream( + "POST", + "/pull", + json={"name": model}, + timeout=None + ) as response: + if response.status_code != 200: + return await msg.edit(content="Failed to download model: `%s`" % response.text) + async for chunk in OllamaStreamReader(response): + if "total" in chunk and "completed" in chunk: + completed = chunk["completed"] or 1 # avoid division by zero + total = chunk["total"] or 1 + percent = completed / total * 100 + if not percent % 10: + await msg.edit(content=f"`{chunk['status']}` - {percent:.0f}%") + else: + await msg.edit(content=f"`{chunk['status']}`") + elif response.status_code != 200: + return await msg.edit(content="Failed to get model: `%s`" % response.text) + + output = discord.Embed( + title=f"{model} says:", + description="", + colour=discord.Colour.blurple(), + ) + output.set_footer(text="Powered by Ollama") + + async with ctx.channel.typing(): + async with client.stream( + "POST", + "/generate", + json={ + "model": model, + "prompt": query, + "format": "json", + "system": "You are a discord bot called Jimmy Saville. " + "Be helpful and make sure your response is safe for work, " + "and is less than 3500 characters" + } + ) as response: + if response.status_code != 200: + return await msg.edit(content="Failed to generate text: `%s`" % response.text) + last_edit = msg.edited_at.timestamp() if msg.edited_at else msg.created_at.timestamp() + async for chunk in OllamaStreamReader(response): + if "done" not in chunk or "response" not in chunk: + continue + else: + output.description = chunk["response"] + if (time() - last_edit) >= 5 or chunk["done"] is True: + await msg.edit(content=None, embed=output) + break + def setup(bot): bot.add_cog(OtherCog(bot))