From 3c61504cb3f04527b46fa20ab569ec657cb489cd Mon Sep 17 00:00:00 2001 From: nexy7574 Date: Tue, 11 Jun 2024 01:21:34 +0100 Subject: [PATCH] Fix /ollama pull --- jimmy/cogs/chat.py | 77 +++++++++++++++++++++++----------------------- 1 file changed, 39 insertions(+), 38 deletions(-) diff --git a/jimmy/cogs/chat.py b/jimmy/cogs/chat.py index 0a93258..c345dca 100644 --- a/jimmy/cogs/chat.py +++ b/jimmy/cogs/chat.py @@ -434,45 +434,46 @@ class Chat(commands.Cog): last_completed = 0 last_completed_ts = time.time() - async for line in await client.pull(model, stream=True): - if view.event.is_set(): - embed.add_field(name="Error!", value="Download cancelled.") - embed.colour = discord.Colour.red() - await ctx.edit(embed=embed) - return - self.log.debug("Response from %r: %r", server, line) - if line["status"] in { - "pulling manifest", - "verifying sha256 digest", - "writing manifest", - "removing any unused layers", - "success" - }: - embed.description = line["status"].capitalize() - else: - total = line["total"] - completed = line.get("completed", 0) - percent = round(completed / total * 100, 1) - pb_fill = "▰" * int(percent / 10) - pb_empty = "▱" * (10 - int(percent / 10)) - bytes_per_second = completed - last_completed - bytes_per_second /= (time.time() - last_completed_ts) - last_completed = completed - last_completed_ts = time.time() - mbps = round((bytes_per_second * 8) / 1024 / 1024) - eta = (total - completed) / max(1, bytes_per_second) - progress_bar = f"[{pb_fill}{pb_empty}]" - ns_total = naturalsize(total, binary=True) - ns_completed = naturalsize(completed, binary=True) - embed.description = ( - f"{line['status'].capitalize()} {percent}% {progress_bar} " - f"({ns_completed}/{ns_total} @ {mbps} Mb/s) " - f"[ETA: {naturaldelta(eta)}]" - ) + async with ollama_client(str(server.base_url)) as client: + async for line in await client.pull(model, stream=True): + if view.event.is_set(): + embed.add_field(name="Error!", value="Download cancelled.") + embed.colour = discord.Colour.red() + await ctx.edit(embed=embed) + return + self.log.debug("Response from %r: %r", server, line) + if line["status"] in { + "pulling manifest", + "verifying sha256 digest", + "writing manifest", + "removing any unused layers", + "success" + }: + embed.description = line["status"].capitalize() + else: + total = line["total"] + completed = line.get("completed", 0) + percent = round(completed / total * 100, 1) + pb_fill = "▰" * int(percent / 10) + pb_empty = "▱" * (10 - int(percent / 10)) + bytes_per_second = completed - last_completed + bytes_per_second /= (time.time() - last_completed_ts) + last_completed = completed + last_completed_ts = time.time() + mbps = round((bytes_per_second * 8) / 1024 / 1024) + eta = (total - completed) / max(1, bytes_per_second) + progress_bar = f"[{pb_fill}{pb_empty}]" + ns_total = naturalsize(total, binary=True) + ns_completed = naturalsize(completed, binary=True) + embed.description = ( + f"{line['status'].capitalize()} {percent}% {progress_bar} " + f"({ns_completed}/{ns_total} @ {mbps} Mb/s) " + f"[ETA: {naturaldelta(eta)}]" + ) - if time.time() - last_edit >= 2.5: - await ctx.edit(embed=embed) - last_edit = time.time() + if time.time() - last_edit >= 2.5: + await ctx.edit(embed=embed) + last_edit = time.time() except ResponseError as err: if err.error.endswith("file does not exist"): await ctx.edit(