diff --git a/src/cogs/ollama.py b/src/cogs/ollama.py index 74b0493..bd9a3d4 100644 --- a/src/cogs/ollama.py +++ b/src/cogs/ollama.py @@ -18,6 +18,9 @@ from discord.ext import commands from conf import CONFIG +_M = typing.TypeVar("_M", discord.Member, discord.User) + + def get_time_spent(nanoseconds: int) -> str: hours, minutes, seconds = 0, 0, 0 seconds = nanoseconds / 1e9 @@ -98,6 +101,25 @@ class ChatHistory: x["images"] = images return x + @staticmethod + def autocomplete(ctx: discord.AutocompleteContext): + # noinspection PyTypeChecker + cog: Ollama = ctx.bot.get_cog("Ollama") + instance = cog.history + return discord.utils.basic_autocomplete(list(instance.threads_for(ctx.interaction.user).keys())) + + def all_threads(self) -> dict[str, dict[str, list[dict[str, str]] | discord.Member | int]]: + """Returns all saved threads.""" + return self._internal.copy() + + def threads_for(self, user: _M) -> dict[str, list[dict[str, str]] | _M | int]: + """Returns all saved threads for a specific user""" + t = self.all_threads() + for k, v in t.copy().items(): + if v["member"] != user: + t.pop(k) + return t + def add_message( self, thread: str, @@ -120,6 +142,8 @@ class ChatHistory: """ Gets the history of a thread. """ + if self._internal.get(thread) is None: + return [] return self._internal[thread]["messages"].copy() # copy() makes it immutable. def get_thread(self, thread: str) -> dict[str, list[dict[str, str]] | discord.Member | int]: @@ -420,7 +444,6 @@ class Ollama(commands.Cog): await ctx.respond(embed=embed, view=view) self.log.debug("Beginning to generate response with key %r.", key) - if context is None or context in self.contexts: context = self.history.create_thread(ctx.user) messages = self.history.get_history(context) @@ -516,6 +539,46 @@ class Ollama(commands.Cog): ) return await ctx.respond(embed=embed, ephemeral=True) + @commands.slash_command(name="ollama-history") + async def ollama_history( + self, + ctx: discord.ApplicationContext, + thread: typing.Annotated[ + str, + discord.Option( + name="thread_id", + description="Thread/Context ID", + type=str, + autocomplete=ChatHistory.autocomplete, + ) + ] + ): + """Shows the history for a thread.""" + await ctx.defer(ephemeral=True) + paginator = commands.Paginator("", "", 4000, "\n\n") + + history = self.history.get_history(thread) + if not history: + return await ctx.respond("No history or invalid context key.") + + for message in history: + if message["role"] == "system": + continue + max_length = 4000 - len("> **%s**: " % message["role"]) + paginator.add_line( + "> **{}**: {}".format(message["role"], textwrap.shorten(message["content"], max_length)) + ) + + embeds = [] + for page in paginator.pages: + embeds.append( + discord.Embed( + description=page + ) + ) + for chunk in discord.utils.as_chunks(iter(embeds), 10): + await ctx.respond(chunk, ephemeral=True) + def setup(bot): bot.add_cog(Ollama(bot))