Allow saving and loading threads (with REDIS)
This commit is contained in:
parent
c47cac4118
commit
9db6b99f90
4 changed files with 38 additions and 5 deletions
|
@ -11,6 +11,8 @@ services:
|
||||||
- jimmy-data:/app/data
|
- jimmy-data:/app/data
|
||||||
ports:
|
ports:
|
||||||
- 11444:8080
|
- 11444:8080
|
||||||
|
extra_hosts:
|
||||||
|
- host.docker.internal:host-gateway
|
||||||
ollama:
|
ollama:
|
||||||
image: ollama/ollama:latest
|
image: ollama/ollama:latest
|
||||||
container_name: ollama
|
container_name: ollama
|
||||||
|
@ -19,6 +21,8 @@ services:
|
||||||
- 11434:11434
|
- 11434:11434
|
||||||
volumes:
|
volumes:
|
||||||
- ollama-data:/root/.ollama
|
- ollama-data:/root/.ollama
|
||||||
|
redis:
|
||||||
|
image: redis
|
||||||
|
|
||||||
volumes:
|
volumes:
|
||||||
ollama-data:
|
ollama-data:
|
||||||
|
|
|
@ -16,3 +16,4 @@ uvicorn==0.25.0
|
||||||
psutil==5.9.7
|
psutil==5.9.7
|
||||||
pydantic==2.5.3
|
pydantic==2.5.3
|
||||||
humanize==4.9.0
|
humanize==4.9.0
|
||||||
|
redis==5.0.1
|
||||||
|
|
|
@ -7,7 +7,7 @@ import time
|
||||||
import typing
|
import typing
|
||||||
import base64
|
import base64
|
||||||
import io
|
import io
|
||||||
import humanize
|
import redis
|
||||||
|
|
||||||
from discord.ui import View, button
|
from discord.ui import View, button
|
||||||
from fnmatch import fnmatch
|
from fnmatch import fnmatch
|
||||||
|
@ -71,6 +71,19 @@ class OllamaView(View):
|
||||||
class ChatHistory:
|
class ChatHistory:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self._internal = {}
|
self._internal = {}
|
||||||
|
self.redis = redis.Redis(**CONFIG["redis"])
|
||||||
|
|
||||||
|
def load_thread(self, thread_id: str):
|
||||||
|
value: str = self.redis.get("threads:" + thread_id)
|
||||||
|
if value:
|
||||||
|
loaded = json.loads(value)
|
||||||
|
self._internal.update(loaded)
|
||||||
|
return self.get_thread(thread_id)
|
||||||
|
|
||||||
|
def save_thread(self, thread_id: str):
|
||||||
|
self.redis.set(
|
||||||
|
"threads:" + thread_id, json.dumps(self._internal[thread_id])
|
||||||
|
)
|
||||||
|
|
||||||
def create_thread(self, member: discord.Member) -> str:
|
def create_thread(self, member: discord.Member) -> str:
|
||||||
"""
|
"""
|
||||||
|
@ -106,7 +119,7 @@ class ChatHistory:
|
||||||
# noinspection PyTypeChecker
|
# noinspection PyTypeChecker
|
||||||
cog: Ollama = ctx.bot.get_cog("Ollama")
|
cog: Ollama = ctx.bot.get_cog("Ollama")
|
||||||
instance = cog.history
|
instance = cog.history
|
||||||
return discord.utils.basic_autocomplete(list(instance.threads_for(ctx.interaction.user).keys()))
|
return discord.utils.basic_autocomplete(list(instance.threads_for(ctx.interaction.user).keys()))(ctx)
|
||||||
|
|
||||||
def all_threads(self) -> dict[str, dict[str, list[dict[str, str]] | discord.Member | int]]:
|
def all_threads(self) -> dict[str, dict[str, list[dict[str, str]] | discord.Member | int]]:
|
||||||
"""Returns all saved threads."""
|
"""Returns all saved threads."""
|
||||||
|
@ -444,8 +457,14 @@ class Ollama(commands.Cog):
|
||||||
await ctx.respond(embed=embed, view=view)
|
await ctx.respond(embed=embed, view=view)
|
||||||
self.log.debug("Beginning to generate response with key %r.", key)
|
self.log.debug("Beginning to generate response with key %r.", key)
|
||||||
|
|
||||||
if context is None or context in self.contexts:
|
if context is None:
|
||||||
context = self.history.create_thread(ctx.user)
|
context = self.history.create_thread(ctx.user)
|
||||||
|
elif context is not None and self.history.get_thread(context) is None:
|
||||||
|
if not self.history.load_thread(context):
|
||||||
|
return await ctx.respond("Invalid thread ID.")
|
||||||
|
else:
|
||||||
|
context = self.history.create_thread(ctx.user)
|
||||||
|
|
||||||
messages = self.history.get_history(context)
|
messages = self.history.get_history(context)
|
||||||
user_message = {
|
user_message = {
|
||||||
"role": "user",
|
"role": "user",
|
||||||
|
@ -525,6 +544,7 @@ class Ollama(commands.Cog):
|
||||||
await ctx.edit(embed=embed, view=None)
|
await ctx.edit(embed=embed, view=None)
|
||||||
|
|
||||||
if line.get("done"):
|
if line.get("done"):
|
||||||
|
self.history.save_thread(context)
|
||||||
total_duration = get_time_spent(line["total_duration"])
|
total_duration = get_time_spent(line["total_duration"])
|
||||||
load_duration = get_time_spent(line["load_duration"])
|
load_duration = get_time_spent(line["load_duration"])
|
||||||
prompt_eval_duration = get_time_spent(line["prompt_eval_duration"])
|
prompt_eval_duration = get_time_spent(line["prompt_eval_duration"])
|
||||||
|
@ -576,8 +596,8 @@ class Ollama(commands.Cog):
|
||||||
description=page
|
description=page
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
for chunk in discord.utils.as_chunks(iter(embeds), 10):
|
for chunk in discord.utils.as_chunks(iter(embeds or [discord.Embed(title="No Content.")]), 10):
|
||||||
await ctx.respond(chunk, ephemeral=True)
|
await ctx.respond(embeds=chunk, ephemeral=True)
|
||||||
|
|
||||||
|
|
||||||
def setup(bot):
|
def setup(bot):
|
||||||
|
|
|
@ -15,6 +15,14 @@ try:
|
||||||
"channel": 1032974266527907901
|
"channel": 1032974266527907901
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
CONFIG.setdefault(
|
||||||
|
"redis",
|
||||||
|
{
|
||||||
|
"host": "redis",
|
||||||
|
"port": 6379,
|
||||||
|
"decode_responses": True
|
||||||
|
}
|
||||||
|
)
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
cwd = Path.cwd()
|
cwd = Path.cwd()
|
||||||
logging.getLogger("jimmy.autoconf").critical("Unable to locate config.toml in %s.", cwd, exc_info=True)
|
logging.getLogger("jimmy.autoconf").critical("Unable to locate config.toml in %s.", cwd, exc_info=True)
|
||||||
|
|
Loading…
Reference in a new issue