Merge branch 'master' of github.com:nexy7574/college-bot-v2

This commit is contained in:
Nexus 2024-04-02 02:38:23 +01:00
commit c8019f3993
10 changed files with 336 additions and 199 deletions

1
.gitignore vendored
View file

@ -311,3 +311,4 @@ pyrightconfig.json
cookies.txt
config.toml
chrome/
src/assets/sensitive/*

View file

@ -18,3 +18,4 @@ humanize~=4.9
redis~=5.0
beautifulsoup4~=4.12
lxml~=5.1
matplotlib~=3.8

View file

@ -4,6 +4,7 @@ import os
import re
import time
import typing
from pathlib import Path
import discord
from discord.ext import commands
@ -243,6 +244,16 @@ class NetworkCog(commands.Cog):
for page in paginator.pages:
await ctx.respond(page)
@commands.slash_command(name="what-are-matthews-bank-details")
async def matthew_bank(self, ctx: discord.ApplicationContext):
"""For the 80th time"""
f = Path.cwd() / "assets" / "sensitive" / "matthew-bank.webp"
if not f.exists():
return await ctx.respond("Idk")
else:
await ctx.defer()
await ctx.respond(file=discord.File(f))
def setup(bot):
bot.add_cog(NetworkCog(bot))

View file

@ -8,6 +8,7 @@ import typing
import base64
import io
import redis
from discord import Interaction
from discord.ui import View, button
from fnmatch import fnmatch
@ -89,7 +90,7 @@ class ChatHistory:
"threads:" + thread_id, json.dumps(self._internal[thread_id])
)
def create_thread(self, member: discord.Member) -> str:
def create_thread(self, member: discord.Member, default: str | None = None) -> str:
"""
Creates a thread, returns its ID.
"""
@ -100,7 +101,7 @@ class ChatHistory:
"messages": []
}
with open("./assets/ollama-prompt.txt") as file:
system_prompt = file.read()
system_prompt = default or file.read()
self.add_message(
key,
"system",
@ -190,6 +191,70 @@ class ChatHistory:
SERVER_KEYS = list(CONFIG["ollama"].keys())
class OllamaGetPrompt(discord.ui.Modal):
def __init__(self, ctx: discord.ApplicationContext, prompt_type: str = "User"):
super().__init__(
discord.ui.InputText(
style=discord.InputTextStyle.long,
label="%s prompt" % prompt_type,
placeholder="Enter your prompt here.",
),
timeout=300,
title="Ollama %s prompt" % prompt_type,
)
self.ctx = ctx
self.prompt_type = prompt_type
self.value = None
async def interaction_check(self, interaction: discord.Interaction) -> bool:
return interaction.user == self.ctx.user
async def callback(self, interaction: Interaction):
await interaction.response.defer()
self.value = self.children[0].value
self.stop()
class PromptSelector(discord.ui.View):
def __init__(self, ctx: discord.ApplicationContext):
super().__init__(timeout=600, disable_on_timeout=True)
self.ctx = ctx
self.system_prompt = None
self.user_prompt = None
async def interaction_check(self, interaction: Interaction) -> bool:
return interaction.user == self.ctx.user
def update_ui(self):
if self.system_prompt is not None:
self.get_item("sys").style = discord.ButtonStyle.secondary # type: ignore
if self.user_prompt is not None:
self.get_item("usr").style = discord.ButtonStyle.secondary # type: ignore
@discord.ui.button(label="Set System Prompt", style=discord.ButtonStyle.primary, custom_id="sys")
async def set_system_prompt(self, btn: discord.ui.Button, interaction: Interaction):
modal = OllamaGetPrompt(self.ctx, "System")
await interaction.response.send_modal(modal)
await modal.wait()
self.system_prompt = modal.value
self.update_ui()
await interaction.edit_original_response(view=self)
@discord.ui.button(label="Set User Prompt", style=discord.ButtonStyle.primary, custom_id="usr")
async def set_user_prompt(self, btn: discord.ui.Button, interaction: Interaction):
modal = OllamaGetPrompt(self.ctx)
await interaction.response.send_modal(modal)
await modal.wait()
self.user_prompt = modal.value
self.update_ui()
await interaction.edit_original_response(view=self)
@discord.ui.button(label="Done", style=discord.ButtonStyle.success, custom_id="done")
async def done(self, btn: discord.ui.Button, interaction: Interaction):
self.stop()
class Ollama(commands.Cog):
def __init__(self, bot: commands.Bot):
self.bot = bot
@ -282,11 +347,24 @@ class Ollama(commands.Cog):
)
]
):
system_query = None
if context is not None:
if not self.history.get_thread(context):
await ctx.respond("Invalid context key.")
return
try:
await ctx.defer()
except discord.HTTPException:
pass
if query == "$":
v = PromptSelector(ctx)
await ctx.respond("Select edit your prompts, as desired. Click done when you want to continue.", view=v)
await v.wait()
query = v.user_prompt or query
system_query = v.system_prompt
await ctx.delete(delay=0.1)
model = model.casefold()
try:
@ -294,7 +372,7 @@ class Ollama(commands.Cog):
model = model + ":" + tag
self.log.debug("Model %r already has a tag", model)
except ValueError:
model = model + ":latest"
model += ":latest"
self.log.debug("Resolved model to %r" % model)
if image:
@ -315,7 +393,7 @@ class Ollama(commands.Cog):
data = io.BytesIO()
await image.save(data)
data.seek(0)
image_data = base64.b64encode(data.read()).decode("utf-8")
image_data = base64.b64encode(data.read()).decode()
else:
image_data = None
@ -336,7 +414,12 @@ class Ollama(commands.Cog):
async with aiohttp.ClientSession(
base_url=server_config["base_url"],
timeout=aiohttp.ClientTimeout(0)
timeout=aiohttp.ClientTimeout(
connect=30,
sock_read=10800,
sock_connect=30,
total=10830
)
) as session:
embed = discord.Embed(
title="Checking server...",
@ -482,7 +565,7 @@ class Ollama(commands.Cog):
self.log.debug("Beginning to generate response with key %r.", key)
if context is None:
context = self.history.create_thread(ctx.user)
context = self.history.create_thread(ctx.user, system_query)
elif context is not None and self.history.get_thread(context) is None:
__thread = self.history.find_thread(context)
if not __thread:

188
src/cogs/quote_quota.py Normal file
View file

@ -0,0 +1,188 @@
import asyncio
import re
import discord
import io
import matplotlib.pyplot as plt
from datetime import timedelta
from discord.ext import commands
from typing import Iterable, Annotated
from conf import CONFIG
class QuoteQuota(commands.Cog):
def __init__(self, bot):
self.bot = bot
self.quotes_channel_id = CONFIG["quote_a"].get("channel_id")
self.names = CONFIG["quote_a"].get("names", {})
@property
def quotes_channel(self) -> discord.TextChannel | None:
if self.quotes_channel_id:
c = self.bot.get_channel(self.quotes_channel_id)
if c:
return c
@staticmethod
def generate_pie_chart(
usernames: list[str],
counts: list[int],
no_other: bool = False
) -> discord.File:
"""
Converts the given username and count tuples into a nice pretty pie chart.
:param usernames: The usernames
:param counts: The number of times the username appears in the chat
:param no_other: Disables the "other" grouping
:returns: The pie chart image
"""
def pct(v: int):
return f"{v:.1f}% ({round((v / 100) * sum(counts))})"
if no_other is False:
other = []
# Any authors with less than 5% of the total count will be grouped into "other"
for i, author in enumerate(usernames.copy()):
if (c := counts[i]) / sum(counts) < 0.05:
other.append(c)
counts[i] = -1
usernames.remove(author)
if other:
usernames.append("Other")
counts.append(sum(other))
# And now filter out any -1% counts
counts = [c for c in counts if c != -1]
mapping = {}
for i, author in enumerate(usernames):
mapping[author] = counts[i]
# Sort the authors by count
new_mapping = {}
for author, count in sorted(mapping.items(), key=lambda x: x[1], reverse=True):
new_mapping[author] = count
usernames = list(new_mapping.keys())
counts = list(new_mapping.values())
fig, ax = plt.subplots(figsize=(7, 7))
ax.pie(
counts,
labels=usernames,
autopct=pct,
startangle=90,
radius=1.2,
)
fig.subplots_adjust(left=0.1, bottom=0.1, right=0.9, top=0.9, wspace=0.3, hspace=0.4)
fio = io.BytesIO()
fig.savefig(fio, format='png')
fio.seek(0)
return discord.File(fio, filename="pie.png")
@commands.slash_command()
async def quota(
self,
ctx: discord.ApplicationContext,
days: Annotated[
int,
discord.Option(
int,
name="lookback",
description="How many days to look back on. Defaults to 7.",
default=7,
min_value=1,
max_value=365
)
],
merge_other: Annotated[
bool,
discord.Option(
bool,
name="merge_other",
description="Whether to merge authors with less than 5% of the total count into 'Other'.",
default=True
)
]
):
"""Checks the quote quota for the quotes channel."""
now = discord.utils.utcnow()
oldest = now - timedelta(days=days)
await ctx.defer()
channel = self.quotes_channel or discord.utils.get(ctx.guild.text_channels, name="quotes")
if not channel:
return await ctx.respond(":x: Cannot find quotes channel.")
await ctx.respond("Gathering messages, this may take a moment.")
authors = {}
filtered_messages = 0
total = 0
async for message in channel.history(
limit=None,
after=oldest,
oldest_first=False
):
total += 1
if not message.content:
filtered_messages += 1
continue
if message.attachments:
regex = r".*\s*-\s*@?([\w\s]+)"
else:
regex = r".+\s+-\s*@?([\w\s]+)"
if not (m := re.match(regex, str(message.clean_content))):
filtered_messages += 1
continue
name = m.group(1)
name = name.strip().casefold()
if name == "me":
name = message.author.name.strip().casefold()
if name in self.names:
name = self.names[name].title()
else:
filtered_messages += 1
continue
elif name in self.names:
name = self.names[name].title()
elif name.isdigit():
filtered_messages += 1
continue
name = name.title()
authors.setdefault(name, 0)
authors[name] += 1
if not authors:
if total:
return await ctx.edit(
content="No valid messages found in the last {!s} days. "
"Make sure quotes are formatted properly ending with ` - AuthorName`"
" (e.g. `\"This is my quote\" - Jimmy`)".format(days)
)
else:
return await ctx.edit(
content="No messages found in the last {!s} days.".format(days)
)
file = await asyncio.to_thread(
self.generate_pie_chart,
list(authors.keys()),
list(authors.values()),
merge_other
)
return await ctx.edit(
content="{:,} messages (out of {:,}) were filtered (didn't follow format?)".format(
filtered_messages,
total
),
file=file
)
def setup(bot):
bot.add_cog(QuoteQuota(bot))

View file

@ -5,6 +5,7 @@ import logging
import os
import tempfile
import time
import copy
from urllib.parse import urlparse
import discord
@ -15,6 +16,8 @@ from selenium import webdriver
from selenium.webdriver.chrome.options import Options as ChromeOptions
from selenium.webdriver.chrome.service import Service as ChromeService
from conf import CONFIG
class ScreenshotCog(commands.Cog):
def __init__(self, bot: commands.Bot):
@ -76,7 +79,8 @@ class ScreenshotCog(commands.Cog):
load_timeout: int = 10,
render_timeout: int = None,
eager: bool = None,
resolution: str = "1920x1080"
resolution: str = "1920x1080",
use_proxy: bool = False
):
"""Screenshots a webpage."""
await ctx.defer()
@ -104,11 +108,14 @@ class ScreenshotCog(commands.Cog):
start_init = time.time()
try:
options = copy.copy(self.chrome_options)
if use_proxy and (server := CONFIG["screenshot"].get("proxy")):
options.add_argument("--proxy-server=" + server)
service = await asyncio.to_thread(ChromeService)
driver: webdriver.Chrome = await asyncio.to_thread(
webdriver.Chrome,
service=service,
options=self.chrome_options
options=options
)
driver.set_page_load_timeout(load_timeout)
if resolution:
@ -173,6 +180,7 @@ class ScreenshotCog(commands.Cog):
end_save = time.time()
if len(await asyncio.to_thread(file.getvalue)) > 24 * 1024 * 1024:
await ctx.edit(content="Compressing screenshot...")
start_compress = time.time()
file = await asyncio.to_thread(self.compress_png, file)
fn = "screenshot.webp"

View file

@ -82,23 +82,34 @@ class YTDLCog(commands.Cog):
await db.commit()
return
async def save_link(self, message: discord.Message, webpage_url: str, format_id: str, attachment_index: int = 0):
async def save_link(
self,
message: discord.Message,
webpage_url: str,
format_id: str,
attachment_index: int = 0,
*,
snip: typing.Optional[str] = None
):
"""
Saves a link to discord to prevent having to re-download it.
:param message: The download message with the attachment.
:param webpage_url: The "webpage_url" key of the metadata
:param format_id: The "format_Id" key of the metadata
:param attachment_index: The index of the attachment. Defaults to 0
:param snip: The start and end time to snip the video. e.g. 00:00:00-00:10:00
:return: The created hash key
"""
snip = snip or '*'
await self._init_db()
async with aiosqlite.connect("./data/ytdl.db") as db:
_hash = hashlib.md5(f"{webpage_url}:{format_id}".encode()).hexdigest()
_hash = hashlib.md5(f"{webpage_url}:{format_id}:{snip}".encode()).hexdigest()
self.log.debug(
"Saving %r (%r:%r) with message %d>%d, index %d",
"Saving %r (%r:%r:%r) with message %d>%d, index %d",
_hash,
webpage_url,
format_id,
snip,
message.channel.id,
message.id,
attachment_index
@ -117,20 +128,27 @@ class YTDLCog(commands.Cog):
await db.commit()
return _hash
async def get_saved(self, webpage_url: str, format_id: str) -> typing.Optional[str]:
async def get_saved(
self,
webpage_url: str,
format_id: str,
snip: str
) -> typing.Optional[str]:
"""
Attempts to retrieve the attachment URL of a previously saved download.
:param webpage_url: The webpage url
:param format_id: The format ID
:param snip: The start and end time to snip the video. e.g. 00:00:00-00:10:00
:return: the URL, if found and valid.
"""
await self._init_db()
async with aiosqlite.connect("./data/ytdl.db") as db:
_hash = hashlib.md5(f"{webpage_url}:{format_id}".encode()).hexdigest()
_hash = hashlib.md5(f"{webpage_url}:{format_id}:{snip}".encode()).hexdigest()
self.log.debug(
"Attempting to find a saved download for '%s:%s' (%r).",
"Attempting to find a saved download for '%s:%s:%s' (%r).",
webpage_url,
format_id,
snip,
_hash
)
cursor = await db.execute(
@ -229,7 +247,6 @@ class YTDLCog(commands.Cog):
snip: typing.Annotated[
typing.Optional[str],
discord.Option(
str,
description="A start and end position to trim. e.g. 00:00:00-00:10:00.",
required=False
)
@ -347,7 +364,7 @@ class YTDLCog(commands.Cog):
colour=self.colours.get(domain, discord.Colour.og_blurple())
).set_footer(text="Downloading (step 2/10)").set_thumbnail(url=thumbnail_url)
)
previous = await self.get_saved(webpage_url, extracted_info["format_id"])
previous = await self.get_saved(webpage_url, extracted_info["format_id"], snip or '*')
if previous:
await ctx.edit(
content=previous,
@ -505,7 +522,7 @@ class YTDLCog(commands.Cog):
url=webpage_url
)
)
await self.save_link(msg, webpage_url, chosen_format_id)
await self.save_link(msg, webpage_url, chosen_format_id, snip=snip or '*')
except discord.HTTPException as e:
self.log.error(e, exc_info=True)
return await ctx.edit(

View file

@ -28,6 +28,8 @@ try:
CONFIG.setdefault("jimmy", {})
CONFIG.setdefault("ollama", {})
CONFIG.setdefault("rss", {"meta": {"channel": None}})
CONFIG.setdefault("screenshot", {})
CONFIG.setdefault("quote_a", {"channel": None})
CONFIG.setdefault(
"server",
{

View file

@ -10,7 +10,6 @@ import random
import httpx
import uvicorn
from web import app
from logging import FileHandler
import discord
@ -104,25 +103,12 @@ class Client(commands.Bot):
CONFIG["jimmy"].get("uptime_kuma_interval", 60.0)
)
self.uptime_thread.start()
app.state.bot = self
config = uvicorn.Config(
app,
host=CONFIG["server"].get("host", "0.0.0.0"),
port=CONFIG["server"].get("port", 8080),
loop="asyncio",
lifespan="on",
server_header=False
)
server = uvicorn.Server(config=config)
self.web = self.loop.create_task(asyncio.to_thread(server.serve()))
await super().start(token, reconnect=reconnect)
async def close(self) -> None:
if self.web:
self.web.cancel()
if self.thread:
self.thread.kill.set()
await asyncio.get_event_loop().run_in_executor(None, self.thread.join)
if self.uptime_thread:
self.uptime_thread.kill.set()
await asyncio.get_event_loop().run_in_executor(None, self.uptime_thread.join)
await super().close()
@ -133,7 +119,7 @@ bot = Client(
debug_guilds=CONFIG["jimmy"].get("debug_guilds")
)
for ext in ("ytdl", "net", "screenshot", "ollama", "ffmeta"):
for ext in ("ytdl", "net", "screenshot", "ollama", "ffmeta", "quote_quota"):
try:
bot.load_extension(f"cogs.{ext}")
except discord.ExtensionError as e:

View file

@ -1,160 +0,0 @@
import asyncio
import datetime
import logging
import textwrap
import psutil
import time
import pydantic
from typing import Optional, Any
from conf import CONFIG
import discord
from discord.ext.commands import Paginator
from fastapi import FastAPI, HTTPException, status, WebSocketException, WebSocket, WebSocketDisconnect, Header
class BridgeResponse(pydantic.BaseModel):
status: str
pages: list[str]
class BridgePayload(pydantic.BaseModel):
secret: str
message: str
sender: str
class MessagePayload(pydantic.BaseModel):
class MessageAttachmentPayload(pydantic.BaseModel):
url: str
proxy_url: str
filename: str
size: int
width: Optional[int] = None
height: Optional[int] = None
content_type: str
ATTACHMENT: Optional[Any] = None
event_type: Optional[str] = "create"
message_id: int
author: str
is_automated: bool = False
avatar: str
content: str
clean_content: str
at: float
attachments: list[MessageAttachmentPayload] = []
reply_to: Optional["MessagePayload"] = None
app = FastAPI(
title="JimmyAPI",
version="2.0.0a1"
)
log = logging.getLogger("jimmy.web.api")
app.state.bot = None
app.state.bridge_lock = asyncio.Lock()
app.state.last_sender_ts = 0
@app.get("/ping")
def ping():
"""Checks the bot is online and provides some uptime information"""
if not app.state.bot:
raise HTTPException(status.HTTP_503_SERVICE_UNAVAILABLE)
return {
"ping": "pong",
"online": app.state.bot.is_ready(),
"latency": max(round(app.state.bot.latency, 2), 0.01),
"uptime": round(time.time() - psutil.Process().create_time()),
"uptime.sys": time.time() - psutil.boot_time()
}
@app.post("/bridge", status_code=201)
async def bridge_post_send_message(body: BridgePayload):
"""Sends a message FROM matrix TO discord."""
now = datetime.datetime.now(datetime.timezone.utc)
ts_diff = (now - app.state.last_sender_ts).total_seconds()
if not app.state.bot:
raise HTTPException(status.HTTP_503_SERVICE_UNAVAILABLE)
if body.secret != CONFIG["jimmy"].get("token"):
log.warning("Authentication failure: %s was not authenticated.", body.secret)
raise HTTPException(status.HTTP_401_UNAUTHORIZED)
channel = app.state.bot.get_channel(CONFIG["server"]["channel"])
if not channel or not channel.can_send():
log.warning("Unable to send message: channel not found or not writable.")
raise HTTPException(status.HTTP_503_SERVICE_UNAVAILABLE)
if len(body.message) > 4000:
log.warning(
"Unable to send message: message too long ({:,} characters long, 4000 max).".format(len(body.message))
)
raise HTTPException(status.HTTP_413_REQUEST_ENTITY_TOO_LARGE)
paginator = Paginator(prefix="", suffix="", max_size=1990)
for line in body["message"].splitlines():
try:
paginator.add_line(line)
except ValueError:
paginator.add_line(textwrap.shorten(line, width=1900, placeholder="<...>"))
if len(paginator.pages) > 1:
msg = None
if app.state.last_sender != body["sender"] or ts_diff >= 600:
msg = await channel.send(f"**{body['sender']}**:")
m = len(paginator.pages)
for n, page in enumerate(paginator.pages, 1):
await channel.send(
f"[{n}/{m}]\n>>> {page}",
allowed_mentions=discord.AllowedMentions.none(),
reference=msg,
silent=True,
suppress=n != m,
)
app.state.last_sender = body["sender"]
else:
content = f"**{body['sender']}**:\n>>> {body['message']}"
if app.state.last_sender == body["sender"] and ts_diff < 600:
content = f">>> {body['message']}"
await channel.send(content, allowed_mentions=discord.AllowedMentions.none(), silent=True, suppress=False)
app.state.last_sender = body["sender"]
app.state.last_sender_ts = now
return {"status": "ok", "pages": len(paginator.pages)}
@app.websocket("/bridge/recv")
async def bridge_recv(ws: WebSocket, secret: str = Header(None)):
await ws.accept()
log.info("Websocket %s:%s accepted.", ws.client.host, ws.client.port)
if secret != app.state.bot.http.token:
log.warning("Closing websocket %r, invalid secret.", ws.client.host)
raise WebSocketException(code=1008, reason="Invalid Secret")
if app.state.ws_connected.locked():
log.warning("Closing websocket %r, already connected." % ws)
raise WebSocketException(code=1008, reason="Already connected.")
queue: asyncio.Queue = app.state.bot.bridge_queue
async with app.state.ws_connected:
while True:
try:
await ws.send_json({"status": "ping"})
except (WebSocketDisconnect, WebSocketException):
log.info("Websocket %r disconnected.", ws)
break
try:
data = await asyncio.wait_for(queue.get(), timeout=5)
except asyncio.TimeoutError:
continue
try:
await ws.send_json(data)
log.debug("Sent data %r to websocket %r.", data, ws)
except (WebSocketDisconnect, WebSocketException):
log.info("Websocket %r disconnected." % ws)
break
finally:
queue.task_done()