From 28908f217c8d86fd3297ced9104b0a9609a873a9 Mon Sep 17 00:00:00 2001 From: nexy7574 Date: Tue, 11 Jun 2024 00:58:17 +0100 Subject: [PATCH] Enable ollama pull --- jimmy/cogs/chat.py | 144 +++++++++++++++++++++++++++++++++++++++++---- jimmy/config.py | 2 +- 2 files changed, 135 insertions(+), 11 deletions(-) diff --git a/jimmy/cogs/chat.py b/jimmy/cogs/chat.py index 2e05e6d..2f75b0f 100644 --- a/jimmy/cogs/chat.py +++ b/jimmy/cogs/chat.py @@ -46,10 +46,13 @@ async def get_available_tags_autocomplete(ctx: discord.AutocompleteContext): chosen_server = get_server(ctx.options.get("server") or get_servers()[0].name) async with ollama_client(str(chosen_server.base_url), timeout=2) as client: tags = (await client.list())["models"] - return [tag["model"] for tag in tags if ctx.value.casefold() in tag["model"].casefold()] + v = [tag["model"] for tag in tags if ctx.value.casefold() in tag["model"].casefold()] + return [ctx.value, *v][:25] -_ServerOptionChoices = [discord.OptionChoice(server.name, server.name) for server in get_servers()] +_ServerOptionAutocomplete = discord.utils.basic_autocomplete( + [x.name for x in get_servers()] +) class Chat(commands.Cog): @@ -112,7 +115,15 @@ class Chat(commands.Cog): ) await ctx.edit(embed=embed) - @commands.slash_command(name="ollama") + ollama_group = discord.SlashCommandGroup( + name="ollama", + description="Commands related to ollama.", + guild_only=True, + max_concurrency=commands.MaxConcurrency(1, per=commands.BucketType.user, wait=False), + cooldown=commands.CooldownMapping(commands.Cooldown(1, 10), commands.BucketType.user) + ) + + @ollama_group.command(name="chat") async def start_ollama_chat( self, ctx: discord.ApplicationContext, @@ -130,7 +141,7 @@ class Chat(commands.Cog): discord.Option( discord.SlashCommandOptionType.string, description="The server to use.", - choices=_ServerOptionChoices, + autocomplete=_ServerOptionAutocomplete, default=get_servers()[0].name ) ], @@ -140,7 +151,7 @@ class Chat(commands.Cog): discord.SlashCommandOptionType.string, description="The model to use.", autocomplete=get_available_tags_autocomplete, - default="llama3:latest" + default="default" ) ], image: typing.Annotated[ @@ -173,7 +184,9 @@ class Chat(commands.Cog): """Have a chat with ollama""" await ctx.defer() server = get_server(server) - if not await server.is_online(): + if not server: + return await ctx.respond("\N{cross mark} Unknown Server.") + elif not await server.is_online(): await ctx.respond( content=f"{server} is offline. Finding a suitable server...", ) @@ -183,10 +196,13 @@ class Chat(commands.Cog): return await ctx.edit(content=str(err), delete_after=30) await ctx.delete(delay=5) async with self.server_locks[server.name]: + if model == "default": + model = server.default_model async with ollama_client(str(server.base_url)) as client: client: AsyncClient self.log.info("Checking if %r has the model %r", server, model) tags = (await client.list())["models"] + # Download code. It's recommended to collapse this in the editor. if model not in [x["model"] for x in tags]: embed = discord.Embed( title=f"Downloading {model} on {server}.", @@ -265,6 +281,7 @@ class Chat(commands.Cog): await ctx.edit(embed=embed, delete_after=30, view=None) messages = [] + thread = None if thread_id: thread = await OllamaThread.get_or_none(thread_id=thread_id) if thread: @@ -330,13 +347,120 @@ class Chat(commands.Cog): else: file = discord.utils.MISSING - thread = OllamaThread( - messages=[{"role": m["role"], "content": m["content"]} for m in messages], - ) - await thread.save() + if not thread: + thread = OllamaThread( + messages=[{"role": m["role"], "content": m["content"]} for m in messages], + ) + await thread.save() embed.set_footer(text=f"Chat ID: {thread.thread_id}") await msg.edit(embed=embed, view=None, file=file) + @ollama_group.command(name="pull") + async def pull_ollama_model( + self, + ctx: discord.ApplicationContext, + server: typing.Annotated[ + str, + discord.Option( + discord.SlashCommandOptionType.string, + description="The server to use.", + autocomplete=_ServerOptionAutocomplete, + default=get_servers()[0].name + ) + ], + model: typing.Annotated[ + str, + discord.Option( + discord.SlashCommandOptionType.string, + description="The model to use.", + autocomplete=get_available_tags_autocomplete, + default="llama3:latest" + ) + ], + ): + """Downloads a tag on the target server""" + await ctx.defer() + server = get_server(server) + if not server: + return await ctx.respond("\N{cross mark} Unknown server.") + elif not await server.is_online(): + return await ctx.respond(f"\N{cross mark} Server {server.name!r} is not responding") + embed = discord.Embed( + title=f"Downloading {model} on {server}.", + description=f"Initiating download...", + color=discord.Color.blurple() + ) + view = StopDownloadView(ctx) + await ctx.respond( + embed=embed, + view=view + ) + last_edit = 0 + async with ctx.typing(): + try: + last_completed = 0 + last_completed_ts = time.time() + + async for line in await client.pull(model, stream=True): + if view.event.is_set(): + embed.add_field(name="Error!", value="Download cancelled.") + embed.colour = discord.Colour.red() + await ctx.edit(embed=embed) + return + self.log.debug("Response from %r: %r", server, line) + if line["status"] in { + "pulling manifest", + "verifying sha256 digest", + "writing manifest", + "removing any unused layers", + "success" + }: + embed.description = line["status"].capitalize() + else: + total = line["total"] + completed = line.get("completed", 0) + percent = round(completed / total * 100, 1) + pb_fill = "▰" * int(percent / 10) + pb_empty = "▱" * (10 - int(percent / 10)) + bytes_per_second = completed - last_completed + bytes_per_second /= (time.time() - last_completed_ts) + last_completed = completed + last_completed_ts = time.time() + mbps = round((bytes_per_second * 8) / 1024 / 1024) + eta = (total - completed) / max(1, bytes_per_second) + progress_bar = f"[{pb_fill}{pb_empty}]" + ns_total = naturalsize(total, binary=True) + ns_completed = naturalsize(completed, binary=True) + embed.description = ( + f"{line['status'].capitalize()} {percent}% {progress_bar} " + f"({ns_completed}/{ns_total} @ {mbps} Mb/s) " + f"[ETA: {naturaldelta(eta)}]" + ) + + if time.time() - last_edit >= 2.5: + await ctx.edit(embed=embed) + last_edit = time.time() + except ResponseError as err: + if err.error.endswith("file does not exist"): + await ctx.edit( + embed=None, + content="The model %r does not exist." % model, + delete_after=60, + view=None + ) + else: + embed.add_field( + name="Error!", + value=err.error + ) + embed.colour = discord.Colour.red() + await ctx.edit(embed=embed, view=None) + return + else: + embed.colour = discord.Colour.green() + embed.description = f"Downloaded {model} on {server}." + await ctx.edit(embed=embed, delete_after=30, view=None) + def setup(bot): bot.add_cog(Chat(bot)) diff --git a/jimmy/config.py b/jimmy/config.py index a2b7b14..12ca67e 100644 --- a/jimmy/config.py +++ b/jimmy/config.py @@ -66,7 +66,7 @@ def get_server(name_or_base_url: str) -> ServerConfig | None: else: if parsed.netloc and parsed.scheme in ["http", "https"]: defaults = { - "name": ":temporary:", + "name": parsed.hostname, "base_url": "{0.scheme}://{0.netloc}".format(parsed), "gpu": False, "vram_gb": 2,