mirror of
https://github.com/nexy7574/LCC-bot.git
synced 2024-09-19 18:16:34 +01:00
Merge pull request #3 from nexy7574/feature/bind
Implement account binding
This commit is contained in:
commit
64dfa6c456
2 changed files with 113 additions and 22 deletions
16
utils/db.py
16
utils/db.py
|
@ -33,6 +33,7 @@ __all__ = [
|
||||||
"Tutors",
|
"Tutors",
|
||||||
"UptimeEntry",
|
"UptimeEntry",
|
||||||
"JimmyBans",
|
"JimmyBans",
|
||||||
|
"BridgeBind"
|
||||||
]
|
]
|
||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
@ -222,3 +223,18 @@ class AccessTokens(orm.Model):
|
||||||
user_id: int
|
user_id: int
|
||||||
access_token: str
|
access_token: str
|
||||||
ip_info: dict | None
|
ip_info: dict | None
|
||||||
|
|
||||||
|
|
||||||
|
class BridgeBinds(orm.Model):
|
||||||
|
tablename = "bridge_binds"
|
||||||
|
registry = registry
|
||||||
|
fields = {
|
||||||
|
"entry_id": orm.UUID(primary_key=True, default=uuid.uuid4),
|
||||||
|
"matrix_id": orm.Text(unique=True),
|
||||||
|
"discord_id": orm.BigInteger()
|
||||||
|
}
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
entry_id: uuid.UUID
|
||||||
|
matrix_id: str
|
||||||
|
discord_id: int
|
||||||
|
|
119
web/server.py
119
web/server.py
|
@ -3,22 +3,24 @@ import ipaddress
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import textwrap
|
import textwrap
|
||||||
|
import secrets
|
||||||
from asyncio import Lock
|
from asyncio import Lock
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from hashlib import sha512
|
from hashlib import sha512
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import discord
|
import discord
|
||||||
import httpx
|
import httpx
|
||||||
from fastapi import FastAPI, Header, HTTPException, Request
|
from fastapi import FastAPI, Header, HTTPException, Request, status
|
||||||
from fastapi import WebSocketException as _WSException
|
from fastapi import WebSocketException as _WSException
|
||||||
from fastapi.responses import HTMLResponse, JSONResponse, RedirectResponse
|
from fastapi.responses import HTMLResponse, JSONResponse, RedirectResponse
|
||||||
from starlette.websockets import WebSocket, WebSocketDisconnect
|
from starlette.websockets import WebSocket, WebSocketDisconnect
|
||||||
from websockets.exceptions import WebSocketException
|
from websockets.exceptions import WebSocketException
|
||||||
|
|
||||||
from config import guilds
|
from config import guilds
|
||||||
from utils import BannedStudentID, Student, VerifyCode, console, get_or_none
|
from utils import BannedStudentID, Student, VerifyCode, console, get_or_none, BridgeBind
|
||||||
from utils.db import AccessTokens
|
from utils.db import AccessTokens
|
||||||
|
|
||||||
SF_ROOT = Path(__file__).parent / "static"
|
SF_ROOT = Path(__file__).parent / "static"
|
||||||
|
@ -46,6 +48,7 @@ OAUTH_ENABLED = OAUTH_ID and OAUTH_SECRET and OAUTH_REDIRECT_URI
|
||||||
app = FastAPI(root_path=WEB_ROOT_PATH)
|
app = FastAPI(root_path=WEB_ROOT_PATH)
|
||||||
app.state.bot = None
|
app.state.bot = None
|
||||||
app.state.states = {}
|
app.state.states = {}
|
||||||
|
app.state.binds = {}
|
||||||
app.state.http = httpx.Client()
|
app.state.http = httpx.Client()
|
||||||
|
|
||||||
if StaticFiles:
|
if StaticFiles:
|
||||||
|
@ -62,6 +65,30 @@ app.state.last_sender_ts = datetime.utcnow()
|
||||||
app.state.ws_connected = Lock()
|
app.state.ws_connected = Lock()
|
||||||
|
|
||||||
|
|
||||||
|
async def get_access_token(code: str):
|
||||||
|
response = app.state.http.post(
|
||||||
|
"https://discord.com/api/oauth2/token",
|
||||||
|
data={
|
||||||
|
"grant_type": "authorization_code",
|
||||||
|
"code": code,
|
||||||
|
"redirect_uri": OAUTH_REDIRECT_URI,
|
||||||
|
},
|
||||||
|
headers={"Content-Type": "application/x-www-form-urlencoded"}
|
||||||
|
auth=(CLIENT_ID, CLIENT_SECRET)
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
return response.json()
|
||||||
|
|
||||||
|
|
||||||
|
async def get_authorised_user(access_token: str):
|
||||||
|
response = app.state.http.get(
|
||||||
|
"https://discord.com/api/users/@me",
|
||||||
|
headers={"Authorization": "Bearer " + access_token}
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
return response.json()
|
||||||
|
|
||||||
|
|
||||||
@app.middleware("http")
|
@app.middleware("http")
|
||||||
async def check_bot_instanced(request, call_next):
|
async def check_bot_instanced(request, call_next):
|
||||||
if not request.app.state.bot:
|
if not request.app.state.bot:
|
||||||
|
@ -121,32 +148,14 @@ async def authenticate(req: Request, code: str = None, state: str = None):
|
||||||
else:
|
else:
|
||||||
app.state.states.pop(state)
|
app.state.states.pop(state)
|
||||||
# First, we need to do the auth code flow
|
# First, we need to do the auth code flow
|
||||||
response = app.state.http.post(
|
data = await get_access_token(code)
|
||||||
"https://discord.com/api/oauth2/token",
|
|
||||||
data={
|
|
||||||
"client_id": OAUTH_ID,
|
|
||||||
"client_secret": OAUTH_SECRET,
|
|
||||||
"grant_type": "authorization_code",
|
|
||||||
"code": code,
|
|
||||||
"redirect_uri": OAUTH_REDIRECT_URI,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
if response.status_code != 200:
|
|
||||||
raise HTTPException(status_code=response.status_code, detail=response.text)
|
|
||||||
data = response.json()
|
|
||||||
access_token = data["access_token"]
|
access_token = data["access_token"]
|
||||||
|
|
||||||
# Now we can generate a token
|
# Now we can generate a token
|
||||||
token = sha512(access_token.encode()).hexdigest()
|
token = sha512(access_token.encode()).hexdigest()
|
||||||
|
|
||||||
# Now we can get the user's info
|
# Now we can get the user's info
|
||||||
response = app.state.http.get(
|
user = await get_authorised_user(access_token)
|
||||||
"https://discord.com/api/users/@me", headers={"Authorization": "Bearer " + data["access_token"]}
|
|
||||||
)
|
|
||||||
if response.status_code != 200:
|
|
||||||
raise HTTPException(status_code=response.status_code, detail=response.text)
|
|
||||||
|
|
||||||
user = response.json()
|
|
||||||
|
|
||||||
# Now we need to fetch the student from the database
|
# Now we need to fetch the student from the database
|
||||||
student = await get_or_none(AccessTokens, user_id=user["id"])
|
student = await get_or_none(AccessTokens, user_id=user["id"])
|
||||||
|
@ -325,3 +334,69 @@ async def bridge_recv(ws: WebSocket, secret: str = Header(None)):
|
||||||
break
|
break
|
||||||
finally:
|
finally:
|
||||||
queue.task_done()
|
queue.task_done()
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/bridge/bind/new")
|
||||||
|
async def bridge_bind_new(mx_id: str):
|
||||||
|
"""Begins a new bind session."""
|
||||||
|
existing: Optional[BridgeBind] = await get_or_none(BridgeBind, matrix_id=mx_id)
|
||||||
|
if existing:
|
||||||
|
raise HTTPException(409, "Account already bound")
|
||||||
|
|
||||||
|
if not OAUTH_ENABLED:
|
||||||
|
raise HTTPException(503)
|
||||||
|
|
||||||
|
token = secrets.token_urlsafe()
|
||||||
|
app.state.binds[token] = mx_id
|
||||||
|
url = discord.utils.oauth_url(
|
||||||
|
OAUTH_ID, redirect_uri=OAUTH_REDIRECT_URI, scopes=("identify")
|
||||||
|
) + f"&state={value}&prompt=none"
|
||||||
|
return {
|
||||||
|
"status": "pending",
|
||||||
|
"url": url,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/bridge/bind/callback")
|
||||||
|
async def bridge_bind_callback(code: str, state: str):
|
||||||
|
"""Finishes the bind."""
|
||||||
|
# Getting an entire access token seems like a waste, but oh well. Only need to do this once.
|
||||||
|
mx_id = app.state.binds.pop(state, None)
|
||||||
|
if not mx_id:
|
||||||
|
raise HTTPException(status_code=400, "Invalid state")
|
||||||
|
data = await get_access_token(code)
|
||||||
|
access_token = data["access_token"]
|
||||||
|
user = await get_authorised_user(access_token)
|
||||||
|
user_id = int(user["id"])
|
||||||
|
await BridgeBind.objects.create(matrix_id=mx_id, user_id=user_id)
|
||||||
|
return JSONResponse({"matrix": mx_id, "discord": user_id}, 201)
|
||||||
|
|
||||||
|
|
||||||
|
@app.delete("/bridge/bind/{mx_id}")
|
||||||
|
async def bridge_bind_delete(mx_id: str, code: str = None, state: str = None):
|
||||||
|
"""Unbinds a matrix account."""
|
||||||
|
existing: Optional[BridgeBind] = await get_or_none(BridgeBind, matrix_id=mx_id)
|
||||||
|
if not existing:
|
||||||
|
raise HTTPException(404, "Not found")
|
||||||
|
|
||||||
|
if not (code and state) or state not in app.state.binds:
|
||||||
|
token = secrets.token_urlsafe()
|
||||||
|
app.state.binds[token] = mx_id
|
||||||
|
url = discord.utils.oauth_url(
|
||||||
|
OAUTH_ID, redirect_uri=OAUTH_REDIRECT_URI, scopes=("identify")
|
||||||
|
) + f"&state={value}&prompt=none"
|
||||||
|
return JSONResponse({"status": "pending", "url": url})
|
||||||
|
else:
|
||||||
|
real_mx_id = app.state.binds.pop(state, None)
|
||||||
|
if real_mx_id != mx_id:
|
||||||
|
raise HTTPException(400, "Invalid state")
|
||||||
|
await existing.delete()
|
||||||
|
return JSONResponse({"status": "ok"}, 200)
|
||||||
|
|
||||||
|
@app.get("/bridge/bind/{mx_id}")
|
||||||
|
async def bridge_bind_fetch(mx_id: str):
|
||||||
|
"""Fetch the discord account associated with a matrix account."""
|
||||||
|
existing: Optional[BridgeBind] = await get_or_none(BridgeBind, matrix_id=mx_id)
|
||||||
|
if not existing:
|
||||||
|
raise HTTPException(404, "Not found")
|
||||||
|
return JSONResponse({"discord": existing.user_id}, 200)
|
||||||
|
|
Loading…
Reference in a new issue