import asyncio import io import logging import time import typing import contextlib import discord from discord import Interaction from ollama import AsyncClient, ResponseError, Options from discord.ext import commands from jimmy.utils import create_ollama_message, find_suitable_server, decorate_server_name as decorate_name from jimmy.config import get_servers, get_server from jimmy.db import OllamaThread from humanize import naturalsize, naturaldelta @contextlib.asynccontextmanager async def ollama_client(host: str, **kwargs) -> AsyncClient: host = str(host) client = AsyncClient(host, **kwargs) try: yield client finally: # Ollama doesn't auto-close the client, so we have to do it ourselves. await client._client.aclose() class StopDownloadView(discord.ui.View): def __init__(self, ctx: discord.ApplicationContext): super().__init__(timeout=None) self.ctx = ctx self.event = asyncio.Event() async def interaction_check(self, interaction: Interaction) -> bool: return interaction.user == self.ctx.user @discord.ui.button(label="Cancel", style=discord.ButtonStyle.danger) async def cancel_download(self, _, interaction: discord.Interaction): self.event.set() self.stop() await interaction.response.edit_message(view=None) async def get_available_tags_autocomplete(ctx: discord.AutocompleteContext): chosen_server = get_server(ctx.options.get("server") or get_servers()[0].name) async with ollama_client(str(chosen_server.base_url), timeout=2) as client: tags = (await client.list())["models"] v = [tag["model"] for tag in tags if ctx.value.casefold() in tag["model"].casefold()] return [ctx.value, *v][:25] _ServerOptionAutocomplete = discord.utils.basic_autocomplete( [x.name for x in get_servers()] ) class Chat(commands.Cog): def __init__(self, bot): self.bot = bot self.server_locks = {} for server in get_servers(): self.server_locks[server.name] = asyncio.Lock() self.log = logging.getLogger(__name__) ollama_group = discord.SlashCommandGroup( name="ollama", description="Commands related to ollama.", guild_only=True, max_concurrency=commands.MaxConcurrency(1, per=commands.BucketType.user, wait=False), cooldown=commands.CooldownMapping(commands.Cooldown(1, 10), commands.BucketType.user) ) @ollama_group.command() async def status(self, ctx: discord.ApplicationContext): """Checks the status on all servers.""" await ctx.defer() embed = discord.Embed( title="Ollama Statuses:", color=discord.Color.blurple() ) fields = {} for server in get_servers(): if server.throttle and self.server_locks[server.name].locked(): embed.add_field( name=decorate_name(server), value=f"\N{closed lock with key} In use.", inline=False ) fields[server] = len(embed.fields) - 1 continue else: embed.add_field( name=decorate_name(server), value=f"\N{hourglass with flowing sand} Waiting...", inline=False ) fields[server] = len(embed.fields) - 1 await ctx.respond(embed=embed) tasks = {} for server in get_servers(): if server.throttle and self.server_locks[server.name].locked(): continue tasks[server] = asyncio.create_task(server.is_online()) await asyncio.gather(*tasks.values()) for server, task in tasks.items(): if task.result(): embed.set_field_at( fields[server], name=decorate_name(server), value=f"\N{white heavy check mark} Online.", inline=False ) else: embed.set_field_at( fields[server], name=decorate_name(server), value=f"\N{cross mark} Offline.", inline=False ) await ctx.edit(embed=embed) @ollama_group.command(name="server-info") async def get_server_info( self, ctx: discord.ApplicationContext, server: typing.Annotated[ str, discord.Option( discord.SlashCommandOptionType.string, description="The server to use.", autocomplete=_ServerOptionAutocomplete, default=get_servers()[0].name ) ] ): """Gets information on a given server""" await ctx.defer() server = get_server(server) is_online = await server.is_online() y = "\N{white heavy check mark}" x = "\N{cross mark}" t = {True: y, False: x} rt = "VRAM" if server.gpu else "RAM" lines = [ f"Name: {server.name!r}", f"Base URL: {server.base_url!r}", f"GPU Enabled: {t[server.gpu]}", f"{rt}: {server.vram_gb:,} GB", f"Default Model: {server.default_model!r}", f"Is Online: {t[is_online]}" ] p = "```md\n" + "\n".join(lines) + "```" return await ctx.respond(p) @ollama_group.command(name="chat") async def start_ollama_chat( self, ctx: discord.ApplicationContext, prompt: str, system_prompt: typing.Annotated[ str | None, discord.Option( discord.SlashCommandOptionType.string, description="The system prompt to use.", default=None ) ], server: typing.Annotated[ str, discord.Option( discord.SlashCommandOptionType.string, description="The server to use.", autocomplete=_ServerOptionAutocomplete, default=get_servers()[0].name ) ], model: typing.Annotated[ str, discord.Option( discord.SlashCommandOptionType.string, description="The model to use.", autocomplete=get_available_tags_autocomplete, default="default" ) ], image: typing.Annotated[ discord.Attachment | None, discord.Option( discord.SlashCommandOptionType.attachment, description="The image to use for llava.", default=None ) ], thread_id: typing.Annotated[ str | None, discord.Option( discord.SlashCommandOptionType.string, description="The thread ID to continue.", default=None ) ], temperature: typing.Annotated[ float, discord.Option( discord.SlashCommandOptionType.number, description="The temperature to use.", default=1.5, min_value=0.0, max_value=2.0 ) ] ): """Have a chat with ollama""" await ctx.defer() server = get_server(server) if not server: return await ctx.respond("\N{cross mark} Unknown Server.") elif not await server.is_online(): await ctx.respond( content=f"{server} is offline. Finding a suitable server...", ) try: server = await find_suitable_server() except ValueError as err: return await ctx.edit(content=str(err), delete_after=30) await ctx.delete(delay=5) async with self.server_locks[server.name]: if model == "default": model = server.default_model async with ollama_client(str(server.base_url)) as client: client: AsyncClient self.log.info("Checking if %r has the model %r", server, model) tags = (await client.list())["models"] # Download code. It's recommended to collapse this in the editor. if model not in [x["model"] for x in tags]: embed = discord.Embed( title=f"Downloading {model} on {server}.", description=f"Initiating download...", color=discord.Color.blurple() ) view = StopDownloadView(ctx) await ctx.respond( embed=embed, view=view ) last_edit = 0 async with ctx.typing(): try: last_completed = 0 last_completed_ts = time.time() async for line in await client.pull(model, stream=True): if view.event.is_set(): embed.add_field(name="Error!", value="Download cancelled.") embed.colour = discord.Colour.red() await ctx.edit(embed=embed) return self.log.debug("Response from %r: %r", server, line) if line["status"] in { "pulling manifest", "verifying sha256 digest", "writing manifest", "removing any unused layers", "success" }: embed.description = line["status"].capitalize() else: total = line["total"] completed = line.get("completed", 0) percent = round(completed / total * 100, 1) pb_fill = "▰" * int(percent / 10) pb_empty = "▱" * (10 - int(percent / 10)) bytes_per_second = completed - last_completed bytes_per_second /= (time.time() - last_completed_ts) last_completed = completed last_completed_ts = time.time() mbps = round((bytes_per_second * 8) / 1024 / 1024) eta = (total - completed) / max(1, bytes_per_second) progress_bar = f"[{pb_fill}{pb_empty}]" ns_total = naturalsize(total, binary=True) ns_completed = naturalsize(completed, binary=True) embed.description = ( f"{line['status'].capitalize()} {percent}% {progress_bar} " f"({ns_completed}/{ns_total} @ {mbps} Mb/s) " f"[ETA: {naturaldelta(eta)}]" ) if time.time() - last_edit >= 2.5: await ctx.edit(embed=embed) last_edit = time.time() except ResponseError as err: if err.error.endswith("file does not exist"): await ctx.edit( embed=None, content="The model %r does not exist." % model, delete_after=60, view=None ) else: embed.add_field( name="Error!", value=err.error ) embed.colour = discord.Colour.red() await ctx.edit(embed=embed, view=None) return else: embed.colour = discord.Colour.green() embed.description = f"Downloaded {model} on {server}." await ctx.edit(embed=embed, delete_after=30, view=None) messages = [] thread = None if thread_id: thread = await OllamaThread.get_or_none(thread_id=thread_id) if thread: for msg in thread.messages: messages.append( await create_ollama_message(msg["content"], role=msg["role"]) ) else: await ctx.respond(content="No thread with that ID exists.", delete_after=30) if system_prompt: messages.append(await create_ollama_message(system_prompt, role="system")) messages.append(await create_ollama_message(prompt, images=[await image.read()] if image else None)) embed = discord.Embed(description="") embed.set_author( name=f"{model} @ {decorate_name(server)!r}" if server.gpu else model, icon_url="https://ollama.com/public/icon-64x64.png" ) view = StopDownloadView(ctx) msg = await ctx.respond( embed=embed, view=view ) last_edit = time.time() buffer = io.StringIO() async for response in await client.chat( model, messages, stream=True, options=Options( num_ctx=4096, low_vram=server.vram_gb < 8, temperature=temperature ) ): response: dict self.log.debug("Response from %r: %r", server, response) buffer.write(response["message"]["content"]) if len(buffer.getvalue()) > 4096: value = "... " + buffer.getvalue()[4:] else: value = buffer.getvalue() embed.description = value if view.event.is_set(): embed.add_field(name="Error!", value="Chat cancelled.") embed.colour = discord.Colour.red() await msg.edit(embed=embed, view=None) return if time.time() - last_edit >= 2.5: await msg.edit(embed=embed, view=view) last_edit = time.time() embed.colour = discord.Colour.green() if len(buffer.getvalue()) > 4096: file = discord.File( io.BytesIO(buffer.getvalue().encode()), filename="full-chat.txt" ) embed.add_field( name="Full chat", value="The chat was too long to fit in this message. " f"You can download the `full-chat.txt` file to see the full message." ) else: file = discord.utils.MISSING if not thread: thread = OllamaThread( messages=[{"role": m["role"], "content": m["content"]} for m in messages], ) await thread.save() embed.set_footer(text=f"Chat ID: {thread.thread_id}") await msg.edit(embed=embed, view=None, file=file) @ollama_group.command(name="pull") async def pull_ollama_model( self, ctx: discord.ApplicationContext, server: typing.Annotated[ str, discord.Option( discord.SlashCommandOptionType.string, description="The server to use.", autocomplete=_ServerOptionAutocomplete, default=get_servers()[0].name ) ], model: typing.Annotated[ str, discord.Option( discord.SlashCommandOptionType.string, description="The model to use.", autocomplete=get_available_tags_autocomplete, default="llama3:latest" ) ], ): """Downloads a tag on the target server""" await ctx.defer() server = get_server(server) if not server: return await ctx.respond("\N{cross mark} Unknown server.") elif not await server.is_online(): return await ctx.respond(f"\N{cross mark} Server {server.name!r} is not responding") embed = discord.Embed( title=f"Downloading {model} on {server}.", description=f"Initiating download...", color=discord.Color.blurple() ) view = StopDownloadView(ctx) await ctx.respond( embed=embed, view=view ) last_edit = 0 async with ctx.typing(): try: last_completed = 0 last_completed_ts = time.time() async for line in await client.pull(model, stream=True): if view.event.is_set(): embed.add_field(name="Error!", value="Download cancelled.") embed.colour = discord.Colour.red() await ctx.edit(embed=embed) return self.log.debug("Response from %r: %r", server, line) if line["status"] in { "pulling manifest", "verifying sha256 digest", "writing manifest", "removing any unused layers", "success" }: embed.description = line["status"].capitalize() else: total = line["total"] completed = line.get("completed", 0) percent = round(completed / total * 100, 1) pb_fill = "▰" * int(percent / 10) pb_empty = "▱" * (10 - int(percent / 10)) bytes_per_second = completed - last_completed bytes_per_second /= (time.time() - last_completed_ts) last_completed = completed last_completed_ts = time.time() mbps = round((bytes_per_second * 8) / 1024 / 1024) eta = (total - completed) / max(1, bytes_per_second) progress_bar = f"[{pb_fill}{pb_empty}]" ns_total = naturalsize(total, binary=True) ns_completed = naturalsize(completed, binary=True) embed.description = ( f"{line['status'].capitalize()} {percent}% {progress_bar} " f"({ns_completed}/{ns_total} @ {mbps} Mb/s) " f"[ETA: {naturaldelta(eta)}]" ) if time.time() - last_edit >= 2.5: await ctx.edit(embed=embed) last_edit = time.time() except ResponseError as err: if err.error.endswith("file does not exist"): await ctx.edit( embed=None, content="The model %r does not exist." % model, delete_after=60, view=None ) else: embed.add_field( name="Error!", value=err.error ) embed.colour = discord.Colour.red() await ctx.edit(embed=embed, view=None) return else: embed.colour = discord.Colour.green() embed.description = f"Downloaded {model} on {server}." await ctx.edit(embed=embed, delete_after=30, view=None) def setup(bot): bot.add_cog(Chat(bot))