Merge remote-tracking branch 'origin/master'
All checks were successful
Build and Publish / build_and_publish (push) Successful in 1m6s
All checks were successful
Build and Publish / build_and_publish (push) Successful in 1m6s
This commit is contained in:
commit
27b23d3e3a
15 changed files with 1311 additions and 553 deletions
|
@ -21,3 +21,4 @@ aiofiles~=23.2
|
||||||
fuzzywuzzy[speedup]~=0.18
|
fuzzywuzzy[speedup]~=0.18
|
||||||
tortoise-orm[asyncpg]~=0.21
|
tortoise-orm[asyncpg]~=0.21
|
||||||
superpaste @ git+https://github.com/nexy7574/superpaste.git@e31eca6
|
superpaste @ git+https://github.com/nexy7574/superpaste.git@e31eca6
|
||||||
|
orjson~=3.10
|
||||||
|
|
|
@ -33,11 +33,16 @@ class AutoResponder(commands.Cog):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"overpowered_by is set, however members intent is disabled, making it useless."
|
"overpowered_by is set, however members intent is disabled, making it useless."
|
||||||
)
|
)
|
||||||
if self.config.get("overrule_offline_superiors", True) is True and bot.intents.presences is False:
|
if (
|
||||||
|
self.config.get("overrule_offline_superiors", True) is True
|
||||||
|
and bot.intents.presences is False
|
||||||
|
):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"overrule_offline_superiors is enabled, however presences intent is not!"
|
"overrule_offline_superiors is enabled, however presences intent is not!"
|
||||||
)
|
)
|
||||||
self.config.setdefault("transcoding", {"enabled": True, "hevc": True, "on_demand": True})
|
self.config.setdefault(
|
||||||
|
"transcoding", {"enabled": True, "hevc": True, "on_demand": True}
|
||||||
|
)
|
||||||
self.lmgtfy_cache = []
|
self.lmgtfy_cache = []
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -50,7 +55,9 @@ class AutoResponder(commands.Cog):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def allow_on_demand_transcoding(self) -> bool:
|
def allow_on_demand_transcoding(self) -> bool:
|
||||||
return self.transcoding_enabled and self.config["transcoding"].get("on_demand", True)
|
return self.transcoding_enabled and self.config["transcoding"].get(
|
||||||
|
"on_demand", True
|
||||||
|
)
|
||||||
|
|
||||||
def overpowered(self, guild: discord.Guild | None) -> bool:
|
def overpowered(self, guild: discord.Guild | None) -> bool:
|
||||||
if not guild:
|
if not guild:
|
||||||
|
@ -68,14 +75,20 @@ class AutoResponder(commands.Cog):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@typing.overload
|
@typing.overload
|
||||||
def extract_links(text: str, *domains: str, raw: typing.Literal[True] = False) -> list[ParseResult]: ...
|
def extract_links(
|
||||||
|
text: str, *domains: str, raw: typing.Literal[True] = False
|
||||||
|
) -> list[ParseResult]: ...
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@typing.overload
|
@typing.overload
|
||||||
def extract_links(text: str, *domains: str, raw: typing.Literal[False] = False) -> list[str]: ...
|
def extract_links(
|
||||||
|
text: str, *domains: str, raw: typing.Literal[False] = False
|
||||||
|
) -> list[str]: ...
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def extract_links(text: str, *domains: str, raw: bool = False) -> list[str | ParseResult]:
|
def extract_links(
|
||||||
|
text: str, *domains: str, raw: bool = False
|
||||||
|
) -> list[str | ParseResult]:
|
||||||
"""
|
"""
|
||||||
Extracts all links from a given text.
|
Extracts all links from a given text.
|
||||||
|
|
||||||
|
@ -124,10 +137,13 @@ class AutoResponder(commands.Cog):
|
||||||
if not update:
|
if not update:
|
||||||
return
|
return
|
||||||
if last_reaction is not None:
|
if last_reaction is not None:
|
||||||
_ = asyncio.create_task(update.remove_reaction(last_reaction, self.bot.user))
|
_ = asyncio.create_task(
|
||||||
|
update.remove_reaction(last_reaction, self.bot.user)
|
||||||
|
)
|
||||||
if new:
|
if new:
|
||||||
_ = asyncio.create_task(update.add_reaction(new))
|
_ = asyncio.create_task(update.add_reaction(new))
|
||||||
last_reaction = new
|
last_reaction = new
|
||||||
|
|
||||||
self.log.info("Waiting for transcode lock to release")
|
self.log.info("Waiting for transcode lock to release")
|
||||||
async with self.transcode_lock:
|
async with self.transcode_lock:
|
||||||
cog = FFMeta(self.bot)
|
cog = FFMeta(self.bot)
|
||||||
|
@ -147,6 +163,7 @@ class AutoResponder(commands.Cog):
|
||||||
return update_reaction("\N{TIMER CLOCK}\U0000fe0f")
|
return update_reaction("\N{TIMER CLOCK}\U0000fe0f")
|
||||||
streams = info.get("streams", [])
|
streams = info.get("streams", [])
|
||||||
hwaccel = True
|
hwaccel = True
|
||||||
|
maxrate = "5M"
|
||||||
for stream in streams:
|
for stream in streams:
|
||||||
self.log.info("Found stream: %s", stream.get("codec_name"))
|
self.log.info("Found stream: %s", stream.get("codec_name"))
|
||||||
if stream.get("codec_name") == "hevc":
|
if stream.get("codec_name") == "hevc":
|
||||||
|
@ -159,8 +176,14 @@ class AutoResponder(commands.Cog):
|
||||||
hwaccel = False
|
hwaccel = False
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
self.log.info("No HEVC streams found in %s", uri)
|
if int(info["format"]["size"]) >= 25 * 1024 * 1024:
|
||||||
return update_reaction()
|
self.log.warning(
|
||||||
|
"%s is too large to render in discord, compressing", uri
|
||||||
|
)
|
||||||
|
maxrate = "1M"
|
||||||
|
else:
|
||||||
|
self.log.info("No HEVC streams found in %s", uri)
|
||||||
|
return update_reaction()
|
||||||
extension = pathlib.Path(uri).suffix
|
extension = pathlib.Path(uri).suffix
|
||||||
with tempfile.NamedTemporaryFile(suffix=extension) as tmp_dl:
|
with tempfile.NamedTemporaryFile(suffix=extension) as tmp_dl:
|
||||||
self.log.info("Downloading %r to %r", uri, tmp_dl.name)
|
self.log.info("Downloading %r to %r", uri, tmp_dl.name)
|
||||||
|
@ -195,10 +218,8 @@ class AutoResponder(commands.Cog):
|
||||||
tmp_dl.name,
|
tmp_dl.name,
|
||||||
"-c:v",
|
"-c:v",
|
||||||
"libx264",
|
"libx264",
|
||||||
"-crf",
|
|
||||||
"25",
|
|
||||||
"-maxrate",
|
"-maxrate",
|
||||||
"5M",
|
maxrate,
|
||||||
"-minrate",
|
"-minrate",
|
||||||
"100K",
|
"100K",
|
||||||
"-bufsize",
|
"-bufsize",
|
||||||
|
@ -216,7 +237,7 @@ class AutoResponder(commands.Cog):
|
||||||
"-movflags",
|
"-movflags",
|
||||||
"faststart",
|
"faststart",
|
||||||
"-profile:v",
|
"-profile:v",
|
||||||
"main",
|
"high",
|
||||||
"-y",
|
"-y",
|
||||||
"-hide_banner",
|
"-hide_banner",
|
||||||
]
|
]
|
||||||
|
@ -230,7 +251,9 @@ class AutoResponder(commands.Cog):
|
||||||
stderr=asyncio.subprocess.PIPE,
|
stderr=asyncio.subprocess.PIPE,
|
||||||
)
|
)
|
||||||
stdout, stderr = await process.communicate()
|
stdout, stderr = await process.communicate()
|
||||||
self.log.info("finished transcode with return code %d", process.returncode)
|
self.log.info(
|
||||||
|
"finished transcode with return code %d", process.returncode
|
||||||
|
)
|
||||||
self.log.debug("stdout: %r", stdout.decode)
|
self.log.debug("stdout: %r", stdout.decode)
|
||||||
self.log.debug("stderr: %r", stderr.decode)
|
self.log.debug("stderr: %r", stderr.decode)
|
||||||
update_reaction()
|
update_reaction()
|
||||||
|
@ -243,12 +266,9 @@ class AutoResponder(commands.Cog):
|
||||||
)
|
)
|
||||||
self._cooldown_transcode()
|
self._cooldown_transcode()
|
||||||
return discord.File(tmp_path), tmp_path
|
return discord.File(tmp_path), tmp_path
|
||||||
|
|
||||||
async def transcode_hevc_to_h264(
|
async def transcode_hevc_to_h264(
|
||||||
self,
|
self, message: discord.Message, *domains: str, additional: Iterable[str] = None
|
||||||
message: discord.Message,
|
|
||||||
*domains: str,
|
|
||||||
additional: Iterable[str] = None
|
|
||||||
) -> None:
|
) -> None:
|
||||||
if not shutil.which("ffmpeg"):
|
if not shutil.which("ffmpeg"):
|
||||||
self.log.error("ffmpeg not installed")
|
self.log.error("ffmpeg not installed")
|
||||||
|
@ -283,7 +303,9 @@ class AutoResponder(commands.Cog):
|
||||||
self.log.info("Found link to transcode: %r", link)
|
self.log.info("Found link to transcode: %r", link)
|
||||||
try:
|
try:
|
||||||
async with message.channel.typing():
|
async with message.channel.typing():
|
||||||
_r = await self._transcode_hevc_to_h264(link, update=message)
|
_r = await self._transcode_hevc_to_h264(
|
||||||
|
link, update=message
|
||||||
|
)
|
||||||
if not _r:
|
if not _r:
|
||||||
continue
|
continue
|
||||||
file, _p = _r
|
file, _p = _r
|
||||||
|
@ -291,7 +313,9 @@ class AutoResponder(commands.Cog):
|
||||||
if _p.stat().st_size <= 24.5 * 1024 * 1024:
|
if _p.stat().st_size <= 24.5 * 1024 * 1024:
|
||||||
await message.add_reaction("\N{OUTBOX TRAY}")
|
await message.add_reaction("\N{OUTBOX TRAY}")
|
||||||
await message.reply(file=file)
|
await message.reply(file=file)
|
||||||
await message.remove_reaction("\N{OUTBOX TRAY}", self.bot.user)
|
await message.remove_reaction(
|
||||||
|
"\N{OUTBOX TRAY}", self.bot.user
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
await message.add_reaction("\N{OUTBOX TRAY}")
|
await message.add_reaction("\N{OUTBOX TRAY}")
|
||||||
self.log.warning(
|
self.log.warning(
|
||||||
|
@ -301,16 +325,31 @@ class AutoResponder(commands.Cog):
|
||||||
)
|
)
|
||||||
if _p.stat().st_size <= 510 * 1024 * 1024:
|
if _p.stat().st_size <= 510 * 1024 * 1024:
|
||||||
file.fp.seek(0)
|
file.fp.seek(0)
|
||||||
self.log.info("Trying to upload file to pastebin.")
|
self.log.info(
|
||||||
|
"Trying to upload file to pastebin."
|
||||||
|
)
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
response = await client.post(
|
response = await client.post(
|
||||||
"https://0x0.st",
|
"https://0x0.st",
|
||||||
files={"file": (_p.name, file.fp, "video/mp4")},
|
files={
|
||||||
headers={"User-Agent": "CollegeBot (matrix: @nex:nexy7574.co.uk)"},
|
"file": (
|
||||||
|
_p.name,
|
||||||
|
file.fp,
|
||||||
|
"video/mp4",
|
||||||
|
)
|
||||||
|
},
|
||||||
|
headers={
|
||||||
|
"User-Agent": "CollegeBot (matrix: @nex:nexy7574.co.uk)"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
await message.remove_reaction(
|
||||||
|
"\N{OUTBOX TRAY}", self.bot.user
|
||||||
)
|
)
|
||||||
await message.remove_reaction("\N{OUTBOX TRAY}", self.bot.user)
|
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
await message.reply("https://embeds.video/" + response.text.strip())
|
await message.reply(
|
||||||
|
"https://embeds.video/"
|
||||||
|
+ response.text.strip()
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
await message.add_reaction("\N{BUG}")
|
await message.add_reaction("\N{BUG}")
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
|
@ -321,12 +360,16 @@ class AutoResponder(commands.Cog):
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.log.error("Failed to transcode %r: %r", link, e)
|
self.log.error("Failed to transcode %r: %r", link, e)
|
||||||
|
|
||||||
async def copy_ncfe_docs(self, message: discord.Message, links: list[ParseResult]) -> None:
|
async def copy_ncfe_docs(
|
||||||
|
self, message: discord.Message, links: list[ParseResult]
|
||||||
|
) -> None:
|
||||||
files = []
|
files = []
|
||||||
if self.config.get("download_pdfs", True) is False:
|
if self.config.get("download_pdfs", True) is False:
|
||||||
self.log.debug("Download PDFs is disabled in config, disengaging.")
|
self.log.debug("Download PDFs is disabled in config, disengaging.")
|
||||||
return
|
return
|
||||||
self.log.info("Preparing to download: %s", ", ".join(map(ParseResult.geturl, links)))
|
self.log.info(
|
||||||
|
"Preparing to download: %s", ", ".join(map(ParseResult.geturl, links))
|
||||||
|
)
|
||||||
async with aiohttp.ClientSession() as session:
|
async with aiohttp.ClientSession() as session:
|
||||||
for link in set(links):
|
for link in set(links):
|
||||||
if link.path.endswith(".pdf"):
|
if link.path.endswith(".pdf"):
|
||||||
|
@ -338,7 +381,7 @@ class AutoResponder(commands.Cog):
|
||||||
"Failed to download %s: HTTP %d - %r",
|
"Failed to download %s: HTTP %d - %r",
|
||||||
link,
|
link,
|
||||||
response.status,
|
response.status,
|
||||||
await response.text()
|
await response.text(),
|
||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
async for chunk in response.content.iter_any():
|
async for chunk in response.content.iter_any():
|
||||||
|
@ -351,7 +394,9 @@ class AutoResponder(commands.Cog):
|
||||||
self.log.warning("File was too large to upload. Skipping.")
|
self.log.warning("File was too large to upload. Skipping.")
|
||||||
continue
|
continue
|
||||||
p = pathlib.Path(link.path).name
|
p = pathlib.Path(link.path).name
|
||||||
file = discord.File(buffer, filename=p, description="Copy of %s" % link.geturl())
|
file = discord.File(
|
||||||
|
buffer, filename=p, description="Copy of %s" % link.geturl()
|
||||||
|
)
|
||||||
files.append(file)
|
files.append(file)
|
||||||
for file in files:
|
for file in files:
|
||||||
await message.reply(file=file)
|
await message.reply(file=file)
|
||||||
|
@ -379,27 +424,38 @@ class AutoResponder(commands.Cog):
|
||||||
self.log.info("Got VHS reaction, scanning for transcode")
|
self.log.info("Got VHS reaction, scanning for transcode")
|
||||||
extra = []
|
extra = []
|
||||||
if reaction.message.attachments:
|
if reaction.message.attachments:
|
||||||
extra = [attachment.url for attachment in reaction.message.attachments]
|
extra = [
|
||||||
|
attachment.url for attachment in reaction.message.attachments
|
||||||
|
]
|
||||||
if self.allow_on_demand_transcoding:
|
if self.allow_on_demand_transcoding:
|
||||||
await self.transcode_hevc_to_h264(reaction.message, additional=extra)
|
await self.transcode_hevc_to_h264(
|
||||||
|
reaction.message, additional=extra
|
||||||
|
)
|
||||||
elif str(reaction.emoji) == "\U0001f310":
|
elif str(reaction.emoji) == "\U0001f310":
|
||||||
if reaction.message.id not in self.lmgtfy_cache:
|
if reaction.message.id not in self.lmgtfy_cache:
|
||||||
url = "https://lmddgtfy.net/?q=" + quote_plus(reaction.message.content)
|
url = "https://lmddgtfy.net/?q=" + quote_plus(
|
||||||
m = await reaction.message.reply(f"[Here's the answer to your question]({url})")
|
reaction.message.content
|
||||||
|
)
|
||||||
|
m = await reaction.message.reply(
|
||||||
|
f"[Here's the answer to your question]({url})"
|
||||||
|
)
|
||||||
await m.edit(suppress=True)
|
await m.edit(suppress=True)
|
||||||
self.lmgtfy_cache.append(reaction.message.id)
|
self.lmgtfy_cache.append(reaction.message.id)
|
||||||
|
|
||||||
elif str(reaction.emoji)[0] == "\N{wastebasket}":
|
elif str(reaction.emoji)[0] == "\N{WASTEBASKET}":
|
||||||
if reaction.message.channel.permissions_for(reaction.message.guild.me).manage_messages:
|
if reaction.message.channel.permissions_for(
|
||||||
self.log.info("Deleting message %s (Wastebasket)" % reaction.message.jump_url)
|
reaction.message.guild.me
|
||||||
|
).manage_messages:
|
||||||
|
self.log.info(
|
||||||
|
"Deleting message %s (Wastebasket)" % reaction.message.jump_url
|
||||||
|
)
|
||||||
await reaction.message.delete(
|
await reaction.message.delete(
|
||||||
reason="%s requested deletion of message" % user,
|
reason="%s requested deletion of message" % user, delay=0.2
|
||||||
delay=0.2
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.log.warning(
|
self.log.warning(
|
||||||
"Unable to delete message %s (wastebasket) - missing permissions",
|
"Unable to delete message %s (wastebasket) - missing permissions",
|
||||||
reaction.message.jump_url
|
reaction.message.jump_url,
|
||||||
)
|
)
|
||||||
|
|
||||||
@commands.Cog.listener("on_raw_reaction_add")
|
@commands.Cog.listener("on_raw_reaction_add")
|
||||||
|
@ -415,8 +471,13 @@ class AutoResponder(commands.Cog):
|
||||||
_e = discord.PartialEmoji.from_str(str(payload.emoji))
|
_e = discord.PartialEmoji.from_str(str(payload.emoji))
|
||||||
reaction = discord.Reaction(
|
reaction = discord.Reaction(
|
||||||
message=message,
|
message=message,
|
||||||
data={"emoji": _e, "count": 1, "me": payload.user_id == self.bot.user.id, "burst": False},
|
data={
|
||||||
emoji=payload.emoji
|
"emoji": _e,
|
||||||
|
"count": 1,
|
||||||
|
"me": payload.user_id == self.bot.user.id,
|
||||||
|
"burst": False,
|
||||||
|
},
|
||||||
|
emoji=payload.emoji,
|
||||||
)
|
)
|
||||||
user = self.bot.get_user(payload.user_id)
|
user = self.bot.get_user(payload.user_id)
|
||||||
await self.on_reaction_add(reaction, user)
|
await self.on_reaction_add(reaction, user)
|
||||||
|
|
|
@ -1,140 +1,220 @@
|
||||||
"""
|
"""
|
||||||
This module is only meant to be loaded during election times.
|
This module is only meant to be loaded during election times.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
|
import random
|
||||||
|
import datetime
|
||||||
import re
|
import re
|
||||||
|
|
||||||
import discord
|
import discord
|
||||||
import httpx
|
import httpx
|
||||||
from bs4 import BeautifulSoup
|
from bs4 import BeautifulSoup
|
||||||
from discord.ext import commands
|
from discord.ext import commands, tasks
|
||||||
|
|
||||||
|
|
||||||
SPAN_REGEX = re.compile(
|
SPAN_REGEX = re.compile(
|
||||||
r"^(?P<party>\D+)(?P<councillors>[0-9,]+)\scouncillors\s(?P<net>[0-9,]+)\scouncillors\s(?P<net_change>(gained|lost))$"
|
r"^(?P<party>\D+)(?P<councillors>[0-9,]+)\scouncillors\s(?P<net>[0-9,]+)\scouncillors\s(?P<net_change>(gained|lost))$"
|
||||||
)
|
)
|
||||||
MULTI: dict[str, int] = {
|
MULTI: dict[str, int] = {"gained": 1, "lost": -1}
|
||||||
"gained": 1,
|
|
||||||
"lost": -1
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class ElectionCog(commands.Cog):
|
class ElectionCog(commands.Cog):
|
||||||
SOURCE = "https://bbc.com/"
|
SOURCE = "https://bbc.com/"
|
||||||
HEADERS = {
|
HEADERS = {
|
||||||
"Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,"
|
"Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,"
|
||||||
"image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.7",
|
"image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.7",
|
||||||
"Accept-Language": "en-GB,en-US;q=0.9,en;q=0.8",
|
"Accept-Language": "en-GB,en-US;q=0.9,en;q=0.8",
|
||||||
"Priority": "u=0, i",
|
"Priority": "u=0, i",
|
||||||
"Sec-Ch-Ua": "\"Chromium\";v=\"124\", \"Google Chrome\";v=\"124\", \"Not-A.Brand\";v=\"99\"",
|
"Sec-Ch-Ua": '"Chromium";v="124", "Google Chrome";v="124", "Not-A.Brand";v="99"',
|
||||||
"Sec-Ch-Ua-Mobile": "?0",
|
"Sec-Ch-Ua-Mobile": "?0",
|
||||||
"Sec-Ch-Ua-Platform": "\"Linux\"",
|
"Sec-Ch-Ua-Platform": '"Linux"',
|
||||||
"Sec-Fetch-Dest": "document",
|
"Sec-Fetch-Dest": "document",
|
||||||
"Sec-Fetch-Mode": "navigate",
|
"Sec-Fetch-Mode": "navigate",
|
||||||
"Sec-Fetch-Site": "none",
|
"Sec-Fetch-Site": "none",
|
||||||
"Sec-Fetch-User": "?1",
|
"Sec-Fetch-User": "?1",
|
||||||
"Upgrade-Insecure-Requests": "1",
|
"Upgrade-Insecure-Requests": "1",
|
||||||
"User-Agent": "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/124.0.0.0 "
|
"User-Agent": "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/124.0.0.0 "
|
||||||
"Safari/537.36",
|
"Safari/537.36",
|
||||||
}
|
}
|
||||||
|
ETA = datetime.datetime(
|
||||||
|
2024, 7, 4, 23, 30, tzinfo=datetime.datetime.now().astimezone().tzinfo
|
||||||
|
)
|
||||||
|
|
||||||
def __init__(self, bot):
|
def __init__(self, bot):
|
||||||
self.bot = bot
|
self.bot: commands.Bot = bot
|
||||||
self.log = logging.getLogger("jimmy.cogs.election")
|
self.log = logging.getLogger("jimmy.cogs.election")
|
||||||
|
self.countdown_message = None
|
||||||
|
# self.check_election.start()
|
||||||
|
|
||||||
|
def cog_unload(self) -> None:
|
||||||
|
self.check_election.cancel()
|
||||||
|
|
||||||
|
@tasks.loop(minutes=1)
|
||||||
|
async def check_election(self):
|
||||||
|
if not self.bot.is_ready():
|
||||||
|
await self.bot.wait_until_ready()
|
||||||
|
|
||||||
|
guild = self.bot.get_guild(994710566612500550)
|
||||||
|
if not guild:
|
||||||
|
return self.log.error("Nonsense guild not found. Can't do countdown.")
|
||||||
|
channel = discord.utils.get(guild.text_channels, name="countdown")
|
||||||
|
if not channel:
|
||||||
|
return self.log.error("Countdown channel not found.")
|
||||||
|
|
||||||
|
await asyncio.sleep(random.randint(0, 10))
|
||||||
|
now = discord.utils.utcnow()
|
||||||
|
diff = (self.ETA - now).total_seconds()
|
||||||
|
if diff < -86400:
|
||||||
|
return self.log.debug("Countdown long expired.")
|
||||||
|
|
||||||
|
if diff > 3600:
|
||||||
|
hours, remainder = map(round, divmod(diff, 3600))
|
||||||
|
minutes, seconds = map(round, divmod(remainder, 60))
|
||||||
|
message = f"+ {hours} hours, {minutes} minutes, and {seconds} seconds."
|
||||||
|
elif diff > 60:
|
||||||
|
minutes, seconds = map(round, divmod(diff, 60))
|
||||||
|
message = f"+ {minutes} minutes and {seconds} seconds."
|
||||||
|
elif diff >= 0:
|
||||||
|
message = f"+ {round(diff)} seconds."
|
||||||
|
else:
|
||||||
|
message = "Results time!"
|
||||||
|
|
||||||
|
if self.countdown_message:
|
||||||
|
try:
|
||||||
|
return await self.countdown_message.edit(
|
||||||
|
content=f"```diff\n{message}```"
|
||||||
|
)
|
||||||
|
except discord.HTTPException:
|
||||||
|
self.log.exception("Failed to edit countdown message.")
|
||||||
|
self.countdown_message = await channel.send(f"```diff\n{message}```")
|
||||||
|
self.log.debug("Sent countdown message")
|
||||||
|
|
||||||
def process_soup(self, soup: BeautifulSoup) -> dict[str, list[int]] | None:
|
def process_soup(self, soup: BeautifulSoup) -> dict[str, list[int]] | None:
|
||||||
good_soup = soup.find(attrs={"data-testid": "election-banner-results-bar"})
|
good_soups = list(
|
||||||
if not good_soup:
|
soup.find_all(attrs={"data-testid": "election-banner-results-bar"})
|
||||||
|
)
|
||||||
|
if not good_soups:
|
||||||
|
self.log.error(
|
||||||
|
"No 'election-banner-results-bar' elements found:\n%r", soup.prettify()
|
||||||
|
)
|
||||||
return
|
return
|
||||||
|
good_soup = list(good_soups)[1]
|
||||||
css: str = "\n".join([x.get_text() for x in soup.find_all("style")])
|
|
||||||
|
|
||||||
def find_colour(style_name: str, want: str = "background-color") -> str | None:
|
|
||||||
self.log.info("Looking for style %r", style_name)
|
|
||||||
index = css.index(style_name) + len(style_name) + 1
|
|
||||||
value = ""
|
|
||||||
for char in css[index:]:
|
|
||||||
if char == "}":
|
|
||||||
break
|
|
||||||
value += char
|
|
||||||
attributes = filter(None, value.split(";"))
|
|
||||||
parsed = {}
|
|
||||||
for attr in attributes:
|
|
||||||
name, val = attr.split(":")
|
|
||||||
parsed[name] = val
|
|
||||||
self.log.info("Parsed the following attributes: %r", parsed)
|
|
||||||
if want in parsed:
|
|
||||||
self.log.info("Returning %r: %r", want, parsed[want])
|
|
||||||
return parsed[want]
|
|
||||||
self.log.warning("%r was not in attributes.", want)
|
|
||||||
|
|
||||||
results: dict[str, list[int]] = {}
|
results: dict[str, list[int]] = {}
|
||||||
for child_ul in good_soup.children:
|
for child_li in good_soup.children:
|
||||||
child_ul: BeautifulSoup
|
span = list(child_li.children)[-1]
|
||||||
span = child_ul.find("span", recursive=False)
|
try:
|
||||||
if not span:
|
party, extra = span.get_text().strip().split(":", 1)
|
||||||
self.log.warning("%r did not have a 'span' element.", child_ul)
|
seats, extra = extra.split(",", 1)
|
||||||
|
extra = extra.strip()
|
||||||
|
if extra.lower() == "no change":
|
||||||
|
change = 0
|
||||||
|
else:
|
||||||
|
_values = extra.split()
|
||||||
|
change = int(_values[0])
|
||||||
|
if _values[-1] != "gained":
|
||||||
|
change *= -1
|
||||||
|
seats = int(seats.split()[0])
|
||||||
|
except ValueError:
|
||||||
|
self.log.error("failed to parse %r", span)
|
||||||
continue
|
continue
|
||||||
|
results[party] = [seats, change, 0, 0]
|
||||||
text = span.get_text().replace(",", "")
|
for child_li in good_soups[0].children:
|
||||||
groups = SPAN_REGEX.match(text)
|
span = list(child_li.children)[-1]
|
||||||
if groups:
|
try:
|
||||||
groups = groups.groupdict()
|
party, extra = span.get_text().strip().split(":", 1)
|
||||||
else:
|
seats, _ = extra.strip().split(" ", 1)
|
||||||
self.log.warning(
|
seats = int(seats.strip())
|
||||||
"Found span element (%r), however resolved text (%r) did not match regex.",
|
except ValueError:
|
||||||
span, text
|
self.log.error("failed to parse %r", span)
|
||||||
)
|
|
||||||
continue
|
continue
|
||||||
|
if party in results:
|
||||||
results[str(groups["party"]).strip()] = [
|
results[party][3] = seats
|
||||||
int(groups["councillors"].strip()),
|
|
||||||
int(groups["net"].strip()) * MULTI[groups["net_change"]],
|
|
||||||
int(find_colour(child_ul.next["class"][0])[1:], base=16)
|
|
||||||
]
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
async def _get_embed(self) -> discord.Embed | None:
|
||||||
|
async with httpx.AsyncClient(headers=self.HEADERS) as client:
|
||||||
|
response = await client.get(self.SOURCE, follow_redirects=True)
|
||||||
|
if response.status_code != 200:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"HTTP {response.status_code} while fetching results from BBC"
|
||||||
|
)
|
||||||
|
soup = await asyncio.to_thread(BeautifulSoup, response.text, "html.parser")
|
||||||
|
results = await self.bot.loop.run_in_executor(None, self.process_soup, soup)
|
||||||
|
if results:
|
||||||
|
now = discord.utils.utcnow()
|
||||||
|
date = now.date().strftime("%B %Y")
|
||||||
|
colour_scores = {}
|
||||||
|
embed = discord.Embed(
|
||||||
|
title="Election results - " + date,
|
||||||
|
url="https://bbc.co.uk/",
|
||||||
|
timestamp=now,
|
||||||
|
)
|
||||||
|
embed.set_footer(text="Source from bbc.co.uk.")
|
||||||
|
description_parts = []
|
||||||
|
|
||||||
|
for party_name, values in results.items():
|
||||||
|
councillors, net, colour, last_election = values
|
||||||
|
colour_scores[party_name] = councillors
|
||||||
|
symbol = "+" if net > 0 else ""
|
||||||
|
description_parts.append(
|
||||||
|
f"**{party_name}**: {symbol}{net:,} ({councillors:,} total, "
|
||||||
|
f"{last_election:,} predicted by exit poll)"
|
||||||
|
)
|
||||||
|
|
||||||
|
top_party = list(
|
||||||
|
sorted(
|
||||||
|
colour_scores.keys(), key=lambda k: colour_scores[k], reverse=True
|
||||||
|
)
|
||||||
|
)[0]
|
||||||
|
embed.colour = discord.Colour(results[top_party][2])
|
||||||
|
embed.description = "\n".join(description_parts)
|
||||||
|
return embed
|
||||||
|
|
||||||
@commands.slash_command(name="election")
|
@commands.slash_command(name="election")
|
||||||
async def get_election_results(self, ctx: discord.ApplicationContext):
|
async def get_election_results(self, ctx: discord.ApplicationContext):
|
||||||
"""Gets the current election results"""
|
"""Gets the current election results"""
|
||||||
await ctx.defer()
|
|
||||||
async with httpx.AsyncClient(headers=self.HEADERS) as client:
|
|
||||||
response = await client.get(self.SOURCE, follow_redirects=True)
|
|
||||||
if response.status_code != 200:
|
|
||||||
return await ctx.respond(
|
|
||||||
"Sorry, I can't do that right now (HTTP %d while fetching results from BBC)" % response.status_code
|
|
||||||
)
|
|
||||||
|
|
||||||
# noinspection PyTypeChecker
|
class RefreshView(discord.ui.View):
|
||||||
soup = await asyncio.to_thread(BeautifulSoup, response.text, "html.parser")
|
def __init__(self, **kwargs):
|
||||||
results = await self.bot.loop.run_in_executor(None, self.process_soup, soup)
|
super().__init__(**kwargs)
|
||||||
if results:
|
self.last_edit = discord.utils.utcnow()
|
||||||
now = discord.utils.utcnow()
|
|
||||||
date = now.date().strftime("%B %Y")
|
|
||||||
colour_scores = {}
|
|
||||||
embed = discord.Embed(
|
|
||||||
title="Election results - " + date,
|
|
||||||
url="https://bbc.co.uk/",
|
|
||||||
timestamp=now
|
|
||||||
)
|
|
||||||
embed.set_footer(text="Source from bbc.co.uk.")
|
|
||||||
description_parts = []
|
|
||||||
|
|
||||||
for party_name, values in results.items():
|
@discord.ui.button(
|
||||||
councillors, net, colour = values
|
label="Refresh", style=discord.ButtonStyle.primary, emoji="\U0001f501"
|
||||||
colour_scores[party_name] = councillors
|
)
|
||||||
symbol = "+" if net > 0 else ''
|
async def refresh(_self, _btn, interaction):
|
||||||
description_parts.append(
|
await interaction.response.defer(invisible=True)
|
||||||
f"**{party_name}**: {symbol}{net:,} ({councillors:,} total)"
|
if (discord.utils.utcnow() - self.last_edit).total_seconds() < 10:
|
||||||
|
return await interaction.followup.send("Slow down.", ephemeral=True)
|
||||||
|
try:
|
||||||
|
embed = await self._get_embed()
|
||||||
|
except Exception as e:
|
||||||
|
self.log.exception("Failed to get election results.")
|
||||||
|
return await interaction.followup.send(
|
||||||
|
f"Sorry, I cannot contact the BBC at this time: {e}"
|
||||||
)
|
)
|
||||||
|
if embed is None:
|
||||||
|
return await interaction.followup.send(
|
||||||
|
"Sorry, I could not find any election results."
|
||||||
|
)
|
||||||
|
await interaction.edit_original_response(embed=embed)
|
||||||
|
|
||||||
top_party = list(sorted(colour_scores.keys(), key=lambda k: colour_scores[k], reverse=True))[0]
|
await ctx.defer()
|
||||||
embed.colour = discord.Colour(results[top_party][2])
|
try:
|
||||||
embed.description = "\n".join(description_parts)
|
embed = await self._get_embed()
|
||||||
return await ctx.respond(embed=embed)
|
except Exception as e:
|
||||||
else:
|
self.log.exception("Failed to get election results.")
|
||||||
return await ctx.respond("Unable to get election results at this time.")
|
return await ctx.respond(
|
||||||
|
f"Sorry, I cannot contact the BBC at this time: {e}"
|
||||||
|
)
|
||||||
|
if embed is None:
|
||||||
|
return await ctx.respond("Sorry, I could not find any election results.")
|
||||||
|
await ctx.respond(
|
||||||
|
embed=embed, view=RefreshView(timeout=3600, disable_on_timeout=True)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def setup(bot):
|
def setup(bot):
|
||||||
|
|
|
@ -29,25 +29,44 @@ class FFMeta(commands.Cog):
|
||||||
self.bot = bot
|
self.bot = bot
|
||||||
self.log = logging.getLogger("jimmy.cogs.ffmeta")
|
self.log = logging.getLogger("jimmy.cogs.ffmeta")
|
||||||
|
|
||||||
def jpegify_image(self, input_file: io.BytesIO, quality: int = 50, image_format: str = "jpeg") -> io.BytesIO:
|
def jpegify_image(
|
||||||
|
self, input_file: io.BytesIO, quality: int = 50, image_format: str = "jpeg"
|
||||||
|
) -> io.BytesIO:
|
||||||
quality = min(1, max(quality, 100))
|
quality = min(1, max(quality, 100))
|
||||||
img_src = PIL.Image.open(input_file)
|
img_src = PIL.Image.open(input_file)
|
||||||
if image_format == "jpeg":
|
if image_format == "jpeg":
|
||||||
img_src = img_src.convert("RGB")
|
img_src = img_src.convert("RGB")
|
||||||
img_dst = io.BytesIO()
|
img_dst = io.BytesIO()
|
||||||
self.log.debug("Saving input file (%r) as %r with quality %r%%", input_file, image_format, quality)
|
self.log.debug(
|
||||||
|
"Saving input file (%r) as %r with quality %r%%",
|
||||||
|
input_file,
|
||||||
|
image_format,
|
||||||
|
quality,
|
||||||
|
)
|
||||||
img_src.save(img_dst, format=image_format, quality=quality)
|
img_src.save(img_dst, format=image_format, quality=quality)
|
||||||
img_dst.seek(0)
|
img_dst.seek(0)
|
||||||
return img_dst
|
return img_dst
|
||||||
|
|
||||||
async def _run_ffprobe(self, uri: str | pathlib.Path, as_json: bool = False) -> dict | str:
|
async def _run_ffprobe(
|
||||||
|
self, uri: str | pathlib.Path, as_json: bool = False
|
||||||
|
) -> dict | str:
|
||||||
"""
|
"""
|
||||||
Runs ffprobe on the given target (either file path or URL) and returns the result
|
Runs ffprobe on the given target (either file path or URL) and returns the result
|
||||||
:param uri: the URI to run ffprobe on
|
:param uri: the URI to run ffprobe on
|
||||||
:return: The result
|
:return: The result
|
||||||
"""
|
"""
|
||||||
_bin = "ffprobe"
|
_bin = "ffprobe"
|
||||||
cmd = ["-hide_banner", "-v", "quiet", "-print_format", "json", "-show_streams", "-show_format", "-i", str(uri)]
|
cmd = [
|
||||||
|
"-hide_banner",
|
||||||
|
"-v",
|
||||||
|
"quiet",
|
||||||
|
"-print_format",
|
||||||
|
"json",
|
||||||
|
"-show_streams",
|
||||||
|
"-show_format",
|
||||||
|
"-i",
|
||||||
|
str(uri),
|
||||||
|
]
|
||||||
if not as_json:
|
if not as_json:
|
||||||
cmd = ["-hide_banner", "-i", str(uri)]
|
cmd = ["-hide_banner", "-i", str(uri)]
|
||||||
process = await asyncio.create_subprocess_exec(
|
process = await asyncio.create_subprocess_exec(
|
||||||
|
@ -62,7 +81,12 @@ class FFMeta(commands.Cog):
|
||||||
return stderr.decode(errors="replace")
|
return stderr.decode(errors="replace")
|
||||||
|
|
||||||
@commands.slash_command()
|
@commands.slash_command()
|
||||||
async def ffprobe(self, ctx: discord.ApplicationContext, url: str = None, attachment: discord.Attachment = None):
|
async def ffprobe(
|
||||||
|
self,
|
||||||
|
ctx: discord.ApplicationContext,
|
||||||
|
url: str = None,
|
||||||
|
attachment: discord.Attachment = None,
|
||||||
|
):
|
||||||
"""Runs ffprobe on a given URL or attachment"""
|
"""Runs ffprobe on a given URL or attachment"""
|
||||||
if not shutil.which("ffprobe"):
|
if not shutil.which("ffprobe"):
|
||||||
return await ctx.respond("ffprobe is not installed on this system.")
|
return await ctx.respond("ffprobe is not installed on this system.")
|
||||||
|
@ -101,7 +125,10 @@ class FFMeta(commands.Cog):
|
||||||
image_format: typing.Annotated[
|
image_format: typing.Annotated[
|
||||||
str,
|
str,
|
||||||
discord.Option(
|
discord.Option(
|
||||||
str, description="The format of the resulting image", choices=["jpeg", "webp"], default="jpeg"
|
str,
|
||||||
|
description="The format of the resulting image",
|
||||||
|
choices=["jpeg", "webp"],
|
||||||
|
default="jpeg",
|
||||||
),
|
),
|
||||||
] = "jpeg",
|
] = "jpeg",
|
||||||
):
|
):
|
||||||
|
@ -115,7 +142,9 @@ class FFMeta(commands.Cog):
|
||||||
|
|
||||||
src = io.BytesIO()
|
src = io.BytesIO()
|
||||||
async with httpx.AsyncClient(
|
async with httpx.AsyncClient(
|
||||||
headers={"User-Agent": f"DiscordBot (Jimmy, v2, {VERSION}, +https://github.com/nexy7574/college-bot-v2)"}
|
headers={
|
||||||
|
"User-Agent": f"DiscordBot (Jimmy, v2, {VERSION}, +https://github.com/nexy7574/college-bot-v2)"
|
||||||
|
}
|
||||||
) as client:
|
) as client:
|
||||||
response = await client.get(url)
|
response = await client.get(url)
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
|
@ -123,13 +152,17 @@ class FFMeta(commands.Cog):
|
||||||
src.write(response.content)
|
src.write(response.content)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
dst = await asyncio.to_thread(self.jpegify_image, src, quality, image_format)
|
dst = await asyncio.to_thread(
|
||||||
|
self.jpegify_image, src, quality, image_format
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
await ctx.respond(f"Failed to convert image: `{e}`.")
|
await ctx.respond(f"Failed to convert image: `{e}`.")
|
||||||
self.log.error("Failed to convert image %r: %r", url, e)
|
self.log.error("Failed to convert image %r: %r", url, e)
|
||||||
return
|
return
|
||||||
else:
|
else:
|
||||||
await ctx.respond(file=discord.File(dst, filename=f"jpegified.{image_format}"))
|
await ctx.respond(
|
||||||
|
file=discord.File(dst, filename=f"jpegified.{image_format}")
|
||||||
|
)
|
||||||
|
|
||||||
@commands.slash_command()
|
@commands.slash_command()
|
||||||
async def opusinate(
|
async def opusinate(
|
||||||
|
@ -148,7 +181,10 @@ class FFMeta(commands.Cog):
|
||||||
),
|
),
|
||||||
] = 96,
|
] = 96,
|
||||||
mono: typing.Annotated[
|
mono: typing.Annotated[
|
||||||
bool, discord.Option(bool, description="Whether to convert the audio to mono", default=False)
|
bool,
|
||||||
|
discord.Option(
|
||||||
|
bool, description="Whether to convert the audio to mono", default=False
|
||||||
|
),
|
||||||
] = False,
|
] = False,
|
||||||
):
|
):
|
||||||
"""Converts a given URL or attachment to an Opus file"""
|
"""Converts a given URL or attachment to an Opus file"""
|
||||||
|
@ -195,7 +231,9 @@ class FFMeta(commands.Cog):
|
||||||
)
|
)
|
||||||
stdout, stderr = await probe_process.communicate()
|
stdout, stderr = await probe_process.communicate()
|
||||||
stdout = stdout.decode("utf-8", "replace")
|
stdout = stdout.decode("utf-8", "replace")
|
||||||
data = {"format": {"duration": 195}} # 3 minutes and 15 seconds is the 2023 average.
|
data = {
|
||||||
|
"format": {"duration": 195}
|
||||||
|
} # 3 minutes and 15 seconds is the 2023 average.
|
||||||
if stdout:
|
if stdout:
|
||||||
try:
|
try:
|
||||||
data = json.loads(stdout)
|
data = json.loads(stdout)
|
||||||
|
@ -206,7 +244,9 @@ class FFMeta(commands.Cog):
|
||||||
max_end_size = ((bitrate * duration * channels) / 8) * 1024
|
max_end_size = ((bitrate * duration * channels) / 8) * 1024
|
||||||
if max_end_size > (24.75 * 1024 * 1024):
|
if max_end_size > (24.75 * 1024 * 1024):
|
||||||
return await ctx.respond(
|
return await ctx.respond(
|
||||||
"The file would be too large to send ({:,.2f} MiB).".format(max_end_size / 1024 / 1024)
|
"The file would be too large to send ({:,.2f} MiB).".format(
|
||||||
|
max_end_size / 1024 / 1024
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
process = await asyncio.create_subprocess_exec(
|
process = await asyncio.create_subprocess_exec(
|
||||||
|
@ -237,7 +277,9 @@ class FFMeta(commands.Cog):
|
||||||
|
|
||||||
file = io.BytesIO(stdout)
|
file = io.BytesIO(stdout)
|
||||||
if (fs := len(file.getvalue())) > (24.75 * 1024 * 1024):
|
if (fs := len(file.getvalue())) > (24.75 * 1024 * 1024):
|
||||||
return await ctx.respond("The file is too large to send ({:,.2f} MiB).".format(fs / 1024 / 1024))
|
return await ctx.respond(
|
||||||
|
"The file is too large to send ({:,.2f} MiB).".format(fs / 1024 / 1024)
|
||||||
|
)
|
||||||
if not fs:
|
if not fs:
|
||||||
await ctx.respond("Failed to convert audio. See below.")
|
await ctx.respond("Failed to convert audio. See below.")
|
||||||
else:
|
else:
|
||||||
|
@ -251,9 +293,11 @@ class FFMeta(commands.Cog):
|
||||||
|
|
||||||
for page in paginator.pages:
|
for page in paginator.pages:
|
||||||
await ctx.respond(page, ephemeral=True)
|
await ctx.respond(page, ephemeral=True)
|
||||||
|
|
||||||
@commands.slash_command(name="right-behind-you")
|
@commands.slash_command(name="right-behind-you")
|
||||||
async def right_behind_you(self, ctx: discord.ApplicationContext, image: discord.Attachment):
|
async def right_behind_you(
|
||||||
|
self, ctx: discord.ApplicationContext, image: discord.Attachment
|
||||||
|
):
|
||||||
await ctx.defer()
|
await ctx.defer()
|
||||||
rbh = Path("assets/right-behind-you.ogg").resolve()
|
rbh = Path("assets/right-behind-you.ogg").resolve()
|
||||||
if not rbh.exists():
|
if not rbh.exists():
|
||||||
|
@ -264,7 +308,9 @@ class FFMeta(commands.Cog):
|
||||||
return await ctx.respond("That's not an image!")
|
return await ctx.respond("That's not an image!")
|
||||||
with tempfile.NamedTemporaryFile(suffix=Path(image.filename).suffix) as temp:
|
with tempfile.NamedTemporaryFile(suffix=Path(image.filename).suffix) as temp:
|
||||||
img_tmp = io.BytesIO(await image.read())
|
img_tmp = io.BytesIO(await image.read())
|
||||||
dst: io.BytesIO = await asyncio.to_thread(self.jpegify_image, img_tmp, 90, "webp")
|
dst: io.BytesIO = await asyncio.to_thread(
|
||||||
|
self.jpegify_image, img_tmp, 90, "webp"
|
||||||
|
)
|
||||||
temp.write(dst.getvalue())
|
temp.write(dst.getvalue())
|
||||||
temp.flush()
|
temp.flush()
|
||||||
process = await asyncio.create_subprocess_exec(
|
process = await asyncio.create_subprocess_exec(
|
||||||
|
@ -307,16 +353,20 @@ class FFMeta(commands.Cog):
|
||||||
"5",
|
"5",
|
||||||
"pipe:1",
|
"pipe:1",
|
||||||
stdout=asyncio.subprocess.PIPE,
|
stdout=asyncio.subprocess.PIPE,
|
||||||
stderr=sys.stderr
|
stderr=sys.stderr,
|
||||||
)
|
)
|
||||||
stdout, stderr = await process.communicate()
|
stdout, stderr = await process.communicate()
|
||||||
file = io.BytesIO(stdout)
|
file = io.BytesIO(stdout)
|
||||||
if (fs := len(file.getvalue())) > (24.75 * 1024 * 1024):
|
if (fs := len(file.getvalue())) > (24.75 * 1024 * 1024):
|
||||||
return await ctx.respond("The file is too large to send ({:,.2f} MiB).".format(fs / 1024 / 1024))
|
return await ctx.respond(
|
||||||
|
"The file is too large to send ({:,.2f} MiB).".format(fs / 1024 / 1024)
|
||||||
|
)
|
||||||
if not fs:
|
if not fs:
|
||||||
await ctx.respond("Failed to convert audio. See below.")
|
await ctx.respond("Failed to convert audio. See below.")
|
||||||
else:
|
else:
|
||||||
return await ctx.respond(file=discord.File(file, filename="right-behind-you.mp4"))
|
return await ctx.respond(
|
||||||
|
file=discord.File(file, filename="right-behind-you.mp4")
|
||||||
|
)
|
||||||
paginator = commands.Paginator()
|
paginator = commands.Paginator()
|
||||||
for line in stderr.decode().splitlines():
|
for line in stderr.decode().splitlines():
|
||||||
if line.strip().startswith(":"):
|
if line.strip().startswith(":"):
|
||||||
|
|
|
@ -11,13 +11,15 @@ class MeterCog(commands.Cog):
|
||||||
self.bot = bot
|
self.bot = bot
|
||||||
self.log = logging.getLogger("jimmy.cogs.auto_responder")
|
self.log = logging.getLogger("jimmy.cogs.auto_responder")
|
||||||
self.cache = {}
|
self.cache = {}
|
||||||
|
|
||||||
@commands.slash_command(name="gay-meter")
|
@commands.slash_command(name="gay-meter")
|
||||||
@discord.guild_only()
|
@discord.guild_only()
|
||||||
async def gay_meter(self, ctx: discord.ApplicationContext, user: discord.User = None):
|
async def gay_meter(
|
||||||
|
self, ctx: discord.ApplicationContext, user: discord.User = None
|
||||||
|
):
|
||||||
"""Checks how gay someone is"""
|
"""Checks how gay someone is"""
|
||||||
user = user or ctx.user
|
user = user or ctx.user
|
||||||
|
|
||||||
await ctx.respond("Calculating...")
|
await ctx.respond("Calculating...")
|
||||||
for i in range(0, 125, 25):
|
for i in range(0, 125, 25):
|
||||||
await ctx.edit(content="Calculating... %d%%" % i)
|
await ctx.edit(content="Calculating... %d%%" % i)
|
||||||
|
@ -28,9 +30,11 @@ class MeterCog(commands.Cog):
|
||||||
else:
|
else:
|
||||||
pct = user.id % 100
|
pct = user.id % 100
|
||||||
await ctx.edit(content=f"{user.mention} is {pct}% gay.")
|
await ctx.edit(content=f"{user.mention} is {pct}% gay.")
|
||||||
|
|
||||||
@commands.slash_command(name="penis-length")
|
@commands.slash_command(name="penis-length")
|
||||||
async def penis_meter(self, ctx: discord.ApplicationContext, user: discord.User = None):
|
async def penis_meter(
|
||||||
|
self, ctx: discord.ApplicationContext, user: discord.User = None
|
||||||
|
):
|
||||||
"""Checks the length of someone's penis."""
|
"""Checks the length of someone's penis."""
|
||||||
user = user or ctx.user
|
user = user or ctx.user
|
||||||
if random.randint(0, 1):
|
if random.randint(0, 1):
|
||||||
|
@ -49,10 +53,10 @@ class MeterCog(commands.Cog):
|
||||||
return await ctx.respond(
|
return await ctx.respond(
|
||||||
embed=discord.Embed(
|
embed=discord.Embed(
|
||||||
title=f"{user.display_name}'s penis length:",
|
title=f"{user.display_name}'s penis length:",
|
||||||
description="%d cm (%.2f%s)\n%s" % (pct, inch, im, "".join(chunks))
|
description="%d cm (%.2f%s)\n%s" % (pct, inch, im, "".join(chunks)),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@commands.command()
|
@commands.command()
|
||||||
@commands.is_owner()
|
@commands.is_owner()
|
||||||
async def clear_cache(self, ctx: commands.Context, user: discord.User = None):
|
async def clear_cache(self, ctx: commands.Context, user: discord.User = None):
|
||||||
|
@ -61,7 +65,7 @@ class MeterCog(commands.Cog):
|
||||||
self.cache = {}
|
self.cache = {}
|
||||||
else:
|
else:
|
||||||
self.cache.pop(user, None)
|
self.cache.pop(user, None)
|
||||||
return await ctx.message.add_reaction("\N{white heavy check mark}")
|
return await ctx.message.add_reaction("\N{WHITE HEAVY CHECK MARK}")
|
||||||
|
|
||||||
|
|
||||||
def setup(bot):
|
def setup(bot):
|
||||||
|
|
149
src/cogs/net.py
149
src/cogs/net.py
|
@ -33,17 +33,11 @@ class GetFilteredTextView(discord.ui.View):
|
||||||
self.text = text
|
self.text = text
|
||||||
super().__init__(timeout=600)
|
super().__init__(timeout=600)
|
||||||
|
|
||||||
@discord.ui.button(
|
@discord.ui.button(label="See filtered data", emoji="\N{INBOX TRAY}")
|
||||||
label="See filtered data",
|
|
||||||
emoji="\N{INBOX TRAY}"
|
|
||||||
)
|
|
||||||
async def see_filtered_data(self, _, interaction: discord.Interaction):
|
async def see_filtered_data(self, _, interaction: discord.Interaction):
|
||||||
await interaction.response.defer(ephemeral=True)
|
await interaction.response.defer(ephemeral=True)
|
||||||
await interaction.followup.send(
|
await interaction.followup.send(
|
||||||
file=discord.File(
|
file=discord.File(io.BytesIO(self.text.encode()), "filtered.txt")
|
||||||
io.BytesIO(self.text.encode()),
|
|
||||||
"filtered.txt"
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -106,7 +100,11 @@ class NetworkCog(commands.Cog):
|
||||||
def decide(ln: str) -> typing.Optional[bool]:
|
def decide(ln: str) -> typing.Optional[bool]:
|
||||||
if ln.startswith(">>> Last update"):
|
if ln.startswith(">>> Last update"):
|
||||||
return
|
return
|
||||||
if "REDACTED" in ln or "Please query the WHOIS server of the owning registrar" in ln or ":" not in ln:
|
if (
|
||||||
|
"REDACTED" in ln
|
||||||
|
or "Please query the WHOIS server of the owning registrar" in ln
|
||||||
|
or ":" not in ln
|
||||||
|
):
|
||||||
return False
|
return False
|
||||||
else:
|
else:
|
||||||
return True
|
return True
|
||||||
|
@ -134,7 +132,9 @@ class NetworkCog(commands.Cog):
|
||||||
if not paginator.pages:
|
if not paginator.pages:
|
||||||
stdout, stderr, status = await run_command(with_disclaimer=True)
|
stdout, stderr, status = await run_command(with_disclaimer=True)
|
||||||
if not any((stdout, stderr)):
|
if not any((stdout, stderr)):
|
||||||
return await ctx.respond(f"No output was returned with status code {status}.")
|
return await ctx.respond(
|
||||||
|
f"No output was returned with status code {status}."
|
||||||
|
)
|
||||||
file = io.BytesIO()
|
file = io.BytesIO()
|
||||||
file.write(stdout)
|
file.write(stdout)
|
||||||
if stderr:
|
if stderr:
|
||||||
|
@ -143,7 +143,7 @@ class NetworkCog(commands.Cog):
|
||||||
file.seek(0)
|
file.seek(0)
|
||||||
return await ctx.respond(
|
return await ctx.respond(
|
||||||
"Seemingly all output was filtered. Returning raw command output.",
|
"Seemingly all output was filtered. Returning raw command output.",
|
||||||
file=discord.File(file, "whois.txt")
|
file=discord.File(file, "whois.txt"),
|
||||||
)
|
)
|
||||||
|
|
||||||
last: discord.Interaction | discord.WebhookMessage | None = None
|
last: discord.Interaction | discord.WebhookMessage | None = None
|
||||||
|
@ -223,9 +223,15 @@ class NetworkCog(commands.Cog):
|
||||||
default="default",
|
default="default",
|
||||||
),
|
),
|
||||||
use_ip_version: discord.Option(
|
use_ip_version: discord.Option(
|
||||||
str, name="ip-version", description="IP version to use.", choices=["ipv4", "ipv6"], default="ipv4"
|
str,
|
||||||
|
name="ip-version",
|
||||||
|
description="IP version to use.",
|
||||||
|
choices=["ipv4", "ipv6"],
|
||||||
|
default="ipv4",
|
||||||
|
),
|
||||||
|
max_ttl: discord.Option(
|
||||||
|
int, name="ttl", description="Max number of hops", default=30
|
||||||
),
|
),
|
||||||
max_ttl: discord.Option(int, name="ttl", description="Max number of hops", default=30),
|
|
||||||
):
|
):
|
||||||
"""Performs a traceroute request."""
|
"""Performs a traceroute request."""
|
||||||
await ctx.defer()
|
await ctx.defer()
|
||||||
|
@ -258,7 +264,9 @@ class NetworkCog(commands.Cog):
|
||||||
args.append(str(port))
|
args.append(str(port))
|
||||||
args.append(url)
|
args.append(url)
|
||||||
paginator = commands.Paginator()
|
paginator = commands.Paginator()
|
||||||
paginator.add_line(f"Running command: {' '.join(args[3 if args[0] == 'sudo' else 0:])}")
|
paginator.add_line(
|
||||||
|
f"Running command: {' '.join(args[3 if args[0] == 'sudo' else 0:])}"
|
||||||
|
)
|
||||||
paginator.add_line(empty=True)
|
paginator.add_line(empty=True)
|
||||||
try:
|
try:
|
||||||
start = time.time_ns()
|
start = time.time_ns()
|
||||||
|
@ -297,10 +305,7 @@ class NetworkCog(commands.Cog):
|
||||||
await ctx.respond(file=discord.File(f))
|
await ctx.respond(file=discord.File(f))
|
||||||
|
|
||||||
async def _fetch_ip_response(
|
async def _fetch_ip_response(
|
||||||
self,
|
self, server: str, lookup: str, client: httpx.AsyncClient
|
||||||
server: str,
|
|
||||||
lookup: str,
|
|
||||||
client: httpx.AsyncClient
|
|
||||||
) -> tuple[dict, float] | httpx.HTTPError | ConnectionError | json.JSONDecodeError:
|
) -> tuple[dict, float] | httpx.HTTPError | ConnectionError | json.JSONDecodeError:
|
||||||
try:
|
try:
|
||||||
start = time.perf_counter()
|
start = time.perf_counter()
|
||||||
|
@ -315,18 +320,16 @@ class NetworkCog(commands.Cog):
|
||||||
async def get_ip_address(self, ctx: discord.ApplicationContext, lookup: str = None):
|
async def get_ip_address(self, ctx: discord.ApplicationContext, lookup: str = None):
|
||||||
"""Fetches IP info from SHRONK IP servers"""
|
"""Fetches IP info from SHRONK IP servers"""
|
||||||
await ctx.defer()
|
await ctx.defer()
|
||||||
async with httpx.AsyncClient(headers={"User-Agent": "Mozilla/5.0 Jimmy/v2"}) as client:
|
async with httpx.AsyncClient(
|
||||||
|
headers={"User-Agent": "Mozilla/5.0 Jimmy/v2"}
|
||||||
|
) as client:
|
||||||
if not lookup:
|
if not lookup:
|
||||||
response = await client.get("https://api.ipify.org")
|
response = await client.get("https://api.ipify.org")
|
||||||
lookup = response.text
|
lookup = response.text
|
||||||
|
|
||||||
servers = self.config.get(
|
servers = self.config.get(
|
||||||
"ip_servers",
|
"ip_servers",
|
||||||
[
|
["ip.shronk.net", "ip.i-am.nexus", "ip.shronk.nicroxio.co.uk"],
|
||||||
"ip.shronk.net",
|
|
||||||
"ip.i-am.nexus",
|
|
||||||
"ip.shronk.nicroxio.co.uk"
|
|
||||||
]
|
|
||||||
)
|
)
|
||||||
embed = discord.Embed(
|
embed = discord.Embed(
|
||||||
title="IP lookup information for: %s" % lookup,
|
title="IP lookup information for: %s" % lookup,
|
||||||
|
@ -353,13 +356,14 @@ class NetworkCog(commands.Cog):
|
||||||
t = response.text[:512]
|
t = response.text[:512]
|
||||||
embed.add_field(
|
embed.add_field(
|
||||||
name=server,
|
name=server,
|
||||||
value=f"An error occurred while parsing the data: {e}\nData: ```\n%s\n```" % t,
|
value=f"An error occurred while parsing the data: {e}\nData: ```\n%s\n```"
|
||||||
|
% t,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
embed.add_field(
|
embed.add_field(
|
||||||
name="%s (%.2fms)" % (server, (end - start) * 1000),
|
name="%s (%.2fms)" % (server, (end - start) * 1000),
|
||||||
value="```json\n%s\n```" % v,
|
value="```json\n%s\n```" % v,
|
||||||
inline=False
|
inline=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
await ctx.respond(embed=embed)
|
await ctx.respond(embed=embed)
|
||||||
|
@ -367,41 +371,41 @@ class NetworkCog(commands.Cog):
|
||||||
@commands.slash_command()
|
@commands.slash_command()
|
||||||
@commands.max_concurrency(1, commands.BucketType.user)
|
@commands.max_concurrency(1, commands.BucketType.user)
|
||||||
async def nmap(
|
async def nmap(
|
||||||
self,
|
self,
|
||||||
ctx: discord.ApplicationContext,
|
ctx: discord.ApplicationContext,
|
||||||
target: str,
|
target: str,
|
||||||
technique: typing.Annotated[
|
technique: typing.Annotated[
|
||||||
|
str,
|
||||||
|
discord.Option(
|
||||||
str,
|
str,
|
||||||
discord.Option(
|
choices=[
|
||||||
str,
|
discord.OptionChoice(name="TCP SYN", value="S"),
|
||||||
choices=[
|
discord.OptionChoice(name="TCP Connect", value="T"),
|
||||||
discord.OptionChoice(name="TCP SYN", value="S"),
|
discord.OptionChoice(name="TCP ACK", value="A"),
|
||||||
discord.OptionChoice(name="TCP Connect", value="T"),
|
discord.OptionChoice(name="TCP Window", value="W"),
|
||||||
discord.OptionChoice(name="TCP ACK", value="A"),
|
discord.OptionChoice(name="TCP Maimon", value="M"),
|
||||||
discord.OptionChoice(name="TCP Window", value="W"),
|
discord.OptionChoice(name="UDP", value="U"),
|
||||||
discord.OptionChoice(name="TCP Maimon", value="M"),
|
discord.OptionChoice(name="TCP Null", value="N"),
|
||||||
discord.OptionChoice(name="UDP", value="U"),
|
discord.OptionChoice(name="TCP FIN", value="F"),
|
||||||
discord.OptionChoice(name="TCP Null", value="N"),
|
discord.OptionChoice(name="TCP XMAS", value="X"),
|
||||||
discord.OptionChoice(name="TCP FIN", value="F"),
|
],
|
||||||
discord.OptionChoice(name="TCP XMAS", value="X"),
|
default="T",
|
||||||
],
|
),
|
||||||
default="T"
|
] = "T",
|
||||||
)
|
treat_all_hosts_online: bool = False,
|
||||||
] = "T",
|
service_scan: bool = False,
|
||||||
treat_all_hosts_online: bool = False,
|
fast_mode: bool = False,
|
||||||
service_scan: bool = False,
|
enable_os_detection: bool = False,
|
||||||
fast_mode: bool = False,
|
timing: typing.Annotated[
|
||||||
enable_os_detection: bool = False,
|
int,
|
||||||
timing: typing.Annotated[
|
discord.Option(
|
||||||
int,
|
int,
|
||||||
discord.Option(
|
description="Timing template to use 0 is slowest, 5 is fastest.",
|
||||||
int,
|
choices=[0, 1, 2, 3, 4, 5],
|
||||||
description="Timing template to use 0 is slowest, 5 is fastest.",
|
default=3,
|
||||||
choices=[0, 1, 2, 3, 4, 5],
|
),
|
||||||
default=3
|
] = 3,
|
||||||
)
|
ports: str = None,
|
||||||
] = 3,
|
|
||||||
ports: str = None
|
|
||||||
):
|
):
|
||||||
"""Runs nmap on a target. You cannot specify multiple targets."""
|
"""Runs nmap on a target. You cannot specify multiple targets."""
|
||||||
await ctx.defer()
|
await ctx.defer()
|
||||||
|
@ -420,7 +424,9 @@ class NetworkCog(commands.Cog):
|
||||||
if enable_os_detection and not is_superuser:
|
if enable_os_detection and not is_superuser:
|
||||||
return await ctx.respond("OS detection is not available on this system.")
|
return await ctx.respond("OS detection is not available on this system.")
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory(prefix=f"nmap-{ctx.user.id}-{discord.utils.utcnow().timestamp():.0f}") as tmp:
|
with tempfile.TemporaryDirectory(
|
||||||
|
prefix=f"nmap-{ctx.user.id}-{discord.utils.utcnow().timestamp():.0f}"
|
||||||
|
) as tmp:
|
||||||
tmp_dir = Path(tmp)
|
tmp_dir = Path(tmp)
|
||||||
args = [
|
args = [
|
||||||
"nmap",
|
"nmap",
|
||||||
|
@ -430,7 +436,7 @@ class NetworkCog(commands.Cog):
|
||||||
str(timing),
|
str(timing),
|
||||||
"-s" + technique,
|
"-s" + technique,
|
||||||
"--reason",
|
"--reason",
|
||||||
"--noninteractive"
|
"--noninteractive",
|
||||||
]
|
]
|
||||||
if treat_all_hosts_online:
|
if treat_all_hosts_online:
|
||||||
args.append("-Pn")
|
args.append("-Pn")
|
||||||
|
@ -447,8 +453,7 @@ class NetworkCog(commands.Cog):
|
||||||
await ctx.respond(
|
await ctx.respond(
|
||||||
embed=discord.Embed(
|
embed=discord.Embed(
|
||||||
title="Running nmap...",
|
title="Running nmap...",
|
||||||
description="Command:\n"
|
description="Command:\n" "```{}```".format(shlex.join(args)),
|
||||||
"```{}```".format(shlex.join(args)),
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
process = await asyncio.create_subprocess_exec(
|
process = await asyncio.create_subprocess_exec(
|
||||||
|
@ -457,14 +462,16 @@ class NetworkCog(commands.Cog):
|
||||||
stderr=asyncio.subprocess.PIPE,
|
stderr=asyncio.subprocess.PIPE,
|
||||||
)
|
)
|
||||||
_, stderr = await process.communicate()
|
_, stderr = await process.communicate()
|
||||||
files = [discord.File(x, filename=x.name + ".txt") for x in tmp_dir.iterdir()]
|
files = [
|
||||||
|
discord.File(x, filename=x.name + ".txt") for x in tmp_dir.iterdir()
|
||||||
|
]
|
||||||
if not files:
|
if not files:
|
||||||
if len(stderr) <= 4089:
|
if len(stderr) <= 4089:
|
||||||
return await ctx.edit(
|
return await ctx.edit(
|
||||||
embed=discord.Embed(
|
embed=discord.Embed(
|
||||||
title="Nmap failed.",
|
title="Nmap failed.",
|
||||||
description="```\n" + stderr.decode() + "```",
|
description="```\n" + stderr.decode() + "```",
|
||||||
color=discord.Color.red()
|
color=discord.Color.red(),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -475,19 +482,19 @@ class NetworkCog(commands.Cog):
|
||||||
embed=discord.Embed(
|
embed=discord.Embed(
|
||||||
title="Nmap failed.",
|
title="Nmap failed.",
|
||||||
description=f"Output was too long. [View full output]({result.url})",
|
description=f"Output was too long. [View full output]({result.url})",
|
||||||
color=discord.Color.red()
|
color=discord.Color.red(),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
await ctx.edit(
|
await ctx.edit(
|
||||||
embed=discord.Embed(
|
embed=discord.Embed(
|
||||||
title="Nmap finished!",
|
title="Nmap finished!",
|
||||||
description="Result files are attached.\n"
|
description="Result files are attached.\n"
|
||||||
"* `gnmap` is 'greppable'\n"
|
"* `gnmap` is 'greppable'\n"
|
||||||
"* `xml` is XML output\n"
|
"* `xml` is XML output\n"
|
||||||
"* `nmap` is normal output",
|
"* `nmap` is normal output",
|
||||||
color=discord.Color.green()
|
color=discord.Color.green(),
|
||||||
),
|
),
|
||||||
files=files
|
files=files,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -107,7 +107,9 @@ class OllamaDownloadHandler:
|
||||||
|
|
||||||
async def __aiter__(self):
|
async def __aiter__(self):
|
||||||
async with aiohttp.ClientSession(base_url=self.base_url) as client:
|
async with aiohttp.ClientSession(base_url=self.base_url) as client:
|
||||||
async with client.post("/api/pull", json={"name": self.model, "stream": True}, timeout=None) as response:
|
async with client.post(
|
||||||
|
"/api/pull", json={"name": self.model, "stream": True}, timeout=None
|
||||||
|
) as response:
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
async for line in ollama_stream(response.content):
|
async for line in ollama_stream(response.content):
|
||||||
self.parse_line(line)
|
self.parse_line(line)
|
||||||
|
@ -122,7 +124,7 @@ class OllamaDownloadHandler:
|
||||||
"Downloading orca-mini:7b on server %r - %s (%.2f%%)",
|
"Downloading orca-mini:7b on server %r - %s (%.2f%%)",
|
||||||
self.base_url,
|
self.base_url,
|
||||||
self.status,
|
self.status,
|
||||||
self.percent
|
self.percent,
|
||||||
)
|
)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
@ -176,12 +178,13 @@ class OllamaChatHandler:
|
||||||
async def __aiter__(self):
|
async def __aiter__(self):
|
||||||
async with aiohttp.ClientSession(base_url=self.base_url) as client:
|
async with aiohttp.ClientSession(base_url=self.base_url) as client:
|
||||||
async with client.post(
|
async with client.post(
|
||||||
"/api/chat", json={
|
"/api/chat",
|
||||||
"model": self.model,
|
json={
|
||||||
"stream": True,
|
"model": self.model,
|
||||||
|
"stream": True,
|
||||||
"messages": self.messages,
|
"messages": self.messages,
|
||||||
"options": self.options
|
"options": self.options,
|
||||||
}
|
},
|
||||||
) as response:
|
) as response:
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
async for line in ollama_stream(response.content):
|
async for line in ollama_stream(response.content):
|
||||||
|
@ -201,7 +204,9 @@ class OllamaClient:
|
||||||
self.base_url = base_url
|
self.base_url = base_url
|
||||||
self.authorisation = authorisation
|
self.authorisation = authorisation
|
||||||
|
|
||||||
def with_client(self, timeout: aiohttp.ClientTimeout | float | int | None = None) -> aiohttp.ClientSession:
|
def with_client(
|
||||||
|
self, timeout: aiohttp.ClientTimeout | float | int | None = None
|
||||||
|
) -> aiohttp.ClientSession:
|
||||||
"""
|
"""
|
||||||
Creates an instance for a request, with properly populated values.
|
Creates an instance for a request, with properly populated values.
|
||||||
:param timeout:
|
:param timeout:
|
||||||
|
@ -213,9 +218,13 @@ class OllamaClient:
|
||||||
timeout = aiohttp.ClientTimeout(timeout)
|
timeout = aiohttp.ClientTimeout(timeout)
|
||||||
else:
|
else:
|
||||||
timeout = timeout or aiohttp.ClientTimeout(120)
|
timeout = timeout or aiohttp.ClientTimeout(120)
|
||||||
return aiohttp.ClientSession(self.base_url, timeout=timeout, auth=self.authorisation)
|
return aiohttp.ClientSession(
|
||||||
|
self.base_url, timeout=timeout, auth=self.authorisation
|
||||||
|
)
|
||||||
|
|
||||||
async def get_tags(self) -> dict[typing.Literal["models"], dict[str, str, int, dict[str, str, None]]]:
|
async def get_tags(
|
||||||
|
self,
|
||||||
|
) -> dict[typing.Literal["models"], dict[str, str, int, dict[str, str, None]]]:
|
||||||
"""
|
"""
|
||||||
Gets the tags for the server.
|
Gets the tags for the server.
|
||||||
:return:
|
:return:
|
||||||
|
@ -250,7 +259,7 @@ class OllamaClient:
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
messages: list[dict[str, str]],
|
messages: list[dict[str, str]],
|
||||||
options: dict[str, typing.Any] = None
|
options: dict[str, typing.Any] = None,
|
||||||
) -> OllamaChatHandler:
|
) -> OllamaChatHandler:
|
||||||
"""
|
"""
|
||||||
Starts a chat with the given messages.
|
Starts a chat with the given messages.
|
||||||
|
@ -273,7 +282,11 @@ class OllamaView(View):
|
||||||
async def interaction_check(self, interaction: discord.Interaction) -> bool:
|
async def interaction_check(self, interaction: discord.Interaction) -> bool:
|
||||||
return interaction.user == self.ctx.user
|
return interaction.user == self.ctx.user
|
||||||
|
|
||||||
@button(label="Stop", style=discord.ButtonStyle.danger, emoji="\N{wastebasket}\U0000fe0f")
|
@button(
|
||||||
|
label="Stop",
|
||||||
|
style=discord.ButtonStyle.danger,
|
||||||
|
emoji="\N{WASTEBASKET}\U0000fe0f",
|
||||||
|
)
|
||||||
async def _stop(self, btn: discord.ui.Button, interaction: discord.Interaction):
|
async def _stop(self, btn: discord.ui.Button, interaction: discord.Interaction):
|
||||||
self.cancel.set()
|
self.cancel.set()
|
||||||
btn.disabled = True
|
btn.disabled = True
|
||||||
|
@ -310,14 +323,20 @@ class ChatHistory:
|
||||||
:return: The thread's ID.
|
:return: The thread's ID.
|
||||||
"""
|
"""
|
||||||
key = os.urandom(3).hex()
|
key = os.urandom(3).hex()
|
||||||
self._internal[key] = {"member": member.id, "seed": round(time.time()), "messages": []}
|
self._internal[key] = {
|
||||||
|
"member": member.id,
|
||||||
|
"seed": round(time.time()),
|
||||||
|
"messages": [],
|
||||||
|
}
|
||||||
with open("./assets/ollama-prompt.txt") as file:
|
with open("./assets/ollama-prompt.txt") as file:
|
||||||
system_prompt = default or file.read()
|
system_prompt = default or file.read()
|
||||||
self.add_message(key, "system", system_prompt)
|
self.add_message(key, "system", system_prompt)
|
||||||
return key
|
return key
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _construct_message(role: str, content: str, images: typing.Optional[list[str]]) -> dict[str, str]:
|
def _construct_message(
|
||||||
|
role: str, content: str, images: typing.Optional[list[str]]
|
||||||
|
) -> dict[str, str]:
|
||||||
x = {"role": role, "content": content}
|
x = {"role": role, "content": content}
|
||||||
if images:
|
if images:
|
||||||
x["images"] = images
|
x["images"] = images
|
||||||
|
@ -331,7 +350,9 @@ class ChatHistory:
|
||||||
return list(
|
return list(
|
||||||
filter(
|
filter(
|
||||||
lambda v: (ctx.value or v) in v,
|
lambda v: (ctx.value or v) in v,
|
||||||
map(lambda d: list(d.keys()), instance.threads_for(ctx.interaction.user)),
|
map(
|
||||||
|
lambda d: list(d.keys()), instance.threads_for(ctx.interaction.user)
|
||||||
|
),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -339,7 +360,9 @@ class ChatHistory:
|
||||||
"""Returns all saved threads."""
|
"""Returns all saved threads."""
|
||||||
return self._internal.copy()
|
return self._internal.copy()
|
||||||
|
|
||||||
def threads_for(self, user: discord.Member) -> dict[str, dict[str, list[dict[str, str]] | int]]:
|
def threads_for(
|
||||||
|
self, user: discord.Member
|
||||||
|
) -> dict[str, dict[str, list[dict[str, str]] | int]]:
|
||||||
"""Returns all saved threads for a specific user"""
|
"""Returns all saved threads for a specific user"""
|
||||||
t = self.all_threads()
|
t = self.all_threads()
|
||||||
for k, v in t.copy().items():
|
for k, v in t.copy().items():
|
||||||
|
@ -354,7 +377,7 @@ class ChatHistory:
|
||||||
content: str,
|
content: str,
|
||||||
images: typing.Optional[list[str]] = None,
|
images: typing.Optional[list[str]] = None,
|
||||||
*,
|
*,
|
||||||
save: bool = True
|
save: bool = True,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Appends a message to the given thread.
|
Appends a message to the given thread.
|
||||||
|
@ -380,7 +403,9 @@ class ChatHistory:
|
||||||
return []
|
return []
|
||||||
return self._internal[thread]["messages"].copy() # copy() makes it immutable.
|
return self._internal[thread]["messages"].copy() # copy() makes it immutable.
|
||||||
|
|
||||||
def get_thread(self, thread: str) -> dict[str, list[dict[str, str]] | discord.Member | int]:
|
def get_thread(
|
||||||
|
self, thread: str
|
||||||
|
) -> dict[str, list[dict[str, str]] | discord.Member | int]:
|
||||||
"""Gets a copy of an entire thread"""
|
"""Gets a copy of an entire thread"""
|
||||||
return self._internal.get(thread, {}).copy()
|
return self._internal.get(thread, {}).copy()
|
||||||
|
|
||||||
|
@ -401,7 +426,6 @@ SERVER_KEYS_AUTOCOMPLETE.remove("order")
|
||||||
|
|
||||||
|
|
||||||
class OllamaGetPrompt(discord.ui.Modal):
|
class OllamaGetPrompt(discord.ui.Modal):
|
||||||
|
|
||||||
def __init__(self, ctx: discord.ApplicationContext, prompt_type: str = "User"):
|
def __init__(self, ctx: discord.ApplicationContext, prompt_type: str = "User"):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
discord.ui.InputText(
|
discord.ui.InputText(
|
||||||
|
@ -441,7 +465,9 @@ class PromptSelector(discord.ui.View):
|
||||||
if self.user_prompt is not None:
|
if self.user_prompt is not None:
|
||||||
self.get_item("usr").style = discord.ButtonStyle.secondary # type: ignore
|
self.get_item("usr").style = discord.ButtonStyle.secondary # type: ignore
|
||||||
|
|
||||||
@discord.ui.button(label="Set System Prompt", style=discord.ButtonStyle.primary, custom_id="sys")
|
@discord.ui.button(
|
||||||
|
label="Set System Prompt", style=discord.ButtonStyle.primary, custom_id="sys"
|
||||||
|
)
|
||||||
async def set_system_prompt(self, btn: discord.ui.Button, interaction: Interaction):
|
async def set_system_prompt(self, btn: discord.ui.Button, interaction: Interaction):
|
||||||
modal = OllamaGetPrompt(self.ctx, "System")
|
modal = OllamaGetPrompt(self.ctx, "System")
|
||||||
await interaction.response.send_modal(modal)
|
await interaction.response.send_modal(modal)
|
||||||
|
@ -450,7 +476,9 @@ class PromptSelector(discord.ui.View):
|
||||||
self.update_ui()
|
self.update_ui()
|
||||||
await interaction.edit_original_response(view=self)
|
await interaction.edit_original_response(view=self)
|
||||||
|
|
||||||
@discord.ui.button(label="Set User Prompt", style=discord.ButtonStyle.primary, custom_id="usr")
|
@discord.ui.button(
|
||||||
|
label="Set User Prompt", style=discord.ButtonStyle.primary, custom_id="usr"
|
||||||
|
)
|
||||||
async def set_user_prompt(self, btn: discord.ui.Button, interaction: Interaction):
|
async def set_user_prompt(self, btn: discord.ui.Button, interaction: Interaction):
|
||||||
modal = OllamaGetPrompt(self.ctx)
|
modal = OllamaGetPrompt(self.ctx)
|
||||||
await interaction.response.send_modal(modal)
|
await interaction.response.send_modal(modal)
|
||||||
|
@ -459,7 +487,9 @@ class PromptSelector(discord.ui.View):
|
||||||
self.update_ui()
|
self.update_ui()
|
||||||
await interaction.edit_original_response(view=self)
|
await interaction.edit_original_response(view=self)
|
||||||
|
|
||||||
@discord.ui.button(label="Done", style=discord.ButtonStyle.success, custom_id="done")
|
@discord.ui.button(
|
||||||
|
label="Done", style=discord.ButtonStyle.success, custom_id="done"
|
||||||
|
)
|
||||||
async def done(self, btn: discord.ui.Button, interaction: Interaction):
|
async def done(self, btn: discord.ui.Button, interaction: Interaction):
|
||||||
self.ctx.interaction = interaction
|
self.ctx.interaction = interaction
|
||||||
self.stop()
|
self.stop()
|
||||||
|
@ -474,13 +504,17 @@ class ConfirmCPURun(discord.ui.View):
|
||||||
async def interaction_check(self, interaction: Interaction) -> bool:
|
async def interaction_check(self, interaction: Interaction) -> bool:
|
||||||
return interaction.user == self.ctx.user
|
return interaction.user == self.ctx.user
|
||||||
|
|
||||||
@discord.ui.button(label="Run on CPU", style=discord.ButtonStyle.primary, custom_id="cpu")
|
@discord.ui.button(
|
||||||
|
label="Run on CPU", style=discord.ButtonStyle.primary, custom_id="cpu"
|
||||||
|
)
|
||||||
async def run_on_cpu(self, btn: discord.ui.Button, interaction: Interaction):
|
async def run_on_cpu(self, btn: discord.ui.Button, interaction: Interaction):
|
||||||
await interaction.response.defer(invisible=True)
|
await interaction.response.defer(invisible=True)
|
||||||
self.proceed = True
|
self.proceed = True
|
||||||
self.stop()
|
self.stop()
|
||||||
|
|
||||||
@discord.ui.button(label="Abort", style=discord.ButtonStyle.primary, custom_id="gpu")
|
@discord.ui.button(
|
||||||
|
label="Abort", style=discord.ButtonStyle.primary, custom_id="gpu"
|
||||||
|
)
|
||||||
async def run_on_gpu(self, btn: discord.ui.Button, interaction: Interaction):
|
async def run_on_gpu(self, btn: discord.ui.Button, interaction: Interaction):
|
||||||
await interaction.response.defer(invisible=True)
|
await interaction.response.defer(invisible=True)
|
||||||
self.stop()
|
self.stop()
|
||||||
|
@ -492,9 +526,7 @@ class Ollama(commands.Cog):
|
||||||
self.log = logging.getLogger("jimmy.cogs.ollama")
|
self.log = logging.getLogger("jimmy.cogs.ollama")
|
||||||
self.contexts = {}
|
self.contexts = {}
|
||||||
self.history = ChatHistory()
|
self.history = ChatHistory()
|
||||||
self.servers = {
|
self.servers = {server: asyncio.Lock() for server in CONFIG["ollama"]}
|
||||||
server: asyncio.Lock() for server in CONFIG["ollama"]
|
|
||||||
}
|
|
||||||
self.servers.pop("order", None)
|
self.servers.pop("order", None)
|
||||||
if CONFIG["ollama"].get("order"):
|
if CONFIG["ollama"].get("order"):
|
||||||
self.servers = {}
|
self.servers = {}
|
||||||
|
@ -520,7 +552,9 @@ class Ollama(commands.Cog):
|
||||||
if url in SERVER_KEYS:
|
if url in SERVER_KEYS:
|
||||||
url = CONFIG["ollama"][url]["base_url"]
|
url = CONFIG["ollama"][url]["base_url"]
|
||||||
async with aiohttp.ClientSession(
|
async with aiohttp.ClientSession(
|
||||||
timeout=aiohttp.ClientTimeout(connect=3, sock_connect=3, sock_read=10, total=3)
|
timeout=aiohttp.ClientTimeout(
|
||||||
|
connect=3, sock_connect=3, sock_read=10, total=3
|
||||||
|
)
|
||||||
) as session:
|
) as session:
|
||||||
self.log.info("Checking if %r is online.", url)
|
self.log.info("Checking if %r is online.", url)
|
||||||
try:
|
try:
|
||||||
|
@ -551,20 +585,37 @@ class Ollama(commands.Cog):
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
server: typing.Annotated[
|
server: typing.Annotated[
|
||||||
str, discord.Option(str, "The server to use for ollama.", default="next", choices=SERVER_KEYS_AUTOCOMPLETE)
|
str,
|
||||||
|
discord.Option(
|
||||||
|
str,
|
||||||
|
"The server to use for ollama.",
|
||||||
|
default="next",
|
||||||
|
choices=SERVER_KEYS_AUTOCOMPLETE,
|
||||||
|
),
|
||||||
],
|
],
|
||||||
context: typing.Annotated[
|
context: typing.Annotated[
|
||||||
str, discord.Option(str, "The context key of a previous ollama response to use as context.", default=None)
|
str,
|
||||||
|
discord.Option(
|
||||||
|
str,
|
||||||
|
"The context key of a previous ollama response to use as context.",
|
||||||
|
default=None,
|
||||||
|
),
|
||||||
],
|
],
|
||||||
give_acid: typing.Annotated[
|
give_acid: typing.Annotated[
|
||||||
bool,
|
bool,
|
||||||
discord.Option(
|
discord.Option(
|
||||||
bool, "Whether to give the AI acid, LSD, and other hallucinogens before responding.", default=False
|
bool,
|
||||||
|
"Whether to give the AI acid, LSD, and other hallucinogens before responding.",
|
||||||
|
default=False,
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
image: typing.Annotated[
|
image: typing.Annotated[
|
||||||
discord.Attachment,
|
discord.Attachment,
|
||||||
discord.Option(discord.Attachment, "An image to feed into ollama. Only works with llava.", default=None),
|
discord.Option(
|
||||||
|
discord.Attachment,
|
||||||
|
"An image to feed into ollama. Only works with llava.",
|
||||||
|
default=None,
|
||||||
|
),
|
||||||
],
|
],
|
||||||
):
|
):
|
||||||
if server == "next":
|
if server == "next":
|
||||||
|
@ -587,7 +638,7 @@ class Ollama(commands.Cog):
|
||||||
await ctx.respond(
|
await ctx.respond(
|
||||||
"Select edit your prompts, as desired. Click done when you want to continue.",
|
"Select edit your prompts, as desired. Click done when you want to continue.",
|
||||||
view=v,
|
view=v,
|
||||||
ephemeral=True
|
ephemeral=True,
|
||||||
)
|
)
|
||||||
await v.wait()
|
await v.wait()
|
||||||
query = v.user_prompt or query
|
query = v.user_prompt or query
|
||||||
|
@ -604,22 +655,23 @@ class Ollama(commands.Cog):
|
||||||
self.log.debug("Resolved model to %r" % model)
|
self.log.debug("Resolved model to %r" % model)
|
||||||
|
|
||||||
if image:
|
if image:
|
||||||
patterns = [
|
patterns = ["llava:*", "llava-llama*:*"]
|
||||||
"llava:*",
|
|
||||||
"llava-llama*:*"
|
|
||||||
]
|
|
||||||
if any(fnmatch(model, p) for p in patterns) is False:
|
if any(fnmatch(model, p) for p in patterns) is False:
|
||||||
await ctx.send(
|
await ctx.send(
|
||||||
f"{ctx.user.mention}: You can only use images with llava. Switching model to `llava:latest`.",
|
f"{ctx.user.mention}: You can only use images with llava. Switching model to `llava:latest`.",
|
||||||
delete_after=5
|
delete_after=5,
|
||||||
)
|
)
|
||||||
model = "llava:latest"
|
model = "llava:latest"
|
||||||
|
|
||||||
if image.size > 1024 * 1024 * 25:
|
if image.size > 1024 * 1024 * 25:
|
||||||
await ctx.respond("Attachment is too large. Maximum size is 25 MB, for sanity. Try compressing it.")
|
await ctx.respond(
|
||||||
|
"Attachment is too large. Maximum size is 25 MB, for sanity. Try compressing it."
|
||||||
|
)
|
||||||
return
|
return
|
||||||
elif not fnmatch(image.content_type, "image/*"):
|
elif not fnmatch(image.content_type, "image/*"):
|
||||||
await ctx.respond("Attachment is not an image. Try using a different file.")
|
await ctx.respond(
|
||||||
|
"Attachment is not an image. Try using a different file."
|
||||||
|
)
|
||||||
return
|
return
|
||||||
else:
|
else:
|
||||||
data = io.BytesIO()
|
data = io.BytesIO()
|
||||||
|
@ -638,13 +690,19 @@ class Ollama(commands.Cog):
|
||||||
if fnmatch(model, model_pattern):
|
if fnmatch(model, model_pattern):
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
allowed_models = ", ".join(map(discord.utils.escape_markdown, server_config["allowed_models"]))
|
allowed_models = ", ".join(
|
||||||
await ctx.respond(f"Invalid model. You can only use one of the following models: {allowed_models}")
|
map(discord.utils.escape_markdown, server_config["allowed_models"])
|
||||||
|
)
|
||||||
|
await ctx.respond(
|
||||||
|
f"Invalid model. You can only use one of the following models: {allowed_models}"
|
||||||
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
async with aiohttp.ClientSession(
|
async with aiohttp.ClientSession(
|
||||||
base_url=server_config["base_url"],
|
base_url=server_config["base_url"],
|
||||||
timeout=aiohttp.ClientTimeout(connect=5, sock_read=10800, sock_connect=5, total=10830),
|
timeout=aiohttp.ClientTimeout(
|
||||||
|
connect=5, sock_read=10800, sock_connect=5, total=10830
|
||||||
|
),
|
||||||
) as session:
|
) as session:
|
||||||
embed = discord.Embed(
|
embed = discord.Embed(
|
||||||
title="Checking server...",
|
title="Checking server...",
|
||||||
|
@ -652,26 +710,32 @@ class Ollama(commands.Cog):
|
||||||
color=discord.Color.blurple(),
|
color=discord.Color.blurple(),
|
||||||
timestamp=discord.utils.utcnow(),
|
timestamp=discord.utils.utcnow(),
|
||||||
)
|
)
|
||||||
embed.set_footer(text="Using server %r" % server, icon_url=server_config.get("icon_url"))
|
embed.set_footer(
|
||||||
|
text="Using server %r" % server, icon_url=server_config.get("icon_url")
|
||||||
|
)
|
||||||
await ctx.respond(embed=embed)
|
await ctx.respond(embed=embed)
|
||||||
if not await self.check_server(server_config["base_url"]):
|
if not await self.check_server(server_config["base_url"]):
|
||||||
tried = {server}
|
tried = {server}
|
||||||
for i in range(10):
|
for i in range(10):
|
||||||
try:
|
try:
|
||||||
server = self.next_server(tried)
|
server = self.next_server(tried)
|
||||||
if server and CONFIG["ollama"]["server"].get("is_gpu", False) is not True:
|
if (
|
||||||
|
server
|
||||||
|
and CONFIG["ollama"]["server"].get("is_gpu", False)
|
||||||
|
is not True
|
||||||
|
):
|
||||||
cf = ConfirmCPURun(ctx)
|
cf = ConfirmCPURun(ctx)
|
||||||
await ctx.edit(
|
await ctx.edit(
|
||||||
embed=discord.Embed(
|
embed=discord.Embed(
|
||||||
title=f"Server {server} is available, but...",
|
title=f"Server {server} is available, but...",
|
||||||
description="It is CPU only, which means it is very slow and will likely crash.\n"
|
description="It is CPU only, which means it is very slow and will likely crash.\n"
|
||||||
"If you really want, you can continue your generation, using this "
|
"If you really want, you can continue your generation, using this "
|
||||||
"server. Be aware though, once in motion, it cannot be stopped.\n\n"
|
"server. Be aware though, once in motion, it cannot be stopped.\n\n"
|
||||||
""
|
""
|
||||||
"Continue?",
|
"Continue?",
|
||||||
color=discord.Color.red(),
|
color=discord.Color.red(),
|
||||||
),
|
),
|
||||||
view=cf
|
view=cf,
|
||||||
)
|
)
|
||||||
await cf.wait()
|
await cf.wait()
|
||||||
await ctx.edit(view=None)
|
await ctx.edit(view=None)
|
||||||
|
@ -688,7 +752,10 @@ class Ollama(commands.Cog):
|
||||||
color=discord.Color.gold(),
|
color=discord.Color.gold(),
|
||||||
timestamp=discord.utils.utcnow(),
|
timestamp=discord.utils.utcnow(),
|
||||||
)
|
)
|
||||||
embed.set_footer(text="Using server %r" % server, icon_url=server_config.get("icon_url"))
|
embed.set_footer(
|
||||||
|
text="Using server %r" % server,
|
||||||
|
icon_url=server_config.get("icon_url"),
|
||||||
|
)
|
||||||
await ctx.edit(embed=embed)
|
await ctx.edit(embed=embed)
|
||||||
await asyncio.sleep(1)
|
await asyncio.sleep(1)
|
||||||
if await self.check_server(CONFIG["ollama"][server]["base_url"]):
|
if await self.check_server(CONFIG["ollama"][server]["base_url"]):
|
||||||
|
@ -713,7 +780,9 @@ class Ollama(commands.Cog):
|
||||||
embed = discord.Embed(
|
embed = discord.Embed(
|
||||||
url=resp.url,
|
url=resp.url,
|
||||||
title=f"HTTP {resp.status} {resp.reason!r} while checking for model.",
|
title=f"HTTP {resp.status} {resp.reason!r} while checking for model.",
|
||||||
description=f"```{await resp.text() or 'No response body'}```"[:4096],
|
description=f"```{await resp.text() or 'No response body'}```"[
|
||||||
|
:4096
|
||||||
|
],
|
||||||
color=discord.Color.red(),
|
color=discord.Color.red(),
|
||||||
timestamp=discord.utils.utcnow(),
|
timestamp=discord.utils.utcnow(),
|
||||||
)
|
)
|
||||||
|
@ -733,8 +802,8 @@ class Ollama(commands.Cog):
|
||||||
self.log.debug("Beginning download of %r", model)
|
self.log.debug("Beginning download of %r", model)
|
||||||
|
|
||||||
def progress_bar(_v: float, action: str = None, _mbps: float = None):
|
def progress_bar(_v: float, action: str = None, _mbps: float = None):
|
||||||
bar = "\N{large green square}" * round(_v / 10)
|
bar = "\N{LARGE GREEN SQUARE}" * round(_v / 10)
|
||||||
bar += "\N{white large square}" * (10 - len(bar))
|
bar += "\N{WHITE LARGE SQUARE}" * (10 - len(bar))
|
||||||
bar += f" {_v:.2f}%"
|
bar += f" {_v:.2f}%"
|
||||||
if _mbps:
|
if _mbps:
|
||||||
bar += f" ({_mbps:.2f} MiB/s)"
|
bar += f" ({_mbps:.2f} MiB/s)"
|
||||||
|
@ -753,12 +822,16 @@ class Ollama(commands.Cog):
|
||||||
|
|
||||||
last_update = time.time()
|
last_update = time.time()
|
||||||
|
|
||||||
async with session.post("/api/pull", json={"name": model, "stream": True}, timeout=None) as response:
|
async with session.post(
|
||||||
|
"/api/pull", json={"name": model, "stream": True}, timeout=None
|
||||||
|
) as response:
|
||||||
if response.status != 200:
|
if response.status != 200:
|
||||||
embed = discord.Embed(
|
embed = discord.Embed(
|
||||||
url=response.url,
|
url=response.url,
|
||||||
title=f"HTTP {response.status} {response.reason!r} while downloading model.",
|
title=f"HTTP {response.status} {response.reason!r} while downloading model.",
|
||||||
description=f"```{await response.text() or 'No response body'}```"[:4096],
|
description=f"```{await response.text() or 'No response body'}```"[
|
||||||
|
:4096
|
||||||
|
],
|
||||||
color=discord.Color.red(),
|
color=discord.Color.red(),
|
||||||
timestamp=discord.utils.utcnow(),
|
timestamp=discord.utils.utcnow(),
|
||||||
)
|
)
|
||||||
|
@ -775,7 +848,10 @@ class Ollama(commands.Cog):
|
||||||
)
|
)
|
||||||
return await ctx.edit(embed=embed, view=None)
|
return await ctx.edit(embed=embed, view=None)
|
||||||
if time.time() >= (last_update + 5.1):
|
if time.time() >= (last_update + 5.1):
|
||||||
if line.get("total") is not None and line.get("completed") is not None:
|
if (
|
||||||
|
line.get("total") is not None
|
||||||
|
and line.get("completed") is not None
|
||||||
|
):
|
||||||
new_bytes = line["completed"] - last_downloaded
|
new_bytes = line["completed"] - last_downloaded
|
||||||
mbps = new_bytes / 1024 / 1024 / 5
|
mbps = new_bytes / 1024 / 1024 / 5
|
||||||
last_downloaded = line["completed"]
|
last_downloaded = line["completed"]
|
||||||
|
@ -784,7 +860,9 @@ class Ollama(commands.Cog):
|
||||||
percent = 50.0
|
percent = 50.0
|
||||||
mbps = 0.0
|
mbps = 0.0
|
||||||
|
|
||||||
embed.fields[0].value = progress_bar(percent, line["status"], mbps)
|
embed.fields[0].value = progress_bar(
|
||||||
|
percent, line["status"], mbps
|
||||||
|
)
|
||||||
await ctx.edit(embed=embed, view=view)
|
await ctx.edit(embed=embed, view=view)
|
||||||
last_update = time.time()
|
last_update = time.time()
|
||||||
else:
|
else:
|
||||||
|
@ -799,13 +877,13 @@ class Ollama(commands.Cog):
|
||||||
embed=discord.Embed(
|
embed=discord.Embed(
|
||||||
title="Before you continue",
|
title="Before you continue",
|
||||||
description="You've selected a CPU-only server. This will be really slow. This will also likely"
|
description="You've selected a CPU-only server. This will be really slow. This will also likely"
|
||||||
" bring the host to a halt. Consider the pain the CPU is about to endure. "
|
" bring the host to a halt. Consider the pain the CPU is about to endure. "
|
||||||
"Are you super"
|
"Are you super"
|
||||||
" sure you want to continue? You can run `h!ollama-status` to see what servers are"
|
" sure you want to continue? You can run `h!ollama-status` to see what servers are"
|
||||||
" available.",
|
" available.",
|
||||||
color=discord.Color.red(),
|
color=discord.Color.red(),
|
||||||
),
|
),
|
||||||
view=cf2
|
view=cf2,
|
||||||
)
|
)
|
||||||
await cf2.wait()
|
await cf2.wait()
|
||||||
if cf2.proceed is False:
|
if cf2.proceed is False:
|
||||||
|
@ -825,9 +903,13 @@ class Ollama(commands.Cog):
|
||||||
icon_url="https://ollama.com/public/ollama.png",
|
icon_url="https://ollama.com/public/ollama.png",
|
||||||
)
|
)
|
||||||
embed.add_field(
|
embed.add_field(
|
||||||
name="Prompt", value=">>> " + textwrap.shorten(query, width=1020, placeholder="..."), inline=False
|
name="Prompt",
|
||||||
|
value=">>> " + textwrap.shorten(query, width=1020, placeholder="..."),
|
||||||
|
inline=False,
|
||||||
|
)
|
||||||
|
embed.set_footer(
|
||||||
|
text="Using server %r" % server, icon_url=server_config.get("icon_url")
|
||||||
)
|
)
|
||||||
embed.set_footer(text="Using server %r" % server, icon_url=server_config.get("icon_url"))
|
|
||||||
if image_data:
|
if image_data:
|
||||||
if (image.height / image.width) >= 1.5:
|
if (image.height / image.width) >= 1.5:
|
||||||
embed.set_image(url=image.url)
|
embed.set_image(url=image.url)
|
||||||
|
@ -842,7 +924,10 @@ class Ollama(commands.Cog):
|
||||||
|
|
||||||
if context is None:
|
if context is None:
|
||||||
context = self.history.create_thread(ctx.user, system_query)
|
context = self.history.create_thread(ctx.user, system_query)
|
||||||
elif context is not None and (__thread := self.history.find_thread(context)) is None:
|
elif (
|
||||||
|
context is not None
|
||||||
|
and (__thread := self.history.find_thread(context)) is None
|
||||||
|
):
|
||||||
if not __thread:
|
if not __thread:
|
||||||
return await ctx.respond("Invalid thread ID.")
|
return await ctx.respond("Invalid thread ID.")
|
||||||
else:
|
else:
|
||||||
|
@ -861,7 +946,12 @@ class Ollama(commands.Cog):
|
||||||
params["top_p"] = 2
|
params["top_p"] = 2
|
||||||
params["repeat_penalty"] = 2
|
params["repeat_penalty"] = 2
|
||||||
|
|
||||||
payload = {"model": model, "stream": True, "options": params, "messages": messages}
|
payload = {
|
||||||
|
"model": model,
|
||||||
|
"stream": True,
|
||||||
|
"options": params,
|
||||||
|
"messages": messages,
|
||||||
|
}
|
||||||
async with session.post(
|
async with session.post(
|
||||||
"/api/chat",
|
"/api/chat",
|
||||||
json=payload,
|
json=payload,
|
||||||
|
@ -870,7 +960,9 @@ class Ollama(commands.Cog):
|
||||||
embed = discord.Embed(
|
embed = discord.Embed(
|
||||||
url=response.url,
|
url=response.url,
|
||||||
title=f"HTTP {response.status} {response.reason!r} while generating response.",
|
title=f"HTTP {response.status} {response.reason!r} while generating response.",
|
||||||
description=f"```{await response.text() or 'No response body'}```"[:4096],
|
description=f"```{await response.text() or 'No response body'}```"[
|
||||||
|
:4096
|
||||||
|
],
|
||||||
color=discord.Color.red(),
|
color=discord.Color.red(),
|
||||||
timestamp=discord.utils.utcnow(),
|
timestamp=discord.utils.utcnow(),
|
||||||
)
|
)
|
||||||
|
@ -888,20 +980,31 @@ class Ollama(commands.Cog):
|
||||||
embed.description = "[...]" + line["message"]["content"]
|
embed.description = "[...]" + line["message"]["content"]
|
||||||
if len(embed.description) >= 3250:
|
if len(embed.description) >= 3250:
|
||||||
embed.colour = discord.Color.gold()
|
embed.colour = discord.Color.gold()
|
||||||
embed.set_footer(text="Warning: {:,}/4096 characters.".format(len(embed.description)))
|
embed.set_footer(
|
||||||
|
text="Warning: {:,}/4096 characters.".format(
|
||||||
|
len(embed.description)
|
||||||
|
)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
embed.colour = discord.Color.blurple()
|
embed.colour = discord.Color.blurple()
|
||||||
embed.set_footer(text="Using server %r" % server, icon_url=server_config.get("icon_url"))
|
embed.set_footer(
|
||||||
|
text="Using server %r" % server,
|
||||||
|
icon_url=server_config.get("icon_url"),
|
||||||
|
)
|
||||||
|
|
||||||
if view.cancel.is_set():
|
if view.cancel.is_set():
|
||||||
break
|
break
|
||||||
|
|
||||||
if time.time() >= (last_update + 5.1):
|
if time.time() >= (last_update + 5.1):
|
||||||
await ctx.edit(embed=embed, view=view)
|
await ctx.edit(embed=embed, view=view)
|
||||||
self.log.debug(f"Updating message ({last_update} -> {time.time()})")
|
self.log.debug(
|
||||||
|
f"Updating message ({last_update} -> {time.time()})"
|
||||||
|
)
|
||||||
last_update = time.time()
|
last_update = time.time()
|
||||||
view.stop()
|
view.stop()
|
||||||
self.history.add_message(context, "user", user_message["content"], user_message.get("images"))
|
self.history.add_message(
|
||||||
|
context, "user", user_message["content"], user_message.get("images")
|
||||||
|
)
|
||||||
self.history.add_message(context, "assistant", buffer.getvalue())
|
self.history.add_message(context, "assistant", buffer.getvalue())
|
||||||
|
|
||||||
embed.add_field(name="Context Key", value=context, inline=True)
|
embed.add_field(name="Context Key", value=context, inline=True)
|
||||||
|
@ -911,7 +1014,9 @@ class Ollama(commands.Cog):
|
||||||
|
|
||||||
value = buffer.getvalue()
|
value = buffer.getvalue()
|
||||||
if len(value) >= 4096:
|
if len(value) >= 4096:
|
||||||
embeds = [discord.Embed(title="Done!", colour=discord.Color.green())]
|
embeds = [
|
||||||
|
discord.Embed(title="Done!", colour=discord.Color.green())
|
||||||
|
]
|
||||||
|
|
||||||
current_page = ""
|
current_page = ""
|
||||||
for word in value.split():
|
for word in value.split():
|
||||||
|
@ -971,13 +1076,19 @@ class Ollama(commands.Cog):
|
||||||
if message["role"] == "system":
|
if message["role"] == "system":
|
||||||
continue
|
continue
|
||||||
max_length = 4000 - len("> **%s**: " % message["role"])
|
max_length = 4000 - len("> **%s**: " % message["role"])
|
||||||
paginator.add_line("> **{}**: {}".format(message["role"], textwrap.shorten(message["content"], max_length)))
|
paginator.add_line(
|
||||||
|
"> **{}**: {}".format(
|
||||||
|
message["role"], textwrap.shorten(message["content"], max_length)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
embeds = []
|
embeds = []
|
||||||
for page in paginator.pages:
|
for page in paginator.pages:
|
||||||
embeds.append(discord.Embed(description=page))
|
embeds.append(discord.Embed(description=page))
|
||||||
ephemeral = len(embeds) > 1
|
ephemeral = len(embeds) > 1
|
||||||
for chunk in discord.utils.as_chunks(iter(embeds or [discord.Embed(title="No Content.")]), 10):
|
for chunk in discord.utils.as_chunks(
|
||||||
|
iter(embeds or [discord.Embed(title="No Content.")]), 10
|
||||||
|
):
|
||||||
await ctx.respond(embeds=chunk, ephemeral=ephemeral)
|
await ctx.respond(embeds=chunk, ephemeral=ephemeral)
|
||||||
|
|
||||||
@commands.command(name="ollama-status", aliases=["ollama_status", "os"])
|
@commands.command(name="ollama-status", aliases=["ollama_status", "os"])
|
||||||
|
@ -990,14 +1101,18 @@ class Ollama(commands.Cog):
|
||||||
if CONFIG["ollama"].get("order"):
|
if CONFIG["ollama"].get("order"):
|
||||||
ln = ["Server order:"]
|
ln = ["Server order:"]
|
||||||
for n, key in enumerate(CONFIG["ollama"].get("order"), start=1):
|
for n, key in enumerate(CONFIG["ollama"].get("order"), start=1):
|
||||||
zap = '\N{high voltage sign}'
|
zap = "\N{HIGH VOLTAGE SIGN}"
|
||||||
ln.append(f"{n}. {key!r} {f'({zap})' if CONFIG['ollama'][key].get('is_gpu') else ''}")
|
ln.append(
|
||||||
|
f"{n}. {key!r} {f'({zap})' if CONFIG['ollama'][key].get('is_gpu') else ''}"
|
||||||
|
)
|
||||||
embed.description = "\n".join(ln)
|
embed.description = "\n".join(ln)
|
||||||
|
|
||||||
for server, lock in self.servers.items():
|
for server, lock in self.servers.items():
|
||||||
embed.add_field(
|
embed.add_field(
|
||||||
name=server,
|
name=server,
|
||||||
value="\U000026a0\U0000fe0fIn use" if lock.locked() else "\U000023f3 Checking..."
|
value="\U000026a0\U0000fe0fIn use"
|
||||||
|
if lock.locked()
|
||||||
|
else "\U000023f3 Checking...",
|
||||||
)
|
)
|
||||||
|
|
||||||
msg = await ctx.reply(embed=embed)
|
msg = await ctx.reply(embed=embed)
|
||||||
|
@ -1006,23 +1121,35 @@ class Ollama(commands.Cog):
|
||||||
for server in self.servers.keys():
|
for server in self.servers.keys():
|
||||||
if self.servers[server].locked():
|
if self.servers[server].locked():
|
||||||
continue
|
continue
|
||||||
tasks[server] = asyncio.create_task(self.check_server(CONFIG["ollama"][server]["base_url"]))
|
tasks[server] = asyncio.create_task(
|
||||||
|
self.check_server(CONFIG["ollama"][server]["base_url"])
|
||||||
|
)
|
||||||
|
|
||||||
await asyncio.gather(*tasks.values())
|
await asyncio.gather(*tasks.values())
|
||||||
for server, task in tasks.items():
|
for server, task in tasks.items():
|
||||||
if not task.done:
|
if not task.done:
|
||||||
await task
|
await task
|
||||||
if e := task.exception():
|
if e := task.exception():
|
||||||
self.log.error("Error while checking server %r: %r", server, e, exc_info=True)
|
self.log.error(
|
||||||
|
"Error while checking server %r: %r", server, e, exc_info=True
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
ok = task.result()
|
ok = task.result()
|
||||||
for field in embed.fields:
|
for field in embed.fields:
|
||||||
if field.name == server:
|
if field.name == server:
|
||||||
field.value = "\U00002705 Online" if ok else "\U0000274c Offline"
|
field.value = (
|
||||||
|
"\U00002705 Online" if ok else "\U0000274c Offline"
|
||||||
|
)
|
||||||
break
|
break
|
||||||
await msg.edit(embed=embed)
|
await msg.edit(embed=embed)
|
||||||
|
|
||||||
async def generate_truth(self, ctx: commands.Context, entity: str, limit: int = None, question: str = None):
|
async def generate_truth(
|
||||||
|
self,
|
||||||
|
ctx: commands.Context,
|
||||||
|
entity: str,
|
||||||
|
limit: int = None,
|
||||||
|
question: str = None,
|
||||||
|
):
|
||||||
if entity == "trump":
|
if entity == "trump":
|
||||||
system = (
|
system = (
|
||||||
"You are the former united states president, Donald Trump, convicted of 34 felonies."
|
"You are the former united states president, Donald Trump, convicted of 34 felonies."
|
||||||
|
@ -1034,7 +1161,7 @@ class Ollama(commands.Cog):
|
||||||
system = (
|
system = (
|
||||||
"You are the 'Top G' Andrew Tate. Andrew Tate is a former professional kickboxer who turned to the "
|
"You are the 'Top G' Andrew Tate. Andrew Tate is a former professional kickboxer who turned to the "
|
||||||
"world of influencing. His personality is one of arrogance and self-aggrandizement, "
|
"world of influencing. His personality is one of arrogance and self-aggrandizement, "
|
||||||
"often referred to as \"King of Toxic Masculinity\" by critics due to his controversial views on "
|
'often referred to as "King of Toxic Masculinity" by critics due to his controversial views on '
|
||||||
"gender roles, relationships, and other topics. He has been involved in several controversies related "
|
"gender roles, relationships, and other topics. He has been involved in several controversies related "
|
||||||
"to his content online including promoting extremist ideologies and misogynistic views. "
|
"to his content online including promoting extremist ideologies and misogynistic views. "
|
||||||
"Despite this, he still has a large following and is known for being an entrepreneur who built multiple"
|
"Despite this, he still has a large following and is known for being an entrepreneur who built multiple"
|
||||||
|
@ -1120,9 +1247,9 @@ class Ollama(commands.Cog):
|
||||||
"of the United Kingdom from 2019 to 2022. Known for his flamboyant personality and "
|
"of the United Kingdom from 2019 to 2022. Known for his flamboyant personality and "
|
||||||
"controversial policies, Johnson's leadership was marred by numerous scandals and crises."
|
"controversial policies, Johnson's leadership was marred by numerous scandals and crises."
|
||||||
"Johnson's personality is characterized by his optimistic and charismatic demeanor, "
|
"Johnson's personality is characterized by his optimistic and charismatic demeanor, "
|
||||||
"often described as a \"dashing politician.\" However, his political career has been "
|
'often described as a "dashing politician." However, his political career has been '
|
||||||
"plagued by scandals involving ethics violations, questionable financial arrangements, "
|
"plagued by scandals involving ethics violations, questionable financial arrangements, "
|
||||||
"and allegations of inappropriate behavior. The \"Partygate\" controversy, where "
|
'and allegations of inappropriate behavior. The "Partygate" controversy, where '
|
||||||
"gatherings at Downing Street violated COVID-19 regulations, significantly damaged his reputation. "
|
"gatherings at Downing Street violated COVID-19 regulations, significantly damaged his reputation. "
|
||||||
"Johnson's achievements include leading the U.K. out of the European Union and brokering a "
|
"Johnson's achievements include leading the U.K. out of the European Union and brokering a "
|
||||||
"post-Brexit trade deal with the United States. However, his tumultuous tenure was punctuated by "
|
"post-Brexit trade deal with the United States. However, his tumultuous tenure was punctuated by "
|
||||||
|
@ -1162,21 +1289,22 @@ class Ollama(commands.Cog):
|
||||||
" Write using the style of a twitter or facebook post. Do not repeat a previous post or any previous "
|
" Write using the style of a twitter or facebook post. Do not repeat a previous post or any previous "
|
||||||
"content. Do not include URLs or links. If an @user asks you a question, reply to them in your post."
|
"content. Do not include URLs or links. If an @user asks you a question, reply to them in your post."
|
||||||
)
|
)
|
||||||
thread_id = self.history.create_thread(
|
thread_id = self.history.create_thread(ctx.author, system)
|
||||||
ctx.author,
|
|
||||||
system
|
|
||||||
)
|
|
||||||
r = CONFIG["truth"].get("api", "https://bots.nexy7574.co.uk/jimmy/v2/api")
|
r = CONFIG["truth"].get("api", "https://bots.nexy7574.co.uk/jimmy/v2/api")
|
||||||
username = CONFIG["truth"].get("username", "1")
|
username = CONFIG["truth"].get("username", "1")
|
||||||
password = CONFIG["truth"].get("password", "2")
|
password = CONFIG["truth"].get("password", "2")
|
||||||
async with httpx.AsyncClient(base_url=r, auth=(username, password), trust_env=False) as http_client:
|
async with httpx.AsyncClient(
|
||||||
|
base_url=r, auth=(username, password), trust_env=False
|
||||||
|
) as http_client:
|
||||||
response = await http_client.get(
|
response = await http_client.get(
|
||||||
"/truths",
|
"/truths",
|
||||||
timeout=60,
|
timeout=60,
|
||||||
)
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
truths = response.json()
|
truths = response.json()
|
||||||
truths: list[TruthPayload] = list(map(lambda t: TruthPayload.model_validate(t), truths))
|
truths: list[TruthPayload] = list(
|
||||||
|
map(lambda t: TruthPayload.model_validate(t), truths)
|
||||||
|
)
|
||||||
|
|
||||||
if entity:
|
if entity:
|
||||||
truths = list(filter(lambda t: t.author == entity, truths))
|
truths = list(filter(lambda t: t.author == entity, truths))
|
||||||
|
@ -1191,11 +1319,13 @@ class Ollama(commands.Cog):
|
||||||
thread_id,
|
thread_id,
|
||||||
"assistant",
|
"assistant",
|
||||||
truth.content,
|
truth.content,
|
||||||
save=False
|
save=False,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
if not question:
|
if not question:
|
||||||
self.history.add_message(thread_id, "user", "Generate a new truth post.")
|
self.history.add_message(
|
||||||
|
thread_id, "user", "Generate a new truth post."
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
self.history.add_message(thread_id, "user", question)
|
self.history.add_message(thread_id, "user", question)
|
||||||
|
|
||||||
|
@ -1204,32 +1334,36 @@ class Ollama(commands.Cog):
|
||||||
server = self.next_server(tried)
|
server = self.next_server(tried)
|
||||||
is_gpu = CONFIG["ollama"][server].get("is_gpu", False)
|
is_gpu = CONFIG["ollama"][server].get("is_gpu", False)
|
||||||
if not is_gpu:
|
if not is_gpu:
|
||||||
self.log.info("Skipping server %r as it is not a GPU server.", server)
|
self.log.info(
|
||||||
|
"Skipping server %r as it is not a GPU server.", server
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
if await self.check_server(CONFIG["ollama"][server]["base_url"]):
|
if await self.check_server(CONFIG["ollama"][server]["base_url"]):
|
||||||
break
|
break
|
||||||
tried.add(server)
|
tried.add(server)
|
||||||
else:
|
else:
|
||||||
return await ctx.reply("All servers are offline. Please try again later.", delete_after=300)
|
return await ctx.reply(
|
||||||
|
"All servers are offline. Please try again later.", delete_after=300
|
||||||
|
)
|
||||||
|
|
||||||
client = OllamaClient(CONFIG["ollama"][server]["base_url"])
|
client = OllamaClient(CONFIG["ollama"][server]["base_url"])
|
||||||
async with self.servers[server]:
|
async with self.servers[server]:
|
||||||
if not await client.has_model_named("llama2-uncensored", "7b-chat"):
|
if not await client.has_model_named("llama2-uncensored", "7b-chat"):
|
||||||
with client.download_model("llama2-uncensored", "7b-chat") as handler:
|
with client.download_model(
|
||||||
|
"llama2-uncensored", "7b-chat"
|
||||||
|
) as handler:
|
||||||
await handler.flatten()
|
await handler.flatten()
|
||||||
|
|
||||||
embed = discord.Embed(
|
embed = discord.Embed(
|
||||||
title=f"New {post_type.title()}!",
|
title=f"New {post_type.title()}!", description="", colour=0x6559FF
|
||||||
description="",
|
|
||||||
colour=0x6559FF
|
|
||||||
)
|
)
|
||||||
msg = await ctx.reply(embed=embed)
|
msg = await ctx.reply(embed=embed)
|
||||||
last_edit = time.time()
|
last_edit = time.time()
|
||||||
messages = self.history.get_history(thread_id)
|
messages = self.history.get_history(thread_id)
|
||||||
with client.new_chat(
|
with client.new_chat(
|
||||||
"llama2-uncensored:7b-chat",
|
"llama2-uncensored:7b-chat",
|
||||||
messages,
|
messages,
|
||||||
options={"num_ctx": 4096, "num_predict": 128, "temperature": 1.5}
|
options={"num_ctx": 4096, "num_predict": 128, "temperature": 1.5},
|
||||||
) as handler:
|
) as handler:
|
||||||
async for ln in handler:
|
async for ln in handler:
|
||||||
embed.description += ln["message"]["content"]
|
embed.description += ln["message"]["content"]
|
||||||
|
@ -1245,7 +1379,7 @@ class Ollama(commands.Cog):
|
||||||
if truth.content == embed.description:
|
if truth.content == embed.description:
|
||||||
embed.add_field(
|
embed.add_field(
|
||||||
name=f"Repeated {post_type} :(",
|
name=f"Repeated {post_type} :(",
|
||||||
value=f"This truth was already {post_type}ed. Shit AI."
|
value=f"This truth was already {post_type}ed. Shit AI.",
|
||||||
)
|
)
|
||||||
elif _ratio >= 70:
|
elif _ratio >= 70:
|
||||||
similar[truth.id] = _ratio
|
similar[truth.id] = _ratio
|
||||||
|
@ -1253,12 +1387,16 @@ class Ollama(commands.Cog):
|
||||||
if similar:
|
if similar:
|
||||||
if len(similar) > 1:
|
if len(similar) > 1:
|
||||||
lns = []
|
lns = []
|
||||||
keys = sorted(similar.keys(), key=lambda k: similar[k], reverse=True)
|
keys = sorted(
|
||||||
|
similar.keys(), key=lambda k: similar[k], reverse=True
|
||||||
|
)
|
||||||
for truth_id in keys:
|
for truth_id in keys:
|
||||||
_ratio = similar[truth_id]
|
_ratio = similar[truth_id]
|
||||||
truth = discord.utils.get(truths, id=truth_id)
|
truth = discord.utils.get(truths, id=truth_id)
|
||||||
first_line = truth.content.splitlines()[0]
|
first_line = truth.content.splitlines()[0]
|
||||||
preview = discord.utils.escape_markdown(textwrap.shorten(first_line, 100))
|
preview = discord.utils.escape_markdown(
|
||||||
|
textwrap.shorten(first_line, 100)
|
||||||
|
)
|
||||||
lns.append(f"* `{truth_id}`: {_ratio}% - `{preview}`")
|
lns.append(f"* `{truth_id}`: {_ratio}% - `{preview}`")
|
||||||
if len(lns) > 5:
|
if len(lns) > 5:
|
||||||
lc = len(lns) - 5
|
lc = len(lns) - 5
|
||||||
|
@ -1266,33 +1404,39 @@ class Ollama(commands.Cog):
|
||||||
lns.append(f"*... and {lc} more*")
|
lns.append(f"*... and {lc} more*")
|
||||||
embed.add_field(
|
embed.add_field(
|
||||||
name=f"Possibly repeated {post_type}",
|
name=f"Possibly repeated {post_type}",
|
||||||
value=f"This {post_type} was similar to the following existing ones:\n" + "\n".join(lns),
|
value=f"This {post_type} was similar to the following existing ones:\n"
|
||||||
inline=False
|
+ "\n".join(lns),
|
||||||
|
inline=False,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
truth_id = tuple(similar)[0]
|
truth_id = tuple(similar)[0]
|
||||||
_ratio = similar[truth_id]
|
_ratio = similar[truth_id]
|
||||||
truth = discord.utils.get(truths, id=truth_id)
|
truth = discord.utils.get(truths, id=truth_id)
|
||||||
first_line = truth.content.splitlines()[0]
|
first_line = truth.content.splitlines()[0]
|
||||||
preview = discord.utils.escape_markdown(textwrap.shorten(first_line, 512))
|
preview = discord.utils.escape_markdown(
|
||||||
|
textwrap.shorten(first_line, 512)
|
||||||
|
)
|
||||||
embed.add_field(
|
embed.add_field(
|
||||||
name=f"Possibly repeated {post_type}",
|
name=f"Possibly repeated {post_type}",
|
||||||
value=f"This {post_type} was {_ratio}% similar to `{truth_id}`.\n>>> {preview}"
|
value=f"This {post_type} was {_ratio}% similar to `{truth_id}`.\n>>> {preview}",
|
||||||
)
|
)
|
||||||
|
|
||||||
embed.set_footer(
|
embed.set_footer(
|
||||||
text="Finished generating {} based off of {:,} messages, using server {!r} | {!s}".format(
|
text="Finished generating {} based off of {:,} messages, using server {!r} | {!s}".format(
|
||||||
post_type,
|
post_type, len(messages) - 2, server, thread_id
|
||||||
len(messages) - 2,
|
|
||||||
server,
|
|
||||||
thread_id
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
await msg.edit(embed=embed)
|
await msg.edit(embed=embed)
|
||||||
|
|
||||||
@commands.command(aliases=["trump"])
|
@commands.command(aliases=["trump"])
|
||||||
@commands.guild_only()
|
@commands.guild_only()
|
||||||
async def donald(self, ctx: commands.Context, latest: typing.Optional[int] = None, *, question: str = None):
|
async def donald(
|
||||||
|
self,
|
||||||
|
ctx: commands.Context,
|
||||||
|
latest: typing.Optional[int] = None,
|
||||||
|
*,
|
||||||
|
question: str = None,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Generates a truth social post from trump!
|
Generates a truth social post from trump!
|
||||||
|
|
||||||
|
@ -1307,7 +1451,13 @@ class Ollama(commands.Cog):
|
||||||
|
|
||||||
@commands.command(aliases=["tate"])
|
@commands.command(aliases=["tate"])
|
||||||
@commands.guild_only()
|
@commands.guild_only()
|
||||||
async def andrew(self, ctx: commands.Context, latest: typing.Optional[int] = None, *, question: str = None):
|
async def andrew(
|
||||||
|
self,
|
||||||
|
ctx: commands.Context,
|
||||||
|
latest: typing.Optional[int] = None,
|
||||||
|
*,
|
||||||
|
question: str = None,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Generates a truth social post from Andrew Tate
|
Generates a truth social post from Andrew Tate
|
||||||
|
|
||||||
|
@ -1319,10 +1469,16 @@ class Ollama(commands.Cog):
|
||||||
question = f"'@{ctx.author.display_name}' asks: {question!r}"
|
question = f"'@{ctx.author.display_name}' asks: {question!r}"
|
||||||
async with ctx.channel.typing():
|
async with ctx.channel.typing():
|
||||||
await self.generate_truth(ctx, "tate", latest, question=question)
|
await self.generate_truth(ctx, "tate", latest, question=question)
|
||||||
|
|
||||||
@commands.command(aliases=["sunak"])
|
@commands.command(aliases=["sunak"])
|
||||||
@commands.guild_only()
|
@commands.guild_only()
|
||||||
async def rishi(self, ctx: commands.Context, latest: typing.Optional[int] = None, *, question: str = None):
|
async def rishi(
|
||||||
|
self,
|
||||||
|
ctx: commands.Context,
|
||||||
|
latest: typing.Optional[int] = None,
|
||||||
|
*,
|
||||||
|
question: str = None,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Generates a twitter post from Rishi Sunak
|
Generates a twitter post from Rishi Sunak
|
||||||
|
|
||||||
|
@ -1334,10 +1490,16 @@ class Ollama(commands.Cog):
|
||||||
question = f"'@{ctx.author.display_name}' asks: {question!r}"
|
question = f"'@{ctx.author.display_name}' asks: {question!r}"
|
||||||
async with ctx.channel.typing():
|
async with ctx.channel.typing():
|
||||||
await self.generate_truth(ctx, "Rishi Sunak", latest, question=question)
|
await self.generate_truth(ctx, "Rishi Sunak", latest, question=question)
|
||||||
|
|
||||||
@commands.command(aliases=["robinson"])
|
@commands.command(aliases=["robinson"])
|
||||||
@commands.guild_only()
|
@commands.guild_only()
|
||||||
async def tommy(self, ctx: commands.Context, latest: typing.Optional[int] = None, *, question: str = None):
|
async def tommy(
|
||||||
|
self,
|
||||||
|
ctx: commands.Context,
|
||||||
|
latest: typing.Optional[int] = None,
|
||||||
|
*,
|
||||||
|
question: str = None,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Generates a twitter post from Tommy Robinson
|
Generates a twitter post from Tommy Robinson
|
||||||
|
|
||||||
|
@ -1348,11 +1510,19 @@ class Ollama(commands.Cog):
|
||||||
if question:
|
if question:
|
||||||
question = f"'@{ctx.author.display_name}' asks: {question!r}"
|
question = f"'@{ctx.author.display_name}' asks: {question!r}"
|
||||||
async with ctx.channel.typing():
|
async with ctx.channel.typing():
|
||||||
await self.generate_truth(ctx, "Tommy Robinson 🇬🇧", latest, question=question)
|
await self.generate_truth(
|
||||||
|
ctx, "Tommy Robinson 🇬🇧", latest, question=question
|
||||||
|
)
|
||||||
|
|
||||||
@commands.command(aliases=["fox"])
|
@commands.command(aliases=["fox"])
|
||||||
@commands.guild_only()
|
@commands.guild_only()
|
||||||
async def laurence(self, ctx: commands.Context, latest: typing.Optional[int] = None, *, question: str = None):
|
async def laurence(
|
||||||
|
self,
|
||||||
|
ctx: commands.Context,
|
||||||
|
latest: typing.Optional[int] = None,
|
||||||
|
*,
|
||||||
|
question: str = None,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Generates a twitter post from Laurence Fox
|
Generates a twitter post from Laurence Fox
|
||||||
|
|
||||||
|
@ -1364,10 +1534,16 @@ class Ollama(commands.Cog):
|
||||||
question = f"'@{ctx.author.display_name}' asks: {question!r}"
|
question = f"'@{ctx.author.display_name}' asks: {question!r}"
|
||||||
async with ctx.channel.typing():
|
async with ctx.channel.typing():
|
||||||
await self.generate_truth(ctx, "Laurence Fox", latest, question=question)
|
await self.generate_truth(ctx, "Laurence Fox", latest, question=question)
|
||||||
|
|
||||||
@commands.command(aliases=["farage"])
|
@commands.command(aliases=["farage"])
|
||||||
@commands.guild_only()
|
@commands.guild_only()
|
||||||
async def nigel(self, ctx: commands.Context, latest: typing.Optional[int] = None, *, question: str = None):
|
async def nigel(
|
||||||
|
self,
|
||||||
|
ctx: commands.Context,
|
||||||
|
latest: typing.Optional[int] = None,
|
||||||
|
*,
|
||||||
|
question: str = None,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Generates a twitter post from Nigel Farage
|
Generates a twitter post from Nigel Farage
|
||||||
|
|
||||||
|
@ -1382,7 +1558,13 @@ class Ollama(commands.Cog):
|
||||||
|
|
||||||
@commands.command(aliases=["starmer"])
|
@commands.command(aliases=["starmer"])
|
||||||
@commands.guild_only()
|
@commands.guild_only()
|
||||||
async def keir(self, ctx: commands.Context, latest: typing.Optional[int] = None, *, question: str = None):
|
async def keir(
|
||||||
|
self,
|
||||||
|
ctx: commands.Context,
|
||||||
|
latest: typing.Optional[int] = None,
|
||||||
|
*,
|
||||||
|
question: str = None,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Generates a twitter post from Keir Starmer
|
Generates a twitter post from Keir Starmer
|
||||||
|
|
||||||
|
@ -1394,10 +1576,16 @@ class Ollama(commands.Cog):
|
||||||
question = f"'@{ctx.author.display_name}' asks: {question!r}"
|
question = f"'@{ctx.author.display_name}' asks: {question!r}"
|
||||||
async with ctx.channel.typing():
|
async with ctx.channel.typing():
|
||||||
await self.generate_truth(ctx, "Keir Starmer", latest, question=question)
|
await self.generate_truth(ctx, "Keir Starmer", latest, question=question)
|
||||||
|
|
||||||
@commands.command(aliases=["johnson"])
|
@commands.command(aliases=["johnson"])
|
||||||
@commands.guild_only()
|
@commands.guild_only()
|
||||||
async def boris(self, ctx: commands.Context, latest: typing.Optional[int] = None, *, question: str = None):
|
async def boris(
|
||||||
|
self,
|
||||||
|
ctx: commands.Context,
|
||||||
|
latest: typing.Optional[int] = None,
|
||||||
|
*,
|
||||||
|
question: str = None,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Generates a twitter post from Boris Johnson
|
Generates a twitter post from Boris Johnson
|
||||||
|
|
||||||
|
@ -1409,10 +1597,16 @@ class Ollama(commands.Cog):
|
||||||
question = f"'@{ctx.author.display_name}' asks: {question!r}"
|
question = f"'@{ctx.author.display_name}' asks: {question!r}"
|
||||||
async with ctx.channel.typing():
|
async with ctx.channel.typing():
|
||||||
await self.generate_truth(ctx, "Boris Johnson", latest, question=question)
|
await self.generate_truth(ctx, "Boris Johnson", latest, question=question)
|
||||||
|
|
||||||
@commands.command(aliases=["desantis"])
|
@commands.command(aliases=["desantis"])
|
||||||
@commands.guild_only()
|
@commands.guild_only()
|
||||||
async def ron(self, ctx: commands.Context, latest: typing.Optional[int] = None, *, question: str = None):
|
async def ron(
|
||||||
|
self,
|
||||||
|
ctx: commands.Context,
|
||||||
|
latest: typing.Optional[int] = None,
|
||||||
|
*,
|
||||||
|
question: str = None,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Generates a twitter post from Ron Desantis
|
Generates a twitter post from Ron Desantis
|
||||||
|
|
||||||
|
@ -1447,12 +1641,15 @@ class Ollama(commands.Cog):
|
||||||
title="Truth:",
|
title="Truth:",
|
||||||
description=truth.content,
|
description=truth.content,
|
||||||
colour=0x6559FF,
|
colour=0x6559FF,
|
||||||
timestamp=datetime.datetime.fromtimestamp(truth.timestamp, tz=datetime.timezone.utc),
|
timestamp=datetime.datetime.fromtimestamp(
|
||||||
|
truth.timestamp, tz=datetime.timezone.utc
|
||||||
|
),
|
||||||
)
|
)
|
||||||
if truth.extra:
|
if truth.extra:
|
||||||
embed.add_field(
|
embed.add_field(
|
||||||
name="Extra Data",
|
name="Extra Data",
|
||||||
value="```json\n%s```" % json.dumps(truth.extra, indent=2, default=repr),
|
value="```json\n%s```"
|
||||||
|
% json.dumps(truth.extra, indent=2, default=repr),
|
||||||
)
|
)
|
||||||
embed.set_author(name=truth.author)
|
embed.set_author(name=truth.author)
|
||||||
await ctx.reply(embed=embed)
|
await ctx.reply(embed=embed)
|
||||||
|
|
112
src/cogs/onion_feed.py
Normal file
112
src/cogs/onion_feed.py
Normal file
|
@ -0,0 +1,112 @@
|
||||||
|
"""
|
||||||
|
Scrapes the onion RSS feed once every hour and posts any new articles to the desired channel
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import dataclasses
|
||||||
|
import datetime
|
||||||
|
import logging
|
||||||
|
from bs4 import BeautifulSoup
|
||||||
|
import discord
|
||||||
|
from discord.ext import commands, tasks
|
||||||
|
import httpx
|
||||||
|
from conf import CONFIG
|
||||||
|
import redis
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class RSSItem:
|
||||||
|
title: str
|
||||||
|
link: str
|
||||||
|
description: str
|
||||||
|
pubDate: datetime.datetime
|
||||||
|
guid: str
|
||||||
|
thumbnail: str
|
||||||
|
|
||||||
|
|
||||||
|
class OnionFeed(commands.Cog):
|
||||||
|
SOURCE = "https://www.theonion.com/rss"
|
||||||
|
EPOCH = datetime.datetime(2024, 7, 1, tzinfo=datetime.timezone.utc)
|
||||||
|
|
||||||
|
def __init__(self, bot):
|
||||||
|
self.bot: commands.Bot = bot
|
||||||
|
self.log = logging.getLogger("jimmy.cogs.onion_feed")
|
||||||
|
self.check_onion_feed.start()
|
||||||
|
self.redis = redis.Redis(**CONFIG["redis"])
|
||||||
|
|
||||||
|
def cog_unload(self) -> None:
|
||||||
|
self.check_onion_feed.cancel()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def parse_item(item: BeautifulSoup) -> RSSItem:
|
||||||
|
description = (
|
||||||
|
BeautifulSoup(item.description.get_text(), "html.parser")
|
||||||
|
.p.get_text(strip=True)
|
||||||
|
.strip()[:-1]
|
||||||
|
)
|
||||||
|
kwargs = {
|
||||||
|
"title": item.title.get_text(strip=True).strip(),
|
||||||
|
"link": item.link.get_text(strip=True).strip(),
|
||||||
|
"pubDate": datetime.datetime.strptime(
|
||||||
|
item.pubDate.get_text(strip=True).strip(), "%a, %d %b %Y %H:%M:%S %Z"
|
||||||
|
),
|
||||||
|
"guid": item.guid.get_text(strip=True).strip(),
|
||||||
|
"description": description,
|
||||||
|
"thumbnail": item.find("media:thumbnail")["url"],
|
||||||
|
}
|
||||||
|
return RSSItem(**kwargs)
|
||||||
|
|
||||||
|
def parse_feed(self, text: str) -> list[RSSItem]:
|
||||||
|
soup = BeautifulSoup(text, "xml")
|
||||||
|
return [self.parse_item(item) for item in soup.find_all("item")]
|
||||||
|
|
||||||
|
@tasks.loop(hours=1)
|
||||||
|
async def check_onion_feed(self):
|
||||||
|
if not self.bot.is_ready():
|
||||||
|
await self.bot.wait_until_ready()
|
||||||
|
|
||||||
|
guild = self.bot.get_guild(994710566612500550)
|
||||||
|
if not guild:
|
||||||
|
return self.log.error("Nonsense guild not found. Can't do onion feed.")
|
||||||
|
channel = discord.utils.get(guild.text_channels, name="spam")
|
||||||
|
if not channel:
|
||||||
|
return self.log.error("Spam channel not found.")
|
||||||
|
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
response = await client.get(self.SOURCE)
|
||||||
|
if response.status_code != 200:
|
||||||
|
return self.log.error(
|
||||||
|
f"Failed to fetch onion feed: {response.status_code}"
|
||||||
|
)
|
||||||
|
items: list[RSSItem] = await asyncio.to_thread(
|
||||||
|
self.parse_feed, response.text
|
||||||
|
)
|
||||||
|
for item in items:
|
||||||
|
if self.redis.get("onion-" + item.guid):
|
||||||
|
continue
|
||||||
|
embed = discord.Embed(
|
||||||
|
title=item.title,
|
||||||
|
url=item.link,
|
||||||
|
description=item.description + f"... [Read More]({item.link})",
|
||||||
|
color=0x00DF78,
|
||||||
|
timestamp=item.pubDate,
|
||||||
|
)
|
||||||
|
embed.set_thumbnail(url=item.thumbnail)
|
||||||
|
try:
|
||||||
|
msg = await channel.send(embed=embed)
|
||||||
|
# noinspection PyAsyncCall
|
||||||
|
self.redis.set("onion-" + item.guid, str(msg.id))
|
||||||
|
except discord.HTTPException:
|
||||||
|
self.log.exception(
|
||||||
|
f"Failed to send onion feed message: {item.title}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.log.debug(f"Sent onion feed message: {item.title}")
|
||||||
|
|
||||||
|
@check_onion_feed.before_loop
|
||||||
|
async def before_check_onion_feed(self):
|
||||||
|
await self.bot.wait_until_ready()
|
||||||
|
|
||||||
|
|
||||||
|
def setup(bot):
|
||||||
|
bot.add_cog(OnionFeed(bot))
|
|
@ -16,11 +16,9 @@ from discord.ext import commands
|
||||||
from conf import CONFIG
|
from conf import CONFIG
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
JSON: typing.Union[
|
JSON: typing.Union[str, int, float, bool, None, dict[str, "JSON"], list["JSON"]] = (
|
||||||
str, int, float, bool, None, dict[str, "JSON"], list["JSON"]
|
typing.Union[str, int, float, bool, None, dict, list]
|
||||||
] = typing.Union[
|
)
|
||||||
str, int, float, bool, None, dict, list
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
class TruthPayload(BaseModel):
|
class TruthPayload(BaseModel):
|
||||||
|
@ -32,7 +30,6 @@ class TruthPayload(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
class QuoteQuota(commands.Cog):
|
class QuoteQuota(commands.Cog):
|
||||||
|
|
||||||
def __init__(self, bot):
|
def __init__(self, bot):
|
||||||
self.bot = bot
|
self.bot = bot
|
||||||
self.quotes_channel_id = CONFIG["quote_a"].get("channel_id")
|
self.quotes_channel_id = CONFIG["quote_a"].get("channel_id")
|
||||||
|
@ -47,7 +44,9 @@ class QuoteQuota(commands.Cog):
|
||||||
return c
|
return c
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def generate_pie_chart(usernames: list[str], counts: list[int], no_other: bool = False) -> discord.File:
|
def generate_pie_chart(
|
||||||
|
usernames: list[str], counts: list[int], no_other: bool = False
|
||||||
|
) -> discord.File:
|
||||||
"""
|
"""
|
||||||
Converts the given username and count tuples into a nice pretty pie chart.
|
Converts the given username and count tuples into a nice pretty pie chart.
|
||||||
|
|
||||||
|
@ -95,7 +94,9 @@ class QuoteQuota(commands.Cog):
|
||||||
startangle=90,
|
startangle=90,
|
||||||
radius=1.2,
|
radius=1.2,
|
||||||
)
|
)
|
||||||
fig.subplots_adjust(left=0.1, bottom=0.1, right=0.9, top=0.9, wspace=0.3, hspace=0.4)
|
fig.subplots_adjust(
|
||||||
|
left=0.1, bottom=0.1, right=0.9, top=0.9, wspace=0.3, hspace=0.4
|
||||||
|
)
|
||||||
fio = io.BytesIO()
|
fio = io.BytesIO()
|
||||||
fig.savefig(fio, format="png")
|
fig.savefig(fio, format="png")
|
||||||
fio.seek(0)
|
fio.seek(0)
|
||||||
|
@ -130,7 +131,9 @@ class QuoteQuota(commands.Cog):
|
||||||
now = discord.utils.utcnow()
|
now = discord.utils.utcnow()
|
||||||
oldest = now - timedelta(days=days)
|
oldest = now - timedelta(days=days)
|
||||||
await ctx.defer()
|
await ctx.defer()
|
||||||
channel = self.quotes_channel or discord.utils.get(ctx.guild.text_channels, name="quotes")
|
channel = self.quotes_channel or discord.utils.get(
|
||||||
|
ctx.guild.text_channels, name="quotes"
|
||||||
|
)
|
||||||
if not channel:
|
if not channel:
|
||||||
return await ctx.respond(":x: Cannot find quotes channel.")
|
return await ctx.respond(":x: Cannot find quotes channel.")
|
||||||
|
|
||||||
|
@ -139,7 +142,9 @@ class QuoteQuota(commands.Cog):
|
||||||
authors = {}
|
authors = {}
|
||||||
filtered_messages = 0
|
filtered_messages = 0
|
||||||
total = 0
|
total = 0
|
||||||
async for message in channel.history(limit=None, after=oldest, oldest_first=False):
|
async for message in channel.history(
|
||||||
|
limit=None, after=oldest, oldest_first=False
|
||||||
|
):
|
||||||
total += 1
|
total += 1
|
||||||
if not message.content:
|
if not message.content:
|
||||||
filtered_messages += 1
|
filtered_messages += 1
|
||||||
|
@ -179,10 +184,15 @@ class QuoteQuota(commands.Cog):
|
||||||
' (e.g. `"This is my quote" - Jimmy`)'.format(days)
|
' (e.g. `"This is my quote" - Jimmy`)'.format(days)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return await ctx.edit(content="No messages found in the last {!s} days.".format(days))
|
return await ctx.edit(
|
||||||
|
content="No messages found in the last {!s} days.".format(days)
|
||||||
|
)
|
||||||
|
|
||||||
file = await asyncio.to_thread(
|
file = await asyncio.to_thread(
|
||||||
self.generate_pie_chart, list(authors.keys()), list(authors.values()), merge_other
|
self.generate_pie_chart,
|
||||||
|
list(authors.keys()),
|
||||||
|
list(authors.values()),
|
||||||
|
merge_other,
|
||||||
)
|
)
|
||||||
return await ctx.edit(
|
return await ctx.edit(
|
||||||
content="{:,} messages (out of {:,}) were filtered (didn't follow format?)".format(
|
content="{:,} messages (out of {:,}) were filtered (didn't follow format?)".format(
|
||||||
|
@ -192,7 +202,11 @@ class QuoteQuota(commands.Cog):
|
||||||
)
|
)
|
||||||
|
|
||||||
def _metacounter(
|
def _metacounter(
|
||||||
self, truths: list[TruthPayload], filter_func: Callable[[TruthPayload], bool], *, now: datetime = None
|
self,
|
||||||
|
truths: list[TruthPayload],
|
||||||
|
filter_func: Callable[[TruthPayload], bool],
|
||||||
|
*,
|
||||||
|
now: datetime = None,
|
||||||
) -> dict[str, float | int | dict[str, int]]:
|
) -> dict[str, float | int | dict[str, int]]:
|
||||||
def _is_today(date: datetime) -> bool:
|
def _is_today(date: datetime) -> bool:
|
||||||
return date.date() == now.date()
|
return date.date() == now.date()
|
||||||
|
@ -215,7 +229,9 @@ class QuoteQuota(commands.Cog):
|
||||||
if filter_func(truth):
|
if filter_func(truth):
|
||||||
created_at = datetime.fromtimestamp(truth.timestamp, tz=timezone.utc)
|
created_at = datetime.fromtimestamp(truth.timestamp, tz=timezone.utc)
|
||||||
age = now - created_at
|
age = now - created_at
|
||||||
self.log.debug("%r was a truth (%.2f seconds ago).", truth.id, age.total_seconds())
|
self.log.debug(
|
||||||
|
"%r was a truth (%.2f seconds ago).", truth.id, age.total_seconds()
|
||||||
|
)
|
||||||
counts["all_time"] += 1
|
counts["all_time"] += 1
|
||||||
if _is_today(created_at):
|
if _is_today(created_at):
|
||||||
counts["today"] += 1
|
counts["today"] += 1
|
||||||
|
@ -275,7 +291,9 @@ class QuoteQuota(commands.Cog):
|
||||||
plt.bar(hrs, list(hours.values()), color="#5448EE")
|
plt.bar(hrs, list(hours.values()), color="#5448EE")
|
||||||
|
|
||||||
average = sum(hours.values()) / len(hours)
|
average = sum(hours.values()) / len(hours)
|
||||||
plt.axhline(average, color="red", linestyle="--", label=f"Average: {average:.1f}")
|
plt.axhline(
|
||||||
|
average, color="red", linestyle="--", label=f"Average: {average:.1f}"
|
||||||
|
)
|
||||||
|
|
||||||
file = io.BytesIO()
|
file = io.BytesIO()
|
||||||
plt.savefig(file, format="png")
|
plt.savefig(file, format="png")
|
||||||
|
@ -283,8 +301,7 @@ class QuoteQuota(commands.Cog):
|
||||||
return discord.File(file, "truths.png")
|
return discord.File(file, "truths.png")
|
||||||
|
|
||||||
async def _process_all_messages(
|
async def _process_all_messages(
|
||||||
self,
|
self, truths: list[TruthPayload]
|
||||||
truths: list[TruthPayload]
|
|
||||||
) -> tuple[discord.Embed, discord.File]:
|
) -> tuple[discord.Embed, discord.File]:
|
||||||
"""
|
"""
|
||||||
Processes all the messages in the given channel.
|
Processes all the messages in the given channel.
|
||||||
|
@ -292,7 +309,11 @@ class QuoteQuota(commands.Cog):
|
||||||
:param truths: The truths to process
|
:param truths: The truths to process
|
||||||
:returns: The stats
|
:returns: The stats
|
||||||
"""
|
"""
|
||||||
embed = discord.Embed(title="Truth Counts", color=discord.Color.blurple(), timestamp=discord.utils.utcnow())
|
embed = discord.Embed(
|
||||||
|
title="Truth Counts",
|
||||||
|
color=discord.Color.blurple(),
|
||||||
|
timestamp=discord.utils.utcnow(),
|
||||||
|
)
|
||||||
trump_stats = await self._process_trump_truths(truths)
|
trump_stats = await self._process_trump_truths(truths)
|
||||||
tate_stats = await self._process_tate_truths(truths)
|
tate_stats = await self._process_tate_truths(truths)
|
||||||
|
|
||||||
|
@ -346,7 +367,9 @@ class QuoteQuota(commands.Cog):
|
||||||
)
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
truths: list[dict] = response.json()
|
truths: list[dict] = response.json()
|
||||||
truths: list[TruthPayload] = list(map(lambda t: TruthPayload.model_validate(t), truths))
|
truths: list[TruthPayload] = list(
|
||||||
|
map(lambda t: TruthPayload.model_validate(t), truths)
|
||||||
|
)
|
||||||
embed, file = await self._process_all_messages(truths)
|
embed, file = await self._process_all_messages(truths)
|
||||||
await ctx.edit(embed=embed, file=file)
|
await ctx.edit(embed=embed, file=file)
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,4 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import copy
|
|
||||||
import datetime
|
import datetime
|
||||||
import io
|
import io
|
||||||
import logging
|
import logging
|
||||||
|
@ -133,7 +132,9 @@ class ScreenshotCog(commands.Cog):
|
||||||
else:
|
else:
|
||||||
use_proxy = False
|
use_proxy = False
|
||||||
service = await asyncio.to_thread(ChromeService)
|
service = await asyncio.to_thread(ChromeService)
|
||||||
driver: webdriver.Chrome = await asyncio.to_thread(webdriver.Chrome, service=service, options=options)
|
driver: webdriver.Chrome = await asyncio.to_thread(
|
||||||
|
webdriver.Chrome, service=service, options=options
|
||||||
|
)
|
||||||
driver.set_page_load_timeout(load_timeout)
|
driver.set_page_load_timeout(load_timeout)
|
||||||
if resolution:
|
if resolution:
|
||||||
resolution = RESOLUTIONS.get(resolution.lower(), resolution)
|
resolution = RESOLUTIONS.get(resolution.lower(), resolution)
|
||||||
|
@ -141,9 +142,13 @@ class ScreenshotCog(commands.Cog):
|
||||||
width, height = map(int, resolution.split("x"))
|
width, height = map(int, resolution.split("x"))
|
||||||
driver.set_window_size(width, height)
|
driver.set_window_size(width, height)
|
||||||
if height > 4320 or width > 7680:
|
if height > 4320 or width > 7680:
|
||||||
return await ctx.respond("Invalid resolution. Max resolution is 7680x4320 (8K).")
|
return await ctx.respond(
|
||||||
|
"Invalid resolution. Max resolution is 7680x4320 (8K)."
|
||||||
|
)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
return await ctx.respond("Invalid resolution. please provide width x height, e.g. 1920x1080")
|
return await ctx.respond(
|
||||||
|
"Invalid resolution. please provide width x height, e.g. 1920x1080"
|
||||||
|
)
|
||||||
if eager:
|
if eager:
|
||||||
driver.implicitly_wait(render_timeout)
|
driver.implicitly_wait(render_timeout)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -151,7 +156,13 @@ class ScreenshotCog(commands.Cog):
|
||||||
raise
|
raise
|
||||||
end_init = time.time()
|
end_init = time.time()
|
||||||
|
|
||||||
await ctx.edit(content=("Loading webpage..." if not eager else "Loading & screenshotting webpage..."))
|
await ctx.edit(
|
||||||
|
content=(
|
||||||
|
"Loading webpage..."
|
||||||
|
if not eager
|
||||||
|
else "Loading & screenshotting webpage..."
|
||||||
|
)
|
||||||
|
)
|
||||||
start_request = time.time()
|
start_request = time.time()
|
||||||
try:
|
try:
|
||||||
await asyncio.to_thread(driver.get, url)
|
await asyncio.to_thread(driver.get, url)
|
||||||
|
@ -160,7 +171,9 @@ class ScreenshotCog(commands.Cog):
|
||||||
if "TimeoutException" in str(e):
|
if "TimeoutException" in str(e):
|
||||||
return await ctx.respond("Timed out while loading webpage.")
|
return await ctx.respond("Timed out while loading webpage.")
|
||||||
else:
|
else:
|
||||||
return await ctx.respond("Failed to load webpage:\n```\n%s\n```" % str(e.msg))
|
return await ctx.respond(
|
||||||
|
"Failed to load webpage:\n```\n%s\n```" % str(e.msg)
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
await self.bot.loop.run_in_executor(None, driver.quit)
|
await self.bot.loop.run_in_executor(None, driver.quit)
|
||||||
await ctx.respond("Failed to get the webpage: " + str(e))
|
await ctx.respond("Failed to get the webpage: " + str(e))
|
||||||
|
@ -170,7 +183,9 @@ class ScreenshotCog(commands.Cog):
|
||||||
if not eager:
|
if not eager:
|
||||||
now = discord.utils.utcnow()
|
now = discord.utils.utcnow()
|
||||||
expires = now + datetime.timedelta(seconds=render_timeout)
|
expires = now + datetime.timedelta(seconds=render_timeout)
|
||||||
await ctx.edit(content=f"Rendering (expires {discord.utils.format_dt(expires, 'R')})...")
|
await ctx.edit(
|
||||||
|
content=f"Rendering (expires {discord.utils.format_dt(expires, 'R')})..."
|
||||||
|
)
|
||||||
start_wait = time.time()
|
start_wait = time.time()
|
||||||
await asyncio.sleep(render_timeout)
|
await asyncio.sleep(render_timeout)
|
||||||
end_wait = time.time()
|
end_wait = time.time()
|
||||||
|
@ -200,7 +215,9 @@ class ScreenshotCog(commands.Cog):
|
||||||
await self.bot.loop.run_in_executor(None, driver.quit)
|
await self.bot.loop.run_in_executor(None, driver.quit)
|
||||||
end_cleanup = time.time()
|
end_cleanup = time.time()
|
||||||
|
|
||||||
screenshot_size_mb = round(len(await asyncio.to_thread(file.getvalue)) / 1024 / 1024, 2)
|
screenshot_size_mb = round(
|
||||||
|
len(await asyncio.to_thread(file.getvalue)) / 1024 / 1024, 2
|
||||||
|
)
|
||||||
|
|
||||||
def seconds(start: float, end: float) -> float:
|
def seconds(start: float, end: float) -> float:
|
||||||
return round(end - start, 2)
|
return round(end - start, 2)
|
||||||
|
@ -221,7 +238,9 @@ class ScreenshotCog(commands.Cog):
|
||||||
timestamp=discord.utils.utcnow(),
|
timestamp=discord.utils.utcnow(),
|
||||||
)
|
)
|
||||||
embed.set_image(url="attachment://" + fn)
|
embed.set_image(url="attachment://" + fn)
|
||||||
return await ctx.edit(content=None, embed=embed, file=discord.File(file, filename=fn))
|
return await ctx.edit(
|
||||||
|
content=None, embed=embed, file=discord.File(file, filename=fn)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def setup(bot):
|
def setup(bot):
|
||||||
|
|
|
@ -10,24 +10,32 @@ from conf import CONFIG
|
||||||
|
|
||||||
|
|
||||||
class Starboard(commands.Cog):
|
class Starboard(commands.Cog):
|
||||||
DOWNVOTE_EMOJI = discord.PartialEmoji.from_str("\N{collision symbol}")
|
DOWNVOTE_EMOJI = discord.PartialEmoji.from_str("\N{COLLISION SYMBOL}")
|
||||||
|
|
||||||
def __init__(self, bot):
|
def __init__(self, bot):
|
||||||
self.bot: commands.Bot = bot
|
self.bot: commands.Bot = bot
|
||||||
self.config = CONFIG["starboard"]
|
self.config = CONFIG["starboard"]
|
||||||
self.emoji = discord.PartialEmoji.from_str(self.config.get("emoji", "\N{white medium star}"))
|
self.emoji = discord.PartialEmoji.from_str(
|
||||||
|
self.config.get("emoji", "\N{WHITE MEDIUM STAR}")
|
||||||
|
)
|
||||||
self.log = logging.getLogger("jimmy.cogs.starboard")
|
self.log = logging.getLogger("jimmy.cogs.starboard")
|
||||||
self.redis = Redis(**CONFIG["redis"])
|
self.redis = Redis(**CONFIG["redis"])
|
||||||
|
|
||||||
async def generate_starboard_embed(self, message: discord.Message) -> tuple[list[discord.Embed], int]:
|
async def generate_starboard_embed(
|
||||||
|
self, message: discord.Message
|
||||||
|
) -> tuple[list[discord.Embed], int]:
|
||||||
"""
|
"""
|
||||||
Generates an embed ready for a starboard message.
|
Generates an embed ready for a starboard message.
|
||||||
|
|
||||||
:param message: The message to base off of.
|
:param message: The message to base off of.
|
||||||
:return: The created embed
|
:return: The created embed
|
||||||
"""
|
"""
|
||||||
reactions: list[discord.Reaction] = [x for x in message.reactions if str(x.emoji) == str(self.emoji)]
|
reactions: list[discord.Reaction] = [
|
||||||
downvote_reactions = [x for x in message.reactions if str(x.emoji) == str(self.DOWNVOTE_EMOJI)]
|
x for x in message.reactions if str(x.emoji) == str(self.emoji)
|
||||||
|
]
|
||||||
|
downvote_reactions = [
|
||||||
|
x for x in message.reactions if str(x.emoji) == str(self.DOWNVOTE_EMOJI)
|
||||||
|
]
|
||||||
if not reactions:
|
if not reactions:
|
||||||
# Nobody has added the star reaction.
|
# Nobody has added the star reaction.
|
||||||
star_count = 0
|
star_count = 0
|
||||||
|
@ -35,12 +43,18 @@ class Starboard(commands.Cog):
|
||||||
else:
|
else:
|
||||||
# Count the number of reactions
|
# Count the number of reactions
|
||||||
star_count = sum([x.count for x in reactions])
|
star_count = sum([x.count for x in reactions])
|
||||||
self.log.debug("There are a total of %d star reactions on message.", star_count)
|
self.log.debug(
|
||||||
|
"There are a total of %d star reactions on message.", star_count
|
||||||
|
)
|
||||||
|
|
||||||
if downvote_reactions:
|
if downvote_reactions:
|
||||||
_dv = sum([x.count for x in downvote_reactions])
|
_dv = sum([x.count for x in downvote_reactions])
|
||||||
star_count -= _dv
|
star_count -= _dv
|
||||||
self.log.debug("There are %d downvotes on message, resulting in %d stars.", _dv, star_count)
|
self.log.debug(
|
||||||
|
"There are %d downvotes on message, resulting in %d stars.",
|
||||||
|
_dv,
|
||||||
|
star_count,
|
||||||
|
)
|
||||||
|
|
||||||
if star_count >= 0:
|
if star_count >= 0:
|
||||||
star_emoji_count = (str(self.emoji) * star_count)[:10]
|
star_emoji_count = (str(self.emoji) * star_count)[:10]
|
||||||
|
@ -54,18 +68,20 @@ class Starboard(commands.Cog):
|
||||||
author=discord.EmbedAuthor(
|
author=discord.EmbedAuthor(
|
||||||
message.author.display_name,
|
message.author.display_name,
|
||||||
message.author.jump_url,
|
message.author.jump_url,
|
||||||
message.author.display_avatar.url
|
message.author.display_avatar.url,
|
||||||
),
|
),
|
||||||
fields=[
|
fields=[
|
||||||
discord.EmbedField(
|
discord.EmbedField(
|
||||||
name="Info",
|
name="Info",
|
||||||
value=f"[Stars: {star_emoji_count} ({star_count:,})]({message.jump_url})"
|
value=f"[Stars: {star_emoji_count} ({star_count:,})]({message.jump_url})",
|
||||||
)
|
)
|
||||||
]
|
],
|
||||||
)
|
)
|
||||||
if message.reference:
|
if message.reference:
|
||||||
try:
|
try:
|
||||||
ref_message = await self.get_or_fetch_message(message.reference.channel_id, message.reference.message_id)
|
ref_message = await self.get_or_fetch_message(
|
||||||
|
message.reference.channel_id, message.reference.message_id
|
||||||
|
)
|
||||||
except discord.HTTPException:
|
except discord.HTTPException:
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
|
@ -74,31 +90,26 @@ class Starboard(commands.Cog):
|
||||||
remaining = 1024 - len(v)
|
remaining = 1024 - len(v)
|
||||||
t = textwrap.shorten(text, remaining, placeholder="...")
|
t = textwrap.shorten(text, remaining, placeholder="...")
|
||||||
v = f"[{ref_message.author.display_name}'s message: {t}]({ref_message.jump_url})"
|
v = f"[{ref_message.author.display_name}'s message: {t}]({ref_message.jump_url})"
|
||||||
embed.add_field(
|
embed.add_field(name="Replying to", value=v)
|
||||||
name="Replying to",
|
|
||||||
value=v
|
|
||||||
)
|
|
||||||
elif message.interaction:
|
elif message.interaction:
|
||||||
if message.interaction.type == discord.InteractionType.application_command:
|
if message.interaction.type == discord.InteractionType.application_command:
|
||||||
real_author: discord.User = await discord.utils.get_or_fetch(
|
real_author: discord.User = await discord.utils.get_or_fetch(
|
||||||
self.bot,
|
self.bot, "user", int(message.interaction.data["user"]["id"])
|
||||||
"user",
|
)
|
||||||
int(message.interaction.data["user"]["id"])
|
real_author = (
|
||||||
|
await discord.utils.get_or_fetch(
|
||||||
|
message.guild, "member", real_author.id, default=real_author
|
||||||
|
)
|
||||||
|
or message.author
|
||||||
)
|
)
|
||||||
real_author = await discord.utils.get_or_fetch(
|
|
||||||
message.guild,
|
|
||||||
"member",
|
|
||||||
real_author.id,
|
|
||||||
default=real_author
|
|
||||||
) or message.author
|
|
||||||
embed.set_author(
|
embed.set_author(
|
||||||
name=real_author.display_name,
|
name=real_author.display_name,
|
||||||
icon_url=real_author.display_avatar.url,
|
icon_url=real_author.display_avatar.url,
|
||||||
url=real_author.jump_url
|
url=real_author.jump_url,
|
||||||
)
|
)
|
||||||
embed.add_field(
|
embed.add_field(
|
||||||
name="Interaction",
|
name="Interaction",
|
||||||
value=f"Command `/{message.interaction.data['name']}` of {message.author.mention}"
|
value=f"Command `/{message.interaction.data['name']}` of {message.author.mention}",
|
||||||
)
|
)
|
||||||
|
|
||||||
if message.content:
|
if message.content:
|
||||||
|
@ -107,17 +118,24 @@ class Starboard(commands.Cog):
|
||||||
for message_embed in message.embeds:
|
for message_embed in message.embeds:
|
||||||
if message_embed.type != "rich":
|
if message_embed.type != "rich":
|
||||||
if message_embed.type == "image":
|
if message_embed.type == "image":
|
||||||
if message_embed.thumbnail and message_embed.thumbnail.proxy_url:
|
if (
|
||||||
|
message_embed.thumbnail
|
||||||
|
and message_embed.thumbnail.proxy_url
|
||||||
|
):
|
||||||
embed.set_image(url=message_embed.thumbnail.proxy_url)
|
embed.set_image(url=message_embed.thumbnail.proxy_url)
|
||||||
continue
|
continue
|
||||||
if message_embed.description:
|
if message_embed.description:
|
||||||
embed.description = message_embed.description
|
embed.description = message_embed.description
|
||||||
elif not message.attachments:
|
elif not message.attachments:
|
||||||
raise ValueError("Message does not appear to contain any text, embeds, or attachments.")
|
raise ValueError(
|
||||||
|
"Message does not appear to contain any text, embeds, or attachments."
|
||||||
|
)
|
||||||
|
|
||||||
if message.attachments:
|
if message.attachments:
|
||||||
new_fields = []
|
new_fields = []
|
||||||
for n, attachment in reversed(tuple(enumerate(message.attachments, start=1))):
|
for n, attachment in reversed(
|
||||||
|
tuple(enumerate(message.attachments, start=1))
|
||||||
|
):
|
||||||
attachment: discord.Attachment
|
attachment: discord.Attachment
|
||||||
if attachment.size >= 1024 * 1024:
|
if attachment.size >= 1024 * 1024:
|
||||||
size = f"{attachment.size / 1024 / 1024:,.1f}MiB"
|
size = f"{attachment.size / 1024 / 1024:,.1f}MiB"
|
||||||
|
@ -129,7 +147,7 @@ class Starboard(commands.Cog):
|
||||||
{
|
{
|
||||||
"name": "Attachment #%d:" % n,
|
"name": "Attachment #%d:" % n,
|
||||||
"value": f"[{attachment.filename} ({size})]({attachment.url})",
|
"value": f"[{attachment.filename} ({size})]({attachment.url})",
|
||||||
"inline": True
|
"inline": True,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
if attachment.content_type.startswith("image/"):
|
if attachment.content_type.startswith("image/"):
|
||||||
|
@ -142,13 +160,19 @@ class Starboard(commands.Cog):
|
||||||
|
|
||||||
return [embed, *filter(lambda e: e.type == "rich", message.embeds)], star_count
|
return [embed, *filter(lambda e: e.type == "rich", message.embeds)], star_count
|
||||||
|
|
||||||
async def get_or_fetch_message(self, channel_id: int, message_id: int) -> discord.Message:
|
async def get_or_fetch_message(
|
||||||
|
self, channel_id: int, message_id: int
|
||||||
|
) -> discord.Message:
|
||||||
"""
|
"""
|
||||||
Fetches a message from cache where possible, falling back to the API.
|
Fetches a message from cache where possible, falling back to the API.
|
||||||
"""
|
"""
|
||||||
message: discord.Message | None = discord.utils.get(self.bot.cached_messages, id=message_id)
|
message: discord.Message | None = discord.utils.get(
|
||||||
|
self.bot.cached_messages, id=message_id
|
||||||
|
)
|
||||||
if not message:
|
if not message:
|
||||||
message: discord.Message = await self.bot.get_channel(channel_id).fetch_message(message_id)
|
message: discord.Message = await self.bot.get_channel(
|
||||||
|
channel_id
|
||||||
|
).fetch_message(message_id)
|
||||||
return message
|
return message
|
||||||
|
|
||||||
@commands.Cog.listener("on_raw_reaction_add")
|
@commands.Cog.listener("on_raw_reaction_add")
|
||||||
|
@ -171,27 +195,35 @@ class Starboard(commands.Cog):
|
||||||
|
|
||||||
guild: discord.Guild = self.bot.get_guild(payload.guild_id)
|
guild: discord.Guild = self.bot.get_guild(payload.guild_id)
|
||||||
starboard_channel: discord.TextChannel | None = discord.utils.get(
|
starboard_channel: discord.TextChannel | None = discord.utils.get(
|
||||||
guild.text_channels,
|
guild.text_channels, name=self.config.get("channel_name", "starboard")
|
||||||
name=self.config.get("channel_name", "starboard")
|
|
||||||
)
|
)
|
||||||
if not starboard_channel:
|
if not starboard_channel:
|
||||||
self.log.warning("Could not find starboard channel in %s (%d)", guild, guild.id)
|
self.log.warning(
|
||||||
|
"Could not find starboard channel in %s (%d)", guild, guild.id
|
||||||
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
if payload.channel_id == starboard_channel.id:
|
if payload.channel_id == starboard_channel.id:
|
||||||
self.log.debug("Ignoring reaction in starboard channel.")
|
self.log.debug("Ignoring reaction in starboard channel.")
|
||||||
return
|
return
|
||||||
|
|
||||||
message = await self.get_or_fetch_message(payload.channel_id, payload.message_id)
|
message = await self.get_or_fetch_message(
|
||||||
|
payload.channel_id, payload.message_id
|
||||||
|
)
|
||||||
|
|
||||||
if payload.user_id == message.author.id:
|
if payload.user_id == message.author.id:
|
||||||
self.log.info("%s tried to star their own message.", message.author)
|
self.log.info("%s tried to star their own message.", message.author)
|
||||||
return await message.reply(
|
return await message.reply(
|
||||||
"You can't star your own message you pretentious dick. Go outside, %s." % message.author.mention,
|
"You can't star your own message you pretentious dick. Go outside, %s."
|
||||||
delete_after=30
|
% message.author.mention,
|
||||||
|
delete_after=30,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.log.info("Processing starboard reaction from %s: %s", payload.user_id, message.jump_url)
|
self.log.info(
|
||||||
|
"Processing starboard reaction from %s: %s",
|
||||||
|
payload.user_id,
|
||||||
|
message.jump_url,
|
||||||
|
)
|
||||||
embed, star_count = await self.generate_starboard_embed(message)
|
embed, star_count = await self.generate_starboard_embed(message)
|
||||||
|
|
||||||
data = await self.redis.get(str(message.id))
|
data = await self.redis.get(str(message.id))
|
||||||
|
@ -205,7 +237,7 @@ class Starboard(commands.Cog):
|
||||||
"source_message_id": payload.message_id,
|
"source_message_id": payload.message_id,
|
||||||
"history": [],
|
"history": [],
|
||||||
"starboard_channel_id": starboard_channel.id,
|
"starboard_channel_id": starboard_channel.id,
|
||||||
"starboard_message_id": None
|
"starboard_message_id": None,
|
||||||
}
|
}
|
||||||
if not starboard_channel.can_send(embed[0]):
|
if not starboard_channel.can_send(embed[0]):
|
||||||
self.log.warning(
|
self.log.warning(
|
||||||
|
@ -213,29 +245,24 @@ class Starboard(commands.Cog):
|
||||||
starboard_channel.id,
|
starboard_channel.id,
|
||||||
payload.guild_id,
|
payload.guild_id,
|
||||||
starboard_channel.name,
|
starboard_channel.name,
|
||||||
guild.name
|
guild.name,
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
starboard_message = await starboard_channel.send(
|
starboard_message = await starboard_channel.send(embeds=embed, silent=True)
|
||||||
embeds=embed,
|
|
||||||
silent=True
|
|
||||||
)
|
|
||||||
data["starboard_message_id"] = starboard_message.id
|
data["starboard_message_id"] = starboard_message.id
|
||||||
await self.redis.set(str(message.id), json.dumps(data))
|
await self.redis.set(str(message.id), json.dumps(data))
|
||||||
else:
|
else:
|
||||||
data = json.loads(data)
|
data = json.loads(data)
|
||||||
try:
|
try:
|
||||||
starboard_message = await self.get_or_fetch_message(
|
starboard_message = await self.get_or_fetch_message(
|
||||||
data["starboard_channel_id"],
|
data["starboard_channel_id"], data["starboard_message_id"]
|
||||||
data["starboard_message_id"]
|
|
||||||
)
|
)
|
||||||
except discord.NotFound:
|
except discord.NotFound:
|
||||||
if star_count <= 0:
|
if star_count <= 0:
|
||||||
return
|
return
|
||||||
starboard_message = await starboard_channel.send(
|
starboard_message = await starboard_channel.send(
|
||||||
embeds=embed,
|
embeds=embed, silent=True
|
||||||
silent=True
|
|
||||||
)
|
)
|
||||||
data["starboard_message_id"] = starboard_message.id
|
data["starboard_message_id"] = starboard_message.id
|
||||||
data["starboard_channel_id"] = starboard_message.channel.id
|
data["starboard_channel_id"] = starboard_message.channel.id
|
||||||
|
@ -247,18 +274,20 @@ class Starboard(commands.Cog):
|
||||||
"Deleted message %s in %s, %s - lost all stars",
|
"Deleted message %s in %s, %s - lost all stars",
|
||||||
starboard_message.id,
|
starboard_message.id,
|
||||||
starboard_message.channel.name,
|
starboard_message.channel.name,
|
||||||
starboard_message.guild.name
|
starboard_message.guild.name,
|
||||||
)
|
)
|
||||||
elif starboard_message.embeds[0] != embed[0]:
|
elif starboard_message.embeds[0] != embed[0]:
|
||||||
await starboard_message.edit(embeds=embed)
|
await starboard_message.edit(embeds=embed)
|
||||||
|
|
||||||
@commands.message_command(name="Preview Starboard Message")
|
@commands.message_command(name="Preview Starboard Message")
|
||||||
async def preview_starboard_message(self, ctx: discord.ApplicationContext, message: discord.Message):
|
async def preview_starboard_message(
|
||||||
|
self, ctx: discord.ApplicationContext, message: discord.Message
|
||||||
|
):
|
||||||
embed, stars = await self.generate_starboard_embed(message)
|
embed, stars = await self.generate_starboard_embed(message)
|
||||||
data = await self.redis.get(str(message.id)) or 'null'
|
data = await self.redis.get(str(message.id)) or "null"
|
||||||
return await ctx.respond(
|
return await ctx.respond(
|
||||||
f"```json\n{json.dumps(json.loads(data), indent=4)}\n```\nStars: {stars:,}",
|
f"```json\n{json.dumps(json.loads(data), indent=4)}\n```\nStars: {stars:,}",
|
||||||
embeds=embed
|
embeds=embed,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
237
src/cogs/ytdl.py
237
src/cogs/ytdl.py
|
@ -135,12 +135,21 @@ class YTDLCog(commands.Cog):
|
||||||
channel_id=excluded.channel_id,
|
channel_id=excluded.channel_id,
|
||||||
attachment_index=excluded.attachment_index
|
attachment_index=excluded.attachment_index
|
||||||
""",
|
""",
|
||||||
(_hash, message.id, message.channel.id, webpage_url, format_id, attachment_index),
|
(
|
||||||
|
_hash,
|
||||||
|
message.id,
|
||||||
|
message.channel.id,
|
||||||
|
webpage_url,
|
||||||
|
format_id,
|
||||||
|
attachment_index,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
await db.commit()
|
await db.commit()
|
||||||
return _hash
|
return _hash
|
||||||
|
|
||||||
async def get_saved(self, webpage_url: str, format_id: str, snip: str) -> typing.Optional[str]:
|
async def get_saved(
|
||||||
|
self, webpage_url: str, format_id: str, snip: str
|
||||||
|
) -> typing.Optional[str]:
|
||||||
"""
|
"""
|
||||||
Attempts to retrieve the attachment URL of a previously saved download.
|
Attempts to retrieve the attachment URL of a previously saved download.
|
||||||
:param webpage_url: The webpage url
|
:param webpage_url: The webpage url
|
||||||
|
@ -154,12 +163,19 @@ class YTDLCog(commands.Cog):
|
||||||
logging.error("Failed to initialise ytdl database: %s", e, exc_info=True)
|
logging.error("Failed to initialise ytdl database: %s", e, exc_info=True)
|
||||||
return
|
return
|
||||||
async with aiosqlite.connect("./data/ytdl.db") as db:
|
async with aiosqlite.connect("./data/ytdl.db") as db:
|
||||||
_hash = hashlib.md5(f"{webpage_url}:{format_id}:{snip}".encode()).hexdigest()
|
_hash = hashlib.md5(
|
||||||
|
f"{webpage_url}:{format_id}:{snip}".encode()
|
||||||
|
).hexdigest()
|
||||||
self.log.debug(
|
self.log.debug(
|
||||||
"Attempting to find a saved download for '%s:%s:%s' (%r).", webpage_url, format_id, snip, _hash
|
"Attempting to find a saved download for '%s:%s:%s' (%r).",
|
||||||
|
webpage_url,
|
||||||
|
format_id,
|
||||||
|
snip,
|
||||||
|
_hash,
|
||||||
)
|
)
|
||||||
cursor = await db.execute(
|
cursor = await db.execute(
|
||||||
"SELECT message_id, channel_id, attachment_index FROM downloads WHERE key=?", (_hash,)
|
"SELECT message_id, channel_id, attachment_index FROM downloads WHERE key=?",
|
||||||
|
(_hash,),
|
||||||
)
|
)
|
||||||
entry = await cursor.fetchone()
|
entry = await cursor.fetchone()
|
||||||
if not entry:
|
if not entry:
|
||||||
|
@ -173,7 +189,9 @@ class YTDLCog(commands.Cog):
|
||||||
try:
|
try:
|
||||||
message = await channel.fetch_message(message_id)
|
message = await channel.fetch_message(message_id)
|
||||||
except discord.HTTPException:
|
except discord.HTTPException:
|
||||||
self.log.debug("%r did not contain a message with ID %r", channel, message_id)
|
self.log.debug(
|
||||||
|
"%r did not contain a message with ID %r", channel, message_id
|
||||||
|
)
|
||||||
await db.execute("DELETE FROM downloads WHERE key=?", (_hash,))
|
await db.execute("DELETE FROM downloads WHERE key=?", (_hash,))
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -182,7 +200,11 @@ class YTDLCog(commands.Cog):
|
||||||
self.log.debug("Found URL %r, returning.", url)
|
self.log.debug("Found URL %r, returning.", url)
|
||||||
return url
|
return url
|
||||||
except IndexError:
|
except IndexError:
|
||||||
self.log.debug("Attachment index %d is out of range (%r)", attachment_index, message.attachments)
|
self.log.debug(
|
||||||
|
"Attachment index %d is out of range (%r)",
|
||||||
|
attachment_index,
|
||||||
|
message.attachments,
|
||||||
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
def convert_to_m4a(self, file: Path) -> Path:
|
def convert_to_m4a(self, file: Path) -> Path:
|
||||||
|
@ -209,23 +231,32 @@ class YTDLCog(commands.Cog):
|
||||||
str(new_file),
|
str(new_file),
|
||||||
]
|
]
|
||||||
self.log.debug("Running command: ffmpeg %s", " ".join(args))
|
self.log.debug("Running command: ffmpeg %s", " ".join(args))
|
||||||
process = subprocess.run(["ffmpeg", *args], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
process = subprocess.run(
|
||||||
|
["ffmpeg", *args], stdout=subprocess.PIPE, stderr=subprocess.PIPE
|
||||||
|
)
|
||||||
if process.returncode != 0:
|
if process.returncode != 0:
|
||||||
raise RuntimeError(process.stderr.decode())
|
raise RuntimeError(process.stderr.decode())
|
||||||
return new_file
|
return new_file
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def upload_to_0x0(name: str, data: typing.IO[bytes], mime_type: str | None = None) -> str:
|
async def upload_to_0x0(
|
||||||
|
name: str, data: typing.IO[bytes], mime_type: str | None = None
|
||||||
|
) -> str:
|
||||||
if not mime_type:
|
if not mime_type:
|
||||||
import magic
|
import magic
|
||||||
mime_type = await asyncio.to_thread(magic.from_buffer, data.read(4096), mime=True)
|
|
||||||
|
mime_type = await asyncio.to_thread(
|
||||||
|
magic.from_buffer, data.read(4096), mime=True
|
||||||
|
)
|
||||||
data.seek(0)
|
data.seek(0)
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
response = await client.post(
|
response = await client.post(
|
||||||
"https://0x0.st",
|
"https://0x0.st",
|
||||||
files={"file": (name, data, mime_type)},
|
files={"file": (name, data, mime_type)},
|
||||||
data={"expires": 12},
|
data={"expires": 12},
|
||||||
headers={"User-Agent": "CollegeBot (see: https://gist.i-am.nexus/nex/f63fcb9eb389401caf66d1dfc3c7570c)"},
|
headers={
|
||||||
|
"User-Agent": "CollegeBot (see: https://gist.i-am.nexus/nex/f63fcb9eb389401caf66d1dfc3c7570c)"
|
||||||
|
},
|
||||||
)
|
)
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
return urlparse(response.text).path[1:]
|
return urlparse(response.text).path[1:]
|
||||||
|
@ -235,7 +266,10 @@ class YTDLCog(commands.Cog):
|
||||||
async def yt_dl_command(
|
async def yt_dl_command(
|
||||||
self,
|
self,
|
||||||
ctx: discord.ApplicationContext,
|
ctx: discord.ApplicationContext,
|
||||||
url: typing.Annotated[str, discord.Option(str, description="The URL to download from.", required=True)],
|
url: typing.Annotated[
|
||||||
|
str,
|
||||||
|
discord.Option(str, description="The URL to download from.", required=True),
|
||||||
|
],
|
||||||
user_format: typing.Annotated[
|
user_format: typing.Annotated[
|
||||||
typing.Optional[str],
|
typing.Optional[str],
|
||||||
discord.Option(
|
discord.Option(
|
||||||
|
@ -258,7 +292,10 @@ class YTDLCog(commands.Cog):
|
||||||
],
|
],
|
||||||
snip: typing.Annotated[
|
snip: typing.Annotated[
|
||||||
typing.Optional[str],
|
typing.Optional[str],
|
||||||
discord.Option(description="A start and end position to trim. e.g. 00:00:00-00:10:00.", required=False),
|
discord.Option(
|
||||||
|
description="A start and end position to trim. e.g. 00:00:00-00:10:00.",
|
||||||
|
required=False,
|
||||||
|
),
|
||||||
],
|
],
|
||||||
subtitles: typing.Annotated[
|
subtitles: typing.Annotated[
|
||||||
typing.Optional[str],
|
typing.Optional[str],
|
||||||
|
@ -267,7 +304,7 @@ class YTDLCog(commands.Cog):
|
||||||
description="The language code of the subtitles to download. e.g. 'en', 'auto'",
|
description="The language code of the subtitles to download. e.g. 'en', 'auto'",
|
||||||
required=False,
|
required=False,
|
||||||
),
|
),
|
||||||
]
|
],
|
||||||
):
|
):
|
||||||
"""Runs yt-dlp and outputs into discord."""
|
"""Runs yt-dlp and outputs into discord."""
|
||||||
await ctx.defer()
|
await ctx.defer()
|
||||||
|
@ -279,28 +316,40 @@ class YTDLCog(commands.Cog):
|
||||||
if stop.is_set():
|
if stop.is_set():
|
||||||
raise RuntimeError("Download cancelled.")
|
raise RuntimeError("Download cancelled.")
|
||||||
n = time.time()
|
n = time.time()
|
||||||
_total = _data.get("total_bytes", _data.get("total_bytes_estimate")) or ctx.guild.filesize_limit
|
_total = (
|
||||||
|
_data.get("total_bytes", _data.get("total_bytes_estimate"))
|
||||||
|
or ctx.guild.filesize_limit
|
||||||
|
)
|
||||||
if _total:
|
if _total:
|
||||||
_percent = round((_data.get("downloaded_bytes") or 0) / _total * 100, 2)
|
_percent = round((_data.get("downloaded_bytes") or 0) / _total * 100, 2)
|
||||||
else:
|
else:
|
||||||
_total = max(1, _data.get("fragment_count", 4096) or 4096)
|
_total = max(1, _data.get("fragment_count", 4096) or 4096)
|
||||||
_percent = round(max(_data.get("fragment_index", 1) or 1, 1) / _total * 100, 2)
|
_percent = round(
|
||||||
|
max(_data.get("fragment_index", 1) or 1, 1) / _total * 100, 2
|
||||||
|
)
|
||||||
_speed_bytes_per_second = _data.get("speed", 1) or 1 or 1
|
_speed_bytes_per_second = _data.get("speed", 1) or 1 or 1
|
||||||
_speed_megabits_per_second = round((_speed_bytes_per_second * 8) / 1024 / 1024)
|
_speed_megabits_per_second = round(
|
||||||
|
(_speed_bytes_per_second * 8) / 1024 / 1024
|
||||||
|
)
|
||||||
if _data.get("eta"):
|
if _data.get("eta"):
|
||||||
_eta = discord.utils.utcnow() + datetime.timedelta(seconds=_data.get("eta"))
|
_eta = discord.utils.utcnow() + datetime.timedelta(
|
||||||
|
seconds=_data.get("eta")
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
_eta = discord.utils.utcnow() + datetime.timedelta(minutes=1)
|
_eta = discord.utils.utcnow() + datetime.timedelta(minutes=1)
|
||||||
blocks = "#" * math.floor(_percent / 10)
|
blocks = "#" * math.floor(_percent / 10)
|
||||||
bar = f"{blocks}{'.' * (10 - len(blocks))}"
|
bar = f"{blocks}{'.' * (10 - len(blocks))}"
|
||||||
line = (f"{_percent}% [{bar}] | {_speed_megabits_per_second}Mbps | "
|
line = (
|
||||||
f"ETA {discord.utils.format_dt(_eta, 'R')}")
|
f"{_percent}% [{bar}] | {_speed_megabits_per_second}Mbps | "
|
||||||
|
f"ETA {discord.utils.format_dt(_eta, 'R')}"
|
||||||
|
)
|
||||||
nonlocal last_edit
|
nonlocal last_edit
|
||||||
if (n - last_edit) >= 1.1:
|
if (n - last_edit) >= 1.1:
|
||||||
embed.clear_fields()
|
embed.clear_fields()
|
||||||
embed.add_field(name="Progress", value=line)
|
embed.add_field(name="Progress", value=line)
|
||||||
ctx.bot.loop.create_task(ctx.edit(embed=embed))
|
ctx.bot.loop.create_task(ctx.edit(embed=embed))
|
||||||
last_edit = time.time()
|
last_edit = time.time()
|
||||||
|
|
||||||
options["progress_hooks"] = [_download_hook]
|
options["progress_hooks"] = [_download_hook]
|
||||||
|
|
||||||
description = ""
|
description = ""
|
||||||
|
@ -331,13 +380,19 @@ class YTDLCog(commands.Cog):
|
||||||
options["format_sort"] = ["abr", "br"]
|
options["format_sort"] = ["abr", "br"]
|
||||||
# noinspection PyTypeChecker
|
# noinspection PyTypeChecker
|
||||||
options["postprocessors"].append(
|
options["postprocessors"].append(
|
||||||
{"key": "FFmpegExtractAudio", "preferredquality": "96", "preferredcodec": "best"}
|
{
|
||||||
|
"key": "FFmpegExtractAudio",
|
||||||
|
"preferredquality": "96",
|
||||||
|
"preferredcodec": "best",
|
||||||
|
}
|
||||||
)
|
)
|
||||||
options["format"] = chosen_format
|
options["format"] = chosen_format
|
||||||
options["paths"] = paths
|
options["paths"] = paths
|
||||||
|
|
||||||
if subtitles and ffmpeg_installed:
|
if subtitles and ffmpeg_installed:
|
||||||
subtitles, burn = subtitles.split("+", 1) if "+" in subtitles else (subtitles, "0")
|
subtitles, burn = (
|
||||||
|
subtitles.split("+", 1) if "+" in subtitles else (subtitles, "0")
|
||||||
|
)
|
||||||
burn = burn[0].lower() in ("y", "1", "t")
|
burn = burn[0].lower() in ("y", "1", "t")
|
||||||
if subtitles.lower() == "auto":
|
if subtitles.lower() == "auto":
|
||||||
options["writeautosubtitles"] = True
|
options["writeautosubtitles"] = True
|
||||||
|
@ -352,10 +407,14 @@ class YTDLCog(commands.Cog):
|
||||||
)
|
)
|
||||||
|
|
||||||
with yt_dlp.YoutubeDL(options) as downloader:
|
with yt_dlp.YoutubeDL(options) as downloader:
|
||||||
await ctx.respond(embed=discord.Embed().set_footer(text="Downloading (step 1/10)"))
|
await ctx.respond(
|
||||||
|
embed=discord.Embed().set_footer(text="Downloading (step 1/10)")
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
# noinspection PyTypeChecker
|
# noinspection PyTypeChecker
|
||||||
extracted_info = await asyncio.to_thread(downloader.extract_info, url, download=False)
|
extracted_info = await asyncio.to_thread(
|
||||||
|
downloader.extract_info, url, download=False
|
||||||
|
)
|
||||||
except yt_dlp.utils.DownloadError as e:
|
except yt_dlp.utils.DownloadError as e:
|
||||||
extracted_info = {
|
extracted_info = {
|
||||||
"title": "error",
|
"title": "error",
|
||||||
|
@ -382,22 +441,38 @@ class YTDLCog(commands.Cog):
|
||||||
thumbnail_url = extracted_info.get("thumbnail") or None
|
thumbnail_url = extracted_info.get("thumbnail") or None
|
||||||
webpage_url = extracted_info.get("webpage_url", url)
|
webpage_url = extracted_info.get("webpage_url", url)
|
||||||
|
|
||||||
chosen_format = extracted_info.get("format") or chosen_format or str(uuid.uuid4())
|
chosen_format = (
|
||||||
chosen_format_id = extracted_info.get("format_id") or str(uuid.uuid4())
|
extracted_info.get("format")
|
||||||
|
or chosen_format
|
||||||
|
or str(uuid.uuid4())
|
||||||
|
)
|
||||||
|
chosen_format_id = extracted_info.get("format_id") or str(
|
||||||
|
uuid.uuid4()
|
||||||
|
)
|
||||||
final_extension = extracted_info.get("ext") or "mp4"
|
final_extension = extracted_info.get("ext") or "mp4"
|
||||||
format_note = extracted_info.get("format_note", "%s (%s)" % (chosen_format, chosen_format_id)) or ""
|
format_note = (
|
||||||
|
extracted_info.get(
|
||||||
|
"format_note", "%s (%s)" % (chosen_format, chosen_format_id)
|
||||||
|
)
|
||||||
|
or ""
|
||||||
|
)
|
||||||
resolution = extracted_info.get("resolution") or "1x1"
|
resolution = extracted_info.get("resolution") or "1x1"
|
||||||
fps = extracted_info.get("fps", 0.0) or 0.0
|
fps = extracted_info.get("fps", 0.0) or 0.0
|
||||||
vcodec = extracted_info.get("vcodec") or "h264"
|
vcodec = extracted_info.get("vcodec") or "h264"
|
||||||
acodec = extracted_info.get("acodec") or "aac"
|
acodec = extracted_info.get("acodec") or "aac"
|
||||||
filesize = extracted_info.get("filesize", extracted_info.get("filesize_approx", 1))
|
filesize = extracted_info.get(
|
||||||
likes = extracted_info.get("like_count", extracted_info.get("average_rating", 0))
|
"filesize", extracted_info.get("filesize_approx", 1)
|
||||||
|
)
|
||||||
|
likes = extracted_info.get(
|
||||||
|
"like_count", extracted_info.get("average_rating", 0)
|
||||||
|
)
|
||||||
views = extracted_info.get("view_count", 0)
|
views = extracted_info.get("view_count", 0)
|
||||||
|
|
||||||
lines = []
|
lines = []
|
||||||
if chosen_format and chosen_format_id:
|
if chosen_format and chosen_format_id:
|
||||||
lines.append(
|
lines.append(
|
||||||
"* Chosen format: `%s` (`%s`)" % (chosen_format, chosen_format_id),
|
"* Chosen format: `%s` (`%s`)"
|
||||||
|
% (chosen_format, chosen_format_id),
|
||||||
)
|
)
|
||||||
if format_note:
|
if format_note:
|
||||||
lines.append("* Format note: %r" % format_note)
|
lines.append("* Format note: %r" % format_note)
|
||||||
|
@ -411,7 +486,9 @@ class YTDLCog(commands.Cog):
|
||||||
if vcodec or acodec:
|
if vcodec or acodec:
|
||||||
lines.append("%s+%s" % (vcodec or "N/A", acodec or "N/A"))
|
lines.append("%s+%s" % (vcodec or "N/A", acodec or "N/A"))
|
||||||
if filesize:
|
if filesize:
|
||||||
lines.append("* Filesize: %s" % yt_dlp.utils.format_bytes(filesize))
|
lines.append(
|
||||||
|
"* Filesize: %s" % yt_dlp.utils.format_bytes(filesize)
|
||||||
|
)
|
||||||
|
|
||||||
if lines:
|
if lines:
|
||||||
description += "\n"
|
description += "\n"
|
||||||
|
@ -424,27 +501,29 @@ class YTDLCog(commands.Cog):
|
||||||
url=webpage_url,
|
url=webpage_url,
|
||||||
colour=self.colours.get(domain, discord.Colour.og_blurple()),
|
colour=self.colours.get(domain, discord.Colour.og_blurple()),
|
||||||
)
|
)
|
||||||
embed.add_field(
|
embed.add_field(name="Progress", value="0% [..........]")
|
||||||
name="Progress",
|
|
||||||
value="0% [..........]"
|
|
||||||
)
|
|
||||||
embed.set_footer(text="Downloading (step 2/10)")
|
embed.set_footer(text="Downloading (step 2/10)")
|
||||||
embed.set_thumbnail(url=thumbnail_url)
|
embed.set_thumbnail(url=thumbnail_url)
|
||||||
|
|
||||||
class StopView(discord.ui.View):
|
class StopView(discord.ui.View):
|
||||||
@discord.ui.button(label="Cancel download", style=discord.ButtonStyle.danger)
|
@discord.ui.button(
|
||||||
async def _stop(self, button: discord.ui.Button, interaction: discord.Interaction):
|
label="Cancel download", style=discord.ButtonStyle.danger
|
||||||
|
)
|
||||||
|
async def _stop(
|
||||||
|
self,
|
||||||
|
button: discord.ui.Button,
|
||||||
|
interaction: discord.Interaction,
|
||||||
|
):
|
||||||
stop.set()
|
stop.set()
|
||||||
button.label = "Cancelling..."
|
button.label = "Cancelling..."
|
||||||
button.disabled = True
|
button.disabled = True
|
||||||
await interaction.response.edit_message(view=self)
|
await interaction.response.edit_message(view=self)
|
||||||
self.stop()
|
self.stop()
|
||||||
|
|
||||||
await ctx.edit(
|
await ctx.edit(embed=embed, view=StopView(timeout=86400))
|
||||||
embed=embed,
|
previous = await self.get_saved(
|
||||||
view=StopView(timeout=86400)
|
webpage_url, chosen_format_id, snip or "*"
|
||||||
)
|
)
|
||||||
previous = await self.get_saved(webpage_url, chosen_format_id, snip or "*")
|
|
||||||
if previous:
|
if previous:
|
||||||
await ctx.edit(
|
await ctx.edit(
|
||||||
content=previous,
|
content=previous,
|
||||||
|
@ -454,7 +533,11 @@ class YTDLCog(commands.Cog):
|
||||||
colour=discord.Colour.green(),
|
colour=discord.Colour.green(),
|
||||||
timestamp=discord.utils.utcnow(),
|
timestamp=discord.utils.utcnow(),
|
||||||
url=previous,
|
url=previous,
|
||||||
fields=[discord.EmbedField(name="URL", value=previous, inline=False)],
|
fields=[
|
||||||
|
discord.EmbedField(
|
||||||
|
name="URL", value=previous, inline=False
|
||||||
|
)
|
||||||
|
],
|
||||||
).set_image(url=previous),
|
).set_image(url=previous),
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
@ -462,7 +545,9 @@ class YTDLCog(commands.Cog):
|
||||||
last_edit = time.time()
|
last_edit = time.time()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await asyncio.to_thread(functools.partial(downloader.download, [url]))
|
await asyncio.to_thread(
|
||||||
|
functools.partial(downloader.download, [url])
|
||||||
|
)
|
||||||
except yt_dlp.DownloadError as e:
|
except yt_dlp.DownloadError as e:
|
||||||
logging.error(e, exc_info=True)
|
logging.error(e, exc_info=True)
|
||||||
return await ctx.edit(
|
return await ctx.edit(
|
||||||
|
@ -473,7 +558,7 @@ class YTDLCog(commands.Cog):
|
||||||
url=webpage_url,
|
url=webpage_url,
|
||||||
),
|
),
|
||||||
delete_after=120,
|
delete_after=120,
|
||||||
view=None
|
view=None,
|
||||||
)
|
)
|
||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
return await ctx.edit(
|
return await ctx.edit(
|
||||||
|
@ -484,16 +569,26 @@ class YTDLCog(commands.Cog):
|
||||||
url=webpage_url,
|
url=webpage_url,
|
||||||
),
|
),
|
||||||
delete_after=120,
|
delete_after=120,
|
||||||
view=None
|
view=None,
|
||||||
)
|
)
|
||||||
await ctx.edit(view=None)
|
await ctx.edit(view=None)
|
||||||
try:
|
try:
|
||||||
if audio_only is False:
|
if audio_only is False:
|
||||||
file: Path = next(temp_dir.glob("*." + extracted_info.get("ext", "*")))
|
file: Path = next(
|
||||||
|
temp_dir.glob("*." + extracted_info.get("ext", "*"))
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# can be .opus, .m4a, .mp3, .ogg, .oga
|
# can be .opus, .m4a, .mp3, .ogg, .oga
|
||||||
for _file in temp_dir.iterdir():
|
for _file in temp_dir.iterdir():
|
||||||
if _file.suffix in (".opus", ".m4a", ".mp3", ".ogg", ".oga", ".aac", ".wav"):
|
if _file.suffix in (
|
||||||
|
".opus",
|
||||||
|
".m4a",
|
||||||
|
".mp3",
|
||||||
|
".ogg",
|
||||||
|
".oga",
|
||||||
|
".aac",
|
||||||
|
".wav",
|
||||||
|
):
|
||||||
file: Path = _file
|
file: Path = _file
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
|
@ -523,7 +618,9 @@ class YTDLCog(commands.Cog):
|
||||||
except ValueError:
|
except ValueError:
|
||||||
trim_start, trim_end = snip, None
|
trim_start, trim_end = snip, None
|
||||||
trim_start = trim_start or "00:00:00"
|
trim_start = trim_start or "00:00:00"
|
||||||
trim_end = trim_end or extracted_info.get("duration_string", "00:30:00")
|
trim_end = trim_end or extracted_info.get(
|
||||||
|
"duration_string", "00:30:00"
|
||||||
|
)
|
||||||
new_file = temp_dir / ("output" + file.suffix)
|
new_file = temp_dir / ("output" + file.suffix)
|
||||||
args = [
|
args = [
|
||||||
"-hwaccel",
|
"-hwaccel",
|
||||||
|
@ -562,7 +659,10 @@ class YTDLCog(commands.Cog):
|
||||||
)
|
)
|
||||||
self.log.debug("Running command: 'ffmpeg %s'", " ".join(args))
|
self.log.debug("Running command: 'ffmpeg %s'", " ".join(args))
|
||||||
process = await asyncio.create_subprocess_exec(
|
process = await asyncio.create_subprocess_exec(
|
||||||
"ffmpeg", *args, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE
|
"ffmpeg",
|
||||||
|
*args,
|
||||||
|
stdout=asyncio.subprocess.PIPE,
|
||||||
|
stderr=asyncio.subprocess.PIPE,
|
||||||
)
|
)
|
||||||
stdout, stderr = await process.communicate()
|
stdout, stderr = await process.communicate()
|
||||||
self.log.debug("STDOUT:\n%r", stdout.decode())
|
self.log.debug("STDOUT:\n%r", stdout.decode())
|
||||||
|
@ -606,31 +706,40 @@ class YTDLCog(commands.Cog):
|
||||||
)
|
)
|
||||||
embed = discord.Embed(
|
embed = discord.Embed(
|
||||||
title=f"Downloaded {title}!",
|
title=f"Downloaded {title}!",
|
||||||
description="Views: {:,} | Likes: {:,}".format(views or 0, likes or 0),
|
description="Views: {:,} | Likes: {:,}".format(
|
||||||
|
views or 0, likes or 0
|
||||||
|
),
|
||||||
colour=discord.Colour.green(),
|
colour=discord.Colour.green(),
|
||||||
timestamp=discord.utils.utcnow(),
|
timestamp=discord.utils.utcnow(),
|
||||||
url=webpage_url,
|
url=webpage_url,
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
if size_bytes >= (20 * 1024 * 1024) or vcodec.lower() in ["hevc", "h265", "av1", "av01"]:
|
if size_bytes >= (20 * 1024 * 1024) or vcodec.lower() in [
|
||||||
|
"hevc",
|
||||||
|
"h265",
|
||||||
|
"av1",
|
||||||
|
"av01",
|
||||||
|
]:
|
||||||
with file.open("rb") as fb:
|
with file.open("rb") as fb:
|
||||||
part = await self.upload_to_0x0(
|
part = await self.upload_to_0x0(file.name, fb)
|
||||||
file.name,
|
embed.add_field(
|
||||||
fb
|
name="URL", value=f"https://0x0.st/{part}", inline=False
|
||||||
)
|
|
||||||
embed.add_field(name="URL", value=f"https://0x0.st/{part}", inline=False)
|
|
||||||
await ctx.edit(
|
|
||||||
embed=embed
|
|
||||||
)
|
)
|
||||||
|
await ctx.edit(embed=embed)
|
||||||
await ctx.respond("https://embeds.video/0x0/" + part)
|
await ctx.respond("https://embeds.video/0x0/" + part)
|
||||||
else:
|
else:
|
||||||
upload_file = await asyncio.to_thread(discord.File, file, filename=file.name)
|
upload_file = await asyncio.to_thread(
|
||||||
msg = await ctx.edit(
|
discord.File, file, filename=file.name
|
||||||
file=upload_file,
|
|
||||||
embed=embed
|
|
||||||
)
|
)
|
||||||
await self.save_link(msg, webpage_url, chosen_format_id, snip=snip or "*")
|
msg = await ctx.edit(file=upload_file, embed=embed)
|
||||||
except (discord.HTTPException, ConnectionError, httpx.HTTPStatusError) as e:
|
await self.save_link(
|
||||||
|
msg, webpage_url, chosen_format_id, snip=snip or "*"
|
||||||
|
)
|
||||||
|
except (
|
||||||
|
discord.HTTPException,
|
||||||
|
ConnectionError,
|
||||||
|
httpx.HTTPStatusError,
|
||||||
|
) as e:
|
||||||
self.log.error(e, exc_info=True)
|
self.log.error(e, exc_info=True)
|
||||||
return await ctx.edit(
|
return await ctx.edit(
|
||||||
embed=discord.Embed(
|
embed=discord.Embed(
|
||||||
|
|
12
src/conf.py
12
src/conf.py
|
@ -21,7 +21,9 @@ if (Path.cwd() / ".git").exists():
|
||||||
log.debug("Unable to auto-detect running version using git.", exc_info=True)
|
log.debug("Unable to auto-detect running version using git.", exc_info=True)
|
||||||
VERSION = "unknown"
|
VERSION = "unknown"
|
||||||
else:
|
else:
|
||||||
log.debug("Unable to auto-detect running version using git, no .git directory exists.")
|
log.debug(
|
||||||
|
"Unable to auto-detect running version using git, no .git directory exists."
|
||||||
|
)
|
||||||
VERSION = "unknown"
|
VERSION = "unknown"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -32,7 +34,7 @@ except FileNotFoundError:
|
||||||
log.critical(
|
log.critical(
|
||||||
"Unable to locate config.toml in %s. Using default configuration. Good luck!",
|
"Unable to locate config.toml in %s. Using default configuration. Good luck!",
|
||||||
cwd,
|
cwd,
|
||||||
exc_info=True
|
exc_info=True,
|
||||||
)
|
)
|
||||||
CONFIG = {}
|
CONFIG = {}
|
||||||
CONFIG.setdefault("logging", {})
|
CONFIG.setdefault("logging", {})
|
||||||
|
@ -50,11 +52,13 @@ CONFIG.setdefault(
|
||||||
"api": "https://bots.nexy7574.co.uk/jimmy/v2/",
|
"api": "https://bots.nexy7574.co.uk/jimmy/v2/",
|
||||||
"username": os.getenv("WEB_USERNAME", os.urandom(32).hex()),
|
"username": os.getenv("WEB_USERNAME", os.urandom(32).hex()),
|
||||||
"password": os.getenv("WEB_PASSWORD", os.urandom(32).hex()),
|
"password": os.getenv("WEB_PASSWORD", os.urandom(32).hex()),
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
if CONFIG["redis"].pop("no_ping", None) is not None:
|
if CONFIG["redis"].pop("no_ping", None) is not None:
|
||||||
log.warning("`redis.no_ping` was deprecated after 808D621F. Ping is now always mandatory.")
|
log.warning(
|
||||||
|
"`redis.no_ping` was deprecated after 808D621F. Ping is now always mandatory."
|
||||||
|
)
|
||||||
|
|
||||||
CONFIG["redis"]["decode_responses"] = True
|
CONFIG["redis"]["decode_responses"] = True
|
||||||
|
|
||||||
|
|
72
src/main.py
72
src/main.py
|
@ -61,9 +61,16 @@ class KumaThread(Thread):
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
self.previous = bot.is_ready()
|
self.previous = bot.is_ready()
|
||||||
except httpx.HTTPError as error:
|
except httpx.HTTPError as error:
|
||||||
self.log.error("Failed to connect to uptime-kuma: %r: %r", url, error, exc_info=error)
|
self.log.error(
|
||||||
|
"Failed to connect to uptime-kuma: %r: %r",
|
||||||
|
url,
|
||||||
|
error,
|
||||||
|
exc_info=error,
|
||||||
|
)
|
||||||
timeout = self.calculate_backoff()
|
timeout = self.calculate_backoff()
|
||||||
self.log.warning("Waiting %d seconds before retrying ping.", timeout)
|
self.log.warning(
|
||||||
|
"Waiting %d seconds before retrying ping.", timeout
|
||||||
|
)
|
||||||
time.sleep(timeout)
|
time.sleep(timeout)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
@ -91,7 +98,11 @@ logging.basicConfig(
|
||||||
markup=True,
|
markup=True,
|
||||||
console=Console(width=cols, height=lns),
|
console=Console(width=cols, height=lns),
|
||||||
),
|
),
|
||||||
FileHandler(filename=CONFIG["logging"].get("file", "jimmy.log"), encoding="utf-8", errors="replace"),
|
FileHandler(
|
||||||
|
filename=CONFIG["logging"].get("file", "jimmy.log"),
|
||||||
|
encoding="utf-8",
|
||||||
|
errors="replace",
|
||||||
|
),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
for logger in CONFIG["logging"].get("suppress", []):
|
for logger in CONFIG["logging"].get("suppress", []):
|
||||||
|
@ -130,7 +141,8 @@ class Client(commands.Bot):
|
||||||
async def start(self, token: str, *, reconnect: bool = True) -> None:
|
async def start(self, token: str, *, reconnect: bool = True) -> None:
|
||||||
if CONFIG["jimmy"].get("uptime_kuma_url"):
|
if CONFIG["jimmy"].get("uptime_kuma_url"):
|
||||||
self.uptime_thread = KumaThread(
|
self.uptime_thread = KumaThread(
|
||||||
CONFIG["jimmy"]["uptime_kuma_url"], CONFIG["jimmy"].get("uptime_kuma_interval", 58.0)
|
CONFIG["jimmy"]["uptime_kuma_url"],
|
||||||
|
CONFIG["jimmy"].get("uptime_kuma_interval", 58.0),
|
||||||
)
|
)
|
||||||
self.uptime_thread.start()
|
self.uptime_thread.start()
|
||||||
await super().start(token, reconnect=reconnect)
|
await super().start(token, reconnect=reconnect)
|
||||||
|
@ -138,7 +150,9 @@ class Client(commands.Bot):
|
||||||
async def close(self) -> None:
|
async def close(self) -> None:
|
||||||
if getattr(self, "uptime_thread", None):
|
if getattr(self, "uptime_thread", None):
|
||||||
self.uptime_thread.kill.set()
|
self.uptime_thread.kill.set()
|
||||||
await asyncio.get_event_loop().run_in_executor(None, self.uptime_thread.join)
|
await asyncio.get_event_loop().run_in_executor(
|
||||||
|
None, self.uptime_thread.join
|
||||||
|
)
|
||||||
await super().close()
|
await super().close()
|
||||||
|
|
||||||
|
|
||||||
|
@ -189,9 +203,13 @@ async def on_application_command_error(ctx: discord.ApplicationContext, exc: Exc
|
||||||
log.error(f"Error in {ctx.command} from {ctx.author} in {ctx.guild}", exc_info=exc)
|
log.error(f"Error in {ctx.command} from {ctx.author} in {ctx.guild}", exc_info=exc)
|
||||||
if isinstance(exc, commands.CommandOnCooldown):
|
if isinstance(exc, commands.CommandOnCooldown):
|
||||||
expires = discord.utils.utcnow() + datetime.timedelta(seconds=exc.retry_after)
|
expires = discord.utils.utcnow() + datetime.timedelta(seconds=exc.retry_after)
|
||||||
await ctx.respond(f"Command on cooldown. Try again {discord.utils.format_dt(expires, style='R')}.")
|
await ctx.respond(
|
||||||
|
f"Command on cooldown. Try again {discord.utils.format_dt(expires, style='R')}."
|
||||||
|
)
|
||||||
elif isinstance(exc, commands.MaxConcurrencyReached):
|
elif isinstance(exc, commands.MaxConcurrencyReached):
|
||||||
await ctx.respond("You've reached the maximum number of concurrent uses for this command.")
|
await ctx.respond(
|
||||||
|
"You've reached the maximum number of concurrent uses for this command."
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
if await bot.is_owner(ctx.author):
|
if await bot.is_owner(ctx.author):
|
||||||
paginator = commands.Paginator(prefix="```py")
|
paginator = commands.Paginator(prefix="```py")
|
||||||
|
@ -200,7 +218,10 @@ async def on_application_command_error(ctx: discord.ApplicationContext, exc: Exc
|
||||||
for page in paginator.pages:
|
for page in paginator.pages:
|
||||||
await ctx.respond(page)
|
await ctx.respond(page)
|
||||||
else:
|
else:
|
||||||
await ctx.respond(f"An error occurred while processing your command. Please try again later.\n" f"{exc}")
|
await ctx.respond(
|
||||||
|
f"An error occurred while processing your command. Please try again later.\n"
|
||||||
|
f"{exc}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@bot.listen()
|
@bot.listen()
|
||||||
|
@ -218,11 +239,22 @@ async def delete_message(ctx: discord.ApplicationContext, message: discord.Messa
|
||||||
await ctx.defer()
|
await ctx.defer()
|
||||||
if not ctx.channel.permissions_for(ctx.me).manage_messages:
|
if not ctx.channel.permissions_for(ctx.me).manage_messages:
|
||||||
if message.author != bot.user:
|
if message.author != bot.user:
|
||||||
return await ctx.respond("I don't have permission to delete messages in this channel.", delete_after=30)
|
return await ctx.respond(
|
||||||
|
"I don't have permission to delete messages in this channel.",
|
||||||
|
delete_after=30,
|
||||||
|
)
|
||||||
|
|
||||||
log.info("%s deleted message %s>%s: %r", ctx.author, ctx.channel.name, message.id, message.content)
|
log.info(
|
||||||
|
"%s deleted message %s>%s: %r",
|
||||||
|
ctx.author,
|
||||||
|
ctx.channel.name,
|
||||||
|
message.id,
|
||||||
|
message.content,
|
||||||
|
)
|
||||||
await message.delete(delay=1)
|
await message.delete(delay=1)
|
||||||
await ctx.respond(f"\N{white heavy check mark} Deleted message by {message.author.display_name}.")
|
await ctx.respond(
|
||||||
|
f"\N{WHITE HEAVY CHECK MARK} Deleted message by {message.author.display_name}."
|
||||||
|
)
|
||||||
await ctx.delete(delay=5)
|
await ctx.delete(delay=5)
|
||||||
|
|
||||||
|
|
||||||
|
@ -231,16 +263,19 @@ async def about_me(ctx: discord.ApplicationContext):
|
||||||
embed = discord.Embed(
|
embed = discord.Embed(
|
||||||
title="Jimmy v3",
|
title="Jimmy v3",
|
||||||
description="A bot specifically for the LCC discord server(s).",
|
description="A bot specifically for the LCC discord server(s).",
|
||||||
colour=discord.Colour.green()
|
colour=discord.Colour.green(),
|
||||||
)
|
)
|
||||||
embed.add_field(
|
embed.add_field(
|
||||||
name="Source",
|
name="Source",
|
||||||
value="[Source code is available under AGPLv3.](https://git.i-am.nexus/nex/college-bot-v2)"
|
value="[Source code is available under AGPLv3.](https://git.i-am.nexus/nex/college-bot-v2)",
|
||||||
)
|
)
|
||||||
user = await ctx.bot.fetch_user(421698654189912064)
|
user = await ctx.bot.fetch_user(421698654189912064)
|
||||||
appinfo = await ctx.bot.application_info()
|
appinfo = await ctx.bot.application_info()
|
||||||
author = appinfo.owner
|
author = appinfo.owner
|
||||||
embed.add_field(name="Author", value=f"{user.mention} created me. {author.mention} is running this instance.")
|
embed.add_field(
|
||||||
|
name="Author",
|
||||||
|
value=f"{user.mention} created me. {author.mention} is running this instance.",
|
||||||
|
)
|
||||||
return await ctx.respond(embed=embed)
|
return await ctx.respond(embed=embed)
|
||||||
|
|
||||||
|
|
||||||
|
@ -249,9 +284,8 @@ async def check_is_enabled(ctx: commands.Context | discord.ApplicationContext) -
|
||||||
disabled = CONFIG["jimmy"].get("disabled_commands", [])
|
disabled = CONFIG["jimmy"].get("disabled_commands", [])
|
||||||
if ctx.command.qualified_name in disabled:
|
if ctx.command.qualified_name in disabled:
|
||||||
raise commands.DisabledCommand(
|
raise commands.DisabledCommand(
|
||||||
"%s is disabled via this instance's configuration file." % (
|
"%s is disabled via this instance's configuration file."
|
||||||
ctx.command.qualified_name
|
% (ctx.command.qualified_name)
|
||||||
)
|
|
||||||
)
|
)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
@ -273,7 +307,9 @@ async def add_delays(ctx: commands.Context | discord.ApplicationContext) -> bool
|
||||||
|
|
||||||
|
|
||||||
if not CONFIG["jimmy"].get("token"):
|
if not CONFIG["jimmy"].get("token"):
|
||||||
log.critical("No token specified in config.toml. Exiting. (hint: set jimmy.token in config.toml)")
|
log.critical(
|
||||||
|
"No token specified in config.toml. Exiting. (hint: set jimmy.token in config.toml)"
|
||||||
|
)
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
__disabled = CONFIG["jimmy"].get("disabled_commands", [])
|
__disabled = CONFIG["jimmy"].get("disabled_commands", [])
|
||||||
|
|
|
@ -2,24 +2,24 @@
|
||||||
import json
|
import json
|
||||||
|
|
||||||
import redis
|
import redis
|
||||||
|
import orjson
|
||||||
import os
|
import os
|
||||||
import secrets
|
import secrets
|
||||||
import typing
|
import typing
|
||||||
import time
|
import time
|
||||||
from fastapi import FastAPI, Depends, HTTPException, APIRouter
|
from fastapi import FastAPI, Depends, HTTPException, APIRouter
|
||||||
from fastapi.responses import JSONResponse, Response
|
from fastapi.responses import JSONResponse, Response, ORJSONResponse
|
||||||
from fastapi.security import HTTPBasic, HTTPBasicCredentials
|
from fastapi.security import HTTPBasic, HTTPBasicCredentials
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
JSON: typing.Union[
|
JSON: typing.Union[
|
||||||
str, int, float, bool, None, typing.Dict[str, "JSON"], typing.List["JSON"]
|
str, int, float, bool, None, typing.Dict[str, "JSON"], typing.List["JSON"]
|
||||||
] = typing.Union[
|
] = typing.Union[str, int, float, bool, None, dict, list]
|
||||||
str, int, float, bool, None, dict, list
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
class TruthPayload(BaseModel):
|
class TruthPayload(BaseModel):
|
||||||
"""Represents a truth. This can be used to both create and get truths."""
|
"""Represents a truth. This can be used to both create and get truths."""
|
||||||
|
|
||||||
id: str
|
id: str
|
||||||
content: str
|
content: str
|
||||||
author: str
|
author: str
|
||||||
|
@ -33,6 +33,7 @@ class OllamaThread(BaseModel):
|
||||||
|
|
||||||
class ThreadMessage(BaseModel):
|
class ThreadMessage(BaseModel):
|
||||||
"""Represents a message in an Ollama thread."""
|
"""Represents a message in an Ollama thread."""
|
||||||
|
|
||||||
role: typing.Literal["assistant", "system", "user"]
|
role: typing.Literal["assistant", "system", "user"]
|
||||||
content: str
|
content: str
|
||||||
images: typing.Optional[list[str]] = []
|
images: typing.Optional[list[str]] = []
|
||||||
|
@ -54,9 +55,7 @@ USERNAME = os.getenv("WEB_USERNAME", os.urandom(32).hex())
|
||||||
PASSWORD = os.getenv("WEB_PASSWORD", os.urandom(32).hex())
|
PASSWORD = os.getenv("WEB_PASSWORD", os.urandom(32).hex())
|
||||||
|
|
||||||
|
|
||||||
accounts = {
|
accounts = {USERNAME: PASSWORD}
|
||||||
USERNAME: PASSWORD
|
|
||||||
}
|
|
||||||
if os.path.exists("./web-accounts.json"):
|
if os.path.exists("./web-accounts.json"):
|
||||||
with open("./web-accounts.json") as f:
|
with open("./web-accounts.json") as f:
|
||||||
accounts.update(json.load(f))
|
accounts.update(json.load(f))
|
||||||
|
@ -71,7 +70,9 @@ def check_credentials(credentials: HTTPBasicCredentials = Depends(security)):
|
||||||
return credentials
|
return credentials
|
||||||
|
|
||||||
|
|
||||||
def get_db_factory(n: int = 11) -> typing.Callable[[], typing.Generator[redis.Redis, None, None]]:
|
def get_db_factory(
|
||||||
|
n: int = 11,
|
||||||
|
) -> typing.Callable[[], typing.Generator[redis.Redis, None, None]]:
|
||||||
def inner():
|
def inner():
|
||||||
uri = os.getenv("REDIS_URL", "redis://redis")
|
uri = os.getenv("REDIS_URL", "redis://redis")
|
||||||
conn = redis.Redis.from_url(uri)
|
conn = redis.Redis.from_url(uri)
|
||||||
|
@ -80,38 +81,58 @@ def get_db_factory(n: int = 11) -> typing.Callable[[], typing.Generator[redis.Re
|
||||||
yield conn
|
yield conn
|
||||||
finally:
|
finally:
|
||||||
conn.close()
|
conn.close()
|
||||||
|
|
||||||
return inner
|
return inner
|
||||||
|
|
||||||
|
|
||||||
app = FastAPI(
|
app = FastAPI(
|
||||||
title="Jimmy v3 API",
|
title="Jimmy v3 API",
|
||||||
version="3.1.0",
|
version="3.1.0",
|
||||||
root_path=os.getenv("WEB_ROOT_PATH", "") + "/api"
|
root_path=os.getenv("WEB_ROOT_PATH", "") + "/api",
|
||||||
)
|
)
|
||||||
truth_router = APIRouter(
|
truth_router = APIRouter(
|
||||||
prefix="/truths",
|
prefix="/truths", dependencies=[Depends(check_credentials)], tags=["Truth Social"]
|
||||||
dependencies=[Depends(check_credentials)],
|
|
||||||
tags=["Truth Social"]
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@truth_router.get("", response_model=list[TruthPayload])
|
@truth_router.get("", response_class=ORJSONResponse, response_model=list[TruthPayload])
|
||||||
def get_all_truths(rich: bool = True, db: redis.Redis = Depends(get_db_factory())):
|
def get_all_truths(
|
||||||
"""Retrieves all stored truths"""
|
rich: bool = True,
|
||||||
|
limit: int = -1,
|
||||||
|
page: int = 0,
|
||||||
|
db: redis.Redis = Depends(get_db_factory()),
|
||||||
|
):
|
||||||
|
"""Retrieves all stored truths
|
||||||
|
|
||||||
|
If ?limit is a positive integer, pagination will be enabled."""
|
||||||
|
query_start = time.perf_counter()
|
||||||
keys = db.keys()
|
keys = db.keys()
|
||||||
|
query_end = time.perf_counter()
|
||||||
if rich is False:
|
if rich is False:
|
||||||
return [
|
return [
|
||||||
{"id": key, "content": "", "author": "", "timestamp": 0, "extra": None}
|
{"id": key, "content": "", "author": "", "timestamp": 0, "extra": None}
|
||||||
for key in keys
|
for key in keys
|
||||||
]
|
]
|
||||||
|
load_start = time.perf_counter()
|
||||||
|
if limit >= 0:
|
||||||
|
keys = keys[page * limit: (page + 1) * limit]
|
||||||
truths = [json.loads(db.get(key)) for key in keys]
|
truths = [json.loads(db.get(key)) for key in keys]
|
||||||
return truths
|
load_end = time.perf_counter()
|
||||||
|
server_timing = "query;dur=%.2f, load;dur=%.2f" % (
|
||||||
|
(query_end - query_start) * 1000,
|
||||||
|
(load_end - load_start) * 1000,
|
||||||
|
)
|
||||||
|
return ORJSONResponse(truths, headers={"Server-Timing": server_timing})
|
||||||
|
|
||||||
|
|
||||||
@truth_router.get("/all", deprecated=True, response_model=list[TruthPayload])
|
@truth_router.get("/all", deprecated=True, response_model=list[TruthPayload])
|
||||||
def get_all_truths_deprecated(response: JSONResponse, rich: bool = True, db: redis.Redis = Depends(get_db_factory())):
|
def get_all_truths_deprecated(
|
||||||
|
response: JSONResponse,
|
||||||
|
rich: bool = True,
|
||||||
|
db: redis.Redis = Depends(get_db_factory()),
|
||||||
|
):
|
||||||
"""DEPRECATED - USE get_all_truths INSTEAD"""
|
"""DEPRECATED - USE get_all_truths INSTEAD"""
|
||||||
return get_all_truths(rich, db)
|
return get_all_truths(rich=rich, db=db)
|
||||||
|
|
||||||
|
|
||||||
@truth_router.get("/{truth_id}", response_model=TruthPayload)
|
@truth_router.get("/{truth_id}", response_model=TruthPayload)
|
||||||
|
@ -133,7 +154,11 @@ def head_truth(truth_id: str, db: redis.Redis = Depends(get_db_factory())):
|
||||||
|
|
||||||
|
|
||||||
@truth_router.post("", status_code=201, response_model=TruthPayload)
|
@truth_router.post("", status_code=201, response_model=TruthPayload)
|
||||||
def new_truth(payload: TruthPayload, response: JSONResponse, db: redis.Redis = Depends(get_db_factory())):
|
def new_truth(
|
||||||
|
payload: TruthPayload,
|
||||||
|
response: JSONResponse,
|
||||||
|
db: redis.Redis = Depends(get_db_factory()),
|
||||||
|
):
|
||||||
"""Creates a new truth"""
|
"""Creates a new truth"""
|
||||||
data = payload.model_dump()
|
data = payload.model_dump()
|
||||||
existing: str = db.get(data["id"])
|
existing: str = db.get(data["id"])
|
||||||
|
@ -148,7 +173,9 @@ def new_truth(payload: TruthPayload, response: JSONResponse, db: redis.Redis = D
|
||||||
|
|
||||||
|
|
||||||
@truth_router.put("/{truth_id}", response_model=TruthPayload)
|
@truth_router.put("/{truth_id}", response_model=TruthPayload)
|
||||||
def put_truth(truth_id: str, payload: TruthPayload, db: redis.Redis = Depends(get_db_factory())):
|
def put_truth(
|
||||||
|
truth_id: str, payload: TruthPayload, db: redis.Redis = Depends(get_db_factory())
|
||||||
|
):
|
||||||
"""Replaces a stored truth"""
|
"""Replaces a stored truth"""
|
||||||
data = payload.model_dump()
|
data = payload.model_dump()
|
||||||
existing = db.get(truth_id)
|
existing = db.get(truth_id)
|
||||||
|
@ -168,9 +195,7 @@ def delete_truth(truth_id: str, db: redis.Redis = Depends(get_db_factory())):
|
||||||
|
|
||||||
app.include_router(truth_router)
|
app.include_router(truth_router)
|
||||||
ollama_router = APIRouter(
|
ollama_router = APIRouter(
|
||||||
prefix="/ollama",
|
prefix="/ollama", dependencies=[Depends(check_credentials)], tags=["Ollama"]
|
||||||
dependencies=[Depends(check_credentials)],
|
|
||||||
tags=["Ollama"]
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -185,13 +210,13 @@ def get_ollama_threads(db: redis.Redis = Depends(get_db_factory(0))):
|
||||||
return keys
|
return keys
|
||||||
|
|
||||||
|
|
||||||
@ollama_router.get("/thread/{thread_id}", response_model=OllamaThread)
|
@ollama_router.get("/thread/{thread_id}", response_class=ORJSONResponse, response_model=OllamaThread)
|
||||||
def get_ollama_thread(thread_id: str, db: redis.Redis = Depends(get_db_factory(0))):
|
def get_ollama_thread(thread_id: str, db: redis.Redis = Depends(get_db_factory(0))):
|
||||||
"""Retrieves a stored thread"""
|
"""Retrieves a stored thread"""
|
||||||
data: str = db.get(thread_id)
|
data: str = db.get(thread_id)
|
||||||
if not data:
|
if not data:
|
||||||
raise HTTPException(404, detail="%r not found." % thread_id)
|
raise HTTPException(404, detail="%r not found." % thread_id)
|
||||||
return json.loads(data)
|
return ORJSONResponse(orjson.loads(data.encode()))
|
||||||
|
|
||||||
|
|
||||||
@ollama_router.delete("/thread/{thread_id}", status_code=204)
|
@ollama_router.delete("/thread/{thread_id}", status_code=204)
|
||||||
|
@ -216,4 +241,5 @@ def health(db: redis.Redis = Depends(get_db_factory())):
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import uvicorn
|
import uvicorn
|
||||||
|
|
||||||
uvicorn.run(app, host="0.0.0.0", port=1111, forwarded_allow_ips="*")
|
uvicorn.run(app, host="0.0.0.0", port=1111, forwarded_allow_ips="*")
|
||||||
|
|
Loading…
Reference in a new issue