Add a transcribe lock

This commit is contained in:
nexy7574 2023-12-05 21:48:46 +00:00
parent f003e5bd4c
commit 8a49acaea3

View file

@ -112,6 +112,7 @@ class OtherCog(commands.Cog):
def __init__(self, bot): def __init__(self, bot):
self.bot = bot self.bot = bot
self.lock = asyncio.Lock() self.lock = asyncio.Lock()
self.transcribe_lock = asyncio.Lock()
self.http = httpx.AsyncClient() self.http = httpx.AsyncClient()
self._fmt_cache = {} self._fmt_cache = {}
self._fmt_queue = asyncio.Queue() self._fmt_queue = asyncio.Queue()
@ -2283,54 +2284,55 @@ class OtherCog(commands.Cog):
@commands.message_command(name="Transcribe") @commands.message_command(name="Transcribe")
async def transcribe_message(self, ctx: discord.ApplicationContext, message: discord.Message): async def transcribe_message(self, ctx: discord.ApplicationContext, message: discord.Message):
await ctx.defer() await ctx.defer()
if not message.attachments: async with self.transcribe_lock:
return await ctx.respond("No attachments found.") if not message.attachments:
return await ctx.respond("No attachments found.")
_ft = "wav" _ft = "wav"
for attachment in message.attachments: for attachment in message.attachments:
if attachment.content_type.startswith("audio/"): if attachment.content_type.startswith("audio/"):
_ft = attachment.filename.split(".")[-1] _ft = attachment.filename.split(".")[-1]
break break
else: else:
return await ctx.respond("No voice messages.") return await ctx.respond("No voice messages.")
if getattr(config, "OPENAI_KEY", None) is None: if getattr(config, "OPENAI_KEY", None) is None:
return await ctx.respond("Service unavailable.") return await ctx.respond("Service unavailable.")
file_hash = hashlib.sha1(usedforsecurity=False) file_hash = hashlib.sha1(usedforsecurity=False)
file_hash.update(await attachment.read()) file_hash.update(await attachment.read())
file_hash = file_hash.hexdigest() file_hash = file_hash.hexdigest()
cache = Path.home() / ".cache" / "lcc-bot" / ("%s-transcript.txt" % file_hash) cache = Path.home() / ".cache" / "lcc-bot" / ("%s-transcript.txt" % file_hash)
cached = False cached = False
if not cache.exists(): if not cache.exists():
client = openai.OpenAI(api_key=config.OPENAI_KEY) client = openai.OpenAI(api_key=config.OPENAI_KEY)
with tempfile.NamedTemporaryFile("wb+", suffix=".mp4") as f: with tempfile.NamedTemporaryFile("wb+", suffix=".mp4") as f:
with tempfile.NamedTemporaryFile("wb+", suffix="-" + attachment.filename) as f2: with tempfile.NamedTemporaryFile("wb+", suffix="-" + attachment.filename) as f2:
await attachment.save(f2.name) await attachment.save(f2.name)
f2.seek(0) f2.seek(0)
seg: pydub.AudioSegment = await asyncio.to_thread(pydub.AudioSegment.from_file, file=f2, format=_ft) seg: pydub.AudioSegment = await asyncio.to_thread(pydub.AudioSegment.from_file, file=f2, format=_ft)
seg = seg.set_channels(1) seg = seg.set_channels(1)
await asyncio.to_thread(seg.export, f.name, format="mp4") await asyncio.to_thread(seg.export, f.name, format="mp4")
f.seek(0) f.seek(0)
transcript = await asyncio.to_thread( transcript = await asyncio.to_thread(
client.audio.transcriptions.create, file=pathlib.Path(f.name), model="whisper-1" client.audio.transcriptions.create, file=pathlib.Path(f.name), model="whisper-1"
)
text = transcript.text
cache.write_text(text)
else:
text = cache.read_text()
cached = True
paginator = commands.Paginator("", "", 4096)
for line in text.splitlines():
paginator.add_line(textwrap.shorten(line, 4096))
embeds = list(map(lambda p: discord.Embed(description=p), paginator.pages))
await ctx.respond(embeds=embeds or [discord.Embed(description="No text found.")])
if await self.bot.is_owner(ctx.user):
await ctx.respond(
("Cached response ({})" if cached else "Uncached response ({})").format(file_hash), ephemeral=True
) )
text = transcript.text
cache.write_text(text)
else:
text = cache.read_text()
cached = True
paginator = commands.Paginator("", "", 4096)
for line in text.splitlines():
paginator.add_line(textwrap.shorten(line, 4096))
embeds = list(map(lambda p: discord.Embed(description=p), paginator.pages))
await ctx.respond(embeds=embeds or [discord.Embed(description="No text found.")])
if await self.bot.is_owner(ctx.user):
await ctx.respond(
("Cached response ({})" if cached else "Uncached response ({})").format(file_hash), ephemeral=True
)
@commands.slash_command() @commands.slash_command()
async def whois(self, ctx: discord.ApplicationContext, domain: str): async def whois(self, ctx: discord.ApplicationContext, domain: str):