Add a stop button to ollama

This commit is contained in:
Nexus 2024-01-10 10:13:37 +00:00
parent d2e334bb98
commit e8f7f447bd

View file

@ -1,11 +1,12 @@
import asyncio
import collections import collections
import json import json
import logging import logging
from pydoc import describe
import textwrap import textwrap
import time import time
import typing import typing
import io import io
from discord.ui import View, button
from fnmatch import fnmatch from fnmatch import fnmatch
import aiohttp import aiohttp
@ -14,6 +15,20 @@ from discord.ext import commands
from conf import CONFIG from conf import CONFIG
class OllamaView(View):
def __init__(self, ctx: discord.ApplicationContext):
super().__init__(timeout=3600, disable_on_timeout=True)
self.ctx = ctx
self.cancel = asyncio.Event()
@button(label="Stop", style=discord.ButtonStyle.danger, emoji="\N{wastebasket}\U0000fe0f")
async def _stop(self, btn: discord.ui.Button, interaction: discord.Interaction):
self.cancel.set()
btn.disabled = True
await interaction.response.edit_message(view=self)
self.stop()
SERVER_KEYS = list(CONFIG["ollama"].keys()) SERVER_KEYS = list(CONFIG["ollama"].keys())
class Ollama(commands.Cog): class Ollama(commands.Cog):
@ -227,6 +242,7 @@ class Ollama(commands.Cog):
await ctx.edit(embed=embed) await ctx.edit(embed=embed)
except discord.NotFound: except discord.NotFound:
await ctx.respond(embed=embed) await ctx.respond(embed=embed)
view = OllamaView(ctx)
self.log.debug("Beginning to generate response.") self.log.debug("Beginning to generate response.")
async with session.post( async with session.post(
"/api/generate", "/api/generate",
@ -254,6 +270,7 @@ class Ollama(commands.Cog):
value=">>> " + textwrap.shorten(query, width=1020, placeholder="..."), value=">>> " + textwrap.shorten(query, width=1020, placeholder="..."),
inline=False inline=False
) )
await ctx.edit(view=view, embed=embed)
buffer = io.StringIO() buffer = io.StringIO()
async for line in self.ollama_stream(response.content): async for line in self.ollama_stream(response.content):
buffer.write(line["response"]) buffer.write(line["response"])
@ -261,10 +278,15 @@ class Ollama(commands.Cog):
embed.timestamp = discord.utils.utcnow() embed.timestamp = discord.utils.utcnow()
if len(embed.description) >= 4096: if len(embed.description) >= 4096:
embed.description = embed.description = "..." + line["response"] embed.description = embed.description = "..." + line["response"]
if view.cancel.is_set():
break
if time.time() >= (last_update + 5.1): if time.time() >= (last_update + 5.1):
await ctx.edit(embed=embed) await ctx.edit(embed=embed)
self.log.debug(f"Updating message ({last_update} -> {time.time()})") self.log.debug(f"Updating message ({last_update} -> {time.time()})")
last_update = time.time() last_update = time.time()
view.stop()
self.log.debug("Ollama finished consuming.") self.log.debug("Ollama finished consuming.")
embed.title = "Done!" embed.title = "Done!"
embed.color = discord.Color.green() embed.color = discord.Color.green()