diff --git a/.gitignore b/.gitignore index 676323b..abc05f8 100644 --- a/.gitignore +++ b/.gitignore @@ -310,4 +310,5 @@ pyrightconfig.json # End of https://www.toptal.com/developers/gitignore/api/python,pycharm,visualstudiocode cookies.txt config.toml -chrome/ \ No newline at end of file +chrome/ +src/assets/sensitive/* \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index da18c0e..0163e13 100644 --- a/requirements.txt +++ b/requirements.txt @@ -18,3 +18,4 @@ humanize~=4.9 redis~=5.0 beautifulsoup4~=4.12 lxml~=5.1 +matplotlib~=3.8 diff --git a/src/cogs/net.py b/src/cogs/net.py index 2d0af08..7716193 100644 --- a/src/cogs/net.py +++ b/src/cogs/net.py @@ -4,6 +4,7 @@ import os import re import time import typing +from pathlib import Path import discord from discord.ext import commands @@ -242,6 +243,16 @@ class NetworkCog(commands.Cog): paginator.add_line(f"Error: {e}") for page in paginator.pages: await ctx.respond(page) + + @commands.slash_command(name="what-are-matthews-bank-details") + async def matthew_bank(self, ctx: discord.ApplicationContext): + """For the 80th time""" + f = Path.cwd() / "assets" / "sensitive" / "matthew-bank.webp" + if not f.exists(): + return await ctx.respond("Idk") + else: + await ctx.defer() + await ctx.respond(file=discord.File(f)) def setup(bot): diff --git a/src/cogs/ollama.py b/src/cogs/ollama.py index 7828bf9..36dad41 100644 --- a/src/cogs/ollama.py +++ b/src/cogs/ollama.py @@ -8,6 +8,7 @@ import typing import base64 import io import redis +from discord import Interaction from discord.ui import View, button from fnmatch import fnmatch @@ -89,7 +90,7 @@ class ChatHistory: "threads:" + thread_id, json.dumps(self._internal[thread_id]) ) - def create_thread(self, member: discord.Member) -> str: + def create_thread(self, member: discord.Member, default: str | None = None) -> str: """ Creates a thread, returns its ID. """ @@ -100,7 +101,7 @@ class ChatHistory: "messages": [] } with open("./assets/ollama-prompt.txt") as file: - system_prompt = file.read() + system_prompt = default or file.read() self.add_message( key, "system", @@ -190,6 +191,70 @@ class ChatHistory: SERVER_KEYS = list(CONFIG["ollama"].keys()) +class OllamaGetPrompt(discord.ui.Modal): + + def __init__(self, ctx: discord.ApplicationContext, prompt_type: str = "User"): + super().__init__( + discord.ui.InputText( + style=discord.InputTextStyle.long, + label="%s prompt" % prompt_type, + placeholder="Enter your prompt here.", + ), + timeout=300, + title="Ollama %s prompt" % prompt_type, + ) + self.ctx = ctx + self.prompt_type = prompt_type + self.value = None + + async def interaction_check(self, interaction: discord.Interaction) -> bool: + return interaction.user == self.ctx.user + + async def callback(self, interaction: Interaction): + await interaction.response.defer() + self.value = self.children[0].value + self.stop() + + +class PromptSelector(discord.ui.View): + def __init__(self, ctx: discord.ApplicationContext): + super().__init__(timeout=600, disable_on_timeout=True) + self.ctx = ctx + self.system_prompt = None + self.user_prompt = None + + async def interaction_check(self, interaction: Interaction) -> bool: + return interaction.user == self.ctx.user + + def update_ui(self): + if self.system_prompt is not None: + self.get_item("sys").style = discord.ButtonStyle.secondary # type: ignore + if self.user_prompt is not None: + self.get_item("usr").style = discord.ButtonStyle.secondary # type: ignore + + @discord.ui.button(label="Set System Prompt", style=discord.ButtonStyle.primary, custom_id="sys") + async def set_system_prompt(self, btn: discord.ui.Button, interaction: Interaction): + modal = OllamaGetPrompt(self.ctx, "System") + await interaction.response.send_modal(modal) + await modal.wait() + self.system_prompt = modal.value + self.update_ui() + await interaction.edit_original_response(view=self) + + @discord.ui.button(label="Set User Prompt", style=discord.ButtonStyle.primary, custom_id="usr") + async def set_user_prompt(self, btn: discord.ui.Button, interaction: Interaction): + modal = OllamaGetPrompt(self.ctx) + await interaction.response.send_modal(modal) + await modal.wait() + self.user_prompt = modal.value + self.update_ui() + await interaction.edit_original_response(view=self) + + @discord.ui.button(label="Done", style=discord.ButtonStyle.success, custom_id="done") + async def done(self, btn: discord.ui.Button, interaction: Interaction): + self.stop() + + class Ollama(commands.Cog): def __init__(self, bot: commands.Bot): self.bot = bot @@ -282,11 +347,24 @@ class Ollama(commands.Cog): ) ] ): + system_query = None if context is not None: if not self.history.get_thread(context): await ctx.respond("Invalid context key.") return - await ctx.defer() + + try: + await ctx.defer() + except discord.HTTPException: + pass + + if query == "$": + v = PromptSelector(ctx) + await ctx.respond("Select edit your prompts, as desired. Click done when you want to continue.", view=v) + await v.wait() + query = v.user_prompt or query + system_query = v.system_prompt + await ctx.delete(delay=0.1) model = model.casefold() try: @@ -294,7 +372,7 @@ class Ollama(commands.Cog): model = model + ":" + tag self.log.debug("Model %r already has a tag", model) except ValueError: - model = model + ":latest" + model += ":latest" self.log.debug("Resolved model to %r" % model) if image: @@ -315,7 +393,7 @@ class Ollama(commands.Cog): data = io.BytesIO() await image.save(data) data.seek(0) - image_data = base64.b64encode(data.read()).decode("utf-8") + image_data = base64.b64encode(data.read()).decode() else: image_data = None @@ -336,7 +414,12 @@ class Ollama(commands.Cog): async with aiohttp.ClientSession( base_url=server_config["base_url"], - timeout=aiohttp.ClientTimeout(0) + timeout=aiohttp.ClientTimeout( + connect=30, + sock_read=10800, + sock_connect=30, + total=10830 + ) ) as session: embed = discord.Embed( title="Checking server...", @@ -482,7 +565,7 @@ class Ollama(commands.Cog): self.log.debug("Beginning to generate response with key %r.", key) if context is None: - context = self.history.create_thread(ctx.user) + context = self.history.create_thread(ctx.user, system_query) elif context is not None and self.history.get_thread(context) is None: __thread = self.history.find_thread(context) if not __thread: diff --git a/src/cogs/quote_quota.py b/src/cogs/quote_quota.py new file mode 100644 index 0000000..b48cabc --- /dev/null +++ b/src/cogs/quote_quota.py @@ -0,0 +1,188 @@ +import asyncio +import re + +import discord +import io +import matplotlib.pyplot as plt +from datetime import timedelta +from discord.ext import commands +from typing import Iterable, Annotated + +from conf import CONFIG + + +class QuoteQuota(commands.Cog): + + def __init__(self, bot): + self.bot = bot + self.quotes_channel_id = CONFIG["quote_a"].get("channel_id") + self.names = CONFIG["quote_a"].get("names", {}) + + @property + def quotes_channel(self) -> discord.TextChannel | None: + if self.quotes_channel_id: + c = self.bot.get_channel(self.quotes_channel_id) + if c: + return c + + @staticmethod + def generate_pie_chart( + usernames: list[str], + counts: list[int], + no_other: bool = False + ) -> discord.File: + """ + Converts the given username and count tuples into a nice pretty pie chart. + + :param usernames: The usernames + :param counts: The number of times the username appears in the chat + :param no_other: Disables the "other" grouping + :returns: The pie chart image + """ + + def pct(v: int): + return f"{v:.1f}% ({round((v / 100) * sum(counts))})" + + if no_other is False: + other = [] + # Any authors with less than 5% of the total count will be grouped into "other" + for i, author in enumerate(usernames.copy()): + if (c := counts[i]) / sum(counts) < 0.05: + other.append(c) + counts[i] = -1 + usernames.remove(author) + if other: + usernames.append("Other") + counts.append(sum(other)) + # And now filter out any -1% counts + counts = [c for c in counts if c != -1] + + mapping = {} + for i, author in enumerate(usernames): + mapping[author] = counts[i] + + # Sort the authors by count + new_mapping = {} + for author, count in sorted(mapping.items(), key=lambda x: x[1], reverse=True): + new_mapping[author] = count + + usernames = list(new_mapping.keys()) + counts = list(new_mapping.values()) + + fig, ax = plt.subplots(figsize=(7, 7)) + ax.pie( + counts, + labels=usernames, + autopct=pct, + startangle=90, + radius=1.2, + ) + fig.subplots_adjust(left=0.1, bottom=0.1, right=0.9, top=0.9, wspace=0.3, hspace=0.4) + fio = io.BytesIO() + fig.savefig(fio, format='png') + fio.seek(0) + return discord.File(fio, filename="pie.png") + + @commands.slash_command() + async def quota( + self, + ctx: discord.ApplicationContext, + days: Annotated[ + int, + discord.Option( + int, + name="lookback", + description="How many days to look back on. Defaults to 7.", + default=7, + min_value=1, + max_value=365 + ) + ], + merge_other: Annotated[ + bool, + discord.Option( + bool, + name="merge_other", + description="Whether to merge authors with less than 5% of the total count into 'Other'.", + default=True + ) + ] + ): + """Checks the quote quota for the quotes channel.""" + now = discord.utils.utcnow() + oldest = now - timedelta(days=days) + await ctx.defer() + channel = self.quotes_channel or discord.utils.get(ctx.guild.text_channels, name="quotes") + if not channel: + return await ctx.respond(":x: Cannot find quotes channel.") + + await ctx.respond("Gathering messages, this may take a moment.") + + authors = {} + filtered_messages = 0 + total = 0 + async for message in channel.history( + limit=None, + after=oldest, + oldest_first=False + ): + total += 1 + if not message.content: + filtered_messages += 1 + continue + if message.attachments: + regex = r".*\s*-\s*@?([\w\s]+)" + else: + regex = r".+\s+-\s*@?([\w\s]+)" + + if not (m := re.match(regex, str(message.clean_content))): + filtered_messages += 1 + continue + name = m.group(1) + name = name.strip().casefold() + if name == "me": + name = message.author.name.strip().casefold() + if name in self.names: + name = self.names[name].title() + else: + filtered_messages += 1 + continue + elif name in self.names: + name = self.names[name].title() + elif name.isdigit(): + filtered_messages += 1 + continue + + name = name.title() + authors.setdefault(name, 0) + authors[name] += 1 + + if not authors: + if total: + return await ctx.edit( + content="No valid messages found in the last {!s} days. " + "Make sure quotes are formatted properly ending with ` - AuthorName`" + " (e.g. `\"This is my quote\" - Jimmy`)".format(days) + ) + else: + return await ctx.edit( + content="No messages found in the last {!s} days.".format(days) + ) + + file = await asyncio.to_thread( + self.generate_pie_chart, + list(authors.keys()), + list(authors.values()), + merge_other + ) + return await ctx.edit( + content="{:,} messages (out of {:,}) were filtered (didn't follow format?)".format( + filtered_messages, + total + ), + file=file + ) + + +def setup(bot): + bot.add_cog(QuoteQuota(bot)) diff --git a/src/cogs/screenshot.py b/src/cogs/screenshot.py index 9390ba0..3d28939 100644 --- a/src/cogs/screenshot.py +++ b/src/cogs/screenshot.py @@ -5,6 +5,7 @@ import logging import os import tempfile import time +import copy from urllib.parse import urlparse import discord @@ -15,6 +16,8 @@ from selenium import webdriver from selenium.webdriver.chrome.options import Options as ChromeOptions from selenium.webdriver.chrome.service import Service as ChromeService +from conf import CONFIG + class ScreenshotCog(commands.Cog): def __init__(self, bot: commands.Bot): @@ -76,7 +79,8 @@ class ScreenshotCog(commands.Cog): load_timeout: int = 10, render_timeout: int = None, eager: bool = None, - resolution: str = "1920x1080" + resolution: str = "1920x1080", + use_proxy: bool = False ): """Screenshots a webpage.""" await ctx.defer() @@ -104,11 +108,14 @@ class ScreenshotCog(commands.Cog): start_init = time.time() try: + options = copy.copy(self.chrome_options) + if use_proxy and (server := CONFIG["screenshot"].get("proxy")): + options.add_argument("--proxy-server=" + server) service = await asyncio.to_thread(ChromeService) driver: webdriver.Chrome = await asyncio.to_thread( webdriver.Chrome, service=service, - options=self.chrome_options + options=options ) driver.set_page_load_timeout(load_timeout) if resolution: @@ -173,6 +180,7 @@ class ScreenshotCog(commands.Cog): end_save = time.time() if len(await asyncio.to_thread(file.getvalue)) > 24 * 1024 * 1024: + await ctx.edit(content="Compressing screenshot...") start_compress = time.time() file = await asyncio.to_thread(self.compress_png, file) fn = "screenshot.webp" diff --git a/src/cogs/ytdl.py b/src/cogs/ytdl.py index 13463df..42888f3 100644 --- a/src/cogs/ytdl.py +++ b/src/cogs/ytdl.py @@ -82,23 +82,34 @@ class YTDLCog(commands.Cog): await db.commit() return - async def save_link(self, message: discord.Message, webpage_url: str, format_id: str, attachment_index: int = 0): + async def save_link( + self, + message: discord.Message, + webpage_url: str, + format_id: str, + attachment_index: int = 0, + *, + snip: typing.Optional[str] = None + ): """ Saves a link to discord to prevent having to re-download it. :param message: The download message with the attachment. :param webpage_url: The "webpage_url" key of the metadata :param format_id: The "format_Id" key of the metadata :param attachment_index: The index of the attachment. Defaults to 0 + :param snip: The start and end time to snip the video. e.g. 00:00:00-00:10:00 :return: The created hash key """ + snip = snip or '*' await self._init_db() async with aiosqlite.connect("./data/ytdl.db") as db: - _hash = hashlib.md5(f"{webpage_url}:{format_id}".encode()).hexdigest() + _hash = hashlib.md5(f"{webpage_url}:{format_id}:{snip}".encode()).hexdigest() self.log.debug( - "Saving %r (%r:%r) with message %d>%d, index %d", + "Saving %r (%r:%r:%r) with message %d>%d, index %d", _hash, webpage_url, format_id, + snip, message.channel.id, message.id, attachment_index @@ -117,20 +128,27 @@ class YTDLCog(commands.Cog): await db.commit() return _hash - async def get_saved(self, webpage_url: str, format_id: str) -> typing.Optional[str]: + async def get_saved( + self, + webpage_url: str, + format_id: str, + snip: str + ) -> typing.Optional[str]: """ Attempts to retrieve the attachment URL of a previously saved download. :param webpage_url: The webpage url :param format_id: The format ID + :param snip: The start and end time to snip the video. e.g. 00:00:00-00:10:00 :return: the URL, if found and valid. """ await self._init_db() async with aiosqlite.connect("./data/ytdl.db") as db: - _hash = hashlib.md5(f"{webpage_url}:{format_id}".encode()).hexdigest() + _hash = hashlib.md5(f"{webpage_url}:{format_id}:{snip}".encode()).hexdigest() self.log.debug( - "Attempting to find a saved download for '%s:%s' (%r).", + "Attempting to find a saved download for '%s:%s:%s' (%r).", webpage_url, format_id, + snip, _hash ) cursor = await db.execute( @@ -160,7 +178,7 @@ class YTDLCog(commands.Cog): except IndexError: self.log.debug("Attachment index %d is out of range (%r)", attachment_index, message.attachments) return - + def convert_to_m4a(self, file: Path) -> Path: """ Converts a file to m4a format. @@ -229,7 +247,6 @@ class YTDLCog(commands.Cog): snip: typing.Annotated[ typing.Optional[str], discord.Option( - str, description="A start and end position to trim. e.g. 00:00:00-00:10:00.", required=False ) @@ -347,7 +364,7 @@ class YTDLCog(commands.Cog): colour=self.colours.get(domain, discord.Colour.og_blurple()) ).set_footer(text="Downloading (step 2/10)").set_thumbnail(url=thumbnail_url) ) - previous = await self.get_saved(webpage_url, extracted_info["format_id"]) + previous = await self.get_saved(webpage_url, extracted_info["format_id"], snip or '*') if previous: await ctx.edit( content=previous, @@ -467,7 +484,7 @@ class YTDLCog(commands.Cog): ) ) file = new_file - + if audio_only and file.suffix != ".m4a": self.log.info("Converting %r to m4a.", file) file = await asyncio.to_thread(self.convert_to_m4a, file) @@ -505,7 +522,7 @@ class YTDLCog(commands.Cog): url=webpage_url ) ) - await self.save_link(msg, webpage_url, chosen_format_id) + await self.save_link(msg, webpage_url, chosen_format_id, snip=snip or '*') except discord.HTTPException as e: self.log.error(e, exc_info=True) return await ctx.edit( diff --git a/src/conf.py b/src/conf.py index 88bd862..e260f8d 100644 --- a/src/conf.py +++ b/src/conf.py @@ -28,6 +28,8 @@ try: CONFIG.setdefault("jimmy", {}) CONFIG.setdefault("ollama", {}) CONFIG.setdefault("rss", {"meta": {"channel": None}}) + CONFIG.setdefault("screenshot", {}) + CONFIG.setdefault("quote_a", {"channel": None}) CONFIG.setdefault( "server", { diff --git a/src/main.py b/src/main.py index 82947c0..fcb6318 100644 --- a/src/main.py +++ b/src/main.py @@ -10,7 +10,6 @@ import random import httpx import uvicorn -from web import app from logging import FileHandler import discord @@ -104,25 +103,12 @@ class Client(commands.Bot): CONFIG["jimmy"].get("uptime_kuma_interval", 60.0) ) self.uptime_thread.start() - app.state.bot = self - config = uvicorn.Config( - app, - host=CONFIG["server"].get("host", "0.0.0.0"), - port=CONFIG["server"].get("port", 8080), - loop="asyncio", - lifespan="on", - server_header=False - ) - server = uvicorn.Server(config=config) - self.web = self.loop.create_task(asyncio.to_thread(server.serve())) await super().start(token, reconnect=reconnect) async def close(self) -> None: - if self.web: - self.web.cancel() - if self.thread: - self.thread.kill.set() - await asyncio.get_event_loop().run_in_executor(None, self.thread.join) + if self.uptime_thread: + self.uptime_thread.kill.set() + await asyncio.get_event_loop().run_in_executor(None, self.uptime_thread.join) await super().close() @@ -133,7 +119,7 @@ bot = Client( debug_guilds=CONFIG["jimmy"].get("debug_guilds") ) -for ext in ("ytdl", "net", "screenshot", "ollama", "ffmeta"): +for ext in ("ytdl", "net", "screenshot", "ollama", "ffmeta", "quote_quota"): try: bot.load_extension(f"cogs.{ext}") except discord.ExtensionError as e: diff --git a/src/web.py b/src/web.py deleted file mode 100644 index a2ec382..0000000 --- a/src/web.py +++ /dev/null @@ -1,160 +0,0 @@ -import asyncio -import datetime -import logging -import textwrap - -import psutil -import time -import pydantic -from typing import Optional, Any -from conf import CONFIG -import discord -from discord.ext.commands import Paginator - -from fastapi import FastAPI, HTTPException, status, WebSocketException, WebSocket, WebSocketDisconnect, Header - -class BridgeResponse(pydantic.BaseModel): - status: str - pages: list[str] - - -class BridgePayload(pydantic.BaseModel): - secret: str - message: str - sender: str - - -class MessagePayload(pydantic.BaseModel): - class MessageAttachmentPayload(pydantic.BaseModel): - url: str - proxy_url: str - filename: str - size: int - width: Optional[int] = None - height: Optional[int] = None - content_type: str - ATTACHMENT: Optional[Any] = None - - event_type: Optional[str] = "create" - message_id: int - author: str - is_automated: bool = False - avatar: str - content: str - clean_content: str - at: float - attachments: list[MessageAttachmentPayload] = [] - reply_to: Optional["MessagePayload"] = None - - -app = FastAPI( - title="JimmyAPI", - version="2.0.0a1" -) -log = logging.getLogger("jimmy.web.api") -app.state.bot = None -app.state.bridge_lock = asyncio.Lock() -app.state.last_sender_ts = 0 - - -@app.get("/ping") -def ping(): - """Checks the bot is online and provides some uptime information""" - if not app.state.bot: - raise HTTPException(status.HTTP_503_SERVICE_UNAVAILABLE) - return { - "ping": "pong", - "online": app.state.bot.is_ready(), - "latency": max(round(app.state.bot.latency, 2), 0.01), - "uptime": round(time.time() - psutil.Process().create_time()), - "uptime.sys": time.time() - psutil.boot_time() - } - - -@app.post("/bridge", status_code=201) -async def bridge_post_send_message(body: BridgePayload): - """Sends a message FROM matrix TO discord.""" - now = datetime.datetime.now(datetime.timezone.utc) - ts_diff = (now - app.state.last_sender_ts).total_seconds() - if not app.state.bot: - raise HTTPException(status.HTTP_503_SERVICE_UNAVAILABLE) - - if body.secret != CONFIG["jimmy"].get("token"): - log.warning("Authentication failure: %s was not authenticated.", body.secret) - raise HTTPException(status.HTTP_401_UNAUTHORIZED) - - channel = app.state.bot.get_channel(CONFIG["server"]["channel"]) - if not channel or not channel.can_send(): - log.warning("Unable to send message: channel not found or not writable.") - raise HTTPException(status.HTTP_503_SERVICE_UNAVAILABLE) - - if len(body.message) > 4000: - log.warning( - "Unable to send message: message too long ({:,} characters long, 4000 max).".format(len(body.message)) - ) - raise HTTPException(status.HTTP_413_REQUEST_ENTITY_TOO_LARGE) - - paginator = Paginator(prefix="", suffix="", max_size=1990) - for line in body["message"].splitlines(): - try: - paginator.add_line(line) - except ValueError: - paginator.add_line(textwrap.shorten(line, width=1900, placeholder="<...>")) - - if len(paginator.pages) > 1: - msg = None - if app.state.last_sender != body["sender"] or ts_diff >= 600: - msg = await channel.send(f"**{body['sender']}**:") - m = len(paginator.pages) - for n, page in enumerate(paginator.pages, 1): - await channel.send( - f"[{n}/{m}]\n>>> {page}", - allowed_mentions=discord.AllowedMentions.none(), - reference=msg, - silent=True, - suppress=n != m, - ) - app.state.last_sender = body["sender"] - else: - content = f"**{body['sender']}**:\n>>> {body['message']}" - if app.state.last_sender == body["sender"] and ts_diff < 600: - content = f">>> {body['message']}" - await channel.send(content, allowed_mentions=discord.AllowedMentions.none(), silent=True, suppress=False) - app.state.last_sender = body["sender"] - app.state.last_sender_ts = now - return {"status": "ok", "pages": len(paginator.pages)} - - -@app.websocket("/bridge/recv") -async def bridge_recv(ws: WebSocket, secret: str = Header(None)): - await ws.accept() - log.info("Websocket %s:%s accepted.", ws.client.host, ws.client.port) - if secret != app.state.bot.http.token: - log.warning("Closing websocket %r, invalid secret.", ws.client.host) - raise WebSocketException(code=1008, reason="Invalid Secret") - if app.state.ws_connected.locked(): - log.warning("Closing websocket %r, already connected." % ws) - raise WebSocketException(code=1008, reason="Already connected.") - queue: asyncio.Queue = app.state.bot.bridge_queue - - async with app.state.ws_connected: - while True: - try: - await ws.send_json({"status": "ping"}) - except (WebSocketDisconnect, WebSocketException): - log.info("Websocket %r disconnected.", ws) - break - - try: - data = await asyncio.wait_for(queue.get(), timeout=5) - except asyncio.TimeoutError: - continue - - try: - await ws.send_json(data) - log.debug("Sent data %r to websocket %r.", data, ws) - except (WebSocketDisconnect, WebSocketException): - log.info("Websocket %r disconnected." % ws) - break - finally: - queue.task_done()