Add killswitch

This commit is contained in:
Nexus 2023-11-11 18:42:33 +00:00
parent b616554fa0
commit 474fe98b38
Signed by: nex
GPG key ID: 0FA334385D0B689F

View file

@ -1831,7 +1831,9 @@ class OtherCog(commands.Cog):
await interaction.edit_original_response(view=self) await interaction.edit_original_response(view=self)
self.stop() self.stop()
@commands.command(hidden=True) @commands.command(
usage="[model:<name:tag>] [server:<ip[:port]>] <query>"
)
@commands.is_owner() @commands.is_owner()
@commands.max_concurrency(1, wait=True) @commands.max_concurrency(1, wait=True)
async def ollama(self, ctx: commands.Context, *, query: str): async def ollama(self, ctx: commands.Context, *, query: str):
@ -1846,8 +1848,19 @@ class OtherCog(commands.Cog):
else: else:
model = "orca-mini" model = "orca-mini"
msg = await ctx.reply(f"Preparing {model!r} <a:loading:1101463077586735174>") if query.startswith("server:"):
async with httpx.AsyncClient(base_url="http://192.168.0.90:11434/api") as client: host, query = query.split(" ", 1)
host = host[7:]
try:
host, port = host.split(":", 1)
int(port)
except ValueError:
host += ":11434"
else:
host = "192.168.0.90:11434"
msg = await ctx.reply(f"Preparing [{model!r}](http://{host}) <a:loading:1101463077586735174>")
async with httpx.AsyncClient(base_url=f"http://{host}/api") as client:
# get models # get models
try: try:
response = await client.post("/show", json={"name": model}) response = await client.post("/show", json={"name": model})
@ -1910,8 +1923,10 @@ class OtherCog(commands.Cog):
if response.status_code != 200: if response.status_code != 200:
error = await response.aread() error = await response.aread()
return await msg.edit(content="Failed to generate text: `%s`" % error.decode()) return await msg.edit(content="Failed to generate text: `%s`" % error.decode())
self.ollama_locks[msg] = asyncio.Event()
view = self.OllamaKillSwitchView(ctx, msg)
await msg.edit(view=view)
async for chunk in ollama_stream_reader(response): async for chunk in ollama_stream_reader(response):
print(chunk)
if "done" not in chunk.keys() or "response" not in chunk.keys(): if "done" not in chunk.keys() or "response" not in chunk.keys():
continue continue
else: else:
@ -1921,7 +1936,9 @@ class OtherCog(commands.Cog):
output.description += chunk["response"] output.description += chunk["response"]
last_edit = msg.edited_at.timestamp() if msg.edited_at else msg.created_at.timestamp() last_edit = msg.edited_at.timestamp() if msg.edited_at else msg.created_at.timestamp()
if (time() - last_edit) >= 5 or chunk["done"] is True: if (time() - last_edit) >= 5 or chunk["done"] is True:
await msg.edit(content=content, embed=output) await msg.edit(content=content, embed=output, view=view)
if self.ollama_locks[msg].is_set():
return await msg.edit(content="Aborted.", embed=output, view=None)
def get_time_spent(nanoseconds: int) -> str: def get_time_spent(nanoseconds: int) -> str:
hours, minutes, seconds = 0, 0, 0 hours, minutes, seconds = 0, 0, 0
@ -1963,7 +1980,8 @@ class OtherCog(commands.Cog):
tokens_per_second tokens_per_second
), ),
) )
await msg.edit(content=None, embed=output) await msg.edit(content=None, embed=output, view=None)
self.ollama_locks.pop(msg, None)
def setup(bot): def setup(bot):