diff --git a/utils/db.py b/utils/db.py index 5bb815f..6fa5a60 100644 --- a/utils/db.py +++ b/utils/db.py @@ -231,10 +231,12 @@ class BridgeBind(orm.Model): fields = { "entry_id": orm.UUID(primary_key=True, default=uuid.uuid4), "matrix_id": orm.Text(unique=True), - "discord_id": orm.BigInteger() + "discord_id": orm.BigInteger(), + "webhook": orm.Text(nullable=True, default=None), } if TYPE_CHECKING: entry_id: uuid.UUID matrix_id: str discord_id: int + webhook: str | None diff --git a/web/server.py b/web/server.py index d1b3520..22622d0 100644 --- a/web/server.py +++ b/web/server.py @@ -9,11 +9,13 @@ from datetime import datetime, timezone from hashlib import sha512 from http import HTTPStatus from pathlib import Path -from typing import Optional +from typing import Optional, Annotated +from discord.ext.commands import Paginator import discord import httpx -from fastapi import FastAPI, Header, HTTPException, Request, status +from fastapi import FastAPI, Header, HTTPException, Request, dependencies, status, Depends +from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials as HTTPAuthCreds from fastapi import WebSocketException as _WSException from fastapi.responses import HTMLResponse, JSONResponse, RedirectResponse from starlette.websockets import WebSocket, WebSocketDisconnect @@ -51,6 +53,7 @@ app.state.bot = None app.state.states = {} app.state.binds = {} app.state.http = httpx.Client() +security = HTTPBearer() if StaticFiles: app.mount("/static", StaticFiles(directory=SF_ROOT), name="static") @@ -66,6 +69,11 @@ app.state.last_sender_ts = datetime.utcnow() app.state.ws_connected = Lock() +async def is_authenticated(credentials: Annotated[HTTPAuthCreds, security]): + if credentials.credentials != app.state.bot.http.token: + raise HTTPException(status_code=401, detail="Invalid secret.") + + async def get_access_token(code: str, redirect_uri: str = OAUTH_REDIRECT_URI): response = app.state.http.post( "https://discord.com/api/oauth2/token", @@ -215,69 +223,34 @@ async def authenticate(req: Request, code: str = None, state: str = None): return response -@app.get("/verify/{code}") -async def verify(code: str): - guild = app.state.bot.get_guild(guilds[0]) - if not guild: - raise HTTPException(status_code=503, detail="Not ready.") - - # First, we need to fetch the code from the database - verify_code = await get_or_none(VerifyCode, code=code) - if not verify_code: - raise HTTPException(status_code=404, detail="Code not found.") - - # Now we need to fetch the student from the database - student = await get_or_none(Student, user_id=verify_code.bind) - if student: - raise HTTPException(status_code=400, detail="Already verified.") - - ban = await get_or_none(BannedStudentID, student_id=verify_code.student_id) - if ban is not None: - return await guild.kick( - reason=f"Attempted to verify with banned student ID {ban.student_id}" - f" (originally associated with account {ban.associated_account})" - ) - await Student.objects.create(id=verify_code.student_id, user_id=verify_code.bind, name=verify_code.name) - await verify_code.delete() - role = discord.utils.find(lambda r: r.name.lower() == "verified", guild.roles) - member = await guild.fetch_member(verify_code.bind) - if role and role < guild.me.top_role: - await member.add_roles(role, reason="Verified") - try: - await member.edit(nick=f"{verify_code.name}", reason="Verified") - except discord.HTTPException: - pass - - # And delete the code - await verify_code.delete() - - log.info(f"[green]{verify_code.bind} verified ({verify_code.bind}/{verify_code.student_id})") - - return RedirectResponse(GENERAL, status_code=308) - - -@app.post("/bridge", include_in_schema=False, status_code=201) +@app.post("/bridge", include_in_schema=False, status_code=201, dependencies=[Depends(is_authenticated)] async def bridge(req: Request): now = datetime.utcnow() ts_diff = (now - app.state.last_sender_ts).total_seconds() - from discord.ext.commands import Paginator - body = await req.json() - if body["secret"] != app.state.bot.http.token: - raise HTTPException(status_code=401, detail="Invalid secret.") - channel = app.state.bot.get_channel(1032974266527907901) # type: discord.TextChannel | None + room_id = body.get("room") + if not room_id: + raise HTTPException(status_code=400, detail="Missing room ID. Required as of 26/02/2024.") + bind = await get_or_none(BridgeBind, matrix_id=room_id) + # ^ Binds are only supposed to be used for User binds, however, in this case we can just recycle it. + if not bind: + channel_id = 1032974266527907901 + else: + channel_id = bind.discord_id + + channel = app.state.bot.get_channel(channel_id) # type: discord.TextChannel | None if not channel: - raise HTTPException(status_code=404, detail="Channel does not exist.") + raise HTTPException(status_code=404, detail="Channel %r does not exist." % channel_id) if len(body["message"]) > 4000: - raise HTTPException(status_code=400, detail="Message too long.") + raise HTTPException(status_code=400, detail="Message too long. 4000 characters maximum.") paginator = Paginator(prefix="", suffix="", max_size=1990) for line in body["message"].splitlines(): try: paginator.add_line(line) except ValueError: - paginator.add_line(textwrap.shorten(line, width=1900, placeholder="<...>")) + paginator.add_line(textwrap.shorten(line, width=1980, placeholder="<...>")) if len(paginator.pages) > 1: msg = None if app.state.last_sender != body["sender"] or ts_diff >= 600: @@ -303,7 +276,7 @@ async def bridge(req: Request): @app.websocket("/bridge/recv") -async def bridge_recv(ws: WebSocket, secret: str = Header(None)): +async def bridge_recv(ws: WebSocket, secret: str = Query(None)): await ws.accept() log.info("Websocket %s:%s accepted.", ws.client.host, ws.client.port) if secret != app.state.bot.http.token: @@ -337,12 +310,12 @@ async def bridge_recv(ws: WebSocket, secret: str = Header(None)): queue.task_done() -@app.get("/bridge/bind/new") +@app.get("/bridge/bind/new", dependencies=[Depends(is_authenticated)]) async def bridge_bind_new(mx_id: str): """Begins a new bind session.""" existing: Optional[BridgeBind] = await get_or_none(BridgeBind, matrix_id=mx_id) if existing: - raise HTTPException(409, "Account already bound") + raise HTTPException(409, "Target already bound") if not OAUTH_ENABLED: raise HTTPException(status.HTTP_503_SERVICE_UNAVAILABLE) @@ -360,7 +333,7 @@ async def bridge_bind_new(mx_id: str): } -@app.get("/bridge/bind/callback") +@app.get("/bridge/bind/callback", include_in_schema=False) async def bridge_bind_callback(code: str, state: str): """Finishes the bind.""" # Getting an entire access token seems like a waste, but oh well. Only need to do this once. @@ -372,7 +345,24 @@ async def bridge_bind_callback(code: str, state: str): user = await get_authorised_user(access_token,) user_id = int(user["id"]) await BridgeBind.objects.create(matrix_id=mx_id, discord_id=user_id) - return JSONResponse({"matrix": mx_id, "discord": user_id}, 201) + return JSONResponse({"success": True, "matrix": mx_id, "discord": user_id}, 201) + + +@app.post("/bridge/bind/_create", include_in_schema=False, dependencies=[Depends(is_authenticated)]) +async def bridge_bind_create_nonuser( + req: Request +): + body = await req.json() + if "mx_id" not in body or "discord_id" not in body: + raise HTTPException(400, "Missing fields") + mx_id = body["mx_id"] + discord_id = body["discord_id"] + webhook = body.get("webhook") + existing: Optional[BridgeBind] = await get_or_none(BridgeBind, matrix_id=mx_id) + if existing: + raise HTTPException(409, "Target already bound") + await BridgeBind.objects.create(matrix_id=mx_id, discord_id=discord_id, webhook=webhook) + return JSONResponse({"status": "ok"}, 201) @app.delete("/bridge/bind/{mx_id}") @@ -402,10 +392,13 @@ async def bridge_bind_delete(mx_id: str, code: str = None, state: str = None): await existing.delete() return JSONResponse({"status": "ok"}, 200) -@app.get("/bridge/bind/{mx_id}") +@app.get("/bridge/bind/{mx_id}", dependencies=[Depends(is_authenticated)]) async def bridge_bind_fetch(mx_id: str): """Fetch the discord account associated with a matrix account.""" existing: Optional[BridgeBind] = await get_or_none(BridgeBind, matrix_id=mx_id) if not existing: raise HTTPException(404, "Not found") - return JSONResponse({"discord": existing.discord_id}, 200) + payload = {"discord": existing.discord_id, "matrix": mx_id} + if existing.webhook: + payload["webhook"] = existing.webhook + return JSONResponse(payload, 200)