diff --git a/cogs/other.py b/cogs/other.py index 245f95c..e93e060 100644 --- a/cogs/other.py +++ b/cogs/other.py @@ -134,6 +134,8 @@ class OtherCog(commands.Cog): self._fmt_queue = asyncio.Queue() self._worker_task = self.bot.loop.create_task(self.cache_population_job()) + self.ollama_locks: dict[discord.Message, asyncio.Event] = {} + def cog_unload(self): self._worker_task.cancel() @@ -1807,6 +1809,29 @@ class OtherCog(commands.Cog): % (content, output_location.name) ) + + class OllamaKillSwitchView(discord.ui.View): + def __init__(self, ctx: commands.Context, msg: discord.Message): + super().__init__(timeout=None) + self.ctx = ctx + self.msg = msg + + async def interaction_check(self, interaction: discord.Interaction) -> bool: + return interaction.user == self.ctx.author and interaction.channel == self.ctx.channel + + @discord.ui.button( + label="Abort", + style=discord.ButtonStyle.red, + emoji="\N{waste basket}", + ) + async def abort_button(self, _, interaction: discord.Interaction): + await interaction.response.defer() + if self.msg in self.ctx.command.cog.ollama_locks: + self.ctx.command.cog.ollama_locks[self.msg].set() + self.disable_all_items() + await interaction.edit_original_response(view=self) + self.stop() + @commands.command(hidden=True) @commands.is_owner() @commands.max_concurrency(1, wait=True) @@ -1839,7 +1864,8 @@ class OtherCog(commands.Cog): timeout=None ) as response: if response.status_code != 200: - return await msg.edit(content="Failed to download model: `%s`" % response.text) + error = await response.aread() + return await msg.edit(content="Failed to download model: `%s`" % error.decode()) async for chunk in ollama_stream_reader(response): print(chunk) if "total" in chunk and "completed" in chunk: @@ -1852,18 +1878,21 @@ class OtherCog(commands.Cog): await msg.edit(content=f"`{chunk['status']}` - {percent}%") else: await msg.edit(content=f"`{chunk['status']}`") - await msg.edit(content=f"Downloaded model {model}. Re-run please.") - return + await msg.edit(content=f"Downloaded model {model}.") + while (await client.post("/show", json={"name": model})).status_code != 200: + await asyncio.sleep(5) elif response.status_code != 200: - return await msg.edit(content="Failed to get model: `%s`" % response.text) + error = await response.aread() + return await msg.edit(content="Failed to get model: `%s`" % error.decode()) output = discord.Embed( title=f"{model} says:", description="", colour=discord.Colour.blurple(), + timestamp=discord.utils.utcnow() ) output.set_footer(text="Powered by Ollama") - await msg.edit(content="Starting generation. Please wait.") + await msg.edit(embed=output) async with ctx.channel.typing(): async with client.stream( "POST", @@ -1880,7 +1909,8 @@ class OtherCog(commands.Cog): timeout=None ) as response: if response.status_code != 200: - return await msg.edit(content="Failed to generate text: `%s`" % response.text) + error = await response.aread() + return await msg.edit(content="Failed to generate text: `%s`" % error.decode()) async for chunk in ollama_stream_reader(response): print(chunk) if "done" not in chunk.keys() or "response" not in chunk.keys(): @@ -1893,6 +1923,47 @@ class OtherCog(commands.Cog): last_edit = msg.edited_at.timestamp() if msg.edited_at else msg.created_at.timestamp() if (time() - last_edit) >= 5 or chunk["done"] is True: await msg.edit(content=content, embed=output) + + def get_time_spent(nanoseconds: int) -> str: + hours, minutes, seconds = 0, 0, 0 + seconds = nanoseconds / 1e9 + if seconds >= 60: + minutes, seconds = divmod(seconds, 60) + if minutes >= 60: + hours, minutes = divmod(minutes, 60) + + result = [] + if seconds: + if seconds != 1: + label = "seconds" + else: + label = "second" + result.append(f"{round(seconds)} {label}") + if minutes: + if minutes != 1: + label = "minutes" + else: + label = "minute" + result.append(f"{round(minutes)} {label}") + if hours: + if hours != 1: + label = "hours" + else: + label = "hour" + result.append(f"{round(hours)} {label}") + return ", ".join(reversed(result)) + + total_time_spent = get_time_spent(chunk["total_duration"]) + eval_time_spent = get_time_spent(chunk["eval_duration"]) + tokens_per_second = chunk["eval_count"] / chunk["eval_duration"] + output.add_field( + name="Timings", + value="Total: {}\nEval: {} ({:,.2f}/s)".format( + total_time_spent, + eval_time_spent, + tokens_per_second + ), + ) await msg.edit(content=None, embed=output)