diff --git a/src/cogs/ollama.py b/src/cogs/ollama.py index 7828bf9..e5a40b3 100644 --- a/src/cogs/ollama.py +++ b/src/cogs/ollama.py @@ -8,6 +8,7 @@ import typing import base64 import io import redis +from discord import Interaction from discord.ui import View, button from fnmatch import fnmatch @@ -89,7 +90,7 @@ class ChatHistory: "threads:" + thread_id, json.dumps(self._internal[thread_id]) ) - def create_thread(self, member: discord.Member) -> str: + def create_thread(self, member: discord.Member, default: str | None = None) -> str: """ Creates a thread, returns its ID. """ @@ -100,7 +101,7 @@ class ChatHistory: "messages": [] } with open("./assets/ollama-prompt.txt") as file: - system_prompt = file.read() + system_prompt = default or file.read() self.add_message( key, "system", @@ -190,6 +191,32 @@ class ChatHistory: SERVER_KEYS = list(CONFIG["ollama"].keys()) +class OllamaGetPrompt(discord.ui.Modal): + + def __init__(self, ctx: discord.ApplicationContext, prompt_type: str = "User"): + super().__init__( + discord.ui.InputText( + style=discord.InputTextStyle.long, + label="%s prompt" % prompt_type, + placeholder="Enter your prompt here.", + ), + timeout=300, + title="Ollama %s prompt" % prompt_type, + ) + self.ctx = ctx + self.prompt_type = prompt_type + self.value = None + + async def interaction_check(self, interaction: discord.Interaction) -> bool: + return interaction.user == self.ctx.user + + async def callback(self, interaction: Interaction): + await interaction.response.defer() + self.ctx.interaction = interaction + self.value = self.children[0].value + self.stop() + + class Ollama(commands.Cog): def __init__(self, bot: commands.Bot): self.bot = bot @@ -286,7 +313,28 @@ class Ollama(commands.Cog): if not self.history.get_thread(context): await ctx.respond("Invalid context key.") return - await ctx.defer() + + if query.startswith("$$"): + prompt = OllamaGetPrompt(ctx, "System") + await ctx.send_modal(prompt) + await prompt.wait() + system_query = prompt.value + if not system_query: + return await ctx.respond("No prompt provided. Aborting.") + else: + system_query = None + if query == "$": + prompt = OllamaGetPrompt(ctx) + await ctx.send_modal(prompt) + await prompt.wait() + query = prompt.value + if not query: + return await ctx.respond("No prompt provided. Aborting.") + + try: + await ctx.defer() + except discord.HTTPException: + pass model = model.casefold() try: @@ -294,7 +342,7 @@ class Ollama(commands.Cog): model = model + ":" + tag self.log.debug("Model %r already has a tag", model) except ValueError: - model = model + ":latest" + model += ":latest" self.log.debug("Resolved model to %r" % model) if image: @@ -315,7 +363,7 @@ class Ollama(commands.Cog): data = io.BytesIO() await image.save(data) data.seek(0) - image_data = base64.b64encode(data.read()).decode("utf-8") + image_data = base64.b64encode(data.read()).decode() else: image_data = None @@ -336,7 +384,12 @@ class Ollama(commands.Cog): async with aiohttp.ClientSession( base_url=server_config["base_url"], - timeout=aiohttp.ClientTimeout(0) + timeout=aiohttp.ClientTimeout( + connect=30, + sock_read=10800, + sock_connect=30, + total=10830 + ) ) as session: embed = discord.Embed( title="Checking server...", @@ -482,7 +535,7 @@ class Ollama(commands.Cog): self.log.debug("Beginning to generate response with key %r.", key) if context is None: - context = self.history.create_thread(ctx.user) + context = self.history.create_thread(ctx.user, system_query) elif context is not None and self.history.get_thread(context) is None: __thread = self.history.find_thread(context) if not __thread: