This commit is contained in:
Nexus 2023-11-11 18:31:48 +00:00
parent 451e80640e
commit 395a9a9581
Signed by: nex
GPG key ID: 0FA334385D0B689F

View file

@ -134,6 +134,8 @@ class OtherCog(commands.Cog):
self._fmt_queue = asyncio.Queue()
self._worker_task = self.bot.loop.create_task(self.cache_population_job())
self.ollama_locks: dict[discord.Message, asyncio.Event] = {}
def cog_unload(self):
self._worker_task.cancel()
@ -1807,6 +1809,29 @@ class OtherCog(commands.Cog):
% (content, output_location.name)
)
class OllamaKillSwitchView(discord.ui.View):
def __init__(self, ctx: commands.Context, msg: discord.Message):
super().__init__(timeout=None)
self.ctx = ctx
self.msg = msg
async def interaction_check(self, interaction: discord.Interaction) -> bool:
return interaction.user == self.ctx.author and interaction.channel == self.ctx.channel
@discord.ui.button(
label="Abort",
style=discord.ButtonStyle.red,
emoji="\N{waste basket}",
)
async def abort_button(self, _, interaction: discord.Interaction):
await interaction.response.defer()
if self.msg in self.ctx.command.cog.ollama_locks:
self.ctx.command.cog.ollama_locks[self.msg].set()
self.disable_all_items()
await interaction.edit_original_response(view=self)
self.stop()
@commands.command(hidden=True)
@commands.is_owner()
@commands.max_concurrency(1, wait=True)
@ -1839,7 +1864,8 @@ class OtherCog(commands.Cog):
timeout=None
) as response:
if response.status_code != 200:
return await msg.edit(content="Failed to download model: `%s`" % response.text)
error = await response.aread()
return await msg.edit(content="Failed to download model: `%s`" % error.decode())
async for chunk in ollama_stream_reader(response):
print(chunk)
if "total" in chunk and "completed" in chunk:
@ -1852,18 +1878,21 @@ class OtherCog(commands.Cog):
await msg.edit(content=f"`{chunk['status']}` - {percent}%")
else:
await msg.edit(content=f"`{chunk['status']}`")
await msg.edit(content=f"Downloaded model {model}. Re-run please.")
return
await msg.edit(content=f"Downloaded model {model}.")
while (await client.post("/show", json={"name": model})).status_code != 200:
await asyncio.sleep(5)
elif response.status_code != 200:
return await msg.edit(content="Failed to get model: `%s`" % response.text)
error = await response.aread()
return await msg.edit(content="Failed to get model: `%s`" % error.decode())
output = discord.Embed(
title=f"{model} says:",
description="",
colour=discord.Colour.blurple(),
timestamp=discord.utils.utcnow()
)
output.set_footer(text="Powered by Ollama")
await msg.edit(content="Starting generation. Please wait.")
await msg.edit(embed=output)
async with ctx.channel.typing():
async with client.stream(
"POST",
@ -1880,7 +1909,8 @@ class OtherCog(commands.Cog):
timeout=None
) as response:
if response.status_code != 200:
return await msg.edit(content="Failed to generate text: `%s`" % response.text)
error = await response.aread()
return await msg.edit(content="Failed to generate text: `%s`" % error.decode())
async for chunk in ollama_stream_reader(response):
print(chunk)
if "done" not in chunk.keys() or "response" not in chunk.keys():
@ -1893,6 +1923,47 @@ class OtherCog(commands.Cog):
last_edit = msg.edited_at.timestamp() if msg.edited_at else msg.created_at.timestamp()
if (time() - last_edit) >= 5 or chunk["done"] is True:
await msg.edit(content=content, embed=output)
def get_time_spent(nanoseconds: int) -> str:
hours, minutes, seconds = 0, 0, 0
seconds = nanoseconds / 1e9
if seconds >= 60:
minutes, seconds = divmod(seconds, 60)
if minutes >= 60:
hours, minutes = divmod(minutes, 60)
result = []
if seconds:
if seconds != 1:
label = "seconds"
else:
label = "second"
result.append(f"{round(seconds)} {label}")
if minutes:
if minutes != 1:
label = "minutes"
else:
label = "minute"
result.append(f"{round(minutes)} {label}")
if hours:
if hours != 1:
label = "hours"
else:
label = "hour"
result.append(f"{round(hours)} {label}")
return ", ".join(reversed(result))
total_time_spent = get_time_spent(chunk["total_duration"])
eval_time_spent = get_time_spent(chunk["eval_duration"])
tokens_per_second = chunk["eval_count"] / chunk["eval_duration"]
output.add_field(
name="Timings",
value="Total: {}\nEval: {} ({:,.2f}/s)".format(
total_time_spent,
eval_time_spent,
tokens_per_second
),
)
await msg.edit(content=None, embed=output)