From c47a38bfb95a42ae41244f6521d707aced94b917 Mon Sep 17 00:00:00 2001 From: nexy7574 Date: Wed, 5 Jun 2024 02:56:03 +0100 Subject: [PATCH] Make use of the API for truth social commands --- config.example.toml | 5 +++ src/cogs/ollama.py | 43 +++++++++--------- src/cogs/quote_quota.py | 99 +++++++++++++++++++++++++---------------- src/conf.py | 8 ++++ src/server.py | 2 +- 5 files changed, 94 insertions(+), 63 deletions(-) diff --git a/config.example.toml b/config.example.toml index 400b59e..89e5f59 100644 --- a/config.example.toml +++ b/config.example.toml @@ -88,3 +88,8 @@ enabled = false # disables starboard entirely emoji = "⭐" # the emoji to use. Defaults to the plain star. whitelist = [994710566612500550] # An array of server IDs to whitelist. Omitted means all servers. channel_name = "starboard" # the channel name to search for. Globally, not per-server. + +[truth] # for truth-social commands (/truths, h!trump, etc) +api = "https://bots.nexy7574.co.uk/jimmy/v2" # base URL +username = "jimmy" # the username to use +password = "password" # the password to use diff --git a/src/cogs/ollama.py b/src/cogs/ollama.py index 51aabef..ea41170 100644 --- a/src/cogs/ollama.py +++ b/src/cogs/ollama.py @@ -12,6 +12,7 @@ from fnmatch import fnmatch import aiohttp import discord +import httpx import redis from discord import Interaction from discord.ext import commands @@ -1014,7 +1015,7 @@ class Ollama(commands.Cog): @commands.command() @commands.guild_only() - async def trump(self, ctx: commands.Context, max_history: int = 100): + async def trump(self, ctx: commands.Context): async with ctx.channel.typing(): thread_id = self.history.create_thread( ctx.author, @@ -1025,25 +1026,21 @@ class Ollama(commands.Cog): "under 4000 characters. Write only the content to be posted, do not include any pleasantries." " Write using the style of a twitter or facebook post. Do not repeat a previous post." ) - channel = discord.utils.get(ctx.guild.text_channels, name="spam") - oldest = discord.Object(1229487300505894964) - async for message in channel.history(limit=max_history): - if message.created_at <= oldest.created_at: - break - if message.author.id == 1101439218334576742 and len(message.embeds): - embed = message.embeds[0] - if not embed.description or embed.description.strip() == "NEW TRUTH": - continue - if embed.type == "rich" and embed.colour and embed.colour.value == 0x5448EE: - await asyncio.to_thread( - functools.partial( - self.history.add_message, - thread_id, - "assistant", - embed.description, - save=False - ) - ) + async with httpx.AsyncClient() as client: + r = CONFIG["truth"].get("api", "https://bots.nexy7574.co.uk/jimmy/v2") + username = CONFIG["truth"].get("username", "1") + password = CONFIG["truth"].get("password", "2") + response = await client.get( + r + "/api/truths/all", + timeout=60, + auth=(username, password), + ) + response.raise_for_status() + truths = response.json() + for truth in truths: + if truth["author"] == "trump": + self.history.add_message(thread_id, "assistant", truth["content"], save=False) + break self.history.add_message(thread_id, "user", "Generate a new truth post.") tried = set() @@ -1078,8 +1075,8 @@ class Ollama(commands.Cog): await msg.edit(embed=embed) last_edit = time.time() - for _message in self.history.get_history(thread_id): - if _message["content"] == embed.description: + for truth in truths: + if truth["content"] == embed.description: embed.add_field( name="Repeated truth :(", value="This truth was already truthed. Shit AI." @@ -1087,7 +1084,7 @@ class Ollama(commands.Cog): break embed.set_footer( text="Finished generating truth based off of {:,} messages, using server {!r} | {!s}".format( - len(messages), + len(messages) - 2, server, thread_id ) diff --git a/src/cogs/quote_quota.py b/src/cogs/quote_quota.py index ec946d6..8065e0c 100644 --- a/src/cogs/quote_quota.py +++ b/src/cogs/quote_quota.py @@ -2,7 +2,11 @@ import asyncio import io import logging import re -from datetime import datetime, timedelta +import time +import typing + +import httpx +from datetime import datetime, timedelta, timezone from typing import Annotated, Callable import discord @@ -10,6 +14,21 @@ import matplotlib.pyplot as plt from discord.ext import commands from conf import CONFIG +from pydantic import BaseModel, Field + +JSON: typing.Union[ + str, int, float, bool, None, dict[str, "JSON"], list["JSON"] +] = typing.Union[ + str, int, float, bool, None, dict, list +] + + +class TruthPayload(BaseModel): + id: str + content: str + author: typing.Literal["trump", "tate"] = Field(pattern=r"^(trump|tate)$") + timestamp: float = Field(default_factory=time.time, ge=0) + extra: typing.Optional[JSON] = None class QuoteQuota(commands.Cog): @@ -173,7 +192,7 @@ class QuoteQuota(commands.Cog): ) def _metacounter( - self, messages: list[discord.Message], filter_func: Callable[[discord.Message], bool], *, now: datetime = None + self, truths: list[TruthPayload], filter_func: Callable[[TruthPayload], bool], *, now: datetime = None ) -> dict[str, float | int | dict[str, int]]: def _is_today(date: datetime) -> bool: return date.date() == now.date() @@ -192,20 +211,21 @@ class QuoteQuota(commands.Cog): } for i in range(24): counts["hours"][str(i)] = 0 - for message in messages: - if filter_func(message): - age = now - message.created_at - self.log.debug("%r was a truth (%.2f seconds ago).", message.id, age.total_seconds()) + for truth in truths: + if filter_func(truth): + created_at = datetime.fromtimestamp(truth.timestamp, tz=timezone.utc) + age = now - created_at + self.log.debug("%r was a truth (%.2f seconds ago).", truth.id, age.total_seconds()) counts["all_time"] += 1 - if _is_today(message.created_at): + if _is_today(created_at): counts["today"] += 1 - if message.created_at > now - timedelta(hours=1): + if created_at > now - timedelta(hours=1): counts["hour"] += 1 - if message.created_at > now - timedelta(days=1): + if created_at > now - timedelta(days=1): counts["day"] += 1 - if message.created_at > now - timedelta(days=7): + if created_at > now - timedelta(days=7): counts["week"] += 1 - counts["hours"][str(message.created_at.hour)] += 1 + counts["hours"][str(created_at.hour)] += 1 counts["per_minute"] = counts["hour"] / 60 counts["per_hour"] = counts["day"] / 24 @@ -213,42 +233,31 @@ class QuoteQuota(commands.Cog): self.log.info("Total truth counts: %r", counts) return counts - async def _process_trump_truths(self, messages: list[discord.Message]): + async def _process_trump_truths(self, truths: list[TruthPayload]): """ Processes the given messages to count the number of posts by Donald Trump. - :param messages: The messages to process + :param truths: The truths to process :returns: The stats """ - def is_truth(msg: discord.Message) -> bool: - if msg.author.id == 1101439218334576742: - for __t_e in msg.embeds: - if __t_e.type == "rich" and __t_e.colour is not None and __t_e.colour.value == 0x5448EE: - self.log.debug("Found tagged rich trump truth embed: %r", msg.id) - return True - return False + def is_truth(truth: TruthPayload) -> bool: + return truth.author == "trump" - return self._metacounter(messages, is_truth) + return self._metacounter(truths, is_truth) - async def _process_tate_truths(self, messages: list[discord.Message]): + async def _process_tate_truths(self, truths: list[TruthPayload]): """ Processes the given messages to count the number of posts by Andrew Tate. - :param messages: The messages to process + :param truths: The messages to process :returns: The stats """ - def is_truth(msg: discord.Message) -> bool: - if msg.author.id == 1229496078726860921: - # All the tate truths are already tagged. - for __t_e in msg.embeds: - if __t_e.type == "rich" and __t_e.colour.value == 0x5448EE: - self.log.debug("Found tagged rich tate truth embed %r", __t_e) - return True - return False + def is_truth(truth: TruthPayload) -> bool: + return truth.author == "tate" - return self._metacounter(messages, is_truth) + return self._metacounter(truths, is_truth) @staticmethod def _generate_truth_frequency_graph(hours: dict[str, int]) -> discord.File: @@ -273,7 +282,11 @@ class QuoteQuota(commands.Cog): file.seek(0) return discord.File(file, "truths.png") - async def _process_all_messages(self, channel: discord.TextChannel) -> tuple[discord.Embed, discord.File]: + async def _process_all_messages( + self, + channel: discord.TextChannel, + truths: list + ) -> tuple[discord.Embed, discord.File]: """ Processes all the messages in the given channel. @@ -281,11 +294,8 @@ class QuoteQuota(commands.Cog): :returns: The stats """ embed = discord.Embed(title="Truth Counts", color=discord.Color.blurple(), timestamp=discord.utils.utcnow()) - messages: list[discord.Message] = await channel.history( - limit=None, after=discord.Object(1229487065117233203) - ).flatten() - trump_stats = await self._process_trump_truths(messages) - tate_stats = await self._process_tate_truths(messages) + trump_stats = await self._process_trump_truths(truths) + tate_stats = await self._process_tate_truths(truths) embed.add_field( name="Donald Trump", @@ -331,7 +341,18 @@ class QuoteQuota(commands.Cog): timestamp=now, ) await ctx.respond(embed=embed) - embed, file = await self._process_all_messages(channel) + async with httpx.AsyncClient() as client: + r = CONFIG["truth"].get("api", "https://bots.nexy7574.co.uk/jimmy/v2") + username = CONFIG["truth"].get("username", "1") + password = CONFIG["truth"].get("password", "2") + response = await client.get( + r + "/api/truths/all", + timeout=60, + auth=(username, password), + ) + response.raise_for_status() + truths = response.json() + embed, file = await self._process_all_messages(channel, truths) await ctx.edit(embed=embed, file=file) diff --git a/src/conf.py b/src/conf.py index b2a4030..198bf57 100644 --- a/src/conf.py +++ b/src/conf.py @@ -44,6 +44,14 @@ CONFIG.setdefault("network", {}) CONFIG.setdefault("quote_a", {"channel": None}) CONFIG.setdefault("redis", {"host": "redis", "port": 6379, "decode_responses": True}) CONFIG.setdefault("starboard", {}) +CONFIG.setdefault( + "truth", + { + "api": "https://bots.nexy7574.co.uk/jimmy/v2/", + "username": os.getenv("WEB_USERNAME", os.urandom(32).hex()), + "password": os.getenv("WEB_PASSWORD", os.urandom(32).hex()), + } +) if CONFIG["redis"].pop("no_ping", None) is not None: log.warning("`redis.no_ping` was deprecated after 808D621F. Ping is now always mandatory.") diff --git a/src/server.py b/src/server.py index bbc245e..c46499e 100644 --- a/src/server.py +++ b/src/server.py @@ -9,7 +9,7 @@ import time from fastapi import FastAPI, Depends, HTTPException from fastapi.responses import JSONResponse, Response from fastapi.security import HTTPBasic, HTTPBasicCredentials -from pydantic import BaseModel, Field, ValidationError +from pydantic import BaseModel, Field JSON: typing.Union[ str, int, float, bool, None, typing.Dict[str, "JSON"], typing.List["JSON"]