Merge remote-tracking branch 'origin/master'
All checks were successful
Build and Publish / build_and_publish (push) Successful in 1m6s

This commit is contained in:
Nexus 2024-07-22 23:50:55 +01:00
commit 27b23d3e3a
Signed by: nex
GPG key ID: 0FA334385D0B689F
15 changed files with 1311 additions and 553 deletions

View file

@ -21,3 +21,4 @@ aiofiles~=23.2
fuzzywuzzy[speedup]~=0.18 fuzzywuzzy[speedup]~=0.18
tortoise-orm[asyncpg]~=0.21 tortoise-orm[asyncpg]~=0.21
superpaste @ git+https://github.com/nexy7574/superpaste.git@e31eca6 superpaste @ git+https://github.com/nexy7574/superpaste.git@e31eca6
orjson~=3.10

View file

@ -33,11 +33,16 @@ class AutoResponder(commands.Cog):
raise ValueError( raise ValueError(
"overpowered_by is set, however members intent is disabled, making it useless." "overpowered_by is set, however members intent is disabled, making it useless."
) )
if self.config.get("overrule_offline_superiors", True) is True and bot.intents.presences is False: if (
self.config.get("overrule_offline_superiors", True) is True
and bot.intents.presences is False
):
raise ValueError( raise ValueError(
"overrule_offline_superiors is enabled, however presences intent is not!" "overrule_offline_superiors is enabled, however presences intent is not!"
) )
self.config.setdefault("transcoding", {"enabled": True, "hevc": True, "on_demand": True}) self.config.setdefault(
"transcoding", {"enabled": True, "hevc": True, "on_demand": True}
)
self.lmgtfy_cache = [] self.lmgtfy_cache = []
@property @property
@ -50,7 +55,9 @@ class AutoResponder(commands.Cog):
@property @property
def allow_on_demand_transcoding(self) -> bool: def allow_on_demand_transcoding(self) -> bool:
return self.transcoding_enabled and self.config["transcoding"].get("on_demand", True) return self.transcoding_enabled and self.config["transcoding"].get(
"on_demand", True
)
def overpowered(self, guild: discord.Guild | None) -> bool: def overpowered(self, guild: discord.Guild | None) -> bool:
if not guild: if not guild:
@ -68,14 +75,20 @@ class AutoResponder(commands.Cog):
@staticmethod @staticmethod
@typing.overload @typing.overload
def extract_links(text: str, *domains: str, raw: typing.Literal[True] = False) -> list[ParseResult]: ... def extract_links(
text: str, *domains: str, raw: typing.Literal[True] = False
) -> list[ParseResult]: ...
@staticmethod @staticmethod
@typing.overload @typing.overload
def extract_links(text: str, *domains: str, raw: typing.Literal[False] = False) -> list[str]: ... def extract_links(
text: str, *domains: str, raw: typing.Literal[False] = False
) -> list[str]: ...
@staticmethod @staticmethod
def extract_links(text: str, *domains: str, raw: bool = False) -> list[str | ParseResult]: def extract_links(
text: str, *domains: str, raw: bool = False
) -> list[str | ParseResult]:
""" """
Extracts all links from a given text. Extracts all links from a given text.
@ -124,10 +137,13 @@ class AutoResponder(commands.Cog):
if not update: if not update:
return return
if last_reaction is not None: if last_reaction is not None:
_ = asyncio.create_task(update.remove_reaction(last_reaction, self.bot.user)) _ = asyncio.create_task(
update.remove_reaction(last_reaction, self.bot.user)
)
if new: if new:
_ = asyncio.create_task(update.add_reaction(new)) _ = asyncio.create_task(update.add_reaction(new))
last_reaction = new last_reaction = new
self.log.info("Waiting for transcode lock to release") self.log.info("Waiting for transcode lock to release")
async with self.transcode_lock: async with self.transcode_lock:
cog = FFMeta(self.bot) cog = FFMeta(self.bot)
@ -147,6 +163,7 @@ class AutoResponder(commands.Cog):
return update_reaction("\N{TIMER CLOCK}\U0000fe0f") return update_reaction("\N{TIMER CLOCK}\U0000fe0f")
streams = info.get("streams", []) streams = info.get("streams", [])
hwaccel = True hwaccel = True
maxrate = "5M"
for stream in streams: for stream in streams:
self.log.info("Found stream: %s", stream.get("codec_name")) self.log.info("Found stream: %s", stream.get("codec_name"))
if stream.get("codec_name") == "hevc": if stream.get("codec_name") == "hevc":
@ -159,8 +176,14 @@ class AutoResponder(commands.Cog):
hwaccel = False hwaccel = False
break break
else: else:
self.log.info("No HEVC streams found in %s", uri) if int(info["format"]["size"]) >= 25 * 1024 * 1024:
return update_reaction() self.log.warning(
"%s is too large to render in discord, compressing", uri
)
maxrate = "1M"
else:
self.log.info("No HEVC streams found in %s", uri)
return update_reaction()
extension = pathlib.Path(uri).suffix extension = pathlib.Path(uri).suffix
with tempfile.NamedTemporaryFile(suffix=extension) as tmp_dl: with tempfile.NamedTemporaryFile(suffix=extension) as tmp_dl:
self.log.info("Downloading %r to %r", uri, tmp_dl.name) self.log.info("Downloading %r to %r", uri, tmp_dl.name)
@ -195,10 +218,8 @@ class AutoResponder(commands.Cog):
tmp_dl.name, tmp_dl.name,
"-c:v", "-c:v",
"libx264", "libx264",
"-crf",
"25",
"-maxrate", "-maxrate",
"5M", maxrate,
"-minrate", "-minrate",
"100K", "100K",
"-bufsize", "-bufsize",
@ -216,7 +237,7 @@ class AutoResponder(commands.Cog):
"-movflags", "-movflags",
"faststart", "faststart",
"-profile:v", "-profile:v",
"main", "high",
"-y", "-y",
"-hide_banner", "-hide_banner",
] ]
@ -230,7 +251,9 @@ class AutoResponder(commands.Cog):
stderr=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE,
) )
stdout, stderr = await process.communicate() stdout, stderr = await process.communicate()
self.log.info("finished transcode with return code %d", process.returncode) self.log.info(
"finished transcode with return code %d", process.returncode
)
self.log.debug("stdout: %r", stdout.decode) self.log.debug("stdout: %r", stdout.decode)
self.log.debug("stderr: %r", stderr.decode) self.log.debug("stderr: %r", stderr.decode)
update_reaction() update_reaction()
@ -243,12 +266,9 @@ class AutoResponder(commands.Cog):
) )
self._cooldown_transcode() self._cooldown_transcode()
return discord.File(tmp_path), tmp_path return discord.File(tmp_path), tmp_path
async def transcode_hevc_to_h264( async def transcode_hevc_to_h264(
self, self, message: discord.Message, *domains: str, additional: Iterable[str] = None
message: discord.Message,
*domains: str,
additional: Iterable[str] = None
) -> None: ) -> None:
if not shutil.which("ffmpeg"): if not shutil.which("ffmpeg"):
self.log.error("ffmpeg not installed") self.log.error("ffmpeg not installed")
@ -283,7 +303,9 @@ class AutoResponder(commands.Cog):
self.log.info("Found link to transcode: %r", link) self.log.info("Found link to transcode: %r", link)
try: try:
async with message.channel.typing(): async with message.channel.typing():
_r = await self._transcode_hevc_to_h264(link, update=message) _r = await self._transcode_hevc_to_h264(
link, update=message
)
if not _r: if not _r:
continue continue
file, _p = _r file, _p = _r
@ -291,7 +313,9 @@ class AutoResponder(commands.Cog):
if _p.stat().st_size <= 24.5 * 1024 * 1024: if _p.stat().st_size <= 24.5 * 1024 * 1024:
await message.add_reaction("\N{OUTBOX TRAY}") await message.add_reaction("\N{OUTBOX TRAY}")
await message.reply(file=file) await message.reply(file=file)
await message.remove_reaction("\N{OUTBOX TRAY}", self.bot.user) await message.remove_reaction(
"\N{OUTBOX TRAY}", self.bot.user
)
else: else:
await message.add_reaction("\N{OUTBOX TRAY}") await message.add_reaction("\N{OUTBOX TRAY}")
self.log.warning( self.log.warning(
@ -301,16 +325,31 @@ class AutoResponder(commands.Cog):
) )
if _p.stat().st_size <= 510 * 1024 * 1024: if _p.stat().st_size <= 510 * 1024 * 1024:
file.fp.seek(0) file.fp.seek(0)
self.log.info("Trying to upload file to pastebin.") self.log.info(
"Trying to upload file to pastebin."
)
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
response = await client.post( response = await client.post(
"https://0x0.st", "https://0x0.st",
files={"file": (_p.name, file.fp, "video/mp4")}, files={
headers={"User-Agent": "CollegeBot (matrix: @nex:nexy7574.co.uk)"}, "file": (
_p.name,
file.fp,
"video/mp4",
)
},
headers={
"User-Agent": "CollegeBot (matrix: @nex:nexy7574.co.uk)"
},
)
await message.remove_reaction(
"\N{OUTBOX TRAY}", self.bot.user
) )
await message.remove_reaction("\N{OUTBOX TRAY}", self.bot.user)
if response.status_code == 200: if response.status_code == 200:
await message.reply("https://embeds.video/" + response.text.strip()) await message.reply(
"https://embeds.video/"
+ response.text.strip()
)
else: else:
await message.add_reaction("\N{BUG}") await message.add_reaction("\N{BUG}")
response.raise_for_status() response.raise_for_status()
@ -321,12 +360,16 @@ class AutoResponder(commands.Cog):
except Exception as e: except Exception as e:
self.log.error("Failed to transcode %r: %r", link, e) self.log.error("Failed to transcode %r: %r", link, e)
async def copy_ncfe_docs(self, message: discord.Message, links: list[ParseResult]) -> None: async def copy_ncfe_docs(
self, message: discord.Message, links: list[ParseResult]
) -> None:
files = [] files = []
if self.config.get("download_pdfs", True) is False: if self.config.get("download_pdfs", True) is False:
self.log.debug("Download PDFs is disabled in config, disengaging.") self.log.debug("Download PDFs is disabled in config, disengaging.")
return return
self.log.info("Preparing to download: %s", ", ".join(map(ParseResult.geturl, links))) self.log.info(
"Preparing to download: %s", ", ".join(map(ParseResult.geturl, links))
)
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
for link in set(links): for link in set(links):
if link.path.endswith(".pdf"): if link.path.endswith(".pdf"):
@ -338,7 +381,7 @@ class AutoResponder(commands.Cog):
"Failed to download %s: HTTP %d - %r", "Failed to download %s: HTTP %d - %r",
link, link,
response.status, response.status,
await response.text() await response.text(),
) )
continue continue
async for chunk in response.content.iter_any(): async for chunk in response.content.iter_any():
@ -351,7 +394,9 @@ class AutoResponder(commands.Cog):
self.log.warning("File was too large to upload. Skipping.") self.log.warning("File was too large to upload. Skipping.")
continue continue
p = pathlib.Path(link.path).name p = pathlib.Path(link.path).name
file = discord.File(buffer, filename=p, description="Copy of %s" % link.geturl()) file = discord.File(
buffer, filename=p, description="Copy of %s" % link.geturl()
)
files.append(file) files.append(file)
for file in files: for file in files:
await message.reply(file=file) await message.reply(file=file)
@ -379,27 +424,38 @@ class AutoResponder(commands.Cog):
self.log.info("Got VHS reaction, scanning for transcode") self.log.info("Got VHS reaction, scanning for transcode")
extra = [] extra = []
if reaction.message.attachments: if reaction.message.attachments:
extra = [attachment.url for attachment in reaction.message.attachments] extra = [
attachment.url for attachment in reaction.message.attachments
]
if self.allow_on_demand_transcoding: if self.allow_on_demand_transcoding:
await self.transcode_hevc_to_h264(reaction.message, additional=extra) await self.transcode_hevc_to_h264(
reaction.message, additional=extra
)
elif str(reaction.emoji) == "\U0001f310": elif str(reaction.emoji) == "\U0001f310":
if reaction.message.id not in self.lmgtfy_cache: if reaction.message.id not in self.lmgtfy_cache:
url = "https://lmddgtfy.net/?q=" + quote_plus(reaction.message.content) url = "https://lmddgtfy.net/?q=" + quote_plus(
m = await reaction.message.reply(f"[Here's the answer to your question]({url})") reaction.message.content
)
m = await reaction.message.reply(
f"[Here's the answer to your question]({url})"
)
await m.edit(suppress=True) await m.edit(suppress=True)
self.lmgtfy_cache.append(reaction.message.id) self.lmgtfy_cache.append(reaction.message.id)
elif str(reaction.emoji)[0] == "\N{wastebasket}": elif str(reaction.emoji)[0] == "\N{WASTEBASKET}":
if reaction.message.channel.permissions_for(reaction.message.guild.me).manage_messages: if reaction.message.channel.permissions_for(
self.log.info("Deleting message %s (Wastebasket)" % reaction.message.jump_url) reaction.message.guild.me
).manage_messages:
self.log.info(
"Deleting message %s (Wastebasket)" % reaction.message.jump_url
)
await reaction.message.delete( await reaction.message.delete(
reason="%s requested deletion of message" % user, reason="%s requested deletion of message" % user, delay=0.2
delay=0.2
) )
else: else:
self.log.warning( self.log.warning(
"Unable to delete message %s (wastebasket) - missing permissions", "Unable to delete message %s (wastebasket) - missing permissions",
reaction.message.jump_url reaction.message.jump_url,
) )
@commands.Cog.listener("on_raw_reaction_add") @commands.Cog.listener("on_raw_reaction_add")
@ -415,8 +471,13 @@ class AutoResponder(commands.Cog):
_e = discord.PartialEmoji.from_str(str(payload.emoji)) _e = discord.PartialEmoji.from_str(str(payload.emoji))
reaction = discord.Reaction( reaction = discord.Reaction(
message=message, message=message,
data={"emoji": _e, "count": 1, "me": payload.user_id == self.bot.user.id, "burst": False}, data={
emoji=payload.emoji "emoji": _e,
"count": 1,
"me": payload.user_id == self.bot.user.id,
"burst": False,
},
emoji=payload.emoji,
) )
user = self.bot.get_user(payload.user_id) user = self.bot.get_user(payload.user_id)
await self.on_reaction_add(reaction, user) await self.on_reaction_add(reaction, user)

View file

@ -1,140 +1,220 @@
""" """
This module is only meant to be loaded during election times. This module is only meant to be loaded during election times.
""" """
import asyncio import asyncio
import logging import logging
import random
import datetime
import re import re
import discord import discord
import httpx import httpx
from bs4 import BeautifulSoup from bs4 import BeautifulSoup
from discord.ext import commands from discord.ext import commands, tasks
SPAN_REGEX = re.compile( SPAN_REGEX = re.compile(
r"^(?P<party>\D+)(?P<councillors>[0-9,]+)\scouncillors\s(?P<net>[0-9,]+)\scouncillors\s(?P<net_change>(gained|lost))$" r"^(?P<party>\D+)(?P<councillors>[0-9,]+)\scouncillors\s(?P<net>[0-9,]+)\scouncillors\s(?P<net_change>(gained|lost))$"
) )
MULTI: dict[str, int] = { MULTI: dict[str, int] = {"gained": 1, "lost": -1}
"gained": 1,
"lost": -1
}
class ElectionCog(commands.Cog): class ElectionCog(commands.Cog):
SOURCE = "https://bbc.com/" SOURCE = "https://bbc.com/"
HEADERS = { HEADERS = {
"Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp," "Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,"
"image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.7", "image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.7",
"Accept-Language": "en-GB,en-US;q=0.9,en;q=0.8", "Accept-Language": "en-GB,en-US;q=0.9,en;q=0.8",
"Priority": "u=0, i", "Priority": "u=0, i",
"Sec-Ch-Ua": "\"Chromium\";v=\"124\", \"Google Chrome\";v=\"124\", \"Not-A.Brand\";v=\"99\"", "Sec-Ch-Ua": '"Chromium";v="124", "Google Chrome";v="124", "Not-A.Brand";v="99"',
"Sec-Ch-Ua-Mobile": "?0", "Sec-Ch-Ua-Mobile": "?0",
"Sec-Ch-Ua-Platform": "\"Linux\"", "Sec-Ch-Ua-Platform": '"Linux"',
"Sec-Fetch-Dest": "document", "Sec-Fetch-Dest": "document",
"Sec-Fetch-Mode": "navigate", "Sec-Fetch-Mode": "navigate",
"Sec-Fetch-Site": "none", "Sec-Fetch-Site": "none",
"Sec-Fetch-User": "?1", "Sec-Fetch-User": "?1",
"Upgrade-Insecure-Requests": "1", "Upgrade-Insecure-Requests": "1",
"User-Agent": "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/124.0.0.0 " "User-Agent": "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/124.0.0.0 "
"Safari/537.36", "Safari/537.36",
} }
ETA = datetime.datetime(
2024, 7, 4, 23, 30, tzinfo=datetime.datetime.now().astimezone().tzinfo
)
def __init__(self, bot): def __init__(self, bot):
self.bot = bot self.bot: commands.Bot = bot
self.log = logging.getLogger("jimmy.cogs.election") self.log = logging.getLogger("jimmy.cogs.election")
self.countdown_message = None
# self.check_election.start()
def cog_unload(self) -> None:
self.check_election.cancel()
@tasks.loop(minutes=1)
async def check_election(self):
if not self.bot.is_ready():
await self.bot.wait_until_ready()
guild = self.bot.get_guild(994710566612500550)
if not guild:
return self.log.error("Nonsense guild not found. Can't do countdown.")
channel = discord.utils.get(guild.text_channels, name="countdown")
if not channel:
return self.log.error("Countdown channel not found.")
await asyncio.sleep(random.randint(0, 10))
now = discord.utils.utcnow()
diff = (self.ETA - now).total_seconds()
if diff < -86400:
return self.log.debug("Countdown long expired.")
if diff > 3600:
hours, remainder = map(round, divmod(diff, 3600))
minutes, seconds = map(round, divmod(remainder, 60))
message = f"+ {hours} hours, {minutes} minutes, and {seconds} seconds."
elif diff > 60:
minutes, seconds = map(round, divmod(diff, 60))
message = f"+ {minutes} minutes and {seconds} seconds."
elif diff >= 0:
message = f"+ {round(diff)} seconds."
else:
message = "Results time!"
if self.countdown_message:
try:
return await self.countdown_message.edit(
content=f"```diff\n{message}```"
)
except discord.HTTPException:
self.log.exception("Failed to edit countdown message.")
self.countdown_message = await channel.send(f"```diff\n{message}```")
self.log.debug("Sent countdown message")
def process_soup(self, soup: BeautifulSoup) -> dict[str, list[int]] | None: def process_soup(self, soup: BeautifulSoup) -> dict[str, list[int]] | None:
good_soup = soup.find(attrs={"data-testid": "election-banner-results-bar"}) good_soups = list(
if not good_soup: soup.find_all(attrs={"data-testid": "election-banner-results-bar"})
)
if not good_soups:
self.log.error(
"No 'election-banner-results-bar' elements found:\n%r", soup.prettify()
)
return return
good_soup = list(good_soups)[1]
css: str = "\n".join([x.get_text() for x in soup.find_all("style")])
def find_colour(style_name: str, want: str = "background-color") -> str | None:
self.log.info("Looking for style %r", style_name)
index = css.index(style_name) + len(style_name) + 1
value = ""
for char in css[index:]:
if char == "}":
break
value += char
attributes = filter(None, value.split(";"))
parsed = {}
for attr in attributes:
name, val = attr.split(":")
parsed[name] = val
self.log.info("Parsed the following attributes: %r", parsed)
if want in parsed:
self.log.info("Returning %r: %r", want, parsed[want])
return parsed[want]
self.log.warning("%r was not in attributes.", want)
results: dict[str, list[int]] = {} results: dict[str, list[int]] = {}
for child_ul in good_soup.children: for child_li in good_soup.children:
child_ul: BeautifulSoup span = list(child_li.children)[-1]
span = child_ul.find("span", recursive=False) try:
if not span: party, extra = span.get_text().strip().split(":", 1)
self.log.warning("%r did not have a 'span' element.", child_ul) seats, extra = extra.split(",", 1)
extra = extra.strip()
if extra.lower() == "no change":
change = 0
else:
_values = extra.split()
change = int(_values[0])
if _values[-1] != "gained":
change *= -1
seats = int(seats.split()[0])
except ValueError:
self.log.error("failed to parse %r", span)
continue continue
results[party] = [seats, change, 0, 0]
text = span.get_text().replace(",", "") for child_li in good_soups[0].children:
groups = SPAN_REGEX.match(text) span = list(child_li.children)[-1]
if groups: try:
groups = groups.groupdict() party, extra = span.get_text().strip().split(":", 1)
else: seats, _ = extra.strip().split(" ", 1)
self.log.warning( seats = int(seats.strip())
"Found span element (%r), however resolved text (%r) did not match regex.", except ValueError:
span, text self.log.error("failed to parse %r", span)
)
continue continue
if party in results:
results[str(groups["party"]).strip()] = [ results[party][3] = seats
int(groups["councillors"].strip()),
int(groups["net"].strip()) * MULTI[groups["net_change"]],
int(find_colour(child_ul.next["class"][0])[1:], base=16)
]
return results return results
async def _get_embed(self) -> discord.Embed | None:
async with httpx.AsyncClient(headers=self.HEADERS) as client:
response = await client.get(self.SOURCE, follow_redirects=True)
if response.status_code != 200:
raise RuntimeError(
f"HTTP {response.status_code} while fetching results from BBC"
)
soup = await asyncio.to_thread(BeautifulSoup, response.text, "html.parser")
results = await self.bot.loop.run_in_executor(None, self.process_soup, soup)
if results:
now = discord.utils.utcnow()
date = now.date().strftime("%B %Y")
colour_scores = {}
embed = discord.Embed(
title="Election results - " + date,
url="https://bbc.co.uk/",
timestamp=now,
)
embed.set_footer(text="Source from bbc.co.uk.")
description_parts = []
for party_name, values in results.items():
councillors, net, colour, last_election = values
colour_scores[party_name] = councillors
symbol = "+" if net > 0 else ""
description_parts.append(
f"**{party_name}**: {symbol}{net:,} ({councillors:,} total, "
f"{last_election:,} predicted by exit poll)"
)
top_party = list(
sorted(
colour_scores.keys(), key=lambda k: colour_scores[k], reverse=True
)
)[0]
embed.colour = discord.Colour(results[top_party][2])
embed.description = "\n".join(description_parts)
return embed
@commands.slash_command(name="election") @commands.slash_command(name="election")
async def get_election_results(self, ctx: discord.ApplicationContext): async def get_election_results(self, ctx: discord.ApplicationContext):
"""Gets the current election results""" """Gets the current election results"""
await ctx.defer()
async with httpx.AsyncClient(headers=self.HEADERS) as client:
response = await client.get(self.SOURCE, follow_redirects=True)
if response.status_code != 200:
return await ctx.respond(
"Sorry, I can't do that right now (HTTP %d while fetching results from BBC)" % response.status_code
)
# noinspection PyTypeChecker class RefreshView(discord.ui.View):
soup = await asyncio.to_thread(BeautifulSoup, response.text, "html.parser") def __init__(self, **kwargs):
results = await self.bot.loop.run_in_executor(None, self.process_soup, soup) super().__init__(**kwargs)
if results: self.last_edit = discord.utils.utcnow()
now = discord.utils.utcnow()
date = now.date().strftime("%B %Y")
colour_scores = {}
embed = discord.Embed(
title="Election results - " + date,
url="https://bbc.co.uk/",
timestamp=now
)
embed.set_footer(text="Source from bbc.co.uk.")
description_parts = []
for party_name, values in results.items(): @discord.ui.button(
councillors, net, colour = values label="Refresh", style=discord.ButtonStyle.primary, emoji="\U0001f501"
colour_scores[party_name] = councillors )
symbol = "+" if net > 0 else '' async def refresh(_self, _btn, interaction):
description_parts.append( await interaction.response.defer(invisible=True)
f"**{party_name}**: {symbol}{net:,} ({councillors:,} total)" if (discord.utils.utcnow() - self.last_edit).total_seconds() < 10:
return await interaction.followup.send("Slow down.", ephemeral=True)
try:
embed = await self._get_embed()
except Exception as e:
self.log.exception("Failed to get election results.")
return await interaction.followup.send(
f"Sorry, I cannot contact the BBC at this time: {e}"
) )
if embed is None:
return await interaction.followup.send(
"Sorry, I could not find any election results."
)
await interaction.edit_original_response(embed=embed)
top_party = list(sorted(colour_scores.keys(), key=lambda k: colour_scores[k], reverse=True))[0] await ctx.defer()
embed.colour = discord.Colour(results[top_party][2]) try:
embed.description = "\n".join(description_parts) embed = await self._get_embed()
return await ctx.respond(embed=embed) except Exception as e:
else: self.log.exception("Failed to get election results.")
return await ctx.respond("Unable to get election results at this time.") return await ctx.respond(
f"Sorry, I cannot contact the BBC at this time: {e}"
)
if embed is None:
return await ctx.respond("Sorry, I could not find any election results.")
await ctx.respond(
embed=embed, view=RefreshView(timeout=3600, disable_on_timeout=True)
)
def setup(bot): def setup(bot):

View file

@ -29,25 +29,44 @@ class FFMeta(commands.Cog):
self.bot = bot self.bot = bot
self.log = logging.getLogger("jimmy.cogs.ffmeta") self.log = logging.getLogger("jimmy.cogs.ffmeta")
def jpegify_image(self, input_file: io.BytesIO, quality: int = 50, image_format: str = "jpeg") -> io.BytesIO: def jpegify_image(
self, input_file: io.BytesIO, quality: int = 50, image_format: str = "jpeg"
) -> io.BytesIO:
quality = min(1, max(quality, 100)) quality = min(1, max(quality, 100))
img_src = PIL.Image.open(input_file) img_src = PIL.Image.open(input_file)
if image_format == "jpeg": if image_format == "jpeg":
img_src = img_src.convert("RGB") img_src = img_src.convert("RGB")
img_dst = io.BytesIO() img_dst = io.BytesIO()
self.log.debug("Saving input file (%r) as %r with quality %r%%", input_file, image_format, quality) self.log.debug(
"Saving input file (%r) as %r with quality %r%%",
input_file,
image_format,
quality,
)
img_src.save(img_dst, format=image_format, quality=quality) img_src.save(img_dst, format=image_format, quality=quality)
img_dst.seek(0) img_dst.seek(0)
return img_dst return img_dst
async def _run_ffprobe(self, uri: str | pathlib.Path, as_json: bool = False) -> dict | str: async def _run_ffprobe(
self, uri: str | pathlib.Path, as_json: bool = False
) -> dict | str:
""" """
Runs ffprobe on the given target (either file path or URL) and returns the result Runs ffprobe on the given target (either file path or URL) and returns the result
:param uri: the URI to run ffprobe on :param uri: the URI to run ffprobe on
:return: The result :return: The result
""" """
_bin = "ffprobe" _bin = "ffprobe"
cmd = ["-hide_banner", "-v", "quiet", "-print_format", "json", "-show_streams", "-show_format", "-i", str(uri)] cmd = [
"-hide_banner",
"-v",
"quiet",
"-print_format",
"json",
"-show_streams",
"-show_format",
"-i",
str(uri),
]
if not as_json: if not as_json:
cmd = ["-hide_banner", "-i", str(uri)] cmd = ["-hide_banner", "-i", str(uri)]
process = await asyncio.create_subprocess_exec( process = await asyncio.create_subprocess_exec(
@ -62,7 +81,12 @@ class FFMeta(commands.Cog):
return stderr.decode(errors="replace") return stderr.decode(errors="replace")
@commands.slash_command() @commands.slash_command()
async def ffprobe(self, ctx: discord.ApplicationContext, url: str = None, attachment: discord.Attachment = None): async def ffprobe(
self,
ctx: discord.ApplicationContext,
url: str = None,
attachment: discord.Attachment = None,
):
"""Runs ffprobe on a given URL or attachment""" """Runs ffprobe on a given URL or attachment"""
if not shutil.which("ffprobe"): if not shutil.which("ffprobe"):
return await ctx.respond("ffprobe is not installed on this system.") return await ctx.respond("ffprobe is not installed on this system.")
@ -101,7 +125,10 @@ class FFMeta(commands.Cog):
image_format: typing.Annotated[ image_format: typing.Annotated[
str, str,
discord.Option( discord.Option(
str, description="The format of the resulting image", choices=["jpeg", "webp"], default="jpeg" str,
description="The format of the resulting image",
choices=["jpeg", "webp"],
default="jpeg",
), ),
] = "jpeg", ] = "jpeg",
): ):
@ -115,7 +142,9 @@ class FFMeta(commands.Cog):
src = io.BytesIO() src = io.BytesIO()
async with httpx.AsyncClient( async with httpx.AsyncClient(
headers={"User-Agent": f"DiscordBot (Jimmy, v2, {VERSION}, +https://github.com/nexy7574/college-bot-v2)"} headers={
"User-Agent": f"DiscordBot (Jimmy, v2, {VERSION}, +https://github.com/nexy7574/college-bot-v2)"
}
) as client: ) as client:
response = await client.get(url) response = await client.get(url)
if response.status_code != 200: if response.status_code != 200:
@ -123,13 +152,17 @@ class FFMeta(commands.Cog):
src.write(response.content) src.write(response.content)
try: try:
dst = await asyncio.to_thread(self.jpegify_image, src, quality, image_format) dst = await asyncio.to_thread(
self.jpegify_image, src, quality, image_format
)
except Exception as e: except Exception as e:
await ctx.respond(f"Failed to convert image: `{e}`.") await ctx.respond(f"Failed to convert image: `{e}`.")
self.log.error("Failed to convert image %r: %r", url, e) self.log.error("Failed to convert image %r: %r", url, e)
return return
else: else:
await ctx.respond(file=discord.File(dst, filename=f"jpegified.{image_format}")) await ctx.respond(
file=discord.File(dst, filename=f"jpegified.{image_format}")
)
@commands.slash_command() @commands.slash_command()
async def opusinate( async def opusinate(
@ -148,7 +181,10 @@ class FFMeta(commands.Cog):
), ),
] = 96, ] = 96,
mono: typing.Annotated[ mono: typing.Annotated[
bool, discord.Option(bool, description="Whether to convert the audio to mono", default=False) bool,
discord.Option(
bool, description="Whether to convert the audio to mono", default=False
),
] = False, ] = False,
): ):
"""Converts a given URL or attachment to an Opus file""" """Converts a given URL or attachment to an Opus file"""
@ -195,7 +231,9 @@ class FFMeta(commands.Cog):
) )
stdout, stderr = await probe_process.communicate() stdout, stderr = await probe_process.communicate()
stdout = stdout.decode("utf-8", "replace") stdout = stdout.decode("utf-8", "replace")
data = {"format": {"duration": 195}} # 3 minutes and 15 seconds is the 2023 average. data = {
"format": {"duration": 195}
} # 3 minutes and 15 seconds is the 2023 average.
if stdout: if stdout:
try: try:
data = json.loads(stdout) data = json.loads(stdout)
@ -206,7 +244,9 @@ class FFMeta(commands.Cog):
max_end_size = ((bitrate * duration * channels) / 8) * 1024 max_end_size = ((bitrate * duration * channels) / 8) * 1024
if max_end_size > (24.75 * 1024 * 1024): if max_end_size > (24.75 * 1024 * 1024):
return await ctx.respond( return await ctx.respond(
"The file would be too large to send ({:,.2f} MiB).".format(max_end_size / 1024 / 1024) "The file would be too large to send ({:,.2f} MiB).".format(
max_end_size / 1024 / 1024
)
) )
process = await asyncio.create_subprocess_exec( process = await asyncio.create_subprocess_exec(
@ -237,7 +277,9 @@ class FFMeta(commands.Cog):
file = io.BytesIO(stdout) file = io.BytesIO(stdout)
if (fs := len(file.getvalue())) > (24.75 * 1024 * 1024): if (fs := len(file.getvalue())) > (24.75 * 1024 * 1024):
return await ctx.respond("The file is too large to send ({:,.2f} MiB).".format(fs / 1024 / 1024)) return await ctx.respond(
"The file is too large to send ({:,.2f} MiB).".format(fs / 1024 / 1024)
)
if not fs: if not fs:
await ctx.respond("Failed to convert audio. See below.") await ctx.respond("Failed to convert audio. See below.")
else: else:
@ -251,9 +293,11 @@ class FFMeta(commands.Cog):
for page in paginator.pages: for page in paginator.pages:
await ctx.respond(page, ephemeral=True) await ctx.respond(page, ephemeral=True)
@commands.slash_command(name="right-behind-you") @commands.slash_command(name="right-behind-you")
async def right_behind_you(self, ctx: discord.ApplicationContext, image: discord.Attachment): async def right_behind_you(
self, ctx: discord.ApplicationContext, image: discord.Attachment
):
await ctx.defer() await ctx.defer()
rbh = Path("assets/right-behind-you.ogg").resolve() rbh = Path("assets/right-behind-you.ogg").resolve()
if not rbh.exists(): if not rbh.exists():
@ -264,7 +308,9 @@ class FFMeta(commands.Cog):
return await ctx.respond("That's not an image!") return await ctx.respond("That's not an image!")
with tempfile.NamedTemporaryFile(suffix=Path(image.filename).suffix) as temp: with tempfile.NamedTemporaryFile(suffix=Path(image.filename).suffix) as temp:
img_tmp = io.BytesIO(await image.read()) img_tmp = io.BytesIO(await image.read())
dst: io.BytesIO = await asyncio.to_thread(self.jpegify_image, img_tmp, 90, "webp") dst: io.BytesIO = await asyncio.to_thread(
self.jpegify_image, img_tmp, 90, "webp"
)
temp.write(dst.getvalue()) temp.write(dst.getvalue())
temp.flush() temp.flush()
process = await asyncio.create_subprocess_exec( process = await asyncio.create_subprocess_exec(
@ -307,16 +353,20 @@ class FFMeta(commands.Cog):
"5", "5",
"pipe:1", "pipe:1",
stdout=asyncio.subprocess.PIPE, stdout=asyncio.subprocess.PIPE,
stderr=sys.stderr stderr=sys.stderr,
) )
stdout, stderr = await process.communicate() stdout, stderr = await process.communicate()
file = io.BytesIO(stdout) file = io.BytesIO(stdout)
if (fs := len(file.getvalue())) > (24.75 * 1024 * 1024): if (fs := len(file.getvalue())) > (24.75 * 1024 * 1024):
return await ctx.respond("The file is too large to send ({:,.2f} MiB).".format(fs / 1024 / 1024)) return await ctx.respond(
"The file is too large to send ({:,.2f} MiB).".format(fs / 1024 / 1024)
)
if not fs: if not fs:
await ctx.respond("Failed to convert audio. See below.") await ctx.respond("Failed to convert audio. See below.")
else: else:
return await ctx.respond(file=discord.File(file, filename="right-behind-you.mp4")) return await ctx.respond(
file=discord.File(file, filename="right-behind-you.mp4")
)
paginator = commands.Paginator() paginator = commands.Paginator()
for line in stderr.decode().splitlines(): for line in stderr.decode().splitlines():
if line.strip().startswith(":"): if line.strip().startswith(":"):

View file

@ -11,13 +11,15 @@ class MeterCog(commands.Cog):
self.bot = bot self.bot = bot
self.log = logging.getLogger("jimmy.cogs.auto_responder") self.log = logging.getLogger("jimmy.cogs.auto_responder")
self.cache = {} self.cache = {}
@commands.slash_command(name="gay-meter") @commands.slash_command(name="gay-meter")
@discord.guild_only() @discord.guild_only()
async def gay_meter(self, ctx: discord.ApplicationContext, user: discord.User = None): async def gay_meter(
self, ctx: discord.ApplicationContext, user: discord.User = None
):
"""Checks how gay someone is""" """Checks how gay someone is"""
user = user or ctx.user user = user or ctx.user
await ctx.respond("Calculating...") await ctx.respond("Calculating...")
for i in range(0, 125, 25): for i in range(0, 125, 25):
await ctx.edit(content="Calculating... %d%%" % i) await ctx.edit(content="Calculating... %d%%" % i)
@ -28,9 +30,11 @@ class MeterCog(commands.Cog):
else: else:
pct = user.id % 100 pct = user.id % 100
await ctx.edit(content=f"{user.mention} is {pct}% gay.") await ctx.edit(content=f"{user.mention} is {pct}% gay.")
@commands.slash_command(name="penis-length") @commands.slash_command(name="penis-length")
async def penis_meter(self, ctx: discord.ApplicationContext, user: discord.User = None): async def penis_meter(
self, ctx: discord.ApplicationContext, user: discord.User = None
):
"""Checks the length of someone's penis.""" """Checks the length of someone's penis."""
user = user or ctx.user user = user or ctx.user
if random.randint(0, 1): if random.randint(0, 1):
@ -49,10 +53,10 @@ class MeterCog(commands.Cog):
return await ctx.respond( return await ctx.respond(
embed=discord.Embed( embed=discord.Embed(
title=f"{user.display_name}'s penis length:", title=f"{user.display_name}'s penis length:",
description="%d cm (%.2f%s)\n%s" % (pct, inch, im, "".join(chunks)) description="%d cm (%.2f%s)\n%s" % (pct, inch, im, "".join(chunks)),
) )
) )
@commands.command() @commands.command()
@commands.is_owner() @commands.is_owner()
async def clear_cache(self, ctx: commands.Context, user: discord.User = None): async def clear_cache(self, ctx: commands.Context, user: discord.User = None):
@ -61,7 +65,7 @@ class MeterCog(commands.Cog):
self.cache = {} self.cache = {}
else: else:
self.cache.pop(user, None) self.cache.pop(user, None)
return await ctx.message.add_reaction("\N{white heavy check mark}") return await ctx.message.add_reaction("\N{WHITE HEAVY CHECK MARK}")
def setup(bot): def setup(bot):

View file

@ -33,17 +33,11 @@ class GetFilteredTextView(discord.ui.View):
self.text = text self.text = text
super().__init__(timeout=600) super().__init__(timeout=600)
@discord.ui.button( @discord.ui.button(label="See filtered data", emoji="\N{INBOX TRAY}")
label="See filtered data",
emoji="\N{INBOX TRAY}"
)
async def see_filtered_data(self, _, interaction: discord.Interaction): async def see_filtered_data(self, _, interaction: discord.Interaction):
await interaction.response.defer(ephemeral=True) await interaction.response.defer(ephemeral=True)
await interaction.followup.send( await interaction.followup.send(
file=discord.File( file=discord.File(io.BytesIO(self.text.encode()), "filtered.txt")
io.BytesIO(self.text.encode()),
"filtered.txt"
)
) )
@ -106,7 +100,11 @@ class NetworkCog(commands.Cog):
def decide(ln: str) -> typing.Optional[bool]: def decide(ln: str) -> typing.Optional[bool]:
if ln.startswith(">>> Last update"): if ln.startswith(">>> Last update"):
return return
if "REDACTED" in ln or "Please query the WHOIS server of the owning registrar" in ln or ":" not in ln: if (
"REDACTED" in ln
or "Please query the WHOIS server of the owning registrar" in ln
or ":" not in ln
):
return False return False
else: else:
return True return True
@ -134,7 +132,9 @@ class NetworkCog(commands.Cog):
if not paginator.pages: if not paginator.pages:
stdout, stderr, status = await run_command(with_disclaimer=True) stdout, stderr, status = await run_command(with_disclaimer=True)
if not any((stdout, stderr)): if not any((stdout, stderr)):
return await ctx.respond(f"No output was returned with status code {status}.") return await ctx.respond(
f"No output was returned with status code {status}."
)
file = io.BytesIO() file = io.BytesIO()
file.write(stdout) file.write(stdout)
if stderr: if stderr:
@ -143,7 +143,7 @@ class NetworkCog(commands.Cog):
file.seek(0) file.seek(0)
return await ctx.respond( return await ctx.respond(
"Seemingly all output was filtered. Returning raw command output.", "Seemingly all output was filtered. Returning raw command output.",
file=discord.File(file, "whois.txt") file=discord.File(file, "whois.txt"),
) )
last: discord.Interaction | discord.WebhookMessage | None = None last: discord.Interaction | discord.WebhookMessage | None = None
@ -223,9 +223,15 @@ class NetworkCog(commands.Cog):
default="default", default="default",
), ),
use_ip_version: discord.Option( use_ip_version: discord.Option(
str, name="ip-version", description="IP version to use.", choices=["ipv4", "ipv6"], default="ipv4" str,
name="ip-version",
description="IP version to use.",
choices=["ipv4", "ipv6"],
default="ipv4",
),
max_ttl: discord.Option(
int, name="ttl", description="Max number of hops", default=30
), ),
max_ttl: discord.Option(int, name="ttl", description="Max number of hops", default=30),
): ):
"""Performs a traceroute request.""" """Performs a traceroute request."""
await ctx.defer() await ctx.defer()
@ -258,7 +264,9 @@ class NetworkCog(commands.Cog):
args.append(str(port)) args.append(str(port))
args.append(url) args.append(url)
paginator = commands.Paginator() paginator = commands.Paginator()
paginator.add_line(f"Running command: {' '.join(args[3 if args[0] == 'sudo' else 0:])}") paginator.add_line(
f"Running command: {' '.join(args[3 if args[0] == 'sudo' else 0:])}"
)
paginator.add_line(empty=True) paginator.add_line(empty=True)
try: try:
start = time.time_ns() start = time.time_ns()
@ -297,10 +305,7 @@ class NetworkCog(commands.Cog):
await ctx.respond(file=discord.File(f)) await ctx.respond(file=discord.File(f))
async def _fetch_ip_response( async def _fetch_ip_response(
self, self, server: str, lookup: str, client: httpx.AsyncClient
server: str,
lookup: str,
client: httpx.AsyncClient
) -> tuple[dict, float] | httpx.HTTPError | ConnectionError | json.JSONDecodeError: ) -> tuple[dict, float] | httpx.HTTPError | ConnectionError | json.JSONDecodeError:
try: try:
start = time.perf_counter() start = time.perf_counter()
@ -315,18 +320,16 @@ class NetworkCog(commands.Cog):
async def get_ip_address(self, ctx: discord.ApplicationContext, lookup: str = None): async def get_ip_address(self, ctx: discord.ApplicationContext, lookup: str = None):
"""Fetches IP info from SHRONK IP servers""" """Fetches IP info from SHRONK IP servers"""
await ctx.defer() await ctx.defer()
async with httpx.AsyncClient(headers={"User-Agent": "Mozilla/5.0 Jimmy/v2"}) as client: async with httpx.AsyncClient(
headers={"User-Agent": "Mozilla/5.0 Jimmy/v2"}
) as client:
if not lookup: if not lookup:
response = await client.get("https://api.ipify.org") response = await client.get("https://api.ipify.org")
lookup = response.text lookup = response.text
servers = self.config.get( servers = self.config.get(
"ip_servers", "ip_servers",
[ ["ip.shronk.net", "ip.i-am.nexus", "ip.shronk.nicroxio.co.uk"],
"ip.shronk.net",
"ip.i-am.nexus",
"ip.shronk.nicroxio.co.uk"
]
) )
embed = discord.Embed( embed = discord.Embed(
title="IP lookup information for: %s" % lookup, title="IP lookup information for: %s" % lookup,
@ -353,13 +356,14 @@ class NetworkCog(commands.Cog):
t = response.text[:512] t = response.text[:512]
embed.add_field( embed.add_field(
name=server, name=server,
value=f"An error occurred while parsing the data: {e}\nData: ```\n%s\n```" % t, value=f"An error occurred while parsing the data: {e}\nData: ```\n%s\n```"
% t,
) )
else: else:
embed.add_field( embed.add_field(
name="%s (%.2fms)" % (server, (end - start) * 1000), name="%s (%.2fms)" % (server, (end - start) * 1000),
value="```json\n%s\n```" % v, value="```json\n%s\n```" % v,
inline=False inline=False,
) )
await ctx.respond(embed=embed) await ctx.respond(embed=embed)
@ -367,41 +371,41 @@ class NetworkCog(commands.Cog):
@commands.slash_command() @commands.slash_command()
@commands.max_concurrency(1, commands.BucketType.user) @commands.max_concurrency(1, commands.BucketType.user)
async def nmap( async def nmap(
self, self,
ctx: discord.ApplicationContext, ctx: discord.ApplicationContext,
target: str, target: str,
technique: typing.Annotated[ technique: typing.Annotated[
str,
discord.Option(
str, str,
discord.Option( choices=[
str, discord.OptionChoice(name="TCP SYN", value="S"),
choices=[ discord.OptionChoice(name="TCP Connect", value="T"),
discord.OptionChoice(name="TCP SYN", value="S"), discord.OptionChoice(name="TCP ACK", value="A"),
discord.OptionChoice(name="TCP Connect", value="T"), discord.OptionChoice(name="TCP Window", value="W"),
discord.OptionChoice(name="TCP ACK", value="A"), discord.OptionChoice(name="TCP Maimon", value="M"),
discord.OptionChoice(name="TCP Window", value="W"), discord.OptionChoice(name="UDP", value="U"),
discord.OptionChoice(name="TCP Maimon", value="M"), discord.OptionChoice(name="TCP Null", value="N"),
discord.OptionChoice(name="UDP", value="U"), discord.OptionChoice(name="TCP FIN", value="F"),
discord.OptionChoice(name="TCP Null", value="N"), discord.OptionChoice(name="TCP XMAS", value="X"),
discord.OptionChoice(name="TCP FIN", value="F"), ],
discord.OptionChoice(name="TCP XMAS", value="X"), default="T",
], ),
default="T" ] = "T",
) treat_all_hosts_online: bool = False,
] = "T", service_scan: bool = False,
treat_all_hosts_online: bool = False, fast_mode: bool = False,
service_scan: bool = False, enable_os_detection: bool = False,
fast_mode: bool = False, timing: typing.Annotated[
enable_os_detection: bool = False, int,
timing: typing.Annotated[ discord.Option(
int, int,
discord.Option( description="Timing template to use 0 is slowest, 5 is fastest.",
int, choices=[0, 1, 2, 3, 4, 5],
description="Timing template to use 0 is slowest, 5 is fastest.", default=3,
choices=[0, 1, 2, 3, 4, 5], ),
default=3 ] = 3,
) ports: str = None,
] = 3,
ports: str = None
): ):
"""Runs nmap on a target. You cannot specify multiple targets.""" """Runs nmap on a target. You cannot specify multiple targets."""
await ctx.defer() await ctx.defer()
@ -420,7 +424,9 @@ class NetworkCog(commands.Cog):
if enable_os_detection and not is_superuser: if enable_os_detection and not is_superuser:
return await ctx.respond("OS detection is not available on this system.") return await ctx.respond("OS detection is not available on this system.")
with tempfile.TemporaryDirectory(prefix=f"nmap-{ctx.user.id}-{discord.utils.utcnow().timestamp():.0f}") as tmp: with tempfile.TemporaryDirectory(
prefix=f"nmap-{ctx.user.id}-{discord.utils.utcnow().timestamp():.0f}"
) as tmp:
tmp_dir = Path(tmp) tmp_dir = Path(tmp)
args = [ args = [
"nmap", "nmap",
@ -430,7 +436,7 @@ class NetworkCog(commands.Cog):
str(timing), str(timing),
"-s" + technique, "-s" + technique,
"--reason", "--reason",
"--noninteractive" "--noninteractive",
] ]
if treat_all_hosts_online: if treat_all_hosts_online:
args.append("-Pn") args.append("-Pn")
@ -447,8 +453,7 @@ class NetworkCog(commands.Cog):
await ctx.respond( await ctx.respond(
embed=discord.Embed( embed=discord.Embed(
title="Running nmap...", title="Running nmap...",
description="Command:\n" description="Command:\n" "```{}```".format(shlex.join(args)),
"```{}```".format(shlex.join(args)),
) )
) )
process = await asyncio.create_subprocess_exec( process = await asyncio.create_subprocess_exec(
@ -457,14 +462,16 @@ class NetworkCog(commands.Cog):
stderr=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE,
) )
_, stderr = await process.communicate() _, stderr = await process.communicate()
files = [discord.File(x, filename=x.name + ".txt") for x in tmp_dir.iterdir()] files = [
discord.File(x, filename=x.name + ".txt") for x in tmp_dir.iterdir()
]
if not files: if not files:
if len(stderr) <= 4089: if len(stderr) <= 4089:
return await ctx.edit( return await ctx.edit(
embed=discord.Embed( embed=discord.Embed(
title="Nmap failed.", title="Nmap failed.",
description="```\n" + stderr.decode() + "```", description="```\n" + stderr.decode() + "```",
color=discord.Color.red() color=discord.Color.red(),
) )
) )
@ -475,19 +482,19 @@ class NetworkCog(commands.Cog):
embed=discord.Embed( embed=discord.Embed(
title="Nmap failed.", title="Nmap failed.",
description=f"Output was too long. [View full output]({result.url})", description=f"Output was too long. [View full output]({result.url})",
color=discord.Color.red() color=discord.Color.red(),
) )
) )
await ctx.edit( await ctx.edit(
embed=discord.Embed( embed=discord.Embed(
title="Nmap finished!", title="Nmap finished!",
description="Result files are attached.\n" description="Result files are attached.\n"
"* `gnmap` is 'greppable'\n" "* `gnmap` is 'greppable'\n"
"* `xml` is XML output\n" "* `xml` is XML output\n"
"* `nmap` is normal output", "* `nmap` is normal output",
color=discord.Color.green() color=discord.Color.green(),
), ),
files=files files=files,
) )

View file

@ -107,7 +107,9 @@ class OllamaDownloadHandler:
async def __aiter__(self): async def __aiter__(self):
async with aiohttp.ClientSession(base_url=self.base_url) as client: async with aiohttp.ClientSession(base_url=self.base_url) as client:
async with client.post("/api/pull", json={"name": self.model, "stream": True}, timeout=None) as response: async with client.post(
"/api/pull", json={"name": self.model, "stream": True}, timeout=None
) as response:
response.raise_for_status() response.raise_for_status()
async for line in ollama_stream(response.content): async for line in ollama_stream(response.content):
self.parse_line(line) self.parse_line(line)
@ -122,7 +124,7 @@ class OllamaDownloadHandler:
"Downloading orca-mini:7b on server %r - %s (%.2f%%)", "Downloading orca-mini:7b on server %r - %s (%.2f%%)",
self.base_url, self.base_url,
self.status, self.status,
self.percent self.percent,
) )
return self return self
@ -176,12 +178,13 @@ class OllamaChatHandler:
async def __aiter__(self): async def __aiter__(self):
async with aiohttp.ClientSession(base_url=self.base_url) as client: async with aiohttp.ClientSession(base_url=self.base_url) as client:
async with client.post( async with client.post(
"/api/chat", json={ "/api/chat",
"model": self.model, json={
"stream": True, "model": self.model,
"stream": True,
"messages": self.messages, "messages": self.messages,
"options": self.options "options": self.options,
} },
) as response: ) as response:
response.raise_for_status() response.raise_for_status()
async for line in ollama_stream(response.content): async for line in ollama_stream(response.content):
@ -201,7 +204,9 @@ class OllamaClient:
self.base_url = base_url self.base_url = base_url
self.authorisation = authorisation self.authorisation = authorisation
def with_client(self, timeout: aiohttp.ClientTimeout | float | int | None = None) -> aiohttp.ClientSession: def with_client(
self, timeout: aiohttp.ClientTimeout | float | int | None = None
) -> aiohttp.ClientSession:
""" """
Creates an instance for a request, with properly populated values. Creates an instance for a request, with properly populated values.
:param timeout: :param timeout:
@ -213,9 +218,13 @@ class OllamaClient:
timeout = aiohttp.ClientTimeout(timeout) timeout = aiohttp.ClientTimeout(timeout)
else: else:
timeout = timeout or aiohttp.ClientTimeout(120) timeout = timeout or aiohttp.ClientTimeout(120)
return aiohttp.ClientSession(self.base_url, timeout=timeout, auth=self.authorisation) return aiohttp.ClientSession(
self.base_url, timeout=timeout, auth=self.authorisation
)
async def get_tags(self) -> dict[typing.Literal["models"], dict[str, str, int, dict[str, str, None]]]: async def get_tags(
self,
) -> dict[typing.Literal["models"], dict[str, str, int, dict[str, str, None]]]:
""" """
Gets the tags for the server. Gets the tags for the server.
:return: :return:
@ -250,7 +259,7 @@ class OllamaClient:
self, self,
model: str, model: str,
messages: list[dict[str, str]], messages: list[dict[str, str]],
options: dict[str, typing.Any] = None options: dict[str, typing.Any] = None,
) -> OllamaChatHandler: ) -> OllamaChatHandler:
""" """
Starts a chat with the given messages. Starts a chat with the given messages.
@ -273,7 +282,11 @@ class OllamaView(View):
async def interaction_check(self, interaction: discord.Interaction) -> bool: async def interaction_check(self, interaction: discord.Interaction) -> bool:
return interaction.user == self.ctx.user return interaction.user == self.ctx.user
@button(label="Stop", style=discord.ButtonStyle.danger, emoji="\N{wastebasket}\U0000fe0f") @button(
label="Stop",
style=discord.ButtonStyle.danger,
emoji="\N{WASTEBASKET}\U0000fe0f",
)
async def _stop(self, btn: discord.ui.Button, interaction: discord.Interaction): async def _stop(self, btn: discord.ui.Button, interaction: discord.Interaction):
self.cancel.set() self.cancel.set()
btn.disabled = True btn.disabled = True
@ -310,14 +323,20 @@ class ChatHistory:
:return: The thread's ID. :return: The thread's ID.
""" """
key = os.urandom(3).hex() key = os.urandom(3).hex()
self._internal[key] = {"member": member.id, "seed": round(time.time()), "messages": []} self._internal[key] = {
"member": member.id,
"seed": round(time.time()),
"messages": [],
}
with open("./assets/ollama-prompt.txt") as file: with open("./assets/ollama-prompt.txt") as file:
system_prompt = default or file.read() system_prompt = default or file.read()
self.add_message(key, "system", system_prompt) self.add_message(key, "system", system_prompt)
return key return key
@staticmethod @staticmethod
def _construct_message(role: str, content: str, images: typing.Optional[list[str]]) -> dict[str, str]: def _construct_message(
role: str, content: str, images: typing.Optional[list[str]]
) -> dict[str, str]:
x = {"role": role, "content": content} x = {"role": role, "content": content}
if images: if images:
x["images"] = images x["images"] = images
@ -331,7 +350,9 @@ class ChatHistory:
return list( return list(
filter( filter(
lambda v: (ctx.value or v) in v, lambda v: (ctx.value or v) in v,
map(lambda d: list(d.keys()), instance.threads_for(ctx.interaction.user)), map(
lambda d: list(d.keys()), instance.threads_for(ctx.interaction.user)
),
) )
) )
@ -339,7 +360,9 @@ class ChatHistory:
"""Returns all saved threads.""" """Returns all saved threads."""
return self._internal.copy() return self._internal.copy()
def threads_for(self, user: discord.Member) -> dict[str, dict[str, list[dict[str, str]] | int]]: def threads_for(
self, user: discord.Member
) -> dict[str, dict[str, list[dict[str, str]] | int]]:
"""Returns all saved threads for a specific user""" """Returns all saved threads for a specific user"""
t = self.all_threads() t = self.all_threads()
for k, v in t.copy().items(): for k, v in t.copy().items():
@ -354,7 +377,7 @@ class ChatHistory:
content: str, content: str,
images: typing.Optional[list[str]] = None, images: typing.Optional[list[str]] = None,
*, *,
save: bool = True save: bool = True,
) -> None: ) -> None:
""" """
Appends a message to the given thread. Appends a message to the given thread.
@ -380,7 +403,9 @@ class ChatHistory:
return [] return []
return self._internal[thread]["messages"].copy() # copy() makes it immutable. return self._internal[thread]["messages"].copy() # copy() makes it immutable.
def get_thread(self, thread: str) -> dict[str, list[dict[str, str]] | discord.Member | int]: def get_thread(
self, thread: str
) -> dict[str, list[dict[str, str]] | discord.Member | int]:
"""Gets a copy of an entire thread""" """Gets a copy of an entire thread"""
return self._internal.get(thread, {}).copy() return self._internal.get(thread, {}).copy()
@ -401,7 +426,6 @@ SERVER_KEYS_AUTOCOMPLETE.remove("order")
class OllamaGetPrompt(discord.ui.Modal): class OllamaGetPrompt(discord.ui.Modal):
def __init__(self, ctx: discord.ApplicationContext, prompt_type: str = "User"): def __init__(self, ctx: discord.ApplicationContext, prompt_type: str = "User"):
super().__init__( super().__init__(
discord.ui.InputText( discord.ui.InputText(
@ -441,7 +465,9 @@ class PromptSelector(discord.ui.View):
if self.user_prompt is not None: if self.user_prompt is not None:
self.get_item("usr").style = discord.ButtonStyle.secondary # type: ignore self.get_item("usr").style = discord.ButtonStyle.secondary # type: ignore
@discord.ui.button(label="Set System Prompt", style=discord.ButtonStyle.primary, custom_id="sys") @discord.ui.button(
label="Set System Prompt", style=discord.ButtonStyle.primary, custom_id="sys"
)
async def set_system_prompt(self, btn: discord.ui.Button, interaction: Interaction): async def set_system_prompt(self, btn: discord.ui.Button, interaction: Interaction):
modal = OllamaGetPrompt(self.ctx, "System") modal = OllamaGetPrompt(self.ctx, "System")
await interaction.response.send_modal(modal) await interaction.response.send_modal(modal)
@ -450,7 +476,9 @@ class PromptSelector(discord.ui.View):
self.update_ui() self.update_ui()
await interaction.edit_original_response(view=self) await interaction.edit_original_response(view=self)
@discord.ui.button(label="Set User Prompt", style=discord.ButtonStyle.primary, custom_id="usr") @discord.ui.button(
label="Set User Prompt", style=discord.ButtonStyle.primary, custom_id="usr"
)
async def set_user_prompt(self, btn: discord.ui.Button, interaction: Interaction): async def set_user_prompt(self, btn: discord.ui.Button, interaction: Interaction):
modal = OllamaGetPrompt(self.ctx) modal = OllamaGetPrompt(self.ctx)
await interaction.response.send_modal(modal) await interaction.response.send_modal(modal)
@ -459,7 +487,9 @@ class PromptSelector(discord.ui.View):
self.update_ui() self.update_ui()
await interaction.edit_original_response(view=self) await interaction.edit_original_response(view=self)
@discord.ui.button(label="Done", style=discord.ButtonStyle.success, custom_id="done") @discord.ui.button(
label="Done", style=discord.ButtonStyle.success, custom_id="done"
)
async def done(self, btn: discord.ui.Button, interaction: Interaction): async def done(self, btn: discord.ui.Button, interaction: Interaction):
self.ctx.interaction = interaction self.ctx.interaction = interaction
self.stop() self.stop()
@ -474,13 +504,17 @@ class ConfirmCPURun(discord.ui.View):
async def interaction_check(self, interaction: Interaction) -> bool: async def interaction_check(self, interaction: Interaction) -> bool:
return interaction.user == self.ctx.user return interaction.user == self.ctx.user
@discord.ui.button(label="Run on CPU", style=discord.ButtonStyle.primary, custom_id="cpu") @discord.ui.button(
label="Run on CPU", style=discord.ButtonStyle.primary, custom_id="cpu"
)
async def run_on_cpu(self, btn: discord.ui.Button, interaction: Interaction): async def run_on_cpu(self, btn: discord.ui.Button, interaction: Interaction):
await interaction.response.defer(invisible=True) await interaction.response.defer(invisible=True)
self.proceed = True self.proceed = True
self.stop() self.stop()
@discord.ui.button(label="Abort", style=discord.ButtonStyle.primary, custom_id="gpu") @discord.ui.button(
label="Abort", style=discord.ButtonStyle.primary, custom_id="gpu"
)
async def run_on_gpu(self, btn: discord.ui.Button, interaction: Interaction): async def run_on_gpu(self, btn: discord.ui.Button, interaction: Interaction):
await interaction.response.defer(invisible=True) await interaction.response.defer(invisible=True)
self.stop() self.stop()
@ -492,9 +526,7 @@ class Ollama(commands.Cog):
self.log = logging.getLogger("jimmy.cogs.ollama") self.log = logging.getLogger("jimmy.cogs.ollama")
self.contexts = {} self.contexts = {}
self.history = ChatHistory() self.history = ChatHistory()
self.servers = { self.servers = {server: asyncio.Lock() for server in CONFIG["ollama"]}
server: asyncio.Lock() for server in CONFIG["ollama"]
}
self.servers.pop("order", None) self.servers.pop("order", None)
if CONFIG["ollama"].get("order"): if CONFIG["ollama"].get("order"):
self.servers = {} self.servers = {}
@ -520,7 +552,9 @@ class Ollama(commands.Cog):
if url in SERVER_KEYS: if url in SERVER_KEYS:
url = CONFIG["ollama"][url]["base_url"] url = CONFIG["ollama"][url]["base_url"]
async with aiohttp.ClientSession( async with aiohttp.ClientSession(
timeout=aiohttp.ClientTimeout(connect=3, sock_connect=3, sock_read=10, total=3) timeout=aiohttp.ClientTimeout(
connect=3, sock_connect=3, sock_read=10, total=3
)
) as session: ) as session:
self.log.info("Checking if %r is online.", url) self.log.info("Checking if %r is online.", url)
try: try:
@ -551,20 +585,37 @@ class Ollama(commands.Cog):
), ),
], ],
server: typing.Annotated[ server: typing.Annotated[
str, discord.Option(str, "The server to use for ollama.", default="next", choices=SERVER_KEYS_AUTOCOMPLETE) str,
discord.Option(
str,
"The server to use for ollama.",
default="next",
choices=SERVER_KEYS_AUTOCOMPLETE,
),
], ],
context: typing.Annotated[ context: typing.Annotated[
str, discord.Option(str, "The context key of a previous ollama response to use as context.", default=None) str,
discord.Option(
str,
"The context key of a previous ollama response to use as context.",
default=None,
),
], ],
give_acid: typing.Annotated[ give_acid: typing.Annotated[
bool, bool,
discord.Option( discord.Option(
bool, "Whether to give the AI acid, LSD, and other hallucinogens before responding.", default=False bool,
"Whether to give the AI acid, LSD, and other hallucinogens before responding.",
default=False,
), ),
], ],
image: typing.Annotated[ image: typing.Annotated[
discord.Attachment, discord.Attachment,
discord.Option(discord.Attachment, "An image to feed into ollama. Only works with llava.", default=None), discord.Option(
discord.Attachment,
"An image to feed into ollama. Only works with llava.",
default=None,
),
], ],
): ):
if server == "next": if server == "next":
@ -587,7 +638,7 @@ class Ollama(commands.Cog):
await ctx.respond( await ctx.respond(
"Select edit your prompts, as desired. Click done when you want to continue.", "Select edit your prompts, as desired. Click done when you want to continue.",
view=v, view=v,
ephemeral=True ephemeral=True,
) )
await v.wait() await v.wait()
query = v.user_prompt or query query = v.user_prompt or query
@ -604,22 +655,23 @@ class Ollama(commands.Cog):
self.log.debug("Resolved model to %r" % model) self.log.debug("Resolved model to %r" % model)
if image: if image:
patterns = [ patterns = ["llava:*", "llava-llama*:*"]
"llava:*",
"llava-llama*:*"
]
if any(fnmatch(model, p) for p in patterns) is False: if any(fnmatch(model, p) for p in patterns) is False:
await ctx.send( await ctx.send(
f"{ctx.user.mention}: You can only use images with llava. Switching model to `llava:latest`.", f"{ctx.user.mention}: You can only use images with llava. Switching model to `llava:latest`.",
delete_after=5 delete_after=5,
) )
model = "llava:latest" model = "llava:latest"
if image.size > 1024 * 1024 * 25: if image.size > 1024 * 1024 * 25:
await ctx.respond("Attachment is too large. Maximum size is 25 MB, for sanity. Try compressing it.") await ctx.respond(
"Attachment is too large. Maximum size is 25 MB, for sanity. Try compressing it."
)
return return
elif not fnmatch(image.content_type, "image/*"): elif not fnmatch(image.content_type, "image/*"):
await ctx.respond("Attachment is not an image. Try using a different file.") await ctx.respond(
"Attachment is not an image. Try using a different file."
)
return return
else: else:
data = io.BytesIO() data = io.BytesIO()
@ -638,13 +690,19 @@ class Ollama(commands.Cog):
if fnmatch(model, model_pattern): if fnmatch(model, model_pattern):
break break
else: else:
allowed_models = ", ".join(map(discord.utils.escape_markdown, server_config["allowed_models"])) allowed_models = ", ".join(
await ctx.respond(f"Invalid model. You can only use one of the following models: {allowed_models}") map(discord.utils.escape_markdown, server_config["allowed_models"])
)
await ctx.respond(
f"Invalid model. You can only use one of the following models: {allowed_models}"
)
return return
async with aiohttp.ClientSession( async with aiohttp.ClientSession(
base_url=server_config["base_url"], base_url=server_config["base_url"],
timeout=aiohttp.ClientTimeout(connect=5, sock_read=10800, sock_connect=5, total=10830), timeout=aiohttp.ClientTimeout(
connect=5, sock_read=10800, sock_connect=5, total=10830
),
) as session: ) as session:
embed = discord.Embed( embed = discord.Embed(
title="Checking server...", title="Checking server...",
@ -652,26 +710,32 @@ class Ollama(commands.Cog):
color=discord.Color.blurple(), color=discord.Color.blurple(),
timestamp=discord.utils.utcnow(), timestamp=discord.utils.utcnow(),
) )
embed.set_footer(text="Using server %r" % server, icon_url=server_config.get("icon_url")) embed.set_footer(
text="Using server %r" % server, icon_url=server_config.get("icon_url")
)
await ctx.respond(embed=embed) await ctx.respond(embed=embed)
if not await self.check_server(server_config["base_url"]): if not await self.check_server(server_config["base_url"]):
tried = {server} tried = {server}
for i in range(10): for i in range(10):
try: try:
server = self.next_server(tried) server = self.next_server(tried)
if server and CONFIG["ollama"]["server"].get("is_gpu", False) is not True: if (
server
and CONFIG["ollama"]["server"].get("is_gpu", False)
is not True
):
cf = ConfirmCPURun(ctx) cf = ConfirmCPURun(ctx)
await ctx.edit( await ctx.edit(
embed=discord.Embed( embed=discord.Embed(
title=f"Server {server} is available, but...", title=f"Server {server} is available, but...",
description="It is CPU only, which means it is very slow and will likely crash.\n" description="It is CPU only, which means it is very slow and will likely crash.\n"
"If you really want, you can continue your generation, using this " "If you really want, you can continue your generation, using this "
"server. Be aware though, once in motion, it cannot be stopped.\n\n" "server. Be aware though, once in motion, it cannot be stopped.\n\n"
"" ""
"Continue?", "Continue?",
color=discord.Color.red(), color=discord.Color.red(),
), ),
view=cf view=cf,
) )
await cf.wait() await cf.wait()
await ctx.edit(view=None) await ctx.edit(view=None)
@ -688,7 +752,10 @@ class Ollama(commands.Cog):
color=discord.Color.gold(), color=discord.Color.gold(),
timestamp=discord.utils.utcnow(), timestamp=discord.utils.utcnow(),
) )
embed.set_footer(text="Using server %r" % server, icon_url=server_config.get("icon_url")) embed.set_footer(
text="Using server %r" % server,
icon_url=server_config.get("icon_url"),
)
await ctx.edit(embed=embed) await ctx.edit(embed=embed)
await asyncio.sleep(1) await asyncio.sleep(1)
if await self.check_server(CONFIG["ollama"][server]["base_url"]): if await self.check_server(CONFIG["ollama"][server]["base_url"]):
@ -713,7 +780,9 @@ class Ollama(commands.Cog):
embed = discord.Embed( embed = discord.Embed(
url=resp.url, url=resp.url,
title=f"HTTP {resp.status} {resp.reason!r} while checking for model.", title=f"HTTP {resp.status} {resp.reason!r} while checking for model.",
description=f"```{await resp.text() or 'No response body'}```"[:4096], description=f"```{await resp.text() or 'No response body'}```"[
:4096
],
color=discord.Color.red(), color=discord.Color.red(),
timestamp=discord.utils.utcnow(), timestamp=discord.utils.utcnow(),
) )
@ -733,8 +802,8 @@ class Ollama(commands.Cog):
self.log.debug("Beginning download of %r", model) self.log.debug("Beginning download of %r", model)
def progress_bar(_v: float, action: str = None, _mbps: float = None): def progress_bar(_v: float, action: str = None, _mbps: float = None):
bar = "\N{large green square}" * round(_v / 10) bar = "\N{LARGE GREEN SQUARE}" * round(_v / 10)
bar += "\N{white large square}" * (10 - len(bar)) bar += "\N{WHITE LARGE SQUARE}" * (10 - len(bar))
bar += f" {_v:.2f}%" bar += f" {_v:.2f}%"
if _mbps: if _mbps:
bar += f" ({_mbps:.2f} MiB/s)" bar += f" ({_mbps:.2f} MiB/s)"
@ -753,12 +822,16 @@ class Ollama(commands.Cog):
last_update = time.time() last_update = time.time()
async with session.post("/api/pull", json={"name": model, "stream": True}, timeout=None) as response: async with session.post(
"/api/pull", json={"name": model, "stream": True}, timeout=None
) as response:
if response.status != 200: if response.status != 200:
embed = discord.Embed( embed = discord.Embed(
url=response.url, url=response.url,
title=f"HTTP {response.status} {response.reason!r} while downloading model.", title=f"HTTP {response.status} {response.reason!r} while downloading model.",
description=f"```{await response.text() or 'No response body'}```"[:4096], description=f"```{await response.text() or 'No response body'}```"[
:4096
],
color=discord.Color.red(), color=discord.Color.red(),
timestamp=discord.utils.utcnow(), timestamp=discord.utils.utcnow(),
) )
@ -775,7 +848,10 @@ class Ollama(commands.Cog):
) )
return await ctx.edit(embed=embed, view=None) return await ctx.edit(embed=embed, view=None)
if time.time() >= (last_update + 5.1): if time.time() >= (last_update + 5.1):
if line.get("total") is not None and line.get("completed") is not None: if (
line.get("total") is not None
and line.get("completed") is not None
):
new_bytes = line["completed"] - last_downloaded new_bytes = line["completed"] - last_downloaded
mbps = new_bytes / 1024 / 1024 / 5 mbps = new_bytes / 1024 / 1024 / 5
last_downloaded = line["completed"] last_downloaded = line["completed"]
@ -784,7 +860,9 @@ class Ollama(commands.Cog):
percent = 50.0 percent = 50.0
mbps = 0.0 mbps = 0.0
embed.fields[0].value = progress_bar(percent, line["status"], mbps) embed.fields[0].value = progress_bar(
percent, line["status"], mbps
)
await ctx.edit(embed=embed, view=view) await ctx.edit(embed=embed, view=view)
last_update = time.time() last_update = time.time()
else: else:
@ -799,13 +877,13 @@ class Ollama(commands.Cog):
embed=discord.Embed( embed=discord.Embed(
title="Before you continue", title="Before you continue",
description="You've selected a CPU-only server. This will be really slow. This will also likely" description="You've selected a CPU-only server. This will be really slow. This will also likely"
" bring the host to a halt. Consider the pain the CPU is about to endure. " " bring the host to a halt. Consider the pain the CPU is about to endure. "
"Are you super" "Are you super"
" sure you want to continue? You can run `h!ollama-status` to see what servers are" " sure you want to continue? You can run `h!ollama-status` to see what servers are"
" available.", " available.",
color=discord.Color.red(), color=discord.Color.red(),
), ),
view=cf2 view=cf2,
) )
await cf2.wait() await cf2.wait()
if cf2.proceed is False: if cf2.proceed is False:
@ -825,9 +903,13 @@ class Ollama(commands.Cog):
icon_url="https://ollama.com/public/ollama.png", icon_url="https://ollama.com/public/ollama.png",
) )
embed.add_field( embed.add_field(
name="Prompt", value=">>> " + textwrap.shorten(query, width=1020, placeholder="..."), inline=False name="Prompt",
value=">>> " + textwrap.shorten(query, width=1020, placeholder="..."),
inline=False,
)
embed.set_footer(
text="Using server %r" % server, icon_url=server_config.get("icon_url")
) )
embed.set_footer(text="Using server %r" % server, icon_url=server_config.get("icon_url"))
if image_data: if image_data:
if (image.height / image.width) >= 1.5: if (image.height / image.width) >= 1.5:
embed.set_image(url=image.url) embed.set_image(url=image.url)
@ -842,7 +924,10 @@ class Ollama(commands.Cog):
if context is None: if context is None:
context = self.history.create_thread(ctx.user, system_query) context = self.history.create_thread(ctx.user, system_query)
elif context is not None and (__thread := self.history.find_thread(context)) is None: elif (
context is not None
and (__thread := self.history.find_thread(context)) is None
):
if not __thread: if not __thread:
return await ctx.respond("Invalid thread ID.") return await ctx.respond("Invalid thread ID.")
else: else:
@ -861,7 +946,12 @@ class Ollama(commands.Cog):
params["top_p"] = 2 params["top_p"] = 2
params["repeat_penalty"] = 2 params["repeat_penalty"] = 2
payload = {"model": model, "stream": True, "options": params, "messages": messages} payload = {
"model": model,
"stream": True,
"options": params,
"messages": messages,
}
async with session.post( async with session.post(
"/api/chat", "/api/chat",
json=payload, json=payload,
@ -870,7 +960,9 @@ class Ollama(commands.Cog):
embed = discord.Embed( embed = discord.Embed(
url=response.url, url=response.url,
title=f"HTTP {response.status} {response.reason!r} while generating response.", title=f"HTTP {response.status} {response.reason!r} while generating response.",
description=f"```{await response.text() or 'No response body'}```"[:4096], description=f"```{await response.text() or 'No response body'}```"[
:4096
],
color=discord.Color.red(), color=discord.Color.red(),
timestamp=discord.utils.utcnow(), timestamp=discord.utils.utcnow(),
) )
@ -888,20 +980,31 @@ class Ollama(commands.Cog):
embed.description = "[...]" + line["message"]["content"] embed.description = "[...]" + line["message"]["content"]
if len(embed.description) >= 3250: if len(embed.description) >= 3250:
embed.colour = discord.Color.gold() embed.colour = discord.Color.gold()
embed.set_footer(text="Warning: {:,}/4096 characters.".format(len(embed.description))) embed.set_footer(
text="Warning: {:,}/4096 characters.".format(
len(embed.description)
)
)
else: else:
embed.colour = discord.Color.blurple() embed.colour = discord.Color.blurple()
embed.set_footer(text="Using server %r" % server, icon_url=server_config.get("icon_url")) embed.set_footer(
text="Using server %r" % server,
icon_url=server_config.get("icon_url"),
)
if view.cancel.is_set(): if view.cancel.is_set():
break break
if time.time() >= (last_update + 5.1): if time.time() >= (last_update + 5.1):
await ctx.edit(embed=embed, view=view) await ctx.edit(embed=embed, view=view)
self.log.debug(f"Updating message ({last_update} -> {time.time()})") self.log.debug(
f"Updating message ({last_update} -> {time.time()})"
)
last_update = time.time() last_update = time.time()
view.stop() view.stop()
self.history.add_message(context, "user", user_message["content"], user_message.get("images")) self.history.add_message(
context, "user", user_message["content"], user_message.get("images")
)
self.history.add_message(context, "assistant", buffer.getvalue()) self.history.add_message(context, "assistant", buffer.getvalue())
embed.add_field(name="Context Key", value=context, inline=True) embed.add_field(name="Context Key", value=context, inline=True)
@ -911,7 +1014,9 @@ class Ollama(commands.Cog):
value = buffer.getvalue() value = buffer.getvalue()
if len(value) >= 4096: if len(value) >= 4096:
embeds = [discord.Embed(title="Done!", colour=discord.Color.green())] embeds = [
discord.Embed(title="Done!", colour=discord.Color.green())
]
current_page = "" current_page = ""
for word in value.split(): for word in value.split():
@ -971,13 +1076,19 @@ class Ollama(commands.Cog):
if message["role"] == "system": if message["role"] == "system":
continue continue
max_length = 4000 - len("> **%s**: " % message["role"]) max_length = 4000 - len("> **%s**: " % message["role"])
paginator.add_line("> **{}**: {}".format(message["role"], textwrap.shorten(message["content"], max_length))) paginator.add_line(
"> **{}**: {}".format(
message["role"], textwrap.shorten(message["content"], max_length)
)
)
embeds = [] embeds = []
for page in paginator.pages: for page in paginator.pages:
embeds.append(discord.Embed(description=page)) embeds.append(discord.Embed(description=page))
ephemeral = len(embeds) > 1 ephemeral = len(embeds) > 1
for chunk in discord.utils.as_chunks(iter(embeds or [discord.Embed(title="No Content.")]), 10): for chunk in discord.utils.as_chunks(
iter(embeds or [discord.Embed(title="No Content.")]), 10
):
await ctx.respond(embeds=chunk, ephemeral=ephemeral) await ctx.respond(embeds=chunk, ephemeral=ephemeral)
@commands.command(name="ollama-status", aliases=["ollama_status", "os"]) @commands.command(name="ollama-status", aliases=["ollama_status", "os"])
@ -990,14 +1101,18 @@ class Ollama(commands.Cog):
if CONFIG["ollama"].get("order"): if CONFIG["ollama"].get("order"):
ln = ["Server order:"] ln = ["Server order:"]
for n, key in enumerate(CONFIG["ollama"].get("order"), start=1): for n, key in enumerate(CONFIG["ollama"].get("order"), start=1):
zap = '\N{high voltage sign}' zap = "\N{HIGH VOLTAGE SIGN}"
ln.append(f"{n}. {key!r} {f'({zap})' if CONFIG['ollama'][key].get('is_gpu') else ''}") ln.append(
f"{n}. {key!r} {f'({zap})' if CONFIG['ollama'][key].get('is_gpu') else ''}"
)
embed.description = "\n".join(ln) embed.description = "\n".join(ln)
for server, lock in self.servers.items(): for server, lock in self.servers.items():
embed.add_field( embed.add_field(
name=server, name=server,
value="\U000026a0\U0000fe0fIn use" if lock.locked() else "\U000023f3 Checking..." value="\U000026a0\U0000fe0fIn use"
if lock.locked()
else "\U000023f3 Checking...",
) )
msg = await ctx.reply(embed=embed) msg = await ctx.reply(embed=embed)
@ -1006,23 +1121,35 @@ class Ollama(commands.Cog):
for server in self.servers.keys(): for server in self.servers.keys():
if self.servers[server].locked(): if self.servers[server].locked():
continue continue
tasks[server] = asyncio.create_task(self.check_server(CONFIG["ollama"][server]["base_url"])) tasks[server] = asyncio.create_task(
self.check_server(CONFIG["ollama"][server]["base_url"])
)
await asyncio.gather(*tasks.values()) await asyncio.gather(*tasks.values())
for server, task in tasks.items(): for server, task in tasks.items():
if not task.done: if not task.done:
await task await task
if e := task.exception(): if e := task.exception():
self.log.error("Error while checking server %r: %r", server, e, exc_info=True) self.log.error(
"Error while checking server %r: %r", server, e, exc_info=True
)
else: else:
ok = task.result() ok = task.result()
for field in embed.fields: for field in embed.fields:
if field.name == server: if field.name == server:
field.value = "\U00002705 Online" if ok else "\U0000274c Offline" field.value = (
"\U00002705 Online" if ok else "\U0000274c Offline"
)
break break
await msg.edit(embed=embed) await msg.edit(embed=embed)
async def generate_truth(self, ctx: commands.Context, entity: str, limit: int = None, question: str = None): async def generate_truth(
self,
ctx: commands.Context,
entity: str,
limit: int = None,
question: str = None,
):
if entity == "trump": if entity == "trump":
system = ( system = (
"You are the former united states president, Donald Trump, convicted of 34 felonies." "You are the former united states president, Donald Trump, convicted of 34 felonies."
@ -1034,7 +1161,7 @@ class Ollama(commands.Cog):
system = ( system = (
"You are the 'Top G' Andrew Tate. Andrew Tate is a former professional kickboxer who turned to the " "You are the 'Top G' Andrew Tate. Andrew Tate is a former professional kickboxer who turned to the "
"world of influencing. His personality is one of arrogance and self-aggrandizement, " "world of influencing. His personality is one of arrogance and self-aggrandizement, "
"often referred to as \"King of Toxic Masculinity\" by critics due to his controversial views on " 'often referred to as "King of Toxic Masculinity" by critics due to his controversial views on '
"gender roles, relationships, and other topics. He has been involved in several controversies related " "gender roles, relationships, and other topics. He has been involved in several controversies related "
"to his content online including promoting extremist ideologies and misogynistic views. " "to his content online including promoting extremist ideologies and misogynistic views. "
"Despite this, he still has a large following and is known for being an entrepreneur who built multiple" "Despite this, he still has a large following and is known for being an entrepreneur who built multiple"
@ -1120,9 +1247,9 @@ class Ollama(commands.Cog):
"of the United Kingdom from 2019 to 2022. Known for his flamboyant personality and " "of the United Kingdom from 2019 to 2022. Known for his flamboyant personality and "
"controversial policies, Johnson's leadership was marred by numerous scandals and crises." "controversial policies, Johnson's leadership was marred by numerous scandals and crises."
"Johnson's personality is characterized by his optimistic and charismatic demeanor, " "Johnson's personality is characterized by his optimistic and charismatic demeanor, "
"often described as a \"dashing politician.\" However, his political career has been " 'often described as a "dashing politician." However, his political career has been '
"plagued by scandals involving ethics violations, questionable financial arrangements, " "plagued by scandals involving ethics violations, questionable financial arrangements, "
"and allegations of inappropriate behavior. The \"Partygate\" controversy, where " 'and allegations of inappropriate behavior. The "Partygate" controversy, where '
"gatherings at Downing Street violated COVID-19 regulations, significantly damaged his reputation. " "gatherings at Downing Street violated COVID-19 regulations, significantly damaged his reputation. "
"Johnson's achievements include leading the U.K. out of the European Union and brokering a " "Johnson's achievements include leading the U.K. out of the European Union and brokering a "
"post-Brexit trade deal with the United States. However, his tumultuous tenure was punctuated by " "post-Brexit trade deal with the United States. However, his tumultuous tenure was punctuated by "
@ -1162,21 +1289,22 @@ class Ollama(commands.Cog):
" Write using the style of a twitter or facebook post. Do not repeat a previous post or any previous " " Write using the style of a twitter or facebook post. Do not repeat a previous post or any previous "
"content. Do not include URLs or links. If an @user asks you a question, reply to them in your post." "content. Do not include URLs or links. If an @user asks you a question, reply to them in your post."
) )
thread_id = self.history.create_thread( thread_id = self.history.create_thread(ctx.author, system)
ctx.author,
system
)
r = CONFIG["truth"].get("api", "https://bots.nexy7574.co.uk/jimmy/v2/api") r = CONFIG["truth"].get("api", "https://bots.nexy7574.co.uk/jimmy/v2/api")
username = CONFIG["truth"].get("username", "1") username = CONFIG["truth"].get("username", "1")
password = CONFIG["truth"].get("password", "2") password = CONFIG["truth"].get("password", "2")
async with httpx.AsyncClient(base_url=r, auth=(username, password), trust_env=False) as http_client: async with httpx.AsyncClient(
base_url=r, auth=(username, password), trust_env=False
) as http_client:
response = await http_client.get( response = await http_client.get(
"/truths", "/truths",
timeout=60, timeout=60,
) )
response.raise_for_status() response.raise_for_status()
truths = response.json() truths = response.json()
truths: list[TruthPayload] = list(map(lambda t: TruthPayload.model_validate(t), truths)) truths: list[TruthPayload] = list(
map(lambda t: TruthPayload.model_validate(t), truths)
)
if entity: if entity:
truths = list(filter(lambda t: t.author == entity, truths)) truths = list(filter(lambda t: t.author == entity, truths))
@ -1191,11 +1319,13 @@ class Ollama(commands.Cog):
thread_id, thread_id,
"assistant", "assistant",
truth.content, truth.content,
save=False save=False,
) )
) )
if not question: if not question:
self.history.add_message(thread_id, "user", "Generate a new truth post.") self.history.add_message(
thread_id, "user", "Generate a new truth post."
)
else: else:
self.history.add_message(thread_id, "user", question) self.history.add_message(thread_id, "user", question)
@ -1204,32 +1334,36 @@ class Ollama(commands.Cog):
server = self.next_server(tried) server = self.next_server(tried)
is_gpu = CONFIG["ollama"][server].get("is_gpu", False) is_gpu = CONFIG["ollama"][server].get("is_gpu", False)
if not is_gpu: if not is_gpu:
self.log.info("Skipping server %r as it is not a GPU server.", server) self.log.info(
"Skipping server %r as it is not a GPU server.", server
)
continue continue
if await self.check_server(CONFIG["ollama"][server]["base_url"]): if await self.check_server(CONFIG["ollama"][server]["base_url"]):
break break
tried.add(server) tried.add(server)
else: else:
return await ctx.reply("All servers are offline. Please try again later.", delete_after=300) return await ctx.reply(
"All servers are offline. Please try again later.", delete_after=300
)
client = OllamaClient(CONFIG["ollama"][server]["base_url"]) client = OllamaClient(CONFIG["ollama"][server]["base_url"])
async with self.servers[server]: async with self.servers[server]:
if not await client.has_model_named("llama2-uncensored", "7b-chat"): if not await client.has_model_named("llama2-uncensored", "7b-chat"):
with client.download_model("llama2-uncensored", "7b-chat") as handler: with client.download_model(
"llama2-uncensored", "7b-chat"
) as handler:
await handler.flatten() await handler.flatten()
embed = discord.Embed( embed = discord.Embed(
title=f"New {post_type.title()}!", title=f"New {post_type.title()}!", description="", colour=0x6559FF
description="",
colour=0x6559FF
) )
msg = await ctx.reply(embed=embed) msg = await ctx.reply(embed=embed)
last_edit = time.time() last_edit = time.time()
messages = self.history.get_history(thread_id) messages = self.history.get_history(thread_id)
with client.new_chat( with client.new_chat(
"llama2-uncensored:7b-chat", "llama2-uncensored:7b-chat",
messages, messages,
options={"num_ctx": 4096, "num_predict": 128, "temperature": 1.5} options={"num_ctx": 4096, "num_predict": 128, "temperature": 1.5},
) as handler: ) as handler:
async for ln in handler: async for ln in handler:
embed.description += ln["message"]["content"] embed.description += ln["message"]["content"]
@ -1245,7 +1379,7 @@ class Ollama(commands.Cog):
if truth.content == embed.description: if truth.content == embed.description:
embed.add_field( embed.add_field(
name=f"Repeated {post_type} :(", name=f"Repeated {post_type} :(",
value=f"This truth was already {post_type}ed. Shit AI." value=f"This truth was already {post_type}ed. Shit AI.",
) )
elif _ratio >= 70: elif _ratio >= 70:
similar[truth.id] = _ratio similar[truth.id] = _ratio
@ -1253,12 +1387,16 @@ class Ollama(commands.Cog):
if similar: if similar:
if len(similar) > 1: if len(similar) > 1:
lns = [] lns = []
keys = sorted(similar.keys(), key=lambda k: similar[k], reverse=True) keys = sorted(
similar.keys(), key=lambda k: similar[k], reverse=True
)
for truth_id in keys: for truth_id in keys:
_ratio = similar[truth_id] _ratio = similar[truth_id]
truth = discord.utils.get(truths, id=truth_id) truth = discord.utils.get(truths, id=truth_id)
first_line = truth.content.splitlines()[0] first_line = truth.content.splitlines()[0]
preview = discord.utils.escape_markdown(textwrap.shorten(first_line, 100)) preview = discord.utils.escape_markdown(
textwrap.shorten(first_line, 100)
)
lns.append(f"* `{truth_id}`: {_ratio}% - `{preview}`") lns.append(f"* `{truth_id}`: {_ratio}% - `{preview}`")
if len(lns) > 5: if len(lns) > 5:
lc = len(lns) - 5 lc = len(lns) - 5
@ -1266,33 +1404,39 @@ class Ollama(commands.Cog):
lns.append(f"*... and {lc} more*") lns.append(f"*... and {lc} more*")
embed.add_field( embed.add_field(
name=f"Possibly repeated {post_type}", name=f"Possibly repeated {post_type}",
value=f"This {post_type} was similar to the following existing ones:\n" + "\n".join(lns), value=f"This {post_type} was similar to the following existing ones:\n"
inline=False + "\n".join(lns),
inline=False,
) )
else: else:
truth_id = tuple(similar)[0] truth_id = tuple(similar)[0]
_ratio = similar[truth_id] _ratio = similar[truth_id]
truth = discord.utils.get(truths, id=truth_id) truth = discord.utils.get(truths, id=truth_id)
first_line = truth.content.splitlines()[0] first_line = truth.content.splitlines()[0]
preview = discord.utils.escape_markdown(textwrap.shorten(first_line, 512)) preview = discord.utils.escape_markdown(
textwrap.shorten(first_line, 512)
)
embed.add_field( embed.add_field(
name=f"Possibly repeated {post_type}", name=f"Possibly repeated {post_type}",
value=f"This {post_type} was {_ratio}% similar to `{truth_id}`.\n>>> {preview}" value=f"This {post_type} was {_ratio}% similar to `{truth_id}`.\n>>> {preview}",
) )
embed.set_footer( embed.set_footer(
text="Finished generating {} based off of {:,} messages, using server {!r} | {!s}".format( text="Finished generating {} based off of {:,} messages, using server {!r} | {!s}".format(
post_type, post_type, len(messages) - 2, server, thread_id
len(messages) - 2,
server,
thread_id
) )
) )
await msg.edit(embed=embed) await msg.edit(embed=embed)
@commands.command(aliases=["trump"]) @commands.command(aliases=["trump"])
@commands.guild_only() @commands.guild_only()
async def donald(self, ctx: commands.Context, latest: typing.Optional[int] = None, *, question: str = None): async def donald(
self,
ctx: commands.Context,
latest: typing.Optional[int] = None,
*,
question: str = None,
):
""" """
Generates a truth social post from trump! Generates a truth social post from trump!
@ -1307,7 +1451,13 @@ class Ollama(commands.Cog):
@commands.command(aliases=["tate"]) @commands.command(aliases=["tate"])
@commands.guild_only() @commands.guild_only()
async def andrew(self, ctx: commands.Context, latest: typing.Optional[int] = None, *, question: str = None): async def andrew(
self,
ctx: commands.Context,
latest: typing.Optional[int] = None,
*,
question: str = None,
):
""" """
Generates a truth social post from Andrew Tate Generates a truth social post from Andrew Tate
@ -1319,10 +1469,16 @@ class Ollama(commands.Cog):
question = f"'@{ctx.author.display_name}' asks: {question!r}" question = f"'@{ctx.author.display_name}' asks: {question!r}"
async with ctx.channel.typing(): async with ctx.channel.typing():
await self.generate_truth(ctx, "tate", latest, question=question) await self.generate_truth(ctx, "tate", latest, question=question)
@commands.command(aliases=["sunak"]) @commands.command(aliases=["sunak"])
@commands.guild_only() @commands.guild_only()
async def rishi(self, ctx: commands.Context, latest: typing.Optional[int] = None, *, question: str = None): async def rishi(
self,
ctx: commands.Context,
latest: typing.Optional[int] = None,
*,
question: str = None,
):
""" """
Generates a twitter post from Rishi Sunak Generates a twitter post from Rishi Sunak
@ -1334,10 +1490,16 @@ class Ollama(commands.Cog):
question = f"'@{ctx.author.display_name}' asks: {question!r}" question = f"'@{ctx.author.display_name}' asks: {question!r}"
async with ctx.channel.typing(): async with ctx.channel.typing():
await self.generate_truth(ctx, "Rishi Sunak", latest, question=question) await self.generate_truth(ctx, "Rishi Sunak", latest, question=question)
@commands.command(aliases=["robinson"]) @commands.command(aliases=["robinson"])
@commands.guild_only() @commands.guild_only()
async def tommy(self, ctx: commands.Context, latest: typing.Optional[int] = None, *, question: str = None): async def tommy(
self,
ctx: commands.Context,
latest: typing.Optional[int] = None,
*,
question: str = None,
):
""" """
Generates a twitter post from Tommy Robinson Generates a twitter post from Tommy Robinson
@ -1348,11 +1510,19 @@ class Ollama(commands.Cog):
if question: if question:
question = f"'@{ctx.author.display_name}' asks: {question!r}" question = f"'@{ctx.author.display_name}' asks: {question!r}"
async with ctx.channel.typing(): async with ctx.channel.typing():
await self.generate_truth(ctx, "Tommy Robinson 🇬🇧", latest, question=question) await self.generate_truth(
ctx, "Tommy Robinson 🇬🇧", latest, question=question
)
@commands.command(aliases=["fox"]) @commands.command(aliases=["fox"])
@commands.guild_only() @commands.guild_only()
async def laurence(self, ctx: commands.Context, latest: typing.Optional[int] = None, *, question: str = None): async def laurence(
self,
ctx: commands.Context,
latest: typing.Optional[int] = None,
*,
question: str = None,
):
""" """
Generates a twitter post from Laurence Fox Generates a twitter post from Laurence Fox
@ -1364,10 +1534,16 @@ class Ollama(commands.Cog):
question = f"'@{ctx.author.display_name}' asks: {question!r}" question = f"'@{ctx.author.display_name}' asks: {question!r}"
async with ctx.channel.typing(): async with ctx.channel.typing():
await self.generate_truth(ctx, "Laurence Fox", latest, question=question) await self.generate_truth(ctx, "Laurence Fox", latest, question=question)
@commands.command(aliases=["farage"]) @commands.command(aliases=["farage"])
@commands.guild_only() @commands.guild_only()
async def nigel(self, ctx: commands.Context, latest: typing.Optional[int] = None, *, question: str = None): async def nigel(
self,
ctx: commands.Context,
latest: typing.Optional[int] = None,
*,
question: str = None,
):
""" """
Generates a twitter post from Nigel Farage Generates a twitter post from Nigel Farage
@ -1382,7 +1558,13 @@ class Ollama(commands.Cog):
@commands.command(aliases=["starmer"]) @commands.command(aliases=["starmer"])
@commands.guild_only() @commands.guild_only()
async def keir(self, ctx: commands.Context, latest: typing.Optional[int] = None, *, question: str = None): async def keir(
self,
ctx: commands.Context,
latest: typing.Optional[int] = None,
*,
question: str = None,
):
""" """
Generates a twitter post from Keir Starmer Generates a twitter post from Keir Starmer
@ -1394,10 +1576,16 @@ class Ollama(commands.Cog):
question = f"'@{ctx.author.display_name}' asks: {question!r}" question = f"'@{ctx.author.display_name}' asks: {question!r}"
async with ctx.channel.typing(): async with ctx.channel.typing():
await self.generate_truth(ctx, "Keir Starmer", latest, question=question) await self.generate_truth(ctx, "Keir Starmer", latest, question=question)
@commands.command(aliases=["johnson"]) @commands.command(aliases=["johnson"])
@commands.guild_only() @commands.guild_only()
async def boris(self, ctx: commands.Context, latest: typing.Optional[int] = None, *, question: str = None): async def boris(
self,
ctx: commands.Context,
latest: typing.Optional[int] = None,
*,
question: str = None,
):
""" """
Generates a twitter post from Boris Johnson Generates a twitter post from Boris Johnson
@ -1409,10 +1597,16 @@ class Ollama(commands.Cog):
question = f"'@{ctx.author.display_name}' asks: {question!r}" question = f"'@{ctx.author.display_name}' asks: {question!r}"
async with ctx.channel.typing(): async with ctx.channel.typing():
await self.generate_truth(ctx, "Boris Johnson", latest, question=question) await self.generate_truth(ctx, "Boris Johnson", latest, question=question)
@commands.command(aliases=["desantis"]) @commands.command(aliases=["desantis"])
@commands.guild_only() @commands.guild_only()
async def ron(self, ctx: commands.Context, latest: typing.Optional[int] = None, *, question: str = None): async def ron(
self,
ctx: commands.Context,
latest: typing.Optional[int] = None,
*,
question: str = None,
):
""" """
Generates a twitter post from Ron Desantis Generates a twitter post from Ron Desantis
@ -1447,12 +1641,15 @@ class Ollama(commands.Cog):
title="Truth:", title="Truth:",
description=truth.content, description=truth.content,
colour=0x6559FF, colour=0x6559FF,
timestamp=datetime.datetime.fromtimestamp(truth.timestamp, tz=datetime.timezone.utc), timestamp=datetime.datetime.fromtimestamp(
truth.timestamp, tz=datetime.timezone.utc
),
) )
if truth.extra: if truth.extra:
embed.add_field( embed.add_field(
name="Extra Data", name="Extra Data",
value="```json\n%s```" % json.dumps(truth.extra, indent=2, default=repr), value="```json\n%s```"
% json.dumps(truth.extra, indent=2, default=repr),
) )
embed.set_author(name=truth.author) embed.set_author(name=truth.author)
await ctx.reply(embed=embed) await ctx.reply(embed=embed)

112
src/cogs/onion_feed.py Normal file
View file

@ -0,0 +1,112 @@
"""
Scrapes the onion RSS feed once every hour and posts any new articles to the desired channel
"""
import asyncio
import dataclasses
import datetime
import logging
from bs4 import BeautifulSoup
import discord
from discord.ext import commands, tasks
import httpx
from conf import CONFIG
import redis
@dataclasses.dataclass
class RSSItem:
title: str
link: str
description: str
pubDate: datetime.datetime
guid: str
thumbnail: str
class OnionFeed(commands.Cog):
SOURCE = "https://www.theonion.com/rss"
EPOCH = datetime.datetime(2024, 7, 1, tzinfo=datetime.timezone.utc)
def __init__(self, bot):
self.bot: commands.Bot = bot
self.log = logging.getLogger("jimmy.cogs.onion_feed")
self.check_onion_feed.start()
self.redis = redis.Redis(**CONFIG["redis"])
def cog_unload(self) -> None:
self.check_onion_feed.cancel()
@staticmethod
def parse_item(item: BeautifulSoup) -> RSSItem:
description = (
BeautifulSoup(item.description.get_text(), "html.parser")
.p.get_text(strip=True)
.strip()[:-1]
)
kwargs = {
"title": item.title.get_text(strip=True).strip(),
"link": item.link.get_text(strip=True).strip(),
"pubDate": datetime.datetime.strptime(
item.pubDate.get_text(strip=True).strip(), "%a, %d %b %Y %H:%M:%S %Z"
),
"guid": item.guid.get_text(strip=True).strip(),
"description": description,
"thumbnail": item.find("media:thumbnail")["url"],
}
return RSSItem(**kwargs)
def parse_feed(self, text: str) -> list[RSSItem]:
soup = BeautifulSoup(text, "xml")
return [self.parse_item(item) for item in soup.find_all("item")]
@tasks.loop(hours=1)
async def check_onion_feed(self):
if not self.bot.is_ready():
await self.bot.wait_until_ready()
guild = self.bot.get_guild(994710566612500550)
if not guild:
return self.log.error("Nonsense guild not found. Can't do onion feed.")
channel = discord.utils.get(guild.text_channels, name="spam")
if not channel:
return self.log.error("Spam channel not found.")
async with httpx.AsyncClient() as client:
response = await client.get(self.SOURCE)
if response.status_code != 200:
return self.log.error(
f"Failed to fetch onion feed: {response.status_code}"
)
items: list[RSSItem] = await asyncio.to_thread(
self.parse_feed, response.text
)
for item in items:
if self.redis.get("onion-" + item.guid):
continue
embed = discord.Embed(
title=item.title,
url=item.link,
description=item.description + f"... [Read More]({item.link})",
color=0x00DF78,
timestamp=item.pubDate,
)
embed.set_thumbnail(url=item.thumbnail)
try:
msg = await channel.send(embed=embed)
# noinspection PyAsyncCall
self.redis.set("onion-" + item.guid, str(msg.id))
except discord.HTTPException:
self.log.exception(
f"Failed to send onion feed message: {item.title}"
)
else:
self.log.debug(f"Sent onion feed message: {item.title}")
@check_onion_feed.before_loop
async def before_check_onion_feed(self):
await self.bot.wait_until_ready()
def setup(bot):
bot.add_cog(OnionFeed(bot))

View file

@ -16,11 +16,9 @@ from discord.ext import commands
from conf import CONFIG from conf import CONFIG
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
JSON: typing.Union[ JSON: typing.Union[str, int, float, bool, None, dict[str, "JSON"], list["JSON"]] = (
str, int, float, bool, None, dict[str, "JSON"], list["JSON"] typing.Union[str, int, float, bool, None, dict, list]
] = typing.Union[ )
str, int, float, bool, None, dict, list
]
class TruthPayload(BaseModel): class TruthPayload(BaseModel):
@ -32,7 +30,6 @@ class TruthPayload(BaseModel):
class QuoteQuota(commands.Cog): class QuoteQuota(commands.Cog):
def __init__(self, bot): def __init__(self, bot):
self.bot = bot self.bot = bot
self.quotes_channel_id = CONFIG["quote_a"].get("channel_id") self.quotes_channel_id = CONFIG["quote_a"].get("channel_id")
@ -47,7 +44,9 @@ class QuoteQuota(commands.Cog):
return c return c
@staticmethod @staticmethod
def generate_pie_chart(usernames: list[str], counts: list[int], no_other: bool = False) -> discord.File: def generate_pie_chart(
usernames: list[str], counts: list[int], no_other: bool = False
) -> discord.File:
""" """
Converts the given username and count tuples into a nice pretty pie chart. Converts the given username and count tuples into a nice pretty pie chart.
@ -95,7 +94,9 @@ class QuoteQuota(commands.Cog):
startangle=90, startangle=90,
radius=1.2, radius=1.2,
) )
fig.subplots_adjust(left=0.1, bottom=0.1, right=0.9, top=0.9, wspace=0.3, hspace=0.4) fig.subplots_adjust(
left=0.1, bottom=0.1, right=0.9, top=0.9, wspace=0.3, hspace=0.4
)
fio = io.BytesIO() fio = io.BytesIO()
fig.savefig(fio, format="png") fig.savefig(fio, format="png")
fio.seek(0) fio.seek(0)
@ -130,7 +131,9 @@ class QuoteQuota(commands.Cog):
now = discord.utils.utcnow() now = discord.utils.utcnow()
oldest = now - timedelta(days=days) oldest = now - timedelta(days=days)
await ctx.defer() await ctx.defer()
channel = self.quotes_channel or discord.utils.get(ctx.guild.text_channels, name="quotes") channel = self.quotes_channel or discord.utils.get(
ctx.guild.text_channels, name="quotes"
)
if not channel: if not channel:
return await ctx.respond(":x: Cannot find quotes channel.") return await ctx.respond(":x: Cannot find quotes channel.")
@ -139,7 +142,9 @@ class QuoteQuota(commands.Cog):
authors = {} authors = {}
filtered_messages = 0 filtered_messages = 0
total = 0 total = 0
async for message in channel.history(limit=None, after=oldest, oldest_first=False): async for message in channel.history(
limit=None, after=oldest, oldest_first=False
):
total += 1 total += 1
if not message.content: if not message.content:
filtered_messages += 1 filtered_messages += 1
@ -179,10 +184,15 @@ class QuoteQuota(commands.Cog):
' (e.g. `"This is my quote" - Jimmy`)'.format(days) ' (e.g. `"This is my quote" - Jimmy`)'.format(days)
) )
else: else:
return await ctx.edit(content="No messages found in the last {!s} days.".format(days)) return await ctx.edit(
content="No messages found in the last {!s} days.".format(days)
)
file = await asyncio.to_thread( file = await asyncio.to_thread(
self.generate_pie_chart, list(authors.keys()), list(authors.values()), merge_other self.generate_pie_chart,
list(authors.keys()),
list(authors.values()),
merge_other,
) )
return await ctx.edit( return await ctx.edit(
content="{:,} messages (out of {:,}) were filtered (didn't follow format?)".format( content="{:,} messages (out of {:,}) were filtered (didn't follow format?)".format(
@ -192,7 +202,11 @@ class QuoteQuota(commands.Cog):
) )
def _metacounter( def _metacounter(
self, truths: list[TruthPayload], filter_func: Callable[[TruthPayload], bool], *, now: datetime = None self,
truths: list[TruthPayload],
filter_func: Callable[[TruthPayload], bool],
*,
now: datetime = None,
) -> dict[str, float | int | dict[str, int]]: ) -> dict[str, float | int | dict[str, int]]:
def _is_today(date: datetime) -> bool: def _is_today(date: datetime) -> bool:
return date.date() == now.date() return date.date() == now.date()
@ -215,7 +229,9 @@ class QuoteQuota(commands.Cog):
if filter_func(truth): if filter_func(truth):
created_at = datetime.fromtimestamp(truth.timestamp, tz=timezone.utc) created_at = datetime.fromtimestamp(truth.timestamp, tz=timezone.utc)
age = now - created_at age = now - created_at
self.log.debug("%r was a truth (%.2f seconds ago).", truth.id, age.total_seconds()) self.log.debug(
"%r was a truth (%.2f seconds ago).", truth.id, age.total_seconds()
)
counts["all_time"] += 1 counts["all_time"] += 1
if _is_today(created_at): if _is_today(created_at):
counts["today"] += 1 counts["today"] += 1
@ -275,7 +291,9 @@ class QuoteQuota(commands.Cog):
plt.bar(hrs, list(hours.values()), color="#5448EE") plt.bar(hrs, list(hours.values()), color="#5448EE")
average = sum(hours.values()) / len(hours) average = sum(hours.values()) / len(hours)
plt.axhline(average, color="red", linestyle="--", label=f"Average: {average:.1f}") plt.axhline(
average, color="red", linestyle="--", label=f"Average: {average:.1f}"
)
file = io.BytesIO() file = io.BytesIO()
plt.savefig(file, format="png") plt.savefig(file, format="png")
@ -283,8 +301,7 @@ class QuoteQuota(commands.Cog):
return discord.File(file, "truths.png") return discord.File(file, "truths.png")
async def _process_all_messages( async def _process_all_messages(
self, self, truths: list[TruthPayload]
truths: list[TruthPayload]
) -> tuple[discord.Embed, discord.File]: ) -> tuple[discord.Embed, discord.File]:
""" """
Processes all the messages in the given channel. Processes all the messages in the given channel.
@ -292,7 +309,11 @@ class QuoteQuota(commands.Cog):
:param truths: The truths to process :param truths: The truths to process
:returns: The stats :returns: The stats
""" """
embed = discord.Embed(title="Truth Counts", color=discord.Color.blurple(), timestamp=discord.utils.utcnow()) embed = discord.Embed(
title="Truth Counts",
color=discord.Color.blurple(),
timestamp=discord.utils.utcnow(),
)
trump_stats = await self._process_trump_truths(truths) trump_stats = await self._process_trump_truths(truths)
tate_stats = await self._process_tate_truths(truths) tate_stats = await self._process_tate_truths(truths)
@ -346,7 +367,9 @@ class QuoteQuota(commands.Cog):
) )
response.raise_for_status() response.raise_for_status()
truths: list[dict] = response.json() truths: list[dict] = response.json()
truths: list[TruthPayload] = list(map(lambda t: TruthPayload.model_validate(t), truths)) truths: list[TruthPayload] = list(
map(lambda t: TruthPayload.model_validate(t), truths)
)
embed, file = await self._process_all_messages(truths) embed, file = await self._process_all_messages(truths)
await ctx.edit(embed=embed, file=file) await ctx.edit(embed=embed, file=file)

View file

@ -1,5 +1,4 @@
import asyncio import asyncio
import copy
import datetime import datetime
import io import io
import logging import logging
@ -133,7 +132,9 @@ class ScreenshotCog(commands.Cog):
else: else:
use_proxy = False use_proxy = False
service = await asyncio.to_thread(ChromeService) service = await asyncio.to_thread(ChromeService)
driver: webdriver.Chrome = await asyncio.to_thread(webdriver.Chrome, service=service, options=options) driver: webdriver.Chrome = await asyncio.to_thread(
webdriver.Chrome, service=service, options=options
)
driver.set_page_load_timeout(load_timeout) driver.set_page_load_timeout(load_timeout)
if resolution: if resolution:
resolution = RESOLUTIONS.get(resolution.lower(), resolution) resolution = RESOLUTIONS.get(resolution.lower(), resolution)
@ -141,9 +142,13 @@ class ScreenshotCog(commands.Cog):
width, height = map(int, resolution.split("x")) width, height = map(int, resolution.split("x"))
driver.set_window_size(width, height) driver.set_window_size(width, height)
if height > 4320 or width > 7680: if height > 4320 or width > 7680:
return await ctx.respond("Invalid resolution. Max resolution is 7680x4320 (8K).") return await ctx.respond(
"Invalid resolution. Max resolution is 7680x4320 (8K)."
)
except ValueError: except ValueError:
return await ctx.respond("Invalid resolution. please provide width x height, e.g. 1920x1080") return await ctx.respond(
"Invalid resolution. please provide width x height, e.g. 1920x1080"
)
if eager: if eager:
driver.implicitly_wait(render_timeout) driver.implicitly_wait(render_timeout)
except Exception as e: except Exception as e:
@ -151,7 +156,13 @@ class ScreenshotCog(commands.Cog):
raise raise
end_init = time.time() end_init = time.time()
await ctx.edit(content=("Loading webpage..." if not eager else "Loading & screenshotting webpage...")) await ctx.edit(
content=(
"Loading webpage..."
if not eager
else "Loading & screenshotting webpage..."
)
)
start_request = time.time() start_request = time.time()
try: try:
await asyncio.to_thread(driver.get, url) await asyncio.to_thread(driver.get, url)
@ -160,7 +171,9 @@ class ScreenshotCog(commands.Cog):
if "TimeoutException" in str(e): if "TimeoutException" in str(e):
return await ctx.respond("Timed out while loading webpage.") return await ctx.respond("Timed out while loading webpage.")
else: else:
return await ctx.respond("Failed to load webpage:\n```\n%s\n```" % str(e.msg)) return await ctx.respond(
"Failed to load webpage:\n```\n%s\n```" % str(e.msg)
)
except Exception as e: except Exception as e:
await self.bot.loop.run_in_executor(None, driver.quit) await self.bot.loop.run_in_executor(None, driver.quit)
await ctx.respond("Failed to get the webpage: " + str(e)) await ctx.respond("Failed to get the webpage: " + str(e))
@ -170,7 +183,9 @@ class ScreenshotCog(commands.Cog):
if not eager: if not eager:
now = discord.utils.utcnow() now = discord.utils.utcnow()
expires = now + datetime.timedelta(seconds=render_timeout) expires = now + datetime.timedelta(seconds=render_timeout)
await ctx.edit(content=f"Rendering (expires {discord.utils.format_dt(expires, 'R')})...") await ctx.edit(
content=f"Rendering (expires {discord.utils.format_dt(expires, 'R')})..."
)
start_wait = time.time() start_wait = time.time()
await asyncio.sleep(render_timeout) await asyncio.sleep(render_timeout)
end_wait = time.time() end_wait = time.time()
@ -200,7 +215,9 @@ class ScreenshotCog(commands.Cog):
await self.bot.loop.run_in_executor(None, driver.quit) await self.bot.loop.run_in_executor(None, driver.quit)
end_cleanup = time.time() end_cleanup = time.time()
screenshot_size_mb = round(len(await asyncio.to_thread(file.getvalue)) / 1024 / 1024, 2) screenshot_size_mb = round(
len(await asyncio.to_thread(file.getvalue)) / 1024 / 1024, 2
)
def seconds(start: float, end: float) -> float: def seconds(start: float, end: float) -> float:
return round(end - start, 2) return round(end - start, 2)
@ -221,7 +238,9 @@ class ScreenshotCog(commands.Cog):
timestamp=discord.utils.utcnow(), timestamp=discord.utils.utcnow(),
) )
embed.set_image(url="attachment://" + fn) embed.set_image(url="attachment://" + fn)
return await ctx.edit(content=None, embed=embed, file=discord.File(file, filename=fn)) return await ctx.edit(
content=None, embed=embed, file=discord.File(file, filename=fn)
)
def setup(bot): def setup(bot):

View file

@ -10,24 +10,32 @@ from conf import CONFIG
class Starboard(commands.Cog): class Starboard(commands.Cog):
DOWNVOTE_EMOJI = discord.PartialEmoji.from_str("\N{collision symbol}") DOWNVOTE_EMOJI = discord.PartialEmoji.from_str("\N{COLLISION SYMBOL}")
def __init__(self, bot): def __init__(self, bot):
self.bot: commands.Bot = bot self.bot: commands.Bot = bot
self.config = CONFIG["starboard"] self.config = CONFIG["starboard"]
self.emoji = discord.PartialEmoji.from_str(self.config.get("emoji", "\N{white medium star}")) self.emoji = discord.PartialEmoji.from_str(
self.config.get("emoji", "\N{WHITE MEDIUM STAR}")
)
self.log = logging.getLogger("jimmy.cogs.starboard") self.log = logging.getLogger("jimmy.cogs.starboard")
self.redis = Redis(**CONFIG["redis"]) self.redis = Redis(**CONFIG["redis"])
async def generate_starboard_embed(self, message: discord.Message) -> tuple[list[discord.Embed], int]: async def generate_starboard_embed(
self, message: discord.Message
) -> tuple[list[discord.Embed], int]:
""" """
Generates an embed ready for a starboard message. Generates an embed ready for a starboard message.
:param message: The message to base off of. :param message: The message to base off of.
:return: The created embed :return: The created embed
""" """
reactions: list[discord.Reaction] = [x for x in message.reactions if str(x.emoji) == str(self.emoji)] reactions: list[discord.Reaction] = [
downvote_reactions = [x for x in message.reactions if str(x.emoji) == str(self.DOWNVOTE_EMOJI)] x for x in message.reactions if str(x.emoji) == str(self.emoji)
]
downvote_reactions = [
x for x in message.reactions if str(x.emoji) == str(self.DOWNVOTE_EMOJI)
]
if not reactions: if not reactions:
# Nobody has added the star reaction. # Nobody has added the star reaction.
star_count = 0 star_count = 0
@ -35,12 +43,18 @@ class Starboard(commands.Cog):
else: else:
# Count the number of reactions # Count the number of reactions
star_count = sum([x.count for x in reactions]) star_count = sum([x.count for x in reactions])
self.log.debug("There are a total of %d star reactions on message.", star_count) self.log.debug(
"There are a total of %d star reactions on message.", star_count
)
if downvote_reactions: if downvote_reactions:
_dv = sum([x.count for x in downvote_reactions]) _dv = sum([x.count for x in downvote_reactions])
star_count -= _dv star_count -= _dv
self.log.debug("There are %d downvotes on message, resulting in %d stars.", _dv, star_count) self.log.debug(
"There are %d downvotes on message, resulting in %d stars.",
_dv,
star_count,
)
if star_count >= 0: if star_count >= 0:
star_emoji_count = (str(self.emoji) * star_count)[:10] star_emoji_count = (str(self.emoji) * star_count)[:10]
@ -54,18 +68,20 @@ class Starboard(commands.Cog):
author=discord.EmbedAuthor( author=discord.EmbedAuthor(
message.author.display_name, message.author.display_name,
message.author.jump_url, message.author.jump_url,
message.author.display_avatar.url message.author.display_avatar.url,
), ),
fields=[ fields=[
discord.EmbedField( discord.EmbedField(
name="Info", name="Info",
value=f"[Stars: {star_emoji_count} ({star_count:,})]({message.jump_url})" value=f"[Stars: {star_emoji_count} ({star_count:,})]({message.jump_url})",
) )
] ],
) )
if message.reference: if message.reference:
try: try:
ref_message = await self.get_or_fetch_message(message.reference.channel_id, message.reference.message_id) ref_message = await self.get_or_fetch_message(
message.reference.channel_id, message.reference.message_id
)
except discord.HTTPException: except discord.HTTPException:
pass pass
else: else:
@ -74,31 +90,26 @@ class Starboard(commands.Cog):
remaining = 1024 - len(v) remaining = 1024 - len(v)
t = textwrap.shorten(text, remaining, placeholder="...") t = textwrap.shorten(text, remaining, placeholder="...")
v = f"[{ref_message.author.display_name}'s message: {t}]({ref_message.jump_url})" v = f"[{ref_message.author.display_name}'s message: {t}]({ref_message.jump_url})"
embed.add_field( embed.add_field(name="Replying to", value=v)
name="Replying to",
value=v
)
elif message.interaction: elif message.interaction:
if message.interaction.type == discord.InteractionType.application_command: if message.interaction.type == discord.InteractionType.application_command:
real_author: discord.User = await discord.utils.get_or_fetch( real_author: discord.User = await discord.utils.get_or_fetch(
self.bot, self.bot, "user", int(message.interaction.data["user"]["id"])
"user", )
int(message.interaction.data["user"]["id"]) real_author = (
await discord.utils.get_or_fetch(
message.guild, "member", real_author.id, default=real_author
)
or message.author
) )
real_author = await discord.utils.get_or_fetch(
message.guild,
"member",
real_author.id,
default=real_author
) or message.author
embed.set_author( embed.set_author(
name=real_author.display_name, name=real_author.display_name,
icon_url=real_author.display_avatar.url, icon_url=real_author.display_avatar.url,
url=real_author.jump_url url=real_author.jump_url,
) )
embed.add_field( embed.add_field(
name="Interaction", name="Interaction",
value=f"Command `/{message.interaction.data['name']}` of {message.author.mention}" value=f"Command `/{message.interaction.data['name']}` of {message.author.mention}",
) )
if message.content: if message.content:
@ -107,17 +118,24 @@ class Starboard(commands.Cog):
for message_embed in message.embeds: for message_embed in message.embeds:
if message_embed.type != "rich": if message_embed.type != "rich":
if message_embed.type == "image": if message_embed.type == "image":
if message_embed.thumbnail and message_embed.thumbnail.proxy_url: if (
message_embed.thumbnail
and message_embed.thumbnail.proxy_url
):
embed.set_image(url=message_embed.thumbnail.proxy_url) embed.set_image(url=message_embed.thumbnail.proxy_url)
continue continue
if message_embed.description: if message_embed.description:
embed.description = message_embed.description embed.description = message_embed.description
elif not message.attachments: elif not message.attachments:
raise ValueError("Message does not appear to contain any text, embeds, or attachments.") raise ValueError(
"Message does not appear to contain any text, embeds, or attachments."
)
if message.attachments: if message.attachments:
new_fields = [] new_fields = []
for n, attachment in reversed(tuple(enumerate(message.attachments, start=1))): for n, attachment in reversed(
tuple(enumerate(message.attachments, start=1))
):
attachment: discord.Attachment attachment: discord.Attachment
if attachment.size >= 1024 * 1024: if attachment.size >= 1024 * 1024:
size = f"{attachment.size / 1024 / 1024:,.1f}MiB" size = f"{attachment.size / 1024 / 1024:,.1f}MiB"
@ -129,7 +147,7 @@ class Starboard(commands.Cog):
{ {
"name": "Attachment #%d:" % n, "name": "Attachment #%d:" % n,
"value": f"[{attachment.filename} ({size})]({attachment.url})", "value": f"[{attachment.filename} ({size})]({attachment.url})",
"inline": True "inline": True,
} }
) )
if attachment.content_type.startswith("image/"): if attachment.content_type.startswith("image/"):
@ -142,13 +160,19 @@ class Starboard(commands.Cog):
return [embed, *filter(lambda e: e.type == "rich", message.embeds)], star_count return [embed, *filter(lambda e: e.type == "rich", message.embeds)], star_count
async def get_or_fetch_message(self, channel_id: int, message_id: int) -> discord.Message: async def get_or_fetch_message(
self, channel_id: int, message_id: int
) -> discord.Message:
""" """
Fetches a message from cache where possible, falling back to the API. Fetches a message from cache where possible, falling back to the API.
""" """
message: discord.Message | None = discord.utils.get(self.bot.cached_messages, id=message_id) message: discord.Message | None = discord.utils.get(
self.bot.cached_messages, id=message_id
)
if not message: if not message:
message: discord.Message = await self.bot.get_channel(channel_id).fetch_message(message_id) message: discord.Message = await self.bot.get_channel(
channel_id
).fetch_message(message_id)
return message return message
@commands.Cog.listener("on_raw_reaction_add") @commands.Cog.listener("on_raw_reaction_add")
@ -171,27 +195,35 @@ class Starboard(commands.Cog):
guild: discord.Guild = self.bot.get_guild(payload.guild_id) guild: discord.Guild = self.bot.get_guild(payload.guild_id)
starboard_channel: discord.TextChannel | None = discord.utils.get( starboard_channel: discord.TextChannel | None = discord.utils.get(
guild.text_channels, guild.text_channels, name=self.config.get("channel_name", "starboard")
name=self.config.get("channel_name", "starboard")
) )
if not starboard_channel: if not starboard_channel:
self.log.warning("Could not find starboard channel in %s (%d)", guild, guild.id) self.log.warning(
"Could not find starboard channel in %s (%d)", guild, guild.id
)
return return
if payload.channel_id == starboard_channel.id: if payload.channel_id == starboard_channel.id:
self.log.debug("Ignoring reaction in starboard channel.") self.log.debug("Ignoring reaction in starboard channel.")
return return
message = await self.get_or_fetch_message(payload.channel_id, payload.message_id) message = await self.get_or_fetch_message(
payload.channel_id, payload.message_id
)
if payload.user_id == message.author.id: if payload.user_id == message.author.id:
self.log.info("%s tried to star their own message.", message.author) self.log.info("%s tried to star their own message.", message.author)
return await message.reply( return await message.reply(
"You can't star your own message you pretentious dick. Go outside, %s." % message.author.mention, "You can't star your own message you pretentious dick. Go outside, %s."
delete_after=30 % message.author.mention,
delete_after=30,
) )
self.log.info("Processing starboard reaction from %s: %s", payload.user_id, message.jump_url) self.log.info(
"Processing starboard reaction from %s: %s",
payload.user_id,
message.jump_url,
)
embed, star_count = await self.generate_starboard_embed(message) embed, star_count = await self.generate_starboard_embed(message)
data = await self.redis.get(str(message.id)) data = await self.redis.get(str(message.id))
@ -205,7 +237,7 @@ class Starboard(commands.Cog):
"source_message_id": payload.message_id, "source_message_id": payload.message_id,
"history": [], "history": [],
"starboard_channel_id": starboard_channel.id, "starboard_channel_id": starboard_channel.id,
"starboard_message_id": None "starboard_message_id": None,
} }
if not starboard_channel.can_send(embed[0]): if not starboard_channel.can_send(embed[0]):
self.log.warning( self.log.warning(
@ -213,29 +245,24 @@ class Starboard(commands.Cog):
starboard_channel.id, starboard_channel.id,
payload.guild_id, payload.guild_id,
starboard_channel.name, starboard_channel.name,
guild.name guild.name,
) )
return return
starboard_message = await starboard_channel.send( starboard_message = await starboard_channel.send(embeds=embed, silent=True)
embeds=embed,
silent=True
)
data["starboard_message_id"] = starboard_message.id data["starboard_message_id"] = starboard_message.id
await self.redis.set(str(message.id), json.dumps(data)) await self.redis.set(str(message.id), json.dumps(data))
else: else:
data = json.loads(data) data = json.loads(data)
try: try:
starboard_message = await self.get_or_fetch_message( starboard_message = await self.get_or_fetch_message(
data["starboard_channel_id"], data["starboard_channel_id"], data["starboard_message_id"]
data["starboard_message_id"]
) )
except discord.NotFound: except discord.NotFound:
if star_count <= 0: if star_count <= 0:
return return
starboard_message = await starboard_channel.send( starboard_message = await starboard_channel.send(
embeds=embed, embeds=embed, silent=True
silent=True
) )
data["starboard_message_id"] = starboard_message.id data["starboard_message_id"] = starboard_message.id
data["starboard_channel_id"] = starboard_message.channel.id data["starboard_channel_id"] = starboard_message.channel.id
@ -247,18 +274,20 @@ class Starboard(commands.Cog):
"Deleted message %s in %s, %s - lost all stars", "Deleted message %s in %s, %s - lost all stars",
starboard_message.id, starboard_message.id,
starboard_message.channel.name, starboard_message.channel.name,
starboard_message.guild.name starboard_message.guild.name,
) )
elif starboard_message.embeds[0] != embed[0]: elif starboard_message.embeds[0] != embed[0]:
await starboard_message.edit(embeds=embed) await starboard_message.edit(embeds=embed)
@commands.message_command(name="Preview Starboard Message") @commands.message_command(name="Preview Starboard Message")
async def preview_starboard_message(self, ctx: discord.ApplicationContext, message: discord.Message): async def preview_starboard_message(
self, ctx: discord.ApplicationContext, message: discord.Message
):
embed, stars = await self.generate_starboard_embed(message) embed, stars = await self.generate_starboard_embed(message)
data = await self.redis.get(str(message.id)) or 'null' data = await self.redis.get(str(message.id)) or "null"
return await ctx.respond( return await ctx.respond(
f"```json\n{json.dumps(json.loads(data), indent=4)}\n```\nStars: {stars:,}", f"```json\n{json.dumps(json.loads(data), indent=4)}\n```\nStars: {stars:,}",
embeds=embed embeds=embed,
) )

View file

@ -135,12 +135,21 @@ class YTDLCog(commands.Cog):
channel_id=excluded.channel_id, channel_id=excluded.channel_id,
attachment_index=excluded.attachment_index attachment_index=excluded.attachment_index
""", """,
(_hash, message.id, message.channel.id, webpage_url, format_id, attachment_index), (
_hash,
message.id,
message.channel.id,
webpage_url,
format_id,
attachment_index,
),
) )
await db.commit() await db.commit()
return _hash return _hash
async def get_saved(self, webpage_url: str, format_id: str, snip: str) -> typing.Optional[str]: async def get_saved(
self, webpage_url: str, format_id: str, snip: str
) -> typing.Optional[str]:
""" """
Attempts to retrieve the attachment URL of a previously saved download. Attempts to retrieve the attachment URL of a previously saved download.
:param webpage_url: The webpage url :param webpage_url: The webpage url
@ -154,12 +163,19 @@ class YTDLCog(commands.Cog):
logging.error("Failed to initialise ytdl database: %s", e, exc_info=True) logging.error("Failed to initialise ytdl database: %s", e, exc_info=True)
return return
async with aiosqlite.connect("./data/ytdl.db") as db: async with aiosqlite.connect("./data/ytdl.db") as db:
_hash = hashlib.md5(f"{webpage_url}:{format_id}:{snip}".encode()).hexdigest() _hash = hashlib.md5(
f"{webpage_url}:{format_id}:{snip}".encode()
).hexdigest()
self.log.debug( self.log.debug(
"Attempting to find a saved download for '%s:%s:%s' (%r).", webpage_url, format_id, snip, _hash "Attempting to find a saved download for '%s:%s:%s' (%r).",
webpage_url,
format_id,
snip,
_hash,
) )
cursor = await db.execute( cursor = await db.execute(
"SELECT message_id, channel_id, attachment_index FROM downloads WHERE key=?", (_hash,) "SELECT message_id, channel_id, attachment_index FROM downloads WHERE key=?",
(_hash,),
) )
entry = await cursor.fetchone() entry = await cursor.fetchone()
if not entry: if not entry:
@ -173,7 +189,9 @@ class YTDLCog(commands.Cog):
try: try:
message = await channel.fetch_message(message_id) message = await channel.fetch_message(message_id)
except discord.HTTPException: except discord.HTTPException:
self.log.debug("%r did not contain a message with ID %r", channel, message_id) self.log.debug(
"%r did not contain a message with ID %r", channel, message_id
)
await db.execute("DELETE FROM downloads WHERE key=?", (_hash,)) await db.execute("DELETE FROM downloads WHERE key=?", (_hash,))
return return
@ -182,7 +200,11 @@ class YTDLCog(commands.Cog):
self.log.debug("Found URL %r, returning.", url) self.log.debug("Found URL %r, returning.", url)
return url return url
except IndexError: except IndexError:
self.log.debug("Attachment index %d is out of range (%r)", attachment_index, message.attachments) self.log.debug(
"Attachment index %d is out of range (%r)",
attachment_index,
message.attachments,
)
return return
def convert_to_m4a(self, file: Path) -> Path: def convert_to_m4a(self, file: Path) -> Path:
@ -209,23 +231,32 @@ class YTDLCog(commands.Cog):
str(new_file), str(new_file),
] ]
self.log.debug("Running command: ffmpeg %s", " ".join(args)) self.log.debug("Running command: ffmpeg %s", " ".join(args))
process = subprocess.run(["ffmpeg", *args], stdout=subprocess.PIPE, stderr=subprocess.PIPE) process = subprocess.run(
["ffmpeg", *args], stdout=subprocess.PIPE, stderr=subprocess.PIPE
)
if process.returncode != 0: if process.returncode != 0:
raise RuntimeError(process.stderr.decode()) raise RuntimeError(process.stderr.decode())
return new_file return new_file
@staticmethod @staticmethod
async def upload_to_0x0(name: str, data: typing.IO[bytes], mime_type: str | None = None) -> str: async def upload_to_0x0(
name: str, data: typing.IO[bytes], mime_type: str | None = None
) -> str:
if not mime_type: if not mime_type:
import magic import magic
mime_type = await asyncio.to_thread(magic.from_buffer, data.read(4096), mime=True)
mime_type = await asyncio.to_thread(
magic.from_buffer, data.read(4096), mime=True
)
data.seek(0) data.seek(0)
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
response = await client.post( response = await client.post(
"https://0x0.st", "https://0x0.st",
files={"file": (name, data, mime_type)}, files={"file": (name, data, mime_type)},
data={"expires": 12}, data={"expires": 12},
headers={"User-Agent": "CollegeBot (see: https://gist.i-am.nexus/nex/f63fcb9eb389401caf66d1dfc3c7570c)"}, headers={
"User-Agent": "CollegeBot (see: https://gist.i-am.nexus/nex/f63fcb9eb389401caf66d1dfc3c7570c)"
},
) )
if response.status_code == 200: if response.status_code == 200:
return urlparse(response.text).path[1:] return urlparse(response.text).path[1:]
@ -235,7 +266,10 @@ class YTDLCog(commands.Cog):
async def yt_dl_command( async def yt_dl_command(
self, self,
ctx: discord.ApplicationContext, ctx: discord.ApplicationContext,
url: typing.Annotated[str, discord.Option(str, description="The URL to download from.", required=True)], url: typing.Annotated[
str,
discord.Option(str, description="The URL to download from.", required=True),
],
user_format: typing.Annotated[ user_format: typing.Annotated[
typing.Optional[str], typing.Optional[str],
discord.Option( discord.Option(
@ -258,7 +292,10 @@ class YTDLCog(commands.Cog):
], ],
snip: typing.Annotated[ snip: typing.Annotated[
typing.Optional[str], typing.Optional[str],
discord.Option(description="A start and end position to trim. e.g. 00:00:00-00:10:00.", required=False), discord.Option(
description="A start and end position to trim. e.g. 00:00:00-00:10:00.",
required=False,
),
], ],
subtitles: typing.Annotated[ subtitles: typing.Annotated[
typing.Optional[str], typing.Optional[str],
@ -267,7 +304,7 @@ class YTDLCog(commands.Cog):
description="The language code of the subtitles to download. e.g. 'en', 'auto'", description="The language code of the subtitles to download. e.g. 'en', 'auto'",
required=False, required=False,
), ),
] ],
): ):
"""Runs yt-dlp and outputs into discord.""" """Runs yt-dlp and outputs into discord."""
await ctx.defer() await ctx.defer()
@ -279,28 +316,40 @@ class YTDLCog(commands.Cog):
if stop.is_set(): if stop.is_set():
raise RuntimeError("Download cancelled.") raise RuntimeError("Download cancelled.")
n = time.time() n = time.time()
_total = _data.get("total_bytes", _data.get("total_bytes_estimate")) or ctx.guild.filesize_limit _total = (
_data.get("total_bytes", _data.get("total_bytes_estimate"))
or ctx.guild.filesize_limit
)
if _total: if _total:
_percent = round((_data.get("downloaded_bytes") or 0) / _total * 100, 2) _percent = round((_data.get("downloaded_bytes") or 0) / _total * 100, 2)
else: else:
_total = max(1, _data.get("fragment_count", 4096) or 4096) _total = max(1, _data.get("fragment_count", 4096) or 4096)
_percent = round(max(_data.get("fragment_index", 1) or 1, 1) / _total * 100, 2) _percent = round(
max(_data.get("fragment_index", 1) or 1, 1) / _total * 100, 2
)
_speed_bytes_per_second = _data.get("speed", 1) or 1 or 1 _speed_bytes_per_second = _data.get("speed", 1) or 1 or 1
_speed_megabits_per_second = round((_speed_bytes_per_second * 8) / 1024 / 1024) _speed_megabits_per_second = round(
(_speed_bytes_per_second * 8) / 1024 / 1024
)
if _data.get("eta"): if _data.get("eta"):
_eta = discord.utils.utcnow() + datetime.timedelta(seconds=_data.get("eta")) _eta = discord.utils.utcnow() + datetime.timedelta(
seconds=_data.get("eta")
)
else: else:
_eta = discord.utils.utcnow() + datetime.timedelta(minutes=1) _eta = discord.utils.utcnow() + datetime.timedelta(minutes=1)
blocks = "#" * math.floor(_percent / 10) blocks = "#" * math.floor(_percent / 10)
bar = f"{blocks}{'.' * (10 - len(blocks))}" bar = f"{blocks}{'.' * (10 - len(blocks))}"
line = (f"{_percent}% [{bar}] | {_speed_megabits_per_second}Mbps | " line = (
f"ETA {discord.utils.format_dt(_eta, 'R')}") f"{_percent}% [{bar}] | {_speed_megabits_per_second}Mbps | "
f"ETA {discord.utils.format_dt(_eta, 'R')}"
)
nonlocal last_edit nonlocal last_edit
if (n - last_edit) >= 1.1: if (n - last_edit) >= 1.1:
embed.clear_fields() embed.clear_fields()
embed.add_field(name="Progress", value=line) embed.add_field(name="Progress", value=line)
ctx.bot.loop.create_task(ctx.edit(embed=embed)) ctx.bot.loop.create_task(ctx.edit(embed=embed))
last_edit = time.time() last_edit = time.time()
options["progress_hooks"] = [_download_hook] options["progress_hooks"] = [_download_hook]
description = "" description = ""
@ -331,13 +380,19 @@ class YTDLCog(commands.Cog):
options["format_sort"] = ["abr", "br"] options["format_sort"] = ["abr", "br"]
# noinspection PyTypeChecker # noinspection PyTypeChecker
options["postprocessors"].append( options["postprocessors"].append(
{"key": "FFmpegExtractAudio", "preferredquality": "96", "preferredcodec": "best"} {
"key": "FFmpegExtractAudio",
"preferredquality": "96",
"preferredcodec": "best",
}
) )
options["format"] = chosen_format options["format"] = chosen_format
options["paths"] = paths options["paths"] = paths
if subtitles and ffmpeg_installed: if subtitles and ffmpeg_installed:
subtitles, burn = subtitles.split("+", 1) if "+" in subtitles else (subtitles, "0") subtitles, burn = (
subtitles.split("+", 1) if "+" in subtitles else (subtitles, "0")
)
burn = burn[0].lower() in ("y", "1", "t") burn = burn[0].lower() in ("y", "1", "t")
if subtitles.lower() == "auto": if subtitles.lower() == "auto":
options["writeautosubtitles"] = True options["writeautosubtitles"] = True
@ -352,10 +407,14 @@ class YTDLCog(commands.Cog):
) )
with yt_dlp.YoutubeDL(options) as downloader: with yt_dlp.YoutubeDL(options) as downloader:
await ctx.respond(embed=discord.Embed().set_footer(text="Downloading (step 1/10)")) await ctx.respond(
embed=discord.Embed().set_footer(text="Downloading (step 1/10)")
)
try: try:
# noinspection PyTypeChecker # noinspection PyTypeChecker
extracted_info = await asyncio.to_thread(downloader.extract_info, url, download=False) extracted_info = await asyncio.to_thread(
downloader.extract_info, url, download=False
)
except yt_dlp.utils.DownloadError as e: except yt_dlp.utils.DownloadError as e:
extracted_info = { extracted_info = {
"title": "error", "title": "error",
@ -382,22 +441,38 @@ class YTDLCog(commands.Cog):
thumbnail_url = extracted_info.get("thumbnail") or None thumbnail_url = extracted_info.get("thumbnail") or None
webpage_url = extracted_info.get("webpage_url", url) webpage_url = extracted_info.get("webpage_url", url)
chosen_format = extracted_info.get("format") or chosen_format or str(uuid.uuid4()) chosen_format = (
chosen_format_id = extracted_info.get("format_id") or str(uuid.uuid4()) extracted_info.get("format")
or chosen_format
or str(uuid.uuid4())
)
chosen_format_id = extracted_info.get("format_id") or str(
uuid.uuid4()
)
final_extension = extracted_info.get("ext") or "mp4" final_extension = extracted_info.get("ext") or "mp4"
format_note = extracted_info.get("format_note", "%s (%s)" % (chosen_format, chosen_format_id)) or "" format_note = (
extracted_info.get(
"format_note", "%s (%s)" % (chosen_format, chosen_format_id)
)
or ""
)
resolution = extracted_info.get("resolution") or "1x1" resolution = extracted_info.get("resolution") or "1x1"
fps = extracted_info.get("fps", 0.0) or 0.0 fps = extracted_info.get("fps", 0.0) or 0.0
vcodec = extracted_info.get("vcodec") or "h264" vcodec = extracted_info.get("vcodec") or "h264"
acodec = extracted_info.get("acodec") or "aac" acodec = extracted_info.get("acodec") or "aac"
filesize = extracted_info.get("filesize", extracted_info.get("filesize_approx", 1)) filesize = extracted_info.get(
likes = extracted_info.get("like_count", extracted_info.get("average_rating", 0)) "filesize", extracted_info.get("filesize_approx", 1)
)
likes = extracted_info.get(
"like_count", extracted_info.get("average_rating", 0)
)
views = extracted_info.get("view_count", 0) views = extracted_info.get("view_count", 0)
lines = [] lines = []
if chosen_format and chosen_format_id: if chosen_format and chosen_format_id:
lines.append( lines.append(
"* Chosen format: `%s` (`%s`)" % (chosen_format, chosen_format_id), "* Chosen format: `%s` (`%s`)"
% (chosen_format, chosen_format_id),
) )
if format_note: if format_note:
lines.append("* Format note: %r" % format_note) lines.append("* Format note: %r" % format_note)
@ -411,7 +486,9 @@ class YTDLCog(commands.Cog):
if vcodec or acodec: if vcodec or acodec:
lines.append("%s+%s" % (vcodec or "N/A", acodec or "N/A")) lines.append("%s+%s" % (vcodec or "N/A", acodec or "N/A"))
if filesize: if filesize:
lines.append("* Filesize: %s" % yt_dlp.utils.format_bytes(filesize)) lines.append(
"* Filesize: %s" % yt_dlp.utils.format_bytes(filesize)
)
if lines: if lines:
description += "\n" description += "\n"
@ -424,27 +501,29 @@ class YTDLCog(commands.Cog):
url=webpage_url, url=webpage_url,
colour=self.colours.get(domain, discord.Colour.og_blurple()), colour=self.colours.get(domain, discord.Colour.og_blurple()),
) )
embed.add_field( embed.add_field(name="Progress", value="0% [..........]")
name="Progress",
value="0% [..........]"
)
embed.set_footer(text="Downloading (step 2/10)") embed.set_footer(text="Downloading (step 2/10)")
embed.set_thumbnail(url=thumbnail_url) embed.set_thumbnail(url=thumbnail_url)
class StopView(discord.ui.View): class StopView(discord.ui.View):
@discord.ui.button(label="Cancel download", style=discord.ButtonStyle.danger) @discord.ui.button(
async def _stop(self, button: discord.ui.Button, interaction: discord.Interaction): label="Cancel download", style=discord.ButtonStyle.danger
)
async def _stop(
self,
button: discord.ui.Button,
interaction: discord.Interaction,
):
stop.set() stop.set()
button.label = "Cancelling..." button.label = "Cancelling..."
button.disabled = True button.disabled = True
await interaction.response.edit_message(view=self) await interaction.response.edit_message(view=self)
self.stop() self.stop()
await ctx.edit( await ctx.edit(embed=embed, view=StopView(timeout=86400))
embed=embed, previous = await self.get_saved(
view=StopView(timeout=86400) webpage_url, chosen_format_id, snip or "*"
) )
previous = await self.get_saved(webpage_url, chosen_format_id, snip or "*")
if previous: if previous:
await ctx.edit( await ctx.edit(
content=previous, content=previous,
@ -454,7 +533,11 @@ class YTDLCog(commands.Cog):
colour=discord.Colour.green(), colour=discord.Colour.green(),
timestamp=discord.utils.utcnow(), timestamp=discord.utils.utcnow(),
url=previous, url=previous,
fields=[discord.EmbedField(name="URL", value=previous, inline=False)], fields=[
discord.EmbedField(
name="URL", value=previous, inline=False
)
],
).set_image(url=previous), ).set_image(url=previous),
) )
return return
@ -462,7 +545,9 @@ class YTDLCog(commands.Cog):
last_edit = time.time() last_edit = time.time()
try: try:
await asyncio.to_thread(functools.partial(downloader.download, [url])) await asyncio.to_thread(
functools.partial(downloader.download, [url])
)
except yt_dlp.DownloadError as e: except yt_dlp.DownloadError as e:
logging.error(e, exc_info=True) logging.error(e, exc_info=True)
return await ctx.edit( return await ctx.edit(
@ -473,7 +558,7 @@ class YTDLCog(commands.Cog):
url=webpage_url, url=webpage_url,
), ),
delete_after=120, delete_after=120,
view=None view=None,
) )
except RuntimeError: except RuntimeError:
return await ctx.edit( return await ctx.edit(
@ -484,16 +569,26 @@ class YTDLCog(commands.Cog):
url=webpage_url, url=webpage_url,
), ),
delete_after=120, delete_after=120,
view=None view=None,
) )
await ctx.edit(view=None) await ctx.edit(view=None)
try: try:
if audio_only is False: if audio_only is False:
file: Path = next(temp_dir.glob("*." + extracted_info.get("ext", "*"))) file: Path = next(
temp_dir.glob("*." + extracted_info.get("ext", "*"))
)
else: else:
# can be .opus, .m4a, .mp3, .ogg, .oga # can be .opus, .m4a, .mp3, .ogg, .oga
for _file in temp_dir.iterdir(): for _file in temp_dir.iterdir():
if _file.suffix in (".opus", ".m4a", ".mp3", ".ogg", ".oga", ".aac", ".wav"): if _file.suffix in (
".opus",
".m4a",
".mp3",
".ogg",
".oga",
".aac",
".wav",
):
file: Path = _file file: Path = _file
break break
else: else:
@ -523,7 +618,9 @@ class YTDLCog(commands.Cog):
except ValueError: except ValueError:
trim_start, trim_end = snip, None trim_start, trim_end = snip, None
trim_start = trim_start or "00:00:00" trim_start = trim_start or "00:00:00"
trim_end = trim_end or extracted_info.get("duration_string", "00:30:00") trim_end = trim_end or extracted_info.get(
"duration_string", "00:30:00"
)
new_file = temp_dir / ("output" + file.suffix) new_file = temp_dir / ("output" + file.suffix)
args = [ args = [
"-hwaccel", "-hwaccel",
@ -562,7 +659,10 @@ class YTDLCog(commands.Cog):
) )
self.log.debug("Running command: 'ffmpeg %s'", " ".join(args)) self.log.debug("Running command: 'ffmpeg %s'", " ".join(args))
process = await asyncio.create_subprocess_exec( process = await asyncio.create_subprocess_exec(
"ffmpeg", *args, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE "ffmpeg",
*args,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
) )
stdout, stderr = await process.communicate() stdout, stderr = await process.communicate()
self.log.debug("STDOUT:\n%r", stdout.decode()) self.log.debug("STDOUT:\n%r", stdout.decode())
@ -606,31 +706,40 @@ class YTDLCog(commands.Cog):
) )
embed = discord.Embed( embed = discord.Embed(
title=f"Downloaded {title}!", title=f"Downloaded {title}!",
description="Views: {:,} | Likes: {:,}".format(views or 0, likes or 0), description="Views: {:,} | Likes: {:,}".format(
views or 0, likes or 0
),
colour=discord.Colour.green(), colour=discord.Colour.green(),
timestamp=discord.utils.utcnow(), timestamp=discord.utils.utcnow(),
url=webpage_url, url=webpage_url,
) )
try: try:
if size_bytes >= (20 * 1024 * 1024) or vcodec.lower() in ["hevc", "h265", "av1", "av01"]: if size_bytes >= (20 * 1024 * 1024) or vcodec.lower() in [
"hevc",
"h265",
"av1",
"av01",
]:
with file.open("rb") as fb: with file.open("rb") as fb:
part = await self.upload_to_0x0( part = await self.upload_to_0x0(file.name, fb)
file.name, embed.add_field(
fb name="URL", value=f"https://0x0.st/{part}", inline=False
)
embed.add_field(name="URL", value=f"https://0x0.st/{part}", inline=False)
await ctx.edit(
embed=embed
) )
await ctx.edit(embed=embed)
await ctx.respond("https://embeds.video/0x0/" + part) await ctx.respond("https://embeds.video/0x0/" + part)
else: else:
upload_file = await asyncio.to_thread(discord.File, file, filename=file.name) upload_file = await asyncio.to_thread(
msg = await ctx.edit( discord.File, file, filename=file.name
file=upload_file,
embed=embed
) )
await self.save_link(msg, webpage_url, chosen_format_id, snip=snip or "*") msg = await ctx.edit(file=upload_file, embed=embed)
except (discord.HTTPException, ConnectionError, httpx.HTTPStatusError) as e: await self.save_link(
msg, webpage_url, chosen_format_id, snip=snip or "*"
)
except (
discord.HTTPException,
ConnectionError,
httpx.HTTPStatusError,
) as e:
self.log.error(e, exc_info=True) self.log.error(e, exc_info=True)
return await ctx.edit( return await ctx.edit(
embed=discord.Embed( embed=discord.Embed(

View file

@ -21,7 +21,9 @@ if (Path.cwd() / ".git").exists():
log.debug("Unable to auto-detect running version using git.", exc_info=True) log.debug("Unable to auto-detect running version using git.", exc_info=True)
VERSION = "unknown" VERSION = "unknown"
else: else:
log.debug("Unable to auto-detect running version using git, no .git directory exists.") log.debug(
"Unable to auto-detect running version using git, no .git directory exists."
)
VERSION = "unknown" VERSION = "unknown"
try: try:
@ -32,7 +34,7 @@ except FileNotFoundError:
log.critical( log.critical(
"Unable to locate config.toml in %s. Using default configuration. Good luck!", "Unable to locate config.toml in %s. Using default configuration. Good luck!",
cwd, cwd,
exc_info=True exc_info=True,
) )
CONFIG = {} CONFIG = {}
CONFIG.setdefault("logging", {}) CONFIG.setdefault("logging", {})
@ -50,11 +52,13 @@ CONFIG.setdefault(
"api": "https://bots.nexy7574.co.uk/jimmy/v2/", "api": "https://bots.nexy7574.co.uk/jimmy/v2/",
"username": os.getenv("WEB_USERNAME", os.urandom(32).hex()), "username": os.getenv("WEB_USERNAME", os.urandom(32).hex()),
"password": os.getenv("WEB_PASSWORD", os.urandom(32).hex()), "password": os.getenv("WEB_PASSWORD", os.urandom(32).hex()),
} },
) )
if CONFIG["redis"].pop("no_ping", None) is not None: if CONFIG["redis"].pop("no_ping", None) is not None:
log.warning("`redis.no_ping` was deprecated after 808D621F. Ping is now always mandatory.") log.warning(
"`redis.no_ping` was deprecated after 808D621F. Ping is now always mandatory."
)
CONFIG["redis"]["decode_responses"] = True CONFIG["redis"]["decode_responses"] = True

View file

@ -61,9 +61,16 @@ class KumaThread(Thread):
response.raise_for_status() response.raise_for_status()
self.previous = bot.is_ready() self.previous = bot.is_ready()
except httpx.HTTPError as error: except httpx.HTTPError as error:
self.log.error("Failed to connect to uptime-kuma: %r: %r", url, error, exc_info=error) self.log.error(
"Failed to connect to uptime-kuma: %r: %r",
url,
error,
exc_info=error,
)
timeout = self.calculate_backoff() timeout = self.calculate_backoff()
self.log.warning("Waiting %d seconds before retrying ping.", timeout) self.log.warning(
"Waiting %d seconds before retrying ping.", timeout
)
time.sleep(timeout) time.sleep(timeout)
continue continue
@ -91,7 +98,11 @@ logging.basicConfig(
markup=True, markup=True,
console=Console(width=cols, height=lns), console=Console(width=cols, height=lns),
), ),
FileHandler(filename=CONFIG["logging"].get("file", "jimmy.log"), encoding="utf-8", errors="replace"), FileHandler(
filename=CONFIG["logging"].get("file", "jimmy.log"),
encoding="utf-8",
errors="replace",
),
], ],
) )
for logger in CONFIG["logging"].get("suppress", []): for logger in CONFIG["logging"].get("suppress", []):
@ -130,7 +141,8 @@ class Client(commands.Bot):
async def start(self, token: str, *, reconnect: bool = True) -> None: async def start(self, token: str, *, reconnect: bool = True) -> None:
if CONFIG["jimmy"].get("uptime_kuma_url"): if CONFIG["jimmy"].get("uptime_kuma_url"):
self.uptime_thread = KumaThread( self.uptime_thread = KumaThread(
CONFIG["jimmy"]["uptime_kuma_url"], CONFIG["jimmy"].get("uptime_kuma_interval", 58.0) CONFIG["jimmy"]["uptime_kuma_url"],
CONFIG["jimmy"].get("uptime_kuma_interval", 58.0),
) )
self.uptime_thread.start() self.uptime_thread.start()
await super().start(token, reconnect=reconnect) await super().start(token, reconnect=reconnect)
@ -138,7 +150,9 @@ class Client(commands.Bot):
async def close(self) -> None: async def close(self) -> None:
if getattr(self, "uptime_thread", None): if getattr(self, "uptime_thread", None):
self.uptime_thread.kill.set() self.uptime_thread.kill.set()
await asyncio.get_event_loop().run_in_executor(None, self.uptime_thread.join) await asyncio.get_event_loop().run_in_executor(
None, self.uptime_thread.join
)
await super().close() await super().close()
@ -189,9 +203,13 @@ async def on_application_command_error(ctx: discord.ApplicationContext, exc: Exc
log.error(f"Error in {ctx.command} from {ctx.author} in {ctx.guild}", exc_info=exc) log.error(f"Error in {ctx.command} from {ctx.author} in {ctx.guild}", exc_info=exc)
if isinstance(exc, commands.CommandOnCooldown): if isinstance(exc, commands.CommandOnCooldown):
expires = discord.utils.utcnow() + datetime.timedelta(seconds=exc.retry_after) expires = discord.utils.utcnow() + datetime.timedelta(seconds=exc.retry_after)
await ctx.respond(f"Command on cooldown. Try again {discord.utils.format_dt(expires, style='R')}.") await ctx.respond(
f"Command on cooldown. Try again {discord.utils.format_dt(expires, style='R')}."
)
elif isinstance(exc, commands.MaxConcurrencyReached): elif isinstance(exc, commands.MaxConcurrencyReached):
await ctx.respond("You've reached the maximum number of concurrent uses for this command.") await ctx.respond(
"You've reached the maximum number of concurrent uses for this command."
)
else: else:
if await bot.is_owner(ctx.author): if await bot.is_owner(ctx.author):
paginator = commands.Paginator(prefix="```py") paginator = commands.Paginator(prefix="```py")
@ -200,7 +218,10 @@ async def on_application_command_error(ctx: discord.ApplicationContext, exc: Exc
for page in paginator.pages: for page in paginator.pages:
await ctx.respond(page) await ctx.respond(page)
else: else:
await ctx.respond(f"An error occurred while processing your command. Please try again later.\n" f"{exc}") await ctx.respond(
f"An error occurred while processing your command. Please try again later.\n"
f"{exc}"
)
@bot.listen() @bot.listen()
@ -218,11 +239,22 @@ async def delete_message(ctx: discord.ApplicationContext, message: discord.Messa
await ctx.defer() await ctx.defer()
if not ctx.channel.permissions_for(ctx.me).manage_messages: if not ctx.channel.permissions_for(ctx.me).manage_messages:
if message.author != bot.user: if message.author != bot.user:
return await ctx.respond("I don't have permission to delete messages in this channel.", delete_after=30) return await ctx.respond(
"I don't have permission to delete messages in this channel.",
delete_after=30,
)
log.info("%s deleted message %s>%s: %r", ctx.author, ctx.channel.name, message.id, message.content) log.info(
"%s deleted message %s>%s: %r",
ctx.author,
ctx.channel.name,
message.id,
message.content,
)
await message.delete(delay=1) await message.delete(delay=1)
await ctx.respond(f"\N{white heavy check mark} Deleted message by {message.author.display_name}.") await ctx.respond(
f"\N{WHITE HEAVY CHECK MARK} Deleted message by {message.author.display_name}."
)
await ctx.delete(delay=5) await ctx.delete(delay=5)
@ -231,16 +263,19 @@ async def about_me(ctx: discord.ApplicationContext):
embed = discord.Embed( embed = discord.Embed(
title="Jimmy v3", title="Jimmy v3",
description="A bot specifically for the LCC discord server(s).", description="A bot specifically for the LCC discord server(s).",
colour=discord.Colour.green() colour=discord.Colour.green(),
) )
embed.add_field( embed.add_field(
name="Source", name="Source",
value="[Source code is available under AGPLv3.](https://git.i-am.nexus/nex/college-bot-v2)" value="[Source code is available under AGPLv3.](https://git.i-am.nexus/nex/college-bot-v2)",
) )
user = await ctx.bot.fetch_user(421698654189912064) user = await ctx.bot.fetch_user(421698654189912064)
appinfo = await ctx.bot.application_info() appinfo = await ctx.bot.application_info()
author = appinfo.owner author = appinfo.owner
embed.add_field(name="Author", value=f"{user.mention} created me. {author.mention} is running this instance.") embed.add_field(
name="Author",
value=f"{user.mention} created me. {author.mention} is running this instance.",
)
return await ctx.respond(embed=embed) return await ctx.respond(embed=embed)
@ -249,9 +284,8 @@ async def check_is_enabled(ctx: commands.Context | discord.ApplicationContext) -
disabled = CONFIG["jimmy"].get("disabled_commands", []) disabled = CONFIG["jimmy"].get("disabled_commands", [])
if ctx.command.qualified_name in disabled: if ctx.command.qualified_name in disabled:
raise commands.DisabledCommand( raise commands.DisabledCommand(
"%s is disabled via this instance's configuration file." % ( "%s is disabled via this instance's configuration file."
ctx.command.qualified_name % (ctx.command.qualified_name)
)
) )
return True return True
@ -273,7 +307,9 @@ async def add_delays(ctx: commands.Context | discord.ApplicationContext) -> bool
if not CONFIG["jimmy"].get("token"): if not CONFIG["jimmy"].get("token"):
log.critical("No token specified in config.toml. Exiting. (hint: set jimmy.token in config.toml)") log.critical(
"No token specified in config.toml. Exiting. (hint: set jimmy.token in config.toml)"
)
sys.exit(1) sys.exit(1)
__disabled = CONFIG["jimmy"].get("disabled_commands", []) __disabled = CONFIG["jimmy"].get("disabled_commands", [])

View file

@ -2,24 +2,24 @@
import json import json
import redis import redis
import orjson
import os import os
import secrets import secrets
import typing import typing
import time import time
from fastapi import FastAPI, Depends, HTTPException, APIRouter from fastapi import FastAPI, Depends, HTTPException, APIRouter
from fastapi.responses import JSONResponse, Response from fastapi.responses import JSONResponse, Response, ORJSONResponse
from fastapi.security import HTTPBasic, HTTPBasicCredentials from fastapi.security import HTTPBasic, HTTPBasicCredentials
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
JSON: typing.Union[ JSON: typing.Union[
str, int, float, bool, None, typing.Dict[str, "JSON"], typing.List["JSON"] str, int, float, bool, None, typing.Dict[str, "JSON"], typing.List["JSON"]
] = typing.Union[ ] = typing.Union[str, int, float, bool, None, dict, list]
str, int, float, bool, None, dict, list
]
class TruthPayload(BaseModel): class TruthPayload(BaseModel):
"""Represents a truth. This can be used to both create and get truths.""" """Represents a truth. This can be used to both create and get truths."""
id: str id: str
content: str content: str
author: str author: str
@ -33,6 +33,7 @@ class OllamaThread(BaseModel):
class ThreadMessage(BaseModel): class ThreadMessage(BaseModel):
"""Represents a message in an Ollama thread.""" """Represents a message in an Ollama thread."""
role: typing.Literal["assistant", "system", "user"] role: typing.Literal["assistant", "system", "user"]
content: str content: str
images: typing.Optional[list[str]] = [] images: typing.Optional[list[str]] = []
@ -54,9 +55,7 @@ USERNAME = os.getenv("WEB_USERNAME", os.urandom(32).hex())
PASSWORD = os.getenv("WEB_PASSWORD", os.urandom(32).hex()) PASSWORD = os.getenv("WEB_PASSWORD", os.urandom(32).hex())
accounts = { accounts = {USERNAME: PASSWORD}
USERNAME: PASSWORD
}
if os.path.exists("./web-accounts.json"): if os.path.exists("./web-accounts.json"):
with open("./web-accounts.json") as f: with open("./web-accounts.json") as f:
accounts.update(json.load(f)) accounts.update(json.load(f))
@ -71,7 +70,9 @@ def check_credentials(credentials: HTTPBasicCredentials = Depends(security)):
return credentials return credentials
def get_db_factory(n: int = 11) -> typing.Callable[[], typing.Generator[redis.Redis, None, None]]: def get_db_factory(
n: int = 11,
) -> typing.Callable[[], typing.Generator[redis.Redis, None, None]]:
def inner(): def inner():
uri = os.getenv("REDIS_URL", "redis://redis") uri = os.getenv("REDIS_URL", "redis://redis")
conn = redis.Redis.from_url(uri) conn = redis.Redis.from_url(uri)
@ -80,38 +81,58 @@ def get_db_factory(n: int = 11) -> typing.Callable[[], typing.Generator[redis.Re
yield conn yield conn
finally: finally:
conn.close() conn.close()
return inner return inner
app = FastAPI( app = FastAPI(
title="Jimmy v3 API", title="Jimmy v3 API",
version="3.1.0", version="3.1.0",
root_path=os.getenv("WEB_ROOT_PATH", "") + "/api" root_path=os.getenv("WEB_ROOT_PATH", "") + "/api",
) )
truth_router = APIRouter( truth_router = APIRouter(
prefix="/truths", prefix="/truths", dependencies=[Depends(check_credentials)], tags=["Truth Social"]
dependencies=[Depends(check_credentials)],
tags=["Truth Social"]
) )
@truth_router.get("", response_model=list[TruthPayload]) @truth_router.get("", response_class=ORJSONResponse, response_model=list[TruthPayload])
def get_all_truths(rich: bool = True, db: redis.Redis = Depends(get_db_factory())): def get_all_truths(
"""Retrieves all stored truths""" rich: bool = True,
limit: int = -1,
page: int = 0,
db: redis.Redis = Depends(get_db_factory()),
):
"""Retrieves all stored truths
If ?limit is a positive integer, pagination will be enabled."""
query_start = time.perf_counter()
keys = db.keys() keys = db.keys()
query_end = time.perf_counter()
if rich is False: if rich is False:
return [ return [
{"id": key, "content": "", "author": "", "timestamp": 0, "extra": None} {"id": key, "content": "", "author": "", "timestamp": 0, "extra": None}
for key in keys for key in keys
] ]
load_start = time.perf_counter()
if limit >= 0:
keys = keys[page * limit: (page + 1) * limit]
truths = [json.loads(db.get(key)) for key in keys] truths = [json.loads(db.get(key)) for key in keys]
return truths load_end = time.perf_counter()
server_timing = "query;dur=%.2f, load;dur=%.2f" % (
(query_end - query_start) * 1000,
(load_end - load_start) * 1000,
)
return ORJSONResponse(truths, headers={"Server-Timing": server_timing})
@truth_router.get("/all", deprecated=True, response_model=list[TruthPayload]) @truth_router.get("/all", deprecated=True, response_model=list[TruthPayload])
def get_all_truths_deprecated(response: JSONResponse, rich: bool = True, db: redis.Redis = Depends(get_db_factory())): def get_all_truths_deprecated(
response: JSONResponse,
rich: bool = True,
db: redis.Redis = Depends(get_db_factory()),
):
"""DEPRECATED - USE get_all_truths INSTEAD""" """DEPRECATED - USE get_all_truths INSTEAD"""
return get_all_truths(rich, db) return get_all_truths(rich=rich, db=db)
@truth_router.get("/{truth_id}", response_model=TruthPayload) @truth_router.get("/{truth_id}", response_model=TruthPayload)
@ -133,7 +154,11 @@ def head_truth(truth_id: str, db: redis.Redis = Depends(get_db_factory())):
@truth_router.post("", status_code=201, response_model=TruthPayload) @truth_router.post("", status_code=201, response_model=TruthPayload)
def new_truth(payload: TruthPayload, response: JSONResponse, db: redis.Redis = Depends(get_db_factory())): def new_truth(
payload: TruthPayload,
response: JSONResponse,
db: redis.Redis = Depends(get_db_factory()),
):
"""Creates a new truth""" """Creates a new truth"""
data = payload.model_dump() data = payload.model_dump()
existing: str = db.get(data["id"]) existing: str = db.get(data["id"])
@ -148,7 +173,9 @@ def new_truth(payload: TruthPayload, response: JSONResponse, db: redis.Redis = D
@truth_router.put("/{truth_id}", response_model=TruthPayload) @truth_router.put("/{truth_id}", response_model=TruthPayload)
def put_truth(truth_id: str, payload: TruthPayload, db: redis.Redis = Depends(get_db_factory())): def put_truth(
truth_id: str, payload: TruthPayload, db: redis.Redis = Depends(get_db_factory())
):
"""Replaces a stored truth""" """Replaces a stored truth"""
data = payload.model_dump() data = payload.model_dump()
existing = db.get(truth_id) existing = db.get(truth_id)
@ -168,9 +195,7 @@ def delete_truth(truth_id: str, db: redis.Redis = Depends(get_db_factory())):
app.include_router(truth_router) app.include_router(truth_router)
ollama_router = APIRouter( ollama_router = APIRouter(
prefix="/ollama", prefix="/ollama", dependencies=[Depends(check_credentials)], tags=["Ollama"]
dependencies=[Depends(check_credentials)],
tags=["Ollama"]
) )
@ -185,13 +210,13 @@ def get_ollama_threads(db: redis.Redis = Depends(get_db_factory(0))):
return keys return keys
@ollama_router.get("/thread/{thread_id}", response_model=OllamaThread) @ollama_router.get("/thread/{thread_id}", response_class=ORJSONResponse, response_model=OllamaThread)
def get_ollama_thread(thread_id: str, db: redis.Redis = Depends(get_db_factory(0))): def get_ollama_thread(thread_id: str, db: redis.Redis = Depends(get_db_factory(0))):
"""Retrieves a stored thread""" """Retrieves a stored thread"""
data: str = db.get(thread_id) data: str = db.get(thread_id)
if not data: if not data:
raise HTTPException(404, detail="%r not found." % thread_id) raise HTTPException(404, detail="%r not found." % thread_id)
return json.loads(data) return ORJSONResponse(orjson.loads(data.encode()))
@ollama_router.delete("/thread/{thread_id}", status_code=204) @ollama_router.delete("/thread/{thread_id}", status_code=204)
@ -216,4 +241,5 @@ def health(db: redis.Redis = Depends(get_db_factory())):
if __name__ == "__main__": if __name__ == "__main__":
import uvicorn import uvicorn
uvicorn.run(app, host="0.0.0.0", port=1111, forwarded_allow_ips="*") uvicorn.run(app, host="0.0.0.0", port=1111, forwarded_allow_ips="*")