Allow saving and loading threads (with REDIS)

This commit is contained in:
Nexus 2024-01-12 16:47:45 +00:00
parent c47cac4118
commit 9db6b99f90
4 changed files with 38 additions and 5 deletions

View file

@ -11,6 +11,8 @@ services:
- jimmy-data:/app/data - jimmy-data:/app/data
ports: ports:
- 11444:8080 - 11444:8080
extra_hosts:
- host.docker.internal:host-gateway
ollama: ollama:
image: ollama/ollama:latest image: ollama/ollama:latest
container_name: ollama container_name: ollama
@ -19,6 +21,8 @@ services:
- 11434:11434 - 11434:11434
volumes: volumes:
- ollama-data:/root/.ollama - ollama-data:/root/.ollama
redis:
image: redis
volumes: volumes:
ollama-data: ollama-data:

View file

@ -16,3 +16,4 @@ uvicorn==0.25.0
psutil==5.9.7 psutil==5.9.7
pydantic==2.5.3 pydantic==2.5.3
humanize==4.9.0 humanize==4.9.0
redis==5.0.1

View file

@ -7,7 +7,7 @@ import time
import typing import typing
import base64 import base64
import io import io
import humanize import redis
from discord.ui import View, button from discord.ui import View, button
from fnmatch import fnmatch from fnmatch import fnmatch
@ -71,6 +71,19 @@ class OllamaView(View):
class ChatHistory: class ChatHistory:
def __init__(self): def __init__(self):
self._internal = {} self._internal = {}
self.redis = redis.Redis(**CONFIG["redis"])
def load_thread(self, thread_id: str):
value: str = self.redis.get("threads:" + thread_id)
if value:
loaded = json.loads(value)
self._internal.update(loaded)
return self.get_thread(thread_id)
def save_thread(self, thread_id: str):
self.redis.set(
"threads:" + thread_id, json.dumps(self._internal[thread_id])
)
def create_thread(self, member: discord.Member) -> str: def create_thread(self, member: discord.Member) -> str:
""" """
@ -106,7 +119,7 @@ class ChatHistory:
# noinspection PyTypeChecker # noinspection PyTypeChecker
cog: Ollama = ctx.bot.get_cog("Ollama") cog: Ollama = ctx.bot.get_cog("Ollama")
instance = cog.history instance = cog.history
return discord.utils.basic_autocomplete(list(instance.threads_for(ctx.interaction.user).keys())) return discord.utils.basic_autocomplete(list(instance.threads_for(ctx.interaction.user).keys()))(ctx)
def all_threads(self) -> dict[str, dict[str, list[dict[str, str]] | discord.Member | int]]: def all_threads(self) -> dict[str, dict[str, list[dict[str, str]] | discord.Member | int]]:
"""Returns all saved threads.""" """Returns all saved threads."""
@ -444,8 +457,14 @@ class Ollama(commands.Cog):
await ctx.respond(embed=embed, view=view) await ctx.respond(embed=embed, view=view)
self.log.debug("Beginning to generate response with key %r.", key) self.log.debug("Beginning to generate response with key %r.", key)
if context is None or context in self.contexts: if context is None:
context = self.history.create_thread(ctx.user) context = self.history.create_thread(ctx.user)
elif context is not None and self.history.get_thread(context) is None:
if not self.history.load_thread(context):
return await ctx.respond("Invalid thread ID.")
else:
context = self.history.create_thread(ctx.user)
messages = self.history.get_history(context) messages = self.history.get_history(context)
user_message = { user_message = {
"role": "user", "role": "user",
@ -525,6 +544,7 @@ class Ollama(commands.Cog):
await ctx.edit(embed=embed, view=None) await ctx.edit(embed=embed, view=None)
if line.get("done"): if line.get("done"):
self.history.save_thread(context)
total_duration = get_time_spent(line["total_duration"]) total_duration = get_time_spent(line["total_duration"])
load_duration = get_time_spent(line["load_duration"]) load_duration = get_time_spent(line["load_duration"])
prompt_eval_duration = get_time_spent(line["prompt_eval_duration"]) prompt_eval_duration = get_time_spent(line["prompt_eval_duration"])
@ -576,8 +596,8 @@ class Ollama(commands.Cog):
description=page description=page
) )
) )
for chunk in discord.utils.as_chunks(iter(embeds), 10): for chunk in discord.utils.as_chunks(iter(embeds or [discord.Embed(title="No Content.")]), 10):
await ctx.respond(chunk, ephemeral=True) await ctx.respond(embeds=chunk, ephemeral=True)
def setup(bot): def setup(bot):

View file

@ -15,6 +15,14 @@ try:
"channel": 1032974266527907901 "channel": 1032974266527907901
} }
) )
CONFIG.setdefault(
"redis",
{
"host": "redis",
"port": 6379,
"decode_responses": True
}
)
except FileNotFoundError: except FileNotFoundError:
cwd = Path.cwd() cwd = Path.cwd()
logging.getLogger("jimmy.autoconf").critical("Unable to locate config.toml in %s.", cwd, exc_info=True) logging.getLogger("jimmy.autoconf").critical("Unable to locate config.toml in %s.", cwd, exc_info=True)