diff --git a/docker-compose.yml b/docker-compose.yml index 4e322cb..2c2374f 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -11,6 +11,8 @@ services: - jimmy-data:/app/data ports: - 11444:8080 + extra_hosts: + - host.docker.internal:host-gateway ollama: image: ollama/ollama:latest container_name: ollama @@ -19,6 +21,8 @@ services: - 11434:11434 volumes: - ollama-data:/root/.ollama + redis: + image: redis volumes: ollama-data: diff --git a/requirements.txt b/requirements.txt index e746281..0ac76cd 100644 --- a/requirements.txt +++ b/requirements.txt @@ -16,3 +16,4 @@ uvicorn==0.25.0 psutil==5.9.7 pydantic==2.5.3 humanize==4.9.0 +redis==5.0.1 diff --git a/src/cogs/ollama.py b/src/cogs/ollama.py index bd9a3d4..dde7d82 100644 --- a/src/cogs/ollama.py +++ b/src/cogs/ollama.py @@ -7,7 +7,7 @@ import time import typing import base64 import io -import humanize +import redis from discord.ui import View, button from fnmatch import fnmatch @@ -71,6 +71,19 @@ class OllamaView(View): class ChatHistory: def __init__(self): self._internal = {} + self.redis = redis.Redis(**CONFIG["redis"]) + + def load_thread(self, thread_id: str): + value: str = self.redis.get("threads:" + thread_id) + if value: + loaded = json.loads(value) + self._internal.update(loaded) + return self.get_thread(thread_id) + + def save_thread(self, thread_id: str): + self.redis.set( + "threads:" + thread_id, json.dumps(self._internal[thread_id]) + ) def create_thread(self, member: discord.Member) -> str: """ @@ -106,7 +119,7 @@ class ChatHistory: # noinspection PyTypeChecker cog: Ollama = ctx.bot.get_cog("Ollama") instance = cog.history - return discord.utils.basic_autocomplete(list(instance.threads_for(ctx.interaction.user).keys())) + return discord.utils.basic_autocomplete(list(instance.threads_for(ctx.interaction.user).keys()))(ctx) def all_threads(self) -> dict[str, dict[str, list[dict[str, str]] | discord.Member | int]]: """Returns all saved threads.""" @@ -444,8 +457,14 @@ class Ollama(commands.Cog): await ctx.respond(embed=embed, view=view) self.log.debug("Beginning to generate response with key %r.", key) - if context is None or context in self.contexts: + if context is None: context = self.history.create_thread(ctx.user) + elif context is not None and self.history.get_thread(context) is None: + if not self.history.load_thread(context): + return await ctx.respond("Invalid thread ID.") + else: + context = self.history.create_thread(ctx.user) + messages = self.history.get_history(context) user_message = { "role": "user", @@ -525,6 +544,7 @@ class Ollama(commands.Cog): await ctx.edit(embed=embed, view=None) if line.get("done"): + self.history.save_thread(context) total_duration = get_time_spent(line["total_duration"]) load_duration = get_time_spent(line["load_duration"]) prompt_eval_duration = get_time_spent(line["prompt_eval_duration"]) @@ -576,8 +596,8 @@ class Ollama(commands.Cog): description=page ) ) - for chunk in discord.utils.as_chunks(iter(embeds), 10): - await ctx.respond(chunk, ephemeral=True) + for chunk in discord.utils.as_chunks(iter(embeds or [discord.Embed(title="No Content.")]), 10): + await ctx.respond(embeds=chunk, ephemeral=True) def setup(bot): diff --git a/src/conf.py b/src/conf.py index aee0f08..2fd7a74 100644 --- a/src/conf.py +++ b/src/conf.py @@ -15,6 +15,14 @@ try: "channel": 1032974266527907901 } ) + CONFIG.setdefault( + "redis", + { + "host": "redis", + "port": 6379, + "decode_responses": True + } + ) except FileNotFoundError: cwd = Path.cwd() logging.getLogger("jimmy.autoconf").critical("Unable to locate config.toml in %s.", cwd, exc_info=True)