diff --git a/src/cogs/auto_responder.py b/src/cogs/auto_responder.py index 5776ce6..7c64075 100644 --- a/src/cogs/auto_responder.py +++ b/src/cogs/auto_responder.py @@ -1,11 +1,12 @@ import asyncio +import io +import typing from collections.abc import Iterable import logging import pathlib -import re import subprocess import tempfile -from urllib.parse import urlparse +from urllib.parse import urlparse, ParseResult import aiohttp import discord @@ -22,12 +23,21 @@ class AutoResponder(commands.Cog): self.transcode_lock = asyncio.Lock() @staticmethod - def extract_links(text: str, *domains: str) -> list[str]: + @typing.overload + def extract_links(text: str, *domains: str, raw: typing.Literal[True] = False) -> list[ParseResult]: ... + + @staticmethod + @typing.overload + def extract_links(text: str, *domains: str, raw: typing.Literal[False] = False) -> list[str]: ... + + @staticmethod + def extract_links(text: str, *domains: str, raw: bool = False) -> list[str | ParseResult]: """ Extracts all links from a given text. :param text: The raw text to extract links from. :param domains: A list of domains to filter for. + :param raw: Whether to return the raw ParseResult :return: A list of found links """ split = text.split() @@ -38,7 +48,7 @@ class AutoResponder(commands.Cog): continue if domains and url.netloc not in domains: continue - links.append(url.geturl()) + links.append(url.geturl() if raw is False else url) return links async def _transcode_hevc_to_h264( @@ -247,6 +257,37 @@ class AutoResponder(commands.Cog): except Exception as e: self.log.error("Failed to transcode %r: %r", link, e) + async def copy_ncfe_docs(self, message: discord.Message, links: list[ParseResult]) -> None: + files = [] + async with aiohttp.ClientSession() as session: + for link in links: + if link.path.endswith(".pdf"): + buffer = io.BytesIO() + self.log.info("Downloading %s", link) + async with session.get(link.geturl()) as response: + if response.status != 200: + self.log.warning( + "Failed to download %s: HTTP %d - %r", + link, + response.status, + await response.text() + ) + continue + async for chunk in response.content.iter_any(): + # noinspection PyTypeChecker + buffer.write(chunk) + size = buffer.tell() + self.log.info("Downloaded %d bytes", size) + buffer.seek(0) + if len(buffer.getvalue()) > 25 * 1024 * 1024: + self.log.warning("File was too large to upload. Skipping.") + continue + p = pathlib.Path(link.path).name + file = discord.File(buffer, filename=p, description="Copy of %s" % link.geturl()) + files.append(file) + for file in files: + await message.reply(file=file) + @commands.Cog.listener("on_message") async def auto_responder(self, message: discord.Message): if message.author == self.bot.user: @@ -255,7 +296,12 @@ class AutoResponder(commands.Cog): # Check for HEVC truth social links and convert into h264 if message.channel.name == "spam" and message.author != self.bot.user: await self.transcode_hevc_to_h264(message) - + + links = self.extract_links(message.content, "www.ncfe.co.uk", "ncfe.co.uk", raw=True) + if links: + await self.copy_ncfe_docs(message, links) + + @commands.Cog.listener("on_reaction_add") async def on_reaction_add(self, reaction: discord.Reaction, user: discord.User): if user == self.bot.user: