Allow importing legacy threads
All checks were successful
Build and Publish / build_and_publish (push) Successful in 58s
All checks were successful
Build and Publish / build_and_publish (push) Successful in 58s
This commit is contained in:
parent
d203376850
commit
448a23affa
3 changed files with 37 additions and 9 deletions
|
@ -7,11 +7,12 @@ import typing
|
||||||
import contextlib
|
import contextlib
|
||||||
|
|
||||||
import discord
|
import discord
|
||||||
|
import httpx
|
||||||
from discord import Interaction
|
from discord import Interaction
|
||||||
from ollama import AsyncClient, ResponseError, Options
|
from ollama import AsyncClient, ResponseError, Options
|
||||||
from discord.ext import commands
|
from discord.ext import commands
|
||||||
from jimmy.utils import create_ollama_message, find_suitable_server, decorate_server_name as decorate_name
|
from jimmy.utils import create_ollama_message, find_suitable_server, decorate_server_name as decorate_name
|
||||||
from jimmy.config import get_servers, get_server
|
from jimmy.config import get_servers, get_server, get_config
|
||||||
from jimmy.db import OllamaThread
|
from jimmy.db import OllamaThread
|
||||||
from humanize import naturalsize, naturaldelta
|
from humanize import naturalsize, naturaldelta
|
||||||
|
|
||||||
|
@ -84,7 +85,7 @@ class Chat(commands.Cog):
|
||||||
if self.server_locks[server.name].locked():
|
if self.server_locks[server.name].locked():
|
||||||
embed.add_field(
|
embed.add_field(
|
||||||
name=decorate_name(server),
|
name=decorate_name(server),
|
||||||
value=f"\N{closed lock with key} In use.",
|
value="\N{closed lock with key} In use.",
|
||||||
inline=False
|
inline=False
|
||||||
)
|
)
|
||||||
fields[server] = len(embed.fields) - 1
|
fields[server] = len(embed.fields) - 1
|
||||||
|
@ -92,7 +93,7 @@ class Chat(commands.Cog):
|
||||||
else:
|
else:
|
||||||
embed.add_field(
|
embed.add_field(
|
||||||
name=decorate_name(server),
|
name=decorate_name(server),
|
||||||
value=f"\N{hourglass with flowing sand} Waiting...",
|
value="\N{hourglass with flowing sand} Waiting...",
|
||||||
inline=False
|
inline=False
|
||||||
)
|
)
|
||||||
fields[server] = len(embed.fields) - 1
|
fields[server] = len(embed.fields) - 1
|
||||||
|
@ -110,14 +111,14 @@ class Chat(commands.Cog):
|
||||||
embed.set_field_at(
|
embed.set_field_at(
|
||||||
fields[server],
|
fields[server],
|
||||||
name=decorate_name(server),
|
name=decorate_name(server),
|
||||||
value=f"\N{white heavy check mark} Online.",
|
value="\N{white heavy check mark} Online.",
|
||||||
inline=False
|
inline=False
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
embed.set_field_at(
|
embed.set_field_at(
|
||||||
fields[server],
|
fields[server],
|
||||||
name=decorate_name(server),
|
name=decorate_name(server),
|
||||||
value=f"\N{cross mark} Offline.",
|
value="\N{cross mark} Offline.",
|
||||||
inline=False
|
inline=False
|
||||||
)
|
)
|
||||||
await ctx.edit(embed=embed)
|
await ctx.edit(embed=embed)
|
||||||
|
@ -238,7 +239,7 @@ class Chat(commands.Cog):
|
||||||
if model not in [x["model"] for x in tags]:
|
if model not in [x["model"] for x in tags]:
|
||||||
embed = discord.Embed(
|
embed = discord.Embed(
|
||||||
title=f"Downloading {model} on {server}.",
|
title=f"Downloading {model} on {server}.",
|
||||||
description=f"Initiating download...",
|
description="Initiating download...",
|
||||||
color=discord.Color.blurple()
|
color=discord.Color.blurple()
|
||||||
)
|
)
|
||||||
view = StopDownloadView(ctx)
|
view = StopDownloadView(ctx)
|
||||||
|
@ -321,8 +322,29 @@ class Chat(commands.Cog):
|
||||||
messages.append(
|
messages.append(
|
||||||
await create_ollama_message(msg["content"], role=msg["role"])
|
await create_ollama_message(msg["content"], role=msg["role"])
|
||||||
)
|
)
|
||||||
|
elif len(thread_id) == 6:
|
||||||
|
# Is a legacy thread
|
||||||
|
_cfg = get_config()["truth_api"]
|
||||||
|
async with httpx.AsyncClient(
|
||||||
|
base_url=_cfg["url"],
|
||||||
|
auth=(_cfg["username"], _cfg["password"])
|
||||||
|
) as http_client:
|
||||||
|
response = await http_client.get(f"/threads/threads:{thread_id}")
|
||||||
|
if response.status_code == 200:
|
||||||
|
thread = await response.json()
|
||||||
|
messages = thread["messages"]
|
||||||
|
thread = OllamaThread(
|
||||||
|
messages=[{"role": m["role"], "content": m["content"]} for m in messages],
|
||||||
|
)
|
||||||
|
await thread.save()
|
||||||
else:
|
else:
|
||||||
await ctx.respond(content="No thread with that ID exists.", delete_after=30)
|
return await ctx.respond(
|
||||||
|
content="Failed to fetch legacy ollama thread from jimmy v2: HTTP %d (`%r`)" % (
|
||||||
|
response.status_code, response.text
|
||||||
|
),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return await ctx.respond(content="No thread with that ID exists.", delete_after=30)
|
||||||
if system_prompt:
|
if system_prompt:
|
||||||
messages.append(await create_ollama_message(system_prompt, role="system"))
|
messages.append(await create_ollama_message(system_prompt, role="system"))
|
||||||
messages.append(await create_ollama_message(prompt, images=[await image.read()] if image else None))
|
messages.append(await create_ollama_message(prompt, images=[await image.read()] if image else None))
|
||||||
|
@ -374,7 +396,7 @@ class Chat(commands.Cog):
|
||||||
embed.add_field(
|
embed.add_field(
|
||||||
name="Full chat",
|
name="Full chat",
|
||||||
value="The chat was too long to fit in this message. "
|
value="The chat was too long to fit in this message. "
|
||||||
f"You can download the `full-chat.txt` file to see the full message."
|
"You can download the `full-chat.txt` file to see the full message."
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
file = discord.utils.MISSING
|
file = discord.utils.MISSING
|
||||||
|
@ -419,7 +441,7 @@ class Chat(commands.Cog):
|
||||||
return await ctx.respond(f"\N{cross mark} Server {server.name!r} is not responding")
|
return await ctx.respond(f"\N{cross mark} Server {server.name!r} is not responding")
|
||||||
embed = discord.Embed(
|
embed = discord.Embed(
|
||||||
title=f"Downloading {model} on {server}.",
|
title=f"Downloading {model} on {server}.",
|
||||||
description=f"Initiating download...",
|
description="Initiating download...",
|
||||||
color=discord.Color.blurple()
|
color=discord.Color.blurple()
|
||||||
)
|
)
|
||||||
view = StopDownloadView(ctx)
|
view = StopDownloadView(ctx)
|
||||||
|
|
|
@ -102,6 +102,10 @@ def get_config():
|
||||||
_loaded.setdefault("servers", {})
|
_loaded.setdefault("servers", {})
|
||||||
_loaded["servers"].setdefault("order", [])
|
_loaded["servers"].setdefault("order", [])
|
||||||
_loaded.setdefault("bot", {})
|
_loaded.setdefault("bot", {})
|
||||||
|
_loaded.setdefault("truth_api", {})
|
||||||
|
_loaded["truth_api"].setdefault("url", "https://bots.nexy7574.co.uk/jimmy/v2/api")
|
||||||
|
_loaded["truth_api"].setdefault("username", "invalid")
|
||||||
|
_loaded["truth_api"].setdefault("password", "invalid")
|
||||||
if database_url := os.getenv("DATABASE_URL"):
|
if database_url := os.getenv("DATABASE_URL"):
|
||||||
_loaded["bot"]["db_url"] = database_url
|
_loaded["bot"]["db_url"] = database_url
|
||||||
return _loaded
|
return _loaded
|
||||||
|
|
2
tox.ini
Normal file
2
tox.ini
Normal file
|
@ -0,0 +1,2 @@
|
||||||
|
[flake8]
|
||||||
|
max-line-length = 120
|
Loading…
Reference in a new issue