I hate the word "thread" now.

This commit is contained in:
Nexus 2024-01-12 17:11:16 +00:00
parent b0c25254ea
commit 14237b8f5f

View file

@ -170,6 +170,13 @@ class ChatHistory:
"""Gets a copy of an entire thread"""
return self._internal.get(thread, {}).copy()
def find_thread(self, thread_id: str):
"""Attempts to find a thread."""
if c := self.get_thread(thread_id):
return c
if d := self.load_thread(thread_id):
return d
SERVER_KEYS = list(CONFIG["ollama"].keys())
@ -467,10 +474,11 @@ class Ollama(commands.Cog):
if context is None:
context = self.history.create_thread(ctx.user)
elif context is not None and self.history.get_thread(context) is None:
if not self.history.load_thread(context):
__thread = self.history.find_thread(context)
if not __thread:
return await ctx.respond("Invalid thread ID.")
else:
context = self.history.create_thread(ctx.user)
context = list(__thread.keys())[0]
messages = self.history.get_history(context)
user_message = {
@ -570,7 +578,7 @@ class Ollama(commands.Cog):
async def ollama_history(
self,
ctx: discord.ApplicationContext,
thread: typing.Annotated[
thread_id: typing.Annotated[
str,
discord.Option(
name="thread_id",
@ -584,7 +592,10 @@ class Ollama(commands.Cog):
await ctx.defer(ephemeral=True)
paginator = commands.Paginator("", "", 4000, "\n\n")
history = self.history.get_history(thread)
thread = self.history.load_thread(thread_id)
if not thread:
return await ctx.respond("No thread with that ID exists.")
history = self.history.get_history(thread_id)
if not history:
return await ctx.respond("No history or invalid context key.")