Add debugging
This commit is contained in:
parent
17ae2db7b7
commit
e92713fe25
1 changed files with 11 additions and 2 deletions
|
@ -68,6 +68,7 @@ class OllamaView(View):
|
||||||
class ChatHistory:
|
class ChatHistory:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self._internal = {}
|
self._internal = {}
|
||||||
|
self.log = logging.getLogger("jimmy.cogs.ollama.history")
|
||||||
no_ping = CONFIG["redis"].pop("no_ping", False)
|
no_ping = CONFIG["redis"].pop("no_ping", False)
|
||||||
self.redis = redis.Redis(**CONFIG["redis"])
|
self.redis = redis.Redis(**CONFIG["redis"])
|
||||||
if no_ping is False:
|
if no_ping is False:
|
||||||
|
@ -76,11 +77,13 @@ class ChatHistory:
|
||||||
def load_thread(self, thread_id: str):
|
def load_thread(self, thread_id: str):
|
||||||
value: str = self.redis.get("threads:" + thread_id)
|
value: str = self.redis.get("threads:" + thread_id)
|
||||||
if value:
|
if value:
|
||||||
|
self.log.debug("Loaded thread %r: %r", thread_id, value)
|
||||||
loaded = json.loads(value)
|
loaded = json.loads(value)
|
||||||
self._internal.update(loaded)
|
self._internal.update(loaded)
|
||||||
return self.get_thread(thread_id)
|
return self.get_thread(thread_id)
|
||||||
|
|
||||||
def save_thread(self, thread_id: str):
|
def save_thread(self, thread_id: str):
|
||||||
|
self.log.info("Saving thread:%s - %r", thread_id, self._internal[thread_id])
|
||||||
self.redis.set(
|
self.redis.set(
|
||||||
"threads:" + thread_id, json.dumps(self._internal[thread_id])
|
"threads:" + thread_id, json.dumps(self._internal[thread_id])
|
||||||
)
|
)
|
||||||
|
@ -156,7 +159,9 @@ class ChatHistory:
|
||||||
:param images: Any images that were attached to the message, in base64.
|
:param images: Any images that were attached to the message, in base64.
|
||||||
:return: None
|
:return: None
|
||||||
"""
|
"""
|
||||||
self._internal[thread]["messages"].append(self._construct_message(role, content, images))
|
new = self._construct_message(role, content, images)
|
||||||
|
self.log.debug("Adding message to thread %r: %r", thread, new)
|
||||||
|
self._internal[thread]["messages"].append(new)
|
||||||
|
|
||||||
def get_history(self, thread: str) -> list[dict[str, str]]:
|
def get_history(self, thread: str) -> list[dict[str, str]]:
|
||||||
"""
|
"""
|
||||||
|
@ -172,10 +177,13 @@ class ChatHistory:
|
||||||
|
|
||||||
def find_thread(self, thread_id: str):
|
def find_thread(self, thread_id: str):
|
||||||
"""Attempts to find a thread."""
|
"""Attempts to find a thread."""
|
||||||
|
self.log.debug("Checking cache for %r...", thread_id)
|
||||||
if c := self.get_thread(thread_id):
|
if c := self.get_thread(thread_id):
|
||||||
return c
|
return c
|
||||||
|
self.log.debug("Checking db for %r...", thread_id)
|
||||||
if d := self.load_thread(thread_id):
|
if d := self.load_thread(thread_id):
|
||||||
return d
|
return d
|
||||||
|
self.log.warning("No thread with ID %r found.", thread_id)
|
||||||
|
|
||||||
|
|
||||||
SERVER_KEYS = list(CONFIG["ollama"].keys())
|
SERVER_KEYS = list(CONFIG["ollama"].keys())
|
||||||
|
@ -536,6 +544,8 @@ class Ollama(commands.Cog):
|
||||||
view.stop()
|
view.stop()
|
||||||
self.history.add_message(context, "user", user_message["content"], user_message.get("images"))
|
self.history.add_message(context, "user", user_message["content"], user_message.get("images"))
|
||||||
self.history.add_message(context, "assistant", buffer.getvalue())
|
self.history.add_message(context, "assistant", buffer.getvalue())
|
||||||
|
self.history.save_thread(context)
|
||||||
|
|
||||||
embed.add_field(name="Context Key", value=context, inline=True)
|
embed.add_field(name="Context Key", value=context, inline=True)
|
||||||
self.log.debug("Ollama finished consuming.")
|
self.log.debug("Ollama finished consuming.")
|
||||||
embed.title = "Done!"
|
embed.title = "Done!"
|
||||||
|
@ -559,7 +569,6 @@ class Ollama(commands.Cog):
|
||||||
await ctx.edit(embed=embed, view=None)
|
await ctx.edit(embed=embed, view=None)
|
||||||
|
|
||||||
if line.get("done"):
|
if line.get("done"):
|
||||||
self.history.save_thread(context)
|
|
||||||
total_duration = get_time_spent(line["total_duration"])
|
total_duration = get_time_spent(line["total_duration"])
|
||||||
load_duration = get_time_spent(line["load_duration"])
|
load_duration = get_time_spent(line["load_duration"])
|
||||||
prompt_eval_duration = get_time_spent(line["prompt_eval_duration"])
|
prompt_eval_duration = get_time_spent(line["prompt_eval_duration"])
|
||||||
|
|
Loading…
Reference in a new issue