import asyncio import datetime import io import logging import time import typing import contextlib import discord import httpx from discord import Interaction from ollama import AsyncClient, ResponseError, Options from discord.ext import commands from jimmy.utils import create_ollama_message, find_suitable_server, decorate_server_name as decorate_name from jimmy.config import get_servers, get_server, get_config from jimmy.db import OllamaThread from humanize import naturalsize, naturaldelta @contextlib.asynccontextmanager async def ollama_client(host: str, **kwargs) -> AsyncClient: host = str(host) client = AsyncClient(host, **kwargs) try: yield client finally: # Ollama doesn't auto-close the client, so we have to do it ourselves. await client._client.aclose() class StopDownloadView(discord.ui.View): def __init__(self, ctx: discord.ApplicationContext): super().__init__(timeout=None) self.ctx = ctx self.event = asyncio.Event() async def interaction_check(self, interaction: Interaction) -> bool: return interaction.user == self.ctx.user @discord.ui.button(label="Cancel", style=discord.ButtonStyle.danger) async def cancel_download(self, _, interaction: discord.Interaction): self.event.set() self.stop() await interaction.response.edit_message(view=None) async def get_available_tags_autocomplete(ctx: discord.AutocompleteContext): chosen_server = get_server(ctx.options.get("server") or get_servers()[0].name) async with ollama_client(str(chosen_server.base_url), timeout=2) as client: tags = (await client.list())["models"] v = [tag["model"] for tag in tags if ctx.value.casefold() in tag["model"].casefold()] return [ctx.value, *v][:25] _ServerOptionAutocomplete = discord.utils.basic_autocomplete( [x.name for x in get_servers()] ) class Chat(commands.Cog): def __init__(self, bot): self.bot = bot self.server_locks = {} for server in get_servers(): self.server_locks[server.name] = asyncio.Lock() self.log = logging.getLogger(__name__) ollama_group = discord.SlashCommandGroup( name="ollama", description="Commands related to ollama.", guild_only=True ) @ollama_group.command() async def status(self, ctx: discord.ApplicationContext): """Checks the status on all servers.""" await ctx.defer() embed = discord.Embed( title="Ollama Statuses:", color=discord.Color.blurple() ) fields = {} for server in get_servers(): if self.server_locks[server.name].locked(): embed.add_field( name=decorate_name(server), value="\N{closed lock with key} In use.", inline=False ) fields[server] = len(embed.fields) - 1 continue else: embed.add_field( name=decorate_name(server), value="\N{hourglass with flowing sand} Waiting...", inline=False ) fields[server] = len(embed.fields) - 1 await ctx.respond(embed=embed) tasks = {} for server in get_servers(): if self.server_locks[server.name].locked(): continue tasks[server] = asyncio.create_task(server.is_online()) await asyncio.gather(*tasks.values()) for server, task in tasks.items(): if task.result(): embed.set_field_at( fields[server], name=decorate_name(server), value="\N{white heavy check mark} Online.", inline=False ) else: embed.set_field_at( fields[server], name=decorate_name(server), value="\N{cross mark} Offline.", inline=False ) await ctx.edit(embed=embed) @ollama_group.command(name="server-info") async def get_server_info( self, ctx: discord.ApplicationContext, server: typing.Annotated[ str, discord.Option( discord.SlashCommandOptionType.string, description="The server to use.", autocomplete=_ServerOptionAutocomplete, default=get_servers()[0].name ) ] ): """Gets information on a given server""" await ctx.defer() server = get_server(server) is_online = await server.is_online() y = "\N{white heavy check mark}" x = "\N{cross mark}" t = {True: y, False: x} rt = "VRAM" if server.gpu else "RAM" lines = [ f"Name: {server.name!r}", f"Base URL: {server.base_url!r}", f"GPU Enabled: {t[server.gpu]}", f"{rt}: {server.vram_gb:,} GB", f"Default Model: {server.default_model!r}", f"Is Online: {t[is_online]}" ] p = "```md\n" + "\n".join(lines) + "```" return await ctx.respond(p) @ollama_group.command(name="chat") async def start_ollama_chat( self, ctx: discord.ApplicationContext, prompt: str, system_prompt: typing.Annotated[ str | None, discord.Option( discord.SlashCommandOptionType.string, description="The system prompt to use.", default=None ) ], server: typing.Annotated[ str, discord.Option( discord.SlashCommandOptionType.string, description="The server to use.", autocomplete=_ServerOptionAutocomplete, default=get_servers()[0].name ) ], model: typing.Annotated[ str, discord.Option( discord.SlashCommandOptionType.string, description="The model to use.", autocomplete=get_available_tags_autocomplete, default="default" ) ], image: typing.Annotated[ discord.Attachment | None, discord.Option( discord.SlashCommandOptionType.attachment, description="The image to use for llava.", default=None ) ], thread_id: typing.Annotated[ str | None, discord.Option( discord.SlashCommandOptionType.string, description="The thread ID to continue.", default=None ) ], temperature: typing.Annotated[ float, discord.Option( discord.SlashCommandOptionType.number, description="The temperature to use.", default=1.5, min_value=0.0, max_value=2.0 ) ] ): """Have a chat with ollama""" await ctx.defer() server = get_server(server) if not server: return await ctx.respond("\N{cross mark} Unknown Server.") elif not await server.is_online(): await ctx.respond( content=f"{server} is offline. Finding a suitable server...", ) try: server = await find_suitable_server() except ValueError as err: return await ctx.edit(content=str(err), delete_after=30) await ctx.delete(delay=5) async with self.server_locks[server.name]: if model == "default": model = server.default_model async with ollama_client(str(server.base_url)) as client: client: AsyncClient self.log.info("Checking if %r has the model %r", server, model) tags = (await client.list())["models"] # Download code. It's recommended to collapse this in the editor. if model not in [x["model"] for x in tags]: embed = discord.Embed( title=f"Downloading {model} on {server}.", description="Initiating download...", color=discord.Color.blurple() ) view = StopDownloadView(ctx) await ctx.respond( embed=embed, view=view ) last_edit = 0 async with ctx.typing(): try: last_completed = 0 last_completed_ts = time.time() async for line in await client.pull(model, stream=True): if view.event.is_set(): embed.add_field(name="Error!", value="Download cancelled.") embed.colour = discord.Colour.red() await ctx.edit(embed=embed) return self.log.debug("Response from %r: %r", server, line) if line["status"] in { "pulling manifest", "verifying sha256 digest", "writing manifest", "removing any unused layers", "success" }: embed.description = line["status"].capitalize() else: total = line["total"] completed = line.get("completed", 0) percent = round(completed / total * 100, 1) pb_fill = "▰" * int(percent / 10) pb_empty = "▱" * (10 - int(percent / 10)) bytes_per_second = completed - last_completed bytes_per_second /= (time.time() - last_completed_ts) last_completed = completed last_completed_ts = time.time() mbps = round((bytes_per_second * 8) / 1024 / 1024) eta = (total - completed) / max(1, bytes_per_second) progress_bar = f"[{pb_fill}{pb_empty}]" ns_total = naturalsize(total, binary=True) ns_completed = naturalsize(completed, binary=True) embed.description = ( f"{line['status'].capitalize()} {percent}% {progress_bar} " f"({ns_completed}/{ns_total} @ {mbps} Mb/s) " f"[ETA: {naturaldelta(eta)}]" ) if time.time() - last_edit >= 2.5: await ctx.edit(embed=embed) last_edit = time.time() except ResponseError as err: if err.error.endswith("file does not exist"): await ctx.edit( embed=None, content="The model %r does not exist." % model, delete_after=60, view=None ) else: embed.add_field( name="Error!", value=err.error ) embed.colour = discord.Colour.red() await ctx.edit(embed=embed, view=None) return else: embed.colour = discord.Colour.green() embed.description = f"Downloaded {model} on {server}." await ctx.edit(embed=embed, delete_after=30, view=None) messages = [] thread = None if thread_id: thread = await OllamaThread.get_or_none(thread_id=thread_id) if thread: for msg in thread.messages: messages.append( await create_ollama_message(msg["content"], role=msg["role"]) ) elif len(thread_id) == 6: # Is a legacy thread _cfg = get_config()["truth_api"] async with httpx.AsyncClient( base_url=_cfg["url"], auth=(_cfg["username"], _cfg["password"]) ) as http_client: response = await http_client.get(f"/ollama/thread/threads:{thread_id}") if response.status_code == 200: thread = response.json() messages = thread["messages"] thread = OllamaThread( messages=[{"role": m["role"], "content": m["content"]} for m in messages], ) await thread.save() else: return await ctx.respond( content="Failed to fetch legacy ollama thread from jimmy v2: HTTP %d (`%r`)" % ( response.status_code, response.text ), ) else: return await ctx.respond(content="No thread with that ID exists.", delete_after=30) if system_prompt: messages.append(await create_ollama_message(system_prompt, role="system")) messages.append(await create_ollama_message(prompt, images=[await image.read()] if image else None)) embed = discord.Embed(description="") embed.set_author( name=f"{model} @ {decorate_name(server)!r}" if server.gpu else model, icon_url="https://ollama.com/public/icon-64x64.png" ) view = StopDownloadView(ctx) msg = await ctx.respond( embed=embed, view=view ) last_edit = time.time() buffer = io.StringIO() async for response in await client.chat( model, messages, stream=True, options=Options( num_ctx=4096, low_vram=server.vram_gb < 8, temperature=temperature ) ): response: dict self.log.debug("Response from %r: %r", server, response) buffer.write(response["message"]["content"]) if len(buffer.getvalue()) > 4096: value = "... " + buffer.getvalue()[4:] else: value = buffer.getvalue() embed.description = value if view.event.is_set(): embed.add_field(name="Error!", value="Chat cancelled.") embed.colour = discord.Colour.red() await msg.edit(embed=embed, view=None) return if time.time() - last_edit >= 2.5: await msg.edit(embed=embed, view=view) last_edit = time.time() embed.colour = discord.Colour.green() if len(buffer.getvalue()) > 4096: file = discord.File( io.BytesIO(buffer.getvalue().encode()), filename="full-chat.txt" ) embed.add_field( name="Full chat", value="The chat was too long to fit in this message. " "You can download the `full-chat.txt` file to see the full message." ) else: file = discord.utils.MISSING if not thread: thread = OllamaThread( messages=[{"role": m["role"], "content": m["content"]} for m in messages], ) await thread.save() embed.set_footer(text=f"Chat ID: {thread.thread_id}") await msg.edit(embed=embed, view=None, file=file) @ollama_group.command(name="pull") async def pull_ollama_model( self, ctx: discord.ApplicationContext, server: typing.Annotated[ str, discord.Option( discord.SlashCommandOptionType.string, description="The server to use.", autocomplete=_ServerOptionAutocomplete, default=get_servers()[0].name ) ], model: typing.Annotated[ str, discord.Option( discord.SlashCommandOptionType.string, description="The model to use.", autocomplete=get_available_tags_autocomplete, default="llama3:latest" ) ], ): """Downloads a tag on the target server""" await ctx.defer() server = get_server(server) if not server: return await ctx.respond("\N{cross mark} Unknown server.") elif not await server.is_online(): return await ctx.respond(f"\N{cross mark} Server {server.name!r} is not responding") embed = discord.Embed( title=f"Downloading {model} on {server}.", description="Initiating download...", color=discord.Color.blurple() ) view = StopDownloadView(ctx) await ctx.respond( embed=embed, view=view ) last_edit = 0 async with ctx.typing(): try: last_completed = 0 last_completed_ts = time.time() async with ollama_client(str(server.base_url)) as client: async for line in await client.pull(model, stream=True): if view.event.is_set(): embed.add_field(name="Error!", value="Download cancelled.") embed.colour = discord.Colour.red() await ctx.edit(embed=embed) return self.log.debug("Response from %r: %r", server, line) if line["status"] in { "pulling manifest", "verifying sha256 digest", "writing manifest", "removing any unused layers", "success" }: embed.description = line["status"].capitalize() else: total = line["total"] completed = line.get("completed", 0) percent = round(completed / total * 100, 1) pb_fill = "▰" * int(percent / 10) pb_empty = "▱" * (10 - int(percent / 10)) bytes_per_second = completed - last_completed bytes_per_second /= (time.time() - last_completed_ts) last_completed = completed last_completed_ts = time.time() mbps = round((bytes_per_second * 8) / 1024 / 1024) eta = (total - completed) / max(1, bytes_per_second) progress_bar = f"[{pb_fill}{pb_empty}]" ns_total = naturalsize(total, binary=True) ns_completed = naturalsize(completed, binary=True) embed.description = ( f"{line['status'].capitalize()} {percent}% {progress_bar} " f"({ns_completed}/{ns_total} @ {mbps} Mb/s) " f"[ETA: {naturaldelta(eta)}]" ) if time.time() - last_edit >= 2.5: await ctx.edit(embed=embed) last_edit = time.time() except ResponseError as err: if err.error.endswith("file does not exist"): await ctx.edit( embed=None, content="The model %r does not exist." % model, delete_after=60, view=None ) else: embed.add_field( name="Error!", value=err.error ) embed.colour = discord.Colour.red() await ctx.edit(embed=embed, view=None) return else: embed.colour = discord.Colour.green() embed.description = f"Downloaded {model} on {server}." await ctx.edit(embed=embed, delete_after=30, view=None) @ollama_group.command(name="ps") async def ollama_proc_list( self, ctx: discord.ApplicationContext, server: typing.Annotated[ str, discord.Option( discord.SlashCommandOptionType.string, description="The server to use.", autocomplete=_ServerOptionAutocomplete, default=get_servers()[0].name ) ] ): """Checks the loaded models on the target server""" await ctx.defer() server = get_server(server) if not server: return await ctx.respond("\N{cross mark} Unknown server.") elif not await server.is_online(): return await ctx.respond(f"\N{cross mark} Server {server.name!r} is not responding") async with ollama_client(str(server.base_url)) as client: response = (await client.ps())["models"] if not response: embed = discord.Embed( title=f"No models loaded on {server}.", color=discord.Color.blurple() ) return await ctx.respond(embed=embed) embed = discord.Embed( title=f"Models loaded on {server}", color=discord.Color.blurple() ) for model in response[:25]: size = naturalsize(model["size"], binary=True) size_vram = naturalsize(model["size_vram"], binary=True) size_ram = naturalsize(model["size"] - model["size_vram"], binary=True) percent_in_vram = round(model["size_vram"] / model["size"] * 100) percent_in_ram = 100 - percent_in_vram expires = datetime.datetime.fromisoformat(model["expires_at"]) lines = [ f"* Size: {size}", f"* Unloaded: {discord.utils.format_dt(expires, style='R')}", ] if percent_in_ram > 0: lines.extend( [ f"* VRAM/RAM: {percent_in_vram}%/{percent_in_ram}%", f"* VRAM Size: {size_vram}", f"* RAM Size: {size_ram}" ] ) else: lines.append(f"* VRAM Size: {size_vram} (100%)") embed.add_field( name=model["model"], value="\n".join(lines), inline=False ) await ctx.respond(embed=embed) def setup(bot): bot.add_cog(Chat(bot))