mirror of
https://github.com/nexy7574/LCC-bot.git
synced 2024-09-19 10:03:40 +01:00
Implement proper authentication for endpoints
This commit is contained in:
parent
10fc05da8d
commit
03fb84b3f8
2 changed files with 55 additions and 60 deletions
|
@ -231,10 +231,12 @@ class BridgeBind(orm.Model):
|
||||||
fields = {
|
fields = {
|
||||||
"entry_id": orm.UUID(primary_key=True, default=uuid.uuid4),
|
"entry_id": orm.UUID(primary_key=True, default=uuid.uuid4),
|
||||||
"matrix_id": orm.Text(unique=True),
|
"matrix_id": orm.Text(unique=True),
|
||||||
"discord_id": orm.BigInteger()
|
"discord_id": orm.BigInteger(),
|
||||||
|
"webhook": orm.Text(nullable=True, default=None),
|
||||||
}
|
}
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
entry_id: uuid.UUID
|
entry_id: uuid.UUID
|
||||||
matrix_id: str
|
matrix_id: str
|
||||||
discord_id: int
|
discord_id: int
|
||||||
|
webhook: str | None
|
||||||
|
|
111
web/server.py
111
web/server.py
|
@ -9,11 +9,13 @@ from datetime import datetime, timezone
|
||||||
from hashlib import sha512
|
from hashlib import sha512
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional, Annotated
|
||||||
|
from discord.ext.commands import Paginator
|
||||||
|
|
||||||
import discord
|
import discord
|
||||||
import httpx
|
import httpx
|
||||||
from fastapi import FastAPI, Header, HTTPException, Request, status
|
from fastapi import FastAPI, Header, HTTPException, Request, dependencies, status, Depends
|
||||||
|
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials as HTTPAuthCreds
|
||||||
from fastapi import WebSocketException as _WSException
|
from fastapi import WebSocketException as _WSException
|
||||||
from fastapi.responses import HTMLResponse, JSONResponse, RedirectResponse
|
from fastapi.responses import HTMLResponse, JSONResponse, RedirectResponse
|
||||||
from starlette.websockets import WebSocket, WebSocketDisconnect
|
from starlette.websockets import WebSocket, WebSocketDisconnect
|
||||||
|
@ -51,6 +53,7 @@ app.state.bot = None
|
||||||
app.state.states = {}
|
app.state.states = {}
|
||||||
app.state.binds = {}
|
app.state.binds = {}
|
||||||
app.state.http = httpx.Client()
|
app.state.http = httpx.Client()
|
||||||
|
security = HTTPBearer()
|
||||||
|
|
||||||
if StaticFiles:
|
if StaticFiles:
|
||||||
app.mount("/static", StaticFiles(directory=SF_ROOT), name="static")
|
app.mount("/static", StaticFiles(directory=SF_ROOT), name="static")
|
||||||
|
@ -66,6 +69,11 @@ app.state.last_sender_ts = datetime.utcnow()
|
||||||
app.state.ws_connected = Lock()
|
app.state.ws_connected = Lock()
|
||||||
|
|
||||||
|
|
||||||
|
async def is_authenticated(credentials: Annotated[HTTPAuthCreds, security]):
|
||||||
|
if credentials.credentials != app.state.bot.http.token:
|
||||||
|
raise HTTPException(status_code=401, detail="Invalid secret.")
|
||||||
|
|
||||||
|
|
||||||
async def get_access_token(code: str, redirect_uri: str = OAUTH_REDIRECT_URI):
|
async def get_access_token(code: str, redirect_uri: str = OAUTH_REDIRECT_URI):
|
||||||
response = app.state.http.post(
|
response = app.state.http.post(
|
||||||
"https://discord.com/api/oauth2/token",
|
"https://discord.com/api/oauth2/token",
|
||||||
|
@ -215,69 +223,34 @@ async def authenticate(req: Request, code: str = None, state: str = None):
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
@app.get("/verify/{code}")
|
@app.post("/bridge", include_in_schema=False, status_code=201, dependencies=[Depends(is_authenticated)]
|
||||||
async def verify(code: str):
|
|
||||||
guild = app.state.bot.get_guild(guilds[0])
|
|
||||||
if not guild:
|
|
||||||
raise HTTPException(status_code=503, detail="Not ready.")
|
|
||||||
|
|
||||||
# First, we need to fetch the code from the database
|
|
||||||
verify_code = await get_or_none(VerifyCode, code=code)
|
|
||||||
if not verify_code:
|
|
||||||
raise HTTPException(status_code=404, detail="Code not found.")
|
|
||||||
|
|
||||||
# Now we need to fetch the student from the database
|
|
||||||
student = await get_or_none(Student, user_id=verify_code.bind)
|
|
||||||
if student:
|
|
||||||
raise HTTPException(status_code=400, detail="Already verified.")
|
|
||||||
|
|
||||||
ban = await get_or_none(BannedStudentID, student_id=verify_code.student_id)
|
|
||||||
if ban is not None:
|
|
||||||
return await guild.kick(
|
|
||||||
reason=f"Attempted to verify with banned student ID {ban.student_id}"
|
|
||||||
f" (originally associated with account {ban.associated_account})"
|
|
||||||
)
|
|
||||||
await Student.objects.create(id=verify_code.student_id, user_id=verify_code.bind, name=verify_code.name)
|
|
||||||
await verify_code.delete()
|
|
||||||
role = discord.utils.find(lambda r: r.name.lower() == "verified", guild.roles)
|
|
||||||
member = await guild.fetch_member(verify_code.bind)
|
|
||||||
if role and role < guild.me.top_role:
|
|
||||||
await member.add_roles(role, reason="Verified")
|
|
||||||
try:
|
|
||||||
await member.edit(nick=f"{verify_code.name}", reason="Verified")
|
|
||||||
except discord.HTTPException:
|
|
||||||
pass
|
|
||||||
|
|
||||||
# And delete the code
|
|
||||||
await verify_code.delete()
|
|
||||||
|
|
||||||
log.info(f"[green]{verify_code.bind} verified ({verify_code.bind}/{verify_code.student_id})")
|
|
||||||
|
|
||||||
return RedirectResponse(GENERAL, status_code=308)
|
|
||||||
|
|
||||||
|
|
||||||
@app.post("/bridge", include_in_schema=False, status_code=201)
|
|
||||||
async def bridge(req: Request):
|
async def bridge(req: Request):
|
||||||
now = datetime.utcnow()
|
now = datetime.utcnow()
|
||||||
ts_diff = (now - app.state.last_sender_ts).total_seconds()
|
ts_diff = (now - app.state.last_sender_ts).total_seconds()
|
||||||
from discord.ext.commands import Paginator
|
|
||||||
|
|
||||||
body = await req.json()
|
body = await req.json()
|
||||||
if body["secret"] != app.state.bot.http.token:
|
|
||||||
raise HTTPException(status_code=401, detail="Invalid secret.")
|
|
||||||
|
|
||||||
channel = app.state.bot.get_channel(1032974266527907901) # type: discord.TextChannel | None
|
room_id = body.get("room")
|
||||||
|
if not room_id:
|
||||||
|
raise HTTPException(status_code=400, detail="Missing room ID. Required as of 26/02/2024.")
|
||||||
|
bind = await get_or_none(BridgeBind, matrix_id=room_id)
|
||||||
|
# ^ Binds are only supposed to be used for User binds, however, in this case we can just recycle it.
|
||||||
|
if not bind:
|
||||||
|
channel_id = 1032974266527907901
|
||||||
|
else:
|
||||||
|
channel_id = bind.discord_id
|
||||||
|
|
||||||
|
channel = app.state.bot.get_channel(channel_id) # type: discord.TextChannel | None
|
||||||
if not channel:
|
if not channel:
|
||||||
raise HTTPException(status_code=404, detail="Channel does not exist.")
|
raise HTTPException(status_code=404, detail="Channel %r does not exist." % channel_id)
|
||||||
|
|
||||||
if len(body["message"]) > 4000:
|
if len(body["message"]) > 4000:
|
||||||
raise HTTPException(status_code=400, detail="Message too long.")
|
raise HTTPException(status_code=400, detail="Message too long. 4000 characters maximum.")
|
||||||
paginator = Paginator(prefix="", suffix="", max_size=1990)
|
paginator = Paginator(prefix="", suffix="", max_size=1990)
|
||||||
for line in body["message"].splitlines():
|
for line in body["message"].splitlines():
|
||||||
try:
|
try:
|
||||||
paginator.add_line(line)
|
paginator.add_line(line)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
paginator.add_line(textwrap.shorten(line, width=1900, placeholder="<...>"))
|
paginator.add_line(textwrap.shorten(line, width=1980, placeholder="<...>"))
|
||||||
if len(paginator.pages) > 1:
|
if len(paginator.pages) > 1:
|
||||||
msg = None
|
msg = None
|
||||||
if app.state.last_sender != body["sender"] or ts_diff >= 600:
|
if app.state.last_sender != body["sender"] or ts_diff >= 600:
|
||||||
|
@ -303,7 +276,7 @@ async def bridge(req: Request):
|
||||||
|
|
||||||
|
|
||||||
@app.websocket("/bridge/recv")
|
@app.websocket("/bridge/recv")
|
||||||
async def bridge_recv(ws: WebSocket, secret: str = Header(None)):
|
async def bridge_recv(ws: WebSocket, secret: str = Query(None)):
|
||||||
await ws.accept()
|
await ws.accept()
|
||||||
log.info("Websocket %s:%s accepted.", ws.client.host, ws.client.port)
|
log.info("Websocket %s:%s accepted.", ws.client.host, ws.client.port)
|
||||||
if secret != app.state.bot.http.token:
|
if secret != app.state.bot.http.token:
|
||||||
|
@ -337,12 +310,12 @@ async def bridge_recv(ws: WebSocket, secret: str = Header(None)):
|
||||||
queue.task_done()
|
queue.task_done()
|
||||||
|
|
||||||
|
|
||||||
@app.get("/bridge/bind/new")
|
@app.get("/bridge/bind/new", dependencies=[Depends(is_authenticated)])
|
||||||
async def bridge_bind_new(mx_id: str):
|
async def bridge_bind_new(mx_id: str):
|
||||||
"""Begins a new bind session."""
|
"""Begins a new bind session."""
|
||||||
existing: Optional[BridgeBind] = await get_or_none(BridgeBind, matrix_id=mx_id)
|
existing: Optional[BridgeBind] = await get_or_none(BridgeBind, matrix_id=mx_id)
|
||||||
if existing:
|
if existing:
|
||||||
raise HTTPException(409, "Account already bound")
|
raise HTTPException(409, "Target already bound")
|
||||||
|
|
||||||
if not OAUTH_ENABLED:
|
if not OAUTH_ENABLED:
|
||||||
raise HTTPException(status.HTTP_503_SERVICE_UNAVAILABLE)
|
raise HTTPException(status.HTTP_503_SERVICE_UNAVAILABLE)
|
||||||
|
@ -360,7 +333,7 @@ async def bridge_bind_new(mx_id: str):
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@app.get("/bridge/bind/callback")
|
@app.get("/bridge/bind/callback", include_in_schema=False)
|
||||||
async def bridge_bind_callback(code: str, state: str):
|
async def bridge_bind_callback(code: str, state: str):
|
||||||
"""Finishes the bind."""
|
"""Finishes the bind."""
|
||||||
# Getting an entire access token seems like a waste, but oh well. Only need to do this once.
|
# Getting an entire access token seems like a waste, but oh well. Only need to do this once.
|
||||||
|
@ -372,7 +345,24 @@ async def bridge_bind_callback(code: str, state: str):
|
||||||
user = await get_authorised_user(access_token,)
|
user = await get_authorised_user(access_token,)
|
||||||
user_id = int(user["id"])
|
user_id = int(user["id"])
|
||||||
await BridgeBind.objects.create(matrix_id=mx_id, discord_id=user_id)
|
await BridgeBind.objects.create(matrix_id=mx_id, discord_id=user_id)
|
||||||
return JSONResponse({"matrix": mx_id, "discord": user_id}, 201)
|
return JSONResponse({"success": True, "matrix": mx_id, "discord": user_id}, 201)
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/bridge/bind/_create", include_in_schema=False, dependencies=[Depends(is_authenticated)])
|
||||||
|
async def bridge_bind_create_nonuser(
|
||||||
|
req: Request
|
||||||
|
):
|
||||||
|
body = await req.json()
|
||||||
|
if "mx_id" not in body or "discord_id" not in body:
|
||||||
|
raise HTTPException(400, "Missing fields")
|
||||||
|
mx_id = body["mx_id"]
|
||||||
|
discord_id = body["discord_id"]
|
||||||
|
webhook = body.get("webhook")
|
||||||
|
existing: Optional[BridgeBind] = await get_or_none(BridgeBind, matrix_id=mx_id)
|
||||||
|
if existing:
|
||||||
|
raise HTTPException(409, "Target already bound")
|
||||||
|
await BridgeBind.objects.create(matrix_id=mx_id, discord_id=discord_id, webhook=webhook)
|
||||||
|
return JSONResponse({"status": "ok"}, 201)
|
||||||
|
|
||||||
|
|
||||||
@app.delete("/bridge/bind/{mx_id}")
|
@app.delete("/bridge/bind/{mx_id}")
|
||||||
|
@ -402,10 +392,13 @@ async def bridge_bind_delete(mx_id: str, code: str = None, state: str = None):
|
||||||
await existing.delete()
|
await existing.delete()
|
||||||
return JSONResponse({"status": "ok"}, 200)
|
return JSONResponse({"status": "ok"}, 200)
|
||||||
|
|
||||||
@app.get("/bridge/bind/{mx_id}")
|
@app.get("/bridge/bind/{mx_id}", dependencies=[Depends(is_authenticated)])
|
||||||
async def bridge_bind_fetch(mx_id: str):
|
async def bridge_bind_fetch(mx_id: str):
|
||||||
"""Fetch the discord account associated with a matrix account."""
|
"""Fetch the discord account associated with a matrix account."""
|
||||||
existing: Optional[BridgeBind] = await get_or_none(BridgeBind, matrix_id=mx_id)
|
existing: Optional[BridgeBind] = await get_or_none(BridgeBind, matrix_id=mx_id)
|
||||||
if not existing:
|
if not existing:
|
||||||
raise HTTPException(404, "Not found")
|
raise HTTPException(404, "Not found")
|
||||||
return JSONResponse({"discord": existing.discord_id}, 200)
|
payload = {"discord": existing.discord_id, "matrix": mx_id}
|
||||||
|
if existing.webhook:
|
||||||
|
payload["webhook"] = existing.webhook
|
||||||
|
return JSONResponse(payload, 200)
|
||||||
|
|
Loading…
Reference in a new issue