mirror of
https://github.com/nexy7574/LCC-bot.git
synced 2024-09-19 18:16:34 +01:00
Move to interactions
This commit is contained in:
parent
f3e1986313
commit
11f2455025
1 changed files with 103 additions and 50 deletions
131
cogs/other.py
131
cogs/other.py
|
@ -1,9 +1,11 @@
|
|||
import asyncio
|
||||
import base64
|
||||
import functools
|
||||
import glob
|
||||
import io
|
||||
import json
|
||||
import typing
|
||||
import zlib
|
||||
|
||||
import math
|
||||
import os
|
||||
|
@ -1810,7 +1812,7 @@ class OtherCog(commands.Cog):
|
|||
)
|
||||
|
||||
class OllamaKillSwitchView(discord.ui.View):
|
||||
def __init__(self, ctx: commands.Context, msg: discord.Message):
|
||||
def __init__(self, ctx: discord.ApplicationContext, msg: discord.Message):
|
||||
super().__init__(timeout=None)
|
||||
self.ctx = ctx
|
||||
self.msg = msg
|
||||
|
@ -1831,12 +1833,63 @@ class OtherCog(commands.Cog):
|
|||
await interaction.edit_original_response(view=self)
|
||||
self.stop()
|
||||
|
||||
@commands.command(
|
||||
usage="[model:<name:tag>] [server:<ip[:port]>] <query>"
|
||||
)
|
||||
@commands.max_concurrency(1, commands.BucketType.user, wait=True)
|
||||
async def ollama(self, ctx: commands.Context, *, query: str):
|
||||
@commands.slash_command()
|
||||
@commands.max_concurrency(1, commands.BucketType.user, wait=False)
|
||||
async def ollama(
|
||||
self,
|
||||
ctx: discord.ApplicationContext,
|
||||
model: str = "orca-mini",
|
||||
query: str = None,
|
||||
context: str = None
|
||||
):
|
||||
""":3"""
|
||||
with open("./assets/ollama-prompt.txt") as file:
|
||||
system_prompt = file.read().replace("\n", " ").strip()
|
||||
if query is None:
|
||||
class InputPrompt(discord.ui.Modal):
|
||||
def __init__(self, is_owner: bool):
|
||||
super().__init__(
|
||||
discord.ui.InputText(
|
||||
label="User Prompt",
|
||||
placeholder="Enter prompt",
|
||||
min_length=1,
|
||||
max_length=4000,
|
||||
style=discord.InputTextStyle.long,
|
||||
),
|
||||
title="Enter prompt",
|
||||
timeout=120
|
||||
)
|
||||
if is_owner:
|
||||
self.add_item(
|
||||
discord.ui.InputText(
|
||||
label="System Prompt",
|
||||
placeholder="Enter prompt",
|
||||
min_length=1,
|
||||
max_length=4000,
|
||||
style=discord.InputTextStyle.long,
|
||||
value=system_prompt,
|
||||
)
|
||||
)
|
||||
|
||||
self.user_prompt = None
|
||||
self.system_prompt = system_prompt
|
||||
|
||||
async def callback(self, interaction: discord.Interaction):
|
||||
self.user_prompt = self.children[0].value
|
||||
if len(self.children) > 1:
|
||||
self.system_prompt = self.children[1].value
|
||||
await interaction.response.defer()
|
||||
self.stop()
|
||||
|
||||
modal = InputPrompt(await self.bot.is_owner(ctx.author))
|
||||
await ctx.send_modal(modal)
|
||||
await modal.wait()
|
||||
query = modal.user_prompt
|
||||
if not modal.user_prompt:
|
||||
return
|
||||
system_prompt = modal.system_prompt or system_prompt
|
||||
else:
|
||||
await ctx.defer()
|
||||
content = None
|
||||
try_hosts = {
|
||||
"127.0.0.1:11434": "localhost",
|
||||
|
@ -1844,38 +1897,16 @@ class OtherCog(commands.Cog):
|
|||
"100.66.187.46:11434": "Nexbox",
|
||||
"100.116.242.161:11434": "PortaPi"
|
||||
}
|
||||
if query.startswith("model:"):
|
||||
model, query = query.split(" ", 1)
|
||||
model = model[6:].casefold()
|
||||
try:
|
||||
_name, _tag = model.split(":", 1)
|
||||
except ValueError:
|
||||
model += ":latest"
|
||||
else:
|
||||
model = "orca-mini"
|
||||
|
||||
model = model.casefold()
|
||||
|
||||
if not await self.bot.is_owner(ctx.author):
|
||||
if not model.startswith("orca-mini"):
|
||||
await ctx.reply(":warning: You can only use `orca-mini` models.", delete_after=30)
|
||||
await ctx.respond(
|
||||
":warning: You can only use `orca-mini` models.",
|
||||
delete_after=30,
|
||||
ephemeral=True
|
||||
)
|
||||
model = "orca-mini"
|
||||
|
||||
if query.startswith("server:"):
|
||||
host, query = query.split(" ", 1)
|
||||
host = host[7:]
|
||||
try:
|
||||
host, port = host.split(":", 1)
|
||||
int(port)
|
||||
except ValueError:
|
||||
host += ":11434"
|
||||
else:
|
||||
# try_hosts = [
|
||||
# "127.0.0.1:11434", # Localhost
|
||||
# "100.106.34.86:11434", # Laptop
|
||||
# "100.66.187.46:11434", # optiplex
|
||||
# "100.116.242.161:11434" # Raspberry Pi
|
||||
# ]
|
||||
async with httpx.AsyncClient(follow_redirects=True) as client:
|
||||
for host in try_hosts.keys():
|
||||
try:
|
||||
|
@ -1888,7 +1919,7 @@ class OtherCog(commands.Cog):
|
|||
else:
|
||||
break
|
||||
else:
|
||||
return await ctx.reply(":x: No servers available.")
|
||||
return await ctx.respond(":x: No servers available.")
|
||||
|
||||
embed = discord.Embed(
|
||||
colour=discord.Colour.greyple()
|
||||
|
@ -1900,7 +1931,7 @@ class OtherCog(commands.Cog):
|
|||
)
|
||||
embed.set_footer(text="Using server {} ({})".format(host, try_hosts.get(host, "Other")))
|
||||
|
||||
msg = await ctx.reply(embed=embed)
|
||||
msg = await ctx.respond(embed=embed, ephemeral=False)
|
||||
async with httpx.AsyncClient(base_url=f"http://{host}/api", follow_redirects=True) as client:
|
||||
# get models
|
||||
try:
|
||||
|
@ -1975,8 +2006,6 @@ class OtherCog(commands.Cog):
|
|||
embed.set_footer(text=f"Powered by Ollama • {host} ({try_hosts.get(host, 'Other')})")
|
||||
await msg.edit(embed=embed)
|
||||
async with ctx.channel.typing():
|
||||
with open("./assets/ollama-prompt.txt") as file:
|
||||
system_prompt = file.read().replace("\n", " ").strip()
|
||||
async with client.stream(
|
||||
"POST",
|
||||
"/generate",
|
||||
|
@ -2069,21 +2098,45 @@ class OtherCog(commands.Cog):
|
|||
load_time_spent = get_time_spent(chunk.get("load_duration", 999999999.0))
|
||||
sample_time_sent = get_time_spent(chunk.get("sample_duration", 999999999.0))
|
||||
prompt_eval_time_spent = get_time_spent(chunk.get("prompt_eval_duration", 999999999.0))
|
||||
context: Optional[list[int]] = chunk.get("context")
|
||||
# noinspection PyTypeChecker
|
||||
if context:
|
||||
context_json = json.dumps(context)
|
||||
start = time()
|
||||
context_json_compressed = await asyncio.to_thread(
|
||||
zlib.compress,
|
||||
context_json.encode(),
|
||||
9
|
||||
)
|
||||
end = time()
|
||||
compress_time_spent = format(end * 1000 - start * 1000, ",")
|
||||
context: str = base64.b64encode(context_json_compressed).decode()
|
||||
else:
|
||||
compress_time_spent = "N/A"
|
||||
context = None
|
||||
value = ("* Total: {}\n"
|
||||
"* Model load: {}\n"
|
||||
"* Sample generation: {}\n"
|
||||
"* Prompt eval: {}\n"
|
||||
"* Response generation: {}").format(
|
||||
"* Response generation: {}\n"
|
||||
"* Context compression: {} milliseconds").format(
|
||||
total_time_spent,
|
||||
load_time_spent,
|
||||
sample_time_sent,
|
||||
prompt_eval_time_spent,
|
||||
eval_time_spent
|
||||
eval_time_spent,
|
||||
compress_time_spent
|
||||
)
|
||||
embed.add_field(
|
||||
name="Timings",
|
||||
value=value
|
||||
)
|
||||
if context:
|
||||
embed.add_field(
|
||||
name="Context",
|
||||
value=f"```\n{context}\n```"[:1024],
|
||||
inline=True
|
||||
)
|
||||
await msg.edit(content=None, embed=embed, view=None)
|
||||
self.ollama_locks.pop(msg, None)
|
||||
|
||||
|
|
Loading…
Reference in a new issue