Allow saving and loading threads (with REDIS)

This commit is contained in:
Nexus 2024-01-12 16:47:45 +00:00
parent c47cac4118
commit 9db6b99f90
4 changed files with 38 additions and 5 deletions

View file

@ -11,6 +11,8 @@ services:
- jimmy-data:/app/data
ports:
- 11444:8080
extra_hosts:
- host.docker.internal:host-gateway
ollama:
image: ollama/ollama:latest
container_name: ollama
@ -19,6 +21,8 @@ services:
- 11434:11434
volumes:
- ollama-data:/root/.ollama
redis:
image: redis
volumes:
ollama-data:

View file

@ -16,3 +16,4 @@ uvicorn==0.25.0
psutil==5.9.7
pydantic==2.5.3
humanize==4.9.0
redis==5.0.1

View file

@ -7,7 +7,7 @@ import time
import typing
import base64
import io
import humanize
import redis
from discord.ui import View, button
from fnmatch import fnmatch
@ -71,6 +71,19 @@ class OllamaView(View):
class ChatHistory:
def __init__(self):
self._internal = {}
self.redis = redis.Redis(**CONFIG["redis"])
def load_thread(self, thread_id: str):
value: str = self.redis.get("threads:" + thread_id)
if value:
loaded = json.loads(value)
self._internal.update(loaded)
return self.get_thread(thread_id)
def save_thread(self, thread_id: str):
self.redis.set(
"threads:" + thread_id, json.dumps(self._internal[thread_id])
)
def create_thread(self, member: discord.Member) -> str:
"""
@ -106,7 +119,7 @@ class ChatHistory:
# noinspection PyTypeChecker
cog: Ollama = ctx.bot.get_cog("Ollama")
instance = cog.history
return discord.utils.basic_autocomplete(list(instance.threads_for(ctx.interaction.user).keys()))
return discord.utils.basic_autocomplete(list(instance.threads_for(ctx.interaction.user).keys()))(ctx)
def all_threads(self) -> dict[str, dict[str, list[dict[str, str]] | discord.Member | int]]:
"""Returns all saved threads."""
@ -444,8 +457,14 @@ class Ollama(commands.Cog):
await ctx.respond(embed=embed, view=view)
self.log.debug("Beginning to generate response with key %r.", key)
if context is None or context in self.contexts:
if context is None:
context = self.history.create_thread(ctx.user)
elif context is not None and self.history.get_thread(context) is None:
if not self.history.load_thread(context):
return await ctx.respond("Invalid thread ID.")
else:
context = self.history.create_thread(ctx.user)
messages = self.history.get_history(context)
user_message = {
"role": "user",
@ -525,6 +544,7 @@ class Ollama(commands.Cog):
await ctx.edit(embed=embed, view=None)
if line.get("done"):
self.history.save_thread(context)
total_duration = get_time_spent(line["total_duration"])
load_duration = get_time_spent(line["load_duration"])
prompt_eval_duration = get_time_spent(line["prompt_eval_duration"])
@ -576,8 +596,8 @@ class Ollama(commands.Cog):
description=page
)
)
for chunk in discord.utils.as_chunks(iter(embeds), 10):
await ctx.respond(chunk, ephemeral=True)
for chunk in discord.utils.as_chunks(iter(embeds or [discord.Embed(title="No Content.")]), 10):
await ctx.respond(embeds=chunk, ephemeral=True)
def setup(bot):

View file

@ -15,6 +15,14 @@ try:
"channel": 1032974266527907901
}
)
CONFIG.setdefault(
"redis",
{
"host": "redis",
"port": 6379,
"decode_responses": True
}
)
except FileNotFoundError:
cwd = Path.cwd()
logging.getLogger("jimmy.autoconf").critical("Unable to locate config.toml in %s.", cwd, exc_info=True)