Working shit
* /status works * starting up works * Config works * /ollama works * autocomplete works * downloading works * threads work * images work
This commit is contained in:
parent
e45e87167f
commit
2b4c324ba6
8 changed files with 547 additions and 1 deletions
2
.gitignore
vendored
2
.gitignore
vendored
|
@ -281,3 +281,5 @@ pyrightconfig.json
|
|||
.ionide
|
||||
|
||||
# End of https://www.toptal.com/developers/gitignore/api/python,visualstudiocode,pycharm+all
|
||||
.venv/
|
||||
default.db
|
||||
|
|
31
config.toml
Normal file
31
config.toml
Normal file
|
@ -0,0 +1,31 @@
|
|||
[bot]
|
||||
token = "MTI0OTUwNjU1ODIwODExNDgxMA.G_JIT7.R2R_1-2IHhdzEf6mHUgIa82oyRuwonBRrkd_Pc"
|
||||
debug_guilds = [1106243455816052847]
|
||||
|
||||
[servers]
|
||||
order = ["SpeedySHRoNK", "IvyPC", "nextop-ts", "optiplex", "shronk"]
|
||||
|
||||
[servers.SpeedySHRoNK]
|
||||
base_url = "http://ollama.shronk.net:11434"
|
||||
gpu = true
|
||||
vram_gb = 10
|
||||
|
||||
[servers.IvyPC]
|
||||
base_url = "http://192.168.0.26:11435"
|
||||
gpu = true
|
||||
vram_gb = 8
|
||||
|
||||
[servers.nextop-ts]
|
||||
base_url = "http://laptop-linux.fluffy-gentoo.ts.net:11434"
|
||||
gpu = true
|
||||
vram_gb = 4
|
||||
|
||||
[servers.optiplex]
|
||||
base_url = "http://192.168.0.254:11434"
|
||||
gpu = false
|
||||
vram_gb = 16
|
||||
|
||||
[servers.shronk]
|
||||
base_url = "http://ollama.shronk.net:11434"
|
||||
gpu = false
|
||||
vram_gb = 16
|
328
jimmy/cogs/chat.py
Normal file
328
jimmy/cogs/chat.py
Normal file
|
@ -0,0 +1,328 @@
|
|||
import asyncio
|
||||
import io
|
||||
import logging
|
||||
import time
|
||||
import typing
|
||||
import contextlib
|
||||
from fnmatch import fnmatch
|
||||
|
||||
import discord
|
||||
from discord import Interaction
|
||||
from ollama import AsyncClient, ResponseError, Options
|
||||
from discord.ext import commands
|
||||
from jimmy.utils import async_ratio, create_ollama_message
|
||||
from jimmy.config import get_servers, ServerConfig, get_server
|
||||
from jimmy.db import OllamaThread
|
||||
from humanize import naturalsize
|
||||
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def ollama_client(host: str, **kwargs) -> AsyncClient:
|
||||
host = str(host)
|
||||
client = AsyncClient(host, **kwargs)
|
||||
try:
|
||||
yield client
|
||||
finally:
|
||||
# Ollama doesn't auto-close the client, so we have to do it ourselves.
|
||||
await client._client.aclose()
|
||||
|
||||
|
||||
class StopDownloadView(discord.ui.View):
|
||||
def __init__(self, ctx: discord.ApplicationContext):
|
||||
super().__init__(timeout=None)
|
||||
self.ctx = ctx
|
||||
self.event = asyncio.Event()
|
||||
|
||||
async def interaction_check(self, interaction: Interaction) -> bool:
|
||||
return interaction.user == self.ctx.user
|
||||
|
||||
@discord.ui.button(label="Cancel", style=discord.ButtonStyle.danger)
|
||||
async def cancel_download(self, _, interaction: discord.Interaction):
|
||||
self.event.set()
|
||||
self.stop()
|
||||
await interaction.response.edit_message(view=None)
|
||||
|
||||
|
||||
async def get_available_tags_autocomplete(ctx: discord.AutocompleteContext):
|
||||
chosen_server = get_server(ctx.options.get("server") or get_servers()[0].name)
|
||||
async with ollama_client(str(chosen_server.base_url), timeout=2) as client:
|
||||
tags = (await client.list())["models"]
|
||||
return [tag["model"] for tag in tags if ctx.value.casefold() in tag["model"].casefold()]
|
||||
|
||||
|
||||
_ServerOptionChoices = [discord.OptionChoice(server.name, server.name) for server in get_servers()]
|
||||
|
||||
|
||||
class Chat(commands.Cog):
|
||||
def __init__(self, bot):
|
||||
self.bot = bot
|
||||
self.server_locks = {}
|
||||
for server in get_servers():
|
||||
self.server_locks[server.name] = asyncio.Lock()
|
||||
self.log = logging.getLogger(__name__)
|
||||
|
||||
@commands.slash_command()
|
||||
async def status(self, ctx: discord.ApplicationContext):
|
||||
"""Checks the status on all servers."""
|
||||
await ctx.defer()
|
||||
|
||||
def decorate_name(_s: ServerConfig):
|
||||
if _s.gpu:
|
||||
return f"{_s.name} (\u26A1)"
|
||||
return _s.name
|
||||
|
||||
embed = discord.Embed(
|
||||
title="Ollama Statuses:",
|
||||
color=discord.Color.blurple()
|
||||
)
|
||||
fields = {}
|
||||
for server in get_servers():
|
||||
if server.throttle and self.server_locks[server.name].locked():
|
||||
embed.add_field(
|
||||
name=decorate_name(server),
|
||||
value=f"\N{closed lock with key} In use.",
|
||||
inline=False
|
||||
)
|
||||
fields[server] = len(embed.fields) - 1
|
||||
continue
|
||||
else:
|
||||
embed.add_field(
|
||||
name=decorate_name(server),
|
||||
value=f"\N{hourglass with flowing sand} Waiting...",
|
||||
inline=False
|
||||
)
|
||||
fields[server] = len(embed.fields) - 1
|
||||
|
||||
await ctx.respond(embed=embed)
|
||||
tasks = {}
|
||||
for server in get_servers():
|
||||
if server.throttle and self.server_locks[server.name].locked():
|
||||
continue
|
||||
tasks[server] = asyncio.create_task(server.is_online())
|
||||
|
||||
await asyncio.gather(*tasks.values())
|
||||
for server, task in tasks.items():
|
||||
if task.result():
|
||||
embed.set_field_at(
|
||||
fields[server],
|
||||
name=decorate_name(server),
|
||||
value=f"\N{white heavy check mark} Online.",
|
||||
inline=False
|
||||
)
|
||||
else:
|
||||
embed.set_field_at(
|
||||
fields[server],
|
||||
name=decorate_name(server),
|
||||
value=f"\N{cross mark} Offline.",
|
||||
inline=False
|
||||
)
|
||||
await ctx.edit(embed=embed)
|
||||
|
||||
@commands.slash_command(name="ollama")
|
||||
async def start_ollama_chat(
|
||||
self,
|
||||
ctx: discord.ApplicationContext,
|
||||
prompt: str,
|
||||
system_prompt: typing.Annotated[
|
||||
str | None,
|
||||
discord.Option(
|
||||
discord.SlashCommandOptionType.string,
|
||||
description="The system prompt to use.",
|
||||
default=None
|
||||
)
|
||||
],
|
||||
server: typing.Annotated[
|
||||
str,
|
||||
discord.Option(
|
||||
discord.SlashCommandOptionType.string,
|
||||
description="The server to use.",
|
||||
choices=_ServerOptionChoices,
|
||||
default=get_servers()[0].name
|
||||
)
|
||||
],
|
||||
model: typing.Annotated[
|
||||
str,
|
||||
discord.Option(
|
||||
discord.SlashCommandOptionType.string,
|
||||
description="The model to use.",
|
||||
autocomplete=get_available_tags_autocomplete,
|
||||
default="llama3:latest"
|
||||
)
|
||||
],
|
||||
image: typing.Annotated[
|
||||
discord.Attachment | None,
|
||||
discord.Option(
|
||||
discord.SlashCommandOptionType.attachment,
|
||||
description="The image to use for llava.",
|
||||
default=None
|
||||
)
|
||||
],
|
||||
thread_id: typing.Annotated[
|
||||
str | None,
|
||||
discord.Option(
|
||||
discord.SlashCommandOptionType.string,
|
||||
description="The thread ID to continue.",
|
||||
default=None
|
||||
)
|
||||
]
|
||||
):
|
||||
"""Have a chat with ollama"""
|
||||
await ctx.defer()
|
||||
server = get_server(server)
|
||||
async with self.server_locks[server.name]:
|
||||
if not await server.is_online():
|
||||
await ctx.respond(
|
||||
content=f"{server} is offline.",
|
||||
delete_after=60
|
||||
)
|
||||
return
|
||||
async with ollama_client(str(server.base_url)) as client:
|
||||
client: AsyncClient
|
||||
self.log.info("Checking if %r has the model %r", server, model)
|
||||
tags = (await client.list())["models"]
|
||||
if model not in [x["model"] for x in tags]:
|
||||
embed = discord.Embed(
|
||||
title=f"Downloading {model} on {server}.",
|
||||
description=f"Initiating download...",
|
||||
color=discord.Color.blurple()
|
||||
)
|
||||
view = StopDownloadView(ctx)
|
||||
await ctx.respond(
|
||||
embed=embed,
|
||||
view=view
|
||||
)
|
||||
last_edit = 0
|
||||
async with ctx.typing():
|
||||
try:
|
||||
last_completed = 0
|
||||
last_completed_ts = time.time()
|
||||
|
||||
async for line in await client.pull(model, stream=True):
|
||||
if view.event.is_set():
|
||||
embed.add_field(name="Error!", value="Download cancelled.")
|
||||
embed.colour = discord.Colour.red()
|
||||
await ctx.edit(embed=embed)
|
||||
return
|
||||
self.log.info("Response from %r: %r", server, line)
|
||||
if line["status"] in {
|
||||
"pulling manifest",
|
||||
"verifying sha256 digest",
|
||||
"writing manifest",
|
||||
"removing any unused layers",
|
||||
"success"
|
||||
}:
|
||||
embed.description = line["status"].capitalize()
|
||||
else:
|
||||
total = line["total"]
|
||||
completed = line.get("completed", 0)
|
||||
percent = round(completed / total * 100, 1)
|
||||
pb_fill = "▰" * int(percent / 10)
|
||||
pb_empty = "▱" * (10 - int(percent / 10))
|
||||
bytes_per_second = completed - last_completed
|
||||
bytes_per_second /= (time.time() - last_completed_ts)
|
||||
last_completed = completed
|
||||
last_completed_ts = time.time()
|
||||
mbps = round((bytes_per_second * 8) / 1024 / 1024)
|
||||
progress_bar = f"[{pb_fill}{pb_empty}]"
|
||||
ns_total = naturalsize(total, binary=True)
|
||||
ns_completed = naturalsize(completed, binary=True)
|
||||
embed.description = (
|
||||
f"{line['status'].capitalize()} {percent}% {progress_bar} "
|
||||
f"({ns_completed}/{ns_total} @ {mbps} Mb/s)"
|
||||
)
|
||||
|
||||
if time.time() - last_edit >= 2.5:
|
||||
await ctx.edit(embed=embed)
|
||||
last_edit = time.time()
|
||||
except ResponseError as err:
|
||||
if err.error.endswith("file does not exist"):
|
||||
await ctx.edit(
|
||||
embed=None,
|
||||
content="The model %r does not exist." % model,
|
||||
delete_after=60,
|
||||
view=None
|
||||
)
|
||||
else:
|
||||
embed.add_field(
|
||||
name="Error!",
|
||||
value=err.error
|
||||
)
|
||||
embed.colour = discord.Colour.red()
|
||||
await ctx.edit(embed=embed, view=None)
|
||||
return
|
||||
else:
|
||||
embed.colour = discord.Colour.green()
|
||||
embed.description = f"Downloaded {model} on {server}."
|
||||
await ctx.edit(embed=embed, delete_after=30, view=None)
|
||||
|
||||
messages = []
|
||||
if thread_id:
|
||||
thread = await OllamaThread.get_or_none(thread_id=thread_id)
|
||||
if thread:
|
||||
for msg in thread.messages:
|
||||
messages.append(
|
||||
await create_ollama_message(msg["content"], role=msg["role"])
|
||||
)
|
||||
else:
|
||||
await ctx.respond(content="No thread with that ID exists.", delete_after=30)
|
||||
if system_prompt:
|
||||
messages.append(await create_ollama_message(system_prompt, role="system"))
|
||||
messages.append(await create_ollama_message(prompt, images=[await image.read()] if image else None))
|
||||
embed = discord.Embed(title=f"{model}:", description="")
|
||||
view = StopDownloadView(ctx)
|
||||
msg = await ctx.respond(
|
||||
embed=embed,
|
||||
view=view
|
||||
)
|
||||
last_edit = time.time()
|
||||
buffer = io.StringIO()
|
||||
async for response in await client.chat(
|
||||
model,
|
||||
messages,
|
||||
stream=True,
|
||||
options=Options(
|
||||
num_ctx=4096,
|
||||
low_vram=server.vram_gb < 8,
|
||||
temperature=1.5
|
||||
)
|
||||
):
|
||||
self.log.info("Response from %r: %r", server, response)
|
||||
buffer.write(response["message"]["content"])
|
||||
|
||||
if len(buffer.getvalue()) > 4096:
|
||||
value = "... " + buffer.getvalue()[4:]
|
||||
else:
|
||||
value = buffer.getvalue()
|
||||
embed.description = value
|
||||
if view.event.is_set():
|
||||
embed.add_field(name="Error!", value="Chat cancelled.")
|
||||
embed.colour = discord.Colour.red()
|
||||
await msg.edit(embed=embed, view=None)
|
||||
return
|
||||
if time.time() - last_edit >= 2.5:
|
||||
await msg.edit(embed=embed, view=view)
|
||||
last_edit = time.time()
|
||||
embed.colour = discord.Colour.green()
|
||||
if len(buffer.getvalue()) > 4096:
|
||||
file = discord.File(
|
||||
io.BytesIO(buffer.getvalue().encode()),
|
||||
filename="full-chat.txt"
|
||||
)
|
||||
embed.add_field(
|
||||
name="Full chat",
|
||||
value="The chat was too long to fit in this message. "
|
||||
f"You can download the `full-chat.txt` file to see the full message."
|
||||
)
|
||||
else:
|
||||
file = discord.utils.MISSING
|
||||
|
||||
thread = OllamaThread(
|
||||
messages=[{"role": m["role"], "content": m["content"]} for m in messages],
|
||||
)
|
||||
await thread.save()
|
||||
embed.set_footer(text=f"Chat ID: {thread.thread_id}")
|
||||
await msg.edit(embed=embed, view=None, file=file)
|
||||
|
||||
|
||||
def setup(bot):
|
||||
bot.add_cog(Chat(bot))
|
69
jimmy/config.py
Normal file
69
jimmy/config.py
Normal file
|
@ -0,0 +1,69 @@
|
|||
import tomllib
|
||||
import logging
|
||||
from typing import Callable
|
||||
|
||||
import httpx
|
||||
from pydantic import BaseModel, Field, AnyHttpUrl
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ServerConfig(BaseModel):
|
||||
name: str = Field(min_length=1, max_length=32)
|
||||
base_url: AnyHttpUrl
|
||||
gpu: bool = False
|
||||
vram_gb: int = 4
|
||||
throttle: bool = False
|
||||
|
||||
def __repr__(self):
|
||||
return "<ServerConfig name={0.name} base_url={0.base_url} gpu={0.gpu!s} vram_gb={0.vram_gb}>".format(self)
|
||||
|
||||
def __str__(self):
|
||||
return self.name
|
||||
|
||||
async def is_online(self) -> bool:
|
||||
"""
|
||||
Checks that the current server is online and responding to requests.
|
||||
"""
|
||||
async with httpx.AsyncClient(base_url=str(self.base_url)) as client:
|
||||
try:
|
||||
response = await client.get("/api/tags")
|
||||
return response.status_code == 200
|
||||
except httpx.RequestError:
|
||||
return False
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.base_url)
|
||||
|
||||
|
||||
def get_servers(filter_func: Callable[[ServerConfig], bool] = None) -> list[ServerConfig]:
|
||||
config = get_config()
|
||||
keys = list(config["servers"].keys())
|
||||
log.info("Servers: %r", keys)
|
||||
try:
|
||||
keys = config["servers"].pop("order")
|
||||
log.info("Ordered keys: %r", keys)
|
||||
except ValueError:
|
||||
pass
|
||||
servers = [ServerConfig(name=key, **config["servers"][key]) for key in keys]
|
||||
if filter_func:
|
||||
servers = list(filter(filter_func, servers))
|
||||
return servers
|
||||
|
||||
|
||||
def get_server(name_or_base_url: str) -> ServerConfig | None:
|
||||
servers = get_servers()
|
||||
for server in servers:
|
||||
if server.name == name_or_base_url or server.base_url == name_or_base_url:
|
||||
return server
|
||||
return None
|
||||
|
||||
|
||||
def get_config():
|
||||
with open("config.toml", "rb") as file:
|
||||
_loaded = tomllib.load(file)
|
||||
|
||||
_loaded.setdefault("servers", {})
|
||||
_loaded["servers"].setdefault("order", [])
|
||||
_loaded.setdefault("bot", {})
|
||||
return _loaded
|
13
jimmy/db.py
Normal file
13
jimmy/db.py
Normal file
|
@ -0,0 +1,13 @@
|
|||
import os
|
||||
|
||||
from tortoise.models import Model
|
||||
from tortoise import fields
|
||||
|
||||
|
||||
class OllamaThread(Model):
|
||||
thread_id = fields.CharField(max_length=255, unique=True, default=lambda: os.urandom(4).hex())
|
||||
messages = fields.JSONField(default=[])
|
||||
created_at = fields.DatetimeField(auto_now_add=True)
|
||||
|
||||
class Meta:
|
||||
table = "ollama_threads"
|
53
jimmy/main.py
Normal file
53
jimmy/main.py
Normal file
|
@ -0,0 +1,53 @@
|
|||
import os
|
||||
import sys
|
||||
import logging
|
||||
import discord
|
||||
from discord.ext import commands
|
||||
from tortoise import Tortoise
|
||||
sys.path.extend("..") # noqa: E402
|
||||
from .config import get_config
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
|
||||
class SentientJimmy(commands.Bot):
|
||||
def __init__(self):
|
||||
intents = discord.Intents.default()
|
||||
# noinspection PyUnresolvedReferences
|
||||
intents.message_content = True
|
||||
super().__init__(
|
||||
commands.when_mentioned_or("."),
|
||||
intents=intents,
|
||||
case_insensitive=True,
|
||||
strip_after_prefix=True,
|
||||
debug_guilds=get_config()["bot"].get("debug_guilds"),
|
||||
)
|
||||
self.load_extension("jimmy.cogs.chat")
|
||||
self.load_extension("jishaku")
|
||||
|
||||
async def start(self, token: str, *, reconnect: bool = True) -> None:
|
||||
is_docker = os.path.exists("/.dockerenv")
|
||||
default_db = "sqlite://:memory:" if is_docker else "sqlite://default.db"
|
||||
await Tortoise.init(
|
||||
db_url=get_config()["bot"].get("db_url", default_db),
|
||||
modules={"models": ["jimmy.db"]}
|
||||
)
|
||||
await Tortoise.generate_schemas()
|
||||
await super().start(token, reconnect=reconnect)
|
||||
|
||||
async def close(self) -> None:
|
||||
await Tortoise.close_connections()
|
||||
await super().close()
|
||||
|
||||
def run(self) -> None:
|
||||
token = get_config()["bot"]["token"]
|
||||
super().run(token)
|
||||
|
||||
|
||||
bot = SentientJimmy()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
bot.run()
|
47
jimmy/utils.py
Normal file
47
jimmy/utils.py
Normal file
|
@ -0,0 +1,47 @@
|
|||
import asyncio
|
||||
import typing
|
||||
from functools import partial
|
||||
from fuzzywuzzy.fuzz import ratio
|
||||
from ollama import Message
|
||||
|
||||
|
||||
__all__ = (
|
||||
'async_ratio',
|
||||
'create_ollama_message',
|
||||
)
|
||||
|
||||
|
||||
async def async_ratio(a: str, b: str) -> int:
|
||||
"""
|
||||
Wraps fuzzywuzzy ratio in an async function
|
||||
|
||||
:param a: str - first string
|
||||
:param b: str - second string
|
||||
:return: int - ratio of similarity
|
||||
"""
|
||||
return await asyncio.to_thread(partial(ratio, a, b))
|
||||
|
||||
|
||||
async def create_ollama_message(
|
||||
content: str,
|
||||
role: typing.Literal["system", "assistant", "user"] = "user",
|
||||
images: typing.List[str | bytes] = None
|
||||
) -> Message:
|
||||
"""
|
||||
Create a message for ollama.
|
||||
|
||||
:param content: str - the content of the message
|
||||
:param role: str - the role of the message
|
||||
:param images: list - the images to attach to the message
|
||||
:return: dict - the message
|
||||
"""
|
||||
def factory(**kwargs):
|
||||
return Message(**kwargs)
|
||||
return await asyncio.to_thread(
|
||||
partial(
|
||||
factory,
|
||||
role=role,
|
||||
content=content,
|
||||
images=images
|
||||
)
|
||||
)
|
|
@ -1,5 +1,8 @@
|
|||
py-cord~=2.5
|
||||
ollama-python~=0.1
|
||||
ollama~=0.2
|
||||
tortoise-orm[asyncpg]~=0.21
|
||||
uvicorn[standard]~=0.30
|
||||
fastapi~=0.111
|
||||
jishaku~=2.5
|
||||
fuzzywuzzy~=0.18
|
||||
humanize~=4.9
|
||||
|
|
Loading…
Reference in a new issue