diff --git a/src/cogs/ollama.py b/src/cogs/ollama.py index eca5070..c85dae3 100644 --- a/src/cogs/ollama.py +++ b/src/cogs/ollama.py @@ -68,6 +68,7 @@ class OllamaView(View): class ChatHistory: def __init__(self): self._internal = {} + self.log = logging.getLogger("jimmy.cogs.ollama.history") no_ping = CONFIG["redis"].pop("no_ping", False) self.redis = redis.Redis(**CONFIG["redis"]) if no_ping is False: @@ -76,11 +77,13 @@ class ChatHistory: def load_thread(self, thread_id: str): value: str = self.redis.get("threads:" + thread_id) if value: + self.log.debug("Loaded thread %r: %r", thread_id, value) loaded = json.loads(value) self._internal.update(loaded) return self.get_thread(thread_id) def save_thread(self, thread_id: str): + self.log.info("Saving thread:%s - %r", thread_id, self._internal[thread_id]) self.redis.set( "threads:" + thread_id, json.dumps(self._internal[thread_id]) ) @@ -156,7 +159,9 @@ class ChatHistory: :param images: Any images that were attached to the message, in base64. :return: None """ - self._internal[thread]["messages"].append(self._construct_message(role, content, images)) + new = self._construct_message(role, content, images) + self.log.debug("Adding message to thread %r: %r", thread, new) + self._internal[thread]["messages"].append(new) def get_history(self, thread: str) -> list[dict[str, str]]: """ @@ -172,10 +177,13 @@ class ChatHistory: def find_thread(self, thread_id: str): """Attempts to find a thread.""" + self.log.debug("Checking cache for %r...", thread_id) if c := self.get_thread(thread_id): return c + self.log.debug("Checking db for %r...", thread_id) if d := self.load_thread(thread_id): return d + self.log.warning("No thread with ID %r found.", thread_id) SERVER_KEYS = list(CONFIG["ollama"].keys()) @@ -536,6 +544,8 @@ class Ollama(commands.Cog): view.stop() self.history.add_message(context, "user", user_message["content"], user_message.get("images")) self.history.add_message(context, "assistant", buffer.getvalue()) + self.history.save_thread(context) + embed.add_field(name="Context Key", value=context, inline=True) self.log.debug("Ollama finished consuming.") embed.title = "Done!" @@ -559,7 +569,6 @@ class Ollama(commands.Cog): await ctx.edit(embed=embed, view=None) if line.get("done"): - self.history.save_thread(context) total_duration = get_time_spent(line["total_duration"]) load_duration = get_time_spent(line["load_duration"]) prompt_eval_duration = get_time_spent(line["prompt_eval_duration"])