Working shit

* /status works
* starting up works
* Config works
* /ollama works
* autocomplete works
* downloading works
* threads work
* images work
This commit is contained in:
Nexus 2024-06-10 03:14:52 +01:00
parent e45e87167f
commit 2b4c324ba6
Signed by: nex
GPG key ID: 0FA334385D0B689F
8 changed files with 547 additions and 1 deletions

2
.gitignore vendored
View file

@ -281,3 +281,5 @@ pyrightconfig.json
.ionide .ionide
# End of https://www.toptal.com/developers/gitignore/api/python,visualstudiocode,pycharm+all # End of https://www.toptal.com/developers/gitignore/api/python,visualstudiocode,pycharm+all
.venv/
default.db

31
config.toml Normal file
View file

@ -0,0 +1,31 @@
[bot]
token = "MTI0OTUwNjU1ODIwODExNDgxMA.G_JIT7.R2R_1-2IHhdzEf6mHUgIa82oyRuwonBRrkd_Pc"
debug_guilds = [1106243455816052847]
[servers]
order = ["SpeedySHRoNK", "IvyPC", "nextop-ts", "optiplex", "shronk"]
[servers.SpeedySHRoNK]
base_url = "http://ollama.shronk.net:11434"
gpu = true
vram_gb = 10
[servers.IvyPC]
base_url = "http://192.168.0.26:11435"
gpu = true
vram_gb = 8
[servers.nextop-ts]
base_url = "http://laptop-linux.fluffy-gentoo.ts.net:11434"
gpu = true
vram_gb = 4
[servers.optiplex]
base_url = "http://192.168.0.254:11434"
gpu = false
vram_gb = 16
[servers.shronk]
base_url = "http://ollama.shronk.net:11434"
gpu = false
vram_gb = 16

328
jimmy/cogs/chat.py Normal file
View file

@ -0,0 +1,328 @@
import asyncio
import io
import logging
import time
import typing
import contextlib
from fnmatch import fnmatch
import discord
from discord import Interaction
from ollama import AsyncClient, ResponseError, Options
from discord.ext import commands
from jimmy.utils import async_ratio, create_ollama_message
from jimmy.config import get_servers, ServerConfig, get_server
from jimmy.db import OllamaThread
from humanize import naturalsize
@contextlib.asynccontextmanager
async def ollama_client(host: str, **kwargs) -> AsyncClient:
host = str(host)
client = AsyncClient(host, **kwargs)
try:
yield client
finally:
# Ollama doesn't auto-close the client, so we have to do it ourselves.
await client._client.aclose()
class StopDownloadView(discord.ui.View):
def __init__(self, ctx: discord.ApplicationContext):
super().__init__(timeout=None)
self.ctx = ctx
self.event = asyncio.Event()
async def interaction_check(self, interaction: Interaction) -> bool:
return interaction.user == self.ctx.user
@discord.ui.button(label="Cancel", style=discord.ButtonStyle.danger)
async def cancel_download(self, _, interaction: discord.Interaction):
self.event.set()
self.stop()
await interaction.response.edit_message(view=None)
async def get_available_tags_autocomplete(ctx: discord.AutocompleteContext):
chosen_server = get_server(ctx.options.get("server") or get_servers()[0].name)
async with ollama_client(str(chosen_server.base_url), timeout=2) as client:
tags = (await client.list())["models"]
return [tag["model"] for tag in tags if ctx.value.casefold() in tag["model"].casefold()]
_ServerOptionChoices = [discord.OptionChoice(server.name, server.name) for server in get_servers()]
class Chat(commands.Cog):
def __init__(self, bot):
self.bot = bot
self.server_locks = {}
for server in get_servers():
self.server_locks[server.name] = asyncio.Lock()
self.log = logging.getLogger(__name__)
@commands.slash_command()
async def status(self, ctx: discord.ApplicationContext):
"""Checks the status on all servers."""
await ctx.defer()
def decorate_name(_s: ServerConfig):
if _s.gpu:
return f"{_s.name} (\u26A1)"
return _s.name
embed = discord.Embed(
title="Ollama Statuses:",
color=discord.Color.blurple()
)
fields = {}
for server in get_servers():
if server.throttle and self.server_locks[server.name].locked():
embed.add_field(
name=decorate_name(server),
value=f"\N{closed lock with key} In use.",
inline=False
)
fields[server] = len(embed.fields) - 1
continue
else:
embed.add_field(
name=decorate_name(server),
value=f"\N{hourglass with flowing sand} Waiting...",
inline=False
)
fields[server] = len(embed.fields) - 1
await ctx.respond(embed=embed)
tasks = {}
for server in get_servers():
if server.throttle and self.server_locks[server.name].locked():
continue
tasks[server] = asyncio.create_task(server.is_online())
await asyncio.gather(*tasks.values())
for server, task in tasks.items():
if task.result():
embed.set_field_at(
fields[server],
name=decorate_name(server),
value=f"\N{white heavy check mark} Online.",
inline=False
)
else:
embed.set_field_at(
fields[server],
name=decorate_name(server),
value=f"\N{cross mark} Offline.",
inline=False
)
await ctx.edit(embed=embed)
@commands.slash_command(name="ollama")
async def start_ollama_chat(
self,
ctx: discord.ApplicationContext,
prompt: str,
system_prompt: typing.Annotated[
str | None,
discord.Option(
discord.SlashCommandOptionType.string,
description="The system prompt to use.",
default=None
)
],
server: typing.Annotated[
str,
discord.Option(
discord.SlashCommandOptionType.string,
description="The server to use.",
choices=_ServerOptionChoices,
default=get_servers()[0].name
)
],
model: typing.Annotated[
str,
discord.Option(
discord.SlashCommandOptionType.string,
description="The model to use.",
autocomplete=get_available_tags_autocomplete,
default="llama3:latest"
)
],
image: typing.Annotated[
discord.Attachment | None,
discord.Option(
discord.SlashCommandOptionType.attachment,
description="The image to use for llava.",
default=None
)
],
thread_id: typing.Annotated[
str | None,
discord.Option(
discord.SlashCommandOptionType.string,
description="The thread ID to continue.",
default=None
)
]
):
"""Have a chat with ollama"""
await ctx.defer()
server = get_server(server)
async with self.server_locks[server.name]:
if not await server.is_online():
await ctx.respond(
content=f"{server} is offline.",
delete_after=60
)
return
async with ollama_client(str(server.base_url)) as client:
client: AsyncClient
self.log.info("Checking if %r has the model %r", server, model)
tags = (await client.list())["models"]
if model not in [x["model"] for x in tags]:
embed = discord.Embed(
title=f"Downloading {model} on {server}.",
description=f"Initiating download...",
color=discord.Color.blurple()
)
view = StopDownloadView(ctx)
await ctx.respond(
embed=embed,
view=view
)
last_edit = 0
async with ctx.typing():
try:
last_completed = 0
last_completed_ts = time.time()
async for line in await client.pull(model, stream=True):
if view.event.is_set():
embed.add_field(name="Error!", value="Download cancelled.")
embed.colour = discord.Colour.red()
await ctx.edit(embed=embed)
return
self.log.info("Response from %r: %r", server, line)
if line["status"] in {
"pulling manifest",
"verifying sha256 digest",
"writing manifest",
"removing any unused layers",
"success"
}:
embed.description = line["status"].capitalize()
else:
total = line["total"]
completed = line.get("completed", 0)
percent = round(completed / total * 100, 1)
pb_fill = "" * int(percent / 10)
pb_empty = "" * (10 - int(percent / 10))
bytes_per_second = completed - last_completed
bytes_per_second /= (time.time() - last_completed_ts)
last_completed = completed
last_completed_ts = time.time()
mbps = round((bytes_per_second * 8) / 1024 / 1024)
progress_bar = f"[{pb_fill}{pb_empty}]"
ns_total = naturalsize(total, binary=True)
ns_completed = naturalsize(completed, binary=True)
embed.description = (
f"{line['status'].capitalize()} {percent}% {progress_bar} "
f"({ns_completed}/{ns_total} @ {mbps} Mb/s)"
)
if time.time() - last_edit >= 2.5:
await ctx.edit(embed=embed)
last_edit = time.time()
except ResponseError as err:
if err.error.endswith("file does not exist"):
await ctx.edit(
embed=None,
content="The model %r does not exist." % model,
delete_after=60,
view=None
)
else:
embed.add_field(
name="Error!",
value=err.error
)
embed.colour = discord.Colour.red()
await ctx.edit(embed=embed, view=None)
return
else:
embed.colour = discord.Colour.green()
embed.description = f"Downloaded {model} on {server}."
await ctx.edit(embed=embed, delete_after=30, view=None)
messages = []
if thread_id:
thread = await OllamaThread.get_or_none(thread_id=thread_id)
if thread:
for msg in thread.messages:
messages.append(
await create_ollama_message(msg["content"], role=msg["role"])
)
else:
await ctx.respond(content="No thread with that ID exists.", delete_after=30)
if system_prompt:
messages.append(await create_ollama_message(system_prompt, role="system"))
messages.append(await create_ollama_message(prompt, images=[await image.read()] if image else None))
embed = discord.Embed(title=f"{model}:", description="")
view = StopDownloadView(ctx)
msg = await ctx.respond(
embed=embed,
view=view
)
last_edit = time.time()
buffer = io.StringIO()
async for response in await client.chat(
model,
messages,
stream=True,
options=Options(
num_ctx=4096,
low_vram=server.vram_gb < 8,
temperature=1.5
)
):
self.log.info("Response from %r: %r", server, response)
buffer.write(response["message"]["content"])
if len(buffer.getvalue()) > 4096:
value = "... " + buffer.getvalue()[4:]
else:
value = buffer.getvalue()
embed.description = value
if view.event.is_set():
embed.add_field(name="Error!", value="Chat cancelled.")
embed.colour = discord.Colour.red()
await msg.edit(embed=embed, view=None)
return
if time.time() - last_edit >= 2.5:
await msg.edit(embed=embed, view=view)
last_edit = time.time()
embed.colour = discord.Colour.green()
if len(buffer.getvalue()) > 4096:
file = discord.File(
io.BytesIO(buffer.getvalue().encode()),
filename="full-chat.txt"
)
embed.add_field(
name="Full chat",
value="The chat was too long to fit in this message. "
f"You can download the `full-chat.txt` file to see the full message."
)
else:
file = discord.utils.MISSING
thread = OllamaThread(
messages=[{"role": m["role"], "content": m["content"]} for m in messages],
)
await thread.save()
embed.set_footer(text=f"Chat ID: {thread.thread_id}")
await msg.edit(embed=embed, view=None, file=file)
def setup(bot):
bot.add_cog(Chat(bot))

69
jimmy/config.py Normal file
View file

@ -0,0 +1,69 @@
import tomllib
import logging
from typing import Callable
import httpx
from pydantic import BaseModel, Field, AnyHttpUrl
log = logging.getLogger(__name__)
class ServerConfig(BaseModel):
name: str = Field(min_length=1, max_length=32)
base_url: AnyHttpUrl
gpu: bool = False
vram_gb: int = 4
throttle: bool = False
def __repr__(self):
return "<ServerConfig name={0.name} base_url={0.base_url} gpu={0.gpu!s} vram_gb={0.vram_gb}>".format(self)
def __str__(self):
return self.name
async def is_online(self) -> bool:
"""
Checks that the current server is online and responding to requests.
"""
async with httpx.AsyncClient(base_url=str(self.base_url)) as client:
try:
response = await client.get("/api/tags")
return response.status_code == 200
except httpx.RequestError:
return False
def __hash__(self):
return hash(self.base_url)
def get_servers(filter_func: Callable[[ServerConfig], bool] = None) -> list[ServerConfig]:
config = get_config()
keys = list(config["servers"].keys())
log.info("Servers: %r", keys)
try:
keys = config["servers"].pop("order")
log.info("Ordered keys: %r", keys)
except ValueError:
pass
servers = [ServerConfig(name=key, **config["servers"][key]) for key in keys]
if filter_func:
servers = list(filter(filter_func, servers))
return servers
def get_server(name_or_base_url: str) -> ServerConfig | None:
servers = get_servers()
for server in servers:
if server.name == name_or_base_url or server.base_url == name_or_base_url:
return server
return None
def get_config():
with open("config.toml", "rb") as file:
_loaded = tomllib.load(file)
_loaded.setdefault("servers", {})
_loaded["servers"].setdefault("order", [])
_loaded.setdefault("bot", {})
return _loaded

13
jimmy/db.py Normal file
View file

@ -0,0 +1,13 @@
import os
from tortoise.models import Model
from tortoise import fields
class OllamaThread(Model):
thread_id = fields.CharField(max_length=255, unique=True, default=lambda: os.urandom(4).hex())
messages = fields.JSONField(default=[])
created_at = fields.DatetimeField(auto_now_add=True)
class Meta:
table = "ollama_threads"

53
jimmy/main.py Normal file
View file

@ -0,0 +1,53 @@
import os
import sys
import logging
import discord
from discord.ext import commands
from tortoise import Tortoise
sys.path.extend("..") # noqa: E402
from .config import get_config
log = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
class SentientJimmy(commands.Bot):
def __init__(self):
intents = discord.Intents.default()
# noinspection PyUnresolvedReferences
intents.message_content = True
super().__init__(
commands.when_mentioned_or("."),
intents=intents,
case_insensitive=True,
strip_after_prefix=True,
debug_guilds=get_config()["bot"].get("debug_guilds"),
)
self.load_extension("jimmy.cogs.chat")
self.load_extension("jishaku")
async def start(self, token: str, *, reconnect: bool = True) -> None:
is_docker = os.path.exists("/.dockerenv")
default_db = "sqlite://:memory:" if is_docker else "sqlite://default.db"
await Tortoise.init(
db_url=get_config()["bot"].get("db_url", default_db),
modules={"models": ["jimmy.db"]}
)
await Tortoise.generate_schemas()
await super().start(token, reconnect=reconnect)
async def close(self) -> None:
await Tortoise.close_connections()
await super().close()
def run(self) -> None:
token = get_config()["bot"]["token"]
super().run(token)
bot = SentientJimmy()
if __name__ == "__main__":
bot.run()

47
jimmy/utils.py Normal file
View file

@ -0,0 +1,47 @@
import asyncio
import typing
from functools import partial
from fuzzywuzzy.fuzz import ratio
from ollama import Message
__all__ = (
'async_ratio',
'create_ollama_message',
)
async def async_ratio(a: str, b: str) -> int:
"""
Wraps fuzzywuzzy ratio in an async function
:param a: str - first string
:param b: str - second string
:return: int - ratio of similarity
"""
return await asyncio.to_thread(partial(ratio, a, b))
async def create_ollama_message(
content: str,
role: typing.Literal["system", "assistant", "user"] = "user",
images: typing.List[str | bytes] = None
) -> Message:
"""
Create a message for ollama.
:param content: str - the content of the message
:param role: str - the role of the message
:param images: list - the images to attach to the message
:return: dict - the message
"""
def factory(**kwargs):
return Message(**kwargs)
return await asyncio.to_thread(
partial(
factory,
role=role,
content=content,
images=images
)
)

View file

@ -1,5 +1,8 @@
py-cord~=2.5 py-cord~=2.5
ollama-python~=0.1 ollama~=0.2
tortoise-orm[asyncpg]~=0.21 tortoise-orm[asyncpg]~=0.21
uvicorn[standard]~=0.30 uvicorn[standard]~=0.30
fastapi~=0.111 fastapi~=0.111
jishaku~=2.5
fuzzywuzzy~=0.18
humanize~=4.9