Allow using system prompt and large user prompt

This commit is contained in:
Nexus 2024-03-22 09:08:03 +00:00
parent 7cd2032de9
commit 50c648a618
Signed by: nex
GPG key ID: 0FA334385D0B689F

View file

@ -8,6 +8,7 @@ import typing
import base64
import io
import redis
from discord import Interaction
from discord.ui import View, button
from fnmatch import fnmatch
@ -89,7 +90,7 @@ class ChatHistory:
"threads:" + thread_id, json.dumps(self._internal[thread_id])
)
def create_thread(self, member: discord.Member) -> str:
def create_thread(self, member: discord.Member, default: str | None = None) -> str:
"""
Creates a thread, returns its ID.
"""
@ -100,7 +101,7 @@ class ChatHistory:
"messages": []
}
with open("./assets/ollama-prompt.txt") as file:
system_prompt = file.read()
system_prompt = default or file.read()
self.add_message(
key,
"system",
@ -190,6 +191,32 @@ class ChatHistory:
SERVER_KEYS = list(CONFIG["ollama"].keys())
class OllamaGetPrompt(discord.ui.Modal):
def __init__(self, ctx: discord.ApplicationContext, prompt_type: str = "User"):
super().__init__(
discord.ui.InputText(
style=discord.InputTextStyle.long,
label="%s prompt" % prompt_type,
placeholder="Enter your prompt here.",
),
timeout=300,
title="Ollama %s prompt" % prompt_type,
)
self.ctx = ctx
self.prompt_type = prompt_type
self.value = None
async def interaction_check(self, interaction: discord.Interaction) -> bool:
return interaction.user == self.ctx.user
async def callback(self, interaction: Interaction):
await interaction.response.defer()
self.ctx.interaction = interaction
self.value = self.children[0].value
self.stop()
class Ollama(commands.Cog):
def __init__(self, bot: commands.Bot):
self.bot = bot
@ -286,7 +313,28 @@ class Ollama(commands.Cog):
if not self.history.get_thread(context):
await ctx.respond("Invalid context key.")
return
await ctx.defer()
if query.startswith("$$"):
prompt = OllamaGetPrompt(ctx, "System")
await ctx.send_modal(prompt)
await prompt.wait()
system_query = prompt.value
if not system_query:
return await ctx.respond("No prompt provided. Aborting.")
else:
system_query = None
if query == "$":
prompt = OllamaGetPrompt(ctx)
await ctx.send_modal(prompt)
await prompt.wait()
query = prompt.value
if not query:
return await ctx.respond("No prompt provided. Aborting.")
try:
await ctx.defer()
except discord.HTTPException:
pass
model = model.casefold()
try:
@ -294,7 +342,7 @@ class Ollama(commands.Cog):
model = model + ":" + tag
self.log.debug("Model %r already has a tag", model)
except ValueError:
model = model + ":latest"
model += ":latest"
self.log.debug("Resolved model to %r" % model)
if image:
@ -315,7 +363,7 @@ class Ollama(commands.Cog):
data = io.BytesIO()
await image.save(data)
data.seek(0)
image_data = base64.b64encode(data.read()).decode("utf-8")
image_data = base64.b64encode(data.read()).decode()
else:
image_data = None
@ -336,7 +384,12 @@ class Ollama(commands.Cog):
async with aiohttp.ClientSession(
base_url=server_config["base_url"],
timeout=aiohttp.ClientTimeout(0)
timeout=aiohttp.ClientTimeout(
connect=30,
sock_read=10800,
sock_connect=30,
total=10830
)
) as session:
embed = discord.Embed(
title="Checking server...",
@ -482,7 +535,7 @@ class Ollama(commands.Cog):
self.log.debug("Beginning to generate response with key %r.", key)
if context is None:
context = self.history.create_thread(ctx.user)
context = self.history.create_thread(ctx.user, system_query)
elif context is not None and self.history.get_thread(context) is None:
__thread = self.history.find_thread(context)
if not __thread: