Fix websocket lock

This commit is contained in:
Nexus 2023-11-07 13:43:39 +00:00
parent 98c916bb9b
commit f79247928c
Signed by: nex
GPG key ID: 0FA334385D0B689F

View file

@ -11,7 +11,8 @@ from pathlib import Path
import discord import discord
import httpx import httpx
from config import guilds from config import guilds
from fastapi import FastAPI, Header, HTTPException, Request from asyncio import Lock
from fastapi import FastAPI, Header, HTTPException, Request, WebSocketException as _WSException
from websockets.exceptions import WebSocketException from websockets.exceptions import WebSocketException
from fastapi.responses import HTMLResponse, JSONResponse, RedirectResponse from fastapi.responses import HTMLResponse, JSONResponse, RedirectResponse
from starlette.websockets import WebSocket, WebSocketDisconnect from starlette.websockets import WebSocket, WebSocketDisconnect
@ -55,7 +56,7 @@ except ImportError:
bot = None bot = None
app.state.last_sender = None app.state.last_sender = None
app.state.last_sender_ts = datetime.utcnow() app.state.last_sender_ts = datetime.utcnow()
app.state.ws_connected = False app.state.ws_connected = Lock()
@app.middleware("http") @app.middleware("http")
@ -290,25 +291,24 @@ async def bridge(req: Request):
@app.websocket("/bridge/recv") @app.websocket("/bridge/recv")
async def bridge_recv(ws: WebSocket, secret: str = Header(None)): async def bridge_recv(ws: WebSocket, secret: str = Header(None)):
await ws.accept()
if secret != app.state.bot.http.token: if secret != app.state.bot.http.token:
raise HTTPException(status_code=401, detail="Invalid secret.") raise _WSException(code=1008, reason="Invalid Secret")
if app.state.ws_connected: if app.state.ws_connected.locked():
raise HTTPException(status_code=409, detail="Already connected.") raise _WSException(code=1008, reason="Already connected.")
queue: asyncio.Queue = app.state.bot.bridge_queue queue: asyncio.Queue = app.state.bot.bridge_queue
await ws.accept() async with app.state.ws_connected:
app.state.ws_connected = True while True:
while True: try:
try: data = queue.get_nowait()
data = queue.get_nowait() except asyncio.QueueEmpty:
except asyncio.QueueEmpty: await asyncio.sleep(0.5)
await asyncio.sleep(0.5) continue
continue
try: try:
await ws.send_json(data) await ws.send_json(data)
except (WebSocketDisconnect, WebSocketException): except (WebSocketDisconnect, WebSocketException):
break break
finally: finally:
queue.task_done() queue.task_done()
app.state.ws_connected = False