diff --git a/cogs/other.py b/cogs/other.py index 3be259a..6761d04 100644 --- a/cogs/other.py +++ b/cogs/other.py @@ -1,9 +1,11 @@ import asyncio +import base64 import functools import glob import io import json import typing +import zlib import math import os @@ -1810,7 +1812,7 @@ class OtherCog(commands.Cog): ) class OllamaKillSwitchView(discord.ui.View): - def __init__(self, ctx: commands.Context, msg: discord.Message): + def __init__(self, ctx: discord.ApplicationContext, msg: discord.Message): super().__init__(timeout=None) self.ctx = ctx self.msg = msg @@ -1831,12 +1833,63 @@ class OtherCog(commands.Cog): await interaction.edit_original_response(view=self) self.stop() - @commands.command( - usage="[model:] [server:] " - ) - @commands.max_concurrency(1, commands.BucketType.user, wait=True) - async def ollama(self, ctx: commands.Context, *, query: str): + @commands.slash_command() + @commands.max_concurrency(1, commands.BucketType.user, wait=False) + async def ollama( + self, + ctx: discord.ApplicationContext, + model: str = "orca-mini", + query: str = None, + context: str = None + ): """:3""" + with open("./assets/ollama-prompt.txt") as file: + system_prompt = file.read().replace("\n", " ").strip() + if query is None: + class InputPrompt(discord.ui.Modal): + def __init__(self, is_owner: bool): + super().__init__( + discord.ui.InputText( + label="User Prompt", + placeholder="Enter prompt", + min_length=1, + max_length=4000, + style=discord.InputTextStyle.long, + ), + title="Enter prompt", + timeout=120 + ) + if is_owner: + self.add_item( + discord.ui.InputText( + label="System Prompt", + placeholder="Enter prompt", + min_length=1, + max_length=4000, + style=discord.InputTextStyle.long, + value=system_prompt, + ) + ) + + self.user_prompt = None + self.system_prompt = system_prompt + + async def callback(self, interaction: discord.Interaction): + self.user_prompt = self.children[0].value + if len(self.children) > 1: + self.system_prompt = self.children[1].value + await interaction.response.defer() + self.stop() + + modal = InputPrompt(await self.bot.is_owner(ctx.author)) + await ctx.send_modal(modal) + await modal.wait() + query = modal.user_prompt + if not modal.user_prompt: + return + system_prompt = modal.system_prompt or system_prompt + else: + await ctx.defer() content = None try_hosts = { "127.0.0.1:11434": "localhost", @@ -1844,51 +1897,29 @@ class OtherCog(commands.Cog): "100.66.187.46:11434": "Nexbox", "100.116.242.161:11434": "PortaPi" } - if query.startswith("model:"): - model, query = query.split(" ", 1) - model = model[6:].casefold() - try: - _name, _tag = model.split(":", 1) - except ValueError: - model += ":latest" - else: - model = "orca-mini" - model = model.casefold() if not await self.bot.is_owner(ctx.author): if not model.startswith("orca-mini"): - await ctx.reply(":warning: You can only use `orca-mini` models.", delete_after=30) + await ctx.respond( + ":warning: You can only use `orca-mini` models.", + delete_after=30, + ephemeral=True + ) model = "orca-mini" - - if query.startswith("server:"): - host, query = query.split(" ", 1) - host = host[7:] - try: - host, port = host.split(":", 1) - int(port) - except ValueError: - host += ":11434" - else: - # try_hosts = [ - # "127.0.0.1:11434", # Localhost - # "100.106.34.86:11434", # Laptop - # "100.66.187.46:11434", # optiplex - # "100.116.242.161:11434" # Raspberry Pi - # ] - async with httpx.AsyncClient(follow_redirects=True) as client: - for host in try_hosts.keys(): - try: - response = await client.get( - f"http://{host}/api/tags", - ) - response.raise_for_status() - except (httpx.TransportError, httpx.NetworkError, httpx.HTTPStatusError): - continue - else: - break + async with httpx.AsyncClient(follow_redirects=True) as client: + for host in try_hosts.keys(): + try: + response = await client.get( + f"http://{host}/api/tags", + ) + response.raise_for_status() + except (httpx.TransportError, httpx.NetworkError, httpx.HTTPStatusError): + continue else: - return await ctx.reply(":x: No servers available.") + break + else: + return await ctx.respond(":x: No servers available.") embed = discord.Embed( colour=discord.Colour.greyple() @@ -1900,7 +1931,7 @@ class OtherCog(commands.Cog): ) embed.set_footer(text="Using server {} ({})".format(host, try_hosts.get(host, "Other"))) - msg = await ctx.reply(embed=embed) + msg = await ctx.respond(embed=embed, ephemeral=False) async with httpx.AsyncClient(base_url=f"http://{host}/api", follow_redirects=True) as client: # get models try: @@ -1975,8 +2006,6 @@ class OtherCog(commands.Cog): embed.set_footer(text=f"Powered by Ollama • {host} ({try_hosts.get(host, 'Other')})") await msg.edit(embed=embed) async with ctx.channel.typing(): - with open("./assets/ollama-prompt.txt") as file: - system_prompt = file.read().replace("\n", " ").strip() async with client.stream( "POST", "/generate", @@ -2069,21 +2098,45 @@ class OtherCog(commands.Cog): load_time_spent = get_time_spent(chunk.get("load_duration", 999999999.0)) sample_time_sent = get_time_spent(chunk.get("sample_duration", 999999999.0)) prompt_eval_time_spent = get_time_spent(chunk.get("prompt_eval_duration", 999999999.0)) + context: Optional[list[int]] = chunk.get("context") + # noinspection PyTypeChecker + if context: + context_json = json.dumps(context) + start = time() + context_json_compressed = await asyncio.to_thread( + zlib.compress, + context_json.encode(), + 9 + ) + end = time() + compress_time_spent = format(end * 1000 - start * 1000, ",") + context: str = base64.b64encode(context_json_compressed).decode() + else: + compress_time_spent = "N/A" + context = None value = ("* Total: {}\n" "* Model load: {}\n" "* Sample generation: {}\n" "* Prompt eval: {}\n" - "* Response generation: {}").format( + "* Response generation: {}\n" + "* Context compression: {} milliseconds").format( total_time_spent, load_time_spent, sample_time_sent, prompt_eval_time_spent, - eval_time_spent + eval_time_spent, + compress_time_spent ) embed.add_field( name="Timings", value=value ) + if context: + embed.add_field( + name="Context", + value=f"```\n{context}\n```"[:1024], + inline=True + ) await msg.edit(content=None, embed=embed, view=None) self.ollama_locks.pop(msg, None)