Fix automatic saving
All checks were successful
Build and Publish college-bot-v2 / build_and_publish (push) Successful in 14s

This commit is contained in:
Nexus 2024-06-04 02:17:02 +01:00
parent 4a42918dee
commit 0ce03d2667
Signed by: nex
GPG key ID: 0FA334385D0B689F

View file

@ -1,5 +1,6 @@
import asyncio import asyncio
import base64 import base64
import functools
import io import io
import json import json
import logging import logging
@ -269,6 +270,8 @@ class OllamaView(View):
class ChatHistory: class ChatHistory:
# TO.DO: Make this async. currently, a remote redis will actually block the bot. this isn't as noticeable in compose
def __init__(self): def __init__(self):
self._internal = {} self._internal = {}
self.log = logging.getLogger("jimmy.cogs.ollama.history") self.log = logging.getLogger("jimmy.cogs.ollama.history")
@ -338,6 +341,8 @@ class ChatHistory:
role: typing.Literal["user", "assistant", "system"], role: typing.Literal["user", "assistant", "system"],
content: str, content: str,
images: typing.Optional[list[str]] = None, images: typing.Optional[list[str]] = None,
*,
save: bool = True
) -> None: ) -> None:
""" """
Appends a message to the given thread. Appends a message to the given thread.
@ -346,11 +351,13 @@ class ChatHistory:
:param role: The author of the message. :param role: The author of the message.
:param content: The message's actual content. :param content: The message's actual content.
:param images: Any images that were attached to the message, in base64. :param images: Any images that were attached to the message, in base64.
:param save: Saves the thread after adding the message. Set to False when bulk adding.
:return: None :return: None
""" """
new = self._construct_message(role, content, images) new = self._construct_message(role, content, images)
self.log.debug("Adding message to thread %r: %r", thread, new) self.log.debug("Adding message to thread %r: %r", thread, new)
self._internal[thread]["messages"].append(new) self._internal[thread]["messages"].append(new)
if save:
self.save_thread(thread) self.save_thread(thread)
def get_history(self, thread: str) -> list[dict[str, str]]: def get_history(self, thread: str) -> list[dict[str, str]]:
@ -817,7 +824,6 @@ class Ollama(commands.Cog):
view.stop() view.stop()
self.history.add_message(context, "user", user_message["content"], user_message.get("images")) self.history.add_message(context, "user", user_message["content"], user_message.get("images"))
self.history.add_message(context, "assistant", buffer.getvalue()) self.history.add_message(context, "assistant", buffer.getvalue())
self.history.save_thread(context)
embed.add_field(name="Context Key", value=context, inline=True) embed.add_field(name="Context Key", value=context, inline=True)
self.log.debug("Ollama finished consuming.") self.log.debug("Ollama finished consuming.")
@ -1029,9 +1035,16 @@ class Ollama(commands.Cog):
if not embed.description or embed.description.strip() == "NEW TRUTH": if not embed.description or embed.description.strip() == "NEW TRUTH":
continue continue
if embed.type == "rich" and embed.colour and embed.colour.value == 0x5448EE: if embed.type == "rich" and embed.colour and embed.colour.value == 0x5448EE:
self.history.add_message(thread_id, "assistant", embed.description) await asyncio.to_thread(
functools.partial(
self.history.add_message,
thread_id,
"assistant",
embed.description,
save=False
)
)
self.history.add_message(thread_id, "user", "Generate a new truth post.") self.history.add_message(thread_id, "user", "Generate a new truth post.")
self.history.save_thread(thread_id)
tried = set() tried = set()
for _ in range(10): for _ in range(10):