Add debugging

This commit is contained in:
Nexus 2024-01-12 17:18:23 +00:00
parent 17ae2db7b7
commit e92713fe25

View file

@ -68,6 +68,7 @@ class OllamaView(View):
class ChatHistory: class ChatHistory:
def __init__(self): def __init__(self):
self._internal = {} self._internal = {}
self.log = logging.getLogger("jimmy.cogs.ollama.history")
no_ping = CONFIG["redis"].pop("no_ping", False) no_ping = CONFIG["redis"].pop("no_ping", False)
self.redis = redis.Redis(**CONFIG["redis"]) self.redis = redis.Redis(**CONFIG["redis"])
if no_ping is False: if no_ping is False:
@ -76,11 +77,13 @@ class ChatHistory:
def load_thread(self, thread_id: str): def load_thread(self, thread_id: str):
value: str = self.redis.get("threads:" + thread_id) value: str = self.redis.get("threads:" + thread_id)
if value: if value:
self.log.debug("Loaded thread %r: %r", thread_id, value)
loaded = json.loads(value) loaded = json.loads(value)
self._internal.update(loaded) self._internal.update(loaded)
return self.get_thread(thread_id) return self.get_thread(thread_id)
def save_thread(self, thread_id: str): def save_thread(self, thread_id: str):
self.log.info("Saving thread:%s - %r", thread_id, self._internal[thread_id])
self.redis.set( self.redis.set(
"threads:" + thread_id, json.dumps(self._internal[thread_id]) "threads:" + thread_id, json.dumps(self._internal[thread_id])
) )
@ -156,7 +159,9 @@ class ChatHistory:
:param images: Any images that were attached to the message, in base64. :param images: Any images that were attached to the message, in base64.
:return: None :return: None
""" """
self._internal[thread]["messages"].append(self._construct_message(role, content, images)) new = self._construct_message(role, content, images)
self.log.debug("Adding message to thread %r: %r", thread, new)
self._internal[thread]["messages"].append(new)
def get_history(self, thread: str) -> list[dict[str, str]]: def get_history(self, thread: str) -> list[dict[str, str]]:
""" """
@ -172,10 +177,13 @@ class ChatHistory:
def find_thread(self, thread_id: str): def find_thread(self, thread_id: str):
"""Attempts to find a thread.""" """Attempts to find a thread."""
self.log.debug("Checking cache for %r...", thread_id)
if c := self.get_thread(thread_id): if c := self.get_thread(thread_id):
return c return c
self.log.debug("Checking db for %r...", thread_id)
if d := self.load_thread(thread_id): if d := self.load_thread(thread_id):
return d return d
self.log.warning("No thread with ID %r found.", thread_id)
SERVER_KEYS = list(CONFIG["ollama"].keys()) SERVER_KEYS = list(CONFIG["ollama"].keys())
@ -536,6 +544,8 @@ class Ollama(commands.Cog):
view.stop() view.stop()
self.history.add_message(context, "user", user_message["content"], user_message.get("images")) self.history.add_message(context, "user", user_message["content"], user_message.get("images"))
self.history.add_message(context, "assistant", buffer.getvalue()) self.history.add_message(context, "assistant", buffer.getvalue())
self.history.save_thread(context)
embed.add_field(name="Context Key", value=context, inline=True) embed.add_field(name="Context Key", value=context, inline=True)
self.log.debug("Ollama finished consuming.") self.log.debug("Ollama finished consuming.")
embed.title = "Done!" embed.title = "Done!"
@ -559,7 +569,6 @@ class Ollama(commands.Cog):
await ctx.edit(embed=embed, view=None) await ctx.edit(embed=embed, view=None)
if line.get("done"): if line.get("done"):
self.history.save_thread(context)
total_duration = get_time_spent(line["total_duration"]) total_duration = get_time_spent(line["total_duration"])
load_duration = get_time_spent(line["load_duration"]) load_duration = get_time_spent(line["load_duration"])
prompt_eval_duration = get_time_spent(line["prompt_eval_duration"]) prompt_eval_duration = get_time_spent(line["prompt_eval_duration"])