diff --git a/src/cogs/ollama.py b/src/cogs/ollama.py index 5f08e60..1190d2f 100644 --- a/src/cogs/ollama.py +++ b/src/cogs/ollama.py @@ -1190,6 +1190,142 @@ class Ollama(commands.Cog): """ async with ctx.channel.typing(): await self.generate_truth(ctx, "tate", latest) + + @commands.command() + @commands.guild_only() + async def impersonate( + self, + ctx: commands.Context, + user: discord.Member, + search: int = 500, + channels: commands.Greedy[discord.TextChannel] = None, + query: str = None + ): + """ + Impersonates a user in a server. + + - The user to impersonate. + - The number of messages to search through. This is NOT the number of messages from the user + - The channels to search through. If not provided, the current channel will be searched. + - A query to ask + """ + search = max(10, min(search, 1000)) + channels = channels or [ctx.channel] + messages = [] + msg = await ctx.reply("Gathering messages, please wait. This may take a while.") + for channel in channels: + async for message in channel.history(limit=search): + if message.author == user and message.content: + messages.append(message) + + if not messages: + return await msg.edit(content="No messages found.") + + await msg.edit(content="Preparing AI...") + + system = ( + "You are {0.mention!r}, a member of the discord server {0.guild.name!r}. You have been provided with" + " your message history. You should generate a new message based on the content of the messages provided. " + "You should not repeat a message. You should write a message that is under 4000 characters. " + "You should not include any pleasentries, only the message that you would send. Your messages have been " + "found in the following channels: {1}.".format(user, ", ".join(map(str, channels))) + ) + thread_id = self.history.create_thread(ctx.author, system) + for message in messages: + await asyncio.to_thread( + functools.partial( + self.history.add_message, + thread_id, + "assistant", + message.content, + save=False + ) + ) + self.history.add_message(thread_id, "user", "new message:") + + tried = set() + for _ in range(10): + server = self.next_server(tried) + if await self.check_server(CONFIG["ollama"][server]["base_url"]): + break + tried.add(server) + else: + return msg.edit(content="All servers are offline. Please try again later.", delete_after=300) + + client = OllamaClient(CONFIG["ollama"][server]["base_url"]) + async with self.servers[server]: + if not await client.has_model_named("llama2-uncensored", "7b-chat"): + with client.download_model("llama2-uncensored", "7b-chat") as handler: + await handler.flatten() + + embed = discord.Embed( + title="New message!", + description="", + colour=0x6559FF + ) + embed.set_author(name=user.display_name, icon_url=user.display_avatar.url) + msg = await ctx.reply(embed=embed) + last_edit = time.time() + messages = self.history.get_history(thread_id) + with client.new_chat("llama2-uncensored:7b-chat", messages) as handler: + await msg.edit(content="Generating message...") + async for ln in handler: + embed.description += ln["message"]["content"] + if len(embed.description) >= 4000: + break + if (time.time() - last_edit) >= 2.5: + await msg.edit(embed=embed) + last_edit = time.time() + + similar = {} + for message in messages: + _ratio = fuzz.ratio(message.content, embed.description) + if message.content == embed.description: + embed.add_field( + name="Repeated message :(", + value="This message was already sent. Shit AI." + ) + elif _ratio >= 70: + similar[message.id] = _ratio + + if similar: + if len(similar) > 1: + lns = [] + keys = sorted(similar.keys(), key=lambda k: similar[k], reverse=True) + for message_id in keys: + _ratio = similar[message_id] + message = discord.utils.get(messages, id=message_id) + first_line = message.content.splitlines()[0] + preview = discord.utils.escape_markdown(textwrap.shorten(first_line, 100)) + lns.append(f"* [{_ratio}% - `{preview}`]({message.jump_url})") + if len(lns) > 5: + lc = len(lns) - 5 + lns = lns[:5] + lns.append(f"*... and {lc} more*") + embed.add_field( + name="Possibly repeated message", + value="This message was similar to the following existing ones:\n" + "\n".join(lns), + inline=False + ) + else: + message_id = tuple(similar)[0] + _ratio = similar[message_id] + message = discord.utils.get(messages, id=message_id) + first_line = message.content.splitlines()[0] + preview = discord.utils.escape_markdown(textwrap.shorten(first_line, 512)) + embed.add_field( + name="Possibly repeated message", + value=f"This message was {_ratio}% similar to `{message_id}`.\n>>> {preview}" + ) + + embed.set_footer( + text="Finished generating message based off of {:,} messages, using server {!r} | {!s}".format( + len(messages) - 2, + server, + thread_id + ) + ) + await msg.edit(content=None, embed=embed) @commands.command() @commands.guild_only() diff --git a/src/cogs/quote_quota.py b/src/cogs/quote_quota.py index 67464a5..a39781f 100644 --- a/src/cogs/quote_quota.py +++ b/src/cogs/quote_quota.py @@ -26,7 +26,7 @@ JSON: typing.Union[ class TruthPayload(BaseModel): id: str content: str - author: typing.Literal["trump", "tate"] = Field(pattern=r"^(trump|tate)$") + author: str timestamp: float = Field(default_factory=time.time, ge=0) extra: typing.Optional[JSON] = None