Add history command
This commit is contained in:
parent
583ab92b54
commit
c47cac4118
1 changed files with 64 additions and 1 deletions
|
@ -18,6 +18,9 @@ from discord.ext import commands
|
|||
from conf import CONFIG
|
||||
|
||||
|
||||
_M = typing.TypeVar("_M", discord.Member, discord.User)
|
||||
|
||||
|
||||
def get_time_spent(nanoseconds: int) -> str:
|
||||
hours, minutes, seconds = 0, 0, 0
|
||||
seconds = nanoseconds / 1e9
|
||||
|
@ -98,6 +101,25 @@ class ChatHistory:
|
|||
x["images"] = images
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def autocomplete(ctx: discord.AutocompleteContext):
|
||||
# noinspection PyTypeChecker
|
||||
cog: Ollama = ctx.bot.get_cog("Ollama")
|
||||
instance = cog.history
|
||||
return discord.utils.basic_autocomplete(list(instance.threads_for(ctx.interaction.user).keys()))
|
||||
|
||||
def all_threads(self) -> dict[str, dict[str, list[dict[str, str]] | discord.Member | int]]:
|
||||
"""Returns all saved threads."""
|
||||
return self._internal.copy()
|
||||
|
||||
def threads_for(self, user: _M) -> dict[str, list[dict[str, str]] | _M | int]:
|
||||
"""Returns all saved threads for a specific user"""
|
||||
t = self.all_threads()
|
||||
for k, v in t.copy().items():
|
||||
if v["member"] != user:
|
||||
t.pop(k)
|
||||
return t
|
||||
|
||||
def add_message(
|
||||
self,
|
||||
thread: str,
|
||||
|
@ -120,6 +142,8 @@ class ChatHistory:
|
|||
"""
|
||||
Gets the history of a thread.
|
||||
"""
|
||||
if self._internal.get(thread) is None:
|
||||
return []
|
||||
return self._internal[thread]["messages"].copy() # copy() makes it immutable.
|
||||
|
||||
def get_thread(self, thread: str) -> dict[str, list[dict[str, str]] | discord.Member | int]:
|
||||
|
@ -420,7 +444,6 @@ class Ollama(commands.Cog):
|
|||
await ctx.respond(embed=embed, view=view)
|
||||
self.log.debug("Beginning to generate response with key %r.", key)
|
||||
|
||||
|
||||
if context is None or context in self.contexts:
|
||||
context = self.history.create_thread(ctx.user)
|
||||
messages = self.history.get_history(context)
|
||||
|
@ -516,6 +539,46 @@ class Ollama(commands.Cog):
|
|||
)
|
||||
return await ctx.respond(embed=embed, ephemeral=True)
|
||||
|
||||
@commands.slash_command(name="ollama-history")
|
||||
async def ollama_history(
|
||||
self,
|
||||
ctx: discord.ApplicationContext,
|
||||
thread: typing.Annotated[
|
||||
str,
|
||||
discord.Option(
|
||||
name="thread_id",
|
||||
description="Thread/Context ID",
|
||||
type=str,
|
||||
autocomplete=ChatHistory.autocomplete,
|
||||
)
|
||||
]
|
||||
):
|
||||
"""Shows the history for a thread."""
|
||||
await ctx.defer(ephemeral=True)
|
||||
paginator = commands.Paginator("", "", 4000, "\n\n")
|
||||
|
||||
history = self.history.get_history(thread)
|
||||
if not history:
|
||||
return await ctx.respond("No history or invalid context key.")
|
||||
|
||||
for message in history:
|
||||
if message["role"] == "system":
|
||||
continue
|
||||
max_length = 4000 - len("> **%s**: " % message["role"])
|
||||
paginator.add_line(
|
||||
"> **{}**: {}".format(message["role"], textwrap.shorten(message["content"], max_length))
|
||||
)
|
||||
|
||||
embeds = []
|
||||
for page in paginator.pages:
|
||||
embeds.append(
|
||||
discord.Embed(
|
||||
description=page
|
||||
)
|
||||
)
|
||||
for chunk in discord.utils.as_chunks(iter(embeds), 10):
|
||||
await ctx.respond(chunk, ephemeral=True)
|
||||
|
||||
|
||||
def setup(bot):
|
||||
bot.add_cog(Ollama(bot))
|
||||
|
|
Loading…
Reference in a new issue