diff --git a/web/server.py b/web/server.py index 5aba5c0..7206869 100644 --- a/web/server.py +++ b/web/server.py @@ -11,7 +11,8 @@ from pathlib import Path import discord import httpx from config import guilds -from fastapi import FastAPI, Header, HTTPException, Request +from asyncio import Lock +from fastapi import FastAPI, Header, HTTPException, Request, WebSocketException as _WSException from websockets.exceptions import WebSocketException from fastapi.responses import HTMLResponse, JSONResponse, RedirectResponse from starlette.websockets import WebSocket, WebSocketDisconnect @@ -55,7 +56,7 @@ except ImportError: bot = None app.state.last_sender = None app.state.last_sender_ts = datetime.utcnow() -app.state.ws_connected = False +app.state.ws_connected = Lock() @app.middleware("http") @@ -290,25 +291,24 @@ async def bridge(req: Request): @app.websocket("/bridge/recv") async def bridge_recv(ws: WebSocket, secret: str = Header(None)): + await ws.accept() if secret != app.state.bot.http.token: - raise HTTPException(status_code=401, detail="Invalid secret.") - if app.state.ws_connected: - raise HTTPException(status_code=409, detail="Already connected.") + raise _WSException(code=1008, reason="Invalid Secret") + if app.state.ws_connected.locked(): + raise _WSException(code=1008, reason="Already connected.") queue: asyncio.Queue = app.state.bot.bridge_queue - await ws.accept() - app.state.ws_connected = True - while True: - try: - data = queue.get_nowait() - except asyncio.QueueEmpty: - await asyncio.sleep(0.5) - continue + async with app.state.ws_connected: + while True: + try: + data = queue.get_nowait() + except asyncio.QueueEmpty: + await asyncio.sleep(0.5) + continue - try: - await ws.send_json(data) - except (WebSocketDisconnect, WebSocketException): - break - finally: - queue.task_done() - app.state.ws_connected = False + try: + await ws.send_json(data) + except (WebSocketDisconnect, WebSocketException): + break + finally: + queue.task_done()