Merge remote-tracking branch 'origin/master'
All checks were successful
Build and Publish / build_and_publish (push) Successful in 1m6s

This commit is contained in:
Nexus 2024-07-22 23:50:55 +01:00
commit 27b23d3e3a
Signed by: nex
GPG key ID: 0FA334385D0B689F
15 changed files with 1311 additions and 553 deletions

View file

@ -21,3 +21,4 @@ aiofiles~=23.2
fuzzywuzzy[speedup]~=0.18
tortoise-orm[asyncpg]~=0.21
superpaste @ git+https://github.com/nexy7574/superpaste.git@e31eca6
orjson~=3.10

View file

@ -33,11 +33,16 @@ class AutoResponder(commands.Cog):
raise ValueError(
"overpowered_by is set, however members intent is disabled, making it useless."
)
if self.config.get("overrule_offline_superiors", True) is True and bot.intents.presences is False:
if (
self.config.get("overrule_offline_superiors", True) is True
and bot.intents.presences is False
):
raise ValueError(
"overrule_offline_superiors is enabled, however presences intent is not!"
)
self.config.setdefault("transcoding", {"enabled": True, "hevc": True, "on_demand": True})
self.config.setdefault(
"transcoding", {"enabled": True, "hevc": True, "on_demand": True}
)
self.lmgtfy_cache = []
@property
@ -50,7 +55,9 @@ class AutoResponder(commands.Cog):
@property
def allow_on_demand_transcoding(self) -> bool:
return self.transcoding_enabled and self.config["transcoding"].get("on_demand", True)
return self.transcoding_enabled and self.config["transcoding"].get(
"on_demand", True
)
def overpowered(self, guild: discord.Guild | None) -> bool:
if not guild:
@ -68,14 +75,20 @@ class AutoResponder(commands.Cog):
@staticmethod
@typing.overload
def extract_links(text: str, *domains: str, raw: typing.Literal[True] = False) -> list[ParseResult]: ...
def extract_links(
text: str, *domains: str, raw: typing.Literal[True] = False
) -> list[ParseResult]: ...
@staticmethod
@typing.overload
def extract_links(text: str, *domains: str, raw: typing.Literal[False] = False) -> list[str]: ...
def extract_links(
text: str, *domains: str, raw: typing.Literal[False] = False
) -> list[str]: ...
@staticmethod
def extract_links(text: str, *domains: str, raw: bool = False) -> list[str | ParseResult]:
def extract_links(
text: str, *domains: str, raw: bool = False
) -> list[str | ParseResult]:
"""
Extracts all links from a given text.
@ -124,10 +137,13 @@ class AutoResponder(commands.Cog):
if not update:
return
if last_reaction is not None:
_ = asyncio.create_task(update.remove_reaction(last_reaction, self.bot.user))
_ = asyncio.create_task(
update.remove_reaction(last_reaction, self.bot.user)
)
if new:
_ = asyncio.create_task(update.add_reaction(new))
last_reaction = new
self.log.info("Waiting for transcode lock to release")
async with self.transcode_lock:
cog = FFMeta(self.bot)
@ -147,6 +163,7 @@ class AutoResponder(commands.Cog):
return update_reaction("\N{TIMER CLOCK}\U0000fe0f")
streams = info.get("streams", [])
hwaccel = True
maxrate = "5M"
for stream in streams:
self.log.info("Found stream: %s", stream.get("codec_name"))
if stream.get("codec_name") == "hevc":
@ -159,8 +176,14 @@ class AutoResponder(commands.Cog):
hwaccel = False
break
else:
self.log.info("No HEVC streams found in %s", uri)
return update_reaction()
if int(info["format"]["size"]) >= 25 * 1024 * 1024:
self.log.warning(
"%s is too large to render in discord, compressing", uri
)
maxrate = "1M"
else:
self.log.info("No HEVC streams found in %s", uri)
return update_reaction()
extension = pathlib.Path(uri).suffix
with tempfile.NamedTemporaryFile(suffix=extension) as tmp_dl:
self.log.info("Downloading %r to %r", uri, tmp_dl.name)
@ -195,10 +218,8 @@ class AutoResponder(commands.Cog):
tmp_dl.name,
"-c:v",
"libx264",
"-crf",
"25",
"-maxrate",
"5M",
maxrate,
"-minrate",
"100K",
"-bufsize",
@ -216,7 +237,7 @@ class AutoResponder(commands.Cog):
"-movflags",
"faststart",
"-profile:v",
"main",
"high",
"-y",
"-hide_banner",
]
@ -230,7 +251,9 @@ class AutoResponder(commands.Cog):
stderr=asyncio.subprocess.PIPE,
)
stdout, stderr = await process.communicate()
self.log.info("finished transcode with return code %d", process.returncode)
self.log.info(
"finished transcode with return code %d", process.returncode
)
self.log.debug("stdout: %r", stdout.decode)
self.log.debug("stderr: %r", stderr.decode)
update_reaction()
@ -243,12 +266,9 @@ class AutoResponder(commands.Cog):
)
self._cooldown_transcode()
return discord.File(tmp_path), tmp_path
async def transcode_hevc_to_h264(
self,
message: discord.Message,
*domains: str,
additional: Iterable[str] = None
self, message: discord.Message, *domains: str, additional: Iterable[str] = None
) -> None:
if not shutil.which("ffmpeg"):
self.log.error("ffmpeg not installed")
@ -283,7 +303,9 @@ class AutoResponder(commands.Cog):
self.log.info("Found link to transcode: %r", link)
try:
async with message.channel.typing():
_r = await self._transcode_hevc_to_h264(link, update=message)
_r = await self._transcode_hevc_to_h264(
link, update=message
)
if not _r:
continue
file, _p = _r
@ -291,7 +313,9 @@ class AutoResponder(commands.Cog):
if _p.stat().st_size <= 24.5 * 1024 * 1024:
await message.add_reaction("\N{OUTBOX TRAY}")
await message.reply(file=file)
await message.remove_reaction("\N{OUTBOX TRAY}", self.bot.user)
await message.remove_reaction(
"\N{OUTBOX TRAY}", self.bot.user
)
else:
await message.add_reaction("\N{OUTBOX TRAY}")
self.log.warning(
@ -301,16 +325,31 @@ class AutoResponder(commands.Cog):
)
if _p.stat().st_size <= 510 * 1024 * 1024:
file.fp.seek(0)
self.log.info("Trying to upload file to pastebin.")
self.log.info(
"Trying to upload file to pastebin."
)
async with httpx.AsyncClient() as client:
response = await client.post(
"https://0x0.st",
files={"file": (_p.name, file.fp, "video/mp4")},
headers={"User-Agent": "CollegeBot (matrix: @nex:nexy7574.co.uk)"},
files={
"file": (
_p.name,
file.fp,
"video/mp4",
)
},
headers={
"User-Agent": "CollegeBot (matrix: @nex:nexy7574.co.uk)"
},
)
await message.remove_reaction(
"\N{OUTBOX TRAY}", self.bot.user
)
await message.remove_reaction("\N{OUTBOX TRAY}", self.bot.user)
if response.status_code == 200:
await message.reply("https://embeds.video/" + response.text.strip())
await message.reply(
"https://embeds.video/"
+ response.text.strip()
)
else:
await message.add_reaction("\N{BUG}")
response.raise_for_status()
@ -321,12 +360,16 @@ class AutoResponder(commands.Cog):
except Exception as e:
self.log.error("Failed to transcode %r: %r", link, e)
async def copy_ncfe_docs(self, message: discord.Message, links: list[ParseResult]) -> None:
async def copy_ncfe_docs(
self, message: discord.Message, links: list[ParseResult]
) -> None:
files = []
if self.config.get("download_pdfs", True) is False:
self.log.debug("Download PDFs is disabled in config, disengaging.")
return
self.log.info("Preparing to download: %s", ", ".join(map(ParseResult.geturl, links)))
self.log.info(
"Preparing to download: %s", ", ".join(map(ParseResult.geturl, links))
)
async with aiohttp.ClientSession() as session:
for link in set(links):
if link.path.endswith(".pdf"):
@ -338,7 +381,7 @@ class AutoResponder(commands.Cog):
"Failed to download %s: HTTP %d - %r",
link,
response.status,
await response.text()
await response.text(),
)
continue
async for chunk in response.content.iter_any():
@ -351,7 +394,9 @@ class AutoResponder(commands.Cog):
self.log.warning("File was too large to upload. Skipping.")
continue
p = pathlib.Path(link.path).name
file = discord.File(buffer, filename=p, description="Copy of %s" % link.geturl())
file = discord.File(
buffer, filename=p, description="Copy of %s" % link.geturl()
)
files.append(file)
for file in files:
await message.reply(file=file)
@ -379,27 +424,38 @@ class AutoResponder(commands.Cog):
self.log.info("Got VHS reaction, scanning for transcode")
extra = []
if reaction.message.attachments:
extra = [attachment.url for attachment in reaction.message.attachments]
extra = [
attachment.url for attachment in reaction.message.attachments
]
if self.allow_on_demand_transcoding:
await self.transcode_hevc_to_h264(reaction.message, additional=extra)
await self.transcode_hevc_to_h264(
reaction.message, additional=extra
)
elif str(reaction.emoji) == "\U0001f310":
if reaction.message.id not in self.lmgtfy_cache:
url = "https://lmddgtfy.net/?q=" + quote_plus(reaction.message.content)
m = await reaction.message.reply(f"[Here's the answer to your question]({url})")
url = "https://lmddgtfy.net/?q=" + quote_plus(
reaction.message.content
)
m = await reaction.message.reply(
f"[Here's the answer to your question]({url})"
)
await m.edit(suppress=True)
self.lmgtfy_cache.append(reaction.message.id)
elif str(reaction.emoji)[0] == "\N{wastebasket}":
if reaction.message.channel.permissions_for(reaction.message.guild.me).manage_messages:
self.log.info("Deleting message %s (Wastebasket)" % reaction.message.jump_url)
elif str(reaction.emoji)[0] == "\N{WASTEBASKET}":
if reaction.message.channel.permissions_for(
reaction.message.guild.me
).manage_messages:
self.log.info(
"Deleting message %s (Wastebasket)" % reaction.message.jump_url
)
await reaction.message.delete(
reason="%s requested deletion of message" % user,
delay=0.2
reason="%s requested deletion of message" % user, delay=0.2
)
else:
self.log.warning(
"Unable to delete message %s (wastebasket) - missing permissions",
reaction.message.jump_url
reaction.message.jump_url,
)
@commands.Cog.listener("on_raw_reaction_add")
@ -415,8 +471,13 @@ class AutoResponder(commands.Cog):
_e = discord.PartialEmoji.from_str(str(payload.emoji))
reaction = discord.Reaction(
message=message,
data={"emoji": _e, "count": 1, "me": payload.user_id == self.bot.user.id, "burst": False},
emoji=payload.emoji
data={
"emoji": _e,
"count": 1,
"me": payload.user_id == self.bot.user.id,
"burst": False,
},
emoji=payload.emoji,
)
user = self.bot.get_user(payload.user_id)
await self.on_reaction_add(reaction, user)

View file

@ -1,140 +1,220 @@
"""
This module is only meant to be loaded during election times.
"""
import asyncio
import logging
import random
import datetime
import re
import discord
import httpx
from bs4 import BeautifulSoup
from discord.ext import commands
from discord.ext import commands, tasks
SPAN_REGEX = re.compile(
r"^(?P<party>\D+)(?P<councillors>[0-9,]+)\scouncillors\s(?P<net>[0-9,]+)\scouncillors\s(?P<net_change>(gained|lost))$"
)
MULTI: dict[str, int] = {
"gained": 1,
"lost": -1
}
MULTI: dict[str, int] = {"gained": 1, "lost": -1}
class ElectionCog(commands.Cog):
SOURCE = "https://bbc.com/"
HEADERS = {
"Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,"
"image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.7",
"image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.7",
"Accept-Language": "en-GB,en-US;q=0.9,en;q=0.8",
"Priority": "u=0, i",
"Sec-Ch-Ua": "\"Chromium\";v=\"124\", \"Google Chrome\";v=\"124\", \"Not-A.Brand\";v=\"99\"",
"Sec-Ch-Ua": '"Chromium";v="124", "Google Chrome";v="124", "Not-A.Brand";v="99"',
"Sec-Ch-Ua-Mobile": "?0",
"Sec-Ch-Ua-Platform": "\"Linux\"",
"Sec-Ch-Ua-Platform": '"Linux"',
"Sec-Fetch-Dest": "document",
"Sec-Fetch-Mode": "navigate",
"Sec-Fetch-Site": "none",
"Sec-Fetch-User": "?1",
"Upgrade-Insecure-Requests": "1",
"User-Agent": "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/124.0.0.0 "
"Safari/537.36",
"Safari/537.36",
}
ETA = datetime.datetime(
2024, 7, 4, 23, 30, tzinfo=datetime.datetime.now().astimezone().tzinfo
)
def __init__(self, bot):
self.bot = bot
self.bot: commands.Bot = bot
self.log = logging.getLogger("jimmy.cogs.election")
self.countdown_message = None
# self.check_election.start()
def cog_unload(self) -> None:
self.check_election.cancel()
@tasks.loop(minutes=1)
async def check_election(self):
if not self.bot.is_ready():
await self.bot.wait_until_ready()
guild = self.bot.get_guild(994710566612500550)
if not guild:
return self.log.error("Nonsense guild not found. Can't do countdown.")
channel = discord.utils.get(guild.text_channels, name="countdown")
if not channel:
return self.log.error("Countdown channel not found.")
await asyncio.sleep(random.randint(0, 10))
now = discord.utils.utcnow()
diff = (self.ETA - now).total_seconds()
if diff < -86400:
return self.log.debug("Countdown long expired.")
if diff > 3600:
hours, remainder = map(round, divmod(diff, 3600))
minutes, seconds = map(round, divmod(remainder, 60))
message = f"+ {hours} hours, {minutes} minutes, and {seconds} seconds."
elif diff > 60:
minutes, seconds = map(round, divmod(diff, 60))
message = f"+ {minutes} minutes and {seconds} seconds."
elif diff >= 0:
message = f"+ {round(diff)} seconds."
else:
message = "Results time!"
if self.countdown_message:
try:
return await self.countdown_message.edit(
content=f"```diff\n{message}```"
)
except discord.HTTPException:
self.log.exception("Failed to edit countdown message.")
self.countdown_message = await channel.send(f"```diff\n{message}```")
self.log.debug("Sent countdown message")
def process_soup(self, soup: BeautifulSoup) -> dict[str, list[int]] | None:
good_soup = soup.find(attrs={"data-testid": "election-banner-results-bar"})
if not good_soup:
good_soups = list(
soup.find_all(attrs={"data-testid": "election-banner-results-bar"})
)
if not good_soups:
self.log.error(
"No 'election-banner-results-bar' elements found:\n%r", soup.prettify()
)
return
css: str = "\n".join([x.get_text() for x in soup.find_all("style")])
def find_colour(style_name: str, want: str = "background-color") -> str | None:
self.log.info("Looking for style %r", style_name)
index = css.index(style_name) + len(style_name) + 1
value = ""
for char in css[index:]:
if char == "}":
break
value += char
attributes = filter(None, value.split(";"))
parsed = {}
for attr in attributes:
name, val = attr.split(":")
parsed[name] = val
self.log.info("Parsed the following attributes: %r", parsed)
if want in parsed:
self.log.info("Returning %r: %r", want, parsed[want])
return parsed[want]
self.log.warning("%r was not in attributes.", want)
good_soup = list(good_soups)[1]
results: dict[str, list[int]] = {}
for child_ul in good_soup.children:
child_ul: BeautifulSoup
span = child_ul.find("span", recursive=False)
if not span:
self.log.warning("%r did not have a 'span' element.", child_ul)
for child_li in good_soup.children:
span = list(child_li.children)[-1]
try:
party, extra = span.get_text().strip().split(":", 1)
seats, extra = extra.split(",", 1)
extra = extra.strip()
if extra.lower() == "no change":
change = 0
else:
_values = extra.split()
change = int(_values[0])
if _values[-1] != "gained":
change *= -1
seats = int(seats.split()[0])
except ValueError:
self.log.error("failed to parse %r", span)
continue
text = span.get_text().replace(",", "")
groups = SPAN_REGEX.match(text)
if groups:
groups = groups.groupdict()
else:
self.log.warning(
"Found span element (%r), however resolved text (%r) did not match regex.",
span, text
)
results[party] = [seats, change, 0, 0]
for child_li in good_soups[0].children:
span = list(child_li.children)[-1]
try:
party, extra = span.get_text().strip().split(":", 1)
seats, _ = extra.strip().split(" ", 1)
seats = int(seats.strip())
except ValueError:
self.log.error("failed to parse %r", span)
continue
results[str(groups["party"]).strip()] = [
int(groups["councillors"].strip()),
int(groups["net"].strip()) * MULTI[groups["net_change"]],
int(find_colour(child_ul.next["class"][0])[1:], base=16)
]
if party in results:
results[party][3] = seats
return results
async def _get_embed(self) -> discord.Embed | None:
async with httpx.AsyncClient(headers=self.HEADERS) as client:
response = await client.get(self.SOURCE, follow_redirects=True)
if response.status_code != 200:
raise RuntimeError(
f"HTTP {response.status_code} while fetching results from BBC"
)
soup = await asyncio.to_thread(BeautifulSoup, response.text, "html.parser")
results = await self.bot.loop.run_in_executor(None, self.process_soup, soup)
if results:
now = discord.utils.utcnow()
date = now.date().strftime("%B %Y")
colour_scores = {}
embed = discord.Embed(
title="Election results - " + date,
url="https://bbc.co.uk/",
timestamp=now,
)
embed.set_footer(text="Source from bbc.co.uk.")
description_parts = []
for party_name, values in results.items():
councillors, net, colour, last_election = values
colour_scores[party_name] = councillors
symbol = "+" if net > 0 else ""
description_parts.append(
f"**{party_name}**: {symbol}{net:,} ({councillors:,} total, "
f"{last_election:,} predicted by exit poll)"
)
top_party = list(
sorted(
colour_scores.keys(), key=lambda k: colour_scores[k], reverse=True
)
)[0]
embed.colour = discord.Colour(results[top_party][2])
embed.description = "\n".join(description_parts)
return embed
@commands.slash_command(name="election")
async def get_election_results(self, ctx: discord.ApplicationContext):
"""Gets the current election results"""
await ctx.defer()
async with httpx.AsyncClient(headers=self.HEADERS) as client:
response = await client.get(self.SOURCE, follow_redirects=True)
if response.status_code != 200:
return await ctx.respond(
"Sorry, I can't do that right now (HTTP %d while fetching results from BBC)" % response.status_code
)
# noinspection PyTypeChecker
soup = await asyncio.to_thread(BeautifulSoup, response.text, "html.parser")
results = await self.bot.loop.run_in_executor(None, self.process_soup, soup)
if results:
now = discord.utils.utcnow()
date = now.date().strftime("%B %Y")
colour_scores = {}
embed = discord.Embed(
title="Election results - " + date,
url="https://bbc.co.uk/",
timestamp=now
)
embed.set_footer(text="Source from bbc.co.uk.")
description_parts = []
class RefreshView(discord.ui.View):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.last_edit = discord.utils.utcnow()
for party_name, values in results.items():
councillors, net, colour = values
colour_scores[party_name] = councillors
symbol = "+" if net > 0 else ''
description_parts.append(
f"**{party_name}**: {symbol}{net:,} ({councillors:,} total)"
@discord.ui.button(
label="Refresh", style=discord.ButtonStyle.primary, emoji="\U0001f501"
)
async def refresh(_self, _btn, interaction):
await interaction.response.defer(invisible=True)
if (discord.utils.utcnow() - self.last_edit).total_seconds() < 10:
return await interaction.followup.send("Slow down.", ephemeral=True)
try:
embed = await self._get_embed()
except Exception as e:
self.log.exception("Failed to get election results.")
return await interaction.followup.send(
f"Sorry, I cannot contact the BBC at this time: {e}"
)
if embed is None:
return await interaction.followup.send(
"Sorry, I could not find any election results."
)
await interaction.edit_original_response(embed=embed)
top_party = list(sorted(colour_scores.keys(), key=lambda k: colour_scores[k], reverse=True))[0]
embed.colour = discord.Colour(results[top_party][2])
embed.description = "\n".join(description_parts)
return await ctx.respond(embed=embed)
else:
return await ctx.respond("Unable to get election results at this time.")
await ctx.defer()
try:
embed = await self._get_embed()
except Exception as e:
self.log.exception("Failed to get election results.")
return await ctx.respond(
f"Sorry, I cannot contact the BBC at this time: {e}"
)
if embed is None:
return await ctx.respond("Sorry, I could not find any election results.")
await ctx.respond(
embed=embed, view=RefreshView(timeout=3600, disable_on_timeout=True)
)
def setup(bot):

View file

@ -29,25 +29,44 @@ class FFMeta(commands.Cog):
self.bot = bot
self.log = logging.getLogger("jimmy.cogs.ffmeta")
def jpegify_image(self, input_file: io.BytesIO, quality: int = 50, image_format: str = "jpeg") -> io.BytesIO:
def jpegify_image(
self, input_file: io.BytesIO, quality: int = 50, image_format: str = "jpeg"
) -> io.BytesIO:
quality = min(1, max(quality, 100))
img_src = PIL.Image.open(input_file)
if image_format == "jpeg":
img_src = img_src.convert("RGB")
img_dst = io.BytesIO()
self.log.debug("Saving input file (%r) as %r with quality %r%%", input_file, image_format, quality)
self.log.debug(
"Saving input file (%r) as %r with quality %r%%",
input_file,
image_format,
quality,
)
img_src.save(img_dst, format=image_format, quality=quality)
img_dst.seek(0)
return img_dst
async def _run_ffprobe(self, uri: str | pathlib.Path, as_json: bool = False) -> dict | str:
async def _run_ffprobe(
self, uri: str | pathlib.Path, as_json: bool = False
) -> dict | str:
"""
Runs ffprobe on the given target (either file path or URL) and returns the result
:param uri: the URI to run ffprobe on
:return: The result
"""
_bin = "ffprobe"
cmd = ["-hide_banner", "-v", "quiet", "-print_format", "json", "-show_streams", "-show_format", "-i", str(uri)]
cmd = [
"-hide_banner",
"-v",
"quiet",
"-print_format",
"json",
"-show_streams",
"-show_format",
"-i",
str(uri),
]
if not as_json:
cmd = ["-hide_banner", "-i", str(uri)]
process = await asyncio.create_subprocess_exec(
@ -62,7 +81,12 @@ class FFMeta(commands.Cog):
return stderr.decode(errors="replace")
@commands.slash_command()
async def ffprobe(self, ctx: discord.ApplicationContext, url: str = None, attachment: discord.Attachment = None):
async def ffprobe(
self,
ctx: discord.ApplicationContext,
url: str = None,
attachment: discord.Attachment = None,
):
"""Runs ffprobe on a given URL or attachment"""
if not shutil.which("ffprobe"):
return await ctx.respond("ffprobe is not installed on this system.")
@ -101,7 +125,10 @@ class FFMeta(commands.Cog):
image_format: typing.Annotated[
str,
discord.Option(
str, description="The format of the resulting image", choices=["jpeg", "webp"], default="jpeg"
str,
description="The format of the resulting image",
choices=["jpeg", "webp"],
default="jpeg",
),
] = "jpeg",
):
@ -115,7 +142,9 @@ class FFMeta(commands.Cog):
src = io.BytesIO()
async with httpx.AsyncClient(
headers={"User-Agent": f"DiscordBot (Jimmy, v2, {VERSION}, +https://github.com/nexy7574/college-bot-v2)"}
headers={
"User-Agent": f"DiscordBot (Jimmy, v2, {VERSION}, +https://github.com/nexy7574/college-bot-v2)"
}
) as client:
response = await client.get(url)
if response.status_code != 200:
@ -123,13 +152,17 @@ class FFMeta(commands.Cog):
src.write(response.content)
try:
dst = await asyncio.to_thread(self.jpegify_image, src, quality, image_format)
dst = await asyncio.to_thread(
self.jpegify_image, src, quality, image_format
)
except Exception as e:
await ctx.respond(f"Failed to convert image: `{e}`.")
self.log.error("Failed to convert image %r: %r", url, e)
return
else:
await ctx.respond(file=discord.File(dst, filename=f"jpegified.{image_format}"))
await ctx.respond(
file=discord.File(dst, filename=f"jpegified.{image_format}")
)
@commands.slash_command()
async def opusinate(
@ -148,7 +181,10 @@ class FFMeta(commands.Cog):
),
] = 96,
mono: typing.Annotated[
bool, discord.Option(bool, description="Whether to convert the audio to mono", default=False)
bool,
discord.Option(
bool, description="Whether to convert the audio to mono", default=False
),
] = False,
):
"""Converts a given URL or attachment to an Opus file"""
@ -195,7 +231,9 @@ class FFMeta(commands.Cog):
)
stdout, stderr = await probe_process.communicate()
stdout = stdout.decode("utf-8", "replace")
data = {"format": {"duration": 195}} # 3 minutes and 15 seconds is the 2023 average.
data = {
"format": {"duration": 195}
} # 3 minutes and 15 seconds is the 2023 average.
if stdout:
try:
data = json.loads(stdout)
@ -206,7 +244,9 @@ class FFMeta(commands.Cog):
max_end_size = ((bitrate * duration * channels) / 8) * 1024
if max_end_size > (24.75 * 1024 * 1024):
return await ctx.respond(
"The file would be too large to send ({:,.2f} MiB).".format(max_end_size / 1024 / 1024)
"The file would be too large to send ({:,.2f} MiB).".format(
max_end_size / 1024 / 1024
)
)
process = await asyncio.create_subprocess_exec(
@ -237,7 +277,9 @@ class FFMeta(commands.Cog):
file = io.BytesIO(stdout)
if (fs := len(file.getvalue())) > (24.75 * 1024 * 1024):
return await ctx.respond("The file is too large to send ({:,.2f} MiB).".format(fs / 1024 / 1024))
return await ctx.respond(
"The file is too large to send ({:,.2f} MiB).".format(fs / 1024 / 1024)
)
if not fs:
await ctx.respond("Failed to convert audio. See below.")
else:
@ -251,9 +293,11 @@ class FFMeta(commands.Cog):
for page in paginator.pages:
await ctx.respond(page, ephemeral=True)
@commands.slash_command(name="right-behind-you")
async def right_behind_you(self, ctx: discord.ApplicationContext, image: discord.Attachment):
async def right_behind_you(
self, ctx: discord.ApplicationContext, image: discord.Attachment
):
await ctx.defer()
rbh = Path("assets/right-behind-you.ogg").resolve()
if not rbh.exists():
@ -264,7 +308,9 @@ class FFMeta(commands.Cog):
return await ctx.respond("That's not an image!")
with tempfile.NamedTemporaryFile(suffix=Path(image.filename).suffix) as temp:
img_tmp = io.BytesIO(await image.read())
dst: io.BytesIO = await asyncio.to_thread(self.jpegify_image, img_tmp, 90, "webp")
dst: io.BytesIO = await asyncio.to_thread(
self.jpegify_image, img_tmp, 90, "webp"
)
temp.write(dst.getvalue())
temp.flush()
process = await asyncio.create_subprocess_exec(
@ -307,16 +353,20 @@ class FFMeta(commands.Cog):
"5",
"pipe:1",
stdout=asyncio.subprocess.PIPE,
stderr=sys.stderr
stderr=sys.stderr,
)
stdout, stderr = await process.communicate()
file = io.BytesIO(stdout)
if (fs := len(file.getvalue())) > (24.75 * 1024 * 1024):
return await ctx.respond("The file is too large to send ({:,.2f} MiB).".format(fs / 1024 / 1024))
return await ctx.respond(
"The file is too large to send ({:,.2f} MiB).".format(fs / 1024 / 1024)
)
if not fs:
await ctx.respond("Failed to convert audio. See below.")
else:
return await ctx.respond(file=discord.File(file, filename="right-behind-you.mp4"))
return await ctx.respond(
file=discord.File(file, filename="right-behind-you.mp4")
)
paginator = commands.Paginator()
for line in stderr.decode().splitlines():
if line.strip().startswith(":"):

View file

@ -11,13 +11,15 @@ class MeterCog(commands.Cog):
self.bot = bot
self.log = logging.getLogger("jimmy.cogs.auto_responder")
self.cache = {}
@commands.slash_command(name="gay-meter")
@discord.guild_only()
async def gay_meter(self, ctx: discord.ApplicationContext, user: discord.User = None):
async def gay_meter(
self, ctx: discord.ApplicationContext, user: discord.User = None
):
"""Checks how gay someone is"""
user = user or ctx.user
await ctx.respond("Calculating...")
for i in range(0, 125, 25):
await ctx.edit(content="Calculating... %d%%" % i)
@ -28,9 +30,11 @@ class MeterCog(commands.Cog):
else:
pct = user.id % 100
await ctx.edit(content=f"{user.mention} is {pct}% gay.")
@commands.slash_command(name="penis-length")
async def penis_meter(self, ctx: discord.ApplicationContext, user: discord.User = None):
async def penis_meter(
self, ctx: discord.ApplicationContext, user: discord.User = None
):
"""Checks the length of someone's penis."""
user = user or ctx.user
if random.randint(0, 1):
@ -49,10 +53,10 @@ class MeterCog(commands.Cog):
return await ctx.respond(
embed=discord.Embed(
title=f"{user.display_name}'s penis length:",
description="%d cm (%.2f%s)\n%s" % (pct, inch, im, "".join(chunks))
description="%d cm (%.2f%s)\n%s" % (pct, inch, im, "".join(chunks)),
)
)
@commands.command()
@commands.is_owner()
async def clear_cache(self, ctx: commands.Context, user: discord.User = None):
@ -61,7 +65,7 @@ class MeterCog(commands.Cog):
self.cache = {}
else:
self.cache.pop(user, None)
return await ctx.message.add_reaction("\N{white heavy check mark}")
return await ctx.message.add_reaction("\N{WHITE HEAVY CHECK MARK}")
def setup(bot):

View file

@ -33,17 +33,11 @@ class GetFilteredTextView(discord.ui.View):
self.text = text
super().__init__(timeout=600)
@discord.ui.button(
label="See filtered data",
emoji="\N{INBOX TRAY}"
)
@discord.ui.button(label="See filtered data", emoji="\N{INBOX TRAY}")
async def see_filtered_data(self, _, interaction: discord.Interaction):
await interaction.response.defer(ephemeral=True)
await interaction.followup.send(
file=discord.File(
io.BytesIO(self.text.encode()),
"filtered.txt"
)
file=discord.File(io.BytesIO(self.text.encode()), "filtered.txt")
)
@ -106,7 +100,11 @@ class NetworkCog(commands.Cog):
def decide(ln: str) -> typing.Optional[bool]:
if ln.startswith(">>> Last update"):
return
if "REDACTED" in ln or "Please query the WHOIS server of the owning registrar" in ln or ":" not in ln:
if (
"REDACTED" in ln
or "Please query the WHOIS server of the owning registrar" in ln
or ":" not in ln
):
return False
else:
return True
@ -134,7 +132,9 @@ class NetworkCog(commands.Cog):
if not paginator.pages:
stdout, stderr, status = await run_command(with_disclaimer=True)
if not any((stdout, stderr)):
return await ctx.respond(f"No output was returned with status code {status}.")
return await ctx.respond(
f"No output was returned with status code {status}."
)
file = io.BytesIO()
file.write(stdout)
if stderr:
@ -143,7 +143,7 @@ class NetworkCog(commands.Cog):
file.seek(0)
return await ctx.respond(
"Seemingly all output was filtered. Returning raw command output.",
file=discord.File(file, "whois.txt")
file=discord.File(file, "whois.txt"),
)
last: discord.Interaction | discord.WebhookMessage | None = None
@ -223,9 +223,15 @@ class NetworkCog(commands.Cog):
default="default",
),
use_ip_version: discord.Option(
str, name="ip-version", description="IP version to use.", choices=["ipv4", "ipv6"], default="ipv4"
str,
name="ip-version",
description="IP version to use.",
choices=["ipv4", "ipv6"],
default="ipv4",
),
max_ttl: discord.Option(
int, name="ttl", description="Max number of hops", default=30
),
max_ttl: discord.Option(int, name="ttl", description="Max number of hops", default=30),
):
"""Performs a traceroute request."""
await ctx.defer()
@ -258,7 +264,9 @@ class NetworkCog(commands.Cog):
args.append(str(port))
args.append(url)
paginator = commands.Paginator()
paginator.add_line(f"Running command: {' '.join(args[3 if args[0] == 'sudo' else 0:])}")
paginator.add_line(
f"Running command: {' '.join(args[3 if args[0] == 'sudo' else 0:])}"
)
paginator.add_line(empty=True)
try:
start = time.time_ns()
@ -297,10 +305,7 @@ class NetworkCog(commands.Cog):
await ctx.respond(file=discord.File(f))
async def _fetch_ip_response(
self,
server: str,
lookup: str,
client: httpx.AsyncClient
self, server: str, lookup: str, client: httpx.AsyncClient
) -> tuple[dict, float] | httpx.HTTPError | ConnectionError | json.JSONDecodeError:
try:
start = time.perf_counter()
@ -315,18 +320,16 @@ class NetworkCog(commands.Cog):
async def get_ip_address(self, ctx: discord.ApplicationContext, lookup: str = None):
"""Fetches IP info from SHRONK IP servers"""
await ctx.defer()
async with httpx.AsyncClient(headers={"User-Agent": "Mozilla/5.0 Jimmy/v2"}) as client:
async with httpx.AsyncClient(
headers={"User-Agent": "Mozilla/5.0 Jimmy/v2"}
) as client:
if not lookup:
response = await client.get("https://api.ipify.org")
lookup = response.text
servers = self.config.get(
"ip_servers",
[
"ip.shronk.net",
"ip.i-am.nexus",
"ip.shronk.nicroxio.co.uk"
]
["ip.shronk.net", "ip.i-am.nexus", "ip.shronk.nicroxio.co.uk"],
)
embed = discord.Embed(
title="IP lookup information for: %s" % lookup,
@ -353,13 +356,14 @@ class NetworkCog(commands.Cog):
t = response.text[:512]
embed.add_field(
name=server,
value=f"An error occurred while parsing the data: {e}\nData: ```\n%s\n```" % t,
value=f"An error occurred while parsing the data: {e}\nData: ```\n%s\n```"
% t,
)
else:
embed.add_field(
name="%s (%.2fms)" % (server, (end - start) * 1000),
value="```json\n%s\n```" % v,
inline=False
inline=False,
)
await ctx.respond(embed=embed)
@ -367,41 +371,41 @@ class NetworkCog(commands.Cog):
@commands.slash_command()
@commands.max_concurrency(1, commands.BucketType.user)
async def nmap(
self,
ctx: discord.ApplicationContext,
target: str,
technique: typing.Annotated[
self,
ctx: discord.ApplicationContext,
target: str,
technique: typing.Annotated[
str,
discord.Option(
str,
discord.Option(
str,
choices=[
discord.OptionChoice(name="TCP SYN", value="S"),
discord.OptionChoice(name="TCP Connect", value="T"),
discord.OptionChoice(name="TCP ACK", value="A"),
discord.OptionChoice(name="TCP Window", value="W"),
discord.OptionChoice(name="TCP Maimon", value="M"),
discord.OptionChoice(name="UDP", value="U"),
discord.OptionChoice(name="TCP Null", value="N"),
discord.OptionChoice(name="TCP FIN", value="F"),
discord.OptionChoice(name="TCP XMAS", value="X"),
],
default="T"
)
] = "T",
treat_all_hosts_online: bool = False,
service_scan: bool = False,
fast_mode: bool = False,
enable_os_detection: bool = False,
timing: typing.Annotated[
choices=[
discord.OptionChoice(name="TCP SYN", value="S"),
discord.OptionChoice(name="TCP Connect", value="T"),
discord.OptionChoice(name="TCP ACK", value="A"),
discord.OptionChoice(name="TCP Window", value="W"),
discord.OptionChoice(name="TCP Maimon", value="M"),
discord.OptionChoice(name="UDP", value="U"),
discord.OptionChoice(name="TCP Null", value="N"),
discord.OptionChoice(name="TCP FIN", value="F"),
discord.OptionChoice(name="TCP XMAS", value="X"),
],
default="T",
),
] = "T",
treat_all_hosts_online: bool = False,
service_scan: bool = False,
fast_mode: bool = False,
enable_os_detection: bool = False,
timing: typing.Annotated[
int,
discord.Option(
int,
discord.Option(
int,
description="Timing template to use 0 is slowest, 5 is fastest.",
choices=[0, 1, 2, 3, 4, 5],
default=3
)
] = 3,
ports: str = None
description="Timing template to use 0 is slowest, 5 is fastest.",
choices=[0, 1, 2, 3, 4, 5],
default=3,
),
] = 3,
ports: str = None,
):
"""Runs nmap on a target. You cannot specify multiple targets."""
await ctx.defer()
@ -420,7 +424,9 @@ class NetworkCog(commands.Cog):
if enable_os_detection and not is_superuser:
return await ctx.respond("OS detection is not available on this system.")
with tempfile.TemporaryDirectory(prefix=f"nmap-{ctx.user.id}-{discord.utils.utcnow().timestamp():.0f}") as tmp:
with tempfile.TemporaryDirectory(
prefix=f"nmap-{ctx.user.id}-{discord.utils.utcnow().timestamp():.0f}"
) as tmp:
tmp_dir = Path(tmp)
args = [
"nmap",
@ -430,7 +436,7 @@ class NetworkCog(commands.Cog):
str(timing),
"-s" + technique,
"--reason",
"--noninteractive"
"--noninteractive",
]
if treat_all_hosts_online:
args.append("-Pn")
@ -447,8 +453,7 @@ class NetworkCog(commands.Cog):
await ctx.respond(
embed=discord.Embed(
title="Running nmap...",
description="Command:\n"
"```{}```".format(shlex.join(args)),
description="Command:\n" "```{}```".format(shlex.join(args)),
)
)
process = await asyncio.create_subprocess_exec(
@ -457,14 +462,16 @@ class NetworkCog(commands.Cog):
stderr=asyncio.subprocess.PIPE,
)
_, stderr = await process.communicate()
files = [discord.File(x, filename=x.name + ".txt") for x in tmp_dir.iterdir()]
files = [
discord.File(x, filename=x.name + ".txt") for x in tmp_dir.iterdir()
]
if not files:
if len(stderr) <= 4089:
return await ctx.edit(
embed=discord.Embed(
title="Nmap failed.",
description="```\n" + stderr.decode() + "```",
color=discord.Color.red()
color=discord.Color.red(),
)
)
@ -475,19 +482,19 @@ class NetworkCog(commands.Cog):
embed=discord.Embed(
title="Nmap failed.",
description=f"Output was too long. [View full output]({result.url})",
color=discord.Color.red()
color=discord.Color.red(),
)
)
await ctx.edit(
embed=discord.Embed(
title="Nmap finished!",
description="Result files are attached.\n"
"* `gnmap` is 'greppable'\n"
"* `xml` is XML output\n"
"* `nmap` is normal output",
color=discord.Color.green()
"* `gnmap` is 'greppable'\n"
"* `xml` is XML output\n"
"* `nmap` is normal output",
color=discord.Color.green(),
),
files=files
files=files,
)

View file

@ -107,7 +107,9 @@ class OllamaDownloadHandler:
async def __aiter__(self):
async with aiohttp.ClientSession(base_url=self.base_url) as client:
async with client.post("/api/pull", json={"name": self.model, "stream": True}, timeout=None) as response:
async with client.post(
"/api/pull", json={"name": self.model, "stream": True}, timeout=None
) as response:
response.raise_for_status()
async for line in ollama_stream(response.content):
self.parse_line(line)
@ -122,7 +124,7 @@ class OllamaDownloadHandler:
"Downloading orca-mini:7b on server %r - %s (%.2f%%)",
self.base_url,
self.status,
self.percent
self.percent,
)
return self
@ -176,12 +178,13 @@ class OllamaChatHandler:
async def __aiter__(self):
async with aiohttp.ClientSession(base_url=self.base_url) as client:
async with client.post(
"/api/chat", json={
"model": self.model,
"stream": True,
"/api/chat",
json={
"model": self.model,
"stream": True,
"messages": self.messages,
"options": self.options
}
"options": self.options,
},
) as response:
response.raise_for_status()
async for line in ollama_stream(response.content):
@ -201,7 +204,9 @@ class OllamaClient:
self.base_url = base_url
self.authorisation = authorisation
def with_client(self, timeout: aiohttp.ClientTimeout | float | int | None = None) -> aiohttp.ClientSession:
def with_client(
self, timeout: aiohttp.ClientTimeout | float | int | None = None
) -> aiohttp.ClientSession:
"""
Creates an instance for a request, with properly populated values.
:param timeout:
@ -213,9 +218,13 @@ class OllamaClient:
timeout = aiohttp.ClientTimeout(timeout)
else:
timeout = timeout or aiohttp.ClientTimeout(120)
return aiohttp.ClientSession(self.base_url, timeout=timeout, auth=self.authorisation)
return aiohttp.ClientSession(
self.base_url, timeout=timeout, auth=self.authorisation
)
async def get_tags(self) -> dict[typing.Literal["models"], dict[str, str, int, dict[str, str, None]]]:
async def get_tags(
self,
) -> dict[typing.Literal["models"], dict[str, str, int, dict[str, str, None]]]:
"""
Gets the tags for the server.
:return:
@ -250,7 +259,7 @@ class OllamaClient:
self,
model: str,
messages: list[dict[str, str]],
options: dict[str, typing.Any] = None
options: dict[str, typing.Any] = None,
) -> OllamaChatHandler:
"""
Starts a chat with the given messages.
@ -273,7 +282,11 @@ class OllamaView(View):
async def interaction_check(self, interaction: discord.Interaction) -> bool:
return interaction.user == self.ctx.user
@button(label="Stop", style=discord.ButtonStyle.danger, emoji="\N{wastebasket}\U0000fe0f")
@button(
label="Stop",
style=discord.ButtonStyle.danger,
emoji="\N{WASTEBASKET}\U0000fe0f",
)
async def _stop(self, btn: discord.ui.Button, interaction: discord.Interaction):
self.cancel.set()
btn.disabled = True
@ -310,14 +323,20 @@ class ChatHistory:
:return: The thread's ID.
"""
key = os.urandom(3).hex()
self._internal[key] = {"member": member.id, "seed": round(time.time()), "messages": []}
self._internal[key] = {
"member": member.id,
"seed": round(time.time()),
"messages": [],
}
with open("./assets/ollama-prompt.txt") as file:
system_prompt = default or file.read()
self.add_message(key, "system", system_prompt)
return key
@staticmethod
def _construct_message(role: str, content: str, images: typing.Optional[list[str]]) -> dict[str, str]:
def _construct_message(
role: str, content: str, images: typing.Optional[list[str]]
) -> dict[str, str]:
x = {"role": role, "content": content}
if images:
x["images"] = images
@ -331,7 +350,9 @@ class ChatHistory:
return list(
filter(
lambda v: (ctx.value or v) in v,
map(lambda d: list(d.keys()), instance.threads_for(ctx.interaction.user)),
map(
lambda d: list(d.keys()), instance.threads_for(ctx.interaction.user)
),
)
)
@ -339,7 +360,9 @@ class ChatHistory:
"""Returns all saved threads."""
return self._internal.copy()
def threads_for(self, user: discord.Member) -> dict[str, dict[str, list[dict[str, str]] | int]]:
def threads_for(
self, user: discord.Member
) -> dict[str, dict[str, list[dict[str, str]] | int]]:
"""Returns all saved threads for a specific user"""
t = self.all_threads()
for k, v in t.copy().items():
@ -354,7 +377,7 @@ class ChatHistory:
content: str,
images: typing.Optional[list[str]] = None,
*,
save: bool = True
save: bool = True,
) -> None:
"""
Appends a message to the given thread.
@ -380,7 +403,9 @@ class ChatHistory:
return []
return self._internal[thread]["messages"].copy() # copy() makes it immutable.
def get_thread(self, thread: str) -> dict[str, list[dict[str, str]] | discord.Member | int]:
def get_thread(
self, thread: str
) -> dict[str, list[dict[str, str]] | discord.Member | int]:
"""Gets a copy of an entire thread"""
return self._internal.get(thread, {}).copy()
@ -401,7 +426,6 @@ SERVER_KEYS_AUTOCOMPLETE.remove("order")
class OllamaGetPrompt(discord.ui.Modal):
def __init__(self, ctx: discord.ApplicationContext, prompt_type: str = "User"):
super().__init__(
discord.ui.InputText(
@ -441,7 +465,9 @@ class PromptSelector(discord.ui.View):
if self.user_prompt is not None:
self.get_item("usr").style = discord.ButtonStyle.secondary # type: ignore
@discord.ui.button(label="Set System Prompt", style=discord.ButtonStyle.primary, custom_id="sys")
@discord.ui.button(
label="Set System Prompt", style=discord.ButtonStyle.primary, custom_id="sys"
)
async def set_system_prompt(self, btn: discord.ui.Button, interaction: Interaction):
modal = OllamaGetPrompt(self.ctx, "System")
await interaction.response.send_modal(modal)
@ -450,7 +476,9 @@ class PromptSelector(discord.ui.View):
self.update_ui()
await interaction.edit_original_response(view=self)
@discord.ui.button(label="Set User Prompt", style=discord.ButtonStyle.primary, custom_id="usr")
@discord.ui.button(
label="Set User Prompt", style=discord.ButtonStyle.primary, custom_id="usr"
)
async def set_user_prompt(self, btn: discord.ui.Button, interaction: Interaction):
modal = OllamaGetPrompt(self.ctx)
await interaction.response.send_modal(modal)
@ -459,7 +487,9 @@ class PromptSelector(discord.ui.View):
self.update_ui()
await interaction.edit_original_response(view=self)
@discord.ui.button(label="Done", style=discord.ButtonStyle.success, custom_id="done")
@discord.ui.button(
label="Done", style=discord.ButtonStyle.success, custom_id="done"
)
async def done(self, btn: discord.ui.Button, interaction: Interaction):
self.ctx.interaction = interaction
self.stop()
@ -474,13 +504,17 @@ class ConfirmCPURun(discord.ui.View):
async def interaction_check(self, interaction: Interaction) -> bool:
return interaction.user == self.ctx.user
@discord.ui.button(label="Run on CPU", style=discord.ButtonStyle.primary, custom_id="cpu")
@discord.ui.button(
label="Run on CPU", style=discord.ButtonStyle.primary, custom_id="cpu"
)
async def run_on_cpu(self, btn: discord.ui.Button, interaction: Interaction):
await interaction.response.defer(invisible=True)
self.proceed = True
self.stop()
@discord.ui.button(label="Abort", style=discord.ButtonStyle.primary, custom_id="gpu")
@discord.ui.button(
label="Abort", style=discord.ButtonStyle.primary, custom_id="gpu"
)
async def run_on_gpu(self, btn: discord.ui.Button, interaction: Interaction):
await interaction.response.defer(invisible=True)
self.stop()
@ -492,9 +526,7 @@ class Ollama(commands.Cog):
self.log = logging.getLogger("jimmy.cogs.ollama")
self.contexts = {}
self.history = ChatHistory()
self.servers = {
server: asyncio.Lock() for server in CONFIG["ollama"]
}
self.servers = {server: asyncio.Lock() for server in CONFIG["ollama"]}
self.servers.pop("order", None)
if CONFIG["ollama"].get("order"):
self.servers = {}
@ -520,7 +552,9 @@ class Ollama(commands.Cog):
if url in SERVER_KEYS:
url = CONFIG["ollama"][url]["base_url"]
async with aiohttp.ClientSession(
timeout=aiohttp.ClientTimeout(connect=3, sock_connect=3, sock_read=10, total=3)
timeout=aiohttp.ClientTimeout(
connect=3, sock_connect=3, sock_read=10, total=3
)
) as session:
self.log.info("Checking if %r is online.", url)
try:
@ -551,20 +585,37 @@ class Ollama(commands.Cog):
),
],
server: typing.Annotated[
str, discord.Option(str, "The server to use for ollama.", default="next", choices=SERVER_KEYS_AUTOCOMPLETE)
str,
discord.Option(
str,
"The server to use for ollama.",
default="next",
choices=SERVER_KEYS_AUTOCOMPLETE,
),
],
context: typing.Annotated[
str, discord.Option(str, "The context key of a previous ollama response to use as context.", default=None)
str,
discord.Option(
str,
"The context key of a previous ollama response to use as context.",
default=None,
),
],
give_acid: typing.Annotated[
bool,
discord.Option(
bool, "Whether to give the AI acid, LSD, and other hallucinogens before responding.", default=False
bool,
"Whether to give the AI acid, LSD, and other hallucinogens before responding.",
default=False,
),
],
image: typing.Annotated[
discord.Attachment,
discord.Option(discord.Attachment, "An image to feed into ollama. Only works with llava.", default=None),
discord.Option(
discord.Attachment,
"An image to feed into ollama. Only works with llava.",
default=None,
),
],
):
if server == "next":
@ -587,7 +638,7 @@ class Ollama(commands.Cog):
await ctx.respond(
"Select edit your prompts, as desired. Click done when you want to continue.",
view=v,
ephemeral=True
ephemeral=True,
)
await v.wait()
query = v.user_prompt or query
@ -604,22 +655,23 @@ class Ollama(commands.Cog):
self.log.debug("Resolved model to %r" % model)
if image:
patterns = [
"llava:*",
"llava-llama*:*"
]
patterns = ["llava:*", "llava-llama*:*"]
if any(fnmatch(model, p) for p in patterns) is False:
await ctx.send(
f"{ctx.user.mention}: You can only use images with llava. Switching model to `llava:latest`.",
delete_after=5
f"{ctx.user.mention}: You can only use images with llava. Switching model to `llava:latest`.",
delete_after=5,
)
model = "llava:latest"
if image.size > 1024 * 1024 * 25:
await ctx.respond("Attachment is too large. Maximum size is 25 MB, for sanity. Try compressing it.")
await ctx.respond(
"Attachment is too large. Maximum size is 25 MB, for sanity. Try compressing it."
)
return
elif not fnmatch(image.content_type, "image/*"):
await ctx.respond("Attachment is not an image. Try using a different file.")
await ctx.respond(
"Attachment is not an image. Try using a different file."
)
return
else:
data = io.BytesIO()
@ -638,13 +690,19 @@ class Ollama(commands.Cog):
if fnmatch(model, model_pattern):
break
else:
allowed_models = ", ".join(map(discord.utils.escape_markdown, server_config["allowed_models"]))
await ctx.respond(f"Invalid model. You can only use one of the following models: {allowed_models}")
allowed_models = ", ".join(
map(discord.utils.escape_markdown, server_config["allowed_models"])
)
await ctx.respond(
f"Invalid model. You can only use one of the following models: {allowed_models}"
)
return
async with aiohttp.ClientSession(
base_url=server_config["base_url"],
timeout=aiohttp.ClientTimeout(connect=5, sock_read=10800, sock_connect=5, total=10830),
timeout=aiohttp.ClientTimeout(
connect=5, sock_read=10800, sock_connect=5, total=10830
),
) as session:
embed = discord.Embed(
title="Checking server...",
@ -652,26 +710,32 @@ class Ollama(commands.Cog):
color=discord.Color.blurple(),
timestamp=discord.utils.utcnow(),
)
embed.set_footer(text="Using server %r" % server, icon_url=server_config.get("icon_url"))
embed.set_footer(
text="Using server %r" % server, icon_url=server_config.get("icon_url")
)
await ctx.respond(embed=embed)
if not await self.check_server(server_config["base_url"]):
tried = {server}
for i in range(10):
try:
server = self.next_server(tried)
if server and CONFIG["ollama"]["server"].get("is_gpu", False) is not True:
if (
server
and CONFIG["ollama"]["server"].get("is_gpu", False)
is not True
):
cf = ConfirmCPURun(ctx)
await ctx.edit(
embed=discord.Embed(
title=f"Server {server} is available, but...",
description="It is CPU only, which means it is very slow and will likely crash.\n"
"If you really want, you can continue your generation, using this "
"server. Be aware though, once in motion, it cannot be stopped.\n\n"
""
"Continue?",
"If you really want, you can continue your generation, using this "
"server. Be aware though, once in motion, it cannot be stopped.\n\n"
""
"Continue?",
color=discord.Color.red(),
),
view=cf
view=cf,
)
await cf.wait()
await ctx.edit(view=None)
@ -688,7 +752,10 @@ class Ollama(commands.Cog):
color=discord.Color.gold(),
timestamp=discord.utils.utcnow(),
)
embed.set_footer(text="Using server %r" % server, icon_url=server_config.get("icon_url"))
embed.set_footer(
text="Using server %r" % server,
icon_url=server_config.get("icon_url"),
)
await ctx.edit(embed=embed)
await asyncio.sleep(1)
if await self.check_server(CONFIG["ollama"][server]["base_url"]):
@ -713,7 +780,9 @@ class Ollama(commands.Cog):
embed = discord.Embed(
url=resp.url,
title=f"HTTP {resp.status} {resp.reason!r} while checking for model.",
description=f"```{await resp.text() or 'No response body'}```"[:4096],
description=f"```{await resp.text() or 'No response body'}```"[
:4096
],
color=discord.Color.red(),
timestamp=discord.utils.utcnow(),
)
@ -733,8 +802,8 @@ class Ollama(commands.Cog):
self.log.debug("Beginning download of %r", model)
def progress_bar(_v: float, action: str = None, _mbps: float = None):
bar = "\N{large green square}" * round(_v / 10)
bar += "\N{white large square}" * (10 - len(bar))
bar = "\N{LARGE GREEN SQUARE}" * round(_v / 10)
bar += "\N{WHITE LARGE SQUARE}" * (10 - len(bar))
bar += f" {_v:.2f}%"
if _mbps:
bar += f" ({_mbps:.2f} MiB/s)"
@ -753,12 +822,16 @@ class Ollama(commands.Cog):
last_update = time.time()
async with session.post("/api/pull", json={"name": model, "stream": True}, timeout=None) as response:
async with session.post(
"/api/pull", json={"name": model, "stream": True}, timeout=None
) as response:
if response.status != 200:
embed = discord.Embed(
url=response.url,
title=f"HTTP {response.status} {response.reason!r} while downloading model.",
description=f"```{await response.text() or 'No response body'}```"[:4096],
description=f"```{await response.text() or 'No response body'}```"[
:4096
],
color=discord.Color.red(),
timestamp=discord.utils.utcnow(),
)
@ -775,7 +848,10 @@ class Ollama(commands.Cog):
)
return await ctx.edit(embed=embed, view=None)
if time.time() >= (last_update + 5.1):
if line.get("total") is not None and line.get("completed") is not None:
if (
line.get("total") is not None
and line.get("completed") is not None
):
new_bytes = line["completed"] - last_downloaded
mbps = new_bytes / 1024 / 1024 / 5
last_downloaded = line["completed"]
@ -784,7 +860,9 @@ class Ollama(commands.Cog):
percent = 50.0
mbps = 0.0
embed.fields[0].value = progress_bar(percent, line["status"], mbps)
embed.fields[0].value = progress_bar(
percent, line["status"], mbps
)
await ctx.edit(embed=embed, view=view)
last_update = time.time()
else:
@ -799,13 +877,13 @@ class Ollama(commands.Cog):
embed=discord.Embed(
title="Before you continue",
description="You've selected a CPU-only server. This will be really slow. This will also likely"
" bring the host to a halt. Consider the pain the CPU is about to endure. "
"Are you super"
" sure you want to continue? You can run `h!ollama-status` to see what servers are"
" available.",
" bring the host to a halt. Consider the pain the CPU is about to endure. "
"Are you super"
" sure you want to continue? You can run `h!ollama-status` to see what servers are"
" available.",
color=discord.Color.red(),
),
view=cf2
view=cf2,
)
await cf2.wait()
if cf2.proceed is False:
@ -825,9 +903,13 @@ class Ollama(commands.Cog):
icon_url="https://ollama.com/public/ollama.png",
)
embed.add_field(
name="Prompt", value=">>> " + textwrap.shorten(query, width=1020, placeholder="..."), inline=False
name="Prompt",
value=">>> " + textwrap.shorten(query, width=1020, placeholder="..."),
inline=False,
)
embed.set_footer(
text="Using server %r" % server, icon_url=server_config.get("icon_url")
)
embed.set_footer(text="Using server %r" % server, icon_url=server_config.get("icon_url"))
if image_data:
if (image.height / image.width) >= 1.5:
embed.set_image(url=image.url)
@ -842,7 +924,10 @@ class Ollama(commands.Cog):
if context is None:
context = self.history.create_thread(ctx.user, system_query)
elif context is not None and (__thread := self.history.find_thread(context)) is None:
elif (
context is not None
and (__thread := self.history.find_thread(context)) is None
):
if not __thread:
return await ctx.respond("Invalid thread ID.")
else:
@ -861,7 +946,12 @@ class Ollama(commands.Cog):
params["top_p"] = 2
params["repeat_penalty"] = 2
payload = {"model": model, "stream": True, "options": params, "messages": messages}
payload = {
"model": model,
"stream": True,
"options": params,
"messages": messages,
}
async with session.post(
"/api/chat",
json=payload,
@ -870,7 +960,9 @@ class Ollama(commands.Cog):
embed = discord.Embed(
url=response.url,
title=f"HTTP {response.status} {response.reason!r} while generating response.",
description=f"```{await response.text() or 'No response body'}```"[:4096],
description=f"```{await response.text() or 'No response body'}```"[
:4096
],
color=discord.Color.red(),
timestamp=discord.utils.utcnow(),
)
@ -888,20 +980,31 @@ class Ollama(commands.Cog):
embed.description = "[...]" + line["message"]["content"]
if len(embed.description) >= 3250:
embed.colour = discord.Color.gold()
embed.set_footer(text="Warning: {:,}/4096 characters.".format(len(embed.description)))
embed.set_footer(
text="Warning: {:,}/4096 characters.".format(
len(embed.description)
)
)
else:
embed.colour = discord.Color.blurple()
embed.set_footer(text="Using server %r" % server, icon_url=server_config.get("icon_url"))
embed.set_footer(
text="Using server %r" % server,
icon_url=server_config.get("icon_url"),
)
if view.cancel.is_set():
break
if time.time() >= (last_update + 5.1):
await ctx.edit(embed=embed, view=view)
self.log.debug(f"Updating message ({last_update} -> {time.time()})")
self.log.debug(
f"Updating message ({last_update} -> {time.time()})"
)
last_update = time.time()
view.stop()
self.history.add_message(context, "user", user_message["content"], user_message.get("images"))
self.history.add_message(
context, "user", user_message["content"], user_message.get("images")
)
self.history.add_message(context, "assistant", buffer.getvalue())
embed.add_field(name="Context Key", value=context, inline=True)
@ -911,7 +1014,9 @@ class Ollama(commands.Cog):
value = buffer.getvalue()
if len(value) >= 4096:
embeds = [discord.Embed(title="Done!", colour=discord.Color.green())]
embeds = [
discord.Embed(title="Done!", colour=discord.Color.green())
]
current_page = ""
for word in value.split():
@ -971,13 +1076,19 @@ class Ollama(commands.Cog):
if message["role"] == "system":
continue
max_length = 4000 - len("> **%s**: " % message["role"])
paginator.add_line("> **{}**: {}".format(message["role"], textwrap.shorten(message["content"], max_length)))
paginator.add_line(
"> **{}**: {}".format(
message["role"], textwrap.shorten(message["content"], max_length)
)
)
embeds = []
for page in paginator.pages:
embeds.append(discord.Embed(description=page))
ephemeral = len(embeds) > 1
for chunk in discord.utils.as_chunks(iter(embeds or [discord.Embed(title="No Content.")]), 10):
for chunk in discord.utils.as_chunks(
iter(embeds or [discord.Embed(title="No Content.")]), 10
):
await ctx.respond(embeds=chunk, ephemeral=ephemeral)
@commands.command(name="ollama-status", aliases=["ollama_status", "os"])
@ -990,14 +1101,18 @@ class Ollama(commands.Cog):
if CONFIG["ollama"].get("order"):
ln = ["Server order:"]
for n, key in enumerate(CONFIG["ollama"].get("order"), start=1):
zap = '\N{high voltage sign}'
ln.append(f"{n}. {key!r} {f'({zap})' if CONFIG['ollama'][key].get('is_gpu') else ''}")
zap = "\N{HIGH VOLTAGE SIGN}"
ln.append(
f"{n}. {key!r} {f'({zap})' if CONFIG['ollama'][key].get('is_gpu') else ''}"
)
embed.description = "\n".join(ln)
for server, lock in self.servers.items():
embed.add_field(
name=server,
value="\U000026a0\U0000fe0fIn use" if lock.locked() else "\U000023f3 Checking..."
value="\U000026a0\U0000fe0fIn use"
if lock.locked()
else "\U000023f3 Checking...",
)
msg = await ctx.reply(embed=embed)
@ -1006,23 +1121,35 @@ class Ollama(commands.Cog):
for server in self.servers.keys():
if self.servers[server].locked():
continue
tasks[server] = asyncio.create_task(self.check_server(CONFIG["ollama"][server]["base_url"]))
tasks[server] = asyncio.create_task(
self.check_server(CONFIG["ollama"][server]["base_url"])
)
await asyncio.gather(*tasks.values())
for server, task in tasks.items():
if not task.done:
await task
if e := task.exception():
self.log.error("Error while checking server %r: %r", server, e, exc_info=True)
self.log.error(
"Error while checking server %r: %r", server, e, exc_info=True
)
else:
ok = task.result()
for field in embed.fields:
if field.name == server:
field.value = "\U00002705 Online" if ok else "\U0000274c Offline"
field.value = (
"\U00002705 Online" if ok else "\U0000274c Offline"
)
break
await msg.edit(embed=embed)
async def generate_truth(self, ctx: commands.Context, entity: str, limit: int = None, question: str = None):
async def generate_truth(
self,
ctx: commands.Context,
entity: str,
limit: int = None,
question: str = None,
):
if entity == "trump":
system = (
"You are the former united states president, Donald Trump, convicted of 34 felonies."
@ -1034,7 +1161,7 @@ class Ollama(commands.Cog):
system = (
"You are the 'Top G' Andrew Tate. Andrew Tate is a former professional kickboxer who turned to the "
"world of influencing. His personality is one of arrogance and self-aggrandizement, "
"often referred to as \"King of Toxic Masculinity\" by critics due to his controversial views on "
'often referred to as "King of Toxic Masculinity" by critics due to his controversial views on '
"gender roles, relationships, and other topics. He has been involved in several controversies related "
"to his content online including promoting extremist ideologies and misogynistic views. "
"Despite this, he still has a large following and is known for being an entrepreneur who built multiple"
@ -1120,9 +1247,9 @@ class Ollama(commands.Cog):
"of the United Kingdom from 2019 to 2022. Known for his flamboyant personality and "
"controversial policies, Johnson's leadership was marred by numerous scandals and crises."
"Johnson's personality is characterized by his optimistic and charismatic demeanor, "
"often described as a \"dashing politician.\" However, his political career has been "
'often described as a "dashing politician." However, his political career has been '
"plagued by scandals involving ethics violations, questionable financial arrangements, "
"and allegations of inappropriate behavior. The \"Partygate\" controversy, where "
'and allegations of inappropriate behavior. The "Partygate" controversy, where '
"gatherings at Downing Street violated COVID-19 regulations, significantly damaged his reputation. "
"Johnson's achievements include leading the U.K. out of the European Union and brokering a "
"post-Brexit trade deal with the United States. However, his tumultuous tenure was punctuated by "
@ -1162,21 +1289,22 @@ class Ollama(commands.Cog):
" Write using the style of a twitter or facebook post. Do not repeat a previous post or any previous "
"content. Do not include URLs or links. If an @user asks you a question, reply to them in your post."
)
thread_id = self.history.create_thread(
ctx.author,
system
)
thread_id = self.history.create_thread(ctx.author, system)
r = CONFIG["truth"].get("api", "https://bots.nexy7574.co.uk/jimmy/v2/api")
username = CONFIG["truth"].get("username", "1")
password = CONFIG["truth"].get("password", "2")
async with httpx.AsyncClient(base_url=r, auth=(username, password), trust_env=False) as http_client:
async with httpx.AsyncClient(
base_url=r, auth=(username, password), trust_env=False
) as http_client:
response = await http_client.get(
"/truths",
timeout=60,
)
response.raise_for_status()
truths = response.json()
truths: list[TruthPayload] = list(map(lambda t: TruthPayload.model_validate(t), truths))
truths: list[TruthPayload] = list(
map(lambda t: TruthPayload.model_validate(t), truths)
)
if entity:
truths = list(filter(lambda t: t.author == entity, truths))
@ -1191,11 +1319,13 @@ class Ollama(commands.Cog):
thread_id,
"assistant",
truth.content,
save=False
save=False,
)
)
if not question:
self.history.add_message(thread_id, "user", "Generate a new truth post.")
self.history.add_message(
thread_id, "user", "Generate a new truth post."
)
else:
self.history.add_message(thread_id, "user", question)
@ -1204,32 +1334,36 @@ class Ollama(commands.Cog):
server = self.next_server(tried)
is_gpu = CONFIG["ollama"][server].get("is_gpu", False)
if not is_gpu:
self.log.info("Skipping server %r as it is not a GPU server.", server)
self.log.info(
"Skipping server %r as it is not a GPU server.", server
)
continue
if await self.check_server(CONFIG["ollama"][server]["base_url"]):
break
tried.add(server)
else:
return await ctx.reply("All servers are offline. Please try again later.", delete_after=300)
return await ctx.reply(
"All servers are offline. Please try again later.", delete_after=300
)
client = OllamaClient(CONFIG["ollama"][server]["base_url"])
async with self.servers[server]:
if not await client.has_model_named("llama2-uncensored", "7b-chat"):
with client.download_model("llama2-uncensored", "7b-chat") as handler:
with client.download_model(
"llama2-uncensored", "7b-chat"
) as handler:
await handler.flatten()
embed = discord.Embed(
title=f"New {post_type.title()}!",
description="",
colour=0x6559FF
title=f"New {post_type.title()}!", description="", colour=0x6559FF
)
msg = await ctx.reply(embed=embed)
last_edit = time.time()
messages = self.history.get_history(thread_id)
with client.new_chat(
"llama2-uncensored:7b-chat",
messages,
options={"num_ctx": 4096, "num_predict": 128, "temperature": 1.5}
"llama2-uncensored:7b-chat",
messages,
options={"num_ctx": 4096, "num_predict": 128, "temperature": 1.5},
) as handler:
async for ln in handler:
embed.description += ln["message"]["content"]
@ -1245,7 +1379,7 @@ class Ollama(commands.Cog):
if truth.content == embed.description:
embed.add_field(
name=f"Repeated {post_type} :(",
value=f"This truth was already {post_type}ed. Shit AI."
value=f"This truth was already {post_type}ed. Shit AI.",
)
elif _ratio >= 70:
similar[truth.id] = _ratio
@ -1253,12 +1387,16 @@ class Ollama(commands.Cog):
if similar:
if len(similar) > 1:
lns = []
keys = sorted(similar.keys(), key=lambda k: similar[k], reverse=True)
keys = sorted(
similar.keys(), key=lambda k: similar[k], reverse=True
)
for truth_id in keys:
_ratio = similar[truth_id]
truth = discord.utils.get(truths, id=truth_id)
first_line = truth.content.splitlines()[0]
preview = discord.utils.escape_markdown(textwrap.shorten(first_line, 100))
preview = discord.utils.escape_markdown(
textwrap.shorten(first_line, 100)
)
lns.append(f"* `{truth_id}`: {_ratio}% - `{preview}`")
if len(lns) > 5:
lc = len(lns) - 5
@ -1266,33 +1404,39 @@ class Ollama(commands.Cog):
lns.append(f"*... and {lc} more*")
embed.add_field(
name=f"Possibly repeated {post_type}",
value=f"This {post_type} was similar to the following existing ones:\n" + "\n".join(lns),
inline=False
value=f"This {post_type} was similar to the following existing ones:\n"
+ "\n".join(lns),
inline=False,
)
else:
truth_id = tuple(similar)[0]
_ratio = similar[truth_id]
truth = discord.utils.get(truths, id=truth_id)
first_line = truth.content.splitlines()[0]
preview = discord.utils.escape_markdown(textwrap.shorten(first_line, 512))
preview = discord.utils.escape_markdown(
textwrap.shorten(first_line, 512)
)
embed.add_field(
name=f"Possibly repeated {post_type}",
value=f"This {post_type} was {_ratio}% similar to `{truth_id}`.\n>>> {preview}"
value=f"This {post_type} was {_ratio}% similar to `{truth_id}`.\n>>> {preview}",
)
embed.set_footer(
text="Finished generating {} based off of {:,} messages, using server {!r} | {!s}".format(
post_type,
len(messages) - 2,
server,
thread_id
post_type, len(messages) - 2, server, thread_id
)
)
await msg.edit(embed=embed)
@commands.command(aliases=["trump"])
@commands.guild_only()
async def donald(self, ctx: commands.Context, latest: typing.Optional[int] = None, *, question: str = None):
async def donald(
self,
ctx: commands.Context,
latest: typing.Optional[int] = None,
*,
question: str = None,
):
"""
Generates a truth social post from trump!
@ -1307,7 +1451,13 @@ class Ollama(commands.Cog):
@commands.command(aliases=["tate"])
@commands.guild_only()
async def andrew(self, ctx: commands.Context, latest: typing.Optional[int] = None, *, question: str = None):
async def andrew(
self,
ctx: commands.Context,
latest: typing.Optional[int] = None,
*,
question: str = None,
):
"""
Generates a truth social post from Andrew Tate
@ -1319,10 +1469,16 @@ class Ollama(commands.Cog):
question = f"'@{ctx.author.display_name}' asks: {question!r}"
async with ctx.channel.typing():
await self.generate_truth(ctx, "tate", latest, question=question)
@commands.command(aliases=["sunak"])
@commands.guild_only()
async def rishi(self, ctx: commands.Context, latest: typing.Optional[int] = None, *, question: str = None):
async def rishi(
self,
ctx: commands.Context,
latest: typing.Optional[int] = None,
*,
question: str = None,
):
"""
Generates a twitter post from Rishi Sunak
@ -1334,10 +1490,16 @@ class Ollama(commands.Cog):
question = f"'@{ctx.author.display_name}' asks: {question!r}"
async with ctx.channel.typing():
await self.generate_truth(ctx, "Rishi Sunak", latest, question=question)
@commands.command(aliases=["robinson"])
@commands.guild_only()
async def tommy(self, ctx: commands.Context, latest: typing.Optional[int] = None, *, question: str = None):
async def tommy(
self,
ctx: commands.Context,
latest: typing.Optional[int] = None,
*,
question: str = None,
):
"""
Generates a twitter post from Tommy Robinson
@ -1348,11 +1510,19 @@ class Ollama(commands.Cog):
if question:
question = f"'@{ctx.author.display_name}' asks: {question!r}"
async with ctx.channel.typing():
await self.generate_truth(ctx, "Tommy Robinson 🇬🇧", latest, question=question)
await self.generate_truth(
ctx, "Tommy Robinson 🇬🇧", latest, question=question
)
@commands.command(aliases=["fox"])
@commands.guild_only()
async def laurence(self, ctx: commands.Context, latest: typing.Optional[int] = None, *, question: str = None):
async def laurence(
self,
ctx: commands.Context,
latest: typing.Optional[int] = None,
*,
question: str = None,
):
"""
Generates a twitter post from Laurence Fox
@ -1364,10 +1534,16 @@ class Ollama(commands.Cog):
question = f"'@{ctx.author.display_name}' asks: {question!r}"
async with ctx.channel.typing():
await self.generate_truth(ctx, "Laurence Fox", latest, question=question)
@commands.command(aliases=["farage"])
@commands.guild_only()
async def nigel(self, ctx: commands.Context, latest: typing.Optional[int] = None, *, question: str = None):
async def nigel(
self,
ctx: commands.Context,
latest: typing.Optional[int] = None,
*,
question: str = None,
):
"""
Generates a twitter post from Nigel Farage
@ -1382,7 +1558,13 @@ class Ollama(commands.Cog):
@commands.command(aliases=["starmer"])
@commands.guild_only()
async def keir(self, ctx: commands.Context, latest: typing.Optional[int] = None, *, question: str = None):
async def keir(
self,
ctx: commands.Context,
latest: typing.Optional[int] = None,
*,
question: str = None,
):
"""
Generates a twitter post from Keir Starmer
@ -1394,10 +1576,16 @@ class Ollama(commands.Cog):
question = f"'@{ctx.author.display_name}' asks: {question!r}"
async with ctx.channel.typing():
await self.generate_truth(ctx, "Keir Starmer", latest, question=question)
@commands.command(aliases=["johnson"])
@commands.guild_only()
async def boris(self, ctx: commands.Context, latest: typing.Optional[int] = None, *, question: str = None):
async def boris(
self,
ctx: commands.Context,
latest: typing.Optional[int] = None,
*,
question: str = None,
):
"""
Generates a twitter post from Boris Johnson
@ -1409,10 +1597,16 @@ class Ollama(commands.Cog):
question = f"'@{ctx.author.display_name}' asks: {question!r}"
async with ctx.channel.typing():
await self.generate_truth(ctx, "Boris Johnson", latest, question=question)
@commands.command(aliases=["desantis"])
@commands.guild_only()
async def ron(self, ctx: commands.Context, latest: typing.Optional[int] = None, *, question: str = None):
async def ron(
self,
ctx: commands.Context,
latest: typing.Optional[int] = None,
*,
question: str = None,
):
"""
Generates a twitter post from Ron Desantis
@ -1447,12 +1641,15 @@ class Ollama(commands.Cog):
title="Truth:",
description=truth.content,
colour=0x6559FF,
timestamp=datetime.datetime.fromtimestamp(truth.timestamp, tz=datetime.timezone.utc),
timestamp=datetime.datetime.fromtimestamp(
truth.timestamp, tz=datetime.timezone.utc
),
)
if truth.extra:
embed.add_field(
name="Extra Data",
value="```json\n%s```" % json.dumps(truth.extra, indent=2, default=repr),
value="```json\n%s```"
% json.dumps(truth.extra, indent=2, default=repr),
)
embed.set_author(name=truth.author)
await ctx.reply(embed=embed)

112
src/cogs/onion_feed.py Normal file
View file

@ -0,0 +1,112 @@
"""
Scrapes the onion RSS feed once every hour and posts any new articles to the desired channel
"""
import asyncio
import dataclasses
import datetime
import logging
from bs4 import BeautifulSoup
import discord
from discord.ext import commands, tasks
import httpx
from conf import CONFIG
import redis
@dataclasses.dataclass
class RSSItem:
title: str
link: str
description: str
pubDate: datetime.datetime
guid: str
thumbnail: str
class OnionFeed(commands.Cog):
SOURCE = "https://www.theonion.com/rss"
EPOCH = datetime.datetime(2024, 7, 1, tzinfo=datetime.timezone.utc)
def __init__(self, bot):
self.bot: commands.Bot = bot
self.log = logging.getLogger("jimmy.cogs.onion_feed")
self.check_onion_feed.start()
self.redis = redis.Redis(**CONFIG["redis"])
def cog_unload(self) -> None:
self.check_onion_feed.cancel()
@staticmethod
def parse_item(item: BeautifulSoup) -> RSSItem:
description = (
BeautifulSoup(item.description.get_text(), "html.parser")
.p.get_text(strip=True)
.strip()[:-1]
)
kwargs = {
"title": item.title.get_text(strip=True).strip(),
"link": item.link.get_text(strip=True).strip(),
"pubDate": datetime.datetime.strptime(
item.pubDate.get_text(strip=True).strip(), "%a, %d %b %Y %H:%M:%S %Z"
),
"guid": item.guid.get_text(strip=True).strip(),
"description": description,
"thumbnail": item.find("media:thumbnail")["url"],
}
return RSSItem(**kwargs)
def parse_feed(self, text: str) -> list[RSSItem]:
soup = BeautifulSoup(text, "xml")
return [self.parse_item(item) for item in soup.find_all("item")]
@tasks.loop(hours=1)
async def check_onion_feed(self):
if not self.bot.is_ready():
await self.bot.wait_until_ready()
guild = self.bot.get_guild(994710566612500550)
if not guild:
return self.log.error("Nonsense guild not found. Can't do onion feed.")
channel = discord.utils.get(guild.text_channels, name="spam")
if not channel:
return self.log.error("Spam channel not found.")
async with httpx.AsyncClient() as client:
response = await client.get(self.SOURCE)
if response.status_code != 200:
return self.log.error(
f"Failed to fetch onion feed: {response.status_code}"
)
items: list[RSSItem] = await asyncio.to_thread(
self.parse_feed, response.text
)
for item in items:
if self.redis.get("onion-" + item.guid):
continue
embed = discord.Embed(
title=item.title,
url=item.link,
description=item.description + f"... [Read More]({item.link})",
color=0x00DF78,
timestamp=item.pubDate,
)
embed.set_thumbnail(url=item.thumbnail)
try:
msg = await channel.send(embed=embed)
# noinspection PyAsyncCall
self.redis.set("onion-" + item.guid, str(msg.id))
except discord.HTTPException:
self.log.exception(
f"Failed to send onion feed message: {item.title}"
)
else:
self.log.debug(f"Sent onion feed message: {item.title}")
@check_onion_feed.before_loop
async def before_check_onion_feed(self):
await self.bot.wait_until_ready()
def setup(bot):
bot.add_cog(OnionFeed(bot))

View file

@ -16,11 +16,9 @@ from discord.ext import commands
from conf import CONFIG
from pydantic import BaseModel, Field
JSON: typing.Union[
str, int, float, bool, None, dict[str, "JSON"], list["JSON"]
] = typing.Union[
str, int, float, bool, None, dict, list
]
JSON: typing.Union[str, int, float, bool, None, dict[str, "JSON"], list["JSON"]] = (
typing.Union[str, int, float, bool, None, dict, list]
)
class TruthPayload(BaseModel):
@ -32,7 +30,6 @@ class TruthPayload(BaseModel):
class QuoteQuota(commands.Cog):
def __init__(self, bot):
self.bot = bot
self.quotes_channel_id = CONFIG["quote_a"].get("channel_id")
@ -47,7 +44,9 @@ class QuoteQuota(commands.Cog):
return c
@staticmethod
def generate_pie_chart(usernames: list[str], counts: list[int], no_other: bool = False) -> discord.File:
def generate_pie_chart(
usernames: list[str], counts: list[int], no_other: bool = False
) -> discord.File:
"""
Converts the given username and count tuples into a nice pretty pie chart.
@ -95,7 +94,9 @@ class QuoteQuota(commands.Cog):
startangle=90,
radius=1.2,
)
fig.subplots_adjust(left=0.1, bottom=0.1, right=0.9, top=0.9, wspace=0.3, hspace=0.4)
fig.subplots_adjust(
left=0.1, bottom=0.1, right=0.9, top=0.9, wspace=0.3, hspace=0.4
)
fio = io.BytesIO()
fig.savefig(fio, format="png")
fio.seek(0)
@ -130,7 +131,9 @@ class QuoteQuota(commands.Cog):
now = discord.utils.utcnow()
oldest = now - timedelta(days=days)
await ctx.defer()
channel = self.quotes_channel or discord.utils.get(ctx.guild.text_channels, name="quotes")
channel = self.quotes_channel or discord.utils.get(
ctx.guild.text_channels, name="quotes"
)
if not channel:
return await ctx.respond(":x: Cannot find quotes channel.")
@ -139,7 +142,9 @@ class QuoteQuota(commands.Cog):
authors = {}
filtered_messages = 0
total = 0
async for message in channel.history(limit=None, after=oldest, oldest_first=False):
async for message in channel.history(
limit=None, after=oldest, oldest_first=False
):
total += 1
if not message.content:
filtered_messages += 1
@ -179,10 +184,15 @@ class QuoteQuota(commands.Cog):
' (e.g. `"This is my quote" - Jimmy`)'.format(days)
)
else:
return await ctx.edit(content="No messages found in the last {!s} days.".format(days))
return await ctx.edit(
content="No messages found in the last {!s} days.".format(days)
)
file = await asyncio.to_thread(
self.generate_pie_chart, list(authors.keys()), list(authors.values()), merge_other
self.generate_pie_chart,
list(authors.keys()),
list(authors.values()),
merge_other,
)
return await ctx.edit(
content="{:,} messages (out of {:,}) were filtered (didn't follow format?)".format(
@ -192,7 +202,11 @@ class QuoteQuota(commands.Cog):
)
def _metacounter(
self, truths: list[TruthPayload], filter_func: Callable[[TruthPayload], bool], *, now: datetime = None
self,
truths: list[TruthPayload],
filter_func: Callable[[TruthPayload], bool],
*,
now: datetime = None,
) -> dict[str, float | int | dict[str, int]]:
def _is_today(date: datetime) -> bool:
return date.date() == now.date()
@ -215,7 +229,9 @@ class QuoteQuota(commands.Cog):
if filter_func(truth):
created_at = datetime.fromtimestamp(truth.timestamp, tz=timezone.utc)
age = now - created_at
self.log.debug("%r was a truth (%.2f seconds ago).", truth.id, age.total_seconds())
self.log.debug(
"%r was a truth (%.2f seconds ago).", truth.id, age.total_seconds()
)
counts["all_time"] += 1
if _is_today(created_at):
counts["today"] += 1
@ -275,7 +291,9 @@ class QuoteQuota(commands.Cog):
plt.bar(hrs, list(hours.values()), color="#5448EE")
average = sum(hours.values()) / len(hours)
plt.axhline(average, color="red", linestyle="--", label=f"Average: {average:.1f}")
plt.axhline(
average, color="red", linestyle="--", label=f"Average: {average:.1f}"
)
file = io.BytesIO()
plt.savefig(file, format="png")
@ -283,8 +301,7 @@ class QuoteQuota(commands.Cog):
return discord.File(file, "truths.png")
async def _process_all_messages(
self,
truths: list[TruthPayload]
self, truths: list[TruthPayload]
) -> tuple[discord.Embed, discord.File]:
"""
Processes all the messages in the given channel.
@ -292,7 +309,11 @@ class QuoteQuota(commands.Cog):
:param truths: The truths to process
:returns: The stats
"""
embed = discord.Embed(title="Truth Counts", color=discord.Color.blurple(), timestamp=discord.utils.utcnow())
embed = discord.Embed(
title="Truth Counts",
color=discord.Color.blurple(),
timestamp=discord.utils.utcnow(),
)
trump_stats = await self._process_trump_truths(truths)
tate_stats = await self._process_tate_truths(truths)
@ -346,7 +367,9 @@ class QuoteQuota(commands.Cog):
)
response.raise_for_status()
truths: list[dict] = response.json()
truths: list[TruthPayload] = list(map(lambda t: TruthPayload.model_validate(t), truths))
truths: list[TruthPayload] = list(
map(lambda t: TruthPayload.model_validate(t), truths)
)
embed, file = await self._process_all_messages(truths)
await ctx.edit(embed=embed, file=file)

View file

@ -1,5 +1,4 @@
import asyncio
import copy
import datetime
import io
import logging
@ -133,7 +132,9 @@ class ScreenshotCog(commands.Cog):
else:
use_proxy = False
service = await asyncio.to_thread(ChromeService)
driver: webdriver.Chrome = await asyncio.to_thread(webdriver.Chrome, service=service, options=options)
driver: webdriver.Chrome = await asyncio.to_thread(
webdriver.Chrome, service=service, options=options
)
driver.set_page_load_timeout(load_timeout)
if resolution:
resolution = RESOLUTIONS.get(resolution.lower(), resolution)
@ -141,9 +142,13 @@ class ScreenshotCog(commands.Cog):
width, height = map(int, resolution.split("x"))
driver.set_window_size(width, height)
if height > 4320 or width > 7680:
return await ctx.respond("Invalid resolution. Max resolution is 7680x4320 (8K).")
return await ctx.respond(
"Invalid resolution. Max resolution is 7680x4320 (8K)."
)
except ValueError:
return await ctx.respond("Invalid resolution. please provide width x height, e.g. 1920x1080")
return await ctx.respond(
"Invalid resolution. please provide width x height, e.g. 1920x1080"
)
if eager:
driver.implicitly_wait(render_timeout)
except Exception as e:
@ -151,7 +156,13 @@ class ScreenshotCog(commands.Cog):
raise
end_init = time.time()
await ctx.edit(content=("Loading webpage..." if not eager else "Loading & screenshotting webpage..."))
await ctx.edit(
content=(
"Loading webpage..."
if not eager
else "Loading & screenshotting webpage..."
)
)
start_request = time.time()
try:
await asyncio.to_thread(driver.get, url)
@ -160,7 +171,9 @@ class ScreenshotCog(commands.Cog):
if "TimeoutException" in str(e):
return await ctx.respond("Timed out while loading webpage.")
else:
return await ctx.respond("Failed to load webpage:\n```\n%s\n```" % str(e.msg))
return await ctx.respond(
"Failed to load webpage:\n```\n%s\n```" % str(e.msg)
)
except Exception as e:
await self.bot.loop.run_in_executor(None, driver.quit)
await ctx.respond("Failed to get the webpage: " + str(e))
@ -170,7 +183,9 @@ class ScreenshotCog(commands.Cog):
if not eager:
now = discord.utils.utcnow()
expires = now + datetime.timedelta(seconds=render_timeout)
await ctx.edit(content=f"Rendering (expires {discord.utils.format_dt(expires, 'R')})...")
await ctx.edit(
content=f"Rendering (expires {discord.utils.format_dt(expires, 'R')})..."
)
start_wait = time.time()
await asyncio.sleep(render_timeout)
end_wait = time.time()
@ -200,7 +215,9 @@ class ScreenshotCog(commands.Cog):
await self.bot.loop.run_in_executor(None, driver.quit)
end_cleanup = time.time()
screenshot_size_mb = round(len(await asyncio.to_thread(file.getvalue)) / 1024 / 1024, 2)
screenshot_size_mb = round(
len(await asyncio.to_thread(file.getvalue)) / 1024 / 1024, 2
)
def seconds(start: float, end: float) -> float:
return round(end - start, 2)
@ -221,7 +238,9 @@ class ScreenshotCog(commands.Cog):
timestamp=discord.utils.utcnow(),
)
embed.set_image(url="attachment://" + fn)
return await ctx.edit(content=None, embed=embed, file=discord.File(file, filename=fn))
return await ctx.edit(
content=None, embed=embed, file=discord.File(file, filename=fn)
)
def setup(bot):

View file

@ -10,24 +10,32 @@ from conf import CONFIG
class Starboard(commands.Cog):
DOWNVOTE_EMOJI = discord.PartialEmoji.from_str("\N{collision symbol}")
DOWNVOTE_EMOJI = discord.PartialEmoji.from_str("\N{COLLISION SYMBOL}")
def __init__(self, bot):
self.bot: commands.Bot = bot
self.config = CONFIG["starboard"]
self.emoji = discord.PartialEmoji.from_str(self.config.get("emoji", "\N{white medium star}"))
self.emoji = discord.PartialEmoji.from_str(
self.config.get("emoji", "\N{WHITE MEDIUM STAR}")
)
self.log = logging.getLogger("jimmy.cogs.starboard")
self.redis = Redis(**CONFIG["redis"])
async def generate_starboard_embed(self, message: discord.Message) -> tuple[list[discord.Embed], int]:
async def generate_starboard_embed(
self, message: discord.Message
) -> tuple[list[discord.Embed], int]:
"""
Generates an embed ready for a starboard message.
:param message: The message to base off of.
:return: The created embed
"""
reactions: list[discord.Reaction] = [x for x in message.reactions if str(x.emoji) == str(self.emoji)]
downvote_reactions = [x for x in message.reactions if str(x.emoji) == str(self.DOWNVOTE_EMOJI)]
reactions: list[discord.Reaction] = [
x for x in message.reactions if str(x.emoji) == str(self.emoji)
]
downvote_reactions = [
x for x in message.reactions if str(x.emoji) == str(self.DOWNVOTE_EMOJI)
]
if not reactions:
# Nobody has added the star reaction.
star_count = 0
@ -35,12 +43,18 @@ class Starboard(commands.Cog):
else:
# Count the number of reactions
star_count = sum([x.count for x in reactions])
self.log.debug("There are a total of %d star reactions on message.", star_count)
self.log.debug(
"There are a total of %d star reactions on message.", star_count
)
if downvote_reactions:
_dv = sum([x.count for x in downvote_reactions])
star_count -= _dv
self.log.debug("There are %d downvotes on message, resulting in %d stars.", _dv, star_count)
self.log.debug(
"There are %d downvotes on message, resulting in %d stars.",
_dv,
star_count,
)
if star_count >= 0:
star_emoji_count = (str(self.emoji) * star_count)[:10]
@ -54,18 +68,20 @@ class Starboard(commands.Cog):
author=discord.EmbedAuthor(
message.author.display_name,
message.author.jump_url,
message.author.display_avatar.url
message.author.display_avatar.url,
),
fields=[
discord.EmbedField(
name="Info",
value=f"[Stars: {star_emoji_count} ({star_count:,})]({message.jump_url})"
value=f"[Stars: {star_emoji_count} ({star_count:,})]({message.jump_url})",
)
]
],
)
if message.reference:
try:
ref_message = await self.get_or_fetch_message(message.reference.channel_id, message.reference.message_id)
ref_message = await self.get_or_fetch_message(
message.reference.channel_id, message.reference.message_id
)
except discord.HTTPException:
pass
else:
@ -74,31 +90,26 @@ class Starboard(commands.Cog):
remaining = 1024 - len(v)
t = textwrap.shorten(text, remaining, placeholder="...")
v = f"[{ref_message.author.display_name}'s message: {t}]({ref_message.jump_url})"
embed.add_field(
name="Replying to",
value=v
)
embed.add_field(name="Replying to", value=v)
elif message.interaction:
if message.interaction.type == discord.InteractionType.application_command:
real_author: discord.User = await discord.utils.get_or_fetch(
self.bot,
"user",
int(message.interaction.data["user"]["id"])
self.bot, "user", int(message.interaction.data["user"]["id"])
)
real_author = (
await discord.utils.get_or_fetch(
message.guild, "member", real_author.id, default=real_author
)
or message.author
)
real_author = await discord.utils.get_or_fetch(
message.guild,
"member",
real_author.id,
default=real_author
) or message.author
embed.set_author(
name=real_author.display_name,
icon_url=real_author.display_avatar.url,
url=real_author.jump_url
url=real_author.jump_url,
)
embed.add_field(
name="Interaction",
value=f"Command `/{message.interaction.data['name']}` of {message.author.mention}"
value=f"Command `/{message.interaction.data['name']}` of {message.author.mention}",
)
if message.content:
@ -107,17 +118,24 @@ class Starboard(commands.Cog):
for message_embed in message.embeds:
if message_embed.type != "rich":
if message_embed.type == "image":
if message_embed.thumbnail and message_embed.thumbnail.proxy_url:
if (
message_embed.thumbnail
and message_embed.thumbnail.proxy_url
):
embed.set_image(url=message_embed.thumbnail.proxy_url)
continue
if message_embed.description:
embed.description = message_embed.description
elif not message.attachments:
raise ValueError("Message does not appear to contain any text, embeds, or attachments.")
raise ValueError(
"Message does not appear to contain any text, embeds, or attachments."
)
if message.attachments:
new_fields = []
for n, attachment in reversed(tuple(enumerate(message.attachments, start=1))):
for n, attachment in reversed(
tuple(enumerate(message.attachments, start=1))
):
attachment: discord.Attachment
if attachment.size >= 1024 * 1024:
size = f"{attachment.size / 1024 / 1024:,.1f}MiB"
@ -129,7 +147,7 @@ class Starboard(commands.Cog):
{
"name": "Attachment #%d:" % n,
"value": f"[{attachment.filename} ({size})]({attachment.url})",
"inline": True
"inline": True,
}
)
if attachment.content_type.startswith("image/"):
@ -142,13 +160,19 @@ class Starboard(commands.Cog):
return [embed, *filter(lambda e: e.type == "rich", message.embeds)], star_count
async def get_or_fetch_message(self, channel_id: int, message_id: int) -> discord.Message:
async def get_or_fetch_message(
self, channel_id: int, message_id: int
) -> discord.Message:
"""
Fetches a message from cache where possible, falling back to the API.
"""
message: discord.Message | None = discord.utils.get(self.bot.cached_messages, id=message_id)
message: discord.Message | None = discord.utils.get(
self.bot.cached_messages, id=message_id
)
if not message:
message: discord.Message = await self.bot.get_channel(channel_id).fetch_message(message_id)
message: discord.Message = await self.bot.get_channel(
channel_id
).fetch_message(message_id)
return message
@commands.Cog.listener("on_raw_reaction_add")
@ -171,27 +195,35 @@ class Starboard(commands.Cog):
guild: discord.Guild = self.bot.get_guild(payload.guild_id)
starboard_channel: discord.TextChannel | None = discord.utils.get(
guild.text_channels,
name=self.config.get("channel_name", "starboard")
guild.text_channels, name=self.config.get("channel_name", "starboard")
)
if not starboard_channel:
self.log.warning("Could not find starboard channel in %s (%d)", guild, guild.id)
self.log.warning(
"Could not find starboard channel in %s (%d)", guild, guild.id
)
return
if payload.channel_id == starboard_channel.id:
self.log.debug("Ignoring reaction in starboard channel.")
return
message = await self.get_or_fetch_message(payload.channel_id, payload.message_id)
message = await self.get_or_fetch_message(
payload.channel_id, payload.message_id
)
if payload.user_id == message.author.id:
self.log.info("%s tried to star their own message.", message.author)
return await message.reply(
"You can't star your own message you pretentious dick. Go outside, %s." % message.author.mention,
delete_after=30
"You can't star your own message you pretentious dick. Go outside, %s."
% message.author.mention,
delete_after=30,
)
self.log.info("Processing starboard reaction from %s: %s", payload.user_id, message.jump_url)
self.log.info(
"Processing starboard reaction from %s: %s",
payload.user_id,
message.jump_url,
)
embed, star_count = await self.generate_starboard_embed(message)
data = await self.redis.get(str(message.id))
@ -205,7 +237,7 @@ class Starboard(commands.Cog):
"source_message_id": payload.message_id,
"history": [],
"starboard_channel_id": starboard_channel.id,
"starboard_message_id": None
"starboard_message_id": None,
}
if not starboard_channel.can_send(embed[0]):
self.log.warning(
@ -213,29 +245,24 @@ class Starboard(commands.Cog):
starboard_channel.id,
payload.guild_id,
starboard_channel.name,
guild.name
guild.name,
)
return
starboard_message = await starboard_channel.send(
embeds=embed,
silent=True
)
starboard_message = await starboard_channel.send(embeds=embed, silent=True)
data["starboard_message_id"] = starboard_message.id
await self.redis.set(str(message.id), json.dumps(data))
else:
data = json.loads(data)
try:
starboard_message = await self.get_or_fetch_message(
data["starboard_channel_id"],
data["starboard_message_id"]
data["starboard_channel_id"], data["starboard_message_id"]
)
except discord.NotFound:
if star_count <= 0:
return
starboard_message = await starboard_channel.send(
embeds=embed,
silent=True
embeds=embed, silent=True
)
data["starboard_message_id"] = starboard_message.id
data["starboard_channel_id"] = starboard_message.channel.id
@ -247,18 +274,20 @@ class Starboard(commands.Cog):
"Deleted message %s in %s, %s - lost all stars",
starboard_message.id,
starboard_message.channel.name,
starboard_message.guild.name
starboard_message.guild.name,
)
elif starboard_message.embeds[0] != embed[0]:
await starboard_message.edit(embeds=embed)
@commands.message_command(name="Preview Starboard Message")
async def preview_starboard_message(self, ctx: discord.ApplicationContext, message: discord.Message):
async def preview_starboard_message(
self, ctx: discord.ApplicationContext, message: discord.Message
):
embed, stars = await self.generate_starboard_embed(message)
data = await self.redis.get(str(message.id)) or 'null'
data = await self.redis.get(str(message.id)) or "null"
return await ctx.respond(
f"```json\n{json.dumps(json.loads(data), indent=4)}\n```\nStars: {stars:,}",
embeds=embed
embeds=embed,
)

View file

@ -135,12 +135,21 @@ class YTDLCog(commands.Cog):
channel_id=excluded.channel_id,
attachment_index=excluded.attachment_index
""",
(_hash, message.id, message.channel.id, webpage_url, format_id, attachment_index),
(
_hash,
message.id,
message.channel.id,
webpage_url,
format_id,
attachment_index,
),
)
await db.commit()
return _hash
async def get_saved(self, webpage_url: str, format_id: str, snip: str) -> typing.Optional[str]:
async def get_saved(
self, webpage_url: str, format_id: str, snip: str
) -> typing.Optional[str]:
"""
Attempts to retrieve the attachment URL of a previously saved download.
:param webpage_url: The webpage url
@ -154,12 +163,19 @@ class YTDLCog(commands.Cog):
logging.error("Failed to initialise ytdl database: %s", e, exc_info=True)
return
async with aiosqlite.connect("./data/ytdl.db") as db:
_hash = hashlib.md5(f"{webpage_url}:{format_id}:{snip}".encode()).hexdigest()
_hash = hashlib.md5(
f"{webpage_url}:{format_id}:{snip}".encode()
).hexdigest()
self.log.debug(
"Attempting to find a saved download for '%s:%s:%s' (%r).", webpage_url, format_id, snip, _hash
"Attempting to find a saved download for '%s:%s:%s' (%r).",
webpage_url,
format_id,
snip,
_hash,
)
cursor = await db.execute(
"SELECT message_id, channel_id, attachment_index FROM downloads WHERE key=?", (_hash,)
"SELECT message_id, channel_id, attachment_index FROM downloads WHERE key=?",
(_hash,),
)
entry = await cursor.fetchone()
if not entry:
@ -173,7 +189,9 @@ class YTDLCog(commands.Cog):
try:
message = await channel.fetch_message(message_id)
except discord.HTTPException:
self.log.debug("%r did not contain a message with ID %r", channel, message_id)
self.log.debug(
"%r did not contain a message with ID %r", channel, message_id
)
await db.execute("DELETE FROM downloads WHERE key=?", (_hash,))
return
@ -182,7 +200,11 @@ class YTDLCog(commands.Cog):
self.log.debug("Found URL %r, returning.", url)
return url
except IndexError:
self.log.debug("Attachment index %d is out of range (%r)", attachment_index, message.attachments)
self.log.debug(
"Attachment index %d is out of range (%r)",
attachment_index,
message.attachments,
)
return
def convert_to_m4a(self, file: Path) -> Path:
@ -209,23 +231,32 @@ class YTDLCog(commands.Cog):
str(new_file),
]
self.log.debug("Running command: ffmpeg %s", " ".join(args))
process = subprocess.run(["ffmpeg", *args], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
process = subprocess.run(
["ffmpeg", *args], stdout=subprocess.PIPE, stderr=subprocess.PIPE
)
if process.returncode != 0:
raise RuntimeError(process.stderr.decode())
return new_file
@staticmethod
async def upload_to_0x0(name: str, data: typing.IO[bytes], mime_type: str | None = None) -> str:
async def upload_to_0x0(
name: str, data: typing.IO[bytes], mime_type: str | None = None
) -> str:
if not mime_type:
import magic
mime_type = await asyncio.to_thread(magic.from_buffer, data.read(4096), mime=True)
mime_type = await asyncio.to_thread(
magic.from_buffer, data.read(4096), mime=True
)
data.seek(0)
async with httpx.AsyncClient() as client:
response = await client.post(
"https://0x0.st",
files={"file": (name, data, mime_type)},
data={"expires": 12},
headers={"User-Agent": "CollegeBot (see: https://gist.i-am.nexus/nex/f63fcb9eb389401caf66d1dfc3c7570c)"},
headers={
"User-Agent": "CollegeBot (see: https://gist.i-am.nexus/nex/f63fcb9eb389401caf66d1dfc3c7570c)"
},
)
if response.status_code == 200:
return urlparse(response.text).path[1:]
@ -235,7 +266,10 @@ class YTDLCog(commands.Cog):
async def yt_dl_command(
self,
ctx: discord.ApplicationContext,
url: typing.Annotated[str, discord.Option(str, description="The URL to download from.", required=True)],
url: typing.Annotated[
str,
discord.Option(str, description="The URL to download from.", required=True),
],
user_format: typing.Annotated[
typing.Optional[str],
discord.Option(
@ -258,7 +292,10 @@ class YTDLCog(commands.Cog):
],
snip: typing.Annotated[
typing.Optional[str],
discord.Option(description="A start and end position to trim. e.g. 00:00:00-00:10:00.", required=False),
discord.Option(
description="A start and end position to trim. e.g. 00:00:00-00:10:00.",
required=False,
),
],
subtitles: typing.Annotated[
typing.Optional[str],
@ -267,7 +304,7 @@ class YTDLCog(commands.Cog):
description="The language code of the subtitles to download. e.g. 'en', 'auto'",
required=False,
),
]
],
):
"""Runs yt-dlp and outputs into discord."""
await ctx.defer()
@ -279,28 +316,40 @@ class YTDLCog(commands.Cog):
if stop.is_set():
raise RuntimeError("Download cancelled.")
n = time.time()
_total = _data.get("total_bytes", _data.get("total_bytes_estimate")) or ctx.guild.filesize_limit
_total = (
_data.get("total_bytes", _data.get("total_bytes_estimate"))
or ctx.guild.filesize_limit
)
if _total:
_percent = round((_data.get("downloaded_bytes") or 0) / _total * 100, 2)
else:
_total = max(1, _data.get("fragment_count", 4096) or 4096)
_percent = round(max(_data.get("fragment_index", 1) or 1, 1) / _total * 100, 2)
_percent = round(
max(_data.get("fragment_index", 1) or 1, 1) / _total * 100, 2
)
_speed_bytes_per_second = _data.get("speed", 1) or 1 or 1
_speed_megabits_per_second = round((_speed_bytes_per_second * 8) / 1024 / 1024)
_speed_megabits_per_second = round(
(_speed_bytes_per_second * 8) / 1024 / 1024
)
if _data.get("eta"):
_eta = discord.utils.utcnow() + datetime.timedelta(seconds=_data.get("eta"))
_eta = discord.utils.utcnow() + datetime.timedelta(
seconds=_data.get("eta")
)
else:
_eta = discord.utils.utcnow() + datetime.timedelta(minutes=1)
blocks = "#" * math.floor(_percent / 10)
bar = f"{blocks}{'.' * (10 - len(blocks))}"
line = (f"{_percent}% [{bar}] | {_speed_megabits_per_second}Mbps | "
f"ETA {discord.utils.format_dt(_eta, 'R')}")
line = (
f"{_percent}% [{bar}] | {_speed_megabits_per_second}Mbps | "
f"ETA {discord.utils.format_dt(_eta, 'R')}"
)
nonlocal last_edit
if (n - last_edit) >= 1.1:
embed.clear_fields()
embed.add_field(name="Progress", value=line)
ctx.bot.loop.create_task(ctx.edit(embed=embed))
last_edit = time.time()
options["progress_hooks"] = [_download_hook]
description = ""
@ -331,13 +380,19 @@ class YTDLCog(commands.Cog):
options["format_sort"] = ["abr", "br"]
# noinspection PyTypeChecker
options["postprocessors"].append(
{"key": "FFmpegExtractAudio", "preferredquality": "96", "preferredcodec": "best"}
{
"key": "FFmpegExtractAudio",
"preferredquality": "96",
"preferredcodec": "best",
}
)
options["format"] = chosen_format
options["paths"] = paths
if subtitles and ffmpeg_installed:
subtitles, burn = subtitles.split("+", 1) if "+" in subtitles else (subtitles, "0")
subtitles, burn = (
subtitles.split("+", 1) if "+" in subtitles else (subtitles, "0")
)
burn = burn[0].lower() in ("y", "1", "t")
if subtitles.lower() == "auto":
options["writeautosubtitles"] = True
@ -352,10 +407,14 @@ class YTDLCog(commands.Cog):
)
with yt_dlp.YoutubeDL(options) as downloader:
await ctx.respond(embed=discord.Embed().set_footer(text="Downloading (step 1/10)"))
await ctx.respond(
embed=discord.Embed().set_footer(text="Downloading (step 1/10)")
)
try:
# noinspection PyTypeChecker
extracted_info = await asyncio.to_thread(downloader.extract_info, url, download=False)
extracted_info = await asyncio.to_thread(
downloader.extract_info, url, download=False
)
except yt_dlp.utils.DownloadError as e:
extracted_info = {
"title": "error",
@ -382,22 +441,38 @@ class YTDLCog(commands.Cog):
thumbnail_url = extracted_info.get("thumbnail") or None
webpage_url = extracted_info.get("webpage_url", url)
chosen_format = extracted_info.get("format") or chosen_format or str(uuid.uuid4())
chosen_format_id = extracted_info.get("format_id") or str(uuid.uuid4())
chosen_format = (
extracted_info.get("format")
or chosen_format
or str(uuid.uuid4())
)
chosen_format_id = extracted_info.get("format_id") or str(
uuid.uuid4()
)
final_extension = extracted_info.get("ext") or "mp4"
format_note = extracted_info.get("format_note", "%s (%s)" % (chosen_format, chosen_format_id)) or ""
format_note = (
extracted_info.get(
"format_note", "%s (%s)" % (chosen_format, chosen_format_id)
)
or ""
)
resolution = extracted_info.get("resolution") or "1x1"
fps = extracted_info.get("fps", 0.0) or 0.0
vcodec = extracted_info.get("vcodec") or "h264"
acodec = extracted_info.get("acodec") or "aac"
filesize = extracted_info.get("filesize", extracted_info.get("filesize_approx", 1))
likes = extracted_info.get("like_count", extracted_info.get("average_rating", 0))
filesize = extracted_info.get(
"filesize", extracted_info.get("filesize_approx", 1)
)
likes = extracted_info.get(
"like_count", extracted_info.get("average_rating", 0)
)
views = extracted_info.get("view_count", 0)
lines = []
if chosen_format and chosen_format_id:
lines.append(
"* Chosen format: `%s` (`%s`)" % (chosen_format, chosen_format_id),
"* Chosen format: `%s` (`%s`)"
% (chosen_format, chosen_format_id),
)
if format_note:
lines.append("* Format note: %r" % format_note)
@ -411,7 +486,9 @@ class YTDLCog(commands.Cog):
if vcodec or acodec:
lines.append("%s+%s" % (vcodec or "N/A", acodec or "N/A"))
if filesize:
lines.append("* Filesize: %s" % yt_dlp.utils.format_bytes(filesize))
lines.append(
"* Filesize: %s" % yt_dlp.utils.format_bytes(filesize)
)
if lines:
description += "\n"
@ -424,27 +501,29 @@ class YTDLCog(commands.Cog):
url=webpage_url,
colour=self.colours.get(domain, discord.Colour.og_blurple()),
)
embed.add_field(
name="Progress",
value="0% [..........]"
)
embed.add_field(name="Progress", value="0% [..........]")
embed.set_footer(text="Downloading (step 2/10)")
embed.set_thumbnail(url=thumbnail_url)
class StopView(discord.ui.View):
@discord.ui.button(label="Cancel download", style=discord.ButtonStyle.danger)
async def _stop(self, button: discord.ui.Button, interaction: discord.Interaction):
@discord.ui.button(
label="Cancel download", style=discord.ButtonStyle.danger
)
async def _stop(
self,
button: discord.ui.Button,
interaction: discord.Interaction,
):
stop.set()
button.label = "Cancelling..."
button.disabled = True
await interaction.response.edit_message(view=self)
self.stop()
await ctx.edit(
embed=embed,
view=StopView(timeout=86400)
await ctx.edit(embed=embed, view=StopView(timeout=86400))
previous = await self.get_saved(
webpage_url, chosen_format_id, snip or "*"
)
previous = await self.get_saved(webpage_url, chosen_format_id, snip or "*")
if previous:
await ctx.edit(
content=previous,
@ -454,7 +533,11 @@ class YTDLCog(commands.Cog):
colour=discord.Colour.green(),
timestamp=discord.utils.utcnow(),
url=previous,
fields=[discord.EmbedField(name="URL", value=previous, inline=False)],
fields=[
discord.EmbedField(
name="URL", value=previous, inline=False
)
],
).set_image(url=previous),
)
return
@ -462,7 +545,9 @@ class YTDLCog(commands.Cog):
last_edit = time.time()
try:
await asyncio.to_thread(functools.partial(downloader.download, [url]))
await asyncio.to_thread(
functools.partial(downloader.download, [url])
)
except yt_dlp.DownloadError as e:
logging.error(e, exc_info=True)
return await ctx.edit(
@ -473,7 +558,7 @@ class YTDLCog(commands.Cog):
url=webpage_url,
),
delete_after=120,
view=None
view=None,
)
except RuntimeError:
return await ctx.edit(
@ -484,16 +569,26 @@ class YTDLCog(commands.Cog):
url=webpage_url,
),
delete_after=120,
view=None
view=None,
)
await ctx.edit(view=None)
try:
if audio_only is False:
file: Path = next(temp_dir.glob("*." + extracted_info.get("ext", "*")))
file: Path = next(
temp_dir.glob("*." + extracted_info.get("ext", "*"))
)
else:
# can be .opus, .m4a, .mp3, .ogg, .oga
for _file in temp_dir.iterdir():
if _file.suffix in (".opus", ".m4a", ".mp3", ".ogg", ".oga", ".aac", ".wav"):
if _file.suffix in (
".opus",
".m4a",
".mp3",
".ogg",
".oga",
".aac",
".wav",
):
file: Path = _file
break
else:
@ -523,7 +618,9 @@ class YTDLCog(commands.Cog):
except ValueError:
trim_start, trim_end = snip, None
trim_start = trim_start or "00:00:00"
trim_end = trim_end or extracted_info.get("duration_string", "00:30:00")
trim_end = trim_end or extracted_info.get(
"duration_string", "00:30:00"
)
new_file = temp_dir / ("output" + file.suffix)
args = [
"-hwaccel",
@ -562,7 +659,10 @@ class YTDLCog(commands.Cog):
)
self.log.debug("Running command: 'ffmpeg %s'", " ".join(args))
process = await asyncio.create_subprocess_exec(
"ffmpeg", *args, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE
"ffmpeg",
*args,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
stdout, stderr = await process.communicate()
self.log.debug("STDOUT:\n%r", stdout.decode())
@ -606,31 +706,40 @@ class YTDLCog(commands.Cog):
)
embed = discord.Embed(
title=f"Downloaded {title}!",
description="Views: {:,} | Likes: {:,}".format(views or 0, likes or 0),
description="Views: {:,} | Likes: {:,}".format(
views or 0, likes or 0
),
colour=discord.Colour.green(),
timestamp=discord.utils.utcnow(),
url=webpage_url,
)
try:
if size_bytes >= (20 * 1024 * 1024) or vcodec.lower() in ["hevc", "h265", "av1", "av01"]:
if size_bytes >= (20 * 1024 * 1024) or vcodec.lower() in [
"hevc",
"h265",
"av1",
"av01",
]:
with file.open("rb") as fb:
part = await self.upload_to_0x0(
file.name,
fb
)
embed.add_field(name="URL", value=f"https://0x0.st/{part}", inline=False)
await ctx.edit(
embed=embed
part = await self.upload_to_0x0(file.name, fb)
embed.add_field(
name="URL", value=f"https://0x0.st/{part}", inline=False
)
await ctx.edit(embed=embed)
await ctx.respond("https://embeds.video/0x0/" + part)
else:
upload_file = await asyncio.to_thread(discord.File, file, filename=file.name)
msg = await ctx.edit(
file=upload_file,
embed=embed
upload_file = await asyncio.to_thread(
discord.File, file, filename=file.name
)
await self.save_link(msg, webpage_url, chosen_format_id, snip=snip or "*")
except (discord.HTTPException, ConnectionError, httpx.HTTPStatusError) as e:
msg = await ctx.edit(file=upload_file, embed=embed)
await self.save_link(
msg, webpage_url, chosen_format_id, snip=snip or "*"
)
except (
discord.HTTPException,
ConnectionError,
httpx.HTTPStatusError,
) as e:
self.log.error(e, exc_info=True)
return await ctx.edit(
embed=discord.Embed(

View file

@ -21,7 +21,9 @@ if (Path.cwd() / ".git").exists():
log.debug("Unable to auto-detect running version using git.", exc_info=True)
VERSION = "unknown"
else:
log.debug("Unable to auto-detect running version using git, no .git directory exists.")
log.debug(
"Unable to auto-detect running version using git, no .git directory exists."
)
VERSION = "unknown"
try:
@ -32,7 +34,7 @@ except FileNotFoundError:
log.critical(
"Unable to locate config.toml in %s. Using default configuration. Good luck!",
cwd,
exc_info=True
exc_info=True,
)
CONFIG = {}
CONFIG.setdefault("logging", {})
@ -50,11 +52,13 @@ CONFIG.setdefault(
"api": "https://bots.nexy7574.co.uk/jimmy/v2/",
"username": os.getenv("WEB_USERNAME", os.urandom(32).hex()),
"password": os.getenv("WEB_PASSWORD", os.urandom(32).hex()),
}
},
)
if CONFIG["redis"].pop("no_ping", None) is not None:
log.warning("`redis.no_ping` was deprecated after 808D621F. Ping is now always mandatory.")
log.warning(
"`redis.no_ping` was deprecated after 808D621F. Ping is now always mandatory."
)
CONFIG["redis"]["decode_responses"] = True

View file

@ -61,9 +61,16 @@ class KumaThread(Thread):
response.raise_for_status()
self.previous = bot.is_ready()
except httpx.HTTPError as error:
self.log.error("Failed to connect to uptime-kuma: %r: %r", url, error, exc_info=error)
self.log.error(
"Failed to connect to uptime-kuma: %r: %r",
url,
error,
exc_info=error,
)
timeout = self.calculate_backoff()
self.log.warning("Waiting %d seconds before retrying ping.", timeout)
self.log.warning(
"Waiting %d seconds before retrying ping.", timeout
)
time.sleep(timeout)
continue
@ -91,7 +98,11 @@ logging.basicConfig(
markup=True,
console=Console(width=cols, height=lns),
),
FileHandler(filename=CONFIG["logging"].get("file", "jimmy.log"), encoding="utf-8", errors="replace"),
FileHandler(
filename=CONFIG["logging"].get("file", "jimmy.log"),
encoding="utf-8",
errors="replace",
),
],
)
for logger in CONFIG["logging"].get("suppress", []):
@ -130,7 +141,8 @@ class Client(commands.Bot):
async def start(self, token: str, *, reconnect: bool = True) -> None:
if CONFIG["jimmy"].get("uptime_kuma_url"):
self.uptime_thread = KumaThread(
CONFIG["jimmy"]["uptime_kuma_url"], CONFIG["jimmy"].get("uptime_kuma_interval", 58.0)
CONFIG["jimmy"]["uptime_kuma_url"],
CONFIG["jimmy"].get("uptime_kuma_interval", 58.0),
)
self.uptime_thread.start()
await super().start(token, reconnect=reconnect)
@ -138,7 +150,9 @@ class Client(commands.Bot):
async def close(self) -> None:
if getattr(self, "uptime_thread", None):
self.uptime_thread.kill.set()
await asyncio.get_event_loop().run_in_executor(None, self.uptime_thread.join)
await asyncio.get_event_loop().run_in_executor(
None, self.uptime_thread.join
)
await super().close()
@ -189,9 +203,13 @@ async def on_application_command_error(ctx: discord.ApplicationContext, exc: Exc
log.error(f"Error in {ctx.command} from {ctx.author} in {ctx.guild}", exc_info=exc)
if isinstance(exc, commands.CommandOnCooldown):
expires = discord.utils.utcnow() + datetime.timedelta(seconds=exc.retry_after)
await ctx.respond(f"Command on cooldown. Try again {discord.utils.format_dt(expires, style='R')}.")
await ctx.respond(
f"Command on cooldown. Try again {discord.utils.format_dt(expires, style='R')}."
)
elif isinstance(exc, commands.MaxConcurrencyReached):
await ctx.respond("You've reached the maximum number of concurrent uses for this command.")
await ctx.respond(
"You've reached the maximum number of concurrent uses for this command."
)
else:
if await bot.is_owner(ctx.author):
paginator = commands.Paginator(prefix="```py")
@ -200,7 +218,10 @@ async def on_application_command_error(ctx: discord.ApplicationContext, exc: Exc
for page in paginator.pages:
await ctx.respond(page)
else:
await ctx.respond(f"An error occurred while processing your command. Please try again later.\n" f"{exc}")
await ctx.respond(
f"An error occurred while processing your command. Please try again later.\n"
f"{exc}"
)
@bot.listen()
@ -218,11 +239,22 @@ async def delete_message(ctx: discord.ApplicationContext, message: discord.Messa
await ctx.defer()
if not ctx.channel.permissions_for(ctx.me).manage_messages:
if message.author != bot.user:
return await ctx.respond("I don't have permission to delete messages in this channel.", delete_after=30)
return await ctx.respond(
"I don't have permission to delete messages in this channel.",
delete_after=30,
)
log.info("%s deleted message %s>%s: %r", ctx.author, ctx.channel.name, message.id, message.content)
log.info(
"%s deleted message %s>%s: %r",
ctx.author,
ctx.channel.name,
message.id,
message.content,
)
await message.delete(delay=1)
await ctx.respond(f"\N{white heavy check mark} Deleted message by {message.author.display_name}.")
await ctx.respond(
f"\N{WHITE HEAVY CHECK MARK} Deleted message by {message.author.display_name}."
)
await ctx.delete(delay=5)
@ -231,16 +263,19 @@ async def about_me(ctx: discord.ApplicationContext):
embed = discord.Embed(
title="Jimmy v3",
description="A bot specifically for the LCC discord server(s).",
colour=discord.Colour.green()
colour=discord.Colour.green(),
)
embed.add_field(
name="Source",
value="[Source code is available under AGPLv3.](https://git.i-am.nexus/nex/college-bot-v2)"
value="[Source code is available under AGPLv3.](https://git.i-am.nexus/nex/college-bot-v2)",
)
user = await ctx.bot.fetch_user(421698654189912064)
appinfo = await ctx.bot.application_info()
author = appinfo.owner
embed.add_field(name="Author", value=f"{user.mention} created me. {author.mention} is running this instance.")
embed.add_field(
name="Author",
value=f"{user.mention} created me. {author.mention} is running this instance.",
)
return await ctx.respond(embed=embed)
@ -249,9 +284,8 @@ async def check_is_enabled(ctx: commands.Context | discord.ApplicationContext) -
disabled = CONFIG["jimmy"].get("disabled_commands", [])
if ctx.command.qualified_name in disabled:
raise commands.DisabledCommand(
"%s is disabled via this instance's configuration file." % (
ctx.command.qualified_name
)
"%s is disabled via this instance's configuration file."
% (ctx.command.qualified_name)
)
return True
@ -273,7 +307,9 @@ async def add_delays(ctx: commands.Context | discord.ApplicationContext) -> bool
if not CONFIG["jimmy"].get("token"):
log.critical("No token specified in config.toml. Exiting. (hint: set jimmy.token in config.toml)")
log.critical(
"No token specified in config.toml. Exiting. (hint: set jimmy.token in config.toml)"
)
sys.exit(1)
__disabled = CONFIG["jimmy"].get("disabled_commands", [])

View file

@ -2,24 +2,24 @@
import json
import redis
import orjson
import os
import secrets
import typing
import time
from fastapi import FastAPI, Depends, HTTPException, APIRouter
from fastapi.responses import JSONResponse, Response
from fastapi.responses import JSONResponse, Response, ORJSONResponse
from fastapi.security import HTTPBasic, HTTPBasicCredentials
from pydantic import BaseModel, Field
JSON: typing.Union[
str, int, float, bool, None, typing.Dict[str, "JSON"], typing.List["JSON"]
] = typing.Union[
str, int, float, bool, None, dict, list
]
] = typing.Union[str, int, float, bool, None, dict, list]
class TruthPayload(BaseModel):
"""Represents a truth. This can be used to both create and get truths."""
id: str
content: str
author: str
@ -33,6 +33,7 @@ class OllamaThread(BaseModel):
class ThreadMessage(BaseModel):
"""Represents a message in an Ollama thread."""
role: typing.Literal["assistant", "system", "user"]
content: str
images: typing.Optional[list[str]] = []
@ -54,9 +55,7 @@ USERNAME = os.getenv("WEB_USERNAME", os.urandom(32).hex())
PASSWORD = os.getenv("WEB_PASSWORD", os.urandom(32).hex())
accounts = {
USERNAME: PASSWORD
}
accounts = {USERNAME: PASSWORD}
if os.path.exists("./web-accounts.json"):
with open("./web-accounts.json") as f:
accounts.update(json.load(f))
@ -71,7 +70,9 @@ def check_credentials(credentials: HTTPBasicCredentials = Depends(security)):
return credentials
def get_db_factory(n: int = 11) -> typing.Callable[[], typing.Generator[redis.Redis, None, None]]:
def get_db_factory(
n: int = 11,
) -> typing.Callable[[], typing.Generator[redis.Redis, None, None]]:
def inner():
uri = os.getenv("REDIS_URL", "redis://redis")
conn = redis.Redis.from_url(uri)
@ -80,38 +81,58 @@ def get_db_factory(n: int = 11) -> typing.Callable[[], typing.Generator[redis.Re
yield conn
finally:
conn.close()
return inner
app = FastAPI(
title="Jimmy v3 API",
version="3.1.0",
root_path=os.getenv("WEB_ROOT_PATH", "") + "/api"
root_path=os.getenv("WEB_ROOT_PATH", "") + "/api",
)
truth_router = APIRouter(
prefix="/truths",
dependencies=[Depends(check_credentials)],
tags=["Truth Social"]
prefix="/truths", dependencies=[Depends(check_credentials)], tags=["Truth Social"]
)
@truth_router.get("", response_model=list[TruthPayload])
def get_all_truths(rich: bool = True, db: redis.Redis = Depends(get_db_factory())):
"""Retrieves all stored truths"""
@truth_router.get("", response_class=ORJSONResponse, response_model=list[TruthPayload])
def get_all_truths(
rich: bool = True,
limit: int = -1,
page: int = 0,
db: redis.Redis = Depends(get_db_factory()),
):
"""Retrieves all stored truths
If ?limit is a positive integer, pagination will be enabled."""
query_start = time.perf_counter()
keys = db.keys()
query_end = time.perf_counter()
if rich is False:
return [
{"id": key, "content": "", "author": "", "timestamp": 0, "extra": None}
for key in keys
]
load_start = time.perf_counter()
if limit >= 0:
keys = keys[page * limit: (page + 1) * limit]
truths = [json.loads(db.get(key)) for key in keys]
return truths
load_end = time.perf_counter()
server_timing = "query;dur=%.2f, load;dur=%.2f" % (
(query_end - query_start) * 1000,
(load_end - load_start) * 1000,
)
return ORJSONResponse(truths, headers={"Server-Timing": server_timing})
@truth_router.get("/all", deprecated=True, response_model=list[TruthPayload])
def get_all_truths_deprecated(response: JSONResponse, rich: bool = True, db: redis.Redis = Depends(get_db_factory())):
def get_all_truths_deprecated(
response: JSONResponse,
rich: bool = True,
db: redis.Redis = Depends(get_db_factory()),
):
"""DEPRECATED - USE get_all_truths INSTEAD"""
return get_all_truths(rich, db)
return get_all_truths(rich=rich, db=db)
@truth_router.get("/{truth_id}", response_model=TruthPayload)
@ -133,7 +154,11 @@ def head_truth(truth_id: str, db: redis.Redis = Depends(get_db_factory())):
@truth_router.post("", status_code=201, response_model=TruthPayload)
def new_truth(payload: TruthPayload, response: JSONResponse, db: redis.Redis = Depends(get_db_factory())):
def new_truth(
payload: TruthPayload,
response: JSONResponse,
db: redis.Redis = Depends(get_db_factory()),
):
"""Creates a new truth"""
data = payload.model_dump()
existing: str = db.get(data["id"])
@ -148,7 +173,9 @@ def new_truth(payload: TruthPayload, response: JSONResponse, db: redis.Redis = D
@truth_router.put("/{truth_id}", response_model=TruthPayload)
def put_truth(truth_id: str, payload: TruthPayload, db: redis.Redis = Depends(get_db_factory())):
def put_truth(
truth_id: str, payload: TruthPayload, db: redis.Redis = Depends(get_db_factory())
):
"""Replaces a stored truth"""
data = payload.model_dump()
existing = db.get(truth_id)
@ -168,9 +195,7 @@ def delete_truth(truth_id: str, db: redis.Redis = Depends(get_db_factory())):
app.include_router(truth_router)
ollama_router = APIRouter(
prefix="/ollama",
dependencies=[Depends(check_credentials)],
tags=["Ollama"]
prefix="/ollama", dependencies=[Depends(check_credentials)], tags=["Ollama"]
)
@ -185,13 +210,13 @@ def get_ollama_threads(db: redis.Redis = Depends(get_db_factory(0))):
return keys
@ollama_router.get("/thread/{thread_id}", response_model=OllamaThread)
@ollama_router.get("/thread/{thread_id}", response_class=ORJSONResponse, response_model=OllamaThread)
def get_ollama_thread(thread_id: str, db: redis.Redis = Depends(get_db_factory(0))):
"""Retrieves a stored thread"""
data: str = db.get(thread_id)
if not data:
raise HTTPException(404, detail="%r not found." % thread_id)
return json.loads(data)
return ORJSONResponse(orjson.loads(data.encode()))
@ollama_router.delete("/thread/{thread_id}", status_code=204)
@ -216,4 +241,5 @@ def health(db: redis.Redis = Depends(get_db_factory())):
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=1111, forwarded_allow_ips="*")