sentient-jimmy/jimmy/cogs/chat.py
nexy7574 e32d866ad4
All checks were successful
Build and Publish / build_and_publish (push) Successful in 48s
add PS command
2024-06-11 01:37:20 +01:00

560 lines
23 KiB
Python

import asyncio
import datetime
import io
import logging
import time
import typing
import contextlib
import discord
from discord import Interaction
from ollama import AsyncClient, ResponseError, Options
from discord.ext import commands
from jimmy.utils import create_ollama_message, find_suitable_server, decorate_server_name as decorate_name
from jimmy.config import get_servers, get_server
from jimmy.db import OllamaThread
from humanize import naturalsize, naturaldelta
@contextlib.asynccontextmanager
async def ollama_client(host: str, **kwargs) -> AsyncClient:
host = str(host)
client = AsyncClient(host, **kwargs)
try:
yield client
finally:
# Ollama doesn't auto-close the client, so we have to do it ourselves.
await client._client.aclose()
class StopDownloadView(discord.ui.View):
def __init__(self, ctx: discord.ApplicationContext):
super().__init__(timeout=None)
self.ctx = ctx
self.event = asyncio.Event()
async def interaction_check(self, interaction: Interaction) -> bool:
return interaction.user == self.ctx.user
@discord.ui.button(label="Cancel", style=discord.ButtonStyle.danger)
async def cancel_download(self, _, interaction: discord.Interaction):
self.event.set()
self.stop()
await interaction.response.edit_message(view=None)
async def get_available_tags_autocomplete(ctx: discord.AutocompleteContext):
chosen_server = get_server(ctx.options.get("server") or get_servers()[0].name)
async with ollama_client(str(chosen_server.base_url), timeout=2) as client:
tags = (await client.list())["models"]
v = [tag["model"] for tag in tags if ctx.value.casefold() in tag["model"].casefold()]
return [ctx.value, *v][:25]
_ServerOptionAutocomplete = discord.utils.basic_autocomplete(
[x.name for x in get_servers()]
)
class Chat(commands.Cog):
def __init__(self, bot):
self.bot = bot
self.server_locks = {}
for server in get_servers():
self.server_locks[server.name] = asyncio.Lock()
self.log = logging.getLogger(__name__)
ollama_group = discord.SlashCommandGroup(
name="ollama",
description="Commands related to ollama.",
guild_only=True
)
@ollama_group.command()
async def status(self, ctx: discord.ApplicationContext):
"""Checks the status on all servers."""
await ctx.defer()
embed = discord.Embed(
title="Ollama Statuses:",
color=discord.Color.blurple()
)
fields = {}
for server in get_servers():
if server.throttle and self.server_locks[server.name].locked():
embed.add_field(
name=decorate_name(server),
value=f"\N{closed lock with key} In use.",
inline=False
)
fields[server] = len(embed.fields) - 1
continue
else:
embed.add_field(
name=decorate_name(server),
value=f"\N{hourglass with flowing sand} Waiting...",
inline=False
)
fields[server] = len(embed.fields) - 1
await ctx.respond(embed=embed)
tasks = {}
for server in get_servers():
if server.throttle and self.server_locks[server.name].locked():
continue
tasks[server] = asyncio.create_task(server.is_online())
await asyncio.gather(*tasks.values())
for server, task in tasks.items():
if task.result():
embed.set_field_at(
fields[server],
name=decorate_name(server),
value=f"\N{white heavy check mark} Online.",
inline=False
)
else:
embed.set_field_at(
fields[server],
name=decorate_name(server),
value=f"\N{cross mark} Offline.",
inline=False
)
await ctx.edit(embed=embed)
@ollama_group.command(name="server-info")
async def get_server_info(
self,
ctx: discord.ApplicationContext,
server: typing.Annotated[
str,
discord.Option(
discord.SlashCommandOptionType.string,
description="The server to use.",
autocomplete=_ServerOptionAutocomplete,
default=get_servers()[0].name
)
]
):
"""Gets information on a given server"""
await ctx.defer()
server = get_server(server)
is_online = await server.is_online()
y = "\N{white heavy check mark}"
x = "\N{cross mark}"
t = {True: y, False: x}
rt = "VRAM" if server.gpu else "RAM"
lines = [
f"Name: {server.name!r}",
f"Base URL: {server.base_url!r}",
f"GPU Enabled: {t[server.gpu]}",
f"{rt}: {server.vram_gb:,} GB",
f"Default Model: {server.default_model!r}",
f"Is Online: {t[is_online]}"
]
p = "```md\n" + "\n".join(lines) + "```"
return await ctx.respond(p)
@ollama_group.command(name="chat")
async def start_ollama_chat(
self,
ctx: discord.ApplicationContext,
prompt: str,
system_prompt: typing.Annotated[
str | None,
discord.Option(
discord.SlashCommandOptionType.string,
description="The system prompt to use.",
default=None
)
],
server: typing.Annotated[
str,
discord.Option(
discord.SlashCommandOptionType.string,
description="The server to use.",
autocomplete=_ServerOptionAutocomplete,
default=get_servers()[0].name
)
],
model: typing.Annotated[
str,
discord.Option(
discord.SlashCommandOptionType.string,
description="The model to use.",
autocomplete=get_available_tags_autocomplete,
default="default"
)
],
image: typing.Annotated[
discord.Attachment | None,
discord.Option(
discord.SlashCommandOptionType.attachment,
description="The image to use for llava.",
default=None
)
],
thread_id: typing.Annotated[
str | None,
discord.Option(
discord.SlashCommandOptionType.string,
description="The thread ID to continue.",
default=None
)
],
temperature: typing.Annotated[
float,
discord.Option(
discord.SlashCommandOptionType.number,
description="The temperature to use.",
default=1.5,
min_value=0.0,
max_value=2.0
)
]
):
"""Have a chat with ollama"""
await ctx.defer()
server = get_server(server)
if not server:
return await ctx.respond("\N{cross mark} Unknown Server.")
elif not await server.is_online():
await ctx.respond(
content=f"{server} is offline. Finding a suitable server...",
)
try:
server = await find_suitable_server()
except ValueError as err:
return await ctx.edit(content=str(err), delete_after=30)
await ctx.delete(delay=5)
async with self.server_locks[server.name]:
if model == "default":
model = server.default_model
async with ollama_client(str(server.base_url)) as client:
client: AsyncClient
self.log.info("Checking if %r has the model %r", server, model)
tags = (await client.list())["models"]
# Download code. It's recommended to collapse this in the editor.
if model not in [x["model"] for x in tags]:
embed = discord.Embed(
title=f"Downloading {model} on {server}.",
description=f"Initiating download...",
color=discord.Color.blurple()
)
view = StopDownloadView(ctx)
await ctx.respond(
embed=embed,
view=view
)
last_edit = 0
async with ctx.typing():
try:
last_completed = 0
last_completed_ts = time.time()
async for line in await client.pull(model, stream=True):
if view.event.is_set():
embed.add_field(name="Error!", value="Download cancelled.")
embed.colour = discord.Colour.red()
await ctx.edit(embed=embed)
return
self.log.debug("Response from %r: %r", server, line)
if line["status"] in {
"pulling manifest",
"verifying sha256 digest",
"writing manifest",
"removing any unused layers",
"success"
}:
embed.description = line["status"].capitalize()
else:
total = line["total"]
completed = line.get("completed", 0)
percent = round(completed / total * 100, 1)
pb_fill = "" * int(percent / 10)
pb_empty = "" * (10 - int(percent / 10))
bytes_per_second = completed - last_completed
bytes_per_second /= (time.time() - last_completed_ts)
last_completed = completed
last_completed_ts = time.time()
mbps = round((bytes_per_second * 8) / 1024 / 1024)
eta = (total - completed) / max(1, bytes_per_second)
progress_bar = f"[{pb_fill}{pb_empty}]"
ns_total = naturalsize(total, binary=True)
ns_completed = naturalsize(completed, binary=True)
embed.description = (
f"{line['status'].capitalize()} {percent}% {progress_bar} "
f"({ns_completed}/{ns_total} @ {mbps} Mb/s) "
f"[ETA: {naturaldelta(eta)}]"
)
if time.time() - last_edit >= 2.5:
await ctx.edit(embed=embed)
last_edit = time.time()
except ResponseError as err:
if err.error.endswith("file does not exist"):
await ctx.edit(
embed=None,
content="The model %r does not exist." % model,
delete_after=60,
view=None
)
else:
embed.add_field(
name="Error!",
value=err.error
)
embed.colour = discord.Colour.red()
await ctx.edit(embed=embed, view=None)
return
else:
embed.colour = discord.Colour.green()
embed.description = f"Downloaded {model} on {server}."
await ctx.edit(embed=embed, delete_after=30, view=None)
messages = []
thread = None
if thread_id:
thread = await OllamaThread.get_or_none(thread_id=thread_id)
if thread:
for msg in thread.messages:
messages.append(
await create_ollama_message(msg["content"], role=msg["role"])
)
else:
await ctx.respond(content="No thread with that ID exists.", delete_after=30)
if system_prompt:
messages.append(await create_ollama_message(system_prompt, role="system"))
messages.append(await create_ollama_message(prompt, images=[await image.read()] if image else None))
embed = discord.Embed(description="")
embed.set_author(
name=f"{model} @ {decorate_name(server)!r}" if server.gpu else model,
icon_url="https://ollama.com/public/icon-64x64.png"
)
view = StopDownloadView(ctx)
msg = await ctx.respond(
embed=embed,
view=view
)
last_edit = time.time()
buffer = io.StringIO()
async for response in await client.chat(
model,
messages,
stream=True,
options=Options(
num_ctx=4096,
low_vram=server.vram_gb < 8,
temperature=temperature
)
):
response: dict
self.log.debug("Response from %r: %r", server, response)
buffer.write(response["message"]["content"])
if len(buffer.getvalue()) > 4096:
value = "... " + buffer.getvalue()[4:]
else:
value = buffer.getvalue()
embed.description = value
if view.event.is_set():
embed.add_field(name="Error!", value="Chat cancelled.")
embed.colour = discord.Colour.red()
await msg.edit(embed=embed, view=None)
return
if time.time() - last_edit >= 2.5:
await msg.edit(embed=embed, view=view)
last_edit = time.time()
embed.colour = discord.Colour.green()
if len(buffer.getvalue()) > 4096:
file = discord.File(
io.BytesIO(buffer.getvalue().encode()),
filename="full-chat.txt"
)
embed.add_field(
name="Full chat",
value="The chat was too long to fit in this message. "
f"You can download the `full-chat.txt` file to see the full message."
)
else:
file = discord.utils.MISSING
if not thread:
thread = OllamaThread(
messages=[{"role": m["role"], "content": m["content"]} for m in messages],
)
await thread.save()
embed.set_footer(text=f"Chat ID: {thread.thread_id}")
await msg.edit(embed=embed, view=None, file=file)
@ollama_group.command(name="pull")
async def pull_ollama_model(
self,
ctx: discord.ApplicationContext,
server: typing.Annotated[
str,
discord.Option(
discord.SlashCommandOptionType.string,
description="The server to use.",
autocomplete=_ServerOptionAutocomplete,
default=get_servers()[0].name
)
],
model: typing.Annotated[
str,
discord.Option(
discord.SlashCommandOptionType.string,
description="The model to use.",
autocomplete=get_available_tags_autocomplete,
default="llama3:latest"
)
],
):
"""Downloads a tag on the target server"""
await ctx.defer()
server = get_server(server)
if not server:
return await ctx.respond("\N{cross mark} Unknown server.")
elif not await server.is_online():
return await ctx.respond(f"\N{cross mark} Server {server.name!r} is not responding")
embed = discord.Embed(
title=f"Downloading {model} on {server}.",
description=f"Initiating download...",
color=discord.Color.blurple()
)
view = StopDownloadView(ctx)
await ctx.respond(
embed=embed,
view=view
)
last_edit = 0
async with ctx.typing():
try:
last_completed = 0
last_completed_ts = time.time()
async with ollama_client(str(server.base_url)) as client:
async for line in await client.pull(model, stream=True):
if view.event.is_set():
embed.add_field(name="Error!", value="Download cancelled.")
embed.colour = discord.Colour.red()
await ctx.edit(embed=embed)
return
self.log.debug("Response from %r: %r", server, line)
if line["status"] in {
"pulling manifest",
"verifying sha256 digest",
"writing manifest",
"removing any unused layers",
"success"
}:
embed.description = line["status"].capitalize()
else:
total = line["total"]
completed = line.get("completed", 0)
percent = round(completed / total * 100, 1)
pb_fill = "" * int(percent / 10)
pb_empty = "" * (10 - int(percent / 10))
bytes_per_second = completed - last_completed
bytes_per_second /= (time.time() - last_completed_ts)
last_completed = completed
last_completed_ts = time.time()
mbps = round((bytes_per_second * 8) / 1024 / 1024)
eta = (total - completed) / max(1, bytes_per_second)
progress_bar = f"[{pb_fill}{pb_empty}]"
ns_total = naturalsize(total, binary=True)
ns_completed = naturalsize(completed, binary=True)
embed.description = (
f"{line['status'].capitalize()} {percent}% {progress_bar} "
f"({ns_completed}/{ns_total} @ {mbps} Mb/s) "
f"[ETA: {naturaldelta(eta)}]"
)
if time.time() - last_edit >= 2.5:
await ctx.edit(embed=embed)
last_edit = time.time()
except ResponseError as err:
if err.error.endswith("file does not exist"):
await ctx.edit(
embed=None,
content="The model %r does not exist." % model,
delete_after=60,
view=None
)
else:
embed.add_field(
name="Error!",
value=err.error
)
embed.colour = discord.Colour.red()
await ctx.edit(embed=embed, view=None)
return
else:
embed.colour = discord.Colour.green()
embed.description = f"Downloaded {model} on {server}."
await ctx.edit(embed=embed, delete_after=30, view=None)
@ollama_group.command(name="ps")
async def ollama_proc_list(
self,
ctx: discord.ApplicationContext,
server: typing.Annotated[
str,
discord.Option(
discord.SlashCommandOptionType.string,
description="The server to use.",
autocomplete=_ServerOptionAutocomplete,
default=get_servers()[0].name
)
]
):
"""Checks the loaded models on the target server"""
await ctx.defer()
server = get_server(server)
if not server:
return await ctx.respond("\N{cross mark} Unknown server.")
elif not await server.is_online():
return await ctx.respond(f"\N{cross mark} Server {server.name!r} is not responding")
async with ollama_client(str(server.base_url)) as client:
response = (await client.ps())["models"]
if not response:
embed = discord.Embed(
title=f"No models loaded on {server}.",
color=discord.Color.blurple()
)
await ctx.respond(embed=embed)
embed = discord.Embed(
title=f"Models loaded on {server}",
color=discord.Color.blurple()
)
for model in response[:25]:
size = naturalsize(model["size"], binary=True)
size_vram = naturalsize(model["size_vram"], binary=True)
size_ram = naturalsize(model["size"] - model["size_vram"], binary=True)
percent_in_vram = round(model["size_vram"] / model["size"] * 100)
percent_in_ram = 100 - percent_in_vram
expires = datetime.datetime.fromisoformat(model["expires_at"])
lines = [
f"* Size: {size}",
f"* Unloaded: {discord.utils.format_dt(expires, style='R')}",
]
if percent_in_ram > 0:
lines.extend(
[
f"* VRAM/RAM: {percent_in_vram}%/{percent_in_ram}%",
f"* VRAM Size: {size_vram}",
f"* RAM Size: {size_ram}"
]
)
else:
lines.append(f"* VRAM Size: {size_vram} (100%)")
embed.add_field(
name=model["model"],
value="\n".join(lines),
inline=False
)
await ctx.respond(embed=embed)
def setup(bot):
bot.add_cog(Chat(bot))