diff --git a/src/cogs/ollama.py b/src/cogs/ollama.py index e5a40b3..3d0b767 100644 --- a/src/cogs/ollama.py +++ b/src/cogs/ollama.py @@ -217,6 +217,41 @@ class OllamaGetPrompt(discord.ui.Modal): self.stop() +class PromptSelector(discord.ui.View): + def __init__(self, ctx: discord.ApplicationContext): + super().__init__(timeout=600, disable_on_timeout=True) + self.ctx = ctx + self.system_prompt = None + self.user_prompt = None + + async def interaction_check(self, interaction: Interaction) -> bool: + return interaction.user == self.ctx.user + + def update_ui(self): + if self.system_prompt is not None: + self.get_item("sys").style = discord.ButtonStyle.secondary # type: ignore + if self.user_prompt is not None: + self.get_item("usr").style = discord.ButtonStyle.secondary # type: ignore + + @discord.ui.button(label="Set System Prompt", style=discord.ButtonStyle.primary, custom_id="sys") + async def set_system_prompt(self, btn: discord.ui.Button, interaction: Interaction): + modal = OllamaGetPrompt(self.ctx, "System") + await interaction.response.send_modal(modal) + await modal.wait() + self.system_prompt = modal.value + + @discord.ui.button(label="Set System Prompt", style=discord.ButtonStyle.primary, custom_id="usr") + async def set_system_prompt(self, btn: discord.ui.Button, interaction: Interaction): + modal = OllamaGetPrompt(self.ctx) + await interaction.response.send_modal(modal) + await modal.wait() + self.user_prompt = modal.value + + @discord.ui.button(label="Done", style=discord.ButtonStyle.success, custom_id="done") + async def done(self, btn: discord.ui.Button, interaction: Interaction): + self.stop() + + class Ollama(commands.Cog): def __init__(self, bot: commands.Bot): self.bot = bot @@ -314,28 +349,17 @@ class Ollama(commands.Cog): await ctx.respond("Invalid context key.") return - if query.startswith("$$"): - prompt = OllamaGetPrompt(ctx, "System") - await ctx.send_modal(prompt) - await prompt.wait() - system_query = prompt.value - if not system_query: - return await ctx.respond("No prompt provided. Aborting.") - else: - system_query = None - if query == "$": - prompt = OllamaGetPrompt(ctx) - await ctx.send_modal(prompt) - await prompt.wait() - query = prompt.value - if not query: - return await ctx.respond("No prompt provided. Aborting.") - try: await ctx.defer() except discord.HTTPException: pass + if query == "$": + v = PromptSelector(ctx) + await ctx.respond("Select edit your prompts, as desired. Click done when you want to continue.", view=v) + await v.wait() + query = v.user_prompt or query + model = model.casefold() try: model, tag = model.split(":", 1)