Merge branch 'master' of github.com:nexy7574/college-bot-v2

This commit is contained in:
Nexus 2024-04-02 02:38:23 +01:00
commit c8019f3993
10 changed files with 336 additions and 199 deletions

3
.gitignore vendored
View file

@ -310,4 +310,5 @@ pyrightconfig.json
# End of https://www.toptal.com/developers/gitignore/api/python,pycharm,visualstudiocode # End of https://www.toptal.com/developers/gitignore/api/python,pycharm,visualstudiocode
cookies.txt cookies.txt
config.toml config.toml
chrome/ chrome/
src/assets/sensitive/*

View file

@ -18,3 +18,4 @@ humanize~=4.9
redis~=5.0 redis~=5.0
beautifulsoup4~=4.12 beautifulsoup4~=4.12
lxml~=5.1 lxml~=5.1
matplotlib~=3.8

View file

@ -4,6 +4,7 @@ import os
import re import re
import time import time
import typing import typing
from pathlib import Path
import discord import discord
from discord.ext import commands from discord.ext import commands
@ -242,6 +243,16 @@ class NetworkCog(commands.Cog):
paginator.add_line(f"Error: {e}") paginator.add_line(f"Error: {e}")
for page in paginator.pages: for page in paginator.pages:
await ctx.respond(page) await ctx.respond(page)
@commands.slash_command(name="what-are-matthews-bank-details")
async def matthew_bank(self, ctx: discord.ApplicationContext):
"""For the 80th time"""
f = Path.cwd() / "assets" / "sensitive" / "matthew-bank.webp"
if not f.exists():
return await ctx.respond("Idk")
else:
await ctx.defer()
await ctx.respond(file=discord.File(f))
def setup(bot): def setup(bot):

View file

@ -8,6 +8,7 @@ import typing
import base64 import base64
import io import io
import redis import redis
from discord import Interaction
from discord.ui import View, button from discord.ui import View, button
from fnmatch import fnmatch from fnmatch import fnmatch
@ -89,7 +90,7 @@ class ChatHistory:
"threads:" + thread_id, json.dumps(self._internal[thread_id]) "threads:" + thread_id, json.dumps(self._internal[thread_id])
) )
def create_thread(self, member: discord.Member) -> str: def create_thread(self, member: discord.Member, default: str | None = None) -> str:
""" """
Creates a thread, returns its ID. Creates a thread, returns its ID.
""" """
@ -100,7 +101,7 @@ class ChatHistory:
"messages": [] "messages": []
} }
with open("./assets/ollama-prompt.txt") as file: with open("./assets/ollama-prompt.txt") as file:
system_prompt = file.read() system_prompt = default or file.read()
self.add_message( self.add_message(
key, key,
"system", "system",
@ -190,6 +191,70 @@ class ChatHistory:
SERVER_KEYS = list(CONFIG["ollama"].keys()) SERVER_KEYS = list(CONFIG["ollama"].keys())
class OllamaGetPrompt(discord.ui.Modal):
def __init__(self, ctx: discord.ApplicationContext, prompt_type: str = "User"):
super().__init__(
discord.ui.InputText(
style=discord.InputTextStyle.long,
label="%s prompt" % prompt_type,
placeholder="Enter your prompt here.",
),
timeout=300,
title="Ollama %s prompt" % prompt_type,
)
self.ctx = ctx
self.prompt_type = prompt_type
self.value = None
async def interaction_check(self, interaction: discord.Interaction) -> bool:
return interaction.user == self.ctx.user
async def callback(self, interaction: Interaction):
await interaction.response.defer()
self.value = self.children[0].value
self.stop()
class PromptSelector(discord.ui.View):
def __init__(self, ctx: discord.ApplicationContext):
super().__init__(timeout=600, disable_on_timeout=True)
self.ctx = ctx
self.system_prompt = None
self.user_prompt = None
async def interaction_check(self, interaction: Interaction) -> bool:
return interaction.user == self.ctx.user
def update_ui(self):
if self.system_prompt is not None:
self.get_item("sys").style = discord.ButtonStyle.secondary # type: ignore
if self.user_prompt is not None:
self.get_item("usr").style = discord.ButtonStyle.secondary # type: ignore
@discord.ui.button(label="Set System Prompt", style=discord.ButtonStyle.primary, custom_id="sys")
async def set_system_prompt(self, btn: discord.ui.Button, interaction: Interaction):
modal = OllamaGetPrompt(self.ctx, "System")
await interaction.response.send_modal(modal)
await modal.wait()
self.system_prompt = modal.value
self.update_ui()
await interaction.edit_original_response(view=self)
@discord.ui.button(label="Set User Prompt", style=discord.ButtonStyle.primary, custom_id="usr")
async def set_user_prompt(self, btn: discord.ui.Button, interaction: Interaction):
modal = OllamaGetPrompt(self.ctx)
await interaction.response.send_modal(modal)
await modal.wait()
self.user_prompt = modal.value
self.update_ui()
await interaction.edit_original_response(view=self)
@discord.ui.button(label="Done", style=discord.ButtonStyle.success, custom_id="done")
async def done(self, btn: discord.ui.Button, interaction: Interaction):
self.stop()
class Ollama(commands.Cog): class Ollama(commands.Cog):
def __init__(self, bot: commands.Bot): def __init__(self, bot: commands.Bot):
self.bot = bot self.bot = bot
@ -282,11 +347,24 @@ class Ollama(commands.Cog):
) )
] ]
): ):
system_query = None
if context is not None: if context is not None:
if not self.history.get_thread(context): if not self.history.get_thread(context):
await ctx.respond("Invalid context key.") await ctx.respond("Invalid context key.")
return return
await ctx.defer()
try:
await ctx.defer()
except discord.HTTPException:
pass
if query == "$":
v = PromptSelector(ctx)
await ctx.respond("Select edit your prompts, as desired. Click done when you want to continue.", view=v)
await v.wait()
query = v.user_prompt or query
system_query = v.system_prompt
await ctx.delete(delay=0.1)
model = model.casefold() model = model.casefold()
try: try:
@ -294,7 +372,7 @@ class Ollama(commands.Cog):
model = model + ":" + tag model = model + ":" + tag
self.log.debug("Model %r already has a tag", model) self.log.debug("Model %r already has a tag", model)
except ValueError: except ValueError:
model = model + ":latest" model += ":latest"
self.log.debug("Resolved model to %r" % model) self.log.debug("Resolved model to %r" % model)
if image: if image:
@ -315,7 +393,7 @@ class Ollama(commands.Cog):
data = io.BytesIO() data = io.BytesIO()
await image.save(data) await image.save(data)
data.seek(0) data.seek(0)
image_data = base64.b64encode(data.read()).decode("utf-8") image_data = base64.b64encode(data.read()).decode()
else: else:
image_data = None image_data = None
@ -336,7 +414,12 @@ class Ollama(commands.Cog):
async with aiohttp.ClientSession( async with aiohttp.ClientSession(
base_url=server_config["base_url"], base_url=server_config["base_url"],
timeout=aiohttp.ClientTimeout(0) timeout=aiohttp.ClientTimeout(
connect=30,
sock_read=10800,
sock_connect=30,
total=10830
)
) as session: ) as session:
embed = discord.Embed( embed = discord.Embed(
title="Checking server...", title="Checking server...",
@ -482,7 +565,7 @@ class Ollama(commands.Cog):
self.log.debug("Beginning to generate response with key %r.", key) self.log.debug("Beginning to generate response with key %r.", key)
if context is None: if context is None:
context = self.history.create_thread(ctx.user) context = self.history.create_thread(ctx.user, system_query)
elif context is not None and self.history.get_thread(context) is None: elif context is not None and self.history.get_thread(context) is None:
__thread = self.history.find_thread(context) __thread = self.history.find_thread(context)
if not __thread: if not __thread:

188
src/cogs/quote_quota.py Normal file
View file

@ -0,0 +1,188 @@
import asyncio
import re
import discord
import io
import matplotlib.pyplot as plt
from datetime import timedelta
from discord.ext import commands
from typing import Iterable, Annotated
from conf import CONFIG
class QuoteQuota(commands.Cog):
def __init__(self, bot):
self.bot = bot
self.quotes_channel_id = CONFIG["quote_a"].get("channel_id")
self.names = CONFIG["quote_a"].get("names", {})
@property
def quotes_channel(self) -> discord.TextChannel | None:
if self.quotes_channel_id:
c = self.bot.get_channel(self.quotes_channel_id)
if c:
return c
@staticmethod
def generate_pie_chart(
usernames: list[str],
counts: list[int],
no_other: bool = False
) -> discord.File:
"""
Converts the given username and count tuples into a nice pretty pie chart.
:param usernames: The usernames
:param counts: The number of times the username appears in the chat
:param no_other: Disables the "other" grouping
:returns: The pie chart image
"""
def pct(v: int):
return f"{v:.1f}% ({round((v / 100) * sum(counts))})"
if no_other is False:
other = []
# Any authors with less than 5% of the total count will be grouped into "other"
for i, author in enumerate(usernames.copy()):
if (c := counts[i]) / sum(counts) < 0.05:
other.append(c)
counts[i] = -1
usernames.remove(author)
if other:
usernames.append("Other")
counts.append(sum(other))
# And now filter out any -1% counts
counts = [c for c in counts if c != -1]
mapping = {}
for i, author in enumerate(usernames):
mapping[author] = counts[i]
# Sort the authors by count
new_mapping = {}
for author, count in sorted(mapping.items(), key=lambda x: x[1], reverse=True):
new_mapping[author] = count
usernames = list(new_mapping.keys())
counts = list(new_mapping.values())
fig, ax = plt.subplots(figsize=(7, 7))
ax.pie(
counts,
labels=usernames,
autopct=pct,
startangle=90,
radius=1.2,
)
fig.subplots_adjust(left=0.1, bottom=0.1, right=0.9, top=0.9, wspace=0.3, hspace=0.4)
fio = io.BytesIO()
fig.savefig(fio, format='png')
fio.seek(0)
return discord.File(fio, filename="pie.png")
@commands.slash_command()
async def quota(
self,
ctx: discord.ApplicationContext,
days: Annotated[
int,
discord.Option(
int,
name="lookback",
description="How many days to look back on. Defaults to 7.",
default=7,
min_value=1,
max_value=365
)
],
merge_other: Annotated[
bool,
discord.Option(
bool,
name="merge_other",
description="Whether to merge authors with less than 5% of the total count into 'Other'.",
default=True
)
]
):
"""Checks the quote quota for the quotes channel."""
now = discord.utils.utcnow()
oldest = now - timedelta(days=days)
await ctx.defer()
channel = self.quotes_channel or discord.utils.get(ctx.guild.text_channels, name="quotes")
if not channel:
return await ctx.respond(":x: Cannot find quotes channel.")
await ctx.respond("Gathering messages, this may take a moment.")
authors = {}
filtered_messages = 0
total = 0
async for message in channel.history(
limit=None,
after=oldest,
oldest_first=False
):
total += 1
if not message.content:
filtered_messages += 1
continue
if message.attachments:
regex = r".*\s*-\s*@?([\w\s]+)"
else:
regex = r".+\s+-\s*@?([\w\s]+)"
if not (m := re.match(regex, str(message.clean_content))):
filtered_messages += 1
continue
name = m.group(1)
name = name.strip().casefold()
if name == "me":
name = message.author.name.strip().casefold()
if name in self.names:
name = self.names[name].title()
else:
filtered_messages += 1
continue
elif name in self.names:
name = self.names[name].title()
elif name.isdigit():
filtered_messages += 1
continue
name = name.title()
authors.setdefault(name, 0)
authors[name] += 1
if not authors:
if total:
return await ctx.edit(
content="No valid messages found in the last {!s} days. "
"Make sure quotes are formatted properly ending with ` - AuthorName`"
" (e.g. `\"This is my quote\" - Jimmy`)".format(days)
)
else:
return await ctx.edit(
content="No messages found in the last {!s} days.".format(days)
)
file = await asyncio.to_thread(
self.generate_pie_chart,
list(authors.keys()),
list(authors.values()),
merge_other
)
return await ctx.edit(
content="{:,} messages (out of {:,}) were filtered (didn't follow format?)".format(
filtered_messages,
total
),
file=file
)
def setup(bot):
bot.add_cog(QuoteQuota(bot))

View file

@ -5,6 +5,7 @@ import logging
import os import os
import tempfile import tempfile
import time import time
import copy
from urllib.parse import urlparse from urllib.parse import urlparse
import discord import discord
@ -15,6 +16,8 @@ from selenium import webdriver
from selenium.webdriver.chrome.options import Options as ChromeOptions from selenium.webdriver.chrome.options import Options as ChromeOptions
from selenium.webdriver.chrome.service import Service as ChromeService from selenium.webdriver.chrome.service import Service as ChromeService
from conf import CONFIG
class ScreenshotCog(commands.Cog): class ScreenshotCog(commands.Cog):
def __init__(self, bot: commands.Bot): def __init__(self, bot: commands.Bot):
@ -76,7 +79,8 @@ class ScreenshotCog(commands.Cog):
load_timeout: int = 10, load_timeout: int = 10,
render_timeout: int = None, render_timeout: int = None,
eager: bool = None, eager: bool = None,
resolution: str = "1920x1080" resolution: str = "1920x1080",
use_proxy: bool = False
): ):
"""Screenshots a webpage.""" """Screenshots a webpage."""
await ctx.defer() await ctx.defer()
@ -104,11 +108,14 @@ class ScreenshotCog(commands.Cog):
start_init = time.time() start_init = time.time()
try: try:
options = copy.copy(self.chrome_options)
if use_proxy and (server := CONFIG["screenshot"].get("proxy")):
options.add_argument("--proxy-server=" + server)
service = await asyncio.to_thread(ChromeService) service = await asyncio.to_thread(ChromeService)
driver: webdriver.Chrome = await asyncio.to_thread( driver: webdriver.Chrome = await asyncio.to_thread(
webdriver.Chrome, webdriver.Chrome,
service=service, service=service,
options=self.chrome_options options=options
) )
driver.set_page_load_timeout(load_timeout) driver.set_page_load_timeout(load_timeout)
if resolution: if resolution:
@ -173,6 +180,7 @@ class ScreenshotCog(commands.Cog):
end_save = time.time() end_save = time.time()
if len(await asyncio.to_thread(file.getvalue)) > 24 * 1024 * 1024: if len(await asyncio.to_thread(file.getvalue)) > 24 * 1024 * 1024:
await ctx.edit(content="Compressing screenshot...")
start_compress = time.time() start_compress = time.time()
file = await asyncio.to_thread(self.compress_png, file) file = await asyncio.to_thread(self.compress_png, file)
fn = "screenshot.webp" fn = "screenshot.webp"

View file

@ -82,23 +82,34 @@ class YTDLCog(commands.Cog):
await db.commit() await db.commit()
return return
async def save_link(self, message: discord.Message, webpage_url: str, format_id: str, attachment_index: int = 0): async def save_link(
self,
message: discord.Message,
webpage_url: str,
format_id: str,
attachment_index: int = 0,
*,
snip: typing.Optional[str] = None
):
""" """
Saves a link to discord to prevent having to re-download it. Saves a link to discord to prevent having to re-download it.
:param message: The download message with the attachment. :param message: The download message with the attachment.
:param webpage_url: The "webpage_url" key of the metadata :param webpage_url: The "webpage_url" key of the metadata
:param format_id: The "format_Id" key of the metadata :param format_id: The "format_Id" key of the metadata
:param attachment_index: The index of the attachment. Defaults to 0 :param attachment_index: The index of the attachment. Defaults to 0
:param snip: The start and end time to snip the video. e.g. 00:00:00-00:10:00
:return: The created hash key :return: The created hash key
""" """
snip = snip or '*'
await self._init_db() await self._init_db()
async with aiosqlite.connect("./data/ytdl.db") as db: async with aiosqlite.connect("./data/ytdl.db") as db:
_hash = hashlib.md5(f"{webpage_url}:{format_id}".encode()).hexdigest() _hash = hashlib.md5(f"{webpage_url}:{format_id}:{snip}".encode()).hexdigest()
self.log.debug( self.log.debug(
"Saving %r (%r:%r) with message %d>%d, index %d", "Saving %r (%r:%r:%r) with message %d>%d, index %d",
_hash, _hash,
webpage_url, webpage_url,
format_id, format_id,
snip,
message.channel.id, message.channel.id,
message.id, message.id,
attachment_index attachment_index
@ -117,20 +128,27 @@ class YTDLCog(commands.Cog):
await db.commit() await db.commit()
return _hash return _hash
async def get_saved(self, webpage_url: str, format_id: str) -> typing.Optional[str]: async def get_saved(
self,
webpage_url: str,
format_id: str,
snip: str
) -> typing.Optional[str]:
""" """
Attempts to retrieve the attachment URL of a previously saved download. Attempts to retrieve the attachment URL of a previously saved download.
:param webpage_url: The webpage url :param webpage_url: The webpage url
:param format_id: The format ID :param format_id: The format ID
:param snip: The start and end time to snip the video. e.g. 00:00:00-00:10:00
:return: the URL, if found and valid. :return: the URL, if found and valid.
""" """
await self._init_db() await self._init_db()
async with aiosqlite.connect("./data/ytdl.db") as db: async with aiosqlite.connect("./data/ytdl.db") as db:
_hash = hashlib.md5(f"{webpage_url}:{format_id}".encode()).hexdigest() _hash = hashlib.md5(f"{webpage_url}:{format_id}:{snip}".encode()).hexdigest()
self.log.debug( self.log.debug(
"Attempting to find a saved download for '%s:%s' (%r).", "Attempting to find a saved download for '%s:%s:%s' (%r).",
webpage_url, webpage_url,
format_id, format_id,
snip,
_hash _hash
) )
cursor = await db.execute( cursor = await db.execute(
@ -160,7 +178,7 @@ class YTDLCog(commands.Cog):
except IndexError: except IndexError:
self.log.debug("Attachment index %d is out of range (%r)", attachment_index, message.attachments) self.log.debug("Attachment index %d is out of range (%r)", attachment_index, message.attachments)
return return
def convert_to_m4a(self, file: Path) -> Path: def convert_to_m4a(self, file: Path) -> Path:
""" """
Converts a file to m4a format. Converts a file to m4a format.
@ -229,7 +247,6 @@ class YTDLCog(commands.Cog):
snip: typing.Annotated[ snip: typing.Annotated[
typing.Optional[str], typing.Optional[str],
discord.Option( discord.Option(
str,
description="A start and end position to trim. e.g. 00:00:00-00:10:00.", description="A start and end position to trim. e.g. 00:00:00-00:10:00.",
required=False required=False
) )
@ -347,7 +364,7 @@ class YTDLCog(commands.Cog):
colour=self.colours.get(domain, discord.Colour.og_blurple()) colour=self.colours.get(domain, discord.Colour.og_blurple())
).set_footer(text="Downloading (step 2/10)").set_thumbnail(url=thumbnail_url) ).set_footer(text="Downloading (step 2/10)").set_thumbnail(url=thumbnail_url)
) )
previous = await self.get_saved(webpage_url, extracted_info["format_id"]) previous = await self.get_saved(webpage_url, extracted_info["format_id"], snip or '*')
if previous: if previous:
await ctx.edit( await ctx.edit(
content=previous, content=previous,
@ -467,7 +484,7 @@ class YTDLCog(commands.Cog):
) )
) )
file = new_file file = new_file
if audio_only and file.suffix != ".m4a": if audio_only and file.suffix != ".m4a":
self.log.info("Converting %r to m4a.", file) self.log.info("Converting %r to m4a.", file)
file = await asyncio.to_thread(self.convert_to_m4a, file) file = await asyncio.to_thread(self.convert_to_m4a, file)
@ -505,7 +522,7 @@ class YTDLCog(commands.Cog):
url=webpage_url url=webpage_url
) )
) )
await self.save_link(msg, webpage_url, chosen_format_id) await self.save_link(msg, webpage_url, chosen_format_id, snip=snip or '*')
except discord.HTTPException as e: except discord.HTTPException as e:
self.log.error(e, exc_info=True) self.log.error(e, exc_info=True)
return await ctx.edit( return await ctx.edit(

View file

@ -28,6 +28,8 @@ try:
CONFIG.setdefault("jimmy", {}) CONFIG.setdefault("jimmy", {})
CONFIG.setdefault("ollama", {}) CONFIG.setdefault("ollama", {})
CONFIG.setdefault("rss", {"meta": {"channel": None}}) CONFIG.setdefault("rss", {"meta": {"channel": None}})
CONFIG.setdefault("screenshot", {})
CONFIG.setdefault("quote_a", {"channel": None})
CONFIG.setdefault( CONFIG.setdefault(
"server", "server",
{ {

View file

@ -10,7 +10,6 @@ import random
import httpx import httpx
import uvicorn import uvicorn
from web import app
from logging import FileHandler from logging import FileHandler
import discord import discord
@ -104,25 +103,12 @@ class Client(commands.Bot):
CONFIG["jimmy"].get("uptime_kuma_interval", 60.0) CONFIG["jimmy"].get("uptime_kuma_interval", 60.0)
) )
self.uptime_thread.start() self.uptime_thread.start()
app.state.bot = self
config = uvicorn.Config(
app,
host=CONFIG["server"].get("host", "0.0.0.0"),
port=CONFIG["server"].get("port", 8080),
loop="asyncio",
lifespan="on",
server_header=False
)
server = uvicorn.Server(config=config)
self.web = self.loop.create_task(asyncio.to_thread(server.serve()))
await super().start(token, reconnect=reconnect) await super().start(token, reconnect=reconnect)
async def close(self) -> None: async def close(self) -> None:
if self.web: if self.uptime_thread:
self.web.cancel() self.uptime_thread.kill.set()
if self.thread: await asyncio.get_event_loop().run_in_executor(None, self.uptime_thread.join)
self.thread.kill.set()
await asyncio.get_event_loop().run_in_executor(None, self.thread.join)
await super().close() await super().close()
@ -133,7 +119,7 @@ bot = Client(
debug_guilds=CONFIG["jimmy"].get("debug_guilds") debug_guilds=CONFIG["jimmy"].get("debug_guilds")
) )
for ext in ("ytdl", "net", "screenshot", "ollama", "ffmeta"): for ext in ("ytdl", "net", "screenshot", "ollama", "ffmeta", "quote_quota"):
try: try:
bot.load_extension(f"cogs.{ext}") bot.load_extension(f"cogs.{ext}")
except discord.ExtensionError as e: except discord.ExtensionError as e:

View file

@ -1,160 +0,0 @@
import asyncio
import datetime
import logging
import textwrap
import psutil
import time
import pydantic
from typing import Optional, Any
from conf import CONFIG
import discord
from discord.ext.commands import Paginator
from fastapi import FastAPI, HTTPException, status, WebSocketException, WebSocket, WebSocketDisconnect, Header
class BridgeResponse(pydantic.BaseModel):
status: str
pages: list[str]
class BridgePayload(pydantic.BaseModel):
secret: str
message: str
sender: str
class MessagePayload(pydantic.BaseModel):
class MessageAttachmentPayload(pydantic.BaseModel):
url: str
proxy_url: str
filename: str
size: int
width: Optional[int] = None
height: Optional[int] = None
content_type: str
ATTACHMENT: Optional[Any] = None
event_type: Optional[str] = "create"
message_id: int
author: str
is_automated: bool = False
avatar: str
content: str
clean_content: str
at: float
attachments: list[MessageAttachmentPayload] = []
reply_to: Optional["MessagePayload"] = None
app = FastAPI(
title="JimmyAPI",
version="2.0.0a1"
)
log = logging.getLogger("jimmy.web.api")
app.state.bot = None
app.state.bridge_lock = asyncio.Lock()
app.state.last_sender_ts = 0
@app.get("/ping")
def ping():
"""Checks the bot is online and provides some uptime information"""
if not app.state.bot:
raise HTTPException(status.HTTP_503_SERVICE_UNAVAILABLE)
return {
"ping": "pong",
"online": app.state.bot.is_ready(),
"latency": max(round(app.state.bot.latency, 2), 0.01),
"uptime": round(time.time() - psutil.Process().create_time()),
"uptime.sys": time.time() - psutil.boot_time()
}
@app.post("/bridge", status_code=201)
async def bridge_post_send_message(body: BridgePayload):
"""Sends a message FROM matrix TO discord."""
now = datetime.datetime.now(datetime.timezone.utc)
ts_diff = (now - app.state.last_sender_ts).total_seconds()
if not app.state.bot:
raise HTTPException(status.HTTP_503_SERVICE_UNAVAILABLE)
if body.secret != CONFIG["jimmy"].get("token"):
log.warning("Authentication failure: %s was not authenticated.", body.secret)
raise HTTPException(status.HTTP_401_UNAUTHORIZED)
channel = app.state.bot.get_channel(CONFIG["server"]["channel"])
if not channel or not channel.can_send():
log.warning("Unable to send message: channel not found or not writable.")
raise HTTPException(status.HTTP_503_SERVICE_UNAVAILABLE)
if len(body.message) > 4000:
log.warning(
"Unable to send message: message too long ({:,} characters long, 4000 max).".format(len(body.message))
)
raise HTTPException(status.HTTP_413_REQUEST_ENTITY_TOO_LARGE)
paginator = Paginator(prefix="", suffix="", max_size=1990)
for line in body["message"].splitlines():
try:
paginator.add_line(line)
except ValueError:
paginator.add_line(textwrap.shorten(line, width=1900, placeholder="<...>"))
if len(paginator.pages) > 1:
msg = None
if app.state.last_sender != body["sender"] or ts_diff >= 600:
msg = await channel.send(f"**{body['sender']}**:")
m = len(paginator.pages)
for n, page in enumerate(paginator.pages, 1):
await channel.send(
f"[{n}/{m}]\n>>> {page}",
allowed_mentions=discord.AllowedMentions.none(),
reference=msg,
silent=True,
suppress=n != m,
)
app.state.last_sender = body["sender"]
else:
content = f"**{body['sender']}**:\n>>> {body['message']}"
if app.state.last_sender == body["sender"] and ts_diff < 600:
content = f">>> {body['message']}"
await channel.send(content, allowed_mentions=discord.AllowedMentions.none(), silent=True, suppress=False)
app.state.last_sender = body["sender"]
app.state.last_sender_ts = now
return {"status": "ok", "pages": len(paginator.pages)}
@app.websocket("/bridge/recv")
async def bridge_recv(ws: WebSocket, secret: str = Header(None)):
await ws.accept()
log.info("Websocket %s:%s accepted.", ws.client.host, ws.client.port)
if secret != app.state.bot.http.token:
log.warning("Closing websocket %r, invalid secret.", ws.client.host)
raise WebSocketException(code=1008, reason="Invalid Secret")
if app.state.ws_connected.locked():
log.warning("Closing websocket %r, already connected." % ws)
raise WebSocketException(code=1008, reason="Already connected.")
queue: asyncio.Queue = app.state.bot.bridge_queue
async with app.state.ws_connected:
while True:
try:
await ws.send_json({"status": "ping"})
except (WebSocketDisconnect, WebSocketException):
log.info("Websocket %r disconnected.", ws)
break
try:
data = await asyncio.wait_for(queue.get(), timeout=5)
except asyncio.TimeoutError:
continue
try:
await ws.send_json(data)
log.debug("Sent data %r to websocket %r.", data, ws)
except (WebSocketDisconnect, WebSocketException):
log.info("Websocket %r disconnected." % ws)
break
finally:
queue.task_done()