Merge branch 'master' of github.com:nexy7574/college-bot-v2
This commit is contained in:
commit
c8019f3993
10 changed files with 336 additions and 199 deletions
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -311,3 +311,4 @@ pyrightconfig.json
|
||||||
cookies.txt
|
cookies.txt
|
||||||
config.toml
|
config.toml
|
||||||
chrome/
|
chrome/
|
||||||
|
src/assets/sensitive/*
|
|
@ -18,3 +18,4 @@ humanize~=4.9
|
||||||
redis~=5.0
|
redis~=5.0
|
||||||
beautifulsoup4~=4.12
|
beautifulsoup4~=4.12
|
||||||
lxml~=5.1
|
lxml~=5.1
|
||||||
|
matplotlib~=3.8
|
||||||
|
|
|
@ -4,6 +4,7 @@ import os
|
||||||
import re
|
import re
|
||||||
import time
|
import time
|
||||||
import typing
|
import typing
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
import discord
|
import discord
|
||||||
from discord.ext import commands
|
from discord.ext import commands
|
||||||
|
@ -243,6 +244,16 @@ class NetworkCog(commands.Cog):
|
||||||
for page in paginator.pages:
|
for page in paginator.pages:
|
||||||
await ctx.respond(page)
|
await ctx.respond(page)
|
||||||
|
|
||||||
|
@commands.slash_command(name="what-are-matthews-bank-details")
|
||||||
|
async def matthew_bank(self, ctx: discord.ApplicationContext):
|
||||||
|
"""For the 80th time"""
|
||||||
|
f = Path.cwd() / "assets" / "sensitive" / "matthew-bank.webp"
|
||||||
|
if not f.exists():
|
||||||
|
return await ctx.respond("Idk")
|
||||||
|
else:
|
||||||
|
await ctx.defer()
|
||||||
|
await ctx.respond(file=discord.File(f))
|
||||||
|
|
||||||
|
|
||||||
def setup(bot):
|
def setup(bot):
|
||||||
bot.add_cog(NetworkCog(bot))
|
bot.add_cog(NetworkCog(bot))
|
||||||
|
|
|
@ -8,6 +8,7 @@ import typing
|
||||||
import base64
|
import base64
|
||||||
import io
|
import io
|
||||||
import redis
|
import redis
|
||||||
|
from discord import Interaction
|
||||||
|
|
||||||
from discord.ui import View, button
|
from discord.ui import View, button
|
||||||
from fnmatch import fnmatch
|
from fnmatch import fnmatch
|
||||||
|
@ -89,7 +90,7 @@ class ChatHistory:
|
||||||
"threads:" + thread_id, json.dumps(self._internal[thread_id])
|
"threads:" + thread_id, json.dumps(self._internal[thread_id])
|
||||||
)
|
)
|
||||||
|
|
||||||
def create_thread(self, member: discord.Member) -> str:
|
def create_thread(self, member: discord.Member, default: str | None = None) -> str:
|
||||||
"""
|
"""
|
||||||
Creates a thread, returns its ID.
|
Creates a thread, returns its ID.
|
||||||
"""
|
"""
|
||||||
|
@ -100,7 +101,7 @@ class ChatHistory:
|
||||||
"messages": []
|
"messages": []
|
||||||
}
|
}
|
||||||
with open("./assets/ollama-prompt.txt") as file:
|
with open("./assets/ollama-prompt.txt") as file:
|
||||||
system_prompt = file.read()
|
system_prompt = default or file.read()
|
||||||
self.add_message(
|
self.add_message(
|
||||||
key,
|
key,
|
||||||
"system",
|
"system",
|
||||||
|
@ -190,6 +191,70 @@ class ChatHistory:
|
||||||
SERVER_KEYS = list(CONFIG["ollama"].keys())
|
SERVER_KEYS = list(CONFIG["ollama"].keys())
|
||||||
|
|
||||||
|
|
||||||
|
class OllamaGetPrompt(discord.ui.Modal):
|
||||||
|
|
||||||
|
def __init__(self, ctx: discord.ApplicationContext, prompt_type: str = "User"):
|
||||||
|
super().__init__(
|
||||||
|
discord.ui.InputText(
|
||||||
|
style=discord.InputTextStyle.long,
|
||||||
|
label="%s prompt" % prompt_type,
|
||||||
|
placeholder="Enter your prompt here.",
|
||||||
|
),
|
||||||
|
timeout=300,
|
||||||
|
title="Ollama %s prompt" % prompt_type,
|
||||||
|
)
|
||||||
|
self.ctx = ctx
|
||||||
|
self.prompt_type = prompt_type
|
||||||
|
self.value = None
|
||||||
|
|
||||||
|
async def interaction_check(self, interaction: discord.Interaction) -> bool:
|
||||||
|
return interaction.user == self.ctx.user
|
||||||
|
|
||||||
|
async def callback(self, interaction: Interaction):
|
||||||
|
await interaction.response.defer()
|
||||||
|
self.value = self.children[0].value
|
||||||
|
self.stop()
|
||||||
|
|
||||||
|
|
||||||
|
class PromptSelector(discord.ui.View):
|
||||||
|
def __init__(self, ctx: discord.ApplicationContext):
|
||||||
|
super().__init__(timeout=600, disable_on_timeout=True)
|
||||||
|
self.ctx = ctx
|
||||||
|
self.system_prompt = None
|
||||||
|
self.user_prompt = None
|
||||||
|
|
||||||
|
async def interaction_check(self, interaction: Interaction) -> bool:
|
||||||
|
return interaction.user == self.ctx.user
|
||||||
|
|
||||||
|
def update_ui(self):
|
||||||
|
if self.system_prompt is not None:
|
||||||
|
self.get_item("sys").style = discord.ButtonStyle.secondary # type: ignore
|
||||||
|
if self.user_prompt is not None:
|
||||||
|
self.get_item("usr").style = discord.ButtonStyle.secondary # type: ignore
|
||||||
|
|
||||||
|
@discord.ui.button(label="Set System Prompt", style=discord.ButtonStyle.primary, custom_id="sys")
|
||||||
|
async def set_system_prompt(self, btn: discord.ui.Button, interaction: Interaction):
|
||||||
|
modal = OllamaGetPrompt(self.ctx, "System")
|
||||||
|
await interaction.response.send_modal(modal)
|
||||||
|
await modal.wait()
|
||||||
|
self.system_prompt = modal.value
|
||||||
|
self.update_ui()
|
||||||
|
await interaction.edit_original_response(view=self)
|
||||||
|
|
||||||
|
@discord.ui.button(label="Set User Prompt", style=discord.ButtonStyle.primary, custom_id="usr")
|
||||||
|
async def set_user_prompt(self, btn: discord.ui.Button, interaction: Interaction):
|
||||||
|
modal = OllamaGetPrompt(self.ctx)
|
||||||
|
await interaction.response.send_modal(modal)
|
||||||
|
await modal.wait()
|
||||||
|
self.user_prompt = modal.value
|
||||||
|
self.update_ui()
|
||||||
|
await interaction.edit_original_response(view=self)
|
||||||
|
|
||||||
|
@discord.ui.button(label="Done", style=discord.ButtonStyle.success, custom_id="done")
|
||||||
|
async def done(self, btn: discord.ui.Button, interaction: Interaction):
|
||||||
|
self.stop()
|
||||||
|
|
||||||
|
|
||||||
class Ollama(commands.Cog):
|
class Ollama(commands.Cog):
|
||||||
def __init__(self, bot: commands.Bot):
|
def __init__(self, bot: commands.Bot):
|
||||||
self.bot = bot
|
self.bot = bot
|
||||||
|
@ -282,11 +347,24 @@ class Ollama(commands.Cog):
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
):
|
):
|
||||||
|
system_query = None
|
||||||
if context is not None:
|
if context is not None:
|
||||||
if not self.history.get_thread(context):
|
if not self.history.get_thread(context):
|
||||||
await ctx.respond("Invalid context key.")
|
await ctx.respond("Invalid context key.")
|
||||||
return
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
await ctx.defer()
|
await ctx.defer()
|
||||||
|
except discord.HTTPException:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if query == "$":
|
||||||
|
v = PromptSelector(ctx)
|
||||||
|
await ctx.respond("Select edit your prompts, as desired. Click done when you want to continue.", view=v)
|
||||||
|
await v.wait()
|
||||||
|
query = v.user_prompt or query
|
||||||
|
system_query = v.system_prompt
|
||||||
|
await ctx.delete(delay=0.1)
|
||||||
|
|
||||||
model = model.casefold()
|
model = model.casefold()
|
||||||
try:
|
try:
|
||||||
|
@ -294,7 +372,7 @@ class Ollama(commands.Cog):
|
||||||
model = model + ":" + tag
|
model = model + ":" + tag
|
||||||
self.log.debug("Model %r already has a tag", model)
|
self.log.debug("Model %r already has a tag", model)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
model = model + ":latest"
|
model += ":latest"
|
||||||
self.log.debug("Resolved model to %r" % model)
|
self.log.debug("Resolved model to %r" % model)
|
||||||
|
|
||||||
if image:
|
if image:
|
||||||
|
@ -315,7 +393,7 @@ class Ollama(commands.Cog):
|
||||||
data = io.BytesIO()
|
data = io.BytesIO()
|
||||||
await image.save(data)
|
await image.save(data)
|
||||||
data.seek(0)
|
data.seek(0)
|
||||||
image_data = base64.b64encode(data.read()).decode("utf-8")
|
image_data = base64.b64encode(data.read()).decode()
|
||||||
else:
|
else:
|
||||||
image_data = None
|
image_data = None
|
||||||
|
|
||||||
|
@ -336,7 +414,12 @@ class Ollama(commands.Cog):
|
||||||
|
|
||||||
async with aiohttp.ClientSession(
|
async with aiohttp.ClientSession(
|
||||||
base_url=server_config["base_url"],
|
base_url=server_config["base_url"],
|
||||||
timeout=aiohttp.ClientTimeout(0)
|
timeout=aiohttp.ClientTimeout(
|
||||||
|
connect=30,
|
||||||
|
sock_read=10800,
|
||||||
|
sock_connect=30,
|
||||||
|
total=10830
|
||||||
|
)
|
||||||
) as session:
|
) as session:
|
||||||
embed = discord.Embed(
|
embed = discord.Embed(
|
||||||
title="Checking server...",
|
title="Checking server...",
|
||||||
|
@ -482,7 +565,7 @@ class Ollama(commands.Cog):
|
||||||
self.log.debug("Beginning to generate response with key %r.", key)
|
self.log.debug("Beginning to generate response with key %r.", key)
|
||||||
|
|
||||||
if context is None:
|
if context is None:
|
||||||
context = self.history.create_thread(ctx.user)
|
context = self.history.create_thread(ctx.user, system_query)
|
||||||
elif context is not None and self.history.get_thread(context) is None:
|
elif context is not None and self.history.get_thread(context) is None:
|
||||||
__thread = self.history.find_thread(context)
|
__thread = self.history.find_thread(context)
|
||||||
if not __thread:
|
if not __thread:
|
||||||
|
|
188
src/cogs/quote_quota.py
Normal file
188
src/cogs/quote_quota.py
Normal file
|
@ -0,0 +1,188 @@
|
||||||
|
import asyncio
|
||||||
|
import re
|
||||||
|
|
||||||
|
import discord
|
||||||
|
import io
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
from datetime import timedelta
|
||||||
|
from discord.ext import commands
|
||||||
|
from typing import Iterable, Annotated
|
||||||
|
|
||||||
|
from conf import CONFIG
|
||||||
|
|
||||||
|
|
||||||
|
class QuoteQuota(commands.Cog):
|
||||||
|
|
||||||
|
def __init__(self, bot):
|
||||||
|
self.bot = bot
|
||||||
|
self.quotes_channel_id = CONFIG["quote_a"].get("channel_id")
|
||||||
|
self.names = CONFIG["quote_a"].get("names", {})
|
||||||
|
|
||||||
|
@property
|
||||||
|
def quotes_channel(self) -> discord.TextChannel | None:
|
||||||
|
if self.quotes_channel_id:
|
||||||
|
c = self.bot.get_channel(self.quotes_channel_id)
|
||||||
|
if c:
|
||||||
|
return c
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def generate_pie_chart(
|
||||||
|
usernames: list[str],
|
||||||
|
counts: list[int],
|
||||||
|
no_other: bool = False
|
||||||
|
) -> discord.File:
|
||||||
|
"""
|
||||||
|
Converts the given username and count tuples into a nice pretty pie chart.
|
||||||
|
|
||||||
|
:param usernames: The usernames
|
||||||
|
:param counts: The number of times the username appears in the chat
|
||||||
|
:param no_other: Disables the "other" grouping
|
||||||
|
:returns: The pie chart image
|
||||||
|
"""
|
||||||
|
|
||||||
|
def pct(v: int):
|
||||||
|
return f"{v:.1f}% ({round((v / 100) * sum(counts))})"
|
||||||
|
|
||||||
|
if no_other is False:
|
||||||
|
other = []
|
||||||
|
# Any authors with less than 5% of the total count will be grouped into "other"
|
||||||
|
for i, author in enumerate(usernames.copy()):
|
||||||
|
if (c := counts[i]) / sum(counts) < 0.05:
|
||||||
|
other.append(c)
|
||||||
|
counts[i] = -1
|
||||||
|
usernames.remove(author)
|
||||||
|
if other:
|
||||||
|
usernames.append("Other")
|
||||||
|
counts.append(sum(other))
|
||||||
|
# And now filter out any -1% counts
|
||||||
|
counts = [c for c in counts if c != -1]
|
||||||
|
|
||||||
|
mapping = {}
|
||||||
|
for i, author in enumerate(usernames):
|
||||||
|
mapping[author] = counts[i]
|
||||||
|
|
||||||
|
# Sort the authors by count
|
||||||
|
new_mapping = {}
|
||||||
|
for author, count in sorted(mapping.items(), key=lambda x: x[1], reverse=True):
|
||||||
|
new_mapping[author] = count
|
||||||
|
|
||||||
|
usernames = list(new_mapping.keys())
|
||||||
|
counts = list(new_mapping.values())
|
||||||
|
|
||||||
|
fig, ax = plt.subplots(figsize=(7, 7))
|
||||||
|
ax.pie(
|
||||||
|
counts,
|
||||||
|
labels=usernames,
|
||||||
|
autopct=pct,
|
||||||
|
startangle=90,
|
||||||
|
radius=1.2,
|
||||||
|
)
|
||||||
|
fig.subplots_adjust(left=0.1, bottom=0.1, right=0.9, top=0.9, wspace=0.3, hspace=0.4)
|
||||||
|
fio = io.BytesIO()
|
||||||
|
fig.savefig(fio, format='png')
|
||||||
|
fio.seek(0)
|
||||||
|
return discord.File(fio, filename="pie.png")
|
||||||
|
|
||||||
|
@commands.slash_command()
|
||||||
|
async def quota(
|
||||||
|
self,
|
||||||
|
ctx: discord.ApplicationContext,
|
||||||
|
days: Annotated[
|
||||||
|
int,
|
||||||
|
discord.Option(
|
||||||
|
int,
|
||||||
|
name="lookback",
|
||||||
|
description="How many days to look back on. Defaults to 7.",
|
||||||
|
default=7,
|
||||||
|
min_value=1,
|
||||||
|
max_value=365
|
||||||
|
)
|
||||||
|
],
|
||||||
|
merge_other: Annotated[
|
||||||
|
bool,
|
||||||
|
discord.Option(
|
||||||
|
bool,
|
||||||
|
name="merge_other",
|
||||||
|
description="Whether to merge authors with less than 5% of the total count into 'Other'.",
|
||||||
|
default=True
|
||||||
|
)
|
||||||
|
]
|
||||||
|
):
|
||||||
|
"""Checks the quote quota for the quotes channel."""
|
||||||
|
now = discord.utils.utcnow()
|
||||||
|
oldest = now - timedelta(days=days)
|
||||||
|
await ctx.defer()
|
||||||
|
channel = self.quotes_channel or discord.utils.get(ctx.guild.text_channels, name="quotes")
|
||||||
|
if not channel:
|
||||||
|
return await ctx.respond(":x: Cannot find quotes channel.")
|
||||||
|
|
||||||
|
await ctx.respond("Gathering messages, this may take a moment.")
|
||||||
|
|
||||||
|
authors = {}
|
||||||
|
filtered_messages = 0
|
||||||
|
total = 0
|
||||||
|
async for message in channel.history(
|
||||||
|
limit=None,
|
||||||
|
after=oldest,
|
||||||
|
oldest_first=False
|
||||||
|
):
|
||||||
|
total += 1
|
||||||
|
if not message.content:
|
||||||
|
filtered_messages += 1
|
||||||
|
continue
|
||||||
|
if message.attachments:
|
||||||
|
regex = r".*\s*-\s*@?([\w\s]+)"
|
||||||
|
else:
|
||||||
|
regex = r".+\s+-\s*@?([\w\s]+)"
|
||||||
|
|
||||||
|
if not (m := re.match(regex, str(message.clean_content))):
|
||||||
|
filtered_messages += 1
|
||||||
|
continue
|
||||||
|
name = m.group(1)
|
||||||
|
name = name.strip().casefold()
|
||||||
|
if name == "me":
|
||||||
|
name = message.author.name.strip().casefold()
|
||||||
|
if name in self.names:
|
||||||
|
name = self.names[name].title()
|
||||||
|
else:
|
||||||
|
filtered_messages += 1
|
||||||
|
continue
|
||||||
|
elif name in self.names:
|
||||||
|
name = self.names[name].title()
|
||||||
|
elif name.isdigit():
|
||||||
|
filtered_messages += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
name = name.title()
|
||||||
|
authors.setdefault(name, 0)
|
||||||
|
authors[name] += 1
|
||||||
|
|
||||||
|
if not authors:
|
||||||
|
if total:
|
||||||
|
return await ctx.edit(
|
||||||
|
content="No valid messages found in the last {!s} days. "
|
||||||
|
"Make sure quotes are formatted properly ending with ` - AuthorName`"
|
||||||
|
" (e.g. `\"This is my quote\" - Jimmy`)".format(days)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return await ctx.edit(
|
||||||
|
content="No messages found in the last {!s} days.".format(days)
|
||||||
|
)
|
||||||
|
|
||||||
|
file = await asyncio.to_thread(
|
||||||
|
self.generate_pie_chart,
|
||||||
|
list(authors.keys()),
|
||||||
|
list(authors.values()),
|
||||||
|
merge_other
|
||||||
|
)
|
||||||
|
return await ctx.edit(
|
||||||
|
content="{:,} messages (out of {:,}) were filtered (didn't follow format?)".format(
|
||||||
|
filtered_messages,
|
||||||
|
total
|
||||||
|
),
|
||||||
|
file=file
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def setup(bot):
|
||||||
|
bot.add_cog(QuoteQuota(bot))
|
|
@ -5,6 +5,7 @@ import logging
|
||||||
import os
|
import os
|
||||||
import tempfile
|
import tempfile
|
||||||
import time
|
import time
|
||||||
|
import copy
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
import discord
|
import discord
|
||||||
|
@ -15,6 +16,8 @@ from selenium import webdriver
|
||||||
from selenium.webdriver.chrome.options import Options as ChromeOptions
|
from selenium.webdriver.chrome.options import Options as ChromeOptions
|
||||||
from selenium.webdriver.chrome.service import Service as ChromeService
|
from selenium.webdriver.chrome.service import Service as ChromeService
|
||||||
|
|
||||||
|
from conf import CONFIG
|
||||||
|
|
||||||
|
|
||||||
class ScreenshotCog(commands.Cog):
|
class ScreenshotCog(commands.Cog):
|
||||||
def __init__(self, bot: commands.Bot):
|
def __init__(self, bot: commands.Bot):
|
||||||
|
@ -76,7 +79,8 @@ class ScreenshotCog(commands.Cog):
|
||||||
load_timeout: int = 10,
|
load_timeout: int = 10,
|
||||||
render_timeout: int = None,
|
render_timeout: int = None,
|
||||||
eager: bool = None,
|
eager: bool = None,
|
||||||
resolution: str = "1920x1080"
|
resolution: str = "1920x1080",
|
||||||
|
use_proxy: bool = False
|
||||||
):
|
):
|
||||||
"""Screenshots a webpage."""
|
"""Screenshots a webpage."""
|
||||||
await ctx.defer()
|
await ctx.defer()
|
||||||
|
@ -104,11 +108,14 @@ class ScreenshotCog(commands.Cog):
|
||||||
|
|
||||||
start_init = time.time()
|
start_init = time.time()
|
||||||
try:
|
try:
|
||||||
|
options = copy.copy(self.chrome_options)
|
||||||
|
if use_proxy and (server := CONFIG["screenshot"].get("proxy")):
|
||||||
|
options.add_argument("--proxy-server=" + server)
|
||||||
service = await asyncio.to_thread(ChromeService)
|
service = await asyncio.to_thread(ChromeService)
|
||||||
driver: webdriver.Chrome = await asyncio.to_thread(
|
driver: webdriver.Chrome = await asyncio.to_thread(
|
||||||
webdriver.Chrome,
|
webdriver.Chrome,
|
||||||
service=service,
|
service=service,
|
||||||
options=self.chrome_options
|
options=options
|
||||||
)
|
)
|
||||||
driver.set_page_load_timeout(load_timeout)
|
driver.set_page_load_timeout(load_timeout)
|
||||||
if resolution:
|
if resolution:
|
||||||
|
@ -173,6 +180,7 @@ class ScreenshotCog(commands.Cog):
|
||||||
end_save = time.time()
|
end_save = time.time()
|
||||||
|
|
||||||
if len(await asyncio.to_thread(file.getvalue)) > 24 * 1024 * 1024:
|
if len(await asyncio.to_thread(file.getvalue)) > 24 * 1024 * 1024:
|
||||||
|
await ctx.edit(content="Compressing screenshot...")
|
||||||
start_compress = time.time()
|
start_compress = time.time()
|
||||||
file = await asyncio.to_thread(self.compress_png, file)
|
file = await asyncio.to_thread(self.compress_png, file)
|
||||||
fn = "screenshot.webp"
|
fn = "screenshot.webp"
|
||||||
|
|
|
@ -82,23 +82,34 @@ class YTDLCog(commands.Cog):
|
||||||
await db.commit()
|
await db.commit()
|
||||||
return
|
return
|
||||||
|
|
||||||
async def save_link(self, message: discord.Message, webpage_url: str, format_id: str, attachment_index: int = 0):
|
async def save_link(
|
||||||
|
self,
|
||||||
|
message: discord.Message,
|
||||||
|
webpage_url: str,
|
||||||
|
format_id: str,
|
||||||
|
attachment_index: int = 0,
|
||||||
|
*,
|
||||||
|
snip: typing.Optional[str] = None
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Saves a link to discord to prevent having to re-download it.
|
Saves a link to discord to prevent having to re-download it.
|
||||||
:param message: The download message with the attachment.
|
:param message: The download message with the attachment.
|
||||||
:param webpage_url: The "webpage_url" key of the metadata
|
:param webpage_url: The "webpage_url" key of the metadata
|
||||||
:param format_id: The "format_Id" key of the metadata
|
:param format_id: The "format_Id" key of the metadata
|
||||||
:param attachment_index: The index of the attachment. Defaults to 0
|
:param attachment_index: The index of the attachment. Defaults to 0
|
||||||
|
:param snip: The start and end time to snip the video. e.g. 00:00:00-00:10:00
|
||||||
:return: The created hash key
|
:return: The created hash key
|
||||||
"""
|
"""
|
||||||
|
snip = snip or '*'
|
||||||
await self._init_db()
|
await self._init_db()
|
||||||
async with aiosqlite.connect("./data/ytdl.db") as db:
|
async with aiosqlite.connect("./data/ytdl.db") as db:
|
||||||
_hash = hashlib.md5(f"{webpage_url}:{format_id}".encode()).hexdigest()
|
_hash = hashlib.md5(f"{webpage_url}:{format_id}:{snip}".encode()).hexdigest()
|
||||||
self.log.debug(
|
self.log.debug(
|
||||||
"Saving %r (%r:%r) with message %d>%d, index %d",
|
"Saving %r (%r:%r:%r) with message %d>%d, index %d",
|
||||||
_hash,
|
_hash,
|
||||||
webpage_url,
|
webpage_url,
|
||||||
format_id,
|
format_id,
|
||||||
|
snip,
|
||||||
message.channel.id,
|
message.channel.id,
|
||||||
message.id,
|
message.id,
|
||||||
attachment_index
|
attachment_index
|
||||||
|
@ -117,20 +128,27 @@ class YTDLCog(commands.Cog):
|
||||||
await db.commit()
|
await db.commit()
|
||||||
return _hash
|
return _hash
|
||||||
|
|
||||||
async def get_saved(self, webpage_url: str, format_id: str) -> typing.Optional[str]:
|
async def get_saved(
|
||||||
|
self,
|
||||||
|
webpage_url: str,
|
||||||
|
format_id: str,
|
||||||
|
snip: str
|
||||||
|
) -> typing.Optional[str]:
|
||||||
"""
|
"""
|
||||||
Attempts to retrieve the attachment URL of a previously saved download.
|
Attempts to retrieve the attachment URL of a previously saved download.
|
||||||
:param webpage_url: The webpage url
|
:param webpage_url: The webpage url
|
||||||
:param format_id: The format ID
|
:param format_id: The format ID
|
||||||
|
:param snip: The start and end time to snip the video. e.g. 00:00:00-00:10:00
|
||||||
:return: the URL, if found and valid.
|
:return: the URL, if found and valid.
|
||||||
"""
|
"""
|
||||||
await self._init_db()
|
await self._init_db()
|
||||||
async with aiosqlite.connect("./data/ytdl.db") as db:
|
async with aiosqlite.connect("./data/ytdl.db") as db:
|
||||||
_hash = hashlib.md5(f"{webpage_url}:{format_id}".encode()).hexdigest()
|
_hash = hashlib.md5(f"{webpage_url}:{format_id}:{snip}".encode()).hexdigest()
|
||||||
self.log.debug(
|
self.log.debug(
|
||||||
"Attempting to find a saved download for '%s:%s' (%r).",
|
"Attempting to find a saved download for '%s:%s:%s' (%r).",
|
||||||
webpage_url,
|
webpage_url,
|
||||||
format_id,
|
format_id,
|
||||||
|
snip,
|
||||||
_hash
|
_hash
|
||||||
)
|
)
|
||||||
cursor = await db.execute(
|
cursor = await db.execute(
|
||||||
|
@ -229,7 +247,6 @@ class YTDLCog(commands.Cog):
|
||||||
snip: typing.Annotated[
|
snip: typing.Annotated[
|
||||||
typing.Optional[str],
|
typing.Optional[str],
|
||||||
discord.Option(
|
discord.Option(
|
||||||
str,
|
|
||||||
description="A start and end position to trim. e.g. 00:00:00-00:10:00.",
|
description="A start and end position to trim. e.g. 00:00:00-00:10:00.",
|
||||||
required=False
|
required=False
|
||||||
)
|
)
|
||||||
|
@ -347,7 +364,7 @@ class YTDLCog(commands.Cog):
|
||||||
colour=self.colours.get(domain, discord.Colour.og_blurple())
|
colour=self.colours.get(domain, discord.Colour.og_blurple())
|
||||||
).set_footer(text="Downloading (step 2/10)").set_thumbnail(url=thumbnail_url)
|
).set_footer(text="Downloading (step 2/10)").set_thumbnail(url=thumbnail_url)
|
||||||
)
|
)
|
||||||
previous = await self.get_saved(webpage_url, extracted_info["format_id"])
|
previous = await self.get_saved(webpage_url, extracted_info["format_id"], snip or '*')
|
||||||
if previous:
|
if previous:
|
||||||
await ctx.edit(
|
await ctx.edit(
|
||||||
content=previous,
|
content=previous,
|
||||||
|
@ -505,7 +522,7 @@ class YTDLCog(commands.Cog):
|
||||||
url=webpage_url
|
url=webpage_url
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
await self.save_link(msg, webpage_url, chosen_format_id)
|
await self.save_link(msg, webpage_url, chosen_format_id, snip=snip or '*')
|
||||||
except discord.HTTPException as e:
|
except discord.HTTPException as e:
|
||||||
self.log.error(e, exc_info=True)
|
self.log.error(e, exc_info=True)
|
||||||
return await ctx.edit(
|
return await ctx.edit(
|
||||||
|
|
|
@ -28,6 +28,8 @@ try:
|
||||||
CONFIG.setdefault("jimmy", {})
|
CONFIG.setdefault("jimmy", {})
|
||||||
CONFIG.setdefault("ollama", {})
|
CONFIG.setdefault("ollama", {})
|
||||||
CONFIG.setdefault("rss", {"meta": {"channel": None}})
|
CONFIG.setdefault("rss", {"meta": {"channel": None}})
|
||||||
|
CONFIG.setdefault("screenshot", {})
|
||||||
|
CONFIG.setdefault("quote_a", {"channel": None})
|
||||||
CONFIG.setdefault(
|
CONFIG.setdefault(
|
||||||
"server",
|
"server",
|
||||||
{
|
{
|
||||||
|
|
22
src/main.py
22
src/main.py
|
@ -10,7 +10,6 @@ import random
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
import uvicorn
|
import uvicorn
|
||||||
from web import app
|
|
||||||
from logging import FileHandler
|
from logging import FileHandler
|
||||||
|
|
||||||
import discord
|
import discord
|
||||||
|
@ -104,25 +103,12 @@ class Client(commands.Bot):
|
||||||
CONFIG["jimmy"].get("uptime_kuma_interval", 60.0)
|
CONFIG["jimmy"].get("uptime_kuma_interval", 60.0)
|
||||||
)
|
)
|
||||||
self.uptime_thread.start()
|
self.uptime_thread.start()
|
||||||
app.state.bot = self
|
|
||||||
config = uvicorn.Config(
|
|
||||||
app,
|
|
||||||
host=CONFIG["server"].get("host", "0.0.0.0"),
|
|
||||||
port=CONFIG["server"].get("port", 8080),
|
|
||||||
loop="asyncio",
|
|
||||||
lifespan="on",
|
|
||||||
server_header=False
|
|
||||||
)
|
|
||||||
server = uvicorn.Server(config=config)
|
|
||||||
self.web = self.loop.create_task(asyncio.to_thread(server.serve()))
|
|
||||||
await super().start(token, reconnect=reconnect)
|
await super().start(token, reconnect=reconnect)
|
||||||
|
|
||||||
async def close(self) -> None:
|
async def close(self) -> None:
|
||||||
if self.web:
|
if self.uptime_thread:
|
||||||
self.web.cancel()
|
self.uptime_thread.kill.set()
|
||||||
if self.thread:
|
await asyncio.get_event_loop().run_in_executor(None, self.uptime_thread.join)
|
||||||
self.thread.kill.set()
|
|
||||||
await asyncio.get_event_loop().run_in_executor(None, self.thread.join)
|
|
||||||
await super().close()
|
await super().close()
|
||||||
|
|
||||||
|
|
||||||
|
@ -133,7 +119,7 @@ bot = Client(
|
||||||
debug_guilds=CONFIG["jimmy"].get("debug_guilds")
|
debug_guilds=CONFIG["jimmy"].get("debug_guilds")
|
||||||
)
|
)
|
||||||
|
|
||||||
for ext in ("ytdl", "net", "screenshot", "ollama", "ffmeta"):
|
for ext in ("ytdl", "net", "screenshot", "ollama", "ffmeta", "quote_quota"):
|
||||||
try:
|
try:
|
||||||
bot.load_extension(f"cogs.{ext}")
|
bot.load_extension(f"cogs.{ext}")
|
||||||
except discord.ExtensionError as e:
|
except discord.ExtensionError as e:
|
||||||
|
|
160
src/web.py
160
src/web.py
|
@ -1,160 +0,0 @@
|
||||||
import asyncio
|
|
||||||
import datetime
|
|
||||||
import logging
|
|
||||||
import textwrap
|
|
||||||
|
|
||||||
import psutil
|
|
||||||
import time
|
|
||||||
import pydantic
|
|
||||||
from typing import Optional, Any
|
|
||||||
from conf import CONFIG
|
|
||||||
import discord
|
|
||||||
from discord.ext.commands import Paginator
|
|
||||||
|
|
||||||
from fastapi import FastAPI, HTTPException, status, WebSocketException, WebSocket, WebSocketDisconnect, Header
|
|
||||||
|
|
||||||
class BridgeResponse(pydantic.BaseModel):
|
|
||||||
status: str
|
|
||||||
pages: list[str]
|
|
||||||
|
|
||||||
|
|
||||||
class BridgePayload(pydantic.BaseModel):
|
|
||||||
secret: str
|
|
||||||
message: str
|
|
||||||
sender: str
|
|
||||||
|
|
||||||
|
|
||||||
class MessagePayload(pydantic.BaseModel):
|
|
||||||
class MessageAttachmentPayload(pydantic.BaseModel):
|
|
||||||
url: str
|
|
||||||
proxy_url: str
|
|
||||||
filename: str
|
|
||||||
size: int
|
|
||||||
width: Optional[int] = None
|
|
||||||
height: Optional[int] = None
|
|
||||||
content_type: str
|
|
||||||
ATTACHMENT: Optional[Any] = None
|
|
||||||
|
|
||||||
event_type: Optional[str] = "create"
|
|
||||||
message_id: int
|
|
||||||
author: str
|
|
||||||
is_automated: bool = False
|
|
||||||
avatar: str
|
|
||||||
content: str
|
|
||||||
clean_content: str
|
|
||||||
at: float
|
|
||||||
attachments: list[MessageAttachmentPayload] = []
|
|
||||||
reply_to: Optional["MessagePayload"] = None
|
|
||||||
|
|
||||||
|
|
||||||
app = FastAPI(
|
|
||||||
title="JimmyAPI",
|
|
||||||
version="2.0.0a1"
|
|
||||||
)
|
|
||||||
log = logging.getLogger("jimmy.web.api")
|
|
||||||
app.state.bot = None
|
|
||||||
app.state.bridge_lock = asyncio.Lock()
|
|
||||||
app.state.last_sender_ts = 0
|
|
||||||
|
|
||||||
|
|
||||||
@app.get("/ping")
|
|
||||||
def ping():
|
|
||||||
"""Checks the bot is online and provides some uptime information"""
|
|
||||||
if not app.state.bot:
|
|
||||||
raise HTTPException(status.HTTP_503_SERVICE_UNAVAILABLE)
|
|
||||||
return {
|
|
||||||
"ping": "pong",
|
|
||||||
"online": app.state.bot.is_ready(),
|
|
||||||
"latency": max(round(app.state.bot.latency, 2), 0.01),
|
|
||||||
"uptime": round(time.time() - psutil.Process().create_time()),
|
|
||||||
"uptime.sys": time.time() - psutil.boot_time()
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@app.post("/bridge", status_code=201)
|
|
||||||
async def bridge_post_send_message(body: BridgePayload):
|
|
||||||
"""Sends a message FROM matrix TO discord."""
|
|
||||||
now = datetime.datetime.now(datetime.timezone.utc)
|
|
||||||
ts_diff = (now - app.state.last_sender_ts).total_seconds()
|
|
||||||
if not app.state.bot:
|
|
||||||
raise HTTPException(status.HTTP_503_SERVICE_UNAVAILABLE)
|
|
||||||
|
|
||||||
if body.secret != CONFIG["jimmy"].get("token"):
|
|
||||||
log.warning("Authentication failure: %s was not authenticated.", body.secret)
|
|
||||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED)
|
|
||||||
|
|
||||||
channel = app.state.bot.get_channel(CONFIG["server"]["channel"])
|
|
||||||
if not channel or not channel.can_send():
|
|
||||||
log.warning("Unable to send message: channel not found or not writable.")
|
|
||||||
raise HTTPException(status.HTTP_503_SERVICE_UNAVAILABLE)
|
|
||||||
|
|
||||||
if len(body.message) > 4000:
|
|
||||||
log.warning(
|
|
||||||
"Unable to send message: message too long ({:,} characters long, 4000 max).".format(len(body.message))
|
|
||||||
)
|
|
||||||
raise HTTPException(status.HTTP_413_REQUEST_ENTITY_TOO_LARGE)
|
|
||||||
|
|
||||||
paginator = Paginator(prefix="", suffix="", max_size=1990)
|
|
||||||
for line in body["message"].splitlines():
|
|
||||||
try:
|
|
||||||
paginator.add_line(line)
|
|
||||||
except ValueError:
|
|
||||||
paginator.add_line(textwrap.shorten(line, width=1900, placeholder="<...>"))
|
|
||||||
|
|
||||||
if len(paginator.pages) > 1:
|
|
||||||
msg = None
|
|
||||||
if app.state.last_sender != body["sender"] or ts_diff >= 600:
|
|
||||||
msg = await channel.send(f"**{body['sender']}**:")
|
|
||||||
m = len(paginator.pages)
|
|
||||||
for n, page in enumerate(paginator.pages, 1):
|
|
||||||
await channel.send(
|
|
||||||
f"[{n}/{m}]\n>>> {page}",
|
|
||||||
allowed_mentions=discord.AllowedMentions.none(),
|
|
||||||
reference=msg,
|
|
||||||
silent=True,
|
|
||||||
suppress=n != m,
|
|
||||||
)
|
|
||||||
app.state.last_sender = body["sender"]
|
|
||||||
else:
|
|
||||||
content = f"**{body['sender']}**:\n>>> {body['message']}"
|
|
||||||
if app.state.last_sender == body["sender"] and ts_diff < 600:
|
|
||||||
content = f">>> {body['message']}"
|
|
||||||
await channel.send(content, allowed_mentions=discord.AllowedMentions.none(), silent=True, suppress=False)
|
|
||||||
app.state.last_sender = body["sender"]
|
|
||||||
app.state.last_sender_ts = now
|
|
||||||
return {"status": "ok", "pages": len(paginator.pages)}
|
|
||||||
|
|
||||||
|
|
||||||
@app.websocket("/bridge/recv")
|
|
||||||
async def bridge_recv(ws: WebSocket, secret: str = Header(None)):
|
|
||||||
await ws.accept()
|
|
||||||
log.info("Websocket %s:%s accepted.", ws.client.host, ws.client.port)
|
|
||||||
if secret != app.state.bot.http.token:
|
|
||||||
log.warning("Closing websocket %r, invalid secret.", ws.client.host)
|
|
||||||
raise WebSocketException(code=1008, reason="Invalid Secret")
|
|
||||||
if app.state.ws_connected.locked():
|
|
||||||
log.warning("Closing websocket %r, already connected." % ws)
|
|
||||||
raise WebSocketException(code=1008, reason="Already connected.")
|
|
||||||
queue: asyncio.Queue = app.state.bot.bridge_queue
|
|
||||||
|
|
||||||
async with app.state.ws_connected:
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
await ws.send_json({"status": "ping"})
|
|
||||||
except (WebSocketDisconnect, WebSocketException):
|
|
||||||
log.info("Websocket %r disconnected.", ws)
|
|
||||||
break
|
|
||||||
|
|
||||||
try:
|
|
||||||
data = await asyncio.wait_for(queue.get(), timeout=5)
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
continue
|
|
||||||
|
|
||||||
try:
|
|
||||||
await ws.send_json(data)
|
|
||||||
log.debug("Sent data %r to websocket %r.", data, ws)
|
|
||||||
except (WebSocketDisconnect, WebSocketException):
|
|
||||||
log.info("Websocket %r disconnected." % ws)
|
|
||||||
break
|
|
||||||
finally:
|
|
||||||
queue.task_done()
|
|
Loading…
Reference in a new issue