Implement proper authentication for endpoints

This commit is contained in:
Nexus 2024-02-26 11:25:52 +00:00
parent 10fc05da8d
commit 03fb84b3f8
2 changed files with 55 additions and 60 deletions

View file

@ -231,10 +231,12 @@ class BridgeBind(orm.Model):
fields = {
"entry_id": orm.UUID(primary_key=True, default=uuid.uuid4),
"matrix_id": orm.Text(unique=True),
"discord_id": orm.BigInteger()
"discord_id": orm.BigInteger(),
"webhook": orm.Text(nullable=True, default=None),
}
if TYPE_CHECKING:
entry_id: uuid.UUID
matrix_id: str
discord_id: int
webhook: str | None

View file

@ -9,11 +9,13 @@ from datetime import datetime, timezone
from hashlib import sha512
from http import HTTPStatus
from pathlib import Path
from typing import Optional
from typing import Optional, Annotated
from discord.ext.commands import Paginator
import discord
import httpx
from fastapi import FastAPI, Header, HTTPException, Request, status
from fastapi import FastAPI, Header, HTTPException, Request, dependencies, status, Depends
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials as HTTPAuthCreds
from fastapi import WebSocketException as _WSException
from fastapi.responses import HTMLResponse, JSONResponse, RedirectResponse
from starlette.websockets import WebSocket, WebSocketDisconnect
@ -51,6 +53,7 @@ app.state.bot = None
app.state.states = {}
app.state.binds = {}
app.state.http = httpx.Client()
security = HTTPBearer()
if StaticFiles:
app.mount("/static", StaticFiles(directory=SF_ROOT), name="static")
@ -66,6 +69,11 @@ app.state.last_sender_ts = datetime.utcnow()
app.state.ws_connected = Lock()
async def is_authenticated(credentials: Annotated[HTTPAuthCreds, security]):
if credentials.credentials != app.state.bot.http.token:
raise HTTPException(status_code=401, detail="Invalid secret.")
async def get_access_token(code: str, redirect_uri: str = OAUTH_REDIRECT_URI):
response = app.state.http.post(
"https://discord.com/api/oauth2/token",
@ -215,69 +223,34 @@ async def authenticate(req: Request, code: str = None, state: str = None):
return response
@app.get("/verify/{code}")
async def verify(code: str):
guild = app.state.bot.get_guild(guilds[0])
if not guild:
raise HTTPException(status_code=503, detail="Not ready.")
# First, we need to fetch the code from the database
verify_code = await get_or_none(VerifyCode, code=code)
if not verify_code:
raise HTTPException(status_code=404, detail="Code not found.")
# Now we need to fetch the student from the database
student = await get_or_none(Student, user_id=verify_code.bind)
if student:
raise HTTPException(status_code=400, detail="Already verified.")
ban = await get_or_none(BannedStudentID, student_id=verify_code.student_id)
if ban is not None:
return await guild.kick(
reason=f"Attempted to verify with banned student ID {ban.student_id}"
f" (originally associated with account {ban.associated_account})"
)
await Student.objects.create(id=verify_code.student_id, user_id=verify_code.bind, name=verify_code.name)
await verify_code.delete()
role = discord.utils.find(lambda r: r.name.lower() == "verified", guild.roles)
member = await guild.fetch_member(verify_code.bind)
if role and role < guild.me.top_role:
await member.add_roles(role, reason="Verified")
try:
await member.edit(nick=f"{verify_code.name}", reason="Verified")
except discord.HTTPException:
pass
# And delete the code
await verify_code.delete()
log.info(f"[green]{verify_code.bind} verified ({verify_code.bind}/{verify_code.student_id})")
return RedirectResponse(GENERAL, status_code=308)
@app.post("/bridge", include_in_schema=False, status_code=201)
@app.post("/bridge", include_in_schema=False, status_code=201, dependencies=[Depends(is_authenticated)]
async def bridge(req: Request):
now = datetime.utcnow()
ts_diff = (now - app.state.last_sender_ts).total_seconds()
from discord.ext.commands import Paginator
body = await req.json()
if body["secret"] != app.state.bot.http.token:
raise HTTPException(status_code=401, detail="Invalid secret.")
channel = app.state.bot.get_channel(1032974266527907901) # type: discord.TextChannel | None
room_id = body.get("room")
if not room_id:
raise HTTPException(status_code=400, detail="Missing room ID. Required as of 26/02/2024.")
bind = await get_or_none(BridgeBind, matrix_id=room_id)
# ^ Binds are only supposed to be used for User binds, however, in this case we can just recycle it.
if not bind:
channel_id = 1032974266527907901
else:
channel_id = bind.discord_id
channel = app.state.bot.get_channel(channel_id) # type: discord.TextChannel | None
if not channel:
raise HTTPException(status_code=404, detail="Channel does not exist.")
raise HTTPException(status_code=404, detail="Channel %r does not exist." % channel_id)
if len(body["message"]) > 4000:
raise HTTPException(status_code=400, detail="Message too long.")
raise HTTPException(status_code=400, detail="Message too long. 4000 characters maximum.")
paginator = Paginator(prefix="", suffix="", max_size=1990)
for line in body["message"].splitlines():
try:
paginator.add_line(line)
except ValueError:
paginator.add_line(textwrap.shorten(line, width=1900, placeholder="<...>"))
paginator.add_line(textwrap.shorten(line, width=1980, placeholder="<...>"))
if len(paginator.pages) > 1:
msg = None
if app.state.last_sender != body["sender"] or ts_diff >= 600:
@ -303,7 +276,7 @@ async def bridge(req: Request):
@app.websocket("/bridge/recv")
async def bridge_recv(ws: WebSocket, secret: str = Header(None)):
async def bridge_recv(ws: WebSocket, secret: str = Query(None)):
await ws.accept()
log.info("Websocket %s:%s accepted.", ws.client.host, ws.client.port)
if secret != app.state.bot.http.token:
@ -337,12 +310,12 @@ async def bridge_recv(ws: WebSocket, secret: str = Header(None)):
queue.task_done()
@app.get("/bridge/bind/new")
@app.get("/bridge/bind/new", dependencies=[Depends(is_authenticated)])
async def bridge_bind_new(mx_id: str):
"""Begins a new bind session."""
existing: Optional[BridgeBind] = await get_or_none(BridgeBind, matrix_id=mx_id)
if existing:
raise HTTPException(409, "Account already bound")
raise HTTPException(409, "Target already bound")
if not OAUTH_ENABLED:
raise HTTPException(status.HTTP_503_SERVICE_UNAVAILABLE)
@ -360,7 +333,7 @@ async def bridge_bind_new(mx_id: str):
}
@app.get("/bridge/bind/callback")
@app.get("/bridge/bind/callback", include_in_schema=False)
async def bridge_bind_callback(code: str, state: str):
"""Finishes the bind."""
# Getting an entire access token seems like a waste, but oh well. Only need to do this once.
@ -372,7 +345,24 @@ async def bridge_bind_callback(code: str, state: str):
user = await get_authorised_user(access_token,)
user_id = int(user["id"])
await BridgeBind.objects.create(matrix_id=mx_id, discord_id=user_id)
return JSONResponse({"matrix": mx_id, "discord": user_id}, 201)
return JSONResponse({"success": True, "matrix": mx_id, "discord": user_id}, 201)
@app.post("/bridge/bind/_create", include_in_schema=False, dependencies=[Depends(is_authenticated)])
async def bridge_bind_create_nonuser(
req: Request
):
body = await req.json()
if "mx_id" not in body or "discord_id" not in body:
raise HTTPException(400, "Missing fields")
mx_id = body["mx_id"]
discord_id = body["discord_id"]
webhook = body.get("webhook")
existing: Optional[BridgeBind] = await get_or_none(BridgeBind, matrix_id=mx_id)
if existing:
raise HTTPException(409, "Target already bound")
await BridgeBind.objects.create(matrix_id=mx_id, discord_id=discord_id, webhook=webhook)
return JSONResponse({"status": "ok"}, 201)
@app.delete("/bridge/bind/{mx_id}")
@ -402,10 +392,13 @@ async def bridge_bind_delete(mx_id: str, code: str = None, state: str = None):
await existing.delete()
return JSONResponse({"status": "ok"}, 200)
@app.get("/bridge/bind/{mx_id}")
@app.get("/bridge/bind/{mx_id}", dependencies=[Depends(is_authenticated)])
async def bridge_bind_fetch(mx_id: str):
"""Fetch the discord account associated with a matrix account."""
existing: Optional[BridgeBind] = await get_or_none(BridgeBind, matrix_id=mx_id)
if not existing:
raise HTTPException(404, "Not found")
return JSONResponse({"discord": existing.discord_id}, 200)
payload = {"discord": existing.discord_id, "matrix": mx_id}
if existing.webhook:
payload["webhook"] = existing.webhook
return JSONResponse(payload, 200)