Allow using system prompt and large user prompt
This commit is contained in:
parent
7cd2032de9
commit
50c648a618
1 changed files with 60 additions and 7 deletions
|
@ -8,6 +8,7 @@ import typing
|
||||||
import base64
|
import base64
|
||||||
import io
|
import io
|
||||||
import redis
|
import redis
|
||||||
|
from discord import Interaction
|
||||||
|
|
||||||
from discord.ui import View, button
|
from discord.ui import View, button
|
||||||
from fnmatch import fnmatch
|
from fnmatch import fnmatch
|
||||||
|
@ -89,7 +90,7 @@ class ChatHistory:
|
||||||
"threads:" + thread_id, json.dumps(self._internal[thread_id])
|
"threads:" + thread_id, json.dumps(self._internal[thread_id])
|
||||||
)
|
)
|
||||||
|
|
||||||
def create_thread(self, member: discord.Member) -> str:
|
def create_thread(self, member: discord.Member, default: str | None = None) -> str:
|
||||||
"""
|
"""
|
||||||
Creates a thread, returns its ID.
|
Creates a thread, returns its ID.
|
||||||
"""
|
"""
|
||||||
|
@ -100,7 +101,7 @@ class ChatHistory:
|
||||||
"messages": []
|
"messages": []
|
||||||
}
|
}
|
||||||
with open("./assets/ollama-prompt.txt") as file:
|
with open("./assets/ollama-prompt.txt") as file:
|
||||||
system_prompt = file.read()
|
system_prompt = default or file.read()
|
||||||
self.add_message(
|
self.add_message(
|
||||||
key,
|
key,
|
||||||
"system",
|
"system",
|
||||||
|
@ -190,6 +191,32 @@ class ChatHistory:
|
||||||
SERVER_KEYS = list(CONFIG["ollama"].keys())
|
SERVER_KEYS = list(CONFIG["ollama"].keys())
|
||||||
|
|
||||||
|
|
||||||
|
class OllamaGetPrompt(discord.ui.Modal):
|
||||||
|
|
||||||
|
def __init__(self, ctx: discord.ApplicationContext, prompt_type: str = "User"):
|
||||||
|
super().__init__(
|
||||||
|
discord.ui.InputText(
|
||||||
|
style=discord.InputTextStyle.long,
|
||||||
|
label="%s prompt" % prompt_type,
|
||||||
|
placeholder="Enter your prompt here.",
|
||||||
|
),
|
||||||
|
timeout=300,
|
||||||
|
title="Ollama %s prompt" % prompt_type,
|
||||||
|
)
|
||||||
|
self.ctx = ctx
|
||||||
|
self.prompt_type = prompt_type
|
||||||
|
self.value = None
|
||||||
|
|
||||||
|
async def interaction_check(self, interaction: discord.Interaction) -> bool:
|
||||||
|
return interaction.user == self.ctx.user
|
||||||
|
|
||||||
|
async def callback(self, interaction: Interaction):
|
||||||
|
await interaction.response.defer()
|
||||||
|
self.ctx.interaction = interaction
|
||||||
|
self.value = self.children[0].value
|
||||||
|
self.stop()
|
||||||
|
|
||||||
|
|
||||||
class Ollama(commands.Cog):
|
class Ollama(commands.Cog):
|
||||||
def __init__(self, bot: commands.Bot):
|
def __init__(self, bot: commands.Bot):
|
||||||
self.bot = bot
|
self.bot = bot
|
||||||
|
@ -286,7 +313,28 @@ class Ollama(commands.Cog):
|
||||||
if not self.history.get_thread(context):
|
if not self.history.get_thread(context):
|
||||||
await ctx.respond("Invalid context key.")
|
await ctx.respond("Invalid context key.")
|
||||||
return
|
return
|
||||||
await ctx.defer()
|
|
||||||
|
if query.startswith("$$"):
|
||||||
|
prompt = OllamaGetPrompt(ctx, "System")
|
||||||
|
await ctx.send_modal(prompt)
|
||||||
|
await prompt.wait()
|
||||||
|
system_query = prompt.value
|
||||||
|
if not system_query:
|
||||||
|
return await ctx.respond("No prompt provided. Aborting.")
|
||||||
|
else:
|
||||||
|
system_query = None
|
||||||
|
if query == "$":
|
||||||
|
prompt = OllamaGetPrompt(ctx)
|
||||||
|
await ctx.send_modal(prompt)
|
||||||
|
await prompt.wait()
|
||||||
|
query = prompt.value
|
||||||
|
if not query:
|
||||||
|
return await ctx.respond("No prompt provided. Aborting.")
|
||||||
|
|
||||||
|
try:
|
||||||
|
await ctx.defer()
|
||||||
|
except discord.HTTPException:
|
||||||
|
pass
|
||||||
|
|
||||||
model = model.casefold()
|
model = model.casefold()
|
||||||
try:
|
try:
|
||||||
|
@ -294,7 +342,7 @@ class Ollama(commands.Cog):
|
||||||
model = model + ":" + tag
|
model = model + ":" + tag
|
||||||
self.log.debug("Model %r already has a tag", model)
|
self.log.debug("Model %r already has a tag", model)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
model = model + ":latest"
|
model += ":latest"
|
||||||
self.log.debug("Resolved model to %r" % model)
|
self.log.debug("Resolved model to %r" % model)
|
||||||
|
|
||||||
if image:
|
if image:
|
||||||
|
@ -315,7 +363,7 @@ class Ollama(commands.Cog):
|
||||||
data = io.BytesIO()
|
data = io.BytesIO()
|
||||||
await image.save(data)
|
await image.save(data)
|
||||||
data.seek(0)
|
data.seek(0)
|
||||||
image_data = base64.b64encode(data.read()).decode("utf-8")
|
image_data = base64.b64encode(data.read()).decode()
|
||||||
else:
|
else:
|
||||||
image_data = None
|
image_data = None
|
||||||
|
|
||||||
|
@ -336,7 +384,12 @@ class Ollama(commands.Cog):
|
||||||
|
|
||||||
async with aiohttp.ClientSession(
|
async with aiohttp.ClientSession(
|
||||||
base_url=server_config["base_url"],
|
base_url=server_config["base_url"],
|
||||||
timeout=aiohttp.ClientTimeout(0)
|
timeout=aiohttp.ClientTimeout(
|
||||||
|
connect=30,
|
||||||
|
sock_read=10800,
|
||||||
|
sock_connect=30,
|
||||||
|
total=10830
|
||||||
|
)
|
||||||
) as session:
|
) as session:
|
||||||
embed = discord.Embed(
|
embed = discord.Embed(
|
||||||
title="Checking server...",
|
title="Checking server...",
|
||||||
|
@ -482,7 +535,7 @@ class Ollama(commands.Cog):
|
||||||
self.log.debug("Beginning to generate response with key %r.", key)
|
self.log.debug("Beginning to generate response with key %r.", key)
|
||||||
|
|
||||||
if context is None:
|
if context is None:
|
||||||
context = self.history.create_thread(ctx.user)
|
context = self.history.create_thread(ctx.user, system_query)
|
||||||
elif context is not None and self.history.get_thread(context) is None:
|
elif context is not None and self.history.get_thread(context) is None:
|
||||||
__thread = self.history.find_thread(context)
|
__thread = self.history.find_thread(context)
|
||||||
if not __thread:
|
if not __thread:
|
||||||
|
|
Loading…
Reference in a new issue