From e8f7f447bd37afb5639ab71884c6612aceb49c0b Mon Sep 17 00:00:00 2001 From: nexy7574 Date: Wed, 10 Jan 2024 10:13:37 +0000 Subject: [PATCH] Add a stop button to ollama --- src/cogs/ollama.py | 24 +++++++++++++++++++++++- 1 file changed, 23 insertions(+), 1 deletion(-) diff --git a/src/cogs/ollama.py b/src/cogs/ollama.py index 7dcc9d5..7470f0f 100644 --- a/src/cogs/ollama.py +++ b/src/cogs/ollama.py @@ -1,11 +1,12 @@ +import asyncio import collections import json import logging -from pydoc import describe import textwrap import time import typing import io +from discord.ui import View, button from fnmatch import fnmatch import aiohttp @@ -14,6 +15,20 @@ from discord.ext import commands from conf import CONFIG +class OllamaView(View): + def __init__(self, ctx: discord.ApplicationContext): + super().__init__(timeout=3600, disable_on_timeout=True) + self.ctx = ctx + self.cancel = asyncio.Event() + + @button(label="Stop", style=discord.ButtonStyle.danger, emoji="\N{wastebasket}\U0000fe0f") + async def _stop(self, btn: discord.ui.Button, interaction: discord.Interaction): + self.cancel.set() + btn.disabled = True + await interaction.response.edit_message(view=self) + self.stop() + + SERVER_KEYS = list(CONFIG["ollama"].keys()) class Ollama(commands.Cog): @@ -227,6 +242,7 @@ class Ollama(commands.Cog): await ctx.edit(embed=embed) except discord.NotFound: await ctx.respond(embed=embed) + view = OllamaView(ctx) self.log.debug("Beginning to generate response.") async with session.post( "/api/generate", @@ -254,6 +270,7 @@ class Ollama(commands.Cog): value=">>> " + textwrap.shorten(query, width=1020, placeholder="..."), inline=False ) + await ctx.edit(view=view, embed=embed) buffer = io.StringIO() async for line in self.ollama_stream(response.content): buffer.write(line["response"]) @@ -261,10 +278,15 @@ class Ollama(commands.Cog): embed.timestamp = discord.utils.utcnow() if len(embed.description) >= 4096: embed.description = embed.description = "..." + line["response"] + + if view.cancel.is_set(): + break + if time.time() >= (last_update + 5.1): await ctx.edit(embed=embed) self.log.debug(f"Updating message ({last_update} -> {time.time()})") last_update = time.time() + view.stop() self.log.debug("Ollama finished consuming.") embed.title = "Done!" embed.color = discord.Color.green()