Allow importing legacy threads
All checks were successful
Build and Publish / build_and_publish (push) Successful in 58s

This commit is contained in:
Nexus 2024-06-16 15:53:43 +01:00
parent d203376850
commit 448a23affa
3 changed files with 37 additions and 9 deletions

View file

@ -7,11 +7,12 @@ import typing
import contextlib import contextlib
import discord import discord
import httpx
from discord import Interaction from discord import Interaction
from ollama import AsyncClient, ResponseError, Options from ollama import AsyncClient, ResponseError, Options
from discord.ext import commands from discord.ext import commands
from jimmy.utils import create_ollama_message, find_suitable_server, decorate_server_name as decorate_name from jimmy.utils import create_ollama_message, find_suitable_server, decorate_server_name as decorate_name
from jimmy.config import get_servers, get_server from jimmy.config import get_servers, get_server, get_config
from jimmy.db import OllamaThread from jimmy.db import OllamaThread
from humanize import naturalsize, naturaldelta from humanize import naturalsize, naturaldelta
@ -84,7 +85,7 @@ class Chat(commands.Cog):
if self.server_locks[server.name].locked(): if self.server_locks[server.name].locked():
embed.add_field( embed.add_field(
name=decorate_name(server), name=decorate_name(server),
value=f"\N{closed lock with key} In use.", value="\N{closed lock with key} In use.",
inline=False inline=False
) )
fields[server] = len(embed.fields) - 1 fields[server] = len(embed.fields) - 1
@ -92,7 +93,7 @@ class Chat(commands.Cog):
else: else:
embed.add_field( embed.add_field(
name=decorate_name(server), name=decorate_name(server),
value=f"\N{hourglass with flowing sand} Waiting...", value="\N{hourglass with flowing sand} Waiting...",
inline=False inline=False
) )
fields[server] = len(embed.fields) - 1 fields[server] = len(embed.fields) - 1
@ -110,14 +111,14 @@ class Chat(commands.Cog):
embed.set_field_at( embed.set_field_at(
fields[server], fields[server],
name=decorate_name(server), name=decorate_name(server),
value=f"\N{white heavy check mark} Online.", value="\N{white heavy check mark} Online.",
inline=False inline=False
) )
else: else:
embed.set_field_at( embed.set_field_at(
fields[server], fields[server],
name=decorate_name(server), name=decorate_name(server),
value=f"\N{cross mark} Offline.", value="\N{cross mark} Offline.",
inline=False inline=False
) )
await ctx.edit(embed=embed) await ctx.edit(embed=embed)
@ -238,7 +239,7 @@ class Chat(commands.Cog):
if model not in [x["model"] for x in tags]: if model not in [x["model"] for x in tags]:
embed = discord.Embed( embed = discord.Embed(
title=f"Downloading {model} on {server}.", title=f"Downloading {model} on {server}.",
description=f"Initiating download...", description="Initiating download...",
color=discord.Color.blurple() color=discord.Color.blurple()
) )
view = StopDownloadView(ctx) view = StopDownloadView(ctx)
@ -321,8 +322,29 @@ class Chat(commands.Cog):
messages.append( messages.append(
await create_ollama_message(msg["content"], role=msg["role"]) await create_ollama_message(msg["content"], role=msg["role"])
) )
elif len(thread_id) == 6:
# Is a legacy thread
_cfg = get_config()["truth_api"]
async with httpx.AsyncClient(
base_url=_cfg["url"],
auth=(_cfg["username"], _cfg["password"])
) as http_client:
response = await http_client.get(f"/threads/threads:{thread_id}")
if response.status_code == 200:
thread = await response.json()
messages = thread["messages"]
thread = OllamaThread(
messages=[{"role": m["role"], "content": m["content"]} for m in messages],
)
await thread.save()
else: else:
await ctx.respond(content="No thread with that ID exists.", delete_after=30) return await ctx.respond(
content="Failed to fetch legacy ollama thread from jimmy v2: HTTP %d (`%r`)" % (
response.status_code, response.text
),
)
else:
return await ctx.respond(content="No thread with that ID exists.", delete_after=30)
if system_prompt: if system_prompt:
messages.append(await create_ollama_message(system_prompt, role="system")) messages.append(await create_ollama_message(system_prompt, role="system"))
messages.append(await create_ollama_message(prompt, images=[await image.read()] if image else None)) messages.append(await create_ollama_message(prompt, images=[await image.read()] if image else None))
@ -374,7 +396,7 @@ class Chat(commands.Cog):
embed.add_field( embed.add_field(
name="Full chat", name="Full chat",
value="The chat was too long to fit in this message. " value="The chat was too long to fit in this message. "
f"You can download the `full-chat.txt` file to see the full message." "You can download the `full-chat.txt` file to see the full message."
) )
else: else:
file = discord.utils.MISSING file = discord.utils.MISSING
@ -419,7 +441,7 @@ class Chat(commands.Cog):
return await ctx.respond(f"\N{cross mark} Server {server.name!r} is not responding") return await ctx.respond(f"\N{cross mark} Server {server.name!r} is not responding")
embed = discord.Embed( embed = discord.Embed(
title=f"Downloading {model} on {server}.", title=f"Downloading {model} on {server}.",
description=f"Initiating download...", description="Initiating download...",
color=discord.Color.blurple() color=discord.Color.blurple()
) )
view = StopDownloadView(ctx) view = StopDownloadView(ctx)

View file

@ -102,6 +102,10 @@ def get_config():
_loaded.setdefault("servers", {}) _loaded.setdefault("servers", {})
_loaded["servers"].setdefault("order", []) _loaded["servers"].setdefault("order", [])
_loaded.setdefault("bot", {}) _loaded.setdefault("bot", {})
_loaded.setdefault("truth_api", {})
_loaded["truth_api"].setdefault("url", "https://bots.nexy7574.co.uk/jimmy/v2/api")
_loaded["truth_api"].setdefault("username", "invalid")
_loaded["truth_api"].setdefault("password", "invalid")
if database_url := os.getenv("DATABASE_URL"): if database_url := os.getenv("DATABASE_URL"):
_loaded["bot"]["db_url"] = database_url _loaded["bot"]["db_url"] = database_url
return _loaded return _loaded

2
tox.ini Normal file
View file

@ -0,0 +1,2 @@
[flake8]
max-line-length = 120