Allow using system prompt and large user prompt
This commit is contained in:
parent
7cd2032de9
commit
50c648a618
1 changed files with 60 additions and 7 deletions
|
@ -8,6 +8,7 @@ import typing
|
|||
import base64
|
||||
import io
|
||||
import redis
|
||||
from discord import Interaction
|
||||
|
||||
from discord.ui import View, button
|
||||
from fnmatch import fnmatch
|
||||
|
@ -89,7 +90,7 @@ class ChatHistory:
|
|||
"threads:" + thread_id, json.dumps(self._internal[thread_id])
|
||||
)
|
||||
|
||||
def create_thread(self, member: discord.Member) -> str:
|
||||
def create_thread(self, member: discord.Member, default: str | None = None) -> str:
|
||||
"""
|
||||
Creates a thread, returns its ID.
|
||||
"""
|
||||
|
@ -100,7 +101,7 @@ class ChatHistory:
|
|||
"messages": []
|
||||
}
|
||||
with open("./assets/ollama-prompt.txt") as file:
|
||||
system_prompt = file.read()
|
||||
system_prompt = default or file.read()
|
||||
self.add_message(
|
||||
key,
|
||||
"system",
|
||||
|
@ -190,6 +191,32 @@ class ChatHistory:
|
|||
SERVER_KEYS = list(CONFIG["ollama"].keys())
|
||||
|
||||
|
||||
class OllamaGetPrompt(discord.ui.Modal):
|
||||
|
||||
def __init__(self, ctx: discord.ApplicationContext, prompt_type: str = "User"):
|
||||
super().__init__(
|
||||
discord.ui.InputText(
|
||||
style=discord.InputTextStyle.long,
|
||||
label="%s prompt" % prompt_type,
|
||||
placeholder="Enter your prompt here.",
|
||||
),
|
||||
timeout=300,
|
||||
title="Ollama %s prompt" % prompt_type,
|
||||
)
|
||||
self.ctx = ctx
|
||||
self.prompt_type = prompt_type
|
||||
self.value = None
|
||||
|
||||
async def interaction_check(self, interaction: discord.Interaction) -> bool:
|
||||
return interaction.user == self.ctx.user
|
||||
|
||||
async def callback(self, interaction: Interaction):
|
||||
await interaction.response.defer()
|
||||
self.ctx.interaction = interaction
|
||||
self.value = self.children[0].value
|
||||
self.stop()
|
||||
|
||||
|
||||
class Ollama(commands.Cog):
|
||||
def __init__(self, bot: commands.Bot):
|
||||
self.bot = bot
|
||||
|
@ -286,7 +313,28 @@ class Ollama(commands.Cog):
|
|||
if not self.history.get_thread(context):
|
||||
await ctx.respond("Invalid context key.")
|
||||
return
|
||||
await ctx.defer()
|
||||
|
||||
if query.startswith("$$"):
|
||||
prompt = OllamaGetPrompt(ctx, "System")
|
||||
await ctx.send_modal(prompt)
|
||||
await prompt.wait()
|
||||
system_query = prompt.value
|
||||
if not system_query:
|
||||
return await ctx.respond("No prompt provided. Aborting.")
|
||||
else:
|
||||
system_query = None
|
||||
if query == "$":
|
||||
prompt = OllamaGetPrompt(ctx)
|
||||
await ctx.send_modal(prompt)
|
||||
await prompt.wait()
|
||||
query = prompt.value
|
||||
if not query:
|
||||
return await ctx.respond("No prompt provided. Aborting.")
|
||||
|
||||
try:
|
||||
await ctx.defer()
|
||||
except discord.HTTPException:
|
||||
pass
|
||||
|
||||
model = model.casefold()
|
||||
try:
|
||||
|
@ -294,7 +342,7 @@ class Ollama(commands.Cog):
|
|||
model = model + ":" + tag
|
||||
self.log.debug("Model %r already has a tag", model)
|
||||
except ValueError:
|
||||
model = model + ":latest"
|
||||
model += ":latest"
|
||||
self.log.debug("Resolved model to %r" % model)
|
||||
|
||||
if image:
|
||||
|
@ -315,7 +363,7 @@ class Ollama(commands.Cog):
|
|||
data = io.BytesIO()
|
||||
await image.save(data)
|
||||
data.seek(0)
|
||||
image_data = base64.b64encode(data.read()).decode("utf-8")
|
||||
image_data = base64.b64encode(data.read()).decode()
|
||||
else:
|
||||
image_data = None
|
||||
|
||||
|
@ -336,7 +384,12 @@ class Ollama(commands.Cog):
|
|||
|
||||
async with aiohttp.ClientSession(
|
||||
base_url=server_config["base_url"],
|
||||
timeout=aiohttp.ClientTimeout(0)
|
||||
timeout=aiohttp.ClientTimeout(
|
||||
connect=30,
|
||||
sock_read=10800,
|
||||
sock_connect=30,
|
||||
total=10830
|
||||
)
|
||||
) as session:
|
||||
embed = discord.Embed(
|
||||
title="Checking server...",
|
||||
|
@ -482,7 +535,7 @@ class Ollama(commands.Cog):
|
|||
self.log.debug("Beginning to generate response with key %r.", key)
|
||||
|
||||
if context is None:
|
||||
context = self.history.create_thread(ctx.user)
|
||||
context = self.history.create_thread(ctx.user, system_query)
|
||||
elif context is not None and self.history.get_thread(context) is None:
|
||||
__thread = self.history.find_thread(context)
|
||||
if not __thread:
|
||||
|
|
Loading…
Reference in a new issue