diff --git a/src/cogs/ollama.py b/src/cogs/ollama.py index 81d7108..51aabef 100644 --- a/src/cogs/ollama.py +++ b/src/cogs/ollama.py @@ -1,5 +1,6 @@ import asyncio import base64 +import functools import io import json import logging @@ -269,6 +270,8 @@ class OllamaView(View): class ChatHistory: + # TO.DO: Make this async. currently, a remote redis will actually block the bot. this isn't as noticeable in compose + def __init__(self): self._internal = {} self.log = logging.getLogger("jimmy.cogs.ollama.history") @@ -338,6 +341,8 @@ class ChatHistory: role: typing.Literal["user", "assistant", "system"], content: str, images: typing.Optional[list[str]] = None, + *, + save: bool = True ) -> None: """ Appends a message to the given thread. @@ -346,12 +351,14 @@ class ChatHistory: :param role: The author of the message. :param content: The message's actual content. :param images: Any images that were attached to the message, in base64. + :param save: Saves the thread after adding the message. Set to False when bulk adding. :return: None """ new = self._construct_message(role, content, images) self.log.debug("Adding message to thread %r: %r", thread, new) self._internal[thread]["messages"].append(new) - self.save_thread(thread) + if save: + self.save_thread(thread) def get_history(self, thread: str) -> list[dict[str, str]]: """ @@ -817,7 +824,6 @@ class Ollama(commands.Cog): view.stop() self.history.add_message(context, "user", user_message["content"], user_message.get("images")) self.history.add_message(context, "assistant", buffer.getvalue()) - self.history.save_thread(context) embed.add_field(name="Context Key", value=context, inline=True) self.log.debug("Ollama finished consuming.") @@ -1029,9 +1035,16 @@ class Ollama(commands.Cog): if not embed.description or embed.description.strip() == "NEW TRUTH": continue if embed.type == "rich" and embed.colour and embed.colour.value == 0x5448EE: - self.history.add_message(thread_id, "assistant", embed.description) + await asyncio.to_thread( + functools.partial( + self.history.add_message, + thread_id, + "assistant", + embed.description, + save=False + ) + ) self.history.add_message(thread_id, "user", "Generate a new truth post.") - self.history.save_thread(thread_id) tried = set() for _ in range(10):