Add impersonate command
All checks were successful
Build and Publish college-bot-v2 / build_and_publish (push) Successful in 22s
All checks were successful
Build and Publish college-bot-v2 / build_and_publish (push) Successful in 22s
This commit is contained in:
parent
8b21954443
commit
c33bce67cb
2 changed files with 137 additions and 1 deletions
|
@ -1190,6 +1190,142 @@ class Ollama(commands.Cog):
|
||||||
"""
|
"""
|
||||||
async with ctx.channel.typing():
|
async with ctx.channel.typing():
|
||||||
await self.generate_truth(ctx, "tate", latest)
|
await self.generate_truth(ctx, "tate", latest)
|
||||||
|
|
||||||
|
@commands.command()
|
||||||
|
@commands.guild_only()
|
||||||
|
async def impersonate(
|
||||||
|
self,
|
||||||
|
ctx: commands.Context,
|
||||||
|
user: discord.Member,
|
||||||
|
search: int = 500,
|
||||||
|
channels: commands.Greedy[discord.TextChannel] = None,
|
||||||
|
query: str = None
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Impersonates a user in a server.
|
||||||
|
|
||||||
|
<user> - The user to impersonate.
|
||||||
|
<search> - The number of messages to search through. This is NOT the number of messages from the user
|
||||||
|
<channels> - The channels to search through. If not provided, the current channel will be searched.
|
||||||
|
<query> - A query to ask
|
||||||
|
"""
|
||||||
|
search = max(10, min(search, 1000))
|
||||||
|
channels = channels or [ctx.channel]
|
||||||
|
messages = []
|
||||||
|
msg = await ctx.reply("Gathering messages, please wait. This may take a while.")
|
||||||
|
for channel in channels:
|
||||||
|
async for message in channel.history(limit=search):
|
||||||
|
if message.author == user and message.content:
|
||||||
|
messages.append(message)
|
||||||
|
|
||||||
|
if not messages:
|
||||||
|
return await msg.edit(content="No messages found.")
|
||||||
|
|
||||||
|
await msg.edit(content="Preparing AI...")
|
||||||
|
|
||||||
|
system = (
|
||||||
|
"You are {0.mention!r}, a member of the discord server {0.guild.name!r}. You have been provided with"
|
||||||
|
" your message history. You should generate a new message based on the content of the messages provided. "
|
||||||
|
"You should not repeat a message. You should write a message that is under 4000 characters. "
|
||||||
|
"You should not include any pleasentries, only the message that you would send. Your messages have been "
|
||||||
|
"found in the following channels: {1}.".format(user, ", ".join(map(str, channels)))
|
||||||
|
)
|
||||||
|
thread_id = self.history.create_thread(ctx.author, system)
|
||||||
|
for message in messages:
|
||||||
|
await asyncio.to_thread(
|
||||||
|
functools.partial(
|
||||||
|
self.history.add_message,
|
||||||
|
thread_id,
|
||||||
|
"assistant",
|
||||||
|
message.content,
|
||||||
|
save=False
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.history.add_message(thread_id, "user", "new message:")
|
||||||
|
|
||||||
|
tried = set()
|
||||||
|
for _ in range(10):
|
||||||
|
server = self.next_server(tried)
|
||||||
|
if await self.check_server(CONFIG["ollama"][server]["base_url"]):
|
||||||
|
break
|
||||||
|
tried.add(server)
|
||||||
|
else:
|
||||||
|
return msg.edit(content="All servers are offline. Please try again later.", delete_after=300)
|
||||||
|
|
||||||
|
client = OllamaClient(CONFIG["ollama"][server]["base_url"])
|
||||||
|
async with self.servers[server]:
|
||||||
|
if not await client.has_model_named("llama2-uncensored", "7b-chat"):
|
||||||
|
with client.download_model("llama2-uncensored", "7b-chat") as handler:
|
||||||
|
await handler.flatten()
|
||||||
|
|
||||||
|
embed = discord.Embed(
|
||||||
|
title="New message!",
|
||||||
|
description="",
|
||||||
|
colour=0x6559FF
|
||||||
|
)
|
||||||
|
embed.set_author(name=user.display_name, icon_url=user.display_avatar.url)
|
||||||
|
msg = await ctx.reply(embed=embed)
|
||||||
|
last_edit = time.time()
|
||||||
|
messages = self.history.get_history(thread_id)
|
||||||
|
with client.new_chat("llama2-uncensored:7b-chat", messages) as handler:
|
||||||
|
await msg.edit(content="Generating message...")
|
||||||
|
async for ln in handler:
|
||||||
|
embed.description += ln["message"]["content"]
|
||||||
|
if len(embed.description) >= 4000:
|
||||||
|
break
|
||||||
|
if (time.time() - last_edit) >= 2.5:
|
||||||
|
await msg.edit(embed=embed)
|
||||||
|
last_edit = time.time()
|
||||||
|
|
||||||
|
similar = {}
|
||||||
|
for message in messages:
|
||||||
|
_ratio = fuzz.ratio(message.content, embed.description)
|
||||||
|
if message.content == embed.description:
|
||||||
|
embed.add_field(
|
||||||
|
name="Repeated message :(",
|
||||||
|
value="This message was already sent. Shit AI."
|
||||||
|
)
|
||||||
|
elif _ratio >= 70:
|
||||||
|
similar[message.id] = _ratio
|
||||||
|
|
||||||
|
if similar:
|
||||||
|
if len(similar) > 1:
|
||||||
|
lns = []
|
||||||
|
keys = sorted(similar.keys(), key=lambda k: similar[k], reverse=True)
|
||||||
|
for message_id in keys:
|
||||||
|
_ratio = similar[message_id]
|
||||||
|
message = discord.utils.get(messages, id=message_id)
|
||||||
|
first_line = message.content.splitlines()[0]
|
||||||
|
preview = discord.utils.escape_markdown(textwrap.shorten(first_line, 100))
|
||||||
|
lns.append(f"* [{_ratio}% - `{preview}`]({message.jump_url})")
|
||||||
|
if len(lns) > 5:
|
||||||
|
lc = len(lns) - 5
|
||||||
|
lns = lns[:5]
|
||||||
|
lns.append(f"*... and {lc} more*")
|
||||||
|
embed.add_field(
|
||||||
|
name="Possibly repeated message",
|
||||||
|
value="This message was similar to the following existing ones:\n" + "\n".join(lns),
|
||||||
|
inline=False
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
message_id = tuple(similar)[0]
|
||||||
|
_ratio = similar[message_id]
|
||||||
|
message = discord.utils.get(messages, id=message_id)
|
||||||
|
first_line = message.content.splitlines()[0]
|
||||||
|
preview = discord.utils.escape_markdown(textwrap.shorten(first_line, 512))
|
||||||
|
embed.add_field(
|
||||||
|
name="Possibly repeated message",
|
||||||
|
value=f"This message was {_ratio}% similar to `{message_id}`.\n>>> {preview}"
|
||||||
|
)
|
||||||
|
|
||||||
|
embed.set_footer(
|
||||||
|
text="Finished generating message based off of {:,} messages, using server {!r} | {!s}".format(
|
||||||
|
len(messages) - 2,
|
||||||
|
server,
|
||||||
|
thread_id
|
||||||
|
)
|
||||||
|
)
|
||||||
|
await msg.edit(content=None, embed=embed)
|
||||||
|
|
||||||
@commands.command()
|
@commands.command()
|
||||||
@commands.guild_only()
|
@commands.guild_only()
|
||||||
|
|
|
@ -26,7 +26,7 @@ JSON: typing.Union[
|
||||||
class TruthPayload(BaseModel):
|
class TruthPayload(BaseModel):
|
||||||
id: str
|
id: str
|
||||||
content: str
|
content: str
|
||||||
author: typing.Literal["trump", "tate"] = Field(pattern=r"^(trump|tate)$")
|
author: str
|
||||||
timestamp: float = Field(default_factory=time.time, ge=0)
|
timestamp: float = Field(default_factory=time.time, ge=0)
|
||||||
extra: typing.Optional[JSON] = None
|
extra: typing.Optional[JSON] = None
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue