Move to interactions

This commit is contained in:
Nexus 2023-11-13 19:29:44 +00:00
parent f3e1986313
commit 11f2455025
Signed by: nex
GPG key ID: 0FA334385D0B689F

View file

@ -1,9 +1,11 @@
import asyncio
import base64
import functools
import glob
import io
import json
import typing
import zlib
import math
import os
@ -1810,7 +1812,7 @@ class OtherCog(commands.Cog):
)
class OllamaKillSwitchView(discord.ui.View):
def __init__(self, ctx: commands.Context, msg: discord.Message):
def __init__(self, ctx: discord.ApplicationContext, msg: discord.Message):
super().__init__(timeout=None)
self.ctx = ctx
self.msg = msg
@ -1831,12 +1833,63 @@ class OtherCog(commands.Cog):
await interaction.edit_original_response(view=self)
self.stop()
@commands.command(
usage="[model:<name:tag>] [server:<ip[:port]>] <query>"
)
@commands.max_concurrency(1, commands.BucketType.user, wait=True)
async def ollama(self, ctx: commands.Context, *, query: str):
@commands.slash_command()
@commands.max_concurrency(1, commands.BucketType.user, wait=False)
async def ollama(
self,
ctx: discord.ApplicationContext,
model: str = "orca-mini",
query: str = None,
context: str = None
):
""":3"""
with open("./assets/ollama-prompt.txt") as file:
system_prompt = file.read().replace("\n", " ").strip()
if query is None:
class InputPrompt(discord.ui.Modal):
def __init__(self, is_owner: bool):
super().__init__(
discord.ui.InputText(
label="User Prompt",
placeholder="Enter prompt",
min_length=1,
max_length=4000,
style=discord.InputTextStyle.long,
),
title="Enter prompt",
timeout=120
)
if is_owner:
self.add_item(
discord.ui.InputText(
label="System Prompt",
placeholder="Enter prompt",
min_length=1,
max_length=4000,
style=discord.InputTextStyle.long,
value=system_prompt,
)
)
self.user_prompt = None
self.system_prompt = system_prompt
async def callback(self, interaction: discord.Interaction):
self.user_prompt = self.children[0].value
if len(self.children) > 1:
self.system_prompt = self.children[1].value
await interaction.response.defer()
self.stop()
modal = InputPrompt(await self.bot.is_owner(ctx.author))
await ctx.send_modal(modal)
await modal.wait()
query = modal.user_prompt
if not modal.user_prompt:
return
system_prompt = modal.system_prompt or system_prompt
else:
await ctx.defer()
content = None
try_hosts = {
"127.0.0.1:11434": "localhost",
@ -1844,51 +1897,29 @@ class OtherCog(commands.Cog):
"100.66.187.46:11434": "Nexbox",
"100.116.242.161:11434": "PortaPi"
}
if query.startswith("model:"):
model, query = query.split(" ", 1)
model = model[6:].casefold()
try:
_name, _tag = model.split(":", 1)
except ValueError:
model += ":latest"
else:
model = "orca-mini"
model = model.casefold()
if not await self.bot.is_owner(ctx.author):
if not model.startswith("orca-mini"):
await ctx.reply(":warning: You can only use `orca-mini` models.", delete_after=30)
await ctx.respond(
":warning: You can only use `orca-mini` models.",
delete_after=30,
ephemeral=True
)
model = "orca-mini"
if query.startswith("server:"):
host, query = query.split(" ", 1)
host = host[7:]
try:
host, port = host.split(":", 1)
int(port)
except ValueError:
host += ":11434"
else:
# try_hosts = [
# "127.0.0.1:11434", # Localhost
# "100.106.34.86:11434", # Laptop
# "100.66.187.46:11434", # optiplex
# "100.116.242.161:11434" # Raspberry Pi
# ]
async with httpx.AsyncClient(follow_redirects=True) as client:
for host in try_hosts.keys():
try:
response = await client.get(
f"http://{host}/api/tags",
)
response.raise_for_status()
except (httpx.TransportError, httpx.NetworkError, httpx.HTTPStatusError):
continue
else:
break
async with httpx.AsyncClient(follow_redirects=True) as client:
for host in try_hosts.keys():
try:
response = await client.get(
f"http://{host}/api/tags",
)
response.raise_for_status()
except (httpx.TransportError, httpx.NetworkError, httpx.HTTPStatusError):
continue
else:
return await ctx.reply(":x: No servers available.")
break
else:
return await ctx.respond(":x: No servers available.")
embed = discord.Embed(
colour=discord.Colour.greyple()
@ -1900,7 +1931,7 @@ class OtherCog(commands.Cog):
)
embed.set_footer(text="Using server {} ({})".format(host, try_hosts.get(host, "Other")))
msg = await ctx.reply(embed=embed)
msg = await ctx.respond(embed=embed, ephemeral=False)
async with httpx.AsyncClient(base_url=f"http://{host}/api", follow_redirects=True) as client:
# get models
try:
@ -1975,8 +2006,6 @@ class OtherCog(commands.Cog):
embed.set_footer(text=f"Powered by Ollama • {host} ({try_hosts.get(host, 'Other')})")
await msg.edit(embed=embed)
async with ctx.channel.typing():
with open("./assets/ollama-prompt.txt") as file:
system_prompt = file.read().replace("\n", " ").strip()
async with client.stream(
"POST",
"/generate",
@ -2069,21 +2098,45 @@ class OtherCog(commands.Cog):
load_time_spent = get_time_spent(chunk.get("load_duration", 999999999.0))
sample_time_sent = get_time_spent(chunk.get("sample_duration", 999999999.0))
prompt_eval_time_spent = get_time_spent(chunk.get("prompt_eval_duration", 999999999.0))
context: Optional[list[int]] = chunk.get("context")
# noinspection PyTypeChecker
if context:
context_json = json.dumps(context)
start = time()
context_json_compressed = await asyncio.to_thread(
zlib.compress,
context_json.encode(),
9
)
end = time()
compress_time_spent = format(end * 1000 - start * 1000, ",")
context: str = base64.b64encode(context_json_compressed).decode()
else:
compress_time_spent = "N/A"
context = None
value = ("* Total: {}\n"
"* Model load: {}\n"
"* Sample generation: {}\n"
"* Prompt eval: {}\n"
"* Response generation: {}").format(
"* Response generation: {}\n"
"* Context compression: {} milliseconds").format(
total_time_spent,
load_time_spent,
sample_time_sent,
prompt_eval_time_spent,
eval_time_spent
eval_time_spent,
compress_time_spent
)
embed.add_field(
name="Timings",
value=value
)
if context:
embed.add_field(
name="Context",
value=f"```\n{context}\n```"[:1024],
inline=True
)
await msg.edit(content=None, embed=embed, view=None)
self.ollama_locks.pop(msg, None)