mirror of
https://github.com/nexy7574/LCC-bot.git
synced 2024-09-19 18:16:34 +01:00
Fix ollama server selector [pt 3]
This commit is contained in:
parent
bc0123be24
commit
c785c22800
1 changed files with 38 additions and 6 deletions
|
@ -34,6 +34,7 @@ import httpx
|
||||||
import psutil
|
import psutil
|
||||||
import pytesseract
|
import pytesseract
|
||||||
import pyttsx3
|
import pyttsx3
|
||||||
|
from discord import Interaction
|
||||||
from discord.ext import commands
|
from discord.ext import commands
|
||||||
from dns import asyncresolver
|
from dns import asyncresolver
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
@ -1934,15 +1935,22 @@ class OtherCog(commands.Cog):
|
||||||
"owner": 421698654189912064
|
"owner": 421698654189912064
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
H_DEFAULT = {
|
||||||
|
"name": "Other",
|
||||||
|
"allow": ["*"],
|
||||||
|
"owner": 1019217990111199243
|
||||||
|
}
|
||||||
|
|
||||||
def model_is_allowed(model_name: str, srv: dict[str, str | list[str] | int]) -> bool:
|
def model_is_allowed(model_name: str, _srv: dict[str, str | list[str] | int]) -> bool:
|
||||||
for pat in srv.get("allow", ['*']):
|
if _srv["owner"] == ctx.user.id:
|
||||||
|
return True
|
||||||
|
for pat in _srv.get("allow", ['*']):
|
||||||
if not fnmatch.fnmatch(model_name.casefold(), pat.casefold()):
|
if not fnmatch.fnmatch(model_name.casefold(), pat.casefold()):
|
||||||
print(
|
print(
|
||||||
"Server %r does not support %r (only %r.)" % (
|
"Server %r does not support %r (only %r.)" % (
|
||||||
srv['name'],
|
_srv['name'],
|
||||||
model_name,
|
model_name,
|
||||||
', '.join(srv['allow'])
|
', '.join(_srv['allow'])
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return False
|
return False
|
||||||
|
@ -1955,6 +1963,9 @@ class OtherCog(commands.Cog):
|
||||||
)
|
)
|
||||||
self.chosen_server = None
|
self.chosen_server = None
|
||||||
|
|
||||||
|
async def interaction_check(self, interaction: Interaction) -> bool:
|
||||||
|
return interaction.user == ctx.user
|
||||||
|
|
||||||
@discord.ui.select(
|
@discord.ui.select(
|
||||||
placeholder="Choose a server.",
|
placeholder="Choose a server.",
|
||||||
custom_id="select",
|
custom_id="select",
|
||||||
|
@ -2010,17 +2021,33 @@ class OtherCog(commands.Cog):
|
||||||
self.chosen_server = f"{_modal.hostname}:{_modal.port}"
|
self.chosen_server = f"{_modal.hostname}:{_modal.port}"
|
||||||
else:
|
else:
|
||||||
self.chosen_server = item.values[0]
|
self.chosen_server = item.values[0]
|
||||||
|
await interaction.response.defer(ephemeral=True)
|
||||||
|
await interaction.followup.send(
|
||||||
|
f"\N{white heavy check mark} Selected server {self.chosen_server}/",
|
||||||
|
ephemeral=True
|
||||||
|
)
|
||||||
self.stop()
|
self.stop()
|
||||||
|
|
||||||
if server == "auto":
|
if server == "auto":
|
||||||
selector = ServerSelector()
|
selector = ServerSelector()
|
||||||
await ctx.send("Select a server:", view=selector)
|
selector_message = await ctx.respond("Select a server:", view=selector)
|
||||||
await selector.wait()
|
await selector.wait()
|
||||||
if not selector.chosen_server:
|
if not selector.chosen_server:
|
||||||
return
|
return
|
||||||
host = selector.chosen_server
|
host = selector.chosen_server
|
||||||
|
await selector_message.delete(delay=1)
|
||||||
else:
|
else:
|
||||||
host = server
|
host = server
|
||||||
|
srv = servers.get(host, H_DEFAULT)
|
||||||
|
if not model_is_allowed(model, srv):
|
||||||
|
return await ctx.respond(
|
||||||
|
":x: <@{!s}> does not allow you to run that model on the server {!r}. You can, however, use"
|
||||||
|
" any of the following: {}".format(
|
||||||
|
srv["owner"],
|
||||||
|
srv["name"],
|
||||||
|
", ".join(srv.get("allowed", ["*"]))
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
content = None
|
content = None
|
||||||
embed = discord.Embed(
|
embed = discord.Embed(
|
||||||
|
@ -2031,7 +2058,12 @@ class OtherCog(commands.Cog):
|
||||||
url=f"http://{host}",
|
url=f"http://{host}",
|
||||||
icon_url="https://cdn.discordapp.com/emojis/1101463077586735174.gif"
|
icon_url="https://cdn.discordapp.com/emojis/1101463077586735174.gif"
|
||||||
)
|
)
|
||||||
embed.set_footer(text="Using server {} ({})".format(host, servers.get(host, "Other")))
|
embed.set_footer(
|
||||||
|
text="Using server {} ({})".format(
|
||||||
|
host,
|
||||||
|
servers.get(host, H_DEFAULT)['name']
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
msg = await ctx.respond(embed=embed, ephemeral=False)
|
msg = await ctx.respond(embed=embed, ephemeral=False)
|
||||||
async with httpx.AsyncClient(follow_redirects=True) as client:
|
async with httpx.AsyncClient(follow_redirects=True) as client:
|
||||||
|
|
Loading…
Reference in a new issue