Add impersonate command
All checks were successful
Build and Publish college-bot-v2 / build_and_publish (push) Successful in 22s

This commit is contained in:
Nexus 2024-06-08 20:04:32 +01:00
parent 8b21954443
commit c33bce67cb
2 changed files with 137 additions and 1 deletions

View file

@ -1190,6 +1190,142 @@ class Ollama(commands.Cog):
""" """
async with ctx.channel.typing(): async with ctx.channel.typing():
await self.generate_truth(ctx, "tate", latest) await self.generate_truth(ctx, "tate", latest)
@commands.command()
@commands.guild_only()
async def impersonate(
self,
ctx: commands.Context,
user: discord.Member,
search: int = 500,
channels: commands.Greedy[discord.TextChannel] = None,
query: str = None
):
"""
Impersonates a user in a server.
<user> - The user to impersonate.
<search> - The number of messages to search through. This is NOT the number of messages from the user
<channels> - The channels to search through. If not provided, the current channel will be searched.
<query> - A query to ask
"""
search = max(10, min(search, 1000))
channels = channels or [ctx.channel]
messages = []
msg = await ctx.reply("Gathering messages, please wait. This may take a while.")
for channel in channels:
async for message in channel.history(limit=search):
if message.author == user and message.content:
messages.append(message)
if not messages:
return await msg.edit(content="No messages found.")
await msg.edit(content="Preparing AI...")
system = (
"You are {0.mention!r}, a member of the discord server {0.guild.name!r}. You have been provided with"
" your message history. You should generate a new message based on the content of the messages provided. "
"You should not repeat a message. You should write a message that is under 4000 characters. "
"You should not include any pleasentries, only the message that you would send. Your messages have been "
"found in the following channels: {1}.".format(user, ", ".join(map(str, channels)))
)
thread_id = self.history.create_thread(ctx.author, system)
for message in messages:
await asyncio.to_thread(
functools.partial(
self.history.add_message,
thread_id,
"assistant",
message.content,
save=False
)
)
self.history.add_message(thread_id, "user", "new message:")
tried = set()
for _ in range(10):
server = self.next_server(tried)
if await self.check_server(CONFIG["ollama"][server]["base_url"]):
break
tried.add(server)
else:
return msg.edit(content="All servers are offline. Please try again later.", delete_after=300)
client = OllamaClient(CONFIG["ollama"][server]["base_url"])
async with self.servers[server]:
if not await client.has_model_named("llama2-uncensored", "7b-chat"):
with client.download_model("llama2-uncensored", "7b-chat") as handler:
await handler.flatten()
embed = discord.Embed(
title="New message!",
description="",
colour=0x6559FF
)
embed.set_author(name=user.display_name, icon_url=user.display_avatar.url)
msg = await ctx.reply(embed=embed)
last_edit = time.time()
messages = self.history.get_history(thread_id)
with client.new_chat("llama2-uncensored:7b-chat", messages) as handler:
await msg.edit(content="Generating message...")
async for ln in handler:
embed.description += ln["message"]["content"]
if len(embed.description) >= 4000:
break
if (time.time() - last_edit) >= 2.5:
await msg.edit(embed=embed)
last_edit = time.time()
similar = {}
for message in messages:
_ratio = fuzz.ratio(message.content, embed.description)
if message.content == embed.description:
embed.add_field(
name="Repeated message :(",
value="This message was already sent. Shit AI."
)
elif _ratio >= 70:
similar[message.id] = _ratio
if similar:
if len(similar) > 1:
lns = []
keys = sorted(similar.keys(), key=lambda k: similar[k], reverse=True)
for message_id in keys:
_ratio = similar[message_id]
message = discord.utils.get(messages, id=message_id)
first_line = message.content.splitlines()[0]
preview = discord.utils.escape_markdown(textwrap.shorten(first_line, 100))
lns.append(f"* [{_ratio}% - `{preview}`]({message.jump_url})")
if len(lns) > 5:
lc = len(lns) - 5
lns = lns[:5]
lns.append(f"*... and {lc} more*")
embed.add_field(
name="Possibly repeated message",
value="This message was similar to the following existing ones:\n" + "\n".join(lns),
inline=False
)
else:
message_id = tuple(similar)[0]
_ratio = similar[message_id]
message = discord.utils.get(messages, id=message_id)
first_line = message.content.splitlines()[0]
preview = discord.utils.escape_markdown(textwrap.shorten(first_line, 512))
embed.add_field(
name="Possibly repeated message",
value=f"This message was {_ratio}% similar to `{message_id}`.\n>>> {preview}"
)
embed.set_footer(
text="Finished generating message based off of {:,} messages, using server {!r} | {!s}".format(
len(messages) - 2,
server,
thread_id
)
)
await msg.edit(content=None, embed=embed)
@commands.command() @commands.command()
@commands.guild_only() @commands.guild_only()

View file

@ -26,7 +26,7 @@ JSON: typing.Union[
class TruthPayload(BaseModel): class TruthPayload(BaseModel):
id: str id: str
content: str content: str
author: typing.Literal["trump", "tate"] = Field(pattern=r"^(trump|tate)$") author: str
timestamp: float = Field(default_factory=time.time, ge=0) timestamp: float = Field(default_factory=time.time, ge=0)
extra: typing.Optional[JSON] = None extra: typing.Optional[JSON] = None