Add history command
This commit is contained in:
parent
583ab92b54
commit
c47cac4118
1 changed files with 64 additions and 1 deletions
|
@ -18,6 +18,9 @@ from discord.ext import commands
|
||||||
from conf import CONFIG
|
from conf import CONFIG
|
||||||
|
|
||||||
|
|
||||||
|
_M = typing.TypeVar("_M", discord.Member, discord.User)
|
||||||
|
|
||||||
|
|
||||||
def get_time_spent(nanoseconds: int) -> str:
|
def get_time_spent(nanoseconds: int) -> str:
|
||||||
hours, minutes, seconds = 0, 0, 0
|
hours, minutes, seconds = 0, 0, 0
|
||||||
seconds = nanoseconds / 1e9
|
seconds = nanoseconds / 1e9
|
||||||
|
@ -98,6 +101,25 @@ class ChatHistory:
|
||||||
x["images"] = images
|
x["images"] = images
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def autocomplete(ctx: discord.AutocompleteContext):
|
||||||
|
# noinspection PyTypeChecker
|
||||||
|
cog: Ollama = ctx.bot.get_cog("Ollama")
|
||||||
|
instance = cog.history
|
||||||
|
return discord.utils.basic_autocomplete(list(instance.threads_for(ctx.interaction.user).keys()))
|
||||||
|
|
||||||
|
def all_threads(self) -> dict[str, dict[str, list[dict[str, str]] | discord.Member | int]]:
|
||||||
|
"""Returns all saved threads."""
|
||||||
|
return self._internal.copy()
|
||||||
|
|
||||||
|
def threads_for(self, user: _M) -> dict[str, list[dict[str, str]] | _M | int]:
|
||||||
|
"""Returns all saved threads for a specific user"""
|
||||||
|
t = self.all_threads()
|
||||||
|
for k, v in t.copy().items():
|
||||||
|
if v["member"] != user:
|
||||||
|
t.pop(k)
|
||||||
|
return t
|
||||||
|
|
||||||
def add_message(
|
def add_message(
|
||||||
self,
|
self,
|
||||||
thread: str,
|
thread: str,
|
||||||
|
@ -120,6 +142,8 @@ class ChatHistory:
|
||||||
"""
|
"""
|
||||||
Gets the history of a thread.
|
Gets the history of a thread.
|
||||||
"""
|
"""
|
||||||
|
if self._internal.get(thread) is None:
|
||||||
|
return []
|
||||||
return self._internal[thread]["messages"].copy() # copy() makes it immutable.
|
return self._internal[thread]["messages"].copy() # copy() makes it immutable.
|
||||||
|
|
||||||
def get_thread(self, thread: str) -> dict[str, list[dict[str, str]] | discord.Member | int]:
|
def get_thread(self, thread: str) -> dict[str, list[dict[str, str]] | discord.Member | int]:
|
||||||
|
@ -420,7 +444,6 @@ class Ollama(commands.Cog):
|
||||||
await ctx.respond(embed=embed, view=view)
|
await ctx.respond(embed=embed, view=view)
|
||||||
self.log.debug("Beginning to generate response with key %r.", key)
|
self.log.debug("Beginning to generate response with key %r.", key)
|
||||||
|
|
||||||
|
|
||||||
if context is None or context in self.contexts:
|
if context is None or context in self.contexts:
|
||||||
context = self.history.create_thread(ctx.user)
|
context = self.history.create_thread(ctx.user)
|
||||||
messages = self.history.get_history(context)
|
messages = self.history.get_history(context)
|
||||||
|
@ -516,6 +539,46 @@ class Ollama(commands.Cog):
|
||||||
)
|
)
|
||||||
return await ctx.respond(embed=embed, ephemeral=True)
|
return await ctx.respond(embed=embed, ephemeral=True)
|
||||||
|
|
||||||
|
@commands.slash_command(name="ollama-history")
|
||||||
|
async def ollama_history(
|
||||||
|
self,
|
||||||
|
ctx: discord.ApplicationContext,
|
||||||
|
thread: typing.Annotated[
|
||||||
|
str,
|
||||||
|
discord.Option(
|
||||||
|
name="thread_id",
|
||||||
|
description="Thread/Context ID",
|
||||||
|
type=str,
|
||||||
|
autocomplete=ChatHistory.autocomplete,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
):
|
||||||
|
"""Shows the history for a thread."""
|
||||||
|
await ctx.defer(ephemeral=True)
|
||||||
|
paginator = commands.Paginator("", "", 4000, "\n\n")
|
||||||
|
|
||||||
|
history = self.history.get_history(thread)
|
||||||
|
if not history:
|
||||||
|
return await ctx.respond("No history or invalid context key.")
|
||||||
|
|
||||||
|
for message in history:
|
||||||
|
if message["role"] == "system":
|
||||||
|
continue
|
||||||
|
max_length = 4000 - len("> **%s**: " % message["role"])
|
||||||
|
paginator.add_line(
|
||||||
|
"> **{}**: {}".format(message["role"], textwrap.shorten(message["content"], max_length))
|
||||||
|
)
|
||||||
|
|
||||||
|
embeds = []
|
||||||
|
for page in paginator.pages:
|
||||||
|
embeds.append(
|
||||||
|
discord.Embed(
|
||||||
|
description=page
|
||||||
|
)
|
||||||
|
)
|
||||||
|
for chunk in discord.utils.as_chunks(iter(embeds), 10):
|
||||||
|
await ctx.respond(chunk, ephemeral=True)
|
||||||
|
|
||||||
|
|
||||||
def setup(bot):
|
def setup(bot):
|
||||||
bot.add_cog(Ollama(bot))
|
bot.add_cog(Ollama(bot))
|
||||||
|
|
Loading…
Reference in a new issue