diff --git a/.gitignore b/.gitignore index 6df56ff..138c68e 100644 --- a/.gitignore +++ b/.gitignore @@ -281,3 +281,5 @@ pyrightconfig.json .ionide # End of https://www.toptal.com/developers/gitignore/api/python,visualstudiocode,pycharm+all +.venv/ +default.db diff --git a/config.toml b/config.toml new file mode 100644 index 0000000..8897883 --- /dev/null +++ b/config.toml @@ -0,0 +1,31 @@ +[bot] +token = "MTI0OTUwNjU1ODIwODExNDgxMA.G_JIT7.R2R_1-2IHhdzEf6mHUgIa82oyRuwonBRrkd_Pc" +debug_guilds = [1106243455816052847] + +[servers] +order = ["SpeedySHRoNK", "IvyPC", "nextop-ts", "optiplex", "shronk"] + +[servers.SpeedySHRoNK] +base_url = "http://ollama.shronk.net:11434" +gpu = true +vram_gb = 10 + +[servers.IvyPC] +base_url = "http://192.168.0.26:11435" +gpu = true +vram_gb = 8 + +[servers.nextop-ts] +base_url = "http://laptop-linux.fluffy-gentoo.ts.net:11434" +gpu = true +vram_gb = 4 + +[servers.optiplex] +base_url = "http://192.168.0.254:11434" +gpu = false +vram_gb = 16 + +[servers.shronk] +base_url = "http://ollama.shronk.net:11434" +gpu = false +vram_gb = 16 diff --git a/jimmy/cogs/chat.py b/jimmy/cogs/chat.py new file mode 100644 index 0000000..fd250b8 --- /dev/null +++ b/jimmy/cogs/chat.py @@ -0,0 +1,328 @@ +import asyncio +import io +import logging +import time +import typing +import contextlib +from fnmatch import fnmatch + +import discord +from discord import Interaction +from ollama import AsyncClient, ResponseError, Options +from discord.ext import commands +from jimmy.utils import async_ratio, create_ollama_message +from jimmy.config import get_servers, ServerConfig, get_server +from jimmy.db import OllamaThread +from humanize import naturalsize + + +@contextlib.asynccontextmanager +async def ollama_client(host: str, **kwargs) -> AsyncClient: + host = str(host) + client = AsyncClient(host, **kwargs) + try: + yield client + finally: + # Ollama doesn't auto-close the client, so we have to do it ourselves. + await client._client.aclose() + + +class StopDownloadView(discord.ui.View): + def __init__(self, ctx: discord.ApplicationContext): + super().__init__(timeout=None) + self.ctx = ctx + self.event = asyncio.Event() + + async def interaction_check(self, interaction: Interaction) -> bool: + return interaction.user == self.ctx.user + + @discord.ui.button(label="Cancel", style=discord.ButtonStyle.danger) + async def cancel_download(self, _, interaction: discord.Interaction): + self.event.set() + self.stop() + await interaction.response.edit_message(view=None) + + +async def get_available_tags_autocomplete(ctx: discord.AutocompleteContext): + chosen_server = get_server(ctx.options.get("server") or get_servers()[0].name) + async with ollama_client(str(chosen_server.base_url), timeout=2) as client: + tags = (await client.list())["models"] + return [tag["model"] for tag in tags if ctx.value.casefold() in tag["model"].casefold()] + + +_ServerOptionChoices = [discord.OptionChoice(server.name, server.name) for server in get_servers()] + + +class Chat(commands.Cog): + def __init__(self, bot): + self.bot = bot + self.server_locks = {} + for server in get_servers(): + self.server_locks[server.name] = asyncio.Lock() + self.log = logging.getLogger(__name__) + + @commands.slash_command() + async def status(self, ctx: discord.ApplicationContext): + """Checks the status on all servers.""" + await ctx.defer() + + def decorate_name(_s: ServerConfig): + if _s.gpu: + return f"{_s.name} (\u26A1)" + return _s.name + + embed = discord.Embed( + title="Ollama Statuses:", + color=discord.Color.blurple() + ) + fields = {} + for server in get_servers(): + if server.throttle and self.server_locks[server.name].locked(): + embed.add_field( + name=decorate_name(server), + value=f"\N{closed lock with key} In use.", + inline=False + ) + fields[server] = len(embed.fields) - 1 + continue + else: + embed.add_field( + name=decorate_name(server), + value=f"\N{hourglass with flowing sand} Waiting...", + inline=False + ) + fields[server] = len(embed.fields) - 1 + + await ctx.respond(embed=embed) + tasks = {} + for server in get_servers(): + if server.throttle and self.server_locks[server.name].locked(): + continue + tasks[server] = asyncio.create_task(server.is_online()) + + await asyncio.gather(*tasks.values()) + for server, task in tasks.items(): + if task.result(): + embed.set_field_at( + fields[server], + name=decorate_name(server), + value=f"\N{white heavy check mark} Online.", + inline=False + ) + else: + embed.set_field_at( + fields[server], + name=decorate_name(server), + value=f"\N{cross mark} Offline.", + inline=False + ) + await ctx.edit(embed=embed) + + @commands.slash_command(name="ollama") + async def start_ollama_chat( + self, + ctx: discord.ApplicationContext, + prompt: str, + system_prompt: typing.Annotated[ + str | None, + discord.Option( + discord.SlashCommandOptionType.string, + description="The system prompt to use.", + default=None + ) + ], + server: typing.Annotated[ + str, + discord.Option( + discord.SlashCommandOptionType.string, + description="The server to use.", + choices=_ServerOptionChoices, + default=get_servers()[0].name + ) + ], + model: typing.Annotated[ + str, + discord.Option( + discord.SlashCommandOptionType.string, + description="The model to use.", + autocomplete=get_available_tags_autocomplete, + default="llama3:latest" + ) + ], + image: typing.Annotated[ + discord.Attachment | None, + discord.Option( + discord.SlashCommandOptionType.attachment, + description="The image to use for llava.", + default=None + ) + ], + thread_id: typing.Annotated[ + str | None, + discord.Option( + discord.SlashCommandOptionType.string, + description="The thread ID to continue.", + default=None + ) + ] + ): + """Have a chat with ollama""" + await ctx.defer() + server = get_server(server) + async with self.server_locks[server.name]: + if not await server.is_online(): + await ctx.respond( + content=f"{server} is offline.", + delete_after=60 + ) + return + async with ollama_client(str(server.base_url)) as client: + client: AsyncClient + self.log.info("Checking if %r has the model %r", server, model) + tags = (await client.list())["models"] + if model not in [x["model"] for x in tags]: + embed = discord.Embed( + title=f"Downloading {model} on {server}.", + description=f"Initiating download...", + color=discord.Color.blurple() + ) + view = StopDownloadView(ctx) + await ctx.respond( + embed=embed, + view=view + ) + last_edit = 0 + async with ctx.typing(): + try: + last_completed = 0 + last_completed_ts = time.time() + + async for line in await client.pull(model, stream=True): + if view.event.is_set(): + embed.add_field(name="Error!", value="Download cancelled.") + embed.colour = discord.Colour.red() + await ctx.edit(embed=embed) + return + self.log.info("Response from %r: %r", server, line) + if line["status"] in { + "pulling manifest", + "verifying sha256 digest", + "writing manifest", + "removing any unused layers", + "success" + }: + embed.description = line["status"].capitalize() + else: + total = line["total"] + completed = line.get("completed", 0) + percent = round(completed / total * 100, 1) + pb_fill = "▰" * int(percent / 10) + pb_empty = "▱" * (10 - int(percent / 10)) + bytes_per_second = completed - last_completed + bytes_per_second /= (time.time() - last_completed_ts) + last_completed = completed + last_completed_ts = time.time() + mbps = round((bytes_per_second * 8) / 1024 / 1024) + progress_bar = f"[{pb_fill}{pb_empty}]" + ns_total = naturalsize(total, binary=True) + ns_completed = naturalsize(completed, binary=True) + embed.description = ( + f"{line['status'].capitalize()} {percent}% {progress_bar} " + f"({ns_completed}/{ns_total} @ {mbps} Mb/s)" + ) + + if time.time() - last_edit >= 2.5: + await ctx.edit(embed=embed) + last_edit = time.time() + except ResponseError as err: + if err.error.endswith("file does not exist"): + await ctx.edit( + embed=None, + content="The model %r does not exist." % model, + delete_after=60, + view=None + ) + else: + embed.add_field( + name="Error!", + value=err.error + ) + embed.colour = discord.Colour.red() + await ctx.edit(embed=embed, view=None) + return + else: + embed.colour = discord.Colour.green() + embed.description = f"Downloaded {model} on {server}." + await ctx.edit(embed=embed, delete_after=30, view=None) + + messages = [] + if thread_id: + thread = await OllamaThread.get_or_none(thread_id=thread_id) + if thread: + for msg in thread.messages: + messages.append( + await create_ollama_message(msg["content"], role=msg["role"]) + ) + else: + await ctx.respond(content="No thread with that ID exists.", delete_after=30) + if system_prompt: + messages.append(await create_ollama_message(system_prompt, role="system")) + messages.append(await create_ollama_message(prompt, images=[await image.read()] if image else None)) + embed = discord.Embed(title=f"{model}:", description="") + view = StopDownloadView(ctx) + msg = await ctx.respond( + embed=embed, + view=view + ) + last_edit = time.time() + buffer = io.StringIO() + async for response in await client.chat( + model, + messages, + stream=True, + options=Options( + num_ctx=4096, + low_vram=server.vram_gb < 8, + temperature=1.5 + ) + ): + self.log.info("Response from %r: %r", server, response) + buffer.write(response["message"]["content"]) + + if len(buffer.getvalue()) > 4096: + value = "... " + buffer.getvalue()[4:] + else: + value = buffer.getvalue() + embed.description = value + if view.event.is_set(): + embed.add_field(name="Error!", value="Chat cancelled.") + embed.colour = discord.Colour.red() + await msg.edit(embed=embed, view=None) + return + if time.time() - last_edit >= 2.5: + await msg.edit(embed=embed, view=view) + last_edit = time.time() + embed.colour = discord.Colour.green() + if len(buffer.getvalue()) > 4096: + file = discord.File( + io.BytesIO(buffer.getvalue().encode()), + filename="full-chat.txt" + ) + embed.add_field( + name="Full chat", + value="The chat was too long to fit in this message. " + f"You can download the `full-chat.txt` file to see the full message." + ) + else: + file = discord.utils.MISSING + + thread = OllamaThread( + messages=[{"role": m["role"], "content": m["content"]} for m in messages], + ) + await thread.save() + embed.set_footer(text=f"Chat ID: {thread.thread_id}") + await msg.edit(embed=embed, view=None, file=file) + + +def setup(bot): + bot.add_cog(Chat(bot)) diff --git a/jimmy/config.py b/jimmy/config.py new file mode 100644 index 0000000..10b144f --- /dev/null +++ b/jimmy/config.py @@ -0,0 +1,69 @@ +import tomllib +import logging +from typing import Callable + +import httpx +from pydantic import BaseModel, Field, AnyHttpUrl + +log = logging.getLogger(__name__) + + +class ServerConfig(BaseModel): + name: str = Field(min_length=1, max_length=32) + base_url: AnyHttpUrl + gpu: bool = False + vram_gb: int = 4 + throttle: bool = False + + def __repr__(self): + return "".format(self) + + def __str__(self): + return self.name + + async def is_online(self) -> bool: + """ + Checks that the current server is online and responding to requests. + """ + async with httpx.AsyncClient(base_url=str(self.base_url)) as client: + try: + response = await client.get("/api/tags") + return response.status_code == 200 + except httpx.RequestError: + return False + + def __hash__(self): + return hash(self.base_url) + + +def get_servers(filter_func: Callable[[ServerConfig], bool] = None) -> list[ServerConfig]: + config = get_config() + keys = list(config["servers"].keys()) + log.info("Servers: %r", keys) + try: + keys = config["servers"].pop("order") + log.info("Ordered keys: %r", keys) + except ValueError: + pass + servers = [ServerConfig(name=key, **config["servers"][key]) for key in keys] + if filter_func: + servers = list(filter(filter_func, servers)) + return servers + + +def get_server(name_or_base_url: str) -> ServerConfig | None: + servers = get_servers() + for server in servers: + if server.name == name_or_base_url or server.base_url == name_or_base_url: + return server + return None + + +def get_config(): + with open("config.toml", "rb") as file: + _loaded = tomllib.load(file) + + _loaded.setdefault("servers", {}) + _loaded["servers"].setdefault("order", []) + _loaded.setdefault("bot", {}) + return _loaded diff --git a/jimmy/db.py b/jimmy/db.py new file mode 100644 index 0000000..1d769e7 --- /dev/null +++ b/jimmy/db.py @@ -0,0 +1,13 @@ +import os + +from tortoise.models import Model +from tortoise import fields + + +class OllamaThread(Model): + thread_id = fields.CharField(max_length=255, unique=True, default=lambda: os.urandom(4).hex()) + messages = fields.JSONField(default=[]) + created_at = fields.DatetimeField(auto_now_add=True) + + class Meta: + table = "ollama_threads" diff --git a/jimmy/main.py b/jimmy/main.py new file mode 100644 index 0000000..f26f1ca --- /dev/null +++ b/jimmy/main.py @@ -0,0 +1,53 @@ +import os +import sys +import logging +import discord +from discord.ext import commands +from tortoise import Tortoise +sys.path.extend("..") # noqa: E402 +from .config import get_config + + +log = logging.getLogger(__name__) +logging.basicConfig(level=logging.INFO) + + +class SentientJimmy(commands.Bot): + def __init__(self): + intents = discord.Intents.default() + # noinspection PyUnresolvedReferences + intents.message_content = True + super().__init__( + commands.when_mentioned_or("."), + intents=intents, + case_insensitive=True, + strip_after_prefix=True, + debug_guilds=get_config()["bot"].get("debug_guilds"), + ) + self.load_extension("jimmy.cogs.chat") + self.load_extension("jishaku") + + async def start(self, token: str, *, reconnect: bool = True) -> None: + is_docker = os.path.exists("/.dockerenv") + default_db = "sqlite://:memory:" if is_docker else "sqlite://default.db" + await Tortoise.init( + db_url=get_config()["bot"].get("db_url", default_db), + modules={"models": ["jimmy.db"]} + ) + await Tortoise.generate_schemas() + await super().start(token, reconnect=reconnect) + + async def close(self) -> None: + await Tortoise.close_connections() + await super().close() + + def run(self) -> None: + token = get_config()["bot"]["token"] + super().run(token) + + +bot = SentientJimmy() + + +if __name__ == "__main__": + bot.run() diff --git a/jimmy/utils.py b/jimmy/utils.py new file mode 100644 index 0000000..782fe35 --- /dev/null +++ b/jimmy/utils.py @@ -0,0 +1,47 @@ +import asyncio +import typing +from functools import partial +from fuzzywuzzy.fuzz import ratio +from ollama import Message + + +__all__ = ( + 'async_ratio', + 'create_ollama_message', +) + + +async def async_ratio(a: str, b: str) -> int: + """ + Wraps fuzzywuzzy ratio in an async function + + :param a: str - first string + :param b: str - second string + :return: int - ratio of similarity + """ + return await asyncio.to_thread(partial(ratio, a, b)) + + +async def create_ollama_message( + content: str, + role: typing.Literal["system", "assistant", "user"] = "user", + images: typing.List[str | bytes] = None +) -> Message: + """ + Create a message for ollama. + + :param content: str - the content of the message + :param role: str - the role of the message + :param images: list - the images to attach to the message + :return: dict - the message + """ + def factory(**kwargs): + return Message(**kwargs) + return await asyncio.to_thread( + partial( + factory, + role=role, + content=content, + images=images + ) + ) diff --git a/requirements.txt b/requirements.txt index fec8d5d..2302399 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,8 @@ py-cord~=2.5 -ollama-python~=0.1 +ollama~=0.2 tortoise-orm[asyncpg]~=0.21 uvicorn[standard]~=0.30 fastapi~=0.111 +jishaku~=2.5 +fuzzywuzzy~=0.18 +humanize~=4.9