diff --git a/web/server.py b/web/server.py index c1c79f5..a2ab889 100644 --- a/web/server.py +++ b/web/server.py @@ -9,10 +9,11 @@ from datetime import datetime, timezone from hashlib import sha512 from http import HTTPStatus from pathlib import Path +from typing import Optional import discord import httpx -from fastapi import FastAPI, Header, HTTPException, Request +from fastapi import FastAPI, Header, HTTPException, Request, status from fastapi import WebSocketException as _WSException from fastapi.responses import HTMLResponse, JSONResponse, RedirectResponse from starlette.websockets import WebSocket, WebSocketDisconnect @@ -64,6 +65,30 @@ app.state.last_sender_ts = datetime.utcnow() app.state.ws_connected = Lock() +async def get_access_token(code: str): + response = app.state.http.post( + "https://discord.com/api/oauth2/token", + data={ + "grant_type": "authorization_code", + "code": code, + "redirect_uri": OAUTH_REDIRECT_URI, + }, + headers={"Content-Type": "application/x-www-form-urlencoded"} + auth=(CLIENT_ID, CLIENT_SECRET) + ) + response.raise_for_status() + return response.json() + + +async def get_authorised_user(access_token: str): + response = app.state.http.get( + "https://discord.com/api/users/@me", + headers={"Authorization": "Bearer " + access_token} + ) + response.raise_for_status() + return response.json() + + @app.middleware("http") async def check_bot_instanced(request, call_next): if not request.app.state.bot: @@ -123,32 +148,14 @@ async def authenticate(req: Request, code: str = None, state: str = None): else: app.state.states.pop(state) # First, we need to do the auth code flow - response = app.state.http.post( - "https://discord.com/api/oauth2/token", - data={ - "client_id": OAUTH_ID, - "client_secret": OAUTH_SECRET, - "grant_type": "authorization_code", - "code": code, - "redirect_uri": OAUTH_REDIRECT_URI, - }, - ) - if response.status_code != 200: - raise HTTPException(status_code=response.status_code, detail=response.text) - data = response.json() + data = await get_access_token(code) access_token = data["access_token"] # Now we can generate a token token = sha512(access_token.encode()).hexdigest() # Now we can get the user's info - response = app.state.http.get( - "https://discord.com/api/users/@me", headers={"Authorization": "Bearer " + data["access_token"]} - ) - if response.status_code != 200: - raise HTTPException(status_code=response.status_code, detail=response.text) - - user = response.json() + user = await get_authorised_user(access_token) # Now we need to fetch the student from the database student = await get_or_none(AccessTokens, user_id=user["id"]) @@ -330,7 +337,7 @@ async def bridge_recv(ws: WebSocket, secret: str = Header(None)): @app.get("/bridge/bind/new") -async def bridge_new_bind(mx_id: str): +async def bridge_bind_new(mx_id: str): """Begins a new bind session.""" existing: Optional[BridgeBind] = await get_or_none(BridgeBind, matrix_id=mx_id) if existing: @@ -342,6 +349,46 @@ async def bridge_new_bind(mx_id: str): token = secrets.token_urlsafe() app.state.binds[token] = mx_id url = discord.utils.oauth_url( - OAUTH_ID, redirect_uri=OAUTH_REDIRECT_URI, scopes=("identify", "connections", "guilds", "email") - ) - + f"&state={value}&prompt=none" + OAUTH_ID, redirect_uri=OAUTH_REDIRECT_URI, scopes=("identify") + ) + f"&state={value}&prompt=none" + return { + "status": "pending", + "url": url, + } + + +@app.get("/bridge/bind/callback") +async def bridge_bind_callback(code: str, state: str): + """Finishes the bind.""" + # Getting an entire access token seems like a waste, but oh well. Only need to do this once. + mx_id = app.state.binds.pop(state, None) + if not mx_id: + raise HTTPException(status_code=400, "Invalid state") + data = await get_access_token(code) + access_token = data["access_token"] + user = await get_authorised_user(access_token) + user_id = int(user["id"]) + await BridgeBind.objects.create(matrix_id=mx_id, user_id=user_id) + return JSONResponse({"matrix": mx_id, "discord": user_id}, 201) + + +@app.delete("/bridge/bind/{mx_id}") +async def bridge_bind_delete(mx_id: str, code: str = None, state: str = None): + """Unbinds a matrix account.""" + existing: Optional[BridgeBind] = await get_or_none(BridgeBind, matrix_id=mx_id) + if not existing: + raise HTTPException(404, "Not found") + + if not (code and state) or state not in app.state.binds: + token = secrets.token_urlsafe() + app.state.binds[token] = mx_id + url = discord.utils.oauth_url( + OAUTH_ID, redirect_uri=OAUTH_REDIRECT_URI, scopes=("identify") + ) + f"&state={value}&prompt=none" + return JSONResponse({"status": "pending", "url": url}) + else: + real_mx_id = app.state.binds.pop(state, None) + if real_mx_id != mx_id: + raise HTTPException(400, "Invalid state") + await existing.delete() + return JSONResponse({"status": "ok"}, 200)