Use a view instead
This commit is contained in:
parent
a8a67d5ce6
commit
3e0152ee24
1 changed files with 41 additions and 17 deletions
|
@ -217,6 +217,41 @@ class OllamaGetPrompt(discord.ui.Modal):
|
|||
self.stop()
|
||||
|
||||
|
||||
class PromptSelector(discord.ui.View):
|
||||
def __init__(self, ctx: discord.ApplicationContext):
|
||||
super().__init__(timeout=600, disable_on_timeout=True)
|
||||
self.ctx = ctx
|
||||
self.system_prompt = None
|
||||
self.user_prompt = None
|
||||
|
||||
async def interaction_check(self, interaction: Interaction) -> bool:
|
||||
return interaction.user == self.ctx.user
|
||||
|
||||
def update_ui(self):
|
||||
if self.system_prompt is not None:
|
||||
self.get_item("sys").style = discord.ButtonStyle.secondary # type: ignore
|
||||
if self.user_prompt is not None:
|
||||
self.get_item("usr").style = discord.ButtonStyle.secondary # type: ignore
|
||||
|
||||
@discord.ui.button(label="Set System Prompt", style=discord.ButtonStyle.primary, custom_id="sys")
|
||||
async def set_system_prompt(self, btn: discord.ui.Button, interaction: Interaction):
|
||||
modal = OllamaGetPrompt(self.ctx, "System")
|
||||
await interaction.response.send_modal(modal)
|
||||
await modal.wait()
|
||||
self.system_prompt = modal.value
|
||||
|
||||
@discord.ui.button(label="Set System Prompt", style=discord.ButtonStyle.primary, custom_id="usr")
|
||||
async def set_system_prompt(self, btn: discord.ui.Button, interaction: Interaction):
|
||||
modal = OllamaGetPrompt(self.ctx)
|
||||
await interaction.response.send_modal(modal)
|
||||
await modal.wait()
|
||||
self.user_prompt = modal.value
|
||||
|
||||
@discord.ui.button(label="Done", style=discord.ButtonStyle.success, custom_id="done")
|
||||
async def done(self, btn: discord.ui.Button, interaction: Interaction):
|
||||
self.stop()
|
||||
|
||||
|
||||
class Ollama(commands.Cog):
|
||||
def __init__(self, bot: commands.Bot):
|
||||
self.bot = bot
|
||||
|
@ -314,28 +349,17 @@ class Ollama(commands.Cog):
|
|||
await ctx.respond("Invalid context key.")
|
||||
return
|
||||
|
||||
if query.startswith("$$"):
|
||||
prompt = OllamaGetPrompt(ctx, "System")
|
||||
await ctx.send_modal(prompt)
|
||||
await prompt.wait()
|
||||
system_query = prompt.value
|
||||
if not system_query:
|
||||
return await ctx.respond("No prompt provided. Aborting.")
|
||||
else:
|
||||
system_query = None
|
||||
if query == "$":
|
||||
prompt = OllamaGetPrompt(ctx)
|
||||
await ctx.send_modal(prompt)
|
||||
await prompt.wait()
|
||||
query = prompt.value
|
||||
if not query:
|
||||
return await ctx.respond("No prompt provided. Aborting.")
|
||||
|
||||
try:
|
||||
await ctx.defer()
|
||||
except discord.HTTPException:
|
||||
pass
|
||||
|
||||
if query == "$":
|
||||
v = PromptSelector(ctx)
|
||||
await ctx.respond("Select edit your prompts, as desired. Click done when you want to continue.", view=v)
|
||||
await v.wait()
|
||||
query = v.user_prompt or query
|
||||
|
||||
model = model.casefold()
|
||||
try:
|
||||
model, tag = model.split(":", 1)
|
||||
|
|
Loading…
Reference in a new issue