From 09f8a59d6b0e98971ba660422b508c89a6841403 Mon Sep 17 00:00:00 2001 From: nexy7574 Date: Fri, 31 May 2024 18:24:15 +0100 Subject: [PATCH] Update ollama progress bar --- src/cogs/ollama.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/cogs/ollama.py b/src/cogs/ollama.py index 7e3e473..914fcf1 100644 --- a/src/cogs/ollama.py +++ b/src/cogs/ollama.py @@ -654,10 +654,12 @@ class Ollama(commands.Cog): if resp.status == 404: self.log.debug("Beginning download of %r", model) - def progress_bar(_v: float, action: str = None): + def progress_bar(_v: float, action: str = None, _mbps: float = None): bar = "\N{large green square}" * round(_v / 10) bar += "\N{white large square}" * (10 - len(bar)) bar += f" {_v:.2f}%" + if _mbps: + bar += f" ({_mbps:.2f} MB/s)" if action: return f"{action} {bar}" return bar @@ -685,6 +687,7 @@ class Ollama(commands.Cog): embed.set_footer(text="Unable to continue.") return await ctx.edit(embed=embed) view = OllamaView(ctx) + last_downloaded = 0 async for line in ollama_stream(response.content): if view.cancel.is_set(): embed = discord.Embed( @@ -695,11 +698,14 @@ class Ollama(commands.Cog): return await ctx.edit(embed=embed, view=None) if time.time() >= (last_update + 5.1): if line.get("total") is not None and line.get("completed") is not None: + new_bytes = line["completed"] - last_downloaded + mbps = new_bytes / 1024 / 1024 / 5 percent = (line["completed"] / line["total"]) * 100 else: percent = 50.0 + mbps = 0.0 - embed.fields[0].value = progress_bar(percent, line["status"]) + embed.fields[0].value = progress_bar(percent, line["status"], mbps) await ctx.edit(embed=embed, view=view) last_update = time.time() else: