diff --git a/cogs/other.py b/cogs/other.py index 5bffc9b..2cdbf9c 100644 --- a/cogs/other.py +++ b/cogs/other.py @@ -4,7 +4,6 @@ import os import random import re import textwrap -import traceback import dns.resolver import aiofiles @@ -37,6 +36,24 @@ class OtherCog(commands.Cog): self.bot = bot self.lock = asyncio.Lock() + class AbortScreenshotTask(discord.ui.View): + def __init__(self, task: asyncio.Task): + super().__init__() + self.task = task + + @discord.ui.button(label="Abort", style=discord.ButtonStyle.red) + async def abort(self, button: discord.ui.Button, interaction: discord.Interaction): + new: discord.Interaction = await interaction.response.send_message("Aborting...", ephemeral=True) + self.task.cancel() + try: + await self.task + except asyncio.CancelledError: + pass + self.disable_all_items() + button.label = "[ aborted ]" + await new.edit_original_response(content="Aborted screenshot task.", view=self) + self.stop() + async def screenshot_website( self, ctx: discord.ApplicationContext, @@ -126,19 +143,19 @@ class OtherCog(commands.Cog): end_init = time() console.log("Driver '{}' initialised in {} seconds.".format(driver_name, round(end_init - start_init, 2))) - async def _edit(content: str): + def _edit(content: str): self.bot.loop.create_task(ctx.interaction.edit_original_response(content=content)) - await _edit(content=f"Screenshotting <{friendly_url}>... (49%, loading webpage)") + _edit(content=f"Screenshotting <{friendly_url}>... (49%, loading webpage)") await _blocking(driver.set_page_load_timeout, render_time) start = time() await _blocking(driver.get, website) end = time() get_time = round((end - start) * 1000) render_time_expires = round(time() + render_time) - await _edit(content=f"Screenshotting <{friendly_url}>... (66%, stopping render )") + _edit(content=f"Screenshotting <{friendly_url}>... (66%, stopping render )") await asyncio.sleep(render_time) - await _edit(content=f"Screenshotting <{friendly_url}>... (83%, saving screenshot)") + _edit(content=f"Screenshotting <{friendly_url}>... (83%, saving screenshot)") domain = re.sub(r"https?://", "", website) screenshot_method = driver.get_screenshot_as_png @@ -417,6 +434,8 @@ class OtherCog(commands.Cog): await ctx.respond(page, ephemeral=secure) @commands.slash_command() + @commands.max_concurrency(1, commands.BucketType.user) + @commands.cooldown(1, 30, commands.BucketType.user) async def screenshot( self, ctx: discord.ApplicationContext, diff --git a/main.py b/main.py index 0b8765f..af6e69e 100644 --- a/main.py +++ b/main.py @@ -2,7 +2,7 @@ import discord from discord.ext import commands from asyncio import Lock import config -from datetime import datetime, timezone +from datetime import datetime, timezone, timedelta from utils import registry, console, get_or_none, JimmyBans @@ -49,6 +49,19 @@ async def on_connect(): @bot.listen("on_application_command_error") async def on_application_command_error(ctx: discord.ApplicationContext, error: Exception): + if isinstance(error, commands.CommandOnCooldown): + now = discord.utils.utcnow() + now += timedelta(seconds=error.retry_after) + return await ctx.respond( + f"\N{stopwatch} This command is on cooldown. You can use this command again " + f"{discord.utils.format_dt(now, 'R')}.", + delete_after=error.retry_after, + ) + elif isinstance(error, commands.MaxConcurrencyReached): + return await ctx.respond( + f"\N{warning sign} This command is already running. Please wait for it to finish.", + ephemeral=True, + ) await ctx.respond("Application Command Error: `%r`" % error) raise error