from __future__ import annotations import json import logging import re import typing from typing import Annotated, TYPE_CHECKING import httpx from pathlib import Path import niobot if typing.TYPE_CHECKING: from ..main import NonsenseBot class MSCGetter(niobot.Module): if TYPE_CHECKING: bot: "NonsenseBot" log = logging.getLogger(__name__) def __init__(self, bot): super().__init__(bot) self.latest_msc = None self.msc_cache = Path.cwd() / ".msc-cache" self.msc_cache.mkdir(parents=True, exist_ok=True) async def get_msc_with_cache(self, number: int) -> dict: if number not in range(1, 10_000): return {"error": "Invalid MSC ID"} file = self.msc_cache / ("%d.json" % number) content = None if file.exists(): try: content = json.loads(file.read_text("utf-8", "replace")) except json.JSONDecodeError: file.unlink() if content: return content self.log.debug("Requesting MSC: %d", number) async with httpx.AsyncClient() as client: response = await client.get( "https://api.github.com/repos/matrix-org/matrix-spec-proposals/issues/%d" % number ) if response.status_code != 200: return { "error": "Failed to fetch issue from GitHub (HTTP {!s})," "and no cached version was available.".format(response.status_code) } content = response.json() file.write_text(json.dumps(content)) return content async def search_for_msc(self, query: str) -> list[dict]: force_fetch = "+fetch" in query query = query.replace("+fetch", "") found = [] for cached_msc in self.msc_cache.glob("*.json"): data = json.loads(cached_msc.read_text()) if query.casefold() in data["title"].casefold(): found.append(data) if len(found) < 10 or force_fetch: async with httpx.AsyncClient() as client: response = await client.get( "https://api.github.com/search/issues", params={ "q": query + "+is:pull-request+repo:matrix-org/matrix-spec-proposals", "per_page": 10 } ) if response.status_code == 200: data = response.json() for pr in data["items"]: found.append(pr) number = pr["number"] file = self.msc_cache / ("%d.json" % number) if not file.exists(): file.write_text(json.dumps(pr)) else: return [ { "title": "Error querying GitHub (HTTP %s)." % str(response.status_code), "html_url": "https://http.cat/" + str(response.status_code) } ] return found @niobot.event("message") async def on_message(self, room: niobot.MatrixRoom, message: niobot.RoomMessage): if self.bot.is_old(message): return if message.sender == self.bot.user: return if not isinstance(message, niobot.RoomMessageText): return if "m.in_reply_to" in message.source.get("m.relates_to", []): return if await self.bot.redis.get( self.bot.redis_key(room.room_id, "auto_msc.enabled") ): matches = re.finditer(r"((msc)\W?)([0-9]{1,4})", message.body, re.IGNORECASE) lines = [] for m in matches: no = m.group(1) if no: data = await self.get_msc_with_cache(int(no)) if data.get("error"): continue lines.append(f"[{data['title']}]({data['html_url']})") if lines: return await self.bot.send_message( room, "\n".join((f"* {ln}" for ln in lines)), reply_to=message ) @staticmethod def pr_to_display(data: dict) -> str: return f"* [{data['title']}]({data['html_url']})" @niobot.command() async def msc( self, ctx: niobot.Context, number_or_query ): """Fetches the given MSC""" if number_or_query.startswith("?"): # search msg = await ctx.respond("Searching for relevant MSCs with query %r..." % number_or_query[1:]) results = await self.search_for_msc(number_or_query[1:]) if not results: await msg.edit("No MSCs matched your query.") lines = [] for pr in results: lines.append(self.pr_to_display(pr)) lines_formatted = "\n".join(lines) if len(lines) > 3: new_lines = [ "", "
", "And %d more..." % (len(results) - 3), "", "
" ] await msg.edit( content="\n".join(new_lines), content_type="html.raw", override={ "body": lines_formatted } ) else: await msg.edit(content=lines_formatted) return if number_or_query.startswith("msc"): number_or_query = number_or_query[3:] elif number_or_query.startswith("#"): number_or_query = number_or_query[1:] if not number_or_query.isdigit() or len(number_or_query) != 4: return await ctx.respond("Invalid MXC number.") msg = await ctx.respond("Fetching MSC #{:0>4}...".format(number_or_query)) data: dict = await self.get_msc_with_cache(int(number_or_query)) if data.get("error"): return await msg.edit(data["error"]) return await msg.edit(self.pr_to_display(data)) @niobot.command("automsc.enable") async def auto_msc_enable(self, ctx: niobot.Context): """Automatically enables MSC linking. Requires a power level of at least 50.""" if (sp := ctx.room.power_levels.users.get(ctx.message.sender, -999)) < 50: return await ctx.respond( f"You need to have at least a power level of 50 to use this (you have {sp})." ) key = self.bot.redis_key(ctx.room.room_id, "auto_msc.enabled") exists = await self.bot.redis.get(key) if exists: return await ctx.respond("AutoMSC is already enabled in this room.") await self.bot.redis.set(key, 1) await self.bot.redis.save() return await self.bot.add_reaction( ctx.room, ctx.message, "\N{WHITE HEAVY CHECK MARK}" ) @niobot.command("automsc.disable") async def auto_msc_disable(self, ctx: niobot.Context): """Disables automatic MSC linking. Requires a power level of at least 50.""" if (sp := ctx.room.power_levels.users.get(ctx.message.sender, -999)) < 50: return await ctx.respond( f"You need to have at least a power level of 50 to use this (you have {sp})." ) key = self.bot.redis_key(ctx.room.room_id, "auto_msc.enabled") exists = await self.bot.redis.get(key) if exists: await self.bot.redis.delete(key) await self.bot.redis.save() await self.bot.add_reaction( ctx.room, ctx.message, "\N{WHITE HEAVY CHECK MARK}" ) else: return await ctx.respond("AutoMSC is already disabled in this room.")