Add history command

This commit is contained in:
Nexus 2024-01-12 16:26:32 +00:00
parent 583ab92b54
commit c47cac4118

View file

@ -18,6 +18,9 @@ from discord.ext import commands
from conf import CONFIG from conf import CONFIG
_M = typing.TypeVar("_M", discord.Member, discord.User)
def get_time_spent(nanoseconds: int) -> str: def get_time_spent(nanoseconds: int) -> str:
hours, minutes, seconds = 0, 0, 0 hours, minutes, seconds = 0, 0, 0
seconds = nanoseconds / 1e9 seconds = nanoseconds / 1e9
@ -98,6 +101,25 @@ class ChatHistory:
x["images"] = images x["images"] = images
return x return x
@staticmethod
def autocomplete(ctx: discord.AutocompleteContext):
# noinspection PyTypeChecker
cog: Ollama = ctx.bot.get_cog("Ollama")
instance = cog.history
return discord.utils.basic_autocomplete(list(instance.threads_for(ctx.interaction.user).keys()))
def all_threads(self) -> dict[str, dict[str, list[dict[str, str]] | discord.Member | int]]:
"""Returns all saved threads."""
return self._internal.copy()
def threads_for(self, user: _M) -> dict[str, list[dict[str, str]] | _M | int]:
"""Returns all saved threads for a specific user"""
t = self.all_threads()
for k, v in t.copy().items():
if v["member"] != user:
t.pop(k)
return t
def add_message( def add_message(
self, self,
thread: str, thread: str,
@ -120,6 +142,8 @@ class ChatHistory:
""" """
Gets the history of a thread. Gets the history of a thread.
""" """
if self._internal.get(thread) is None:
return []
return self._internal[thread]["messages"].copy() # copy() makes it immutable. return self._internal[thread]["messages"].copy() # copy() makes it immutable.
def get_thread(self, thread: str) -> dict[str, list[dict[str, str]] | discord.Member | int]: def get_thread(self, thread: str) -> dict[str, list[dict[str, str]] | discord.Member | int]:
@ -420,7 +444,6 @@ class Ollama(commands.Cog):
await ctx.respond(embed=embed, view=view) await ctx.respond(embed=embed, view=view)
self.log.debug("Beginning to generate response with key %r.", key) self.log.debug("Beginning to generate response with key %r.", key)
if context is None or context in self.contexts: if context is None or context in self.contexts:
context = self.history.create_thread(ctx.user) context = self.history.create_thread(ctx.user)
messages = self.history.get_history(context) messages = self.history.get_history(context)
@ -516,6 +539,46 @@ class Ollama(commands.Cog):
) )
return await ctx.respond(embed=embed, ephemeral=True) return await ctx.respond(embed=embed, ephemeral=True)
@commands.slash_command(name="ollama-history")
async def ollama_history(
self,
ctx: discord.ApplicationContext,
thread: typing.Annotated[
str,
discord.Option(
name="thread_id",
description="Thread/Context ID",
type=str,
autocomplete=ChatHistory.autocomplete,
)
]
):
"""Shows the history for a thread."""
await ctx.defer(ephemeral=True)
paginator = commands.Paginator("", "", 4000, "\n\n")
history = self.history.get_history(thread)
if not history:
return await ctx.respond("No history or invalid context key.")
for message in history:
if message["role"] == "system":
continue
max_length = 4000 - len("> **%s**: " % message["role"])
paginator.add_line(
"> **{}**: {}".format(message["role"], textwrap.shorten(message["content"], max_length))
)
embeds = []
for page in paginator.pages:
embeds.append(
discord.Embed(
description=page
)
)
for chunk in discord.utils.as_chunks(iter(embeds), 10):
await ctx.respond(chunk, ephemeral=True)
def setup(bot): def setup(bot):
bot.add_cog(Ollama(bot)) bot.add_cog(Ollama(bot))