Allow saving and loading threads (with REDIS)
This commit is contained in:
parent
c47cac4118
commit
9db6b99f90
4 changed files with 38 additions and 5 deletions
|
@ -11,6 +11,8 @@ services:
|
|||
- jimmy-data:/app/data
|
||||
ports:
|
||||
- 11444:8080
|
||||
extra_hosts:
|
||||
- host.docker.internal:host-gateway
|
||||
ollama:
|
||||
image: ollama/ollama:latest
|
||||
container_name: ollama
|
||||
|
@ -19,6 +21,8 @@ services:
|
|||
- 11434:11434
|
||||
volumes:
|
||||
- ollama-data:/root/.ollama
|
||||
redis:
|
||||
image: redis
|
||||
|
||||
volumes:
|
||||
ollama-data:
|
||||
|
|
|
@ -16,3 +16,4 @@ uvicorn==0.25.0
|
|||
psutil==5.9.7
|
||||
pydantic==2.5.3
|
||||
humanize==4.9.0
|
||||
redis==5.0.1
|
||||
|
|
|
@ -7,7 +7,7 @@ import time
|
|||
import typing
|
||||
import base64
|
||||
import io
|
||||
import humanize
|
||||
import redis
|
||||
|
||||
from discord.ui import View, button
|
||||
from fnmatch import fnmatch
|
||||
|
@ -71,6 +71,19 @@ class OllamaView(View):
|
|||
class ChatHistory:
|
||||
def __init__(self):
|
||||
self._internal = {}
|
||||
self.redis = redis.Redis(**CONFIG["redis"])
|
||||
|
||||
def load_thread(self, thread_id: str):
|
||||
value: str = self.redis.get("threads:" + thread_id)
|
||||
if value:
|
||||
loaded = json.loads(value)
|
||||
self._internal.update(loaded)
|
||||
return self.get_thread(thread_id)
|
||||
|
||||
def save_thread(self, thread_id: str):
|
||||
self.redis.set(
|
||||
"threads:" + thread_id, json.dumps(self._internal[thread_id])
|
||||
)
|
||||
|
||||
def create_thread(self, member: discord.Member) -> str:
|
||||
"""
|
||||
|
@ -106,7 +119,7 @@ class ChatHistory:
|
|||
# noinspection PyTypeChecker
|
||||
cog: Ollama = ctx.bot.get_cog("Ollama")
|
||||
instance = cog.history
|
||||
return discord.utils.basic_autocomplete(list(instance.threads_for(ctx.interaction.user).keys()))
|
||||
return discord.utils.basic_autocomplete(list(instance.threads_for(ctx.interaction.user).keys()))(ctx)
|
||||
|
||||
def all_threads(self) -> dict[str, dict[str, list[dict[str, str]] | discord.Member | int]]:
|
||||
"""Returns all saved threads."""
|
||||
|
@ -444,8 +457,14 @@ class Ollama(commands.Cog):
|
|||
await ctx.respond(embed=embed, view=view)
|
||||
self.log.debug("Beginning to generate response with key %r.", key)
|
||||
|
||||
if context is None or context in self.contexts:
|
||||
if context is None:
|
||||
context = self.history.create_thread(ctx.user)
|
||||
elif context is not None and self.history.get_thread(context) is None:
|
||||
if not self.history.load_thread(context):
|
||||
return await ctx.respond("Invalid thread ID.")
|
||||
else:
|
||||
context = self.history.create_thread(ctx.user)
|
||||
|
||||
messages = self.history.get_history(context)
|
||||
user_message = {
|
||||
"role": "user",
|
||||
|
@ -525,6 +544,7 @@ class Ollama(commands.Cog):
|
|||
await ctx.edit(embed=embed, view=None)
|
||||
|
||||
if line.get("done"):
|
||||
self.history.save_thread(context)
|
||||
total_duration = get_time_spent(line["total_duration"])
|
||||
load_duration = get_time_spent(line["load_duration"])
|
||||
prompt_eval_duration = get_time_spent(line["prompt_eval_duration"])
|
||||
|
@ -576,8 +596,8 @@ class Ollama(commands.Cog):
|
|||
description=page
|
||||
)
|
||||
)
|
||||
for chunk in discord.utils.as_chunks(iter(embeds), 10):
|
||||
await ctx.respond(chunk, ephemeral=True)
|
||||
for chunk in discord.utils.as_chunks(iter(embeds or [discord.Embed(title="No Content.")]), 10):
|
||||
await ctx.respond(embeds=chunk, ephemeral=True)
|
||||
|
||||
|
||||
def setup(bot):
|
||||
|
|
|
@ -15,6 +15,14 @@ try:
|
|||
"channel": 1032974266527907901
|
||||
}
|
||||
)
|
||||
CONFIG.setdefault(
|
||||
"redis",
|
||||
{
|
||||
"host": "redis",
|
||||
"port": 6379,
|
||||
"decode_responses": True
|
||||
}
|
||||
)
|
||||
except FileNotFoundError:
|
||||
cwd = Path.cwd()
|
||||
logging.getLogger("jimmy.autoconf").critical("Unable to locate config.toml in %s.", cwd, exc_info=True)
|
||||
|
|
Loading…
Reference in a new issue