diff --git a/src/main.py b/src/main.py index 274c18f..db2001c 100644 --- a/src/main.py +++ b/src/main.py @@ -1,7 +1,12 @@ +import asyncio import datetime import logging import sys import traceback +import typing + +import uvicorn +from web import app from logging import FileHandler import discord @@ -35,6 +40,30 @@ for logger in CONFIG["logging"].get("suppress", []): logging.getLogger(logger).setLevel(logging.WARNING) log.info(f"Suppressed logging for {logger}") + +class Client(commands.Bot): + def __init_(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.web: typing.Optional[asyncio.Task] = None + + async def start(self, token: str, *, reconnect: bool = True) -> None: + config = uvicorn.Config( + app, + host=CONFIG["server"].get("host", "0.0.0.0"), + port=CONFIG["server"].get("port", 8080), + loop="asyncio", + lifespan="on", + server_header=False + ) + server = uvicorn.Server(config=config) + self.web= self.loop.create_task(server.serve()) + await super().start(token, reconnect=reconnect) + + async def close(self) -> None: + if self.web: + self.web.cancel() + await super().close() + bot = commands.Bot( command_prefix=commands.when_mentioned_or("h!", "H!"), case_insensitive=True, @@ -91,4 +120,5 @@ async def on_application_command_completion(ctx: discord.ApplicationContext): if not CONFIG["jimmy"].get("token"): log.critical("No token specified in config.toml. Exiting. (hint: set jimmy.token in config.toml)") sys.exit(1) + bot.run(CONFIG["jimmy"]["token"]) diff --git a/src/web.py b/src/web.py index e953665..a2ec382 100644 --- a/src/web.py +++ b/src/web.py @@ -1,12 +1,17 @@ import asyncio +import datetime import logging +import textwrap + import psutil import time import pydantic from typing import Optional, Any from conf import CONFIG +import discord +from discord.ext.commands import Paginator -from fastapi import FastAPI, HTTPException, status +from fastapi import FastAPI, HTTPException, status, WebSocketException, WebSocket, WebSocketDisconnect, Header class BridgeResponse(pydantic.BaseModel): status: str @@ -49,6 +54,7 @@ app = FastAPI( log = logging.getLogger("jimmy.web.api") app.state.bot = None app.state.bridge_lock = asyncio.Lock() +app.state.last_sender_ts = 0 @app.get("/ping") @@ -68,4 +74,87 @@ def ping(): @app.post("/bridge", status_code=201) async def bridge_post_send_message(body: BridgePayload): """Sends a message FROM matrix TO discord.""" - raise HTTPException(status.HTTP_503_SERVICE_UNAVAILABLE) + now = datetime.datetime.now(datetime.timezone.utc) + ts_diff = (now - app.state.last_sender_ts).total_seconds() + if not app.state.bot: + raise HTTPException(status.HTTP_503_SERVICE_UNAVAILABLE) + + if body.secret != CONFIG["jimmy"].get("token"): + log.warning("Authentication failure: %s was not authenticated.", body.secret) + raise HTTPException(status.HTTP_401_UNAUTHORIZED) + + channel = app.state.bot.get_channel(CONFIG["server"]["channel"]) + if not channel or not channel.can_send(): + log.warning("Unable to send message: channel not found or not writable.") + raise HTTPException(status.HTTP_503_SERVICE_UNAVAILABLE) + + if len(body.message) > 4000: + log.warning( + "Unable to send message: message too long ({:,} characters long, 4000 max).".format(len(body.message)) + ) + raise HTTPException(status.HTTP_413_REQUEST_ENTITY_TOO_LARGE) + + paginator = Paginator(prefix="", suffix="", max_size=1990) + for line in body["message"].splitlines(): + try: + paginator.add_line(line) + except ValueError: + paginator.add_line(textwrap.shorten(line, width=1900, placeholder="<...>")) + + if len(paginator.pages) > 1: + msg = None + if app.state.last_sender != body["sender"] or ts_diff >= 600: + msg = await channel.send(f"**{body['sender']}**:") + m = len(paginator.pages) + for n, page in enumerate(paginator.pages, 1): + await channel.send( + f"[{n}/{m}]\n>>> {page}", + allowed_mentions=discord.AllowedMentions.none(), + reference=msg, + silent=True, + suppress=n != m, + ) + app.state.last_sender = body["sender"] + else: + content = f"**{body['sender']}**:\n>>> {body['message']}" + if app.state.last_sender == body["sender"] and ts_diff < 600: + content = f">>> {body['message']}" + await channel.send(content, allowed_mentions=discord.AllowedMentions.none(), silent=True, suppress=False) + app.state.last_sender = body["sender"] + app.state.last_sender_ts = now + return {"status": "ok", "pages": len(paginator.pages)} + + +@app.websocket("/bridge/recv") +async def bridge_recv(ws: WebSocket, secret: str = Header(None)): + await ws.accept() + log.info("Websocket %s:%s accepted.", ws.client.host, ws.client.port) + if secret != app.state.bot.http.token: + log.warning("Closing websocket %r, invalid secret.", ws.client.host) + raise WebSocketException(code=1008, reason="Invalid Secret") + if app.state.ws_connected.locked(): + log.warning("Closing websocket %r, already connected." % ws) + raise WebSocketException(code=1008, reason="Already connected.") + queue: asyncio.Queue = app.state.bot.bridge_queue + + async with app.state.ws_connected: + while True: + try: + await ws.send_json({"status": "ping"}) + except (WebSocketDisconnect, WebSocketException): + log.info("Websocket %r disconnected.", ws) + break + + try: + data = await asyncio.wait_for(queue.get(), timeout=5) + except asyncio.TimeoutError: + continue + + try: + await ws.send_json(data) + log.debug("Sent data %r to websocket %r.", data, ws) + except (WebSocketDisconnect, WebSocketException): + log.info("Websocket %r disconnected." % ws) + break + finally: + queue.task_done()