mirror of
https://github.com/nexy7574/LCC-bot.git
synced 2024-09-19 18:16:34 +01:00
Move to interactions
This commit is contained in:
parent
f3e1986313
commit
11f2455025
1 changed files with 103 additions and 50 deletions
131
cogs/other.py
131
cogs/other.py
|
@ -1,9 +1,11 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import base64
|
||||||
import functools
|
import functools
|
||||||
import glob
|
import glob
|
||||||
import io
|
import io
|
||||||
import json
|
import json
|
||||||
import typing
|
import typing
|
||||||
|
import zlib
|
||||||
|
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
|
@ -1810,7 +1812,7 @@ class OtherCog(commands.Cog):
|
||||||
)
|
)
|
||||||
|
|
||||||
class OllamaKillSwitchView(discord.ui.View):
|
class OllamaKillSwitchView(discord.ui.View):
|
||||||
def __init__(self, ctx: commands.Context, msg: discord.Message):
|
def __init__(self, ctx: discord.ApplicationContext, msg: discord.Message):
|
||||||
super().__init__(timeout=None)
|
super().__init__(timeout=None)
|
||||||
self.ctx = ctx
|
self.ctx = ctx
|
||||||
self.msg = msg
|
self.msg = msg
|
||||||
|
@ -1831,12 +1833,63 @@ class OtherCog(commands.Cog):
|
||||||
await interaction.edit_original_response(view=self)
|
await interaction.edit_original_response(view=self)
|
||||||
self.stop()
|
self.stop()
|
||||||
|
|
||||||
@commands.command(
|
@commands.slash_command()
|
||||||
usage="[model:<name:tag>] [server:<ip[:port]>] <query>"
|
@commands.max_concurrency(1, commands.BucketType.user, wait=False)
|
||||||
)
|
async def ollama(
|
||||||
@commands.max_concurrency(1, commands.BucketType.user, wait=True)
|
self,
|
||||||
async def ollama(self, ctx: commands.Context, *, query: str):
|
ctx: discord.ApplicationContext,
|
||||||
|
model: str = "orca-mini",
|
||||||
|
query: str = None,
|
||||||
|
context: str = None
|
||||||
|
):
|
||||||
""":3"""
|
""":3"""
|
||||||
|
with open("./assets/ollama-prompt.txt") as file:
|
||||||
|
system_prompt = file.read().replace("\n", " ").strip()
|
||||||
|
if query is None:
|
||||||
|
class InputPrompt(discord.ui.Modal):
|
||||||
|
def __init__(self, is_owner: bool):
|
||||||
|
super().__init__(
|
||||||
|
discord.ui.InputText(
|
||||||
|
label="User Prompt",
|
||||||
|
placeholder="Enter prompt",
|
||||||
|
min_length=1,
|
||||||
|
max_length=4000,
|
||||||
|
style=discord.InputTextStyle.long,
|
||||||
|
),
|
||||||
|
title="Enter prompt",
|
||||||
|
timeout=120
|
||||||
|
)
|
||||||
|
if is_owner:
|
||||||
|
self.add_item(
|
||||||
|
discord.ui.InputText(
|
||||||
|
label="System Prompt",
|
||||||
|
placeholder="Enter prompt",
|
||||||
|
min_length=1,
|
||||||
|
max_length=4000,
|
||||||
|
style=discord.InputTextStyle.long,
|
||||||
|
value=system_prompt,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.user_prompt = None
|
||||||
|
self.system_prompt = system_prompt
|
||||||
|
|
||||||
|
async def callback(self, interaction: discord.Interaction):
|
||||||
|
self.user_prompt = self.children[0].value
|
||||||
|
if len(self.children) > 1:
|
||||||
|
self.system_prompt = self.children[1].value
|
||||||
|
await interaction.response.defer()
|
||||||
|
self.stop()
|
||||||
|
|
||||||
|
modal = InputPrompt(await self.bot.is_owner(ctx.author))
|
||||||
|
await ctx.send_modal(modal)
|
||||||
|
await modal.wait()
|
||||||
|
query = modal.user_prompt
|
||||||
|
if not modal.user_prompt:
|
||||||
|
return
|
||||||
|
system_prompt = modal.system_prompt or system_prompt
|
||||||
|
else:
|
||||||
|
await ctx.defer()
|
||||||
content = None
|
content = None
|
||||||
try_hosts = {
|
try_hosts = {
|
||||||
"127.0.0.1:11434": "localhost",
|
"127.0.0.1:11434": "localhost",
|
||||||
|
@ -1844,38 +1897,16 @@ class OtherCog(commands.Cog):
|
||||||
"100.66.187.46:11434": "Nexbox",
|
"100.66.187.46:11434": "Nexbox",
|
||||||
"100.116.242.161:11434": "PortaPi"
|
"100.116.242.161:11434": "PortaPi"
|
||||||
}
|
}
|
||||||
if query.startswith("model:"):
|
|
||||||
model, query = query.split(" ", 1)
|
|
||||||
model = model[6:].casefold()
|
|
||||||
try:
|
|
||||||
_name, _tag = model.split(":", 1)
|
|
||||||
except ValueError:
|
|
||||||
model += ":latest"
|
|
||||||
else:
|
|
||||||
model = "orca-mini"
|
|
||||||
|
|
||||||
model = model.casefold()
|
model = model.casefold()
|
||||||
|
|
||||||
if not await self.bot.is_owner(ctx.author):
|
if not await self.bot.is_owner(ctx.author):
|
||||||
if not model.startswith("orca-mini"):
|
if not model.startswith("orca-mini"):
|
||||||
await ctx.reply(":warning: You can only use `orca-mini` models.", delete_after=30)
|
await ctx.respond(
|
||||||
|
":warning: You can only use `orca-mini` models.",
|
||||||
|
delete_after=30,
|
||||||
|
ephemeral=True
|
||||||
|
)
|
||||||
model = "orca-mini"
|
model = "orca-mini"
|
||||||
|
|
||||||
if query.startswith("server:"):
|
|
||||||
host, query = query.split(" ", 1)
|
|
||||||
host = host[7:]
|
|
||||||
try:
|
|
||||||
host, port = host.split(":", 1)
|
|
||||||
int(port)
|
|
||||||
except ValueError:
|
|
||||||
host += ":11434"
|
|
||||||
else:
|
|
||||||
# try_hosts = [
|
|
||||||
# "127.0.0.1:11434", # Localhost
|
|
||||||
# "100.106.34.86:11434", # Laptop
|
|
||||||
# "100.66.187.46:11434", # optiplex
|
|
||||||
# "100.116.242.161:11434" # Raspberry Pi
|
|
||||||
# ]
|
|
||||||
async with httpx.AsyncClient(follow_redirects=True) as client:
|
async with httpx.AsyncClient(follow_redirects=True) as client:
|
||||||
for host in try_hosts.keys():
|
for host in try_hosts.keys():
|
||||||
try:
|
try:
|
||||||
|
@ -1888,7 +1919,7 @@ class OtherCog(commands.Cog):
|
||||||
else:
|
else:
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
return await ctx.reply(":x: No servers available.")
|
return await ctx.respond(":x: No servers available.")
|
||||||
|
|
||||||
embed = discord.Embed(
|
embed = discord.Embed(
|
||||||
colour=discord.Colour.greyple()
|
colour=discord.Colour.greyple()
|
||||||
|
@ -1900,7 +1931,7 @@ class OtherCog(commands.Cog):
|
||||||
)
|
)
|
||||||
embed.set_footer(text="Using server {} ({})".format(host, try_hosts.get(host, "Other")))
|
embed.set_footer(text="Using server {} ({})".format(host, try_hosts.get(host, "Other")))
|
||||||
|
|
||||||
msg = await ctx.reply(embed=embed)
|
msg = await ctx.respond(embed=embed, ephemeral=False)
|
||||||
async with httpx.AsyncClient(base_url=f"http://{host}/api", follow_redirects=True) as client:
|
async with httpx.AsyncClient(base_url=f"http://{host}/api", follow_redirects=True) as client:
|
||||||
# get models
|
# get models
|
||||||
try:
|
try:
|
||||||
|
@ -1975,8 +2006,6 @@ class OtherCog(commands.Cog):
|
||||||
embed.set_footer(text=f"Powered by Ollama • {host} ({try_hosts.get(host, 'Other')})")
|
embed.set_footer(text=f"Powered by Ollama • {host} ({try_hosts.get(host, 'Other')})")
|
||||||
await msg.edit(embed=embed)
|
await msg.edit(embed=embed)
|
||||||
async with ctx.channel.typing():
|
async with ctx.channel.typing():
|
||||||
with open("./assets/ollama-prompt.txt") as file:
|
|
||||||
system_prompt = file.read().replace("\n", " ").strip()
|
|
||||||
async with client.stream(
|
async with client.stream(
|
||||||
"POST",
|
"POST",
|
||||||
"/generate",
|
"/generate",
|
||||||
|
@ -2069,21 +2098,45 @@ class OtherCog(commands.Cog):
|
||||||
load_time_spent = get_time_spent(chunk.get("load_duration", 999999999.0))
|
load_time_spent = get_time_spent(chunk.get("load_duration", 999999999.0))
|
||||||
sample_time_sent = get_time_spent(chunk.get("sample_duration", 999999999.0))
|
sample_time_sent = get_time_spent(chunk.get("sample_duration", 999999999.0))
|
||||||
prompt_eval_time_spent = get_time_spent(chunk.get("prompt_eval_duration", 999999999.0))
|
prompt_eval_time_spent = get_time_spent(chunk.get("prompt_eval_duration", 999999999.0))
|
||||||
|
context: Optional[list[int]] = chunk.get("context")
|
||||||
|
# noinspection PyTypeChecker
|
||||||
|
if context:
|
||||||
|
context_json = json.dumps(context)
|
||||||
|
start = time()
|
||||||
|
context_json_compressed = await asyncio.to_thread(
|
||||||
|
zlib.compress,
|
||||||
|
context_json.encode(),
|
||||||
|
9
|
||||||
|
)
|
||||||
|
end = time()
|
||||||
|
compress_time_spent = format(end * 1000 - start * 1000, ",")
|
||||||
|
context: str = base64.b64encode(context_json_compressed).decode()
|
||||||
|
else:
|
||||||
|
compress_time_spent = "N/A"
|
||||||
|
context = None
|
||||||
value = ("* Total: {}\n"
|
value = ("* Total: {}\n"
|
||||||
"* Model load: {}\n"
|
"* Model load: {}\n"
|
||||||
"* Sample generation: {}\n"
|
"* Sample generation: {}\n"
|
||||||
"* Prompt eval: {}\n"
|
"* Prompt eval: {}\n"
|
||||||
"* Response generation: {}").format(
|
"* Response generation: {}\n"
|
||||||
|
"* Context compression: {} milliseconds").format(
|
||||||
total_time_spent,
|
total_time_spent,
|
||||||
load_time_spent,
|
load_time_spent,
|
||||||
sample_time_sent,
|
sample_time_sent,
|
||||||
prompt_eval_time_spent,
|
prompt_eval_time_spent,
|
||||||
eval_time_spent
|
eval_time_spent,
|
||||||
|
compress_time_spent
|
||||||
)
|
)
|
||||||
embed.add_field(
|
embed.add_field(
|
||||||
name="Timings",
|
name="Timings",
|
||||||
value=value
|
value=value
|
||||||
)
|
)
|
||||||
|
if context:
|
||||||
|
embed.add_field(
|
||||||
|
name="Context",
|
||||||
|
value=f"```\n{context}\n```"[:1024],
|
||||||
|
inline=True
|
||||||
|
)
|
||||||
await msg.edit(content=None, embed=embed, view=None)
|
await msg.edit(content=None, embed=embed, view=None)
|
||||||
self.ollama_locks.pop(msg, None)
|
self.ollama_locks.pop(msg, None)
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue