From 0ea24df64e40c863399824a4fb3bf0a3c3cd2c56 Mon Sep 17 00:00:00 2001 From: nexy7574 Date: Thu, 4 Jul 2024 01:53:08 +0100 Subject: [PATCH 01/24] Add election countdown --- src/cogs/election.py | 61 ++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 59 insertions(+), 2 deletions(-) diff --git a/src/cogs/election.py b/src/cogs/election.py index b490787..9a71cc4 100644 --- a/src/cogs/election.py +++ b/src/cogs/election.py @@ -4,12 +4,13 @@ This module is only meant to be loaded during election times. import asyncio import datetime import logging +import random import re import discord import httpx from bs4 import BeautifulSoup -from discord.ext import commands +from discord.ext import commands, tasks SPAN_REGEX = re.compile( @@ -39,10 +40,66 @@ class ElectionCog(commands.Cog): "User-Agent": "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/124.0.0.0 " "Safari/537.36", } + ETA = datetime.datetime( + 2024, + 7, + 4, + 23, + 30, + tzinfo=datetime.datetime.now().astimezone().tzinfo + ) def __init__(self, bot): - self.bot = bot + self.bot: commands.Bot = bot self.log = logging.getLogger("jimmy.cogs.election") + self.countdown_message = None + self.check_election.start() + + def cog_unload(self) -> None: + self.check_election.cancel() + + @tasks.loop(minutes=1) + async def check_election(self): + if not self.bot.is_ready(): + await self.bot.wait_until_ready() + + guild = self.bot.get_guild(994710566612500550) + if not guild: + return self.log.error("Nonsense guild not found. Can't do countdown.") + channel = discord.utils.get(guild.text_channels, name="countdown") + if not channel: + return self.log.error("Countdown channel not found.") + + await asyncio.sleep(random.randint(0, 10)) + now = discord.utils.utcnow() + diff = (self.ETA - now).total_seconds() + if diff < -86400: + return self.log.debug("Countdown long expired.") + + if diff > 3600: + hours, remainder = map(round, divmod(diff, 3600)) + minutes, seconds = map(round, divmod(remainder, 60)) + message = f"+ {hours} hours, {minutes} minutes, and {seconds} seconds." + elif diff > 60: + minutes, seconds = map(round, divmod(diff, 60)) + message = f"+ {minutes} minutes and {seconds} seconds." + elif diff >= 0: + message = f"+ {round(diff)} seconds." + else: + message = "Results time!" + + if self.countdown_message: + try: + return await self.countdown_message.edit( + content=f"```diff\n{message}```" + ) + except discord.HTTPException: + self.log.exception("Failed to edit countdown message.") + self.countdown_message = await channel.send( + f"```diff\n{message}```" + ) + self.log.debug("Sent countdown message") + def process_soup(self, soup: BeautifulSoup) -> dict[str, list[int]] | None: good_soup = soup.find(attrs={"data-testid": "election-banner-results-bar"}) From 5401cafb6f6b7ebd90b7543311c8c8c3e0019944 Mon Sep 17 00:00:00 2001 From: nexy7574 Date: Thu, 4 Jul 2024 01:54:14 +0100 Subject: [PATCH 02/24] Fix missing import --- src/cogs/election.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/cogs/election.py b/src/cogs/election.py index 7f55190..cd59bf7 100644 --- a/src/cogs/election.py +++ b/src/cogs/election.py @@ -4,6 +4,7 @@ This module is only meant to be loaded during election times. import asyncio import logging import random +import datetime import re import discord From a8e43fff03100273afb48dede59489b6350ebb68 Mon Sep 17 00:00:00 2001 From: nexy7574 Date: Fri, 5 Jul 2024 00:11:56 +0100 Subject: [PATCH 03/24] Patch together the election command for 2024 --- src/cogs/election.py | 73 ++++++++++++++++++-------------------------- 1 file changed, 30 insertions(+), 43 deletions(-) diff --git a/src/cogs/election.py b/src/cogs/election.py index b490787..f8bb28e 100644 --- a/src/cogs/election.py +++ b/src/cogs/election.py @@ -45,55 +45,42 @@ class ElectionCog(commands.Cog): self.log = logging.getLogger("jimmy.cogs.election") def process_soup(self, soup: BeautifulSoup) -> dict[str, list[int]] | None: - good_soup = soup.find(attrs={"data-testid": "election-banner-results-bar"}) - if not good_soup: + good_soups = list(soup.find_all(attrs={"data-testid": "election-banner-results-bar"})) + if not good_soups: return - - css: str = "\n".join([x.get_text() for x in soup.find_all("style")]) - - def find_colour(style_name: str, want: str = "background-color") -> str | None: - self.log.info("Looking for style %r", style_name) - index = css.index(style_name) + len(style_name) + 1 - value = "" - for char in css[index:]: - if char == "}": - break - value += char - attributes = filter(None, value.split(";")) - parsed = {} - for attr in attributes: - name, val = attr.split(":") - parsed[name] = val - self.log.info("Parsed the following attributes: %r", parsed) - if want in parsed: - self.log.info("Returning %r: %r", want, parsed[want]) - return parsed[want] - self.log.warning("%r was not in attributes.", want) + good_soup = list(good_soups)[1] results: dict[str, list[int]] = {} - for child_ul in good_soup.children: - child_ul: BeautifulSoup - span = child_ul.find("span", recursive=False) - if not span: - self.log.warning("%r did not have a 'span' element.", child_ul) + for child_li in good_soup.children: + try: + party, extra = child_li.get_text().strip().split(":", 1) + seats, extra = extra.split(",", 1) + seats = int(seats.split()[0]) + change = -1 + except ValueError: + self.log.error("failed to parse %r", child_li) continue + results[party] = [seats, change, 0] + # if not span: + # self.log.warning("%r did not have a 'span' element.", child_ul) + # continue - text = span.get_text().replace(",", "") - groups = SPAN_REGEX.match(text) - if groups: - groups = groups.groupdict() - else: - self.log.warning( - "Found span element (%r), however resolved text (%r) did not match regex.", - span, text - ) - continue + # text = span.get_text().replace(",", "") + # groups = SPAN_REGEX.match(text) + # if groups: + # groups = groups.groupdict() + # else: + # self.log.warning( + # "Found span element (%r), however resolved text (%r) did not match regex.", + # span, text + # ) + # continue - results[str(groups["party"]).strip()] = [ - int(groups["councillors"].strip()), - int(groups["net"].strip()) * MULTI[groups["net_change"]], - int(find_colour(child_ul.next["class"][0])[1:], base=16) - ] + # results[str(groups["party"]).strip()] = [ + # int(groups["councillors"].strip()), + # int(groups["net"].strip()) * MULTI[groups["net_change"]], + # int(find_colour(child_ul.next["class"][0])[1:], base=16) + # ] return results @commands.slash_command(name="election") From a3d03d344658c12e293413a38ac73160e6270cbe Mon Sep 17 00:00:00 2001 From: nexy7574 Date: Fri, 5 Jul 2024 00:18:23 +0100 Subject: [PATCH 04/24] Fix parser --- src/cogs/election.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/cogs/election.py b/src/cogs/election.py index 06ed705..b556d5d 100644 --- a/src/cogs/election.py +++ b/src/cogs/election.py @@ -100,7 +100,6 @@ class ElectionCog(commands.Cog): ) self.log.debug("Sent countdown message") - def process_soup(self, soup: BeautifulSoup) -> dict[str, list[int]] | None: good_soups = list(soup.find_all(attrs={"data-testid": "election-banner-results-bar"})) if not good_soups: @@ -109,13 +108,13 @@ class ElectionCog(commands.Cog): results: dict[str, list[int]] = {} for child_li in good_soup.children: + span = list(child_li.children)[-1] try: - party, extra = child_li.get_text().strip().split(":", 1) + party, extra = span.get_text().strip().split(":", 1) seats, extra = extra.split(",", 1) - seats = int(seats.split()[0]) - change = -1 + seats = change = int(seats.split()[0]) except ValueError: - self.log.error("failed to parse %r", child_li) + self.log.error("failed to parse %r", span) continue results[party] = [seats, change, 0] # if not span: From 0071ea19ac1791388025060a09ac9ef605ea4061 Mon Sep 17 00:00:00 2001 From: nexy7574 Date: Fri, 5 Jul 2024 00:22:02 +0100 Subject: [PATCH 05/24] log when souping goes wrong --- src/cogs/election.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/cogs/election.py b/src/cogs/election.py index b556d5d..38e9c31 100644 --- a/src/cogs/election.py +++ b/src/cogs/election.py @@ -103,6 +103,7 @@ class ElectionCog(commands.Cog): def process_soup(self, soup: BeautifulSoup) -> dict[str, list[int]] | None: good_soups = list(soup.find_all(attrs={"data-testid": "election-banner-results-bar"})) if not good_soups: + self.log.error("No 'election-banner-results-bar' elements found:\n%r", soup.prettify()) return good_soup = list(good_soups)[1] From 54c49727dcfeb0f556aa6f1f8657faa69f0955ae Mon Sep 17 00:00:00 2001 From: nexy7574 Date: Fri, 5 Jul 2024 00:22:17 +0100 Subject: [PATCH 06/24] Disable the timer --- src/cogs/election.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/cogs/election.py b/src/cogs/election.py index 38e9c31..7cb47f2 100644 --- a/src/cogs/election.py +++ b/src/cogs/election.py @@ -53,7 +53,7 @@ class ElectionCog(commands.Cog): self.bot: commands.Bot = bot self.log = logging.getLogger("jimmy.cogs.election") self.countdown_message = None - self.check_election.start() + # self.check_election.start() def cog_unload(self) -> None: self.check_election.cancel() From a8f5ae145a20907d0f2828bf4ebd5adf1d1654fc Mon Sep 17 00:00:00 2001 From: nexy7574 Date: Fri, 5 Jul 2024 00:27:59 +0100 Subject: [PATCH 07/24] Properly report gains and losses --- src/cogs/election.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/src/cogs/election.py b/src/cogs/election.py index 7cb47f2..d35fb1e 100644 --- a/src/cogs/election.py +++ b/src/cogs/election.py @@ -113,7 +113,15 @@ class ElectionCog(commands.Cog): try: party, extra = span.get_text().strip().split(":", 1) seats, extra = extra.split(",", 1) - seats = change = int(seats.split()[0]) + extra = extra.strip() + if extra.lower() == "no change": + change = 0 + else: + _values = extra.split() + change = int(_values[0]) + if _values[-1] != "gained": + change *= -1 + seats = int(seats.split()[0]) except ValueError: self.log.error("failed to parse %r", span) continue From dbd06bfe38146d317d60ec637054abe8b268114e Mon Sep 17 00:00:00 2001 From: nexy7574 Date: Fri, 5 Jul 2024 00:34:14 +0100 Subject: [PATCH 08/24] Add the refresh button --- src/cogs/election.py | 85 +++++++++++++++++++++++++++----------------- 1 file changed, 52 insertions(+), 33 deletions(-) diff --git a/src/cogs/election.py b/src/cogs/election.py index d35fb1e..633cdcc 100644 --- a/src/cogs/election.py +++ b/src/cogs/election.py @@ -2,6 +2,8 @@ This module is only meant to be loaded during election times. """ import asyncio +from calendar import c +from contextlib import asynccontextmanager import logging import random import datetime @@ -148,46 +150,63 @@ class ElectionCog(commands.Cog): # ] return results - @commands.slash_command(name="election") - async def get_election_results(self, ctx: discord.ApplicationContext): - """Gets the current election results""" - await ctx.defer() + async def _get_embed(self) -> discord.Embed | None: async with httpx.AsyncClient(headers=self.HEADERS) as client: response = await client.get(self.SOURCE, follow_redirects=True) if response.status_code != 200: - return await ctx.respond( - "Sorry, I can't do that right now (HTTP %d while fetching results from BBC)" % response.status_code - ) - - # noinspection PyTypeChecker + raise RuntimeError(f"HTTP {response.status_code} while fetching results from BBC") soup = await asyncio.to_thread(BeautifulSoup, response.text, "html.parser") - results = await self.bot.loop.run_in_executor(None, self.process_soup, soup) - if results: - now = discord.utils.utcnow() - date = now.date().strftime("%B %Y") - colour_scores = {} - embed = discord.Embed( - title="Election results - " + date, - url="https://bbc.co.uk/", - timestamp=now + results = await self.bot.loop.run_in_executor(None, self.process_soup, soup) + if results: + now = discord.utils.utcnow() + date = now.date().strftime("%B %Y") + colour_scores = {} + embed = discord.Embed( + title="Election results - " + date, + url="https://bbc.co.uk/", + timestamp=now + ) + embed.set_footer(text="Source from bbc.co.uk.") + description_parts = [] + + for party_name, values in results.items(): + councillors, net, colour = values + colour_scores[party_name] = councillors + symbol = "+" if net > 0 else '' + description_parts.append( + f"**{party_name}**: {symbol}{net:,} ({councillors:,} total)" ) - embed.set_footer(text="Source from bbc.co.uk.") - description_parts = [] - for party_name, values in results.items(): - councillors, net, colour = values - colour_scores[party_name] = councillors - symbol = "+" if net > 0 else '' - description_parts.append( - f"**{party_name}**: {symbol}{net:,} ({councillors:,} total)" - ) + top_party = list(sorted(colour_scores.keys(), key=lambda k: colour_scores[k], reverse=True))[0] + embed.colour = discord.Colour(results[top_party][2]) + embed.description = "\n".join(description_parts) + return embed - top_party = list(sorted(colour_scores.keys(), key=lambda k: colour_scores[k], reverse=True))[0] - embed.colour = discord.Colour(results[top_party][2]) - embed.description = "\n".join(description_parts) - return await ctx.respond(embed=embed) - else: - return await ctx.respond("Unable to get election results at this time.") + @commands.slash_command(name="election") + async def get_election_results(self, ctx: discord.ApplicationContext): + """Gets the current election results""" + class RefreshView(discord.ui.View): + @discord.ui.button(label="Refresh", style=discord.ButtonStyle.primary, emoji="\U0001f501") + async def refresh(_self, _btn, interaction): + await interaction.response.defer(invisible=True) + try: + embed = await self._get_embed() + except Exception as e: + self.log.exception("Failed to get election results.") + return await interaction.followup.send(f"Sorry, I cannot contact the BBC at this time: {e}") + if embed is None: + return await interaction.followup.send("Sorry, I could not find any election results.") + await interaction.edit_original_response(embed=embed) + + await ctx.defer() + try: + embed = await self._get_embed() + except Exception as e: + self.log.exception("Failed to get election results.") + return await ctx.respond(f"Sorry, I cannot contact the BBC at this time: {e}") + if embed is None: + return await ctx.respond("Sorry, I could not find any election results.") + await ctx.respond(embed=embed, view=RefreshView()) def setup(bot): From 5a5752eae7b9301fcb5a75852b400fa020e468f0 Mon Sep 17 00:00:00 2001 From: nexy7574 Date: Fri, 5 Jul 2024 00:35:10 +0100 Subject: [PATCH 09/24] Extend refresh timeout --- src/cogs/election.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/cogs/election.py b/src/cogs/election.py index 633cdcc..be1476a 100644 --- a/src/cogs/election.py +++ b/src/cogs/election.py @@ -206,7 +206,7 @@ class ElectionCog(commands.Cog): return await ctx.respond(f"Sorry, I cannot contact the BBC at this time: {e}") if embed is None: return await ctx.respond("Sorry, I could not find any election results.") - await ctx.respond(embed=embed, view=RefreshView()) + await ctx.respond(embed=embed, view=RefreshView(timeout=3600, disable_on_timeout=True)) def setup(bot): From 0874ad78e0eb17d72641ee14957dc4ae5ecd792a Mon Sep 17 00:00:00 2001 From: nexy7574 Date: Fri, 5 Jul 2024 00:42:12 +0100 Subject: [PATCH 10/24] Include 2019 in counts --- src/cogs/election.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/src/cogs/election.py b/src/cogs/election.py index be1476a..44b0c3f 100644 --- a/src/cogs/election.py +++ b/src/cogs/election.py @@ -127,7 +127,7 @@ class ElectionCog(commands.Cog): except ValueError: self.log.error("failed to parse %r", span) continue - results[party] = [seats, change, 0] + results[party] = [seats, change, 0, 0] # if not span: # self.log.warning("%r did not have a 'span' element.", child_ul) # continue @@ -148,6 +148,17 @@ class ElectionCog(commands.Cog): # int(groups["net"].strip()) * MULTI[groups["net_change"]], # int(find_colour(child_ul.next["class"][0])[1:], base=16) # ] + for child_li in good_soups[1].children: + span = list(child_li.children)[-1] + try: + party, extra = span.get_text().strip().split(":", 1) + seats, _ = extra.split(" ", 1) + seats = int(seats) + except ValueError: + self.log.error("failed to parse %r", span) + continue + if party in results: + results[party][3] = seats return results async def _get_embed(self) -> discord.Embed | None: @@ -170,11 +181,11 @@ class ElectionCog(commands.Cog): description_parts = [] for party_name, values in results.items(): - councillors, net, colour = values + councillors, net, colour, last_election = values colour_scores[party_name] = councillors symbol = "+" if net > 0 else '' description_parts.append( - f"**{party_name}**: {symbol}{net:,} ({councillors:,} total)" + f"**{party_name}**: {symbol}{net:,} ({councillors:,} total, {last_election:,} in 2019)" ) top_party = list(sorted(colour_scores.keys(), key=lambda k: colour_scores[k], reverse=True))[0] From b8f73f3befd2971d572e4696d134cf4f358d50a3 Mon Sep 17 00:00:00 2001 From: nexy7574 Date: Fri, 5 Jul 2024 00:46:07 +0100 Subject: [PATCH 11/24] fix 2019 parsing --- src/cogs/election.py | 24 ++---------------------- 1 file changed, 2 insertions(+), 22 deletions(-) diff --git a/src/cogs/election.py b/src/cogs/election.py index 44b0c3f..67df7f8 100644 --- a/src/cogs/election.py +++ b/src/cogs/election.py @@ -128,32 +128,12 @@ class ElectionCog(commands.Cog): self.log.error("failed to parse %r", span) continue results[party] = [seats, change, 0, 0] - # if not span: - # self.log.warning("%r did not have a 'span' element.", child_ul) - # continue - - # text = span.get_text().replace(",", "") - # groups = SPAN_REGEX.match(text) - # if groups: - # groups = groups.groupdict() - # else: - # self.log.warning( - # "Found span element (%r), however resolved text (%r) did not match regex.", - # span, text - # ) - # continue - - # results[str(groups["party"]).strip()] = [ - # int(groups["councillors"].strip()), - # int(groups["net"].strip()) * MULTI[groups["net_change"]], - # int(find_colour(child_ul.next["class"][0])[1:], base=16) - # ] for child_li in good_soups[1].children: span = list(child_li.children)[-1] try: party, extra = span.get_text().strip().split(":", 1) - seats, _ = extra.split(" ", 1) - seats = int(seats) + seats, _ = extra.strip().split(" ", 1) + seats = int(seats.strip()) except ValueError: self.log.error("failed to parse %r", span) continue From ee3c3df28919a41e0d94f964b4ebc5c54be21b0f Mon Sep 17 00:00:00 2001 From: nexy7574 Date: Fri, 5 Jul 2024 00:48:29 +0100 Subject: [PATCH 12/24] wrong index --- src/cogs/election.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/cogs/election.py b/src/cogs/election.py index 67df7f8..a805156 100644 --- a/src/cogs/election.py +++ b/src/cogs/election.py @@ -128,7 +128,7 @@ class ElectionCog(commands.Cog): self.log.error("failed to parse %r", span) continue results[party] = [seats, change, 0, 0] - for child_li in good_soups[1].children: + for child_li in good_soups[0].children: span = list(child_li.children)[-1] try: party, extra = span.get_text().strip().split(":", 1) From 5d6123dffc9554b948e60851b6c3abd20cbac27d Mon Sep 17 00:00:00 2001 From: nexy7574 Date: Fri, 5 Jul 2024 00:51:52 +0100 Subject: [PATCH 13/24] Ratelimit refreshes --- src/cogs/election.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/cogs/election.py b/src/cogs/election.py index a805156..c409a21 100644 --- a/src/cogs/election.py +++ b/src/cogs/election.py @@ -180,6 +180,9 @@ class ElectionCog(commands.Cog): @discord.ui.button(label="Refresh", style=discord.ButtonStyle.primary, emoji="\U0001f501") async def refresh(_self, _btn, interaction): await interaction.response.defer(invisible=True) + if interaction.message.edited_at: + if (discord.utils.utcnow() - interaction.message.edited_at).total_seconds() < 5: + return await interaction.followup.send("Slow down.", ephemeral=True) try: embed = await self._get_embed() except Exception as e: From 495d10c78d6889149c36ddf5240531dde2767b28 Mon Sep 17 00:00:00 2001 From: nexy7574 Date: Fri, 5 Jul 2024 00:52:52 +0100 Subject: [PATCH 14/24] Correct embed --- src/cogs/election.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/cogs/election.py b/src/cogs/election.py index c409a21..57c85d7 100644 --- a/src/cogs/election.py +++ b/src/cogs/election.py @@ -165,7 +165,8 @@ class ElectionCog(commands.Cog): colour_scores[party_name] = councillors symbol = "+" if net > 0 else '' description_parts.append( - f"**{party_name}**: {symbol}{net:,} ({councillors:,} total, {last_election:,} in 2019)" + f"**{party_name}**: {symbol}{net:,} ({councillors:,} total, " + f"{last_election:,} predicted by exit poll)" ) top_party = list(sorted(colour_scores.keys(), key=lambda k: colour_scores[k], reverse=True))[0] From 34511e87eadaa7b48602cf06df2ad1fbd8173c1c Mon Sep 17 00:00:00 2001 From: nexy7574 Date: Fri, 5 Jul 2024 01:03:33 +0100 Subject: [PATCH 15/24] Fix cooldown, again --- src/cogs/election.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/src/cogs/election.py b/src/cogs/election.py index 57c85d7..825ec21 100644 --- a/src/cogs/election.py +++ b/src/cogs/election.py @@ -166,7 +166,7 @@ class ElectionCog(commands.Cog): symbol = "+" if net > 0 else '' description_parts.append( f"**{party_name}**: {symbol}{net:,} ({councillors:,} total, " - f"{last_election:,} predicted by exit poll)" + f"{last_election:,} predicted by exit poll)" ) top_party = list(sorted(colour_scores.keys(), key=lambda k: colour_scores[k], reverse=True))[0] @@ -178,12 +178,15 @@ class ElectionCog(commands.Cog): async def get_election_results(self, ctx: discord.ApplicationContext): """Gets the current election results""" class RefreshView(discord.ui.View): + def __init__(**kwargs): + super().__init__(**kwargs) + self.last_edit = discord.utils.utcnow() + @discord.ui.button(label="Refresh", style=discord.ButtonStyle.primary, emoji="\U0001f501") async def refresh(_self, _btn, interaction): await interaction.response.defer(invisible=True) - if interaction.message.edited_at: - if (discord.utils.utcnow() - interaction.message.edited_at).total_seconds() < 5: - return await interaction.followup.send("Slow down.", ephemeral=True) + if (discord.utils.utcnow() - self.last_edit).total_seconds() < 10: + return await interaction.followup.send("Slow down.", ephemeral=True) try: embed = await self._get_embed() except Exception as e: From 00b80a929cfc2c18999664e461ce646ad836b452 Mon Sep 17 00:00:00 2001 From: nexy7574 Date: Fri, 5 Jul 2024 01:06:10 +0100 Subject: [PATCH 16/24] fuck --- src/cogs/election.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/cogs/election.py b/src/cogs/election.py index 825ec21..f206f4f 100644 --- a/src/cogs/election.py +++ b/src/cogs/election.py @@ -178,7 +178,7 @@ class ElectionCog(commands.Cog): async def get_election_results(self, ctx: discord.ApplicationContext): """Gets the current election results""" class RefreshView(discord.ui.View): - def __init__(**kwargs): + def __init__(self, **kwargs): super().__init__(**kwargs) self.last_edit = discord.utils.utcnow() From d856c102608a31457e6d3d724c68d67bee391408 Mon Sep 17 00:00:00 2001 From: nexy7574 Date: Mon, 8 Jul 2024 01:53:31 +0100 Subject: [PATCH 17/24] Add onion feed to cogs --- src/cogs/onion_feed.py | 100 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 100 insertions(+) create mode 100644 src/cogs/onion_feed.py diff --git a/src/cogs/onion_feed.py b/src/cogs/onion_feed.py new file mode 100644 index 0000000..18c6b08 --- /dev/null +++ b/src/cogs/onion_feed.py @@ -0,0 +1,100 @@ +""" +Scrapes the onion RSS feed once every hour and posts any new articles to the desired channel +""" +import asyncio +import dataclasses +import datetime +import logging +from bs4 import BeautifulSoup +import discord +from discord.ext import commands, tasks +import httpx +from conf import CONFIG +import redis + + +@dataclasses.dataclass +class RSSItem: + title: str + link: str + description: str + pubDate: datetime.datetime + guid: str + thumbnail: str + + +class OnionFeed(commands.Cog): + SOURCE = "https://www.theonion.com/rss" + EPOCH = datetime.datetime(2024, 7, 7, tzinfo=datetime.timezone.utc) + + def __init__(self, bot): + self.bot: commands.Bot = bot + self.log = logging.getLogger("jimmy.cogs.onion_feed") + self.check_onion_feed.start() + self.redis = redis.Redis(**CONFIG["redis"]) + + def cog_unload(self) -> None: + self.check_onion_feed.cancel() + + @staticmethod + def parse_item(item: BeautifulSoup) -> RSSItem: + kwargs = { + "title": item.title.get_text(strip=True).strip(), + "link": item.link.get_text(strip=True).strip(), + "pubDate": datetime.datetime.strptime( + item.pubDate.get_text(strip=True).strip(), "%a, %d %b %Y %H:%M:%S %Z" + ), + "guid": item.guid.get_text(strip=True).strip(), + "description": BeautifulSoup(item.description.get_text()).p.get_text(strip=True).strip()[:-1], + "thumbnail": item.find("media:thumbnail")["url"], + } + return RSSItem(**kwargs) + + def parse_feed(self, text: str) -> list[RSSItem]: + soup = BeautifulSoup(text, "xml") + return [self.parse_item(item) for item in soup.find_all("item")] + + @tasks.loop(hours=1) + async def check_onion_feed(self): + if not self.bot.is_ready(): + await self.bot.wait_until_ready() + + guild = self.bot.get_guild(994710566612500550) + if not guild: + return self.log.error("Nonsense guild not found. Can't do onion feed.") + channel = discord.utils.get(guild.text_channels, name="spam") + if not channel: + return self.log.error("Spam channel not found.") + + async with httpx.AsyncClient() as client: + response = await client.get(self.SOURCE) + if response.status_code != 200: + return self.log.error(f"Failed to fetch onion feed: {response.status_code}") + items: list[RSSItem] = await asyncio.to_thread(self.parse_feed, response.text) + for item in items: + if self.redis.get("onion-" + item.guid): + continue + embed = discord.Embed( + title=item.title, + url=item.link, + description=item.description + f"... [Read More]({item.link})", + color=0x00DF78, + timestamp=item.pubDate, + ) + embed.set_thumbnail(url=item.thumbnail) + try: + msg = await channel.send(embed=embed) + # noinspection PyAsyncCall + self.redis.set("onion-" + item.guid, str(msg.id)) + except discord.HTTPException: + self.log.exception(f"Failed to send onion feed message: {item.title}") + else: + self.log.debug(f"Sent onion feed message: {item.title}") + + @check_onion_feed.before_loop + async def before_check_onion_feed(self): + await self.bot.wait_until_ready() + + +def setup(bot): + bot.add_cog(OnionFeed(bot)) From c61807071888624db7fa9fd498f4d270140ee879 Mon Sep 17 00:00:00 2001 From: nexy7574 Date: Mon, 8 Jul 2024 01:56:20 +0100 Subject: [PATCH 18/24] Fix unreliable HTML parsing --- src/cogs/onion_feed.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/cogs/onion_feed.py b/src/cogs/onion_feed.py index 18c6b08..a1e5ca9 100644 --- a/src/cogs/onion_feed.py +++ b/src/cogs/onion_feed.py @@ -25,7 +25,7 @@ class RSSItem: class OnionFeed(commands.Cog): SOURCE = "https://www.theonion.com/rss" - EPOCH = datetime.datetime(2024, 7, 7, tzinfo=datetime.timezone.utc) + EPOCH = datetime.datetime(2024, 7, 1, tzinfo=datetime.timezone.utc) def __init__(self, bot): self.bot: commands.Bot = bot @@ -38,6 +38,7 @@ class OnionFeed(commands.Cog): @staticmethod def parse_item(item: BeautifulSoup) -> RSSItem: + description = BeautifulSoup(item.description.get_text(), "html.parser").p.get_text(strip=True).strip()[:-1] kwargs = { "title": item.title.get_text(strip=True).strip(), "link": item.link.get_text(strip=True).strip(), @@ -45,7 +46,7 @@ class OnionFeed(commands.Cog): item.pubDate.get_text(strip=True).strip(), "%a, %d %b %Y %H:%M:%S %Z" ), "guid": item.guid.get_text(strip=True).strip(), - "description": BeautifulSoup(item.description.get_text()).p.get_text(strip=True).strip()[:-1], + "description": description, "thumbnail": item.find("media:thumbnail")["url"], } return RSSItem(**kwargs) From 608dd46afa31365d2bb6e0db1e109063cc5fde07 Mon Sep 17 00:00:00 2001 From: nexy7574 Date: Wed, 10 Jul 2024 23:20:24 +0100 Subject: [PATCH 19/24] Enable better compression for auto-transcoding --- src/cogs/auto_responder.py | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/src/cogs/auto_responder.py b/src/cogs/auto_responder.py index 03f8329..f21e168 100644 --- a/src/cogs/auto_responder.py +++ b/src/cogs/auto_responder.py @@ -68,11 +68,13 @@ class AutoResponder(commands.Cog): @staticmethod @typing.overload - def extract_links(text: str, *domains: str, raw: typing.Literal[True] = False) -> list[ParseResult]: ... + def extract_links(text: str, *domains: str, raw: typing.Literal[True] = False) -> list[ParseResult]: + ... @staticmethod @typing.overload - def extract_links(text: str, *domains: str, raw: typing.Literal[False] = False) -> list[str]: ... + def extract_links(text: str, *domains: str, raw: typing.Literal[False] = False) -> list[str]: + ... @staticmethod def extract_links(text: str, *domains: str, raw: bool = False) -> list[str | ParseResult]: @@ -147,6 +149,7 @@ class AutoResponder(commands.Cog): return update_reaction("\N{TIMER CLOCK}\U0000fe0f") streams = info.get("streams", []) hwaccel = True + maxrate = "5M" for stream in streams: self.log.info("Found stream: %s", stream.get("codec_name")) if stream.get("codec_name") == "hevc": @@ -159,8 +162,12 @@ class AutoResponder(commands.Cog): hwaccel = False break else: - self.log.info("No HEVC streams found in %s", uri) - return update_reaction() + if int(info["format"]["size"]) >= 25 * 1024 * 1024: + self.log.warning("%s is too large to render in discord, compressing", uri) + maxrate = "1M" + else: + self.log.info("No HEVC streams found in %s", uri) + return update_reaction() extension = pathlib.Path(uri).suffix with tempfile.NamedTemporaryFile(suffix=extension) as tmp_dl: self.log.info("Downloading %r to %r", uri, tmp_dl.name) @@ -195,10 +202,8 @@ class AutoResponder(commands.Cog): tmp_dl.name, "-c:v", "libx264", - "-crf", - "25", "-maxrate", - "5M", + maxrate, "-minrate", "100K", "-bufsize", @@ -216,7 +221,7 @@ class AutoResponder(commands.Cog): "-movflags", "faststart", "-profile:v", - "main", + "high", "-y", "-hide_banner", ] From 5fbd425452c21fd1526c805a2ffc71df47000e58 Mon Sep 17 00:00:00 2001 From: nexy7574 Date: Sun, 14 Jul 2024 01:33:48 +0100 Subject: [PATCH 20/24] enable pagination in /truths --- src/server.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/src/server.py b/src/server.py index 973423a..6e2b245 100644 --- a/src/server.py +++ b/src/server.py @@ -96,7 +96,12 @@ truth_router = APIRouter( @truth_router.get("", response_model=list[TruthPayload]) -def get_all_truths(rich: bool = True, db: redis.Redis = Depends(get_db_factory())): +def get_all_truths( + rich: bool = True, + limit: int = -1, + page: int = 0, + db: redis.Redis = Depends(get_db_factory()) +): """Retrieves all stored truths""" keys = db.keys() if rich is False: @@ -105,13 +110,15 @@ def get_all_truths(rich: bool = True, db: redis.Redis = Depends(get_db_factory() for key in keys ] truths = [json.loads(db.get(key)) for key in keys] + if limit >= 0: + return truths[page * limit:(page + 1) * limit] return truths @truth_router.get("/all", deprecated=True, response_model=list[TruthPayload]) def get_all_truths_deprecated(response: JSONResponse, rich: bool = True, db: redis.Redis = Depends(get_db_factory())): """DEPRECATED - USE get_all_truths INSTEAD""" - return get_all_truths(rich, db) + return get_all_truths(rich=rich, db=db) @truth_router.get("/{truth_id}", response_model=TruthPayload) From ad6a85062f7481ee5db6b16154a27a81dfbc373e Mon Sep 17 00:00:00 2001 From: nexy7574 Date: Sun, 14 Jul 2024 01:39:16 +0100 Subject: [PATCH 21/24] Add pagination to docstring" --- src/server.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/server.py b/src/server.py index 6e2b245..36bc95a 100644 --- a/src/server.py +++ b/src/server.py @@ -102,7 +102,9 @@ def get_all_truths( page: int = 0, db: redis.Redis = Depends(get_db_factory()) ): - """Retrieves all stored truths""" + """Retrieves all stored truths + + If ?limit is a positive integer, pagination will be enabled.""" keys = db.keys() if rich is False: return [ From d7d18056229d6801ca8c68e47e82bd343dcac278 Mon Sep 17 00:00:00 2001 From: nexy7574 Date: Sun, 14 Jul 2024 01:53:39 +0100 Subject: [PATCH 22/24] Reformat + lint + use orjson --- requirements.txt | 1 + src/cogs/auto_responder.py | 136 +++++++---- src/cogs/election.py | 70 +++--- src/cogs/ffmeta.py | 88 +++++-- src/cogs/gay_meter.py | 20 +- src/cogs/net.py | 149 ++++++------ src/cogs/ollama.py | 461 ++++++++++++++++++++++++++----------- src/cogs/onion_feed.py | 19 +- src/cogs/quote_quota.py | 61 +++-- src/cogs/screenshot.py | 37 ++- src/cogs/starboard.py | 137 ++++++----- src/cogs/ytdl.py | 237 ++++++++++++++----- src/conf.py | 12 +- src/main.py | 72 ++++-- src/server.py | 61 +++-- 15 files changed, 1064 insertions(+), 497 deletions(-) diff --git a/requirements.txt b/requirements.txt index 073051b..b642250 100644 --- a/requirements.txt +++ b/requirements.txt @@ -21,3 +21,4 @@ aiofiles~=23.2 fuzzywuzzy[speedup]~=0.18 tortoise-orm[asyncpg]~=0.21 superpaste @ git+https://github.com/nexy7574/superpaste.git@e31eca6 +orjson~=3.10 diff --git a/src/cogs/auto_responder.py b/src/cogs/auto_responder.py index f21e168..2daaa95 100644 --- a/src/cogs/auto_responder.py +++ b/src/cogs/auto_responder.py @@ -33,11 +33,16 @@ class AutoResponder(commands.Cog): raise ValueError( "overpowered_by is set, however members intent is disabled, making it useless." ) - if self.config.get("overrule_offline_superiors", True) is True and bot.intents.presences is False: + if ( + self.config.get("overrule_offline_superiors", True) is True + and bot.intents.presences is False + ): raise ValueError( "overrule_offline_superiors is enabled, however presences intent is not!" ) - self.config.setdefault("transcoding", {"enabled": True, "hevc": True, "on_demand": True}) + self.config.setdefault( + "transcoding", {"enabled": True, "hevc": True, "on_demand": True} + ) self.lmgtfy_cache = [] @property @@ -50,7 +55,9 @@ class AutoResponder(commands.Cog): @property def allow_on_demand_transcoding(self) -> bool: - return self.transcoding_enabled and self.config["transcoding"].get("on_demand", True) + return self.transcoding_enabled and self.config["transcoding"].get( + "on_demand", True + ) def overpowered(self, guild: discord.Guild | None) -> bool: if not guild: @@ -68,16 +75,20 @@ class AutoResponder(commands.Cog): @staticmethod @typing.overload - def extract_links(text: str, *domains: str, raw: typing.Literal[True] = False) -> list[ParseResult]: - ... + def extract_links( + text: str, *domains: str, raw: typing.Literal[True] = False + ) -> list[ParseResult]: ... @staticmethod @typing.overload - def extract_links(text: str, *domains: str, raw: typing.Literal[False] = False) -> list[str]: - ... + def extract_links( + text: str, *domains: str, raw: typing.Literal[False] = False + ) -> list[str]: ... @staticmethod - def extract_links(text: str, *domains: str, raw: bool = False) -> list[str | ParseResult]: + def extract_links( + text: str, *domains: str, raw: bool = False + ) -> list[str | ParseResult]: """ Extracts all links from a given text. @@ -126,10 +137,13 @@ class AutoResponder(commands.Cog): if not update: return if last_reaction is not None: - _ = asyncio.create_task(update.remove_reaction(last_reaction, self.bot.user)) + _ = asyncio.create_task( + update.remove_reaction(last_reaction, self.bot.user) + ) if new: _ = asyncio.create_task(update.add_reaction(new)) last_reaction = new + self.log.info("Waiting for transcode lock to release") async with self.transcode_lock: cog = FFMeta(self.bot) @@ -163,7 +177,9 @@ class AutoResponder(commands.Cog): break else: if int(info["format"]["size"]) >= 25 * 1024 * 1024: - self.log.warning("%s is too large to render in discord, compressing", uri) + self.log.warning( + "%s is too large to render in discord, compressing", uri + ) maxrate = "1M" else: self.log.info("No HEVC streams found in %s", uri) @@ -235,7 +251,9 @@ class AutoResponder(commands.Cog): stderr=asyncio.subprocess.PIPE, ) stdout, stderr = await process.communicate() - self.log.info("finished transcode with return code %d", process.returncode) + self.log.info( + "finished transcode with return code %d", process.returncode + ) self.log.debug("stdout: %r", stdout.decode) self.log.debug("stderr: %r", stderr.decode) update_reaction() @@ -248,12 +266,9 @@ class AutoResponder(commands.Cog): ) self._cooldown_transcode() return discord.File(tmp_path), tmp_path - + async def transcode_hevc_to_h264( - self, - message: discord.Message, - *domains: str, - additional: Iterable[str] = None + self, message: discord.Message, *domains: str, additional: Iterable[str] = None ) -> None: if not shutil.which("ffmpeg"): self.log.error("ffmpeg not installed") @@ -288,7 +303,9 @@ class AutoResponder(commands.Cog): self.log.info("Found link to transcode: %r", link) try: async with message.channel.typing(): - _r = await self._transcode_hevc_to_h264(link, update=message) + _r = await self._transcode_hevc_to_h264( + link, update=message + ) if not _r: continue file, _p = _r @@ -296,7 +313,9 @@ class AutoResponder(commands.Cog): if _p.stat().st_size <= 24.5 * 1024 * 1024: await message.add_reaction("\N{OUTBOX TRAY}") await message.reply(file=file) - await message.remove_reaction("\N{OUTBOX TRAY}", self.bot.user) + await message.remove_reaction( + "\N{OUTBOX TRAY}", self.bot.user + ) else: await message.add_reaction("\N{OUTBOX TRAY}") self.log.warning( @@ -306,16 +325,31 @@ class AutoResponder(commands.Cog): ) if _p.stat().st_size <= 510 * 1024 * 1024: file.fp.seek(0) - self.log.info("Trying to upload file to pastebin.") + self.log.info( + "Trying to upload file to pastebin." + ) async with httpx.AsyncClient() as client: response = await client.post( "https://0x0.st", - files={"file": (_p.name, file.fp, "video/mp4")}, - headers={"User-Agent": "CollegeBot (matrix: @nex:nexy7574.co.uk)"}, + files={ + "file": ( + _p.name, + file.fp, + "video/mp4", + ) + }, + headers={ + "User-Agent": "CollegeBot (matrix: @nex:nexy7574.co.uk)" + }, + ) + await message.remove_reaction( + "\N{OUTBOX TRAY}", self.bot.user ) - await message.remove_reaction("\N{OUTBOX TRAY}", self.bot.user) if response.status_code == 200: - await message.reply("https://embeds.video/" + response.text.strip()) + await message.reply( + "https://embeds.video/" + + response.text.strip() + ) else: await message.add_reaction("\N{BUG}") response.raise_for_status() @@ -326,12 +360,16 @@ class AutoResponder(commands.Cog): except Exception as e: self.log.error("Failed to transcode %r: %r", link, e) - async def copy_ncfe_docs(self, message: discord.Message, links: list[ParseResult]) -> None: + async def copy_ncfe_docs( + self, message: discord.Message, links: list[ParseResult] + ) -> None: files = [] if self.config.get("download_pdfs", True) is False: self.log.debug("Download PDFs is disabled in config, disengaging.") return - self.log.info("Preparing to download: %s", ", ".join(map(ParseResult.geturl, links))) + self.log.info( + "Preparing to download: %s", ", ".join(map(ParseResult.geturl, links)) + ) async with aiohttp.ClientSession() as session: for link in set(links): if link.path.endswith(".pdf"): @@ -343,7 +381,7 @@ class AutoResponder(commands.Cog): "Failed to download %s: HTTP %d - %r", link, response.status, - await response.text() + await response.text(), ) continue async for chunk in response.content.iter_any(): @@ -356,7 +394,9 @@ class AutoResponder(commands.Cog): self.log.warning("File was too large to upload. Skipping.") continue p = pathlib.Path(link.path).name - file = discord.File(buffer, filename=p, description="Copy of %s" % link.geturl()) + file = discord.File( + buffer, filename=p, description="Copy of %s" % link.geturl() + ) files.append(file) for file in files: await message.reply(file=file) @@ -384,27 +424,38 @@ class AutoResponder(commands.Cog): self.log.info("Got VHS reaction, scanning for transcode") extra = [] if reaction.message.attachments: - extra = [attachment.url for attachment in reaction.message.attachments] + extra = [ + attachment.url for attachment in reaction.message.attachments + ] if self.allow_on_demand_transcoding: - await self.transcode_hevc_to_h264(reaction.message, additional=extra) + await self.transcode_hevc_to_h264( + reaction.message, additional=extra + ) elif str(reaction.emoji) == "\U0001f310": if reaction.message.id not in self.lmgtfy_cache: - url = "https://lmddgtfy.net/?q=" + quote_plus(reaction.message.content) - m = await reaction.message.reply(f"[Here's the answer to your question]({url})") + url = "https://lmddgtfy.net/?q=" + quote_plus( + reaction.message.content + ) + m = await reaction.message.reply( + f"[Here's the answer to your question]({url})" + ) await m.edit(suppress=True) self.lmgtfy_cache.append(reaction.message.id) - - elif str(reaction.emoji)[0] == "\N{wastebasket}": - if reaction.message.channel.permissions_for(reaction.message.guild.me).manage_messages: - self.log.info("Deleting message %s (Wastebasket)" % reaction.message.jump_url) + + elif str(reaction.emoji)[0] == "\N{WASTEBASKET}": + if reaction.message.channel.permissions_for( + reaction.message.guild.me + ).manage_messages: + self.log.info( + "Deleting message %s (Wastebasket)" % reaction.message.jump_url + ) await reaction.message.delete( - reason="%s requested deletion of message" % user, - delay=0.2 + reason="%s requested deletion of message" % user, delay=0.2 ) else: self.log.warning( "Unable to delete message %s (wastebasket) - missing permissions", - reaction.message.jump_url + reaction.message.jump_url, ) @commands.Cog.listener("on_raw_reaction_add") @@ -420,8 +471,13 @@ class AutoResponder(commands.Cog): _e = discord.PartialEmoji.from_str(str(payload.emoji)) reaction = discord.Reaction( message=message, - data={"emoji": _e, "count": 1, "me": payload.user_id == self.bot.user.id, "burst": False}, - emoji=payload.emoji + data={ + "emoji": _e, + "count": 1, + "me": payload.user_id == self.bot.user.id, + "burst": False, + }, + emoji=payload.emoji, ) user = self.bot.get_user(payload.user_id) await self.on_reaction_add(reaction, user) diff --git a/src/cogs/election.py b/src/cogs/election.py index f206f4f..32d5119 100644 --- a/src/cogs/election.py +++ b/src/cogs/election.py @@ -1,9 +1,8 @@ """ This module is only meant to be loaded during election times. """ + import asyncio -from calendar import c -from contextlib import asynccontextmanager import logging import random import datetime @@ -18,37 +17,29 @@ from discord.ext import commands, tasks SPAN_REGEX = re.compile( r"^(?P\D+)(?P[0-9,]+)\scouncillors\s(?P[0-9,]+)\scouncillors\s(?P(gained|lost))$" ) -MULTI: dict[str, int] = { - "gained": 1, - "lost": -1 -} +MULTI: dict[str, int] = {"gained": 1, "lost": -1} class ElectionCog(commands.Cog): SOURCE = "https://bbc.com/" HEADERS = { "Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp," - "image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.7", + "image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.7", "Accept-Language": "en-GB,en-US;q=0.9,en;q=0.8", "Priority": "u=0, i", - "Sec-Ch-Ua": "\"Chromium\";v=\"124\", \"Google Chrome\";v=\"124\", \"Not-A.Brand\";v=\"99\"", + "Sec-Ch-Ua": '"Chromium";v="124", "Google Chrome";v="124", "Not-A.Brand";v="99"', "Sec-Ch-Ua-Mobile": "?0", - "Sec-Ch-Ua-Platform": "\"Linux\"", + "Sec-Ch-Ua-Platform": '"Linux"', "Sec-Fetch-Dest": "document", "Sec-Fetch-Mode": "navigate", "Sec-Fetch-Site": "none", "Sec-Fetch-User": "?1", "Upgrade-Insecure-Requests": "1", "User-Agent": "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/124.0.0.0 " - "Safari/537.36", + "Safari/537.36", } ETA = datetime.datetime( - 2024, - 7, - 4, - 23, - 30, - tzinfo=datetime.datetime.now().astimezone().tzinfo + 2024, 7, 4, 23, 30, tzinfo=datetime.datetime.now().astimezone().tzinfo ) def __init__(self, bot): @@ -97,15 +88,17 @@ class ElectionCog(commands.Cog): ) except discord.HTTPException: self.log.exception("Failed to edit countdown message.") - self.countdown_message = await channel.send( - f"```diff\n{message}```" - ) + self.countdown_message = await channel.send(f"```diff\n{message}```") self.log.debug("Sent countdown message") def process_soup(self, soup: BeautifulSoup) -> dict[str, list[int]] | None: - good_soups = list(soup.find_all(attrs={"data-testid": "election-banner-results-bar"})) + good_soups = list( + soup.find_all(attrs={"data-testid": "election-banner-results-bar"}) + ) if not good_soups: - self.log.error("No 'election-banner-results-bar' elements found:\n%r", soup.prettify()) + self.log.error( + "No 'election-banner-results-bar' elements found:\n%r", soup.prettify() + ) return good_soup = list(good_soups)[1] @@ -145,7 +138,9 @@ class ElectionCog(commands.Cog): async with httpx.AsyncClient(headers=self.HEADERS) as client: response = await client.get(self.SOURCE, follow_redirects=True) if response.status_code != 200: - raise RuntimeError(f"HTTP {response.status_code} while fetching results from BBC") + raise RuntimeError( + f"HTTP {response.status_code} while fetching results from BBC" + ) soup = await asyncio.to_thread(BeautifulSoup, response.text, "html.parser") results = await self.bot.loop.run_in_executor(None, self.process_soup, soup) if results: @@ -155,7 +150,7 @@ class ElectionCog(commands.Cog): embed = discord.Embed( title="Election results - " + date, url="https://bbc.co.uk/", - timestamp=now + timestamp=now, ) embed.set_footer(text="Source from bbc.co.uk.") description_parts = [] @@ -163,13 +158,17 @@ class ElectionCog(commands.Cog): for party_name, values in results.items(): councillors, net, colour, last_election = values colour_scores[party_name] = councillors - symbol = "+" if net > 0 else '' + symbol = "+" if net > 0 else "" description_parts.append( f"**{party_name}**: {symbol}{net:,} ({councillors:,} total, " f"{last_election:,} predicted by exit poll)" ) - top_party = list(sorted(colour_scores.keys(), key=lambda k: colour_scores[k], reverse=True))[0] + top_party = list( + sorted( + colour_scores.keys(), key=lambda k: colour_scores[k], reverse=True + ) + )[0] embed.colour = discord.Colour(results[top_party][2]) embed.description = "\n".join(description_parts) return embed @@ -177,12 +176,15 @@ class ElectionCog(commands.Cog): @commands.slash_command(name="election") async def get_election_results(self, ctx: discord.ApplicationContext): """Gets the current election results""" + class RefreshView(discord.ui.View): def __init__(self, **kwargs): super().__init__(**kwargs) self.last_edit = discord.utils.utcnow() - @discord.ui.button(label="Refresh", style=discord.ButtonStyle.primary, emoji="\U0001f501") + @discord.ui.button( + label="Refresh", style=discord.ButtonStyle.primary, emoji="\U0001f501" + ) async def refresh(_self, _btn, interaction): await interaction.response.defer(invisible=True) if (discord.utils.utcnow() - self.last_edit).total_seconds() < 10: @@ -191,9 +193,13 @@ class ElectionCog(commands.Cog): embed = await self._get_embed() except Exception as e: self.log.exception("Failed to get election results.") - return await interaction.followup.send(f"Sorry, I cannot contact the BBC at this time: {e}") + return await interaction.followup.send( + f"Sorry, I cannot contact the BBC at this time: {e}" + ) if embed is None: - return await interaction.followup.send("Sorry, I could not find any election results.") + return await interaction.followup.send( + "Sorry, I could not find any election results." + ) await interaction.edit_original_response(embed=embed) await ctx.defer() @@ -201,10 +207,14 @@ class ElectionCog(commands.Cog): embed = await self._get_embed() except Exception as e: self.log.exception("Failed to get election results.") - return await ctx.respond(f"Sorry, I cannot contact the BBC at this time: {e}") + return await ctx.respond( + f"Sorry, I cannot contact the BBC at this time: {e}" + ) if embed is None: return await ctx.respond("Sorry, I could not find any election results.") - await ctx.respond(embed=embed, view=RefreshView(timeout=3600, disable_on_timeout=True)) + await ctx.respond( + embed=embed, view=RefreshView(timeout=3600, disable_on_timeout=True) + ) def setup(bot): diff --git a/src/cogs/ffmeta.py b/src/cogs/ffmeta.py index 06bda53..aa0fbb9 100644 --- a/src/cogs/ffmeta.py +++ b/src/cogs/ffmeta.py @@ -27,25 +27,44 @@ class FFMeta(commands.Cog): self.bot = bot self.log = logging.getLogger("jimmy.cogs.ffmeta") - def jpegify_image(self, input_file: io.BytesIO, quality: int = 50, image_format: str = "jpeg") -> io.BytesIO: + def jpegify_image( + self, input_file: io.BytesIO, quality: int = 50, image_format: str = "jpeg" + ) -> io.BytesIO: quality = min(1, max(quality, 100)) img_src = PIL.Image.open(input_file) if image_format == "jpeg": img_src = img_src.convert("RGB") img_dst = io.BytesIO() - self.log.debug("Saving input file (%r) as %r with quality %r%%", input_file, image_format, quality) + self.log.debug( + "Saving input file (%r) as %r with quality %r%%", + input_file, + image_format, + quality, + ) img_src.save(img_dst, format=image_format, quality=quality) img_dst.seek(0) return img_dst - async def _run_ffprobe(self, uri: str | pathlib.Path, as_json: bool = False) -> dict | str: + async def _run_ffprobe( + self, uri: str | pathlib.Path, as_json: bool = False + ) -> dict | str: """ Runs ffprobe on the given target (either file path or URL) and returns the result :param uri: the URI to run ffprobe on :return: The result """ _bin = "ffprobe" - cmd = ["-hide_banner", "-v", "quiet", "-print_format", "json", "-show_streams", "-show_format", "-i", str(uri)] + cmd = [ + "-hide_banner", + "-v", + "quiet", + "-print_format", + "json", + "-show_streams", + "-show_format", + "-i", + str(uri), + ] if not as_json: cmd = ["-hide_banner", "-i", str(uri)] process = await asyncio.create_subprocess_exec( @@ -60,7 +79,12 @@ class FFMeta(commands.Cog): return stderr.decode(errors="replace") @commands.slash_command() - async def ffprobe(self, ctx: discord.ApplicationContext, url: str = None, attachment: discord.Attachment = None): + async def ffprobe( + self, + ctx: discord.ApplicationContext, + url: str = None, + attachment: discord.Attachment = None, + ): """Runs ffprobe on a given URL or attachment""" if not shutil.which("ffprobe"): return await ctx.respond("ffprobe is not installed on this system.") @@ -99,7 +123,10 @@ class FFMeta(commands.Cog): image_format: typing.Annotated[ str, discord.Option( - str, description="The format of the resulting image", choices=["jpeg", "webp"], default="jpeg" + str, + description="The format of the resulting image", + choices=["jpeg", "webp"], + default="jpeg", ), ] = "jpeg", ): @@ -113,7 +140,9 @@ class FFMeta(commands.Cog): src = io.BytesIO() async with httpx.AsyncClient( - headers={"User-Agent": f"DiscordBot (Jimmy, v2, {VERSION}, +https://github.com/nexy7574/college-bot-v2)"} + headers={ + "User-Agent": f"DiscordBot (Jimmy, v2, {VERSION}, +https://github.com/nexy7574/college-bot-v2)" + } ) as client: response = await client.get(url) if response.status_code != 200: @@ -121,13 +150,17 @@ class FFMeta(commands.Cog): src.write(response.content) try: - dst = await asyncio.to_thread(self.jpegify_image, src, quality, image_format) + dst = await asyncio.to_thread( + self.jpegify_image, src, quality, image_format + ) except Exception as e: await ctx.respond(f"Failed to convert image: `{e}`.") self.log.error("Failed to convert image %r: %r", url, e) return else: - await ctx.respond(file=discord.File(dst, filename=f"jpegified.{image_format}")) + await ctx.respond( + file=discord.File(dst, filename=f"jpegified.{image_format}") + ) @commands.slash_command() async def opusinate( @@ -146,7 +179,10 @@ class FFMeta(commands.Cog): ), ] = 96, mono: typing.Annotated[ - bool, discord.Option(bool, description="Whether to convert the audio to mono", default=False) + bool, + discord.Option( + bool, description="Whether to convert the audio to mono", default=False + ), ] = False, ): """Converts a given URL or attachment to an Opus file""" @@ -193,7 +229,9 @@ class FFMeta(commands.Cog): ) stdout, stderr = await probe_process.communicate() stdout = stdout.decode("utf-8", "replace") - data = {"format": {"duration": 195}} # 3 minutes and 15 seconds is the 2023 average. + data = { + "format": {"duration": 195} + } # 3 minutes and 15 seconds is the 2023 average. if stdout: try: data = json.loads(stdout) @@ -204,7 +242,9 @@ class FFMeta(commands.Cog): max_end_size = ((bitrate * duration * channels) / 8) * 1024 if max_end_size > (24.75 * 1024 * 1024): return await ctx.respond( - "The file would be too large to send ({:,.2f} MiB).".format(max_end_size / 1024 / 1024) + "The file would be too large to send ({:,.2f} MiB).".format( + max_end_size / 1024 / 1024 + ) ) process = await asyncio.create_subprocess_exec( @@ -235,7 +275,9 @@ class FFMeta(commands.Cog): file = io.BytesIO(stdout) if (fs := len(file.getvalue())) > (24.75 * 1024 * 1024): - return await ctx.respond("The file is too large to send ({:,.2f} MiB).".format(fs / 1024 / 1024)) + return await ctx.respond( + "The file is too large to send ({:,.2f} MiB).".format(fs / 1024 / 1024) + ) if not fs: await ctx.respond("Failed to convert audio. See below.") else: @@ -249,9 +291,11 @@ class FFMeta(commands.Cog): for page in paginator.pages: await ctx.respond(page, ephemeral=True) - + @commands.slash_command(name="right-behind-you") - async def right_behind_you(self, ctx: discord.ApplicationContext, image: discord.Attachment): + async def right_behind_you( + self, ctx: discord.ApplicationContext, image: discord.Attachment + ): await ctx.defer() rbh = Path("assets/right-behind-you.ogg").resolve() if not rbh.exists(): @@ -262,7 +306,9 @@ class FFMeta(commands.Cog): return await ctx.respond("That's not an image!") with tempfile.NamedTemporaryFile(suffix=Path(image.filename).suffix) as temp: img_tmp = io.BytesIO(await image.read()) - dst: io.BytesIO = await asyncio.to_thread(self.jpegify_image, img_tmp, 90, "webp") + dst: io.BytesIO = await asyncio.to_thread( + self.jpegify_image, img_tmp, 90, "webp" + ) temp.write(dst.getvalue()) temp.flush() process = await asyncio.create_subprocess_exec( @@ -305,16 +351,20 @@ class FFMeta(commands.Cog): "5", "pipe:1", stdout=asyncio.subprocess.PIPE, - stderr=sys.stderr + stderr=sys.stderr, ) stdout, stderr = await process.communicate() file = io.BytesIO(stdout) if (fs := len(file.getvalue())) > (24.75 * 1024 * 1024): - return await ctx.respond("The file is too large to send ({:,.2f} MiB).".format(fs / 1024 / 1024)) + return await ctx.respond( + "The file is too large to send ({:,.2f} MiB).".format(fs / 1024 / 1024) + ) if not fs: await ctx.respond("Failed to convert audio. See below.") else: - return await ctx.respond(file=discord.File(file, filename="right-behind-you.mp4")) + return await ctx.respond( + file=discord.File(file, filename="right-behind-you.mp4") + ) paginator = commands.Paginator() for line in stderr.decode().splitlines(): if line.strip().startswith(":"): diff --git a/src/cogs/gay_meter.py b/src/cogs/gay_meter.py index 6c73923..00d89a6 100644 --- a/src/cogs/gay_meter.py +++ b/src/cogs/gay_meter.py @@ -11,13 +11,15 @@ class MeterCog(commands.Cog): self.bot = bot self.log = logging.getLogger("jimmy.cogs.auto_responder") self.cache = {} - + @commands.slash_command(name="gay-meter") @discord.guild_only() - async def gay_meter(self, ctx: discord.ApplicationContext, user: discord.User = None): + async def gay_meter( + self, ctx: discord.ApplicationContext, user: discord.User = None + ): """Checks how gay someone is""" user = user or ctx.user - + await ctx.respond("Calculating...") for i in range(0, 125, 25): await ctx.edit(content="Calculating... %d%%" % i) @@ -28,9 +30,11 @@ class MeterCog(commands.Cog): else: pct = user.id % 100 await ctx.edit(content=f"{user.mention} is {pct}% gay.") - + @commands.slash_command(name="penis-length") - async def penis_meter(self, ctx: discord.ApplicationContext, user: discord.User = None): + async def penis_meter( + self, ctx: discord.ApplicationContext, user: discord.User = None + ): """Checks the length of someone's penis.""" user = user or ctx.user if random.randint(0, 1): @@ -49,10 +53,10 @@ class MeterCog(commands.Cog): return await ctx.respond( embed=discord.Embed( title=f"{user.display_name}'s penis length:", - description="%d cm (%.2f%s)\n%s" % (pct, inch, im, "".join(chunks)) + description="%d cm (%.2f%s)\n%s" % (pct, inch, im, "".join(chunks)), ) ) - + @commands.command() @commands.is_owner() async def clear_cache(self, ctx: commands.Context, user: discord.User = None): @@ -61,7 +65,7 @@ class MeterCog(commands.Cog): self.cache = {} else: self.cache.pop(user, None) - return await ctx.message.add_reaction("\N{white heavy check mark}") + return await ctx.message.add_reaction("\N{WHITE HEAVY CHECK MARK}") def setup(bot): diff --git a/src/cogs/net.py b/src/cogs/net.py index 150148f..034fc61 100644 --- a/src/cogs/net.py +++ b/src/cogs/net.py @@ -33,17 +33,11 @@ class GetFilteredTextView(discord.ui.View): self.text = text super().__init__(timeout=600) - @discord.ui.button( - label="See filtered data", - emoji="\N{INBOX TRAY}" - ) + @discord.ui.button(label="See filtered data", emoji="\N{INBOX TRAY}") async def see_filtered_data(self, _, interaction: discord.Interaction): await interaction.response.defer(ephemeral=True) await interaction.followup.send( - file=discord.File( - io.BytesIO(self.text.encode()), - "filtered.txt" - ) + file=discord.File(io.BytesIO(self.text.encode()), "filtered.txt") ) @@ -106,7 +100,11 @@ class NetworkCog(commands.Cog): def decide(ln: str) -> typing.Optional[bool]: if ln.startswith(">>> Last update"): return - if "REDACTED" in ln or "Please query the WHOIS server of the owning registrar" in ln or ":" not in ln: + if ( + "REDACTED" in ln + or "Please query the WHOIS server of the owning registrar" in ln + or ":" not in ln + ): return False else: return True @@ -134,7 +132,9 @@ class NetworkCog(commands.Cog): if not paginator.pages: stdout, stderr, status = await run_command(with_disclaimer=True) if not any((stdout, stderr)): - return await ctx.respond(f"No output was returned with status code {status}.") + return await ctx.respond( + f"No output was returned with status code {status}." + ) file = io.BytesIO() file.write(stdout) if stderr: @@ -143,7 +143,7 @@ class NetworkCog(commands.Cog): file.seek(0) return await ctx.respond( "Seemingly all output was filtered. Returning raw command output.", - file=discord.File(file, "whois.txt") + file=discord.File(file, "whois.txt"), ) last: discord.Interaction | discord.WebhookMessage | None = None @@ -223,9 +223,15 @@ class NetworkCog(commands.Cog): default="default", ), use_ip_version: discord.Option( - str, name="ip-version", description="IP version to use.", choices=["ipv4", "ipv6"], default="ipv4" + str, + name="ip-version", + description="IP version to use.", + choices=["ipv4", "ipv6"], + default="ipv4", + ), + max_ttl: discord.Option( + int, name="ttl", description="Max number of hops", default=30 ), - max_ttl: discord.Option(int, name="ttl", description="Max number of hops", default=30), ): """Performs a traceroute request.""" await ctx.defer() @@ -258,7 +264,9 @@ class NetworkCog(commands.Cog): args.append(str(port)) args.append(url) paginator = commands.Paginator() - paginator.add_line(f"Running command: {' '.join(args[3 if args[0] == 'sudo' else 0:])}") + paginator.add_line( + f"Running command: {' '.join(args[3 if args[0] == 'sudo' else 0:])}" + ) paginator.add_line(empty=True) try: start = time.time_ns() @@ -297,10 +305,7 @@ class NetworkCog(commands.Cog): await ctx.respond(file=discord.File(f)) async def _fetch_ip_response( - self, - server: str, - lookup: str, - client: httpx.AsyncClient + self, server: str, lookup: str, client: httpx.AsyncClient ) -> tuple[dict, float] | httpx.HTTPError | ConnectionError | json.JSONDecodeError: try: start = time.perf_counter() @@ -315,18 +320,16 @@ class NetworkCog(commands.Cog): async def get_ip_address(self, ctx: discord.ApplicationContext, lookup: str = None): """Fetches IP info from SHRONK IP servers""" await ctx.defer() - async with httpx.AsyncClient(headers={"User-Agent": "Mozilla/5.0 Jimmy/v2"}) as client: + async with httpx.AsyncClient( + headers={"User-Agent": "Mozilla/5.0 Jimmy/v2"} + ) as client: if not lookup: response = await client.get("https://api.ipify.org") lookup = response.text servers = self.config.get( "ip_servers", - [ - "ip.shronk.net", - "ip.i-am.nexus", - "ip.shronk.nicroxio.co.uk" - ] + ["ip.shronk.net", "ip.i-am.nexus", "ip.shronk.nicroxio.co.uk"], ) embed = discord.Embed( title="IP lookup information for: %s" % lookup, @@ -353,13 +356,14 @@ class NetworkCog(commands.Cog): t = response.text[:512] embed.add_field( name=server, - value=f"An error occurred while parsing the data: {e}\nData: ```\n%s\n```" % t, + value=f"An error occurred while parsing the data: {e}\nData: ```\n%s\n```" + % t, ) else: embed.add_field( name="%s (%.2fms)" % (server, (end - start) * 1000), value="```json\n%s\n```" % v, - inline=False + inline=False, ) await ctx.respond(embed=embed) @@ -367,41 +371,41 @@ class NetworkCog(commands.Cog): @commands.slash_command() @commands.max_concurrency(1, commands.BucketType.user) async def nmap( - self, - ctx: discord.ApplicationContext, - target: str, - technique: typing.Annotated[ + self, + ctx: discord.ApplicationContext, + target: str, + technique: typing.Annotated[ + str, + discord.Option( str, - discord.Option( - str, - choices=[ - discord.OptionChoice(name="TCP SYN", value="S"), - discord.OptionChoice(name="TCP Connect", value="T"), - discord.OptionChoice(name="TCP ACK", value="A"), - discord.OptionChoice(name="TCP Window", value="W"), - discord.OptionChoice(name="TCP Maimon", value="M"), - discord.OptionChoice(name="UDP", value="U"), - discord.OptionChoice(name="TCP Null", value="N"), - discord.OptionChoice(name="TCP FIN", value="F"), - discord.OptionChoice(name="TCP XMAS", value="X"), - ], - default="T" - ) - ] = "T", - treat_all_hosts_online: bool = False, - service_scan: bool = False, - fast_mode: bool = False, - enable_os_detection: bool = False, - timing: typing.Annotated[ + choices=[ + discord.OptionChoice(name="TCP SYN", value="S"), + discord.OptionChoice(name="TCP Connect", value="T"), + discord.OptionChoice(name="TCP ACK", value="A"), + discord.OptionChoice(name="TCP Window", value="W"), + discord.OptionChoice(name="TCP Maimon", value="M"), + discord.OptionChoice(name="UDP", value="U"), + discord.OptionChoice(name="TCP Null", value="N"), + discord.OptionChoice(name="TCP FIN", value="F"), + discord.OptionChoice(name="TCP XMAS", value="X"), + ], + default="T", + ), + ] = "T", + treat_all_hosts_online: bool = False, + service_scan: bool = False, + fast_mode: bool = False, + enable_os_detection: bool = False, + timing: typing.Annotated[ + int, + discord.Option( int, - discord.Option( - int, - description="Timing template to use 0 is slowest, 5 is fastest.", - choices=[0, 1, 2, 3, 4, 5], - default=3 - ) - ] = 3, - ports: str = None + description="Timing template to use 0 is slowest, 5 is fastest.", + choices=[0, 1, 2, 3, 4, 5], + default=3, + ), + ] = 3, + ports: str = None, ): """Runs nmap on a target. You cannot specify multiple targets.""" await ctx.defer() @@ -420,7 +424,9 @@ class NetworkCog(commands.Cog): if enable_os_detection and not is_superuser: return await ctx.respond("OS detection is not available on this system.") - with tempfile.TemporaryDirectory(prefix=f"nmap-{ctx.user.id}-{discord.utils.utcnow().timestamp():.0f}") as tmp: + with tempfile.TemporaryDirectory( + prefix=f"nmap-{ctx.user.id}-{discord.utils.utcnow().timestamp():.0f}" + ) as tmp: tmp_dir = Path(tmp) args = [ "nmap", @@ -430,7 +436,7 @@ class NetworkCog(commands.Cog): str(timing), "-s" + technique, "--reason", - "--noninteractive" + "--noninteractive", ] if treat_all_hosts_online: args.append("-Pn") @@ -447,8 +453,7 @@ class NetworkCog(commands.Cog): await ctx.respond( embed=discord.Embed( title="Running nmap...", - description="Command:\n" - "```{}```".format(shlex.join(args)), + description="Command:\n" "```{}```".format(shlex.join(args)), ) ) process = await asyncio.create_subprocess_exec( @@ -457,14 +462,16 @@ class NetworkCog(commands.Cog): stderr=asyncio.subprocess.PIPE, ) _, stderr = await process.communicate() - files = [discord.File(x, filename=x.name + ".txt") for x in tmp_dir.iterdir()] + files = [ + discord.File(x, filename=x.name + ".txt") for x in tmp_dir.iterdir() + ] if not files: if len(stderr) <= 4089: return await ctx.edit( embed=discord.Embed( title="Nmap failed.", description="```\n" + stderr.decode() + "```", - color=discord.Color.red() + color=discord.Color.red(), ) ) @@ -475,19 +482,19 @@ class NetworkCog(commands.Cog): embed=discord.Embed( title="Nmap failed.", description=f"Output was too long. [View full output]({result.url})", - color=discord.Color.red() + color=discord.Color.red(), ) ) await ctx.edit( embed=discord.Embed( title="Nmap finished!", description="Result files are attached.\n" - "* `gnmap` is 'greppable'\n" - "* `xml` is XML output\n" - "* `nmap` is normal output", - color=discord.Color.green() + "* `gnmap` is 'greppable'\n" + "* `xml` is XML output\n" + "* `nmap` is normal output", + color=discord.Color.green(), ), - files=files + files=files, ) diff --git a/src/cogs/ollama.py b/src/cogs/ollama.py index 5223256..4a35435 100644 --- a/src/cogs/ollama.py +++ b/src/cogs/ollama.py @@ -107,7 +107,9 @@ class OllamaDownloadHandler: async def __aiter__(self): async with aiohttp.ClientSession(base_url=self.base_url) as client: - async with client.post("/api/pull", json={"name": self.model, "stream": True}, timeout=None) as response: + async with client.post( + "/api/pull", json={"name": self.model, "stream": True}, timeout=None + ) as response: response.raise_for_status() async for line in ollama_stream(response.content): self.parse_line(line) @@ -122,7 +124,7 @@ class OllamaDownloadHandler: "Downloading orca-mini:7b on server %r - %s (%.2f%%)", self.base_url, self.status, - self.percent + self.percent, ) return self @@ -176,12 +178,13 @@ class OllamaChatHandler: async def __aiter__(self): async with aiohttp.ClientSession(base_url=self.base_url) as client: async with client.post( - "/api/chat", json={ - "model": self.model, - "stream": True, + "/api/chat", + json={ + "model": self.model, + "stream": True, "messages": self.messages, - "options": self.options - } + "options": self.options, + }, ) as response: response.raise_for_status() async for line in ollama_stream(response.content): @@ -201,7 +204,9 @@ class OllamaClient: self.base_url = base_url self.authorisation = authorisation - def with_client(self, timeout: aiohttp.ClientTimeout | float | int | None = None) -> aiohttp.ClientSession: + def with_client( + self, timeout: aiohttp.ClientTimeout | float | int | None = None + ) -> aiohttp.ClientSession: """ Creates an instance for a request, with properly populated values. :param timeout: @@ -213,9 +218,13 @@ class OllamaClient: timeout = aiohttp.ClientTimeout(timeout) else: timeout = timeout or aiohttp.ClientTimeout(120) - return aiohttp.ClientSession(self.base_url, timeout=timeout, auth=self.authorisation) + return aiohttp.ClientSession( + self.base_url, timeout=timeout, auth=self.authorisation + ) - async def get_tags(self) -> dict[typing.Literal["models"], dict[str, str, int, dict[str, str, None]]]: + async def get_tags( + self, + ) -> dict[typing.Literal["models"], dict[str, str, int, dict[str, str, None]]]: """ Gets the tags for the server. :return: @@ -250,7 +259,7 @@ class OllamaClient: self, model: str, messages: list[dict[str, str]], - options: dict[str, typing.Any] = None + options: dict[str, typing.Any] = None, ) -> OllamaChatHandler: """ Starts a chat with the given messages. @@ -273,7 +282,11 @@ class OllamaView(View): async def interaction_check(self, interaction: discord.Interaction) -> bool: return interaction.user == self.ctx.user - @button(label="Stop", style=discord.ButtonStyle.danger, emoji="\N{wastebasket}\U0000fe0f") + @button( + label="Stop", + style=discord.ButtonStyle.danger, + emoji="\N{WASTEBASKET}\U0000fe0f", + ) async def _stop(self, btn: discord.ui.Button, interaction: discord.Interaction): self.cancel.set() btn.disabled = True @@ -310,14 +323,20 @@ class ChatHistory: :return: The thread's ID. """ key = os.urandom(3).hex() - self._internal[key] = {"member": member.id, "seed": round(time.time()), "messages": []} + self._internal[key] = { + "member": member.id, + "seed": round(time.time()), + "messages": [], + } with open("./assets/ollama-prompt.txt") as file: system_prompt = default or file.read() self.add_message(key, "system", system_prompt) return key @staticmethod - def _construct_message(role: str, content: str, images: typing.Optional[list[str]]) -> dict[str, str]: + def _construct_message( + role: str, content: str, images: typing.Optional[list[str]] + ) -> dict[str, str]: x = {"role": role, "content": content} if images: x["images"] = images @@ -331,7 +350,9 @@ class ChatHistory: return list( filter( lambda v: (ctx.value or v) in v, - map(lambda d: list(d.keys()), instance.threads_for(ctx.interaction.user)), + map( + lambda d: list(d.keys()), instance.threads_for(ctx.interaction.user) + ), ) ) @@ -339,7 +360,9 @@ class ChatHistory: """Returns all saved threads.""" return self._internal.copy() - def threads_for(self, user: discord.Member) -> dict[str, dict[str, list[dict[str, str]] | int]]: + def threads_for( + self, user: discord.Member + ) -> dict[str, dict[str, list[dict[str, str]] | int]]: """Returns all saved threads for a specific user""" t = self.all_threads() for k, v in t.copy().items(): @@ -354,7 +377,7 @@ class ChatHistory: content: str, images: typing.Optional[list[str]] = None, *, - save: bool = True + save: bool = True, ) -> None: """ Appends a message to the given thread. @@ -380,7 +403,9 @@ class ChatHistory: return [] return self._internal[thread]["messages"].copy() # copy() makes it immutable. - def get_thread(self, thread: str) -> dict[str, list[dict[str, str]] | discord.Member | int]: + def get_thread( + self, thread: str + ) -> dict[str, list[dict[str, str]] | discord.Member | int]: """Gets a copy of an entire thread""" return self._internal.get(thread, {}).copy() @@ -401,7 +426,6 @@ SERVER_KEYS_AUTOCOMPLETE.remove("order") class OllamaGetPrompt(discord.ui.Modal): - def __init__(self, ctx: discord.ApplicationContext, prompt_type: str = "User"): super().__init__( discord.ui.InputText( @@ -441,7 +465,9 @@ class PromptSelector(discord.ui.View): if self.user_prompt is not None: self.get_item("usr").style = discord.ButtonStyle.secondary # type: ignore - @discord.ui.button(label="Set System Prompt", style=discord.ButtonStyle.primary, custom_id="sys") + @discord.ui.button( + label="Set System Prompt", style=discord.ButtonStyle.primary, custom_id="sys" + ) async def set_system_prompt(self, btn: discord.ui.Button, interaction: Interaction): modal = OllamaGetPrompt(self.ctx, "System") await interaction.response.send_modal(modal) @@ -450,7 +476,9 @@ class PromptSelector(discord.ui.View): self.update_ui() await interaction.edit_original_response(view=self) - @discord.ui.button(label="Set User Prompt", style=discord.ButtonStyle.primary, custom_id="usr") + @discord.ui.button( + label="Set User Prompt", style=discord.ButtonStyle.primary, custom_id="usr" + ) async def set_user_prompt(self, btn: discord.ui.Button, interaction: Interaction): modal = OllamaGetPrompt(self.ctx) await interaction.response.send_modal(modal) @@ -459,7 +487,9 @@ class PromptSelector(discord.ui.View): self.update_ui() await interaction.edit_original_response(view=self) - @discord.ui.button(label="Done", style=discord.ButtonStyle.success, custom_id="done") + @discord.ui.button( + label="Done", style=discord.ButtonStyle.success, custom_id="done" + ) async def done(self, btn: discord.ui.Button, interaction: Interaction): self.ctx.interaction = interaction self.stop() @@ -474,13 +504,17 @@ class ConfirmCPURun(discord.ui.View): async def interaction_check(self, interaction: Interaction) -> bool: return interaction.user == self.ctx.user - @discord.ui.button(label="Run on CPU", style=discord.ButtonStyle.primary, custom_id="cpu") + @discord.ui.button( + label="Run on CPU", style=discord.ButtonStyle.primary, custom_id="cpu" + ) async def run_on_cpu(self, btn: discord.ui.Button, interaction: Interaction): await interaction.response.defer(invisible=True) self.proceed = True self.stop() - @discord.ui.button(label="Abort", style=discord.ButtonStyle.primary, custom_id="gpu") + @discord.ui.button( + label="Abort", style=discord.ButtonStyle.primary, custom_id="gpu" + ) async def run_on_gpu(self, btn: discord.ui.Button, interaction: Interaction): await interaction.response.defer(invisible=True) self.stop() @@ -492,9 +526,7 @@ class Ollama(commands.Cog): self.log = logging.getLogger("jimmy.cogs.ollama") self.contexts = {} self.history = ChatHistory() - self.servers = { - server: asyncio.Lock() for server in CONFIG["ollama"] - } + self.servers = {server: asyncio.Lock() for server in CONFIG["ollama"]} self.servers.pop("order", None) if CONFIG["ollama"].get("order"): self.servers = {} @@ -520,7 +552,9 @@ class Ollama(commands.Cog): if url in SERVER_KEYS: url = CONFIG["ollama"][url]["base_url"] async with aiohttp.ClientSession( - timeout=aiohttp.ClientTimeout(connect=3, sock_connect=3, sock_read=10, total=3) + timeout=aiohttp.ClientTimeout( + connect=3, sock_connect=3, sock_read=10, total=3 + ) ) as session: self.log.info("Checking if %r is online.", url) try: @@ -551,20 +585,37 @@ class Ollama(commands.Cog): ), ], server: typing.Annotated[ - str, discord.Option(str, "The server to use for ollama.", default="next", choices=SERVER_KEYS_AUTOCOMPLETE) + str, + discord.Option( + str, + "The server to use for ollama.", + default="next", + choices=SERVER_KEYS_AUTOCOMPLETE, + ), ], context: typing.Annotated[ - str, discord.Option(str, "The context key of a previous ollama response to use as context.", default=None) + str, + discord.Option( + str, + "The context key of a previous ollama response to use as context.", + default=None, + ), ], give_acid: typing.Annotated[ bool, discord.Option( - bool, "Whether to give the AI acid, LSD, and other hallucinogens before responding.", default=False + bool, + "Whether to give the AI acid, LSD, and other hallucinogens before responding.", + default=False, ), ], image: typing.Annotated[ discord.Attachment, - discord.Option(discord.Attachment, "An image to feed into ollama. Only works with llava.", default=None), + discord.Option( + discord.Attachment, + "An image to feed into ollama. Only works with llava.", + default=None, + ), ], ): if server == "next": @@ -587,7 +638,7 @@ class Ollama(commands.Cog): await ctx.respond( "Select edit your prompts, as desired. Click done when you want to continue.", view=v, - ephemeral=True + ephemeral=True, ) await v.wait() query = v.user_prompt or query @@ -604,22 +655,23 @@ class Ollama(commands.Cog): self.log.debug("Resolved model to %r" % model) if image: - patterns = [ - "llava:*", - "llava-llama*:*" - ] + patterns = ["llava:*", "llava-llama*:*"] if any(fnmatch(model, p) for p in patterns) is False: await ctx.send( - f"{ctx.user.mention}: You can only use images with llava. Switching model to `llava:latest`.", - delete_after=5 + f"{ctx.user.mention}: You can only use images with llava. Switching model to `llava:latest`.", + delete_after=5, ) model = "llava:latest" if image.size > 1024 * 1024 * 25: - await ctx.respond("Attachment is too large. Maximum size is 25 MB, for sanity. Try compressing it.") + await ctx.respond( + "Attachment is too large. Maximum size is 25 MB, for sanity. Try compressing it." + ) return elif not fnmatch(image.content_type, "image/*"): - await ctx.respond("Attachment is not an image. Try using a different file.") + await ctx.respond( + "Attachment is not an image. Try using a different file." + ) return else: data = io.BytesIO() @@ -638,13 +690,19 @@ class Ollama(commands.Cog): if fnmatch(model, model_pattern): break else: - allowed_models = ", ".join(map(discord.utils.escape_markdown, server_config["allowed_models"])) - await ctx.respond(f"Invalid model. You can only use one of the following models: {allowed_models}") + allowed_models = ", ".join( + map(discord.utils.escape_markdown, server_config["allowed_models"]) + ) + await ctx.respond( + f"Invalid model. You can only use one of the following models: {allowed_models}" + ) return async with aiohttp.ClientSession( base_url=server_config["base_url"], - timeout=aiohttp.ClientTimeout(connect=5, sock_read=10800, sock_connect=5, total=10830), + timeout=aiohttp.ClientTimeout( + connect=5, sock_read=10800, sock_connect=5, total=10830 + ), ) as session: embed = discord.Embed( title="Checking server...", @@ -652,26 +710,32 @@ class Ollama(commands.Cog): color=discord.Color.blurple(), timestamp=discord.utils.utcnow(), ) - embed.set_footer(text="Using server %r" % server, icon_url=server_config.get("icon_url")) + embed.set_footer( + text="Using server %r" % server, icon_url=server_config.get("icon_url") + ) await ctx.respond(embed=embed) if not await self.check_server(server_config["base_url"]): tried = {server} for i in range(10): try: server = self.next_server(tried) - if server and CONFIG["ollama"]["server"].get("is_gpu", False) is not True: + if ( + server + and CONFIG["ollama"]["server"].get("is_gpu", False) + is not True + ): cf = ConfirmCPURun(ctx) await ctx.edit( embed=discord.Embed( title=f"Server {server} is available, but...", description="It is CPU only, which means it is very slow and will likely crash.\n" - "If you really want, you can continue your generation, using this " - "server. Be aware though, once in motion, it cannot be stopped.\n\n" - "" - "Continue?", + "If you really want, you can continue your generation, using this " + "server. Be aware though, once in motion, it cannot be stopped.\n\n" + "" + "Continue?", color=discord.Color.red(), ), - view=cf + view=cf, ) await cf.wait() await ctx.edit(view=None) @@ -688,7 +752,10 @@ class Ollama(commands.Cog): color=discord.Color.gold(), timestamp=discord.utils.utcnow(), ) - embed.set_footer(text="Using server %r" % server, icon_url=server_config.get("icon_url")) + embed.set_footer( + text="Using server %r" % server, + icon_url=server_config.get("icon_url"), + ) await ctx.edit(embed=embed) await asyncio.sleep(1) if await self.check_server(CONFIG["ollama"][server]["base_url"]): @@ -713,7 +780,9 @@ class Ollama(commands.Cog): embed = discord.Embed( url=resp.url, title=f"HTTP {resp.status} {resp.reason!r} while checking for model.", - description=f"```{await resp.text() or 'No response body'}```"[:4096], + description=f"```{await resp.text() or 'No response body'}```"[ + :4096 + ], color=discord.Color.red(), timestamp=discord.utils.utcnow(), ) @@ -733,8 +802,8 @@ class Ollama(commands.Cog): self.log.debug("Beginning download of %r", model) def progress_bar(_v: float, action: str = None, _mbps: float = None): - bar = "\N{large green square}" * round(_v / 10) - bar += "\N{white large square}" * (10 - len(bar)) + bar = "\N{LARGE GREEN SQUARE}" * round(_v / 10) + bar += "\N{WHITE LARGE SQUARE}" * (10 - len(bar)) bar += f" {_v:.2f}%" if _mbps: bar += f" ({_mbps:.2f} MiB/s)" @@ -753,12 +822,16 @@ class Ollama(commands.Cog): last_update = time.time() - async with session.post("/api/pull", json={"name": model, "stream": True}, timeout=None) as response: + async with session.post( + "/api/pull", json={"name": model, "stream": True}, timeout=None + ) as response: if response.status != 200: embed = discord.Embed( url=response.url, title=f"HTTP {response.status} {response.reason!r} while downloading model.", - description=f"```{await response.text() or 'No response body'}```"[:4096], + description=f"```{await response.text() or 'No response body'}```"[ + :4096 + ], color=discord.Color.red(), timestamp=discord.utils.utcnow(), ) @@ -775,7 +848,10 @@ class Ollama(commands.Cog): ) return await ctx.edit(embed=embed, view=None) if time.time() >= (last_update + 5.1): - if line.get("total") is not None and line.get("completed") is not None: + if ( + line.get("total") is not None + and line.get("completed") is not None + ): new_bytes = line["completed"] - last_downloaded mbps = new_bytes / 1024 / 1024 / 5 last_downloaded = line["completed"] @@ -784,7 +860,9 @@ class Ollama(commands.Cog): percent = 50.0 mbps = 0.0 - embed.fields[0].value = progress_bar(percent, line["status"], mbps) + embed.fields[0].value = progress_bar( + percent, line["status"], mbps + ) await ctx.edit(embed=embed, view=view) last_update = time.time() else: @@ -799,13 +877,13 @@ class Ollama(commands.Cog): embed=discord.Embed( title="Before you continue", description="You've selected a CPU-only server. This will be really slow. This will also likely" - " bring the host to a halt. Consider the pain the CPU is about to endure. " - "Are you super" - " sure you want to continue? You can run `h!ollama-status` to see what servers are" - " available.", + " bring the host to a halt. Consider the pain the CPU is about to endure. " + "Are you super" + " sure you want to continue? You can run `h!ollama-status` to see what servers are" + " available.", color=discord.Color.red(), ), - view=cf2 + view=cf2, ) await cf2.wait() if cf2.proceed is False: @@ -825,9 +903,13 @@ class Ollama(commands.Cog): icon_url="https://ollama.com/public/ollama.png", ) embed.add_field( - name="Prompt", value=">>> " + textwrap.shorten(query, width=1020, placeholder="..."), inline=False + name="Prompt", + value=">>> " + textwrap.shorten(query, width=1020, placeholder="..."), + inline=False, + ) + embed.set_footer( + text="Using server %r" % server, icon_url=server_config.get("icon_url") ) - embed.set_footer(text="Using server %r" % server, icon_url=server_config.get("icon_url")) if image_data: if (image.height / image.width) >= 1.5: embed.set_image(url=image.url) @@ -842,7 +924,10 @@ class Ollama(commands.Cog): if context is None: context = self.history.create_thread(ctx.user, system_query) - elif context is not None and (__thread := self.history.find_thread(context)) is None: + elif ( + context is not None + and (__thread := self.history.find_thread(context)) is None + ): if not __thread: return await ctx.respond("Invalid thread ID.") else: @@ -861,7 +946,12 @@ class Ollama(commands.Cog): params["top_p"] = 2 params["repeat_penalty"] = 2 - payload = {"model": model, "stream": True, "options": params, "messages": messages} + payload = { + "model": model, + "stream": True, + "options": params, + "messages": messages, + } async with session.post( "/api/chat", json=payload, @@ -870,7 +960,9 @@ class Ollama(commands.Cog): embed = discord.Embed( url=response.url, title=f"HTTP {response.status} {response.reason!r} while generating response.", - description=f"```{await response.text() or 'No response body'}```"[:4096], + description=f"```{await response.text() or 'No response body'}```"[ + :4096 + ], color=discord.Color.red(), timestamp=discord.utils.utcnow(), ) @@ -888,20 +980,31 @@ class Ollama(commands.Cog): embed.description = "[...]" + line["message"]["content"] if len(embed.description) >= 3250: embed.colour = discord.Color.gold() - embed.set_footer(text="Warning: {:,}/4096 characters.".format(len(embed.description))) + embed.set_footer( + text="Warning: {:,}/4096 characters.".format( + len(embed.description) + ) + ) else: embed.colour = discord.Color.blurple() - embed.set_footer(text="Using server %r" % server, icon_url=server_config.get("icon_url")) + embed.set_footer( + text="Using server %r" % server, + icon_url=server_config.get("icon_url"), + ) if view.cancel.is_set(): break if time.time() >= (last_update + 5.1): await ctx.edit(embed=embed, view=view) - self.log.debug(f"Updating message ({last_update} -> {time.time()})") + self.log.debug( + f"Updating message ({last_update} -> {time.time()})" + ) last_update = time.time() view.stop() - self.history.add_message(context, "user", user_message["content"], user_message.get("images")) + self.history.add_message( + context, "user", user_message["content"], user_message.get("images") + ) self.history.add_message(context, "assistant", buffer.getvalue()) embed.add_field(name="Context Key", value=context, inline=True) @@ -911,7 +1014,9 @@ class Ollama(commands.Cog): value = buffer.getvalue() if len(value) >= 4096: - embeds = [discord.Embed(title="Done!", colour=discord.Color.green())] + embeds = [ + discord.Embed(title="Done!", colour=discord.Color.green()) + ] current_page = "" for word in value.split(): @@ -971,13 +1076,19 @@ class Ollama(commands.Cog): if message["role"] == "system": continue max_length = 4000 - len("> **%s**: " % message["role"]) - paginator.add_line("> **{}**: {}".format(message["role"], textwrap.shorten(message["content"], max_length))) + paginator.add_line( + "> **{}**: {}".format( + message["role"], textwrap.shorten(message["content"], max_length) + ) + ) embeds = [] for page in paginator.pages: embeds.append(discord.Embed(description=page)) ephemeral = len(embeds) > 1 - for chunk in discord.utils.as_chunks(iter(embeds or [discord.Embed(title="No Content.")]), 10): + for chunk in discord.utils.as_chunks( + iter(embeds or [discord.Embed(title="No Content.")]), 10 + ): await ctx.respond(embeds=chunk, ephemeral=ephemeral) @commands.command(name="ollama-status", aliases=["ollama_status", "os"]) @@ -990,14 +1101,18 @@ class Ollama(commands.Cog): if CONFIG["ollama"].get("order"): ln = ["Server order:"] for n, key in enumerate(CONFIG["ollama"].get("order"), start=1): - zap = '\N{high voltage sign}' - ln.append(f"{n}. {key!r} {f'({zap})' if CONFIG['ollama'][key].get('is_gpu') else ''}") + zap = "\N{HIGH VOLTAGE SIGN}" + ln.append( + f"{n}. {key!r} {f'({zap})' if CONFIG['ollama'][key].get('is_gpu') else ''}" + ) embed.description = "\n".join(ln) for server, lock in self.servers.items(): embed.add_field( name=server, - value="\U000026a0\U0000fe0fIn use" if lock.locked() else "\U000023f3 Checking..." + value="\U000026a0\U0000fe0fIn use" + if lock.locked() + else "\U000023f3 Checking...", ) msg = await ctx.reply(embed=embed) @@ -1006,23 +1121,35 @@ class Ollama(commands.Cog): for server in self.servers.keys(): if self.servers[server].locked(): continue - tasks[server] = asyncio.create_task(self.check_server(CONFIG["ollama"][server]["base_url"])) + tasks[server] = asyncio.create_task( + self.check_server(CONFIG["ollama"][server]["base_url"]) + ) await asyncio.gather(*tasks.values()) for server, task in tasks.items(): if not task.done: await task if e := task.exception(): - self.log.error("Error while checking server %r: %r", server, e, exc_info=True) + self.log.error( + "Error while checking server %r: %r", server, e, exc_info=True + ) else: ok = task.result() for field in embed.fields: if field.name == server: - field.value = "\U00002705 Online" if ok else "\U0000274c Offline" + field.value = ( + "\U00002705 Online" if ok else "\U0000274c Offline" + ) break await msg.edit(embed=embed) - async def generate_truth(self, ctx: commands.Context, entity: str, limit: int = None, question: str = None): + async def generate_truth( + self, + ctx: commands.Context, + entity: str, + limit: int = None, + question: str = None, + ): if entity == "trump": system = ( "You are the former united states president, Donald Trump, convicted of 34 felonies." @@ -1034,7 +1161,7 @@ class Ollama(commands.Cog): system = ( "You are the 'Top G' Andrew Tate. Andrew Tate is a former professional kickboxer who turned to the " "world of influencing. His personality is one of arrogance and self-aggrandizement, " - "often referred to as \"King of Toxic Masculinity\" by critics due to his controversial views on " + 'often referred to as "King of Toxic Masculinity" by critics due to his controversial views on ' "gender roles, relationships, and other topics. He has been involved in several controversies related " "to his content online including promoting extremist ideologies and misogynistic views. " "Despite this, he still has a large following and is known for being an entrepreneur who built multiple" @@ -1120,9 +1247,9 @@ class Ollama(commands.Cog): "of the United Kingdom from 2019 to 2022. Known for his flamboyant personality and " "controversial policies, Johnson's leadership was marred by numerous scandals and crises." "Johnson's personality is characterized by his optimistic and charismatic demeanor, " - "often described as a \"dashing politician.\" However, his political career has been " + 'often described as a "dashing politician." However, his political career has been ' "plagued by scandals involving ethics violations, questionable financial arrangements, " - "and allegations of inappropriate behavior. The \"Partygate\" controversy, where " + 'and allegations of inappropriate behavior. The "Partygate" controversy, where ' "gatherings at Downing Street violated COVID-19 regulations, significantly damaged his reputation. " "Johnson's achievements include leading the U.K. out of the European Union and brokering a " "post-Brexit trade deal with the United States. However, his tumultuous tenure was punctuated by " @@ -1162,21 +1289,22 @@ class Ollama(commands.Cog): " Write using the style of a twitter or facebook post. Do not repeat a previous post or any previous " "content. Do not include URLs or links. If an @user asks you a question, reply to them in your post." ) - thread_id = self.history.create_thread( - ctx.author, - system - ) + thread_id = self.history.create_thread(ctx.author, system) r = CONFIG["truth"].get("api", "https://bots.nexy7574.co.uk/jimmy/v2/api") username = CONFIG["truth"].get("username", "1") password = CONFIG["truth"].get("password", "2") - async with httpx.AsyncClient(base_url=r, auth=(username, password), trust_env=False) as http_client: + async with httpx.AsyncClient( + base_url=r, auth=(username, password), trust_env=False + ) as http_client: response = await http_client.get( "/truths", timeout=60, ) response.raise_for_status() truths = response.json() - truths: list[TruthPayload] = list(map(lambda t: TruthPayload.model_validate(t), truths)) + truths: list[TruthPayload] = list( + map(lambda t: TruthPayload.model_validate(t), truths) + ) if entity: truths = list(filter(lambda t: t.author == entity, truths)) @@ -1191,11 +1319,13 @@ class Ollama(commands.Cog): thread_id, "assistant", truth.content, - save=False + save=False, ) ) if not question: - self.history.add_message(thread_id, "user", "Generate a new truth post.") + self.history.add_message( + thread_id, "user", "Generate a new truth post." + ) else: self.history.add_message(thread_id, "user", question) @@ -1204,32 +1334,36 @@ class Ollama(commands.Cog): server = self.next_server(tried) is_gpu = CONFIG["ollama"][server].get("is_gpu", False) if not is_gpu: - self.log.info("Skipping server %r as it is not a GPU server.", server) + self.log.info( + "Skipping server %r as it is not a GPU server.", server + ) continue if await self.check_server(CONFIG["ollama"][server]["base_url"]): break tried.add(server) else: - return await ctx.reply("All servers are offline. Please try again later.", delete_after=300) + return await ctx.reply( + "All servers are offline. Please try again later.", delete_after=300 + ) client = OllamaClient(CONFIG["ollama"][server]["base_url"]) async with self.servers[server]: if not await client.has_model_named("llama2-uncensored", "7b-chat"): - with client.download_model("llama2-uncensored", "7b-chat") as handler: + with client.download_model( + "llama2-uncensored", "7b-chat" + ) as handler: await handler.flatten() embed = discord.Embed( - title=f"New {post_type.title()}!", - description="", - colour=0x6559FF + title=f"New {post_type.title()}!", description="", colour=0x6559FF ) msg = await ctx.reply(embed=embed) last_edit = time.time() messages = self.history.get_history(thread_id) with client.new_chat( - "llama2-uncensored:7b-chat", - messages, - options={"num_ctx": 4096, "num_predict": 128, "temperature": 1.5} + "llama2-uncensored:7b-chat", + messages, + options={"num_ctx": 4096, "num_predict": 128, "temperature": 1.5}, ) as handler: async for ln in handler: embed.description += ln["message"]["content"] @@ -1245,7 +1379,7 @@ class Ollama(commands.Cog): if truth.content == embed.description: embed.add_field( name=f"Repeated {post_type} :(", - value=f"This truth was already {post_type}ed. Shit AI." + value=f"This truth was already {post_type}ed. Shit AI.", ) elif _ratio >= 70: similar[truth.id] = _ratio @@ -1253,12 +1387,16 @@ class Ollama(commands.Cog): if similar: if len(similar) > 1: lns = [] - keys = sorted(similar.keys(), key=lambda k: similar[k], reverse=True) + keys = sorted( + similar.keys(), key=lambda k: similar[k], reverse=True + ) for truth_id in keys: _ratio = similar[truth_id] truth = discord.utils.get(truths, id=truth_id) first_line = truth.content.splitlines()[0] - preview = discord.utils.escape_markdown(textwrap.shorten(first_line, 100)) + preview = discord.utils.escape_markdown( + textwrap.shorten(first_line, 100) + ) lns.append(f"* `{truth_id}`: {_ratio}% - `{preview}`") if len(lns) > 5: lc = len(lns) - 5 @@ -1266,33 +1404,39 @@ class Ollama(commands.Cog): lns.append(f"*... and {lc} more*") embed.add_field( name=f"Possibly repeated {post_type}", - value=f"This {post_type} was similar to the following existing ones:\n" + "\n".join(lns), - inline=False + value=f"This {post_type} was similar to the following existing ones:\n" + + "\n".join(lns), + inline=False, ) else: truth_id = tuple(similar)[0] _ratio = similar[truth_id] truth = discord.utils.get(truths, id=truth_id) first_line = truth.content.splitlines()[0] - preview = discord.utils.escape_markdown(textwrap.shorten(first_line, 512)) + preview = discord.utils.escape_markdown( + textwrap.shorten(first_line, 512) + ) embed.add_field( name=f"Possibly repeated {post_type}", - value=f"This {post_type} was {_ratio}% similar to `{truth_id}`.\n>>> {preview}" + value=f"This {post_type} was {_ratio}% similar to `{truth_id}`.\n>>> {preview}", ) embed.set_footer( text="Finished generating {} based off of {:,} messages, using server {!r} | {!s}".format( - post_type, - len(messages) - 2, - server, - thread_id + post_type, len(messages) - 2, server, thread_id ) ) await msg.edit(embed=embed) @commands.command(aliases=["trump"]) @commands.guild_only() - async def donald(self, ctx: commands.Context, latest: typing.Optional[int] = None, *, question: str = None): + async def donald( + self, + ctx: commands.Context, + latest: typing.Optional[int] = None, + *, + question: str = None, + ): """ Generates a truth social post from trump! @@ -1307,7 +1451,13 @@ class Ollama(commands.Cog): @commands.command(aliases=["tate"]) @commands.guild_only() - async def andrew(self, ctx: commands.Context, latest: typing.Optional[int] = None, *, question: str = None): + async def andrew( + self, + ctx: commands.Context, + latest: typing.Optional[int] = None, + *, + question: str = None, + ): """ Generates a truth social post from Andrew Tate @@ -1319,10 +1469,16 @@ class Ollama(commands.Cog): question = f"'@{ctx.author.display_name}' asks: {question!r}" async with ctx.channel.typing(): await self.generate_truth(ctx, "tate", latest, question=question) - + @commands.command(aliases=["sunak"]) @commands.guild_only() - async def rishi(self, ctx: commands.Context, latest: typing.Optional[int] = None, *, question: str = None): + async def rishi( + self, + ctx: commands.Context, + latest: typing.Optional[int] = None, + *, + question: str = None, + ): """ Generates a twitter post from Rishi Sunak @@ -1334,10 +1490,16 @@ class Ollama(commands.Cog): question = f"'@{ctx.author.display_name}' asks: {question!r}" async with ctx.channel.typing(): await self.generate_truth(ctx, "Rishi Sunak", latest, question=question) - + @commands.command(aliases=["robinson"]) @commands.guild_only() - async def tommy(self, ctx: commands.Context, latest: typing.Optional[int] = None, *, question: str = None): + async def tommy( + self, + ctx: commands.Context, + latest: typing.Optional[int] = None, + *, + question: str = None, + ): """ Generates a twitter post from Tommy Robinson @@ -1348,11 +1510,19 @@ class Ollama(commands.Cog): if question: question = f"'@{ctx.author.display_name}' asks: {question!r}" async with ctx.channel.typing(): - await self.generate_truth(ctx, "Tommy Robinson 🇬🇧", latest, question=question) - + await self.generate_truth( + ctx, "Tommy Robinson 🇬🇧", latest, question=question + ) + @commands.command(aliases=["fox"]) @commands.guild_only() - async def laurence(self, ctx: commands.Context, latest: typing.Optional[int] = None, *, question: str = None): + async def laurence( + self, + ctx: commands.Context, + latest: typing.Optional[int] = None, + *, + question: str = None, + ): """ Generates a twitter post from Laurence Fox @@ -1364,10 +1534,16 @@ class Ollama(commands.Cog): question = f"'@{ctx.author.display_name}' asks: {question!r}" async with ctx.channel.typing(): await self.generate_truth(ctx, "Laurence Fox", latest, question=question) - + @commands.command(aliases=["farage"]) @commands.guild_only() - async def nigel(self, ctx: commands.Context, latest: typing.Optional[int] = None, *, question: str = None): + async def nigel( + self, + ctx: commands.Context, + latest: typing.Optional[int] = None, + *, + question: str = None, + ): """ Generates a twitter post from Nigel Farage @@ -1382,7 +1558,13 @@ class Ollama(commands.Cog): @commands.command(aliases=["starmer"]) @commands.guild_only() - async def keir(self, ctx: commands.Context, latest: typing.Optional[int] = None, *, question: str = None): + async def keir( + self, + ctx: commands.Context, + latest: typing.Optional[int] = None, + *, + question: str = None, + ): """ Generates a twitter post from Keir Starmer @@ -1394,10 +1576,16 @@ class Ollama(commands.Cog): question = f"'@{ctx.author.display_name}' asks: {question!r}" async with ctx.channel.typing(): await self.generate_truth(ctx, "Keir Starmer", latest, question=question) - + @commands.command(aliases=["johnson"]) @commands.guild_only() - async def boris(self, ctx: commands.Context, latest: typing.Optional[int] = None, *, question: str = None): + async def boris( + self, + ctx: commands.Context, + latest: typing.Optional[int] = None, + *, + question: str = None, + ): """ Generates a twitter post from Boris Johnson @@ -1409,10 +1597,16 @@ class Ollama(commands.Cog): question = f"'@{ctx.author.display_name}' asks: {question!r}" async with ctx.channel.typing(): await self.generate_truth(ctx, "Boris Johnson", latest, question=question) - + @commands.command(aliases=["desantis"]) @commands.guild_only() - async def ron(self, ctx: commands.Context, latest: typing.Optional[int] = None, *, question: str = None): + async def ron( + self, + ctx: commands.Context, + latest: typing.Optional[int] = None, + *, + question: str = None, + ): """ Generates a twitter post from Ron Desantis @@ -1447,12 +1641,15 @@ class Ollama(commands.Cog): title="Truth:", description=truth.content, colour=0x6559FF, - timestamp=datetime.datetime.fromtimestamp(truth.timestamp, tz=datetime.timezone.utc), + timestamp=datetime.datetime.fromtimestamp( + truth.timestamp, tz=datetime.timezone.utc + ), ) if truth.extra: embed.add_field( name="Extra Data", - value="```json\n%s```" % json.dumps(truth.extra, indent=2, default=repr), + value="```json\n%s```" + % json.dumps(truth.extra, indent=2, default=repr), ) embed.set_author(name=truth.author) await ctx.reply(embed=embed) diff --git a/src/cogs/onion_feed.py b/src/cogs/onion_feed.py index a1e5ca9..bc0008b 100644 --- a/src/cogs/onion_feed.py +++ b/src/cogs/onion_feed.py @@ -1,6 +1,7 @@ """ Scrapes the onion RSS feed once every hour and posts any new articles to the desired channel """ + import asyncio import dataclasses import datetime @@ -38,7 +39,11 @@ class OnionFeed(commands.Cog): @staticmethod def parse_item(item: BeautifulSoup) -> RSSItem: - description = BeautifulSoup(item.description.get_text(), "html.parser").p.get_text(strip=True).strip()[:-1] + description = ( + BeautifulSoup(item.description.get_text(), "html.parser") + .p.get_text(strip=True) + .strip()[:-1] + ) kwargs = { "title": item.title.get_text(strip=True).strip(), "link": item.link.get_text(strip=True).strip(), @@ -70,8 +75,12 @@ class OnionFeed(commands.Cog): async with httpx.AsyncClient() as client: response = await client.get(self.SOURCE) if response.status_code != 200: - return self.log.error(f"Failed to fetch onion feed: {response.status_code}") - items: list[RSSItem] = await asyncio.to_thread(self.parse_feed, response.text) + return self.log.error( + f"Failed to fetch onion feed: {response.status_code}" + ) + items: list[RSSItem] = await asyncio.to_thread( + self.parse_feed, response.text + ) for item in items: if self.redis.get("onion-" + item.guid): continue @@ -88,7 +97,9 @@ class OnionFeed(commands.Cog): # noinspection PyAsyncCall self.redis.set("onion-" + item.guid, str(msg.id)) except discord.HTTPException: - self.log.exception(f"Failed to send onion feed message: {item.title}") + self.log.exception( + f"Failed to send onion feed message: {item.title}" + ) else: self.log.debug(f"Sent onion feed message: {item.title}") diff --git a/src/cogs/quote_quota.py b/src/cogs/quote_quota.py index 8cf9a03..b895356 100644 --- a/src/cogs/quote_quota.py +++ b/src/cogs/quote_quota.py @@ -16,11 +16,9 @@ from discord.ext import commands from conf import CONFIG from pydantic import BaseModel, Field -JSON: typing.Union[ - str, int, float, bool, None, dict[str, "JSON"], list["JSON"] -] = typing.Union[ - str, int, float, bool, None, dict, list -] +JSON: typing.Union[str, int, float, bool, None, dict[str, "JSON"], list["JSON"]] = ( + typing.Union[str, int, float, bool, None, dict, list] +) class TruthPayload(BaseModel): @@ -32,7 +30,6 @@ class TruthPayload(BaseModel): class QuoteQuota(commands.Cog): - def __init__(self, bot): self.bot = bot self.quotes_channel_id = CONFIG["quote_a"].get("channel_id") @@ -47,7 +44,9 @@ class QuoteQuota(commands.Cog): return c @staticmethod - def generate_pie_chart(usernames: list[str], counts: list[int], no_other: bool = False) -> discord.File: + def generate_pie_chart( + usernames: list[str], counts: list[int], no_other: bool = False + ) -> discord.File: """ Converts the given username and count tuples into a nice pretty pie chart. @@ -95,7 +94,9 @@ class QuoteQuota(commands.Cog): startangle=90, radius=1.2, ) - fig.subplots_adjust(left=0.1, bottom=0.1, right=0.9, top=0.9, wspace=0.3, hspace=0.4) + fig.subplots_adjust( + left=0.1, bottom=0.1, right=0.9, top=0.9, wspace=0.3, hspace=0.4 + ) fio = io.BytesIO() fig.savefig(fio, format="png") fio.seek(0) @@ -130,7 +131,9 @@ class QuoteQuota(commands.Cog): now = discord.utils.utcnow() oldest = now - timedelta(days=days) await ctx.defer() - channel = self.quotes_channel or discord.utils.get(ctx.guild.text_channels, name="quotes") + channel = self.quotes_channel or discord.utils.get( + ctx.guild.text_channels, name="quotes" + ) if not channel: return await ctx.respond(":x: Cannot find quotes channel.") @@ -139,7 +142,9 @@ class QuoteQuota(commands.Cog): authors = {} filtered_messages = 0 total = 0 - async for message in channel.history(limit=None, after=oldest, oldest_first=False): + async for message in channel.history( + limit=None, after=oldest, oldest_first=False + ): total += 1 if not message.content: filtered_messages += 1 @@ -179,10 +184,15 @@ class QuoteQuota(commands.Cog): ' (e.g. `"This is my quote" - Jimmy`)'.format(days) ) else: - return await ctx.edit(content="No messages found in the last {!s} days.".format(days)) + return await ctx.edit( + content="No messages found in the last {!s} days.".format(days) + ) file = await asyncio.to_thread( - self.generate_pie_chart, list(authors.keys()), list(authors.values()), merge_other + self.generate_pie_chart, + list(authors.keys()), + list(authors.values()), + merge_other, ) return await ctx.edit( content="{:,} messages (out of {:,}) were filtered (didn't follow format?)".format( @@ -192,7 +202,11 @@ class QuoteQuota(commands.Cog): ) def _metacounter( - self, truths: list[TruthPayload], filter_func: Callable[[TruthPayload], bool], *, now: datetime = None + self, + truths: list[TruthPayload], + filter_func: Callable[[TruthPayload], bool], + *, + now: datetime = None, ) -> dict[str, float | int | dict[str, int]]: def _is_today(date: datetime) -> bool: return date.date() == now.date() @@ -215,7 +229,9 @@ class QuoteQuota(commands.Cog): if filter_func(truth): created_at = datetime.fromtimestamp(truth.timestamp, tz=timezone.utc) age = now - created_at - self.log.debug("%r was a truth (%.2f seconds ago).", truth.id, age.total_seconds()) + self.log.debug( + "%r was a truth (%.2f seconds ago).", truth.id, age.total_seconds() + ) counts["all_time"] += 1 if _is_today(created_at): counts["today"] += 1 @@ -275,7 +291,9 @@ class QuoteQuota(commands.Cog): plt.bar(hrs, list(hours.values()), color="#5448EE") average = sum(hours.values()) / len(hours) - plt.axhline(average, color="red", linestyle="--", label=f"Average: {average:.1f}") + plt.axhline( + average, color="red", linestyle="--", label=f"Average: {average:.1f}" + ) file = io.BytesIO() plt.savefig(file, format="png") @@ -283,8 +301,7 @@ class QuoteQuota(commands.Cog): return discord.File(file, "truths.png") async def _process_all_messages( - self, - truths: list[TruthPayload] + self, truths: list[TruthPayload] ) -> tuple[discord.Embed, discord.File]: """ Processes all the messages in the given channel. @@ -292,7 +309,11 @@ class QuoteQuota(commands.Cog): :param truths: The truths to process :returns: The stats """ - embed = discord.Embed(title="Truth Counts", color=discord.Color.blurple(), timestamp=discord.utils.utcnow()) + embed = discord.Embed( + title="Truth Counts", + color=discord.Color.blurple(), + timestamp=discord.utils.utcnow(), + ) trump_stats = await self._process_trump_truths(truths) tate_stats = await self._process_tate_truths(truths) @@ -346,7 +367,9 @@ class QuoteQuota(commands.Cog): ) response.raise_for_status() truths: list[dict] = response.json() - truths: list[TruthPayload] = list(map(lambda t: TruthPayload.model_validate(t), truths)) + truths: list[TruthPayload] = list( + map(lambda t: TruthPayload.model_validate(t), truths) + ) embed, file = await self._process_all_messages(truths) await ctx.edit(embed=embed, file=file) diff --git a/src/cogs/screenshot.py b/src/cogs/screenshot.py index 8d1affd..19d902a 100644 --- a/src/cogs/screenshot.py +++ b/src/cogs/screenshot.py @@ -1,5 +1,4 @@ import asyncio -import copy import datetime import io import logging @@ -133,7 +132,9 @@ class ScreenshotCog(commands.Cog): else: use_proxy = False service = await asyncio.to_thread(ChromeService) - driver: webdriver.Chrome = await asyncio.to_thread(webdriver.Chrome, service=service, options=options) + driver: webdriver.Chrome = await asyncio.to_thread( + webdriver.Chrome, service=service, options=options + ) driver.set_page_load_timeout(load_timeout) if resolution: resolution = RESOLUTIONS.get(resolution.lower(), resolution) @@ -141,9 +142,13 @@ class ScreenshotCog(commands.Cog): width, height = map(int, resolution.split("x")) driver.set_window_size(width, height) if height > 4320 or width > 7680: - return await ctx.respond("Invalid resolution. Max resolution is 7680x4320 (8K).") + return await ctx.respond( + "Invalid resolution. Max resolution is 7680x4320 (8K)." + ) except ValueError: - return await ctx.respond("Invalid resolution. please provide width x height, e.g. 1920x1080") + return await ctx.respond( + "Invalid resolution. please provide width x height, e.g. 1920x1080" + ) if eager: driver.implicitly_wait(render_timeout) except Exception as e: @@ -151,7 +156,13 @@ class ScreenshotCog(commands.Cog): raise end_init = time.time() - await ctx.edit(content=("Loading webpage..." if not eager else "Loading & screenshotting webpage...")) + await ctx.edit( + content=( + "Loading webpage..." + if not eager + else "Loading & screenshotting webpage..." + ) + ) start_request = time.time() try: await asyncio.to_thread(driver.get, url) @@ -160,7 +171,9 @@ class ScreenshotCog(commands.Cog): if "TimeoutException" in str(e): return await ctx.respond("Timed out while loading webpage.") else: - return await ctx.respond("Failed to load webpage:\n```\n%s\n```" % str(e.msg)) + return await ctx.respond( + "Failed to load webpage:\n```\n%s\n```" % str(e.msg) + ) except Exception as e: await self.bot.loop.run_in_executor(None, driver.quit) await ctx.respond("Failed to get the webpage: " + str(e)) @@ -170,7 +183,9 @@ class ScreenshotCog(commands.Cog): if not eager: now = discord.utils.utcnow() expires = now + datetime.timedelta(seconds=render_timeout) - await ctx.edit(content=f"Rendering (expires {discord.utils.format_dt(expires, 'R')})...") + await ctx.edit( + content=f"Rendering (expires {discord.utils.format_dt(expires, 'R')})..." + ) start_wait = time.time() await asyncio.sleep(render_timeout) end_wait = time.time() @@ -200,7 +215,9 @@ class ScreenshotCog(commands.Cog): await self.bot.loop.run_in_executor(None, driver.quit) end_cleanup = time.time() - screenshot_size_mb = round(len(await asyncio.to_thread(file.getvalue)) / 1024 / 1024, 2) + screenshot_size_mb = round( + len(await asyncio.to_thread(file.getvalue)) / 1024 / 1024, 2 + ) def seconds(start: float, end: float) -> float: return round(end - start, 2) @@ -221,7 +238,9 @@ class ScreenshotCog(commands.Cog): timestamp=discord.utils.utcnow(), ) embed.set_image(url="attachment://" + fn) - return await ctx.edit(content=None, embed=embed, file=discord.File(file, filename=fn)) + return await ctx.edit( + content=None, embed=embed, file=discord.File(file, filename=fn) + ) def setup(bot): diff --git a/src/cogs/starboard.py b/src/cogs/starboard.py index 021619c..a6c9e36 100644 --- a/src/cogs/starboard.py +++ b/src/cogs/starboard.py @@ -10,24 +10,32 @@ from conf import CONFIG class Starboard(commands.Cog): - DOWNVOTE_EMOJI = discord.PartialEmoji.from_str("\N{collision symbol}") + DOWNVOTE_EMOJI = discord.PartialEmoji.from_str("\N{COLLISION SYMBOL}") def __init__(self, bot): self.bot: commands.Bot = bot self.config = CONFIG["starboard"] - self.emoji = discord.PartialEmoji.from_str(self.config.get("emoji", "\N{white medium star}")) + self.emoji = discord.PartialEmoji.from_str( + self.config.get("emoji", "\N{WHITE MEDIUM STAR}") + ) self.log = logging.getLogger("jimmy.cogs.starboard") self.redis = Redis(**CONFIG["redis"]) - async def generate_starboard_embed(self, message: discord.Message) -> tuple[list[discord.Embed], int]: + async def generate_starboard_embed( + self, message: discord.Message + ) -> tuple[list[discord.Embed], int]: """ Generates an embed ready for a starboard message. :param message: The message to base off of. :return: The created embed """ - reactions: list[discord.Reaction] = [x for x in message.reactions if str(x.emoji) == str(self.emoji)] - downvote_reactions = [x for x in message.reactions if str(x.emoji) == str(self.DOWNVOTE_EMOJI)] + reactions: list[discord.Reaction] = [ + x for x in message.reactions if str(x.emoji) == str(self.emoji) + ] + downvote_reactions = [ + x for x in message.reactions if str(x.emoji) == str(self.DOWNVOTE_EMOJI) + ] if not reactions: # Nobody has added the star reaction. star_count = 0 @@ -35,12 +43,18 @@ class Starboard(commands.Cog): else: # Count the number of reactions star_count = sum([x.count for x in reactions]) - self.log.debug("There are a total of %d star reactions on message.", star_count) + self.log.debug( + "There are a total of %d star reactions on message.", star_count + ) if downvote_reactions: _dv = sum([x.count for x in downvote_reactions]) star_count -= _dv - self.log.debug("There are %d downvotes on message, resulting in %d stars.", _dv, star_count) + self.log.debug( + "There are %d downvotes on message, resulting in %d stars.", + _dv, + star_count, + ) if star_count >= 0: star_emoji_count = (str(self.emoji) * star_count)[:10] @@ -54,18 +68,20 @@ class Starboard(commands.Cog): author=discord.EmbedAuthor( message.author.display_name, message.author.jump_url, - message.author.display_avatar.url + message.author.display_avatar.url, ), fields=[ discord.EmbedField( name="Info", - value=f"[Stars: {star_emoji_count} ({star_count:,})]({message.jump_url})" + value=f"[Stars: {star_emoji_count} ({star_count:,})]({message.jump_url})", ) - ] + ], ) if message.reference: try: - ref_message = await self.get_or_fetch_message(message.reference.channel_id, message.reference.message_id) + ref_message = await self.get_or_fetch_message( + message.reference.channel_id, message.reference.message_id + ) except discord.HTTPException: pass else: @@ -74,31 +90,26 @@ class Starboard(commands.Cog): remaining = 1024 - len(v) t = textwrap.shorten(text, remaining, placeholder="...") v = f"[{ref_message.author.display_name}'s message: {t}]({ref_message.jump_url})" - embed.add_field( - name="Replying to", - value=v - ) + embed.add_field(name="Replying to", value=v) elif message.interaction: if message.interaction.type == discord.InteractionType.application_command: real_author: discord.User = await discord.utils.get_or_fetch( - self.bot, - "user", - int(message.interaction.data["user"]["id"]) + self.bot, "user", int(message.interaction.data["user"]["id"]) + ) + real_author = ( + await discord.utils.get_or_fetch( + message.guild, "member", real_author.id, default=real_author + ) + or message.author ) - real_author = await discord.utils.get_or_fetch( - message.guild, - "member", - real_author.id, - default=real_author - ) or message.author embed.set_author( name=real_author.display_name, icon_url=real_author.display_avatar.url, - url=real_author.jump_url + url=real_author.jump_url, ) embed.add_field( name="Interaction", - value=f"Command `/{message.interaction.data['name']}` of {message.author.mention}" + value=f"Command `/{message.interaction.data['name']}` of {message.author.mention}", ) if message.content: @@ -107,17 +118,24 @@ class Starboard(commands.Cog): for message_embed in message.embeds: if message_embed.type != "rich": if message_embed.type == "image": - if message_embed.thumbnail and message_embed.thumbnail.proxy_url: + if ( + message_embed.thumbnail + and message_embed.thumbnail.proxy_url + ): embed.set_image(url=message_embed.thumbnail.proxy_url) continue if message_embed.description: embed.description = message_embed.description elif not message.attachments: - raise ValueError("Message does not appear to contain any text, embeds, or attachments.") + raise ValueError( + "Message does not appear to contain any text, embeds, or attachments." + ) if message.attachments: new_fields = [] - for n, attachment in reversed(tuple(enumerate(message.attachments, start=1))): + for n, attachment in reversed( + tuple(enumerate(message.attachments, start=1)) + ): attachment: discord.Attachment if attachment.size >= 1024 * 1024: size = f"{attachment.size / 1024 / 1024:,.1f}MiB" @@ -129,7 +147,7 @@ class Starboard(commands.Cog): { "name": "Attachment #%d:" % n, "value": f"[{attachment.filename} ({size})]({attachment.url})", - "inline": True + "inline": True, } ) if attachment.content_type.startswith("image/"): @@ -142,13 +160,19 @@ class Starboard(commands.Cog): return [embed, *filter(lambda e: e.type == "rich", message.embeds)], star_count - async def get_or_fetch_message(self, channel_id: int, message_id: int) -> discord.Message: + async def get_or_fetch_message( + self, channel_id: int, message_id: int + ) -> discord.Message: """ Fetches a message from cache where possible, falling back to the API. """ - message: discord.Message | None = discord.utils.get(self.bot.cached_messages, id=message_id) + message: discord.Message | None = discord.utils.get( + self.bot.cached_messages, id=message_id + ) if not message: - message: discord.Message = await self.bot.get_channel(channel_id).fetch_message(message_id) + message: discord.Message = await self.bot.get_channel( + channel_id + ).fetch_message(message_id) return message @commands.Cog.listener("on_raw_reaction_add") @@ -171,27 +195,35 @@ class Starboard(commands.Cog): guild: discord.Guild = self.bot.get_guild(payload.guild_id) starboard_channel: discord.TextChannel | None = discord.utils.get( - guild.text_channels, - name=self.config.get("channel_name", "starboard") + guild.text_channels, name=self.config.get("channel_name", "starboard") ) if not starboard_channel: - self.log.warning("Could not find starboard channel in %s (%d)", guild, guild.id) + self.log.warning( + "Could not find starboard channel in %s (%d)", guild, guild.id + ) return if payload.channel_id == starboard_channel.id: self.log.debug("Ignoring reaction in starboard channel.") return - message = await self.get_or_fetch_message(payload.channel_id, payload.message_id) + message = await self.get_or_fetch_message( + payload.channel_id, payload.message_id + ) if payload.user_id == message.author.id: self.log.info("%s tried to star their own message.", message.author) return await message.reply( - "You can't star your own message you pretentious dick. Go outside, %s." % message.author.mention, - delete_after=30 + "You can't star your own message you pretentious dick. Go outside, %s." + % message.author.mention, + delete_after=30, ) - self.log.info("Processing starboard reaction from %s: %s", payload.user_id, message.jump_url) + self.log.info( + "Processing starboard reaction from %s: %s", + payload.user_id, + message.jump_url, + ) embed, star_count = await self.generate_starboard_embed(message) data = await self.redis.get(str(message.id)) @@ -205,7 +237,7 @@ class Starboard(commands.Cog): "source_message_id": payload.message_id, "history": [], "starboard_channel_id": starboard_channel.id, - "starboard_message_id": None + "starboard_message_id": None, } if not starboard_channel.can_send(embed[0]): self.log.warning( @@ -213,29 +245,24 @@ class Starboard(commands.Cog): starboard_channel.id, payload.guild_id, starboard_channel.name, - guild.name + guild.name, ) return - starboard_message = await starboard_channel.send( - embeds=embed, - silent=True - ) + starboard_message = await starboard_channel.send(embeds=embed, silent=True) data["starboard_message_id"] = starboard_message.id await self.redis.set(str(message.id), json.dumps(data)) else: data = json.loads(data) try: starboard_message = await self.get_or_fetch_message( - data["starboard_channel_id"], - data["starboard_message_id"] + data["starboard_channel_id"], data["starboard_message_id"] ) except discord.NotFound: if star_count <= 0: return starboard_message = await starboard_channel.send( - embeds=embed, - silent=True + embeds=embed, silent=True ) data["starboard_message_id"] = starboard_message.id data["starboard_channel_id"] = starboard_message.channel.id @@ -247,18 +274,20 @@ class Starboard(commands.Cog): "Deleted message %s in %s, %s - lost all stars", starboard_message.id, starboard_message.channel.name, - starboard_message.guild.name + starboard_message.guild.name, ) elif starboard_message.embeds[0] != embed[0]: await starboard_message.edit(embeds=embed) @commands.message_command(name="Preview Starboard Message") - async def preview_starboard_message(self, ctx: discord.ApplicationContext, message: discord.Message): + async def preview_starboard_message( + self, ctx: discord.ApplicationContext, message: discord.Message + ): embed, stars = await self.generate_starboard_embed(message) - data = await self.redis.get(str(message.id)) or 'null' + data = await self.redis.get(str(message.id)) or "null" return await ctx.respond( f"```json\n{json.dumps(json.loads(data), indent=4)}\n```\nStars: {stars:,}", - embeds=embed + embeds=embed, ) diff --git a/src/cogs/ytdl.py b/src/cogs/ytdl.py index 833b1f2..9dc901f 100644 --- a/src/cogs/ytdl.py +++ b/src/cogs/ytdl.py @@ -135,12 +135,21 @@ class YTDLCog(commands.Cog): channel_id=excluded.channel_id, attachment_index=excluded.attachment_index """, - (_hash, message.id, message.channel.id, webpage_url, format_id, attachment_index), + ( + _hash, + message.id, + message.channel.id, + webpage_url, + format_id, + attachment_index, + ), ) await db.commit() return _hash - async def get_saved(self, webpage_url: str, format_id: str, snip: str) -> typing.Optional[str]: + async def get_saved( + self, webpage_url: str, format_id: str, snip: str + ) -> typing.Optional[str]: """ Attempts to retrieve the attachment URL of a previously saved download. :param webpage_url: The webpage url @@ -154,12 +163,19 @@ class YTDLCog(commands.Cog): logging.error("Failed to initialise ytdl database: %s", e, exc_info=True) return async with aiosqlite.connect("./data/ytdl.db") as db: - _hash = hashlib.md5(f"{webpage_url}:{format_id}:{snip}".encode()).hexdigest() + _hash = hashlib.md5( + f"{webpage_url}:{format_id}:{snip}".encode() + ).hexdigest() self.log.debug( - "Attempting to find a saved download for '%s:%s:%s' (%r).", webpage_url, format_id, snip, _hash + "Attempting to find a saved download for '%s:%s:%s' (%r).", + webpage_url, + format_id, + snip, + _hash, ) cursor = await db.execute( - "SELECT message_id, channel_id, attachment_index FROM downloads WHERE key=?", (_hash,) + "SELECT message_id, channel_id, attachment_index FROM downloads WHERE key=?", + (_hash,), ) entry = await cursor.fetchone() if not entry: @@ -173,7 +189,9 @@ class YTDLCog(commands.Cog): try: message = await channel.fetch_message(message_id) except discord.HTTPException: - self.log.debug("%r did not contain a message with ID %r", channel, message_id) + self.log.debug( + "%r did not contain a message with ID %r", channel, message_id + ) await db.execute("DELETE FROM downloads WHERE key=?", (_hash,)) return @@ -182,7 +200,11 @@ class YTDLCog(commands.Cog): self.log.debug("Found URL %r, returning.", url) return url except IndexError: - self.log.debug("Attachment index %d is out of range (%r)", attachment_index, message.attachments) + self.log.debug( + "Attachment index %d is out of range (%r)", + attachment_index, + message.attachments, + ) return def convert_to_m4a(self, file: Path) -> Path: @@ -209,23 +231,32 @@ class YTDLCog(commands.Cog): str(new_file), ] self.log.debug("Running command: ffmpeg %s", " ".join(args)) - process = subprocess.run(["ffmpeg", *args], stdout=subprocess.PIPE, stderr=subprocess.PIPE) + process = subprocess.run( + ["ffmpeg", *args], stdout=subprocess.PIPE, stderr=subprocess.PIPE + ) if process.returncode != 0: raise RuntimeError(process.stderr.decode()) return new_file @staticmethod - async def upload_to_0x0(name: str, data: typing.IO[bytes], mime_type: str | None = None) -> str: + async def upload_to_0x0( + name: str, data: typing.IO[bytes], mime_type: str | None = None + ) -> str: if not mime_type: import magic - mime_type = await asyncio.to_thread(magic.from_buffer, data.read(4096), mime=True) + + mime_type = await asyncio.to_thread( + magic.from_buffer, data.read(4096), mime=True + ) data.seek(0) async with httpx.AsyncClient() as client: response = await client.post( "https://0x0.st", files={"file": (name, data, mime_type)}, data={"expires": 12}, - headers={"User-Agent": "CollegeBot (see: https://gist.i-am.nexus/nex/f63fcb9eb389401caf66d1dfc3c7570c)"}, + headers={ + "User-Agent": "CollegeBot (see: https://gist.i-am.nexus/nex/f63fcb9eb389401caf66d1dfc3c7570c)" + }, ) if response.status_code == 200: return urlparse(response.text).path[1:] @@ -235,7 +266,10 @@ class YTDLCog(commands.Cog): async def yt_dl_command( self, ctx: discord.ApplicationContext, - url: typing.Annotated[str, discord.Option(str, description="The URL to download from.", required=True)], + url: typing.Annotated[ + str, + discord.Option(str, description="The URL to download from.", required=True), + ], user_format: typing.Annotated[ typing.Optional[str], discord.Option( @@ -258,7 +292,10 @@ class YTDLCog(commands.Cog): ], snip: typing.Annotated[ typing.Optional[str], - discord.Option(description="A start and end position to trim. e.g. 00:00:00-00:10:00.", required=False), + discord.Option( + description="A start and end position to trim. e.g. 00:00:00-00:10:00.", + required=False, + ), ], subtitles: typing.Annotated[ typing.Optional[str], @@ -267,7 +304,7 @@ class YTDLCog(commands.Cog): description="The language code of the subtitles to download. e.g. 'en', 'auto'", required=False, ), - ] + ], ): """Runs yt-dlp and outputs into discord.""" await ctx.defer() @@ -279,28 +316,40 @@ class YTDLCog(commands.Cog): if stop.is_set(): raise RuntimeError("Download cancelled.") n = time.time() - _total = _data.get("total_bytes", _data.get("total_bytes_estimate")) or ctx.guild.filesize_limit + _total = ( + _data.get("total_bytes", _data.get("total_bytes_estimate")) + or ctx.guild.filesize_limit + ) if _total: _percent = round((_data.get("downloaded_bytes") or 0) / _total * 100, 2) else: _total = max(1, _data.get("fragment_count", 4096) or 4096) - _percent = round(max(_data.get("fragment_index", 1) or 1, 1) / _total * 100, 2) + _percent = round( + max(_data.get("fragment_index", 1) or 1, 1) / _total * 100, 2 + ) _speed_bytes_per_second = _data.get("speed", 1) or 1 or 1 - _speed_megabits_per_second = round((_speed_bytes_per_second * 8) / 1024 / 1024) + _speed_megabits_per_second = round( + (_speed_bytes_per_second * 8) / 1024 / 1024 + ) if _data.get("eta"): - _eta = discord.utils.utcnow() + datetime.timedelta(seconds=_data.get("eta")) + _eta = discord.utils.utcnow() + datetime.timedelta( + seconds=_data.get("eta") + ) else: _eta = discord.utils.utcnow() + datetime.timedelta(minutes=1) blocks = "#" * math.floor(_percent / 10) bar = f"{blocks}{'.' * (10 - len(blocks))}" - line = (f"{_percent}% [{bar}] | {_speed_megabits_per_second}Mbps | " - f"ETA {discord.utils.format_dt(_eta, 'R')}") + line = ( + f"{_percent}% [{bar}] | {_speed_megabits_per_second}Mbps | " + f"ETA {discord.utils.format_dt(_eta, 'R')}" + ) nonlocal last_edit if (n - last_edit) >= 1.1: embed.clear_fields() embed.add_field(name="Progress", value=line) ctx.bot.loop.create_task(ctx.edit(embed=embed)) last_edit = time.time() + options["progress_hooks"] = [_download_hook] description = "" @@ -331,13 +380,19 @@ class YTDLCog(commands.Cog): options["format_sort"] = ["abr", "br"] # noinspection PyTypeChecker options["postprocessors"].append( - {"key": "FFmpegExtractAudio", "preferredquality": "96", "preferredcodec": "best"} + { + "key": "FFmpegExtractAudio", + "preferredquality": "96", + "preferredcodec": "best", + } ) options["format"] = chosen_format options["paths"] = paths if subtitles and ffmpeg_installed: - subtitles, burn = subtitles.split("+", 1) if "+" in subtitles else (subtitles, "0") + subtitles, burn = ( + subtitles.split("+", 1) if "+" in subtitles else (subtitles, "0") + ) burn = burn[0].lower() in ("y", "1", "t") if subtitles.lower() == "auto": options["writeautosubtitles"] = True @@ -352,10 +407,14 @@ class YTDLCog(commands.Cog): ) with yt_dlp.YoutubeDL(options) as downloader: - await ctx.respond(embed=discord.Embed().set_footer(text="Downloading (step 1/10)")) + await ctx.respond( + embed=discord.Embed().set_footer(text="Downloading (step 1/10)") + ) try: # noinspection PyTypeChecker - extracted_info = await asyncio.to_thread(downloader.extract_info, url, download=False) + extracted_info = await asyncio.to_thread( + downloader.extract_info, url, download=False + ) except yt_dlp.utils.DownloadError as e: extracted_info = { "title": "error", @@ -382,22 +441,38 @@ class YTDLCog(commands.Cog): thumbnail_url = extracted_info.get("thumbnail") or None webpage_url = extracted_info.get("webpage_url", url) - chosen_format = extracted_info.get("format") or chosen_format or str(uuid.uuid4()) - chosen_format_id = extracted_info.get("format_id") or str(uuid.uuid4()) + chosen_format = ( + extracted_info.get("format") + or chosen_format + or str(uuid.uuid4()) + ) + chosen_format_id = extracted_info.get("format_id") or str( + uuid.uuid4() + ) final_extension = extracted_info.get("ext") or "mp4" - format_note = extracted_info.get("format_note", "%s (%s)" % (chosen_format, chosen_format_id)) or "" + format_note = ( + extracted_info.get( + "format_note", "%s (%s)" % (chosen_format, chosen_format_id) + ) + or "" + ) resolution = extracted_info.get("resolution") or "1x1" fps = extracted_info.get("fps", 0.0) or 0.0 vcodec = extracted_info.get("vcodec") or "h264" acodec = extracted_info.get("acodec") or "aac" - filesize = extracted_info.get("filesize", extracted_info.get("filesize_approx", 1)) - likes = extracted_info.get("like_count", extracted_info.get("average_rating", 0)) + filesize = extracted_info.get( + "filesize", extracted_info.get("filesize_approx", 1) + ) + likes = extracted_info.get( + "like_count", extracted_info.get("average_rating", 0) + ) views = extracted_info.get("view_count", 0) lines = [] if chosen_format and chosen_format_id: lines.append( - "* Chosen format: `%s` (`%s`)" % (chosen_format, chosen_format_id), + "* Chosen format: `%s` (`%s`)" + % (chosen_format, chosen_format_id), ) if format_note: lines.append("* Format note: %r" % format_note) @@ -411,7 +486,9 @@ class YTDLCog(commands.Cog): if vcodec or acodec: lines.append("%s+%s" % (vcodec or "N/A", acodec or "N/A")) if filesize: - lines.append("* Filesize: %s" % yt_dlp.utils.format_bytes(filesize)) + lines.append( + "* Filesize: %s" % yt_dlp.utils.format_bytes(filesize) + ) if lines: description += "\n" @@ -424,27 +501,29 @@ class YTDLCog(commands.Cog): url=webpage_url, colour=self.colours.get(domain, discord.Colour.og_blurple()), ) - embed.add_field( - name="Progress", - value="0% [..........]" - ) + embed.add_field(name="Progress", value="0% [..........]") embed.set_footer(text="Downloading (step 2/10)") embed.set_thumbnail(url=thumbnail_url) class StopView(discord.ui.View): - @discord.ui.button(label="Cancel download", style=discord.ButtonStyle.danger) - async def _stop(self, button: discord.ui.Button, interaction: discord.Interaction): + @discord.ui.button( + label="Cancel download", style=discord.ButtonStyle.danger + ) + async def _stop( + self, + button: discord.ui.Button, + interaction: discord.Interaction, + ): stop.set() button.label = "Cancelling..." button.disabled = True await interaction.response.edit_message(view=self) self.stop() - await ctx.edit( - embed=embed, - view=StopView(timeout=86400) + await ctx.edit(embed=embed, view=StopView(timeout=86400)) + previous = await self.get_saved( + webpage_url, chosen_format_id, snip or "*" ) - previous = await self.get_saved(webpage_url, chosen_format_id, snip or "*") if previous: await ctx.edit( content=previous, @@ -454,7 +533,11 @@ class YTDLCog(commands.Cog): colour=discord.Colour.green(), timestamp=discord.utils.utcnow(), url=previous, - fields=[discord.EmbedField(name="URL", value=previous, inline=False)], + fields=[ + discord.EmbedField( + name="URL", value=previous, inline=False + ) + ], ).set_image(url=previous), ) return @@ -462,7 +545,9 @@ class YTDLCog(commands.Cog): last_edit = time.time() try: - await asyncio.to_thread(functools.partial(downloader.download, [url])) + await asyncio.to_thread( + functools.partial(downloader.download, [url]) + ) except yt_dlp.DownloadError as e: logging.error(e, exc_info=True) return await ctx.edit( @@ -473,7 +558,7 @@ class YTDLCog(commands.Cog): url=webpage_url, ), delete_after=120, - view=None + view=None, ) except RuntimeError: return await ctx.edit( @@ -484,16 +569,26 @@ class YTDLCog(commands.Cog): url=webpage_url, ), delete_after=120, - view=None + view=None, ) await ctx.edit(view=None) try: if audio_only is False: - file: Path = next(temp_dir.glob("*." + extracted_info.get("ext", "*"))) + file: Path = next( + temp_dir.glob("*." + extracted_info.get("ext", "*")) + ) else: # can be .opus, .m4a, .mp3, .ogg, .oga for _file in temp_dir.iterdir(): - if _file.suffix in (".opus", ".m4a", ".mp3", ".ogg", ".oga", ".aac", ".wav"): + if _file.suffix in ( + ".opus", + ".m4a", + ".mp3", + ".ogg", + ".oga", + ".aac", + ".wav", + ): file: Path = _file break else: @@ -523,7 +618,9 @@ class YTDLCog(commands.Cog): except ValueError: trim_start, trim_end = snip, None trim_start = trim_start or "00:00:00" - trim_end = trim_end or extracted_info.get("duration_string", "00:30:00") + trim_end = trim_end or extracted_info.get( + "duration_string", "00:30:00" + ) new_file = temp_dir / ("output" + file.suffix) args = [ "-hwaccel", @@ -562,7 +659,10 @@ class YTDLCog(commands.Cog): ) self.log.debug("Running command: 'ffmpeg %s'", " ".join(args)) process = await asyncio.create_subprocess_exec( - "ffmpeg", *args, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE + "ffmpeg", + *args, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, ) stdout, stderr = await process.communicate() self.log.debug("STDOUT:\n%r", stdout.decode()) @@ -606,31 +706,40 @@ class YTDLCog(commands.Cog): ) embed = discord.Embed( title=f"Downloaded {title}!", - description="Views: {:,} | Likes: {:,}".format(views or 0, likes or 0), + description="Views: {:,} | Likes: {:,}".format( + views or 0, likes or 0 + ), colour=discord.Colour.green(), timestamp=discord.utils.utcnow(), url=webpage_url, ) try: - if size_bytes >= (20 * 1024 * 1024) or vcodec.lower() in ["hevc", "h265", "av1", "av01"]: + if size_bytes >= (20 * 1024 * 1024) or vcodec.lower() in [ + "hevc", + "h265", + "av1", + "av01", + ]: with file.open("rb") as fb: - part = await self.upload_to_0x0( - file.name, - fb - ) - embed.add_field(name="URL", value=f"https://0x0.st/{part}", inline=False) - await ctx.edit( - embed=embed + part = await self.upload_to_0x0(file.name, fb) + embed.add_field( + name="URL", value=f"https://0x0.st/{part}", inline=False ) + await ctx.edit(embed=embed) await ctx.respond("https://embeds.video/0x0/" + part) else: - upload_file = await asyncio.to_thread(discord.File, file, filename=file.name) - msg = await ctx.edit( - file=upload_file, - embed=embed + upload_file = await asyncio.to_thread( + discord.File, file, filename=file.name ) - await self.save_link(msg, webpage_url, chosen_format_id, snip=snip or "*") - except (discord.HTTPException, ConnectionError, httpx.HTTPStatusError) as e: + msg = await ctx.edit(file=upload_file, embed=embed) + await self.save_link( + msg, webpage_url, chosen_format_id, snip=snip or "*" + ) + except ( + discord.HTTPException, + ConnectionError, + httpx.HTTPStatusError, + ) as e: self.log.error(e, exc_info=True) return await ctx.edit( embed=discord.Embed( diff --git a/src/conf.py b/src/conf.py index 198bf57..d4ae7c8 100644 --- a/src/conf.py +++ b/src/conf.py @@ -21,7 +21,9 @@ if (Path.cwd() / ".git").exists(): log.debug("Unable to auto-detect running version using git.", exc_info=True) VERSION = "unknown" else: - log.debug("Unable to auto-detect running version using git, no .git directory exists.") + log.debug( + "Unable to auto-detect running version using git, no .git directory exists." + ) VERSION = "unknown" try: @@ -32,7 +34,7 @@ except FileNotFoundError: log.critical( "Unable to locate config.toml in %s. Using default configuration. Good luck!", cwd, - exc_info=True + exc_info=True, ) CONFIG = {} CONFIG.setdefault("logging", {}) @@ -50,11 +52,13 @@ CONFIG.setdefault( "api": "https://bots.nexy7574.co.uk/jimmy/v2/", "username": os.getenv("WEB_USERNAME", os.urandom(32).hex()), "password": os.getenv("WEB_PASSWORD", os.urandom(32).hex()), - } + }, ) if CONFIG["redis"].pop("no_ping", None) is not None: - log.warning("`redis.no_ping` was deprecated after 808D621F. Ping is now always mandatory.") + log.warning( + "`redis.no_ping` was deprecated after 808D621F. Ping is now always mandatory." + ) CONFIG["redis"]["decode_responses"] = True diff --git a/src/main.py b/src/main.py index 5f4d200..1abe4ff 100644 --- a/src/main.py +++ b/src/main.py @@ -61,9 +61,16 @@ class KumaThread(Thread): response.raise_for_status() self.previous = bot.is_ready() except httpx.HTTPError as error: - self.log.error("Failed to connect to uptime-kuma: %r: %r", url, error, exc_info=error) + self.log.error( + "Failed to connect to uptime-kuma: %r: %r", + url, + error, + exc_info=error, + ) timeout = self.calculate_backoff() - self.log.warning("Waiting %d seconds before retrying ping.", timeout) + self.log.warning( + "Waiting %d seconds before retrying ping.", timeout + ) time.sleep(timeout) continue @@ -91,7 +98,11 @@ logging.basicConfig( markup=True, console=Console(width=cols, height=lns), ), - FileHandler(filename=CONFIG["logging"].get("file", "jimmy.log"), encoding="utf-8", errors="replace"), + FileHandler( + filename=CONFIG["logging"].get("file", "jimmy.log"), + encoding="utf-8", + errors="replace", + ), ], ) for logger in CONFIG["logging"].get("suppress", []): @@ -130,7 +141,8 @@ class Client(commands.Bot): async def start(self, token: str, *, reconnect: bool = True) -> None: if CONFIG["jimmy"].get("uptime_kuma_url"): self.uptime_thread = KumaThread( - CONFIG["jimmy"]["uptime_kuma_url"], CONFIG["jimmy"].get("uptime_kuma_interval", 58.0) + CONFIG["jimmy"]["uptime_kuma_url"], + CONFIG["jimmy"].get("uptime_kuma_interval", 58.0), ) self.uptime_thread.start() await super().start(token, reconnect=reconnect) @@ -138,7 +150,9 @@ class Client(commands.Bot): async def close(self) -> None: if getattr(self, "uptime_thread", None): self.uptime_thread.kill.set() - await asyncio.get_event_loop().run_in_executor(None, self.uptime_thread.join) + await asyncio.get_event_loop().run_in_executor( + None, self.uptime_thread.join + ) await super().close() @@ -189,9 +203,13 @@ async def on_application_command_error(ctx: discord.ApplicationContext, exc: Exc log.error(f"Error in {ctx.command} from {ctx.author} in {ctx.guild}", exc_info=exc) if isinstance(exc, commands.CommandOnCooldown): expires = discord.utils.utcnow() + datetime.timedelta(seconds=exc.retry_after) - await ctx.respond(f"Command on cooldown. Try again {discord.utils.format_dt(expires, style='R')}.") + await ctx.respond( + f"Command on cooldown. Try again {discord.utils.format_dt(expires, style='R')}." + ) elif isinstance(exc, commands.MaxConcurrencyReached): - await ctx.respond("You've reached the maximum number of concurrent uses for this command.") + await ctx.respond( + "You've reached the maximum number of concurrent uses for this command." + ) else: if await bot.is_owner(ctx.author): paginator = commands.Paginator(prefix="```py") @@ -200,7 +218,10 @@ async def on_application_command_error(ctx: discord.ApplicationContext, exc: Exc for page in paginator.pages: await ctx.respond(page) else: - await ctx.respond(f"An error occurred while processing your command. Please try again later.\n" f"{exc}") + await ctx.respond( + f"An error occurred while processing your command. Please try again later.\n" + f"{exc}" + ) @bot.listen() @@ -218,11 +239,22 @@ async def delete_message(ctx: discord.ApplicationContext, message: discord.Messa await ctx.defer() if not ctx.channel.permissions_for(ctx.me).manage_messages: if message.author != bot.user: - return await ctx.respond("I don't have permission to delete messages in this channel.", delete_after=30) + return await ctx.respond( + "I don't have permission to delete messages in this channel.", + delete_after=30, + ) - log.info("%s deleted message %s>%s: %r", ctx.author, ctx.channel.name, message.id, message.content) + log.info( + "%s deleted message %s>%s: %r", + ctx.author, + ctx.channel.name, + message.id, + message.content, + ) await message.delete(delay=1) - await ctx.respond(f"\N{white heavy check mark} Deleted message by {message.author.display_name}.") + await ctx.respond( + f"\N{WHITE HEAVY CHECK MARK} Deleted message by {message.author.display_name}." + ) await ctx.delete(delay=5) @@ -231,16 +263,19 @@ async def about_me(ctx: discord.ApplicationContext): embed = discord.Embed( title="Jimmy v3", description="A bot specifically for the LCC discord server(s).", - colour=discord.Colour.green() + colour=discord.Colour.green(), ) embed.add_field( name="Source", - value="[Source code is available under AGPLv3.](https://git.i-am.nexus/nex/college-bot-v2)" + value="[Source code is available under AGPLv3.](https://git.i-am.nexus/nex/college-bot-v2)", ) user = await ctx.bot.fetch_user(421698654189912064) appinfo = await ctx.bot.application_info() author = appinfo.owner - embed.add_field(name="Author", value=f"{user.mention} created me. {author.mention} is running this instance.") + embed.add_field( + name="Author", + value=f"{user.mention} created me. {author.mention} is running this instance.", + ) return await ctx.respond(embed=embed) @@ -249,9 +284,8 @@ async def check_is_enabled(ctx: commands.Context | discord.ApplicationContext) - disabled = CONFIG["jimmy"].get("disabled_commands", []) if ctx.command.qualified_name in disabled: raise commands.DisabledCommand( - "%s is disabled via this instance's configuration file." % ( - ctx.command.qualified_name - ) + "%s is disabled via this instance's configuration file." + % (ctx.command.qualified_name) ) return True @@ -273,7 +307,9 @@ async def add_delays(ctx: commands.Context | discord.ApplicationContext) -> bool if not CONFIG["jimmy"].get("token"): - log.critical("No token specified in config.toml. Exiting. (hint: set jimmy.token in config.toml)") + log.critical( + "No token specified in config.toml. Exiting. (hint: set jimmy.token in config.toml)" + ) sys.exit(1) __disabled = CONFIG["jimmy"].get("disabled_commands", []) diff --git a/src/server.py b/src/server.py index 36bc95a..f5480e3 100644 --- a/src/server.py +++ b/src/server.py @@ -2,24 +2,24 @@ import json import redis +import orjson import os import secrets import typing import time from fastapi import FastAPI, Depends, HTTPException, APIRouter -from fastapi.responses import JSONResponse, Response +from fastapi.responses import JSONResponse, Response, ORJSONResponse from fastapi.security import HTTPBasic, HTTPBasicCredentials from pydantic import BaseModel, Field JSON: typing.Union[ str, int, float, bool, None, typing.Dict[str, "JSON"], typing.List["JSON"] -] = typing.Union[ - str, int, float, bool, None, dict, list -] +] = typing.Union[str, int, float, bool, None, dict, list] class TruthPayload(BaseModel): """Represents a truth. This can be used to both create and get truths.""" + id: str content: str author: str @@ -33,6 +33,7 @@ class OllamaThread(BaseModel): class ThreadMessage(BaseModel): """Represents a message in an Ollama thread.""" + role: typing.Literal["assistant", "system", "user"] content: str images: typing.Optional[list[str]] = [] @@ -54,9 +55,7 @@ USERNAME = os.getenv("WEB_USERNAME", os.urandom(32).hex()) PASSWORD = os.getenv("WEB_PASSWORD", os.urandom(32).hex()) -accounts = { - USERNAME: PASSWORD -} +accounts = {USERNAME: PASSWORD} if os.path.exists("./web-accounts.json"): with open("./web-accounts.json") as f: accounts.update(json.load(f)) @@ -71,7 +70,9 @@ def check_credentials(credentials: HTTPBasicCredentials = Depends(security)): return credentials -def get_db_factory(n: int = 11) -> typing.Callable[[], typing.Generator[redis.Redis, None, None]]: +def get_db_factory( + n: int = 11, +) -> typing.Callable[[], typing.Generator[redis.Redis, None, None]]: def inner(): uri = os.getenv("REDIS_URL", "redis://redis") conn = redis.Redis.from_url(uri) @@ -80,30 +81,29 @@ def get_db_factory(n: int = 11) -> typing.Callable[[], typing.Generator[redis.Re yield conn finally: conn.close() + return inner app = FastAPI( title="Jimmy v3 API", version="3.1.0", - root_path=os.getenv("WEB_ROOT_PATH", "") + "/api" + root_path=os.getenv("WEB_ROOT_PATH", "") + "/api", ) truth_router = APIRouter( - prefix="/truths", - dependencies=[Depends(check_credentials)], - tags=["Truth Social"] + prefix="/truths", dependencies=[Depends(check_credentials)], tags=["Truth Social"] ) -@truth_router.get("", response_model=list[TruthPayload]) +@truth_router.get("", response_class=ORJSONResponse, response_model=list[TruthPayload]) def get_all_truths( rich: bool = True, limit: int = -1, page: int = 0, - db: redis.Redis = Depends(get_db_factory()) + db: redis.Redis = Depends(get_db_factory()), ): """Retrieves all stored truths - + If ?limit is a positive integer, pagination will be enabled.""" keys = db.keys() if rich is False: @@ -113,12 +113,18 @@ def get_all_truths( ] truths = [json.loads(db.get(key)) for key in keys] if limit >= 0: - return truths[page * limit:(page + 1) * limit] - return truths + selection = truths[page * limit: (page + 1) * limit] + else: + selection = truths + return ORJSONResponse(selection) @truth_router.get("/all", deprecated=True, response_model=list[TruthPayload]) -def get_all_truths_deprecated(response: JSONResponse, rich: bool = True, db: redis.Redis = Depends(get_db_factory())): +def get_all_truths_deprecated( + response: JSONResponse, + rich: bool = True, + db: redis.Redis = Depends(get_db_factory()), +): """DEPRECATED - USE get_all_truths INSTEAD""" return get_all_truths(rich=rich, db=db) @@ -142,7 +148,11 @@ def head_truth(truth_id: str, db: redis.Redis = Depends(get_db_factory())): @truth_router.post("", status_code=201, response_model=TruthPayload) -def new_truth(payload: TruthPayload, response: JSONResponse, db: redis.Redis = Depends(get_db_factory())): +def new_truth( + payload: TruthPayload, + response: JSONResponse, + db: redis.Redis = Depends(get_db_factory()), +): """Creates a new truth""" data = payload.model_dump() existing: str = db.get(data["id"]) @@ -157,7 +167,9 @@ def new_truth(payload: TruthPayload, response: JSONResponse, db: redis.Redis = D @truth_router.put("/{truth_id}", response_model=TruthPayload) -def put_truth(truth_id: str, payload: TruthPayload, db: redis.Redis = Depends(get_db_factory())): +def put_truth( + truth_id: str, payload: TruthPayload, db: redis.Redis = Depends(get_db_factory()) +): """Replaces a stored truth""" data = payload.model_dump() existing = db.get(truth_id) @@ -177,9 +189,7 @@ def delete_truth(truth_id: str, db: redis.Redis = Depends(get_db_factory())): app.include_router(truth_router) ollama_router = APIRouter( - prefix="/ollama", - dependencies=[Depends(check_credentials)], - tags=["Ollama"] + prefix="/ollama", dependencies=[Depends(check_credentials)], tags=["Ollama"] ) @@ -194,13 +204,13 @@ def get_ollama_threads(db: redis.Redis = Depends(get_db_factory(0))): return keys -@ollama_router.get("/thread/{thread_id}", response_model=OllamaThread) +@ollama_router.get("/thread/{thread_id}", response_class=ORJSONResponse, response_model=OllamaThread) def get_ollama_thread(thread_id: str, db: redis.Redis = Depends(get_db_factory(0))): """Retrieves a stored thread""" data: str = db.get(thread_id) if not data: raise HTTPException(404, detail="%r not found." % thread_id) - return json.loads(data) + return ORJSONResponse(orjson.loads(data.encode())) @ollama_router.delete("/thread/{thread_id}", status_code=204) @@ -225,4 +235,5 @@ def health(db: redis.Redis = Depends(get_db_factory())): if __name__ == "__main__": import uvicorn + uvicorn.run(app, host="0.0.0.0", port=1111, forwarded_allow_ips="*") From 301e7b7417a312920d86433e79855493de0d3f85 Mon Sep 17 00:00:00 2001 From: nexy7574 Date: Sun, 14 Jul 2024 02:06:06 +0100 Subject: [PATCH 23/24] add server-timing --- src/server.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/src/server.py b/src/server.py index f5480e3..5b39ccf 100644 --- a/src/server.py +++ b/src/server.py @@ -105,18 +105,28 @@ def get_all_truths( """Retrieves all stored truths If ?limit is a positive integer, pagination will be enabled.""" + query_start = time.perf_counter() keys = db.keys() + query_end = time.perf_counter() if rich is False: return [ {"id": key, "content": "", "author": "", "timestamp": 0, "extra": None} for key in keys ] + load_start = time.perf_counter() truths = [json.loads(db.get(key)) for key in keys] + load_end = time.perf_counter() if limit >= 0: selection = truths[page * limit: (page + 1) * limit] else: selection = truths - return ORJSONResponse(selection) + whittle_end = time.perf_counter() + server_timing = "query;dur=%.2f, load;dur=%.2f, whittle;dur=%.2f" % ( + (query_end - query_start) * 1000, + (load_end - load_start) * 1000, + (whittle_end - load_end) * 1000, + ) + return ORJSONResponse(selection, headers={"Server-Timing": server_timing}) @truth_router.get("/all", deprecated=True, response_model=list[TruthPayload]) From b0325a26b0dafc1fbbe3bc5b0ba89d5885673e23 Mon Sep 17 00:00:00 2001 From: nexy7574 Date: Sun, 14 Jul 2024 02:09:32 +0100 Subject: [PATCH 24/24] Speed up truth querying --- src/server.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/src/server.py b/src/server.py index 5b39ccf..1e02219 100644 --- a/src/server.py +++ b/src/server.py @@ -114,19 +114,15 @@ def get_all_truths( for key in keys ] load_start = time.perf_counter() + if limit >= 0: + keys = keys[page * limit: (page + 1) * limit] truths = [json.loads(db.get(key)) for key in keys] load_end = time.perf_counter() - if limit >= 0: - selection = truths[page * limit: (page + 1) * limit] - else: - selection = truths - whittle_end = time.perf_counter() - server_timing = "query;dur=%.2f, load;dur=%.2f, whittle;dur=%.2f" % ( + server_timing = "query;dur=%.2f, load;dur=%.2f" % ( (query_end - query_start) * 1000, (load_end - load_start) * 1000, - (whittle_end - load_end) * 1000, ) - return ORJSONResponse(selection, headers={"Server-Timing": server_timing}) + return ORJSONResponse(truths, headers={"Server-Timing": server_timing}) @truth_router.get("/all", deprecated=True, response_model=list[TruthPayload])