From 380d500e322528619395802c18383cc03eabb224 Mon Sep 17 00:00:00 2001 From: nexy7574 Date: Fri, 12 Jan 2024 15:39:39 +0000 Subject: [PATCH] Migrate to chat endpoint --- src/cogs/ollama.py | 95 ++++++++++++++++++++++++++++++++++++---------- 1 file changed, 76 insertions(+), 19 deletions(-) diff --git a/src/cogs/ollama.py b/src/cogs/ollama.py index 8fec1bc..28bd596 100644 --- a/src/cogs/ollama.py +++ b/src/cogs/ollama.py @@ -65,6 +65,63 @@ class OllamaView(View): self.stop() +class ChatHistory: + def __init__(self): + self._internal = {} + + def create_thread(self, member: discord.Member) -> str: + """ + Creates a thread, returns its ID. + """ + key = os.urandom(3).hex() + self._internal[key] = { + "member": member, + "messages": [] + } + with open("./assets/ollama-prompt.txt") as file: + system_prompt = file.read() + self.add_message( + key, + "system", + system_prompt + ) + return key + + @staticmethod + def _construct_message(role: str, content: str, images: typing.Optional[list[str]]) -> dict[str, str]: + x = { + "role": role, + "content": content + } + if images: + x["images"] = images + return x + + def add_message( + self, + thread: str, + role: typing.Literal["user", "assistant", "system"], + content: str, + images: typing.Optional[list[str]] = None + ) -> None: + """ + Appends a message to the given thread. + + :param thread: The thread's ID. + :param role: The author of the message. + :param content: The message's actual content. + :param images: Any images that were attached to the message, in base64. + :return: None + """ + self._internal[thread]["messages"].append(self._construct_message(role, content, images)) + + def get_history(self, thread: str) -> list[dict[str, str]]: + """ + Gets the history of a thread. + """ + return self._internal[thread]["messages"].copy() # copy() makes it immutable. + + SERVER_KEYS = list(CONFIG["ollama"].keys()) @@ -74,6 +131,7 @@ class Ollama(commands.Cog): self.log = logging.getLogger("jimmy.cogs.ollama") self.last_server = 0 self.contexts = {} + self.history = ChatHistory() def next_server(self, increment: bool = True) -> str: """Returns the next server key.""" @@ -163,9 +221,6 @@ class Ollama(commands.Cog): if context not in self.contexts: await ctx.respond("Invalid context key.") return - return await ctx.respond("Context is currently disabled.", ephemeral=True) - with open("./assets/ollama-prompt.txt") as file: - system_prompt = file.read() await ctx.defer() model = model.casefold() @@ -366,19 +421,24 @@ class Ollama(commands.Cog): params["top_k"] = 500 params["top_p"] = 500 + if context is None or context in self.contexts: + context = self.history.create_thread(ctx.user) + messages = self.history.get_history(context) + user_message = { + "role": "user", + "content": query + } + if image_data: + user_message["images"] = [image_data] + messages.append(user_message) payload = { "model": model, - "prompt": query, - "system": system_prompt, "stream": True, "options": params, + "messages": messages } - if context is not None: - payload["context"] = self.contexts[context] - if image_data: - payload["images"] = [image_data] async with session.post( - "/api/generate", + "/api/chat", json=payload, ) as response: if response.status != 200: @@ -394,16 +454,13 @@ class Ollama(commands.Cog): last_update = time.time() buffer = io.StringIO() - context = [] if not view.cancel.is_set(): async for line in self.ollama_stream(response.content): - if "context" in line: - context = line["context"] - buffer.write(line["response"]) - embed.description += line["response"] + buffer.write(line["assistant"]) + embed.description += line["assistant"] embed.timestamp = discord.utils.utcnow() if len(embed.description) >= 4096: - embed.description = embed.description = "..." + line["response"] + embed.description = embed.description = "..." + line["assistant"] if view.cancel.is_set(): break @@ -413,9 +470,9 @@ class Ollama(commands.Cog): self.log.debug(f"Updating message ({last_update} -> {time.time()})") last_update = time.time() view.stop() - if context: - self.contexts[key] = context - embed.add_field(name="Context Key", value=key, inline=True) + self.history.add_message(context, "user", user_message["content"], user_message["images"]) + self.history.add_message(context, "assistant", buffer.getvalue()) + embed.add_field(name="Context Key", value=context, inline=True) self.log.debug("Ollama finished consuming.") embed.title = "Done!" embed.colour = discord.Color.green()