329 lines
14 KiB
Python
329 lines
14 KiB
Python
|
import asyncio
|
||
|
import io
|
||
|
import logging
|
||
|
import time
|
||
|
import typing
|
||
|
import contextlib
|
||
|
from fnmatch import fnmatch
|
||
|
|
||
|
import discord
|
||
|
from discord import Interaction
|
||
|
from ollama import AsyncClient, ResponseError, Options
|
||
|
from discord.ext import commands
|
||
|
from jimmy.utils import async_ratio, create_ollama_message
|
||
|
from jimmy.config import get_servers, ServerConfig, get_server
|
||
|
from jimmy.db import OllamaThread
|
||
|
from humanize import naturalsize
|
||
|
|
||
|
|
||
|
@contextlib.asynccontextmanager
|
||
|
async def ollama_client(host: str, **kwargs) -> AsyncClient:
|
||
|
host = str(host)
|
||
|
client = AsyncClient(host, **kwargs)
|
||
|
try:
|
||
|
yield client
|
||
|
finally:
|
||
|
# Ollama doesn't auto-close the client, so we have to do it ourselves.
|
||
|
await client._client.aclose()
|
||
|
|
||
|
|
||
|
class StopDownloadView(discord.ui.View):
|
||
|
def __init__(self, ctx: discord.ApplicationContext):
|
||
|
super().__init__(timeout=None)
|
||
|
self.ctx = ctx
|
||
|
self.event = asyncio.Event()
|
||
|
|
||
|
async def interaction_check(self, interaction: Interaction) -> bool:
|
||
|
return interaction.user == self.ctx.user
|
||
|
|
||
|
@discord.ui.button(label="Cancel", style=discord.ButtonStyle.danger)
|
||
|
async def cancel_download(self, _, interaction: discord.Interaction):
|
||
|
self.event.set()
|
||
|
self.stop()
|
||
|
await interaction.response.edit_message(view=None)
|
||
|
|
||
|
|
||
|
async def get_available_tags_autocomplete(ctx: discord.AutocompleteContext):
|
||
|
chosen_server = get_server(ctx.options.get("server") or get_servers()[0].name)
|
||
|
async with ollama_client(str(chosen_server.base_url), timeout=2) as client:
|
||
|
tags = (await client.list())["models"]
|
||
|
return [tag["model"] for tag in tags if ctx.value.casefold() in tag["model"].casefold()]
|
||
|
|
||
|
|
||
|
_ServerOptionChoices = [discord.OptionChoice(server.name, server.name) for server in get_servers()]
|
||
|
|
||
|
|
||
|
class Chat(commands.Cog):
|
||
|
def __init__(self, bot):
|
||
|
self.bot = bot
|
||
|
self.server_locks = {}
|
||
|
for server in get_servers():
|
||
|
self.server_locks[server.name] = asyncio.Lock()
|
||
|
self.log = logging.getLogger(__name__)
|
||
|
|
||
|
@commands.slash_command()
|
||
|
async def status(self, ctx: discord.ApplicationContext):
|
||
|
"""Checks the status on all servers."""
|
||
|
await ctx.defer()
|
||
|
|
||
|
def decorate_name(_s: ServerConfig):
|
||
|
if _s.gpu:
|
||
|
return f"{_s.name} (\u26A1)"
|
||
|
return _s.name
|
||
|
|
||
|
embed = discord.Embed(
|
||
|
title="Ollama Statuses:",
|
||
|
color=discord.Color.blurple()
|
||
|
)
|
||
|
fields = {}
|
||
|
for server in get_servers():
|
||
|
if server.throttle and self.server_locks[server.name].locked():
|
||
|
embed.add_field(
|
||
|
name=decorate_name(server),
|
||
|
value=f"\N{closed lock with key} In use.",
|
||
|
inline=False
|
||
|
)
|
||
|
fields[server] = len(embed.fields) - 1
|
||
|
continue
|
||
|
else:
|
||
|
embed.add_field(
|
||
|
name=decorate_name(server),
|
||
|
value=f"\N{hourglass with flowing sand} Waiting...",
|
||
|
inline=False
|
||
|
)
|
||
|
fields[server] = len(embed.fields) - 1
|
||
|
|
||
|
await ctx.respond(embed=embed)
|
||
|
tasks = {}
|
||
|
for server in get_servers():
|
||
|
if server.throttle and self.server_locks[server.name].locked():
|
||
|
continue
|
||
|
tasks[server] = asyncio.create_task(server.is_online())
|
||
|
|
||
|
await asyncio.gather(*tasks.values())
|
||
|
for server, task in tasks.items():
|
||
|
if task.result():
|
||
|
embed.set_field_at(
|
||
|
fields[server],
|
||
|
name=decorate_name(server),
|
||
|
value=f"\N{white heavy check mark} Online.",
|
||
|
inline=False
|
||
|
)
|
||
|
else:
|
||
|
embed.set_field_at(
|
||
|
fields[server],
|
||
|
name=decorate_name(server),
|
||
|
value=f"\N{cross mark} Offline.",
|
||
|
inline=False
|
||
|
)
|
||
|
await ctx.edit(embed=embed)
|
||
|
|
||
|
@commands.slash_command(name="ollama")
|
||
|
async def start_ollama_chat(
|
||
|
self,
|
||
|
ctx: discord.ApplicationContext,
|
||
|
prompt: str,
|
||
|
system_prompt: typing.Annotated[
|
||
|
str | None,
|
||
|
discord.Option(
|
||
|
discord.SlashCommandOptionType.string,
|
||
|
description="The system prompt to use.",
|
||
|
default=None
|
||
|
)
|
||
|
],
|
||
|
server: typing.Annotated[
|
||
|
str,
|
||
|
discord.Option(
|
||
|
discord.SlashCommandOptionType.string,
|
||
|
description="The server to use.",
|
||
|
choices=_ServerOptionChoices,
|
||
|
default=get_servers()[0].name
|
||
|
)
|
||
|
],
|
||
|
model: typing.Annotated[
|
||
|
str,
|
||
|
discord.Option(
|
||
|
discord.SlashCommandOptionType.string,
|
||
|
description="The model to use.",
|
||
|
autocomplete=get_available_tags_autocomplete,
|
||
|
default="llama3:latest"
|
||
|
)
|
||
|
],
|
||
|
image: typing.Annotated[
|
||
|
discord.Attachment | None,
|
||
|
discord.Option(
|
||
|
discord.SlashCommandOptionType.attachment,
|
||
|
description="The image to use for llava.",
|
||
|
default=None
|
||
|
)
|
||
|
],
|
||
|
thread_id: typing.Annotated[
|
||
|
str | None,
|
||
|
discord.Option(
|
||
|
discord.SlashCommandOptionType.string,
|
||
|
description="The thread ID to continue.",
|
||
|
default=None
|
||
|
)
|
||
|
]
|
||
|
):
|
||
|
"""Have a chat with ollama"""
|
||
|
await ctx.defer()
|
||
|
server = get_server(server)
|
||
|
async with self.server_locks[server.name]:
|
||
|
if not await server.is_online():
|
||
|
await ctx.respond(
|
||
|
content=f"{server} is offline.",
|
||
|
delete_after=60
|
||
|
)
|
||
|
return
|
||
|
async with ollama_client(str(server.base_url)) as client:
|
||
|
client: AsyncClient
|
||
|
self.log.info("Checking if %r has the model %r", server, model)
|
||
|
tags = (await client.list())["models"]
|
||
|
if model not in [x["model"] for x in tags]:
|
||
|
embed = discord.Embed(
|
||
|
title=f"Downloading {model} on {server}.",
|
||
|
description=f"Initiating download...",
|
||
|
color=discord.Color.blurple()
|
||
|
)
|
||
|
view = StopDownloadView(ctx)
|
||
|
await ctx.respond(
|
||
|
embed=embed,
|
||
|
view=view
|
||
|
)
|
||
|
last_edit = 0
|
||
|
async with ctx.typing():
|
||
|
try:
|
||
|
last_completed = 0
|
||
|
last_completed_ts = time.time()
|
||
|
|
||
|
async for line in await client.pull(model, stream=True):
|
||
|
if view.event.is_set():
|
||
|
embed.add_field(name="Error!", value="Download cancelled.")
|
||
|
embed.colour = discord.Colour.red()
|
||
|
await ctx.edit(embed=embed)
|
||
|
return
|
||
|
self.log.info("Response from %r: %r", server, line)
|
||
|
if line["status"] in {
|
||
|
"pulling manifest",
|
||
|
"verifying sha256 digest",
|
||
|
"writing manifest",
|
||
|
"removing any unused layers",
|
||
|
"success"
|
||
|
}:
|
||
|
embed.description = line["status"].capitalize()
|
||
|
else:
|
||
|
total = line["total"]
|
||
|
completed = line.get("completed", 0)
|
||
|
percent = round(completed / total * 100, 1)
|
||
|
pb_fill = "▰" * int(percent / 10)
|
||
|
pb_empty = "▱" * (10 - int(percent / 10))
|
||
|
bytes_per_second = completed - last_completed
|
||
|
bytes_per_second /= (time.time() - last_completed_ts)
|
||
|
last_completed = completed
|
||
|
last_completed_ts = time.time()
|
||
|
mbps = round((bytes_per_second * 8) / 1024 / 1024)
|
||
|
progress_bar = f"[{pb_fill}{pb_empty}]"
|
||
|
ns_total = naturalsize(total, binary=True)
|
||
|
ns_completed = naturalsize(completed, binary=True)
|
||
|
embed.description = (
|
||
|
f"{line['status'].capitalize()} {percent}% {progress_bar} "
|
||
|
f"({ns_completed}/{ns_total} @ {mbps} Mb/s)"
|
||
|
)
|
||
|
|
||
|
if time.time() - last_edit >= 2.5:
|
||
|
await ctx.edit(embed=embed)
|
||
|
last_edit = time.time()
|
||
|
except ResponseError as err:
|
||
|
if err.error.endswith("file does not exist"):
|
||
|
await ctx.edit(
|
||
|
embed=None,
|
||
|
content="The model %r does not exist." % model,
|
||
|
delete_after=60,
|
||
|
view=None
|
||
|
)
|
||
|
else:
|
||
|
embed.add_field(
|
||
|
name="Error!",
|
||
|
value=err.error
|
||
|
)
|
||
|
embed.colour = discord.Colour.red()
|
||
|
await ctx.edit(embed=embed, view=None)
|
||
|
return
|
||
|
else:
|
||
|
embed.colour = discord.Colour.green()
|
||
|
embed.description = f"Downloaded {model} on {server}."
|
||
|
await ctx.edit(embed=embed, delete_after=30, view=None)
|
||
|
|
||
|
messages = []
|
||
|
if thread_id:
|
||
|
thread = await OllamaThread.get_or_none(thread_id=thread_id)
|
||
|
if thread:
|
||
|
for msg in thread.messages:
|
||
|
messages.append(
|
||
|
await create_ollama_message(msg["content"], role=msg["role"])
|
||
|
)
|
||
|
else:
|
||
|
await ctx.respond(content="No thread with that ID exists.", delete_after=30)
|
||
|
if system_prompt:
|
||
|
messages.append(await create_ollama_message(system_prompt, role="system"))
|
||
|
messages.append(await create_ollama_message(prompt, images=[await image.read()] if image else None))
|
||
|
embed = discord.Embed(title=f"{model}:", description="")
|
||
|
view = StopDownloadView(ctx)
|
||
|
msg = await ctx.respond(
|
||
|
embed=embed,
|
||
|
view=view
|
||
|
)
|
||
|
last_edit = time.time()
|
||
|
buffer = io.StringIO()
|
||
|
async for response in await client.chat(
|
||
|
model,
|
||
|
messages,
|
||
|
stream=True,
|
||
|
options=Options(
|
||
|
num_ctx=4096,
|
||
|
low_vram=server.vram_gb < 8,
|
||
|
temperature=1.5
|
||
|
)
|
||
|
):
|
||
|
self.log.info("Response from %r: %r", server, response)
|
||
|
buffer.write(response["message"]["content"])
|
||
|
|
||
|
if len(buffer.getvalue()) > 4096:
|
||
|
value = "... " + buffer.getvalue()[4:]
|
||
|
else:
|
||
|
value = buffer.getvalue()
|
||
|
embed.description = value
|
||
|
if view.event.is_set():
|
||
|
embed.add_field(name="Error!", value="Chat cancelled.")
|
||
|
embed.colour = discord.Colour.red()
|
||
|
await msg.edit(embed=embed, view=None)
|
||
|
return
|
||
|
if time.time() - last_edit >= 2.5:
|
||
|
await msg.edit(embed=embed, view=view)
|
||
|
last_edit = time.time()
|
||
|
embed.colour = discord.Colour.green()
|
||
|
if len(buffer.getvalue()) > 4096:
|
||
|
file = discord.File(
|
||
|
io.BytesIO(buffer.getvalue().encode()),
|
||
|
filename="full-chat.txt"
|
||
|
)
|
||
|
embed.add_field(
|
||
|
name="Full chat",
|
||
|
value="The chat was too long to fit in this message. "
|
||
|
f"You can download the `full-chat.txt` file to see the full message."
|
||
|
)
|
||
|
else:
|
||
|
file = discord.utils.MISSING
|
||
|
|
||
|
thread = OllamaThread(
|
||
|
messages=[{"role": m["role"], "content": m["content"]} for m in messages],
|
||
|
)
|
||
|
await thread.save()
|
||
|
embed.set_footer(text=f"Chat ID: {thread.thread_id}")
|
||
|
await msg.edit(embed=embed, view=None, file=file)
|
||
|
|
||
|
|
||
|
def setup(bot):
|
||
|
bot.add_cog(Chat(bot))
|