mirror of
https://github.com/nexy7574/LCC-bot.git
synced 2024-09-19 10:03:40 +01:00
spice
This commit is contained in:
parent
451e80640e
commit
395a9a9581
1 changed files with 77 additions and 6 deletions
|
@ -134,6 +134,8 @@ class OtherCog(commands.Cog):
|
|||
self._fmt_queue = asyncio.Queue()
|
||||
self._worker_task = self.bot.loop.create_task(self.cache_population_job())
|
||||
|
||||
self.ollama_locks: dict[discord.Message, asyncio.Event] = {}
|
||||
|
||||
def cog_unload(self):
|
||||
self._worker_task.cancel()
|
||||
|
||||
|
@ -1807,6 +1809,29 @@ class OtherCog(commands.Cog):
|
|||
% (content, output_location.name)
|
||||
)
|
||||
|
||||
|
||||
class OllamaKillSwitchView(discord.ui.View):
|
||||
def __init__(self, ctx: commands.Context, msg: discord.Message):
|
||||
super().__init__(timeout=None)
|
||||
self.ctx = ctx
|
||||
self.msg = msg
|
||||
|
||||
async def interaction_check(self, interaction: discord.Interaction) -> bool:
|
||||
return interaction.user == self.ctx.author and interaction.channel == self.ctx.channel
|
||||
|
||||
@discord.ui.button(
|
||||
label="Abort",
|
||||
style=discord.ButtonStyle.red,
|
||||
emoji="\N{waste basket}",
|
||||
)
|
||||
async def abort_button(self, _, interaction: discord.Interaction):
|
||||
await interaction.response.defer()
|
||||
if self.msg in self.ctx.command.cog.ollama_locks:
|
||||
self.ctx.command.cog.ollama_locks[self.msg].set()
|
||||
self.disable_all_items()
|
||||
await interaction.edit_original_response(view=self)
|
||||
self.stop()
|
||||
|
||||
@commands.command(hidden=True)
|
||||
@commands.is_owner()
|
||||
@commands.max_concurrency(1, wait=True)
|
||||
|
@ -1839,7 +1864,8 @@ class OtherCog(commands.Cog):
|
|||
timeout=None
|
||||
) as response:
|
||||
if response.status_code != 200:
|
||||
return await msg.edit(content="Failed to download model: `%s`" % response.text)
|
||||
error = await response.aread()
|
||||
return await msg.edit(content="Failed to download model: `%s`" % error.decode())
|
||||
async for chunk in ollama_stream_reader(response):
|
||||
print(chunk)
|
||||
if "total" in chunk and "completed" in chunk:
|
||||
|
@ -1852,18 +1878,21 @@ class OtherCog(commands.Cog):
|
|||
await msg.edit(content=f"`{chunk['status']}` - {percent}%")
|
||||
else:
|
||||
await msg.edit(content=f"`{chunk['status']}`")
|
||||
await msg.edit(content=f"Downloaded model {model}. Re-run please.")
|
||||
return
|
||||
await msg.edit(content=f"Downloaded model {model}.")
|
||||
while (await client.post("/show", json={"name": model})).status_code != 200:
|
||||
await asyncio.sleep(5)
|
||||
elif response.status_code != 200:
|
||||
return await msg.edit(content="Failed to get model: `%s`" % response.text)
|
||||
error = await response.aread()
|
||||
return await msg.edit(content="Failed to get model: `%s`" % error.decode())
|
||||
|
||||
output = discord.Embed(
|
||||
title=f"{model} says:",
|
||||
description="",
|
||||
colour=discord.Colour.blurple(),
|
||||
timestamp=discord.utils.utcnow()
|
||||
)
|
||||
output.set_footer(text="Powered by Ollama")
|
||||
await msg.edit(content="Starting generation. Please wait.")
|
||||
await msg.edit(embed=output)
|
||||
async with ctx.channel.typing():
|
||||
async with client.stream(
|
||||
"POST",
|
||||
|
@ -1880,7 +1909,8 @@ class OtherCog(commands.Cog):
|
|||
timeout=None
|
||||
) as response:
|
||||
if response.status_code != 200:
|
||||
return await msg.edit(content="Failed to generate text: `%s`" % response.text)
|
||||
error = await response.aread()
|
||||
return await msg.edit(content="Failed to generate text: `%s`" % error.decode())
|
||||
async for chunk in ollama_stream_reader(response):
|
||||
print(chunk)
|
||||
if "done" not in chunk.keys() or "response" not in chunk.keys():
|
||||
|
@ -1893,6 +1923,47 @@ class OtherCog(commands.Cog):
|
|||
last_edit = msg.edited_at.timestamp() if msg.edited_at else msg.created_at.timestamp()
|
||||
if (time() - last_edit) >= 5 or chunk["done"] is True:
|
||||
await msg.edit(content=content, embed=output)
|
||||
|
||||
def get_time_spent(nanoseconds: int) -> str:
|
||||
hours, minutes, seconds = 0, 0, 0
|
||||
seconds = nanoseconds / 1e9
|
||||
if seconds >= 60:
|
||||
minutes, seconds = divmod(seconds, 60)
|
||||
if minutes >= 60:
|
||||
hours, minutes = divmod(minutes, 60)
|
||||
|
||||
result = []
|
||||
if seconds:
|
||||
if seconds != 1:
|
||||
label = "seconds"
|
||||
else:
|
||||
label = "second"
|
||||
result.append(f"{round(seconds)} {label}")
|
||||
if minutes:
|
||||
if minutes != 1:
|
||||
label = "minutes"
|
||||
else:
|
||||
label = "minute"
|
||||
result.append(f"{round(minutes)} {label}")
|
||||
if hours:
|
||||
if hours != 1:
|
||||
label = "hours"
|
||||
else:
|
||||
label = "hour"
|
||||
result.append(f"{round(hours)} {label}")
|
||||
return ", ".join(reversed(result))
|
||||
|
||||
total_time_spent = get_time_spent(chunk["total_duration"])
|
||||
eval_time_spent = get_time_spent(chunk["eval_duration"])
|
||||
tokens_per_second = chunk["eval_count"] / chunk["eval_duration"]
|
||||
output.add_field(
|
||||
name="Timings",
|
||||
value="Total: {}\nEval: {} ({:,.2f}/s)".format(
|
||||
total_time_spent,
|
||||
eval_time_spent,
|
||||
tokens_per_second
|
||||
),
|
||||
)
|
||||
await msg.edit(content=None, embed=output)
|
||||
|
||||
|
||||
|
|
Loading…
Reference in a new issue