From eaa195fb7b7be822f57ffaf86715fd0b0393d313 Mon Sep 17 00:00:00 2001 From: nexy7574 Date: Thu, 19 Sep 2024 17:28:25 +0100 Subject: [PATCH] Add MSC searching --- app/modules/memetera_counter.py | 4 +- app/modules/msc_getter.py | 95 +++++++++++++++++++++++++++++++-- 2 files changed, 93 insertions(+), 6 deletions(-) diff --git a/app/modules/memetera_counter.py b/app/modules/memetera_counter.py index 01fa243..b060b3d 100644 --- a/app/modules/memetera_counter.py +++ b/app/modules/memetera_counter.py @@ -26,7 +26,9 @@ class MemeteraCounter(niobot.Module): } WORDS = { "memetera": re.compile(r"^memetera$", re.IGNORECASE), - "comma": re.compile(r",") + "comma": re.compile(r","), + "nyo": re.compile(r"nyo", re.IGNORECASE), + "bite": re.compile(r"bite", re.IGNORECASE) } bot: "NonsenseBot" log = logging.getLogger(__name__) diff --git a/app/modules/msc_getter.py b/app/modules/msc_getter.py index 2b3a214..8e4e77e 100644 --- a/app/modules/msc_getter.py +++ b/app/modules/msc_getter.py @@ -21,7 +21,7 @@ class MSCGetter(niobot.Module): self.msc_cache = Path.cwd() / ".msc-cache" self.msc_cache.mkdir(parents=True, exist_ok=True) - async def get_msc_with_cache(self, number: int): + async def get_msc_with_cache(self, number: int) -> dict: if number not in range(1, 10_000): return {"error": "Invalid MSC ID"} @@ -43,11 +43,49 @@ class MSCGetter(niobot.Module): % number ) if response.status_code != 200: - return {"error": "HTTP %d" % response.status_code} + return { + "error": "Failed to fetch issue from GitHub (HTTP {!s})," + "and no cached version was available.".format(response.status_code) + } content = response.json() file.write_text(json.dumps(content)) return content + async def search_for_msc(self, query: str) -> list[dict]: + force_fetch = "+fetch" in query + query = query.replace("+fetch", "") + found = [] + for cached_msc in self.msc_cache.glob("*.json"): + data = json.loads(cached_msc.read_text()) + if query.casefold() in data["title"].casefold(): + found.append(data) + + if len(found) < 10 or force_fetch: + async with httpx.AsyncClient() as client: + response = await client.get( + "https://api.github.com/search/issues", + params={ + "q": query, + "per_page": 10 + } + ) + if response.status_code == 200: + data = response.json() + for pr in data: + found.append(pr) + number = pr["number"] + file = self.msc_cache / ("%d.json" % number) + if not file.exists(): + file.write_text(json.dumps(pr)) + else: + return [ + { + "title": "Error querying GitHub (HTTP %s)." % str(response.status_code), + "html_url": "https://http.cat/" + str(response.status_code) + } + ] + return found + @niobot.event("message") async def on_message(self, room: niobot.MatrixRoom, message: niobot.RoomMessage): if self.bot.is_old(message): @@ -61,7 +99,7 @@ class MSCGetter(niobot.Module): if await self.bot.redis.get( self.bot.redis_key(room.room_id, "auto_msc.enabled") ): - matches = re.finditer("^[MmSsCc]\W?([0-9]{1,4})$", message.body) + matches = re.finditer(r"((msc)\W?)([0-9]{1,4})", message.body, re.IGNORECASE) lines = [] for m in matches: @@ -76,9 +114,55 @@ class MSCGetter(niobot.Module): room, "\n".join((f"* {ln}" for ln in lines)), reply_to=message ) + @staticmethod + def pr_to_display(data: dict) -> str: + return f"* [{data['title']}]({data['html_url']})" + @niobot.command() async def msc(self, ctx: niobot.Context, number: str): """Fetches the given MSC""" + if number.startswith("?"): # search + msg = await ctx.respond("Searching for relevant MSCs...") + results = await self.search_for_msc(number[1:]) + if not results: + await msg.edit("No MSCs matched your query.") + + lines = [] + for pr in results: + lines.append(self.pr_to_display(pr)) + lines_formatted = "\n".join(lines) + + if len(lines) > 5: + new_lines = [ + "", + "
", + "And %d more..." % (len(results) - 5), + "", + "
" + ] + + await msg.edit( + content="\n".join(new_lines), + content_type="html.raw", + override={ + "body": lines_formatted + } + ) + else: + await msg.edit(content=lines_formatted) + return + + if number.startswith("msc"): number = number[3:] elif number.startswith("#"): @@ -86,10 +170,11 @@ class MSCGetter(niobot.Module): if not number.isdigit() or len(number) != 4: return await ctx.respond("Invalid MXC number.") + msg = await ctx.respond("Fetching MSC #{:0>4}...".format(number)) data: dict = await self.get_msc_with_cache(int(number)) if data.get("error"): - return await ctx.respond(data["error"]) - return await ctx.respond(f"[{data['title']}]({data['html_url']})") + return await msg.edit(data["error"]) + return await msg.edit(self.pr_to_display(data)) @niobot.command("automsc.enable") async def auto_msc_enable(self, ctx: niobot.Context):