Allow using system prompt and large user prompt

This commit is contained in:
Nexus 2024-03-22 09:08:03 +00:00
parent 7cd2032de9
commit 50c648a618
Signed by: nex
GPG key ID: 0FA334385D0B689F

View file

@ -8,6 +8,7 @@ import typing
import base64 import base64
import io import io
import redis import redis
from discord import Interaction
from discord.ui import View, button from discord.ui import View, button
from fnmatch import fnmatch from fnmatch import fnmatch
@ -89,7 +90,7 @@ class ChatHistory:
"threads:" + thread_id, json.dumps(self._internal[thread_id]) "threads:" + thread_id, json.dumps(self._internal[thread_id])
) )
def create_thread(self, member: discord.Member) -> str: def create_thread(self, member: discord.Member, default: str | None = None) -> str:
""" """
Creates a thread, returns its ID. Creates a thread, returns its ID.
""" """
@ -100,7 +101,7 @@ class ChatHistory:
"messages": [] "messages": []
} }
with open("./assets/ollama-prompt.txt") as file: with open("./assets/ollama-prompt.txt") as file:
system_prompt = file.read() system_prompt = default or file.read()
self.add_message( self.add_message(
key, key,
"system", "system",
@ -190,6 +191,32 @@ class ChatHistory:
SERVER_KEYS = list(CONFIG["ollama"].keys()) SERVER_KEYS = list(CONFIG["ollama"].keys())
class OllamaGetPrompt(discord.ui.Modal):
def __init__(self, ctx: discord.ApplicationContext, prompt_type: str = "User"):
super().__init__(
discord.ui.InputText(
style=discord.InputTextStyle.long,
label="%s prompt" % prompt_type,
placeholder="Enter your prompt here.",
),
timeout=300,
title="Ollama %s prompt" % prompt_type,
)
self.ctx = ctx
self.prompt_type = prompt_type
self.value = None
async def interaction_check(self, interaction: discord.Interaction) -> bool:
return interaction.user == self.ctx.user
async def callback(self, interaction: Interaction):
await interaction.response.defer()
self.ctx.interaction = interaction
self.value = self.children[0].value
self.stop()
class Ollama(commands.Cog): class Ollama(commands.Cog):
def __init__(self, bot: commands.Bot): def __init__(self, bot: commands.Bot):
self.bot = bot self.bot = bot
@ -286,7 +313,28 @@ class Ollama(commands.Cog):
if not self.history.get_thread(context): if not self.history.get_thread(context):
await ctx.respond("Invalid context key.") await ctx.respond("Invalid context key.")
return return
await ctx.defer()
if query.startswith("$$"):
prompt = OllamaGetPrompt(ctx, "System")
await ctx.send_modal(prompt)
await prompt.wait()
system_query = prompt.value
if not system_query:
return await ctx.respond("No prompt provided. Aborting.")
else:
system_query = None
if query == "$":
prompt = OllamaGetPrompt(ctx)
await ctx.send_modal(prompt)
await prompt.wait()
query = prompt.value
if not query:
return await ctx.respond("No prompt provided. Aborting.")
try:
await ctx.defer()
except discord.HTTPException:
pass
model = model.casefold() model = model.casefold()
try: try:
@ -294,7 +342,7 @@ class Ollama(commands.Cog):
model = model + ":" + tag model = model + ":" + tag
self.log.debug("Model %r already has a tag", model) self.log.debug("Model %r already has a tag", model)
except ValueError: except ValueError:
model = model + ":latest" model += ":latest"
self.log.debug("Resolved model to %r" % model) self.log.debug("Resolved model to %r" % model)
if image: if image:
@ -315,7 +363,7 @@ class Ollama(commands.Cog):
data = io.BytesIO() data = io.BytesIO()
await image.save(data) await image.save(data)
data.seek(0) data.seek(0)
image_data = base64.b64encode(data.read()).decode("utf-8") image_data = base64.b64encode(data.read()).decode()
else: else:
image_data = None image_data = None
@ -336,7 +384,12 @@ class Ollama(commands.Cog):
async with aiohttp.ClientSession( async with aiohttp.ClientSession(
base_url=server_config["base_url"], base_url=server_config["base_url"],
timeout=aiohttp.ClientTimeout(0) timeout=aiohttp.ClientTimeout(
connect=30,
sock_read=10800,
sock_connect=30,
total=10830
)
) as session: ) as session:
embed = discord.Embed( embed = discord.Embed(
title="Checking server...", title="Checking server...",
@ -482,7 +535,7 @@ class Ollama(commands.Cog):
self.log.debug("Beginning to generate response with key %r.", key) self.log.debug("Beginning to generate response with key %r.", key)
if context is None: if context is None:
context = self.history.create_thread(ctx.user) context = self.history.create_thread(ctx.user, system_query)
elif context is not None and self.history.get_thread(context) is None: elif context is not None and self.history.get_thread(context) is None:
__thread = self.history.find_thread(context) __thread = self.history.find_thread(context)
if not __thread: if not __thread: