Allow importing legacy threads
All checks were successful
Build and Publish / build_and_publish (push) Successful in 58s

This commit is contained in:
Nexus 2024-06-16 15:53:43 +01:00
parent d203376850
commit 448a23affa
3 changed files with 37 additions and 9 deletions

View file

@ -7,11 +7,12 @@ import typing
import contextlib
import discord
import httpx
from discord import Interaction
from ollama import AsyncClient, ResponseError, Options
from discord.ext import commands
from jimmy.utils import create_ollama_message, find_suitable_server, decorate_server_name as decorate_name
from jimmy.config import get_servers, get_server
from jimmy.config import get_servers, get_server, get_config
from jimmy.db import OllamaThread
from humanize import naturalsize, naturaldelta
@ -84,7 +85,7 @@ class Chat(commands.Cog):
if self.server_locks[server.name].locked():
embed.add_field(
name=decorate_name(server),
value=f"\N{closed lock with key} In use.",
value="\N{closed lock with key} In use.",
inline=False
)
fields[server] = len(embed.fields) - 1
@ -92,7 +93,7 @@ class Chat(commands.Cog):
else:
embed.add_field(
name=decorate_name(server),
value=f"\N{hourglass with flowing sand} Waiting...",
value="\N{hourglass with flowing sand} Waiting...",
inline=False
)
fields[server] = len(embed.fields) - 1
@ -110,14 +111,14 @@ class Chat(commands.Cog):
embed.set_field_at(
fields[server],
name=decorate_name(server),
value=f"\N{white heavy check mark} Online.",
value="\N{white heavy check mark} Online.",
inline=False
)
else:
embed.set_field_at(
fields[server],
name=decorate_name(server),
value=f"\N{cross mark} Offline.",
value="\N{cross mark} Offline.",
inline=False
)
await ctx.edit(embed=embed)
@ -238,7 +239,7 @@ class Chat(commands.Cog):
if model not in [x["model"] for x in tags]:
embed = discord.Embed(
title=f"Downloading {model} on {server}.",
description=f"Initiating download...",
description="Initiating download...",
color=discord.Color.blurple()
)
view = StopDownloadView(ctx)
@ -321,8 +322,29 @@ class Chat(commands.Cog):
messages.append(
await create_ollama_message(msg["content"], role=msg["role"])
)
elif len(thread_id) == 6:
# Is a legacy thread
_cfg = get_config()["truth_api"]
async with httpx.AsyncClient(
base_url=_cfg["url"],
auth=(_cfg["username"], _cfg["password"])
) as http_client:
response = await http_client.get(f"/threads/threads:{thread_id}")
if response.status_code == 200:
thread = await response.json()
messages = thread["messages"]
thread = OllamaThread(
messages=[{"role": m["role"], "content": m["content"]} for m in messages],
)
await thread.save()
else:
return await ctx.respond(
content="Failed to fetch legacy ollama thread from jimmy v2: HTTP %d (`%r`)" % (
response.status_code, response.text
),
)
else:
await ctx.respond(content="No thread with that ID exists.", delete_after=30)
return await ctx.respond(content="No thread with that ID exists.", delete_after=30)
if system_prompt:
messages.append(await create_ollama_message(system_prompt, role="system"))
messages.append(await create_ollama_message(prompt, images=[await image.read()] if image else None))
@ -374,7 +396,7 @@ class Chat(commands.Cog):
embed.add_field(
name="Full chat",
value="The chat was too long to fit in this message. "
f"You can download the `full-chat.txt` file to see the full message."
"You can download the `full-chat.txt` file to see the full message."
)
else:
file = discord.utils.MISSING
@ -419,7 +441,7 @@ class Chat(commands.Cog):
return await ctx.respond(f"\N{cross mark} Server {server.name!r} is not responding")
embed = discord.Embed(
title=f"Downloading {model} on {server}.",
description=f"Initiating download...",
description="Initiating download...",
color=discord.Color.blurple()
)
view = StopDownloadView(ctx)

View file

@ -102,6 +102,10 @@ def get_config():
_loaded.setdefault("servers", {})
_loaded["servers"].setdefault("order", [])
_loaded.setdefault("bot", {})
_loaded.setdefault("truth_api", {})
_loaded["truth_api"].setdefault("url", "https://bots.nexy7574.co.uk/jimmy/v2/api")
_loaded["truth_api"].setdefault("username", "invalid")
_loaded["truth_api"].setdefault("password", "invalid")
if database_url := os.getenv("DATABASE_URL"):
_loaded["bot"]["db_url"] = database_url
return _loaded

2
tox.ini Normal file
View file

@ -0,0 +1,2 @@
[flake8]
max-line-length = 120