sentient-jimmy/jimmy/cogs/chat.py

583 lines
24 KiB
Python
Raw Permalink Normal View History

import asyncio
2024-06-11 01:37:20 +01:00
import datetime
import io
import logging
import time
import typing
import contextlib
import discord
2024-06-16 15:53:43 +01:00
import httpx
from discord import Interaction
from ollama import AsyncClient, ResponseError, Options
from discord.ext import commands
2024-06-10 17:03:58 +01:00
from jimmy.utils import create_ollama_message, find_suitable_server, decorate_server_name as decorate_name
2024-06-16 15:53:43 +01:00
from jimmy.config import get_servers, get_server, get_config
2024-06-10 17:03:58 +01:00
from jimmy.db import OllamaThread
from humanize import naturalsize, naturaldelta
@contextlib.asynccontextmanager
async def ollama_client(host: str, **kwargs) -> AsyncClient:
host = str(host)
client = AsyncClient(host, **kwargs)
try:
yield client
finally:
# Ollama doesn't auto-close the client, so we have to do it ourselves.
await client._client.aclose()
class StopDownloadView(discord.ui.View):
def __init__(self, ctx: discord.ApplicationContext):
super().__init__(timeout=None)
self.ctx = ctx
self.event = asyncio.Event()
async def interaction_check(self, interaction: Interaction) -> bool:
return interaction.user == self.ctx.user
@discord.ui.button(label="Cancel", style=discord.ButtonStyle.danger)
async def cancel_download(self, _, interaction: discord.Interaction):
self.event.set()
self.stop()
await interaction.response.edit_message(view=None)
async def get_available_tags_autocomplete(ctx: discord.AutocompleteContext):
chosen_server = get_server(ctx.options.get("server") or get_servers()[0].name)
async with ollama_client(str(chosen_server.base_url), timeout=2) as client:
tags = (await client.list())["models"]
2024-06-11 00:58:17 +01:00
v = [tag["model"] for tag in tags if ctx.value.casefold() in tag["model"].casefold()]
return [ctx.value, *v][:25]
2024-06-11 00:58:17 +01:00
_ServerOptionAutocomplete = discord.utils.basic_autocomplete(
[x.name for x in get_servers()]
)
class Chat(commands.Cog):
def __init__(self, bot):
self.bot = bot
self.server_locks = {}
for server in get_servers():
self.server_locks[server.name] = asyncio.Lock()
self.log = logging.getLogger(__name__)
2024-06-11 01:09:59 +01:00
ollama_group = discord.SlashCommandGroup(
name="ollama",
description="Commands related to ollama.",
2024-06-11 01:37:20 +01:00
guild_only=True
2024-06-11 01:09:59 +01:00
)
@ollama_group.command()
async def status(self, ctx: discord.ApplicationContext):
"""Checks the status on all servers."""
await ctx.defer()
embed = discord.Embed(
title="Ollama Statuses:",
color=discord.Color.blurple()
)
fields = {}
for server in get_servers():
2024-06-11 01:44:34 +01:00
if self.server_locks[server.name].locked():
embed.add_field(
name=decorate_name(server),
2024-06-16 15:53:43 +01:00
value="\N{closed lock with key} In use.",
inline=False
)
fields[server] = len(embed.fields) - 1
continue
else:
embed.add_field(
name=decorate_name(server),
2024-06-16 15:53:43 +01:00
value="\N{hourglass with flowing sand} Waiting...",
inline=False
)
fields[server] = len(embed.fields) - 1
await ctx.respond(embed=embed)
tasks = {}
for server in get_servers():
2024-06-11 01:44:34 +01:00
if self.server_locks[server.name].locked():
continue
tasks[server] = asyncio.create_task(server.is_online())
await asyncio.gather(*tasks.values())
for server, task in tasks.items():
if task.result():
embed.set_field_at(
fields[server],
name=decorate_name(server),
2024-06-16 15:53:43 +01:00
value="\N{white heavy check mark} Online.",
inline=False
)
else:
embed.set_field_at(
fields[server],
name=decorate_name(server),
2024-06-16 15:53:43 +01:00
value="\N{cross mark} Offline.",
inline=False
)
await ctx.edit(embed=embed)
2024-06-11 01:09:59 +01:00
@ollama_group.command(name="server-info")
async def get_server_info(
self,
ctx: discord.ApplicationContext,
server: typing.Annotated[
str,
discord.Option(
discord.SlashCommandOptionType.string,
description="The server to use.",
autocomplete=_ServerOptionAutocomplete,
default=get_servers()[0].name
)
]
):
"""Gets information on a given server"""
await ctx.defer()
server = get_server(server)
is_online = await server.is_online()
y = "\N{white heavy check mark}"
x = "\N{cross mark}"
t = {True: y, False: x}
rt = "VRAM" if server.gpu else "RAM"
lines = [
f"Name: {server.name!r}",
f"Base URL: {server.base_url!r}",
f"GPU Enabled: {t[server.gpu]}",
f"{rt}: {server.vram_gb:,} GB",
f"Default Model: {server.default_model!r}",
f"Is Online: {t[is_online]}"
]
p = "```md\n" + "\n".join(lines) + "```"
return await ctx.respond(p)
2024-06-11 00:58:17 +01:00
@ollama_group.command(name="chat")
async def start_ollama_chat(
self,
ctx: discord.ApplicationContext,
prompt: str,
system_prompt: typing.Annotated[
str | None,
discord.Option(
discord.SlashCommandOptionType.string,
description="The system prompt to use.",
default=None
)
],
server: typing.Annotated[
str,
discord.Option(
discord.SlashCommandOptionType.string,
description="The server to use.",
2024-06-11 00:58:17 +01:00
autocomplete=_ServerOptionAutocomplete,
default=get_servers()[0].name
)
],
model: typing.Annotated[
str,
discord.Option(
discord.SlashCommandOptionType.string,
description="The model to use.",
autocomplete=get_available_tags_autocomplete,
2024-06-11 00:58:17 +01:00
default="default"
)
],
image: typing.Annotated[
discord.Attachment | None,
discord.Option(
discord.SlashCommandOptionType.attachment,
description="The image to use for llava.",
default=None
)
],
thread_id: typing.Annotated[
str | None,
discord.Option(
discord.SlashCommandOptionType.string,
description="The thread ID to continue.",
default=None
)
],
temperature: typing.Annotated[
float,
discord.Option(
discord.SlashCommandOptionType.number,
description="The temperature to use.",
default=1.5,
min_value=0.0,
max_value=2.0
)
]
):
"""Have a chat with ollama"""
await ctx.defer()
server = get_server(server)
2024-06-11 00:58:17 +01:00
if not server:
return await ctx.respond("\N{cross mark} Unknown Server.")
elif not await server.is_online():
await ctx.respond(
content=f"{server} is offline. Finding a suitable server...",
)
try:
server = await find_suitable_server()
except ValueError as err:
return await ctx.edit(content=str(err), delete_after=30)
await ctx.delete(delay=5)
async with self.server_locks[server.name]:
2024-06-11 00:58:17 +01:00
if model == "default":
model = server.default_model
async with ollama_client(str(server.base_url)) as client:
client: AsyncClient
self.log.info("Checking if %r has the model %r", server, model)
tags = (await client.list())["models"]
2024-06-11 00:58:17 +01:00
# Download code. It's recommended to collapse this in the editor.
if model not in [x["model"] for x in tags]:
embed = discord.Embed(
title=f"Downloading {model} on {server}.",
2024-06-16 15:53:43 +01:00
description="Initiating download...",
color=discord.Color.blurple()
)
view = StopDownloadView(ctx)
await ctx.respond(
embed=embed,
view=view
)
last_edit = 0
async with ctx.typing():
try:
last_completed = 0
last_completed_ts = time.time()
async for line in await client.pull(model, stream=True):
if view.event.is_set():
embed.add_field(name="Error!", value="Download cancelled.")
embed.colour = discord.Colour.red()
await ctx.edit(embed=embed)
return
self.log.debug("Response from %r: %r", server, line)
if line["status"] in {
"pulling manifest",
"verifying sha256 digest",
"writing manifest",
"removing any unused layers",
"success"
}:
embed.description = line["status"].capitalize()
else:
total = line["total"]
completed = line.get("completed", 0)
percent = round(completed / total * 100, 1)
pb_fill = "" * int(percent / 10)
pb_empty = "" * (10 - int(percent / 10))
bytes_per_second = completed - last_completed
bytes_per_second /= (time.time() - last_completed_ts)
last_completed = completed
last_completed_ts = time.time()
mbps = round((bytes_per_second * 8) / 1024 / 1024)
eta = (total - completed) / max(1, bytes_per_second)
progress_bar = f"[{pb_fill}{pb_empty}]"
ns_total = naturalsize(total, binary=True)
ns_completed = naturalsize(completed, binary=True)
embed.description = (
f"{line['status'].capitalize()} {percent}% {progress_bar} "
f"({ns_completed}/{ns_total} @ {mbps} Mb/s) "
f"[ETA: {naturaldelta(eta)}]"
)
if time.time() - last_edit >= 2.5:
await ctx.edit(embed=embed)
last_edit = time.time()
except ResponseError as err:
if err.error.endswith("file does not exist"):
await ctx.edit(
embed=None,
content="The model %r does not exist." % model,
delete_after=60,
view=None
)
else:
embed.add_field(
name="Error!",
value=err.error
)
embed.colour = discord.Colour.red()
await ctx.edit(embed=embed, view=None)
return
else:
embed.colour = discord.Colour.green()
embed.description = f"Downloaded {model} on {server}."
await ctx.edit(embed=embed, delete_after=30, view=None)
messages = []
2024-06-11 00:58:17 +01:00
thread = None
if thread_id:
thread = await OllamaThread.get_or_none(thread_id=thread_id)
if thread:
for msg in thread.messages:
messages.append(
await create_ollama_message(msg["content"], role=msg["role"])
)
2024-06-16 15:53:43 +01:00
elif len(thread_id) == 6:
# Is a legacy thread
_cfg = get_config()["truth_api"]
async with httpx.AsyncClient(
base_url=_cfg["url"],
auth=(_cfg["username"], _cfg["password"])
) as http_client:
2024-06-16 16:10:50 +01:00
response = await http_client.get(f"/ollama/thread/threads:{thread_id}")
2024-06-16 15:53:43 +01:00
if response.status_code == 200:
2024-06-16 16:15:01 +01:00
thread = response.json()
2024-06-16 15:53:43 +01:00
messages = thread["messages"]
thread = OllamaThread(
messages=[{"role": m["role"], "content": m["content"]} for m in messages],
)
await thread.save()
else:
return await ctx.respond(
content="Failed to fetch legacy ollama thread from jimmy v2: HTTP %d (`%r`)" % (
response.status_code, response.text
),
)
else:
2024-06-16 15:53:43 +01:00
return await ctx.respond(content="No thread with that ID exists.", delete_after=30)
if system_prompt:
messages.append(await create_ollama_message(system_prompt, role="system"))
messages.append(await create_ollama_message(prompt, images=[await image.read()] if image else None))
embed = discord.Embed(description="")
embed.set_author(
name=f"{model} @ {decorate_name(server)!r}" if server.gpu else model,
icon_url="https://ollama.com/public/icon-64x64.png"
)
view = StopDownloadView(ctx)
msg = await ctx.respond(
embed=embed,
view=view
)
last_edit = time.time()
buffer = io.StringIO()
async for response in await client.chat(
model,
messages,
stream=True,
options=Options(
num_ctx=4096,
low_vram=server.vram_gb < 8,
temperature=temperature
)
):
response: dict
self.log.debug("Response from %r: %r", server, response)
buffer.write(response["message"]["content"])
if len(buffer.getvalue()) > 4096:
value = "... " + buffer.getvalue()[4:]
else:
value = buffer.getvalue()
embed.description = value
if view.event.is_set():
embed.add_field(name="Error!", value="Chat cancelled.")
embed.colour = discord.Colour.red()
await msg.edit(embed=embed, view=None)
return
if time.time() - last_edit >= 2.5:
await msg.edit(embed=embed, view=view)
last_edit = time.time()
embed.colour = discord.Colour.green()
if len(buffer.getvalue()) > 4096:
file = discord.File(
io.BytesIO(buffer.getvalue().encode()),
filename="full-chat.txt"
)
embed.add_field(
name="Full chat",
value="The chat was too long to fit in this message. "
2024-06-16 15:53:43 +01:00
"You can download the `full-chat.txt` file to see the full message."
)
else:
file = discord.utils.MISSING
2024-06-11 00:58:17 +01:00
if not thread:
thread = OllamaThread(
messages=[{"role": m["role"], "content": m["content"]} for m in messages],
)
await thread.save()
embed.set_footer(text=f"Chat ID: {thread.thread_id}")
await msg.edit(embed=embed, view=None, file=file)
2024-06-11 00:58:17 +01:00
@ollama_group.command(name="pull")
async def pull_ollama_model(
self,
ctx: discord.ApplicationContext,
server: typing.Annotated[
str,
discord.Option(
discord.SlashCommandOptionType.string,
description="The server to use.",
autocomplete=_ServerOptionAutocomplete,
default=get_servers()[0].name
)
],
model: typing.Annotated[
str,
discord.Option(
discord.SlashCommandOptionType.string,
description="The model to use.",
autocomplete=get_available_tags_autocomplete,
default="llama3:latest"
)
],
):
"""Downloads a tag on the target server"""
await ctx.defer()
server = get_server(server)
if not server:
return await ctx.respond("\N{cross mark} Unknown server.")
elif not await server.is_online():
return await ctx.respond(f"\N{cross mark} Server {server.name!r} is not responding")
embed = discord.Embed(
title=f"Downloading {model} on {server}.",
2024-06-16 15:53:43 +01:00
description="Initiating download...",
2024-06-11 00:58:17 +01:00
color=discord.Color.blurple()
)
view = StopDownloadView(ctx)
await ctx.respond(
embed=embed,
view=view
)
last_edit = 0
async with ctx.typing():
try:
last_completed = 0
last_completed_ts = time.time()
2024-06-11 01:21:34 +01:00
async with ollama_client(str(server.base_url)) as client:
async for line in await client.pull(model, stream=True):
if view.event.is_set():
embed.add_field(name="Error!", value="Download cancelled.")
embed.colour = discord.Colour.red()
await ctx.edit(embed=embed)
return
self.log.debug("Response from %r: %r", server, line)
if line["status"] in {
"pulling manifest",
"verifying sha256 digest",
"writing manifest",
"removing any unused layers",
"success"
}:
embed.description = line["status"].capitalize()
else:
total = line["total"]
completed = line.get("completed", 0)
percent = round(completed / total * 100, 1)
pb_fill = "" * int(percent / 10)
pb_empty = "" * (10 - int(percent / 10))
bytes_per_second = completed - last_completed
bytes_per_second /= (time.time() - last_completed_ts)
last_completed = completed
last_completed_ts = time.time()
mbps = round((bytes_per_second * 8) / 1024 / 1024)
eta = (total - completed) / max(1, bytes_per_second)
progress_bar = f"[{pb_fill}{pb_empty}]"
ns_total = naturalsize(total, binary=True)
ns_completed = naturalsize(completed, binary=True)
embed.description = (
f"{line['status'].capitalize()} {percent}% {progress_bar} "
f"({ns_completed}/{ns_total} @ {mbps} Mb/s) "
f"[ETA: {naturaldelta(eta)}]"
)
2024-06-11 00:58:17 +01:00
2024-06-11 01:21:34 +01:00
if time.time() - last_edit >= 2.5:
await ctx.edit(embed=embed)
last_edit = time.time()
2024-06-11 00:58:17 +01:00
except ResponseError as err:
if err.error.endswith("file does not exist"):
await ctx.edit(
embed=None,
content="The model %r does not exist." % model,
delete_after=60,
view=None
)
else:
embed.add_field(
name="Error!",
value=err.error
)
embed.colour = discord.Colour.red()
await ctx.edit(embed=embed, view=None)
return
else:
embed.colour = discord.Colour.green()
embed.description = f"Downloaded {model} on {server}."
await ctx.edit(embed=embed, delete_after=30, view=None)
2024-06-11 01:37:20 +01:00
@ollama_group.command(name="ps")
async def ollama_proc_list(
self,
ctx: discord.ApplicationContext,
server: typing.Annotated[
str,
discord.Option(
discord.SlashCommandOptionType.string,
description="The server to use.",
autocomplete=_ServerOptionAutocomplete,
default=get_servers()[0].name
)
]
):
"""Checks the loaded models on the target server"""
await ctx.defer()
server = get_server(server)
if not server:
return await ctx.respond("\N{cross mark} Unknown server.")
elif not await server.is_online():
return await ctx.respond(f"\N{cross mark} Server {server.name!r} is not responding")
async with ollama_client(str(server.base_url)) as client:
response = (await client.ps())["models"]
if not response:
embed = discord.Embed(
title=f"No models loaded on {server}.",
color=discord.Color.blurple()
)
2024-06-11 01:41:20 +01:00
return await ctx.respond(embed=embed)
2024-06-11 01:37:20 +01:00
embed = discord.Embed(
title=f"Models loaded on {server}",
color=discord.Color.blurple()
)
for model in response[:25]:
size = naturalsize(model["size"], binary=True)
size_vram = naturalsize(model["size_vram"], binary=True)
size_ram = naturalsize(model["size"] - model["size_vram"], binary=True)
percent_in_vram = round(model["size_vram"] / model["size"] * 100)
percent_in_ram = 100 - percent_in_vram
expires = datetime.datetime.fromisoformat(model["expires_at"])
lines = [
f"* Size: {size}",
f"* Unloaded: {discord.utils.format_dt(expires, style='R')}",
]
if percent_in_ram > 0:
lines.extend(
[
f"* VRAM/RAM: {percent_in_vram}%/{percent_in_ram}%",
f"* VRAM Size: {size_vram}",
f"* RAM Size: {size_ram}"
]
)
else:
lines.append(f"* VRAM Size: {size_vram} (100%)")
embed.add_field(
name=model["model"],
value="\n".join(lines),
inline=False
)
await ctx.respond(embed=embed)
def setup(bot):
bot.add_cog(Chat(bot))