Fix automatic saving
All checks were successful
Build and Publish college-bot-v2 / build_and_publish (push) Successful in 14s

This commit is contained in:
Nexus 2024-06-04 02:17:02 +01:00
parent 4a42918dee
commit 0ce03d2667
Signed by: nex
GPG key ID: 0FA334385D0B689F

View file

@ -1,5 +1,6 @@
import asyncio
import base64
import functools
import io
import json
import logging
@ -269,6 +270,8 @@ class OllamaView(View):
class ChatHistory:
# TO.DO: Make this async. currently, a remote redis will actually block the bot. this isn't as noticeable in compose
def __init__(self):
self._internal = {}
self.log = logging.getLogger("jimmy.cogs.ollama.history")
@ -338,6 +341,8 @@ class ChatHistory:
role: typing.Literal["user", "assistant", "system"],
content: str,
images: typing.Optional[list[str]] = None,
*,
save: bool = True
) -> None:
"""
Appends a message to the given thread.
@ -346,12 +351,14 @@ class ChatHistory:
:param role: The author of the message.
:param content: The message's actual content.
:param images: Any images that were attached to the message, in base64.
:param save: Saves the thread after adding the message. Set to False when bulk adding.
:return: None
"""
new = self._construct_message(role, content, images)
self.log.debug("Adding message to thread %r: %r", thread, new)
self._internal[thread]["messages"].append(new)
self.save_thread(thread)
if save:
self.save_thread(thread)
def get_history(self, thread: str) -> list[dict[str, str]]:
"""
@ -817,7 +824,6 @@ class Ollama(commands.Cog):
view.stop()
self.history.add_message(context, "user", user_message["content"], user_message.get("images"))
self.history.add_message(context, "assistant", buffer.getvalue())
self.history.save_thread(context)
embed.add_field(name="Context Key", value=context, inline=True)
self.log.debug("Ollama finished consuming.")
@ -1029,9 +1035,16 @@ class Ollama(commands.Cog):
if not embed.description or embed.description.strip() == "NEW TRUTH":
continue
if embed.type == "rich" and embed.colour and embed.colour.value == 0x5448EE:
self.history.add_message(thread_id, "assistant", embed.description)
await asyncio.to_thread(
functools.partial(
self.history.add_message,
thread_id,
"assistant",
embed.description,
save=False
)
)
self.history.add_message(thread_id, "user", "Generate a new truth post.")
self.history.save_thread(thread_id)
tried = set()
for _ in range(10):