Move to interactions

This commit is contained in:
Nexus 2023-11-13 19:29:44 +00:00
parent f3e1986313
commit 11f2455025
Signed by: nex
GPG key ID: 0FA334385D0B689F

View file

@ -1,9 +1,11 @@
import asyncio import asyncio
import base64
import functools import functools
import glob import glob
import io import io
import json import json
import typing import typing
import zlib
import math import math
import os import os
@ -1810,7 +1812,7 @@ class OtherCog(commands.Cog):
) )
class OllamaKillSwitchView(discord.ui.View): class OllamaKillSwitchView(discord.ui.View):
def __init__(self, ctx: commands.Context, msg: discord.Message): def __init__(self, ctx: discord.ApplicationContext, msg: discord.Message):
super().__init__(timeout=None) super().__init__(timeout=None)
self.ctx = ctx self.ctx = ctx
self.msg = msg self.msg = msg
@ -1831,12 +1833,63 @@ class OtherCog(commands.Cog):
await interaction.edit_original_response(view=self) await interaction.edit_original_response(view=self)
self.stop() self.stop()
@commands.command( @commands.slash_command()
usage="[model:<name:tag>] [server:<ip[:port]>] <query>" @commands.max_concurrency(1, commands.BucketType.user, wait=False)
) async def ollama(
@commands.max_concurrency(1, commands.BucketType.user, wait=True) self,
async def ollama(self, ctx: commands.Context, *, query: str): ctx: discord.ApplicationContext,
model: str = "orca-mini",
query: str = None,
context: str = None
):
""":3""" """:3"""
with open("./assets/ollama-prompt.txt") as file:
system_prompt = file.read().replace("\n", " ").strip()
if query is None:
class InputPrompt(discord.ui.Modal):
def __init__(self, is_owner: bool):
super().__init__(
discord.ui.InputText(
label="User Prompt",
placeholder="Enter prompt",
min_length=1,
max_length=4000,
style=discord.InputTextStyle.long,
),
title="Enter prompt",
timeout=120
)
if is_owner:
self.add_item(
discord.ui.InputText(
label="System Prompt",
placeholder="Enter prompt",
min_length=1,
max_length=4000,
style=discord.InputTextStyle.long,
value=system_prompt,
)
)
self.user_prompt = None
self.system_prompt = system_prompt
async def callback(self, interaction: discord.Interaction):
self.user_prompt = self.children[0].value
if len(self.children) > 1:
self.system_prompt = self.children[1].value
await interaction.response.defer()
self.stop()
modal = InputPrompt(await self.bot.is_owner(ctx.author))
await ctx.send_modal(modal)
await modal.wait()
query = modal.user_prompt
if not modal.user_prompt:
return
system_prompt = modal.system_prompt or system_prompt
else:
await ctx.defer()
content = None content = None
try_hosts = { try_hosts = {
"127.0.0.1:11434": "localhost", "127.0.0.1:11434": "localhost",
@ -1844,51 +1897,29 @@ class OtherCog(commands.Cog):
"100.66.187.46:11434": "Nexbox", "100.66.187.46:11434": "Nexbox",
"100.116.242.161:11434": "PortaPi" "100.116.242.161:11434": "PortaPi"
} }
if query.startswith("model:"):
model, query = query.split(" ", 1)
model = model[6:].casefold()
try:
_name, _tag = model.split(":", 1)
except ValueError:
model += ":latest"
else:
model = "orca-mini"
model = model.casefold() model = model.casefold()
if not await self.bot.is_owner(ctx.author): if not await self.bot.is_owner(ctx.author):
if not model.startswith("orca-mini"): if not model.startswith("orca-mini"):
await ctx.reply(":warning: You can only use `orca-mini` models.", delete_after=30) await ctx.respond(
":warning: You can only use `orca-mini` models.",
delete_after=30,
ephemeral=True
)
model = "orca-mini" model = "orca-mini"
async with httpx.AsyncClient(follow_redirects=True) as client:
if query.startswith("server:"): for host in try_hosts.keys():
host, query = query.split(" ", 1) try:
host = host[7:] response = await client.get(
try: f"http://{host}/api/tags",
host, port = host.split(":", 1) )
int(port) response.raise_for_status()
except ValueError: except (httpx.TransportError, httpx.NetworkError, httpx.HTTPStatusError):
host += ":11434" continue
else:
# try_hosts = [
# "127.0.0.1:11434", # Localhost
# "100.106.34.86:11434", # Laptop
# "100.66.187.46:11434", # optiplex
# "100.116.242.161:11434" # Raspberry Pi
# ]
async with httpx.AsyncClient(follow_redirects=True) as client:
for host in try_hosts.keys():
try:
response = await client.get(
f"http://{host}/api/tags",
)
response.raise_for_status()
except (httpx.TransportError, httpx.NetworkError, httpx.HTTPStatusError):
continue
else:
break
else: else:
return await ctx.reply(":x: No servers available.") break
else:
return await ctx.respond(":x: No servers available.")
embed = discord.Embed( embed = discord.Embed(
colour=discord.Colour.greyple() colour=discord.Colour.greyple()
@ -1900,7 +1931,7 @@ class OtherCog(commands.Cog):
) )
embed.set_footer(text="Using server {} ({})".format(host, try_hosts.get(host, "Other"))) embed.set_footer(text="Using server {} ({})".format(host, try_hosts.get(host, "Other")))
msg = await ctx.reply(embed=embed) msg = await ctx.respond(embed=embed, ephemeral=False)
async with httpx.AsyncClient(base_url=f"http://{host}/api", follow_redirects=True) as client: async with httpx.AsyncClient(base_url=f"http://{host}/api", follow_redirects=True) as client:
# get models # get models
try: try:
@ -1975,8 +2006,6 @@ class OtherCog(commands.Cog):
embed.set_footer(text=f"Powered by Ollama • {host} ({try_hosts.get(host, 'Other')})") embed.set_footer(text=f"Powered by Ollama • {host} ({try_hosts.get(host, 'Other')})")
await msg.edit(embed=embed) await msg.edit(embed=embed)
async with ctx.channel.typing(): async with ctx.channel.typing():
with open("./assets/ollama-prompt.txt") as file:
system_prompt = file.read().replace("\n", " ").strip()
async with client.stream( async with client.stream(
"POST", "POST",
"/generate", "/generate",
@ -2069,21 +2098,45 @@ class OtherCog(commands.Cog):
load_time_spent = get_time_spent(chunk.get("load_duration", 999999999.0)) load_time_spent = get_time_spent(chunk.get("load_duration", 999999999.0))
sample_time_sent = get_time_spent(chunk.get("sample_duration", 999999999.0)) sample_time_sent = get_time_spent(chunk.get("sample_duration", 999999999.0))
prompt_eval_time_spent = get_time_spent(chunk.get("prompt_eval_duration", 999999999.0)) prompt_eval_time_spent = get_time_spent(chunk.get("prompt_eval_duration", 999999999.0))
context: Optional[list[int]] = chunk.get("context")
# noinspection PyTypeChecker
if context:
context_json = json.dumps(context)
start = time()
context_json_compressed = await asyncio.to_thread(
zlib.compress,
context_json.encode(),
9
)
end = time()
compress_time_spent = format(end * 1000 - start * 1000, ",")
context: str = base64.b64encode(context_json_compressed).decode()
else:
compress_time_spent = "N/A"
context = None
value = ("* Total: {}\n" value = ("* Total: {}\n"
"* Model load: {}\n" "* Model load: {}\n"
"* Sample generation: {}\n" "* Sample generation: {}\n"
"* Prompt eval: {}\n" "* Prompt eval: {}\n"
"* Response generation: {}").format( "* Response generation: {}\n"
"* Context compression: {} milliseconds").format(
total_time_spent, total_time_spent,
load_time_spent, load_time_spent,
sample_time_sent, sample_time_sent,
prompt_eval_time_spent, prompt_eval_time_spent,
eval_time_spent eval_time_spent,
compress_time_spent
) )
embed.add_field( embed.add_field(
name="Timings", name="Timings",
value=value value=value
) )
if context:
embed.add_field(
name="Context",
value=f"```\n{context}\n```"[:1024],
inline=True
)
await msg.edit(content=None, embed=embed, view=None) await msg.edit(content=None, embed=embed, view=None)
self.ollama_locks.pop(msg, None) self.ollama_locks.pop(msg, None)