From c785c22800f13ab800ce91faa3fc2d09f7fd89f2 Mon Sep 17 00:00:00 2001 From: nex Date: Wed, 22 Nov 2023 16:01:11 +0000 Subject: [PATCH] Fix ollama server selector [pt 3] --- cogs/other.py | 44 ++++++++++++++++++++++++++++++++++++++------ 1 file changed, 38 insertions(+), 6 deletions(-) diff --git a/cogs/other.py b/cogs/other.py index 503559b..b3edcb4 100644 --- a/cogs/other.py +++ b/cogs/other.py @@ -34,6 +34,7 @@ import httpx import psutil import pytesseract import pyttsx3 +from discord import Interaction from discord.ext import commands from dns import asyncresolver from PIL import Image @@ -1934,15 +1935,22 @@ class OtherCog(commands.Cog): "owner": 421698654189912064 }, } + H_DEFAULT = { + "name": "Other", + "allow": ["*"], + "owner": 1019217990111199243 + } - def model_is_allowed(model_name: str, srv: dict[str, str | list[str] | int]) -> bool: - for pat in srv.get("allow", ['*']): + def model_is_allowed(model_name: str, _srv: dict[str, str | list[str] | int]) -> bool: + if _srv["owner"] == ctx.user.id: + return True + for pat in _srv.get("allow", ['*']): if not fnmatch.fnmatch(model_name.casefold(), pat.casefold()): print( "Server %r does not support %r (only %r.)" % ( - srv['name'], + _srv['name'], model_name, - ', '.join(srv['allow']) + ', '.join(_srv['allow']) ) ) return False @@ -1955,6 +1963,9 @@ class OtherCog(commands.Cog): ) self.chosen_server = None + async def interaction_check(self, interaction: Interaction) -> bool: + return interaction.user == ctx.user + @discord.ui.select( placeholder="Choose a server.", custom_id="select", @@ -2010,17 +2021,33 @@ class OtherCog(commands.Cog): self.chosen_server = f"{_modal.hostname}:{_modal.port}" else: self.chosen_server = item.values[0] + await interaction.response.defer(ephemeral=True) + await interaction.followup.send( + f"\N{white heavy check mark} Selected server {self.chosen_server}/", + ephemeral=True + ) self.stop() if server == "auto": selector = ServerSelector() - await ctx.send("Select a server:", view=selector) + selector_message = await ctx.respond("Select a server:", view=selector) await selector.wait() if not selector.chosen_server: return host = selector.chosen_server + await selector_message.delete(delay=1) else: host = server + srv = servers.get(host, H_DEFAULT) + if not model_is_allowed(model, srv): + return await ctx.respond( + ":x: <@{!s}> does not allow you to run that model on the server {!r}. You can, however, use" + " any of the following: {}".format( + srv["owner"], + srv["name"], + ", ".join(srv.get("allowed", ["*"])) + ) + ) content = None embed = discord.Embed( @@ -2031,7 +2058,12 @@ class OtherCog(commands.Cog): url=f"http://{host}", icon_url="https://cdn.discordapp.com/emojis/1101463077586735174.gif" ) - embed.set_footer(text="Using server {} ({})".format(host, servers.get(host, "Other"))) + embed.set_footer( + text="Using server {} ({})".format( + host, + servers.get(host, H_DEFAULT)['name'] + ) + ) msg = await ctx.respond(embed=embed, ephemeral=False) async with httpx.AsyncClient(follow_redirects=True) as client: