diff --git a/src/cogs/ollama.py b/src/cogs/ollama.py index 95114f2..8cf24d6 100644 --- a/src/cogs/ollama.py +++ b/src/cogs/ollama.py @@ -170,6 +170,13 @@ class ChatHistory: """Gets a copy of an entire thread""" return self._internal.get(thread, {}).copy() + def find_thread(self, thread_id: str): + """Attempts to find a thread.""" + if c := self.get_thread(thread_id): + return c + if d := self.load_thread(thread_id): + return d + SERVER_KEYS = list(CONFIG["ollama"].keys()) @@ -467,10 +474,11 @@ class Ollama(commands.Cog): if context is None: context = self.history.create_thread(ctx.user) elif context is not None and self.history.get_thread(context) is None: - if not self.history.load_thread(context): + __thread = self.history.find_thread(context) + if not __thread: return await ctx.respond("Invalid thread ID.") else: - context = self.history.create_thread(ctx.user) + context = list(__thread.keys())[0] messages = self.history.get_history(context) user_message = { @@ -570,7 +578,7 @@ class Ollama(commands.Cog): async def ollama_history( self, ctx: discord.ApplicationContext, - thread: typing.Annotated[ + thread_id: typing.Annotated[ str, discord.Option( name="thread_id", @@ -584,7 +592,10 @@ class Ollama(commands.Cog): await ctx.defer(ephemeral=True) paginator = commands.Paginator("", "", 4000, "\n\n") - history = self.history.get_history(thread) + thread = self.history.load_thread(thread_id) + if not thread: + return await ctx.respond("No thread with that ID exists.") + history = self.history.get_history(thread_id) if not history: return await ctx.respond("No history or invalid context key.")