From 8a49acaea3908651acc133fc01e640871efffa10 Mon Sep 17 00:00:00 2001 From: nexy7574 Date: Tue, 5 Dec 2023 21:48:46 +0000 Subject: [PATCH] Add a transcribe lock --- cogs/other.py | 90 ++++++++++++++++++++++++++------------------------- 1 file changed, 46 insertions(+), 44 deletions(-) diff --git a/cogs/other.py b/cogs/other.py index 07fca77..27c9616 100644 --- a/cogs/other.py +++ b/cogs/other.py @@ -112,6 +112,7 @@ class OtherCog(commands.Cog): def __init__(self, bot): self.bot = bot self.lock = asyncio.Lock() + self.transcribe_lock = asyncio.Lock() self.http = httpx.AsyncClient() self._fmt_cache = {} self._fmt_queue = asyncio.Queue() @@ -2283,54 +2284,55 @@ class OtherCog(commands.Cog): @commands.message_command(name="Transcribe") async def transcribe_message(self, ctx: discord.ApplicationContext, message: discord.Message): await ctx.defer() - if not message.attachments: - return await ctx.respond("No attachments found.") + async with self.transcribe_lock: + if not message.attachments: + return await ctx.respond("No attachments found.") - _ft = "wav" - for attachment in message.attachments: - if attachment.content_type.startswith("audio/"): - _ft = attachment.filename.split(".")[-1] - break - else: - return await ctx.respond("No voice messages.") - if getattr(config, "OPENAI_KEY", None) is None: - return await ctx.respond("Service unavailable.") - file_hash = hashlib.sha1(usedforsecurity=False) - file_hash.update(await attachment.read()) - file_hash = file_hash.hexdigest() + _ft = "wav" + for attachment in message.attachments: + if attachment.content_type.startswith("audio/"): + _ft = attachment.filename.split(".")[-1] + break + else: + return await ctx.respond("No voice messages.") + if getattr(config, "OPENAI_KEY", None) is None: + return await ctx.respond("Service unavailable.") + file_hash = hashlib.sha1(usedforsecurity=False) + file_hash.update(await attachment.read()) + file_hash = file_hash.hexdigest() - cache = Path.home() / ".cache" / "lcc-bot" / ("%s-transcript.txt" % file_hash) - cached = False - if not cache.exists(): - client = openai.OpenAI(api_key=config.OPENAI_KEY) - with tempfile.NamedTemporaryFile("wb+", suffix=".mp4") as f: - with tempfile.NamedTemporaryFile("wb+", suffix="-" + attachment.filename) as f2: - await attachment.save(f2.name) - f2.seek(0) - seg: pydub.AudioSegment = await asyncio.to_thread(pydub.AudioSegment.from_file, file=f2, format=_ft) - seg = seg.set_channels(1) - await asyncio.to_thread(seg.export, f.name, format="mp4") - f.seek(0) + cache = Path.home() / ".cache" / "lcc-bot" / ("%s-transcript.txt" % file_hash) + cached = False + if not cache.exists(): + client = openai.OpenAI(api_key=config.OPENAI_KEY) + with tempfile.NamedTemporaryFile("wb+", suffix=".mp4") as f: + with tempfile.NamedTemporaryFile("wb+", suffix="-" + attachment.filename) as f2: + await attachment.save(f2.name) + f2.seek(0) + seg: pydub.AudioSegment = await asyncio.to_thread(pydub.AudioSegment.from_file, file=f2, format=_ft) + seg = seg.set_channels(1) + await asyncio.to_thread(seg.export, f.name, format="mp4") + f.seek(0) - transcript = await asyncio.to_thread( - client.audio.transcriptions.create, file=pathlib.Path(f.name), model="whisper-1" + transcript = await asyncio.to_thread( + client.audio.transcriptions.create, file=pathlib.Path(f.name), model="whisper-1" + ) + text = transcript.text + cache.write_text(text) + else: + text = cache.read_text() + cached = True + + paginator = commands.Paginator("", "", 4096) + for line in text.splitlines(): + paginator.add_line(textwrap.shorten(line, 4096)) + embeds = list(map(lambda p: discord.Embed(description=p), paginator.pages)) + await ctx.respond(embeds=embeds or [discord.Embed(description="No text found.")]) + + if await self.bot.is_owner(ctx.user): + await ctx.respond( + ("Cached response ({})" if cached else "Uncached response ({})").format(file_hash), ephemeral=True ) - text = transcript.text - cache.write_text(text) - else: - text = cache.read_text() - cached = True - - paginator = commands.Paginator("", "", 4096) - for line in text.splitlines(): - paginator.add_line(textwrap.shorten(line, 4096)) - embeds = list(map(lambda p: discord.Embed(description=p), paginator.pages)) - await ctx.respond(embeds=embeds or [discord.Embed(description="No text found.")]) - - if await self.bot.is_owner(ctx.user): - await ctx.respond( - ("Cached response ({})" if cached else "Uncached response ({})").format(file_hash), ephemeral=True - ) @commands.slash_command() async def whois(self, ctx: discord.ApplicationContext, domain: str):