college-bot-v2/src/cogs/ollama.py

517 lines
21 KiB
Python
Raw Normal View History

2024-01-10 10:13:37 +00:00
import asyncio
import json
import logging
2024-01-10 15:59:13 +00:00
import os
2024-01-09 14:49:29 +00:00
import textwrap
import time
import typing
2024-01-10 16:10:45 +00:00
import base64
2024-01-09 14:49:29 +00:00
import io
2024-01-12 09:53:57 +00:00
import humanize
2024-01-10 16:04:58 +00:00
2024-01-10 10:13:37 +00:00
from discord.ui import View, button
from fnmatch import fnmatch
import aiohttp
import discord
from discord.ext import commands
2024-01-06 21:56:18 +00:00
from conf import CONFIG
2024-01-12 10:00:12 +00:00
def get_time_spent(nanoseconds: int) -> str:
hours, minutes, seconds = 0, 0, 0
seconds = nanoseconds / 1e9
if seconds >= 60:
minutes, seconds = divmod(seconds, 60)
if minutes >= 60:
hours, minutes = divmod(minutes, 60)
result = []
if seconds:
if seconds != 1:
label = "seconds"
else:
label = "second"
result.append(f"{round(seconds)} {label}")
if minutes:
if minutes != 1:
label = "minutes"
else:
label = "minute"
result.append(f"{round(minutes)} {label}")
if hours:
if hours != 1:
label = "hours"
else:
label = "hour"
result.append(f"{round(hours)} {label}")
return ", ".join(reversed(result))
2024-01-10 10:13:37 +00:00
class OllamaView(View):
def __init__(self, ctx: discord.ApplicationContext):
super().__init__(timeout=3600, disable_on_timeout=True)
self.ctx = ctx
self.cancel = asyncio.Event()
2024-01-10 16:04:58 +00:00
async def interaction_check(self, interaction: discord.Interaction) -> bool:
return interaction.user == self.ctx.user
2024-01-10 10:13:37 +00:00
@button(label="Stop", style=discord.ButtonStyle.danger, emoji="\N{wastebasket}\U0000fe0f")
async def _stop(self, btn: discord.ui.Button, interaction: discord.Interaction):
self.cancel.set()
btn.disabled = True
await interaction.response.edit_message(view=self)
self.stop()
2024-01-12 15:39:39 +00:00
class ChatHistory:
def __init__(self):
self._internal = {}
def create_thread(self, member: discord.Member) -> str:
"""
Creates a thread, returns its ID.
"""
key = os.urandom(3).hex()
self._internal[key] = {
"member": member,
"messages": []
}
with open("./assets/ollama-prompt.txt") as file:
system_prompt = file.read()
self.add_message(
key,
"system",
system_prompt
)
return key
@staticmethod
def _construct_message(role: str, content: str, images: typing.Optional[list[str]]) -> dict[str, str]:
x = {
"role": role,
"content": content
}
if images:
x["images"] = images
return x
def add_message(
self,
thread: str,
role: typing.Literal["user", "assistant", "system"],
content: str,
images: typing.Optional[list[str]] = None
) -> None:
"""
Appends a message to the given thread.
:param thread: The thread's ID.
:param role: The author of the message.
:param content: The message's actual content.
:param images: Any images that were attached to the message, in base64.
:return: None
"""
self._internal[thread]["messages"].append(self._construct_message(role, content, images))
def get_history(self, thread: str) -> list[dict[str, str]]:
"""
Gets the history of a thread.
"""
return self._internal[thread]["messages"].copy() # copy() makes it immutable.
SERVER_KEYS = list(CONFIG["ollama"].keys())
2024-01-11 14:41:07 +00:00
class Ollama(commands.Cog):
def __init__(self, bot: commands.Bot):
self.bot = bot
self.log = logging.getLogger("jimmy.cogs.ollama")
2024-01-10 15:11:36 +00:00
self.last_server = 0
2024-01-10 15:59:13 +00:00
self.contexts = {}
2024-01-12 15:39:39 +00:00
self.history = ChatHistory()
2024-01-10 15:11:36 +00:00
def next_server(self, increment: bool = True) -> str:
"""Returns the next server key."""
if increment:
self.last_server += 1
return SERVER_KEYS[self.last_server % len(SERVER_KEYS)]
async def ollama_stream(self, iterator: aiohttp.StreamReader) -> typing.AsyncIterator[dict]:
async for line in iterator:
original_line = line
line = line.decode("utf-8", "replace").strip()
try:
line = json.loads(line)
except json.JSONDecodeError:
self.log.warning("Unable to decode JSON: %r", original_line)
continue
else:
self.log.debug("Decoded JSON %r -> %r", original_line, line)
yield line
2024-01-11 13:20:32 +00:00
async def check_server(self, url: str) -> bool:
"""Checks that a server is online and responding."""
2024-01-11 13:32:41 +00:00
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(10)) as session:
2024-01-11 13:20:32 +00:00
self.log.debug("Checking if %r is online.", url)
try:
2024-01-11 13:32:41 +00:00
async with session.get(url + "/api/tags") as resp:
2024-01-11 13:20:32 +00:00
self.log.debug("%r is online.", resp.url.host)
return resp.ok
2024-01-11 14:34:46 +00:00
except (aiohttp.ClientConnectionError, asyncio.TimeoutError):
2024-01-11 14:41:07 +00:00
self.log.warning("%r is offline.", url, exc_info=True)
2024-01-11 13:20:32 +00:00
return False
@commands.slash_command()
async def ollama(
self,
ctx: discord.ApplicationContext,
query: typing.Annotated[
str,
discord.Option(
str,
"The query to feed into ollama. Not the system prompt.",
)
],
model: typing.Annotated[
str,
discord.Option(
str,
"The model to use for ollama. Defaults to 'llama2-uncensored:latest'.",
2024-01-10 20:20:00 +00:00
default="llama2-uncensored:7b-chat"
)
],
server: typing.Annotated[
str,
discord.Option(
str,
"The server to use for ollama.",
2024-01-10 15:11:36 +00:00
default="next",
choices=SERVER_KEYS
)
],
2024-01-10 15:59:13 +00:00
context: typing.Annotated[
str,
discord.Option(
str,
"The context key of a previous ollama response to use as context.",
default=None
)
],
give_acid: typing.Annotated[
bool,
discord.Option(
bool,
"Whether to give the AI acid, LSD, and other hallucinogens before responding.",
default=False
)
2024-01-10 16:10:45 +00:00
],
image: typing.Annotated[
discord.Attachment,
discord.Option(
discord.Attachment,
"An image to feed into ollama. Only works with llava.",
default=None
)
2024-01-10 15:59:13 +00:00
]
):
2024-01-10 15:59:13 +00:00
if context is not None:
if context not in self.contexts:
await ctx.respond("Invalid context key.")
return
await ctx.defer()
model = model.casefold()
try:
model, tag = model.split(":", 1)
model = model + ":" + tag
2024-01-10 15:11:36 +00:00
self.log.debug("Model %r already has a tag", model)
except ValueError:
model = model + ":latest"
self.log.debug("Resolved model to %r" % model)
2024-01-10 16:10:45 +00:00
if image:
if fnmatch(model, "llava:*") is False:
await ctx.respond(
"You can only use images with llava. Switching model to `llava:latest`.",
delete_after=5
)
model = "llava:latest"
if image.size > 1024 * 1024 * 25:
2024-01-10 16:10:45 +00:00
await ctx.respond("Attachment is too large. Maximum size is 25 MB, for sanity. Try compressing it.")
return
elif not fnmatch(image.content_type, "image/*"):
await ctx.respond("Attachment is not an image. Try using a different file.")
return
else:
data = io.BytesIO()
await image.save(data)
data.seek(0)
image_data = base64.b64encode(data.read()).decode("utf-8")
else:
image_data = None
2024-01-10 15:11:36 +00:00
if server == "next":
server = self.next_server()
elif server not in CONFIG["ollama"]:
await ctx.respond("Invalid server")
return
server_config = CONFIG["ollama"][server]
for model_pattern in server_config["allowed_models"]:
if fnmatch(model, model_pattern):
break
else:
allowed_models = ", ".join(map(discord.utils.escape_markdown, server_config["allowed_models"]))
await ctx.respond(f"Invalid model. You can only use one of the following models: {allowed_models}")
return
async with aiohttp.ClientSession(
base_url=server_config["base_url"],
2024-01-10 10:31:55 +00:00
timeout=aiohttp.ClientTimeout(0)
) as session:
embed = discord.Embed(
title="Checking server...",
description=f"Checking that specified model and tag ({model}) are available on the server.",
color=discord.Color.blurple(),
timestamp=discord.utils.utcnow()
)
2024-01-11 13:20:32 +00:00
embed.set_footer(text="Using server %r" % server, icon_url=server_config.get("icon_url"))
await ctx.respond(embed=embed)
2024-01-11 13:20:32 +00:00
if not await self.check_server(server_config["base_url"]):
for i in range(10):
server = self.next_server()
embed = discord.Embed(
title="Server was offline. Trying next server.",
description=f"Trying server {server}...",
color=discord.Color.gold(),
timestamp=discord.utils.utcnow()
)
embed.set_footer(text="Using server %r" % server, icon_url=server_config.get("icon_url"))
await ctx.edit(embed=embed)
2024-01-11 14:41:07 +00:00
await asyncio.sleep(1)
2024-01-11 13:20:32 +00:00
if await self.check_server(CONFIG["ollama"][server]["base_url"]):
server_config = CONFIG["ollama"][server]
break
else:
embed = discord.Embed(
title="All servers are offline.",
description="Please try again later.",
color=discord.Color.red(),
timestamp=discord.utils.utcnow()
)
embed.set_footer(text="Unable to continue.")
return await ctx.edit(embed=embed)
try:
2024-01-06 22:11:20 +00:00
self.log.debug("Connecting to %r", server_config["base_url"])
2024-01-06 21:58:09 +00:00
async with session.post("/api/show", json={"name": model}) as resp:
2024-01-06 22:11:20 +00:00
self.log.debug("%r responded.", server_config["base_url"])
if resp.status not in [404, 200]:
embed = discord.Embed(
url=resp.url,
title=f"HTTP {resp.status} {resp.reason!r} while checking for model.",
description=f"```{await resp.text() or 'No response body'}```"[:4096],
color=discord.Color.red(),
timestamp=discord.utils.utcnow()
)
embed.set_footer(text="Unable to continue.")
return await ctx.edit(embed=embed)
except aiohttp.ClientConnectionError as e:
embed = discord.Embed(
title="Connection error while checking for model.",
description=f"```{e}```"[:4096],
color=discord.Color.red(),
timestamp=discord.utils.utcnow()
)
embed.set_footer(text="Unable to continue.")
return await ctx.edit(embed=embed)
if resp.status == 404:
2024-01-06 22:11:20 +00:00
self.log.debug("Beginning download of %r", model)
2024-01-10 20:20:00 +00:00
2024-01-10 15:59:13 +00:00
def progress_bar(_v: float, action: str = None):
bar = "\N{large green square}" * round(_v / 10)
bar += "\N{white large square}" * (10 - len(bar))
2024-01-10 15:59:13 +00:00
bar += f" {_v:.2f}%"
if action:
return f"{action} {bar}"
return bar
embed = discord.Embed(
title=f"Downloading {model!r}",
description=f"Downloading {model!r} from {server_config['base_url']}",
color=discord.Color.blurple(),
timestamp=discord.utils.utcnow()
)
embed.add_field(name="Progress", value=progress_bar(0))
await ctx.edit(embed=embed)
last_update = time.time()
2024-01-06 21:58:09 +00:00
async with session.post("/api/pull", json={"name": model, "stream": True}, timeout=None) as response:
if response.status != 200:
embed = discord.Embed(
url=response.url,
title=f"HTTP {response.status} {response.reason!r} while downloading model.",
description=f"```{await response.text() or 'No response body'}```"[:4096],
color=discord.Color.red(),
timestamp=discord.utils.utcnow()
)
embed.set_footer(text="Unable to continue.")
return await ctx.edit(embed=embed)
2024-01-11 14:58:00 +00:00
view = OllamaView(ctx)
async for line in self.ollama_stream(response.content):
2024-01-11 14:58:00 +00:00
if view.cancel.is_set():
embed = discord.Embed(
title="Download cancelled.",
colour=discord.Colour.red(),
timestamp=discord.utils.utcnow()
)
return await ctx.edit(embed=embed, view=None)
if time.time() >= (last_update + 5.1):
if line.get("total") is not None and line.get("completed") is not None:
percent = (line["completed"] / line["total"]) * 100
else:
percent = 50.0
embed.fields[0].value = progress_bar(percent, line["status"])
2024-01-11 14:58:00 +00:00
await ctx.edit(embed=embed, view=view)
last_update = time.time()
2024-01-06 22:11:20 +00:00
else:
self.log.debug("Model %r already exists on server.", model)
2024-01-10 15:59:13 +00:00
key = os.urandom(6).hex()
embed = discord.Embed(
title="Generating response...",
2024-01-06 22:07:31 +00:00
description=">>> ",
color=discord.Color.blurple(),
timestamp=discord.utils.utcnow()
)
embed.set_author(
name=model,
url="https://ollama.ai/library/" + model.split(":")[0],
icon_url="https://ollama.ai/public/ollama.png"
)
embed.add_field(
name="Prompt",
value=">>> " + textwrap.shorten(query, width=1020, placeholder="..."),
inline=False
)
2024-01-10 10:39:37 +00:00
embed.set_footer(text="Using server %r" % server, icon_url=server_config.get("icon_url"))
2024-01-10 16:20:03 +00:00
if image_data:
if (image.height / image.width) >= 1.5:
embed.set_image(url=image.url)
else:
embed.set_thumbnail(url=image.url)
view = OllamaView(ctx)
2024-01-09 22:54:37 +00:00
try:
await ctx.edit(embed=embed, view=view)
2024-01-09 22:54:37 +00:00
except discord.NotFound:
await ctx.respond(embed=embed, view=view)
2024-01-10 15:59:13 +00:00
self.log.debug("Beginning to generate response with key %r.", key)
params = {}
if give_acid is True:
2024-01-10 16:03:29 +00:00
params["temperature"] = 500
2024-01-10 15:59:13 +00:00
params["top_k"] = 500
2024-01-10 16:03:29 +00:00
params["top_p"] = 500
2024-01-10 15:59:13 +00:00
2024-01-12 15:39:39 +00:00
if context is None or context in self.contexts:
context = self.history.create_thread(ctx.user)
messages = self.history.get_history(context)
user_message = {
"role": "user",
"content": query
}
if image_data:
user_message["images"] = [image_data]
messages.append(user_message)
2024-01-10 15:59:13 +00:00
payload = {
"model": model,
"stream": True,
"options": params,
2024-01-12 15:39:39 +00:00
"messages": messages
2024-01-10 15:59:13 +00:00
}
async with session.post(
2024-01-12 15:39:39 +00:00
"/api/chat",
2024-01-10 15:59:13 +00:00
json=payload,
) as response:
if response.status != 200:
embed = discord.Embed(
url=response.url,
title=f"HTTP {response.status} {response.reason!r} while generating response.",
description=f"```{await response.text() or 'No response body'}```"[:4096],
color=discord.Color.red(),
timestamp=discord.utils.utcnow()
)
embed.set_footer(text="Unable to continue.")
return await ctx.edit(embed=embed)
last_update = time.time()
2024-01-09 14:49:29 +00:00
buffer = io.StringIO()
if not view.cancel.is_set():
async for line in self.ollama_stream(response.content):
2024-01-12 15:39:39 +00:00
buffer.write(line["assistant"])
embed.description += line["assistant"]
embed.timestamp = discord.utils.utcnow()
if len(embed.description) >= 4096:
2024-01-12 15:39:39 +00:00
embed.description = embed.description = "..." + line["assistant"]
2024-01-10 10:13:37 +00:00
if view.cancel.is_set():
break
2024-01-10 10:13:37 +00:00
if time.time() >= (last_update + 5.1):
await ctx.edit(embed=embed, view=view)
self.log.debug(f"Updating message ({last_update} -> {time.time()})")
last_update = time.time()
2024-01-10 10:13:37 +00:00
view.stop()
2024-01-12 15:39:39 +00:00
self.history.add_message(context, "user", user_message["content"], user_message["images"])
self.history.add_message(context, "assistant", buffer.getvalue())
embed.add_field(name="Context Key", value=context, inline=True)
2024-01-09 09:34:55 +00:00
self.log.debug("Ollama finished consuming.")
2024-01-09 08:56:46 +00:00
embed.title = "Done!"
2024-01-12 10:00:12 +00:00
embed.colour = discord.Color.green()
2024-01-09 14:49:29 +00:00
value = buffer.getvalue()
if len(value) >= 4096:
embeds = [discord.Embed(title="Done!", colour=discord.Color.green())]
current_page = ""
for word in value.split():
if len(current_page) + len(word) >= 4096:
embeds.append(discord.Embed(description=current_page))
current_page = ""
current_page += word + " "
else:
embeds.append(discord.Embed(description=current_page))
2024-01-10 10:41:30 +00:00
await ctx.edit(embeds=embeds, view=None)
2024-01-09 14:49:29 +00:00
else:
2024-01-10 10:41:30 +00:00
await ctx.edit(embed=embed, view=None)
2024-01-12 09:53:57 +00:00
if line.get("done"):
2024-01-12 10:00:12 +00:00
total_duration = get_time_spent(line["total_duration"])
load_duration = get_time_spent(line["load_duration"])
prompt_eval_duration = get_time_spent(line["prompt_eval_duration"])
eval_duration = get_time_spent(line["eval_duration"])
2024-01-12 09:53:57 +00:00
embed = discord.Embed(
title="Timings",
description=f"Total: {total_duration}\nLoad: {load_duration}\n"
f"Prompt Eval: {prompt_eval_duration}\nEval: {eval_duration}\n"
f"Prompt Tokens: {line['prompt_eval_count']:,}\n"
f"Response Tokens: {line['eval_count']:,}",
color=discord.Color.blurple(),
timestamp=discord.utils.utcnow()
)
return await ctx.respond(embed=embed, ephemeral=True)
def setup(bot):
bot.add_cog(Ollama(bot))