mirror of
https://github.com/nexy7574/LCC-bot.git
synced 2024-09-19 18:16:34 +01:00
Add a transcribe lock
This commit is contained in:
parent
f003e5bd4c
commit
8a49acaea3
1 changed files with 46 additions and 44 deletions
|
@ -112,6 +112,7 @@ class OtherCog(commands.Cog):
|
||||||
def __init__(self, bot):
|
def __init__(self, bot):
|
||||||
self.bot = bot
|
self.bot = bot
|
||||||
self.lock = asyncio.Lock()
|
self.lock = asyncio.Lock()
|
||||||
|
self.transcribe_lock = asyncio.Lock()
|
||||||
self.http = httpx.AsyncClient()
|
self.http = httpx.AsyncClient()
|
||||||
self._fmt_cache = {}
|
self._fmt_cache = {}
|
||||||
self._fmt_queue = asyncio.Queue()
|
self._fmt_queue = asyncio.Queue()
|
||||||
|
@ -2283,54 +2284,55 @@ class OtherCog(commands.Cog):
|
||||||
@commands.message_command(name="Transcribe")
|
@commands.message_command(name="Transcribe")
|
||||||
async def transcribe_message(self, ctx: discord.ApplicationContext, message: discord.Message):
|
async def transcribe_message(self, ctx: discord.ApplicationContext, message: discord.Message):
|
||||||
await ctx.defer()
|
await ctx.defer()
|
||||||
if not message.attachments:
|
async with self.transcribe_lock:
|
||||||
return await ctx.respond("No attachments found.")
|
if not message.attachments:
|
||||||
|
return await ctx.respond("No attachments found.")
|
||||||
|
|
||||||
_ft = "wav"
|
_ft = "wav"
|
||||||
for attachment in message.attachments:
|
for attachment in message.attachments:
|
||||||
if attachment.content_type.startswith("audio/"):
|
if attachment.content_type.startswith("audio/"):
|
||||||
_ft = attachment.filename.split(".")[-1]
|
_ft = attachment.filename.split(".")[-1]
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
return await ctx.respond("No voice messages.")
|
return await ctx.respond("No voice messages.")
|
||||||
if getattr(config, "OPENAI_KEY", None) is None:
|
if getattr(config, "OPENAI_KEY", None) is None:
|
||||||
return await ctx.respond("Service unavailable.")
|
return await ctx.respond("Service unavailable.")
|
||||||
file_hash = hashlib.sha1(usedforsecurity=False)
|
file_hash = hashlib.sha1(usedforsecurity=False)
|
||||||
file_hash.update(await attachment.read())
|
file_hash.update(await attachment.read())
|
||||||
file_hash = file_hash.hexdigest()
|
file_hash = file_hash.hexdigest()
|
||||||
|
|
||||||
cache = Path.home() / ".cache" / "lcc-bot" / ("%s-transcript.txt" % file_hash)
|
cache = Path.home() / ".cache" / "lcc-bot" / ("%s-transcript.txt" % file_hash)
|
||||||
cached = False
|
cached = False
|
||||||
if not cache.exists():
|
if not cache.exists():
|
||||||
client = openai.OpenAI(api_key=config.OPENAI_KEY)
|
client = openai.OpenAI(api_key=config.OPENAI_KEY)
|
||||||
with tempfile.NamedTemporaryFile("wb+", suffix=".mp4") as f:
|
with tempfile.NamedTemporaryFile("wb+", suffix=".mp4") as f:
|
||||||
with tempfile.NamedTemporaryFile("wb+", suffix="-" + attachment.filename) as f2:
|
with tempfile.NamedTemporaryFile("wb+", suffix="-" + attachment.filename) as f2:
|
||||||
await attachment.save(f2.name)
|
await attachment.save(f2.name)
|
||||||
f2.seek(0)
|
f2.seek(0)
|
||||||
seg: pydub.AudioSegment = await asyncio.to_thread(pydub.AudioSegment.from_file, file=f2, format=_ft)
|
seg: pydub.AudioSegment = await asyncio.to_thread(pydub.AudioSegment.from_file, file=f2, format=_ft)
|
||||||
seg = seg.set_channels(1)
|
seg = seg.set_channels(1)
|
||||||
await asyncio.to_thread(seg.export, f.name, format="mp4")
|
await asyncio.to_thread(seg.export, f.name, format="mp4")
|
||||||
f.seek(0)
|
f.seek(0)
|
||||||
|
|
||||||
transcript = await asyncio.to_thread(
|
transcript = await asyncio.to_thread(
|
||||||
client.audio.transcriptions.create, file=pathlib.Path(f.name), model="whisper-1"
|
client.audio.transcriptions.create, file=pathlib.Path(f.name), model="whisper-1"
|
||||||
|
)
|
||||||
|
text = transcript.text
|
||||||
|
cache.write_text(text)
|
||||||
|
else:
|
||||||
|
text = cache.read_text()
|
||||||
|
cached = True
|
||||||
|
|
||||||
|
paginator = commands.Paginator("", "", 4096)
|
||||||
|
for line in text.splitlines():
|
||||||
|
paginator.add_line(textwrap.shorten(line, 4096))
|
||||||
|
embeds = list(map(lambda p: discord.Embed(description=p), paginator.pages))
|
||||||
|
await ctx.respond(embeds=embeds or [discord.Embed(description="No text found.")])
|
||||||
|
|
||||||
|
if await self.bot.is_owner(ctx.user):
|
||||||
|
await ctx.respond(
|
||||||
|
("Cached response ({})" if cached else "Uncached response ({})").format(file_hash), ephemeral=True
|
||||||
)
|
)
|
||||||
text = transcript.text
|
|
||||||
cache.write_text(text)
|
|
||||||
else:
|
|
||||||
text = cache.read_text()
|
|
||||||
cached = True
|
|
||||||
|
|
||||||
paginator = commands.Paginator("", "", 4096)
|
|
||||||
for line in text.splitlines():
|
|
||||||
paginator.add_line(textwrap.shorten(line, 4096))
|
|
||||||
embeds = list(map(lambda p: discord.Embed(description=p), paginator.pages))
|
|
||||||
await ctx.respond(embeds=embeds or [discord.Embed(description="No text found.")])
|
|
||||||
|
|
||||||
if await self.bot.is_owner(ctx.user):
|
|
||||||
await ctx.respond(
|
|
||||||
("Cached response ({})" if cached else "Uncached response ({})").format(file_hash), ephemeral=True
|
|
||||||
)
|
|
||||||
|
|
||||||
@commands.slash_command()
|
@commands.slash_command()
|
||||||
async def whois(self, ctx: discord.ApplicationContext, domain: str):
|
async def whois(self, ctx: discord.ApplicationContext, domain: str):
|
||||||
|
|
Loading…
Reference in a new issue