diff --git a/src/cogs/ollama.py b/src/cogs/ollama.py index c9669a8..bf59ec4 100644 --- a/src/cogs/ollama.py +++ b/src/cogs/ollama.py @@ -1,6 +1,7 @@ import asyncio import json import logging +import os import textwrap import time import typing @@ -35,6 +36,7 @@ class Ollama(commands.Cog): self.bot = bot self.log = logging.getLogger("jimmy.cogs.ollama") self.last_server = 0 + self.contexts = {} def next_server(self, increment: bool = True) -> str: """Returns the next server key.""" @@ -83,7 +85,27 @@ class Ollama(commands.Cog): choices=SERVER_KEYS ) ], + context: typing.Annotated[ + str, + discord.Option( + str, + "The context key of a previous ollama response to use as context.", + default=None + ) + ], + give_acid: typing.Annotated[ + bool, + discord.Option( + bool, + "Whether to give the AI acid, LSD, and other hallucinogens before responding.", + default=False + ) + ] ): + if context is not None: + if context not in self.contexts: + await ctx.respond("Invalid context key.") + return with open("./assets/ollama-prompt.txt") as file: system_prompt = file.read() await ctx.defer() @@ -150,10 +172,10 @@ class Ollama(commands.Cog): if resp.status == 404: self.log.debug("Beginning download of %r", model) - def progress_bar(value: float, action: str = None): - bar = "\N{large green square}" * round(value / 10) + def progress_bar(_v: float, action: str = None): + bar = "\N{large green square}" * round(_v / 10) bar += "\N{white large square}" * (10 - len(bar)) - bar += f" {value:.2f}%" + bar += f" {_v:.2f}%" if action: return f"{action} {bar}" return bar @@ -194,6 +216,8 @@ class Ollama(commands.Cog): else: self.log.debug("Model %r already exists on server.", model) + key = os.urandom(6).hex() + embed = discord.Embed( title="Generating response...", description=">>> ", @@ -216,15 +240,26 @@ class Ollama(commands.Cog): await ctx.edit(embed=embed, view=view) except discord.NotFound: await ctx.respond(embed=embed, view=view) - self.log.debug("Beginning to generate response.") + self.log.debug("Beginning to generate response with key %r.", key) + + params = {} + if give_acid is True: + params["temperature"] = 5 + params["top_k"] = 500 + params["top_p"] = 5 + + payload = { + "model": model, + "prompt": query, + "system": system_prompt, + "stream": True, + "options": params, + } + if context is not None: + payload["context"] = self.contexts[context] async with session.post( "/api/generate", - json={ - "model": model, - "prompt": query, - "system": system_prompt, - "stream": True - }, + json=payload, ) as response: if response.status != 200: embed = discord.Embed( @@ -239,8 +274,11 @@ class Ollama(commands.Cog): last_update = time.time() buffer = io.StringIO() + context = [] if not view.cancel.is_set(): async for line in self.ollama_stream(response.content): + if "context" in line: + context = line["context"] buffer.write(line["response"]) embed.description += line["response"] embed.timestamp = discord.utils.utcnow() @@ -255,6 +293,9 @@ class Ollama(commands.Cog): self.log.debug(f"Updating message ({last_update} -> {time.time()})") last_update = time.time() view.stop() + if context: + self.contexts[key] = context + embed.add_field(name="Context Key", value=key, inline=True) self.log.debug("Ollama finished consuming.") embed.title = "Done!" embed.color = discord.Color.green()