253 lines
10 KiB
Python
253 lines
10 KiB
Python
|
import collections
|
||
|
import json
|
||
|
import logging
|
||
|
import time
|
||
|
import typing
|
||
|
from fnmatch import fnmatch
|
||
|
|
||
|
import aiohttp
|
||
|
import discord
|
||
|
from discord.ext import commands
|
||
|
from ..conf import CONFIG
|
||
|
|
||
|
|
||
|
SERVER_KEYS = list(CONFIG["ollama"].keys())
|
||
|
|
||
|
class Ollama(commands.Cog):
|
||
|
def __init__(self, bot: commands.Bot):
|
||
|
self.bot = bot
|
||
|
self.log = logging.getLogger("jimmy.cogs.ollama")
|
||
|
|
||
|
async def ollama_stream(self, iterator: aiohttp.StreamReader) -> typing.AsyncIterator[dict]:
|
||
|
async for line in iterator:
|
||
|
original_line = line
|
||
|
line = line.decode("utf-8", "replace").strip()
|
||
|
try:
|
||
|
line = json.loads(line)
|
||
|
except json.JSONDecodeError:
|
||
|
self.log.warning("Unable to decode JSON: %r", original_line)
|
||
|
continue
|
||
|
else:
|
||
|
self.log.debug("Decoded JSON %r -> %r", original_line, line)
|
||
|
yield line
|
||
|
|
||
|
@commands.slash_command()
|
||
|
async def ollama(
|
||
|
self,
|
||
|
ctx: discord.ApplicationContext,
|
||
|
query: typing.Annotated[
|
||
|
str,
|
||
|
discord.Option(
|
||
|
str,
|
||
|
"The query to feed into ollama. Not the system prompt.",
|
||
|
default=None
|
||
|
)
|
||
|
],
|
||
|
model: typing.Annotated[
|
||
|
str,
|
||
|
discord.Option(
|
||
|
str,
|
||
|
"The model to use for ollama. Defaults to 'llama2-uncensored:latest'.",
|
||
|
default="llama2-uncensored:latest"
|
||
|
)
|
||
|
],
|
||
|
server: typing.Annotated[
|
||
|
str,
|
||
|
discord.Option(
|
||
|
str,
|
||
|
"The server to use for ollama.",
|
||
|
default=SERVER_KEYS[0],
|
||
|
choices=SERVER_KEYS
|
||
|
)
|
||
|
],
|
||
|
):
|
||
|
with open("./assets/ollama-prompt.txt") as file:
|
||
|
system_prompt = file.read()
|
||
|
|
||
|
if query is None:
|
||
|
class InputPrompt(discord.ui.Modal):
|
||
|
def __init__(self, is_owner: bool):
|
||
|
super().__init__(
|
||
|
discord.ui.InputText(
|
||
|
label="User Prompt",
|
||
|
placeholder="Enter prompt",
|
||
|
min_length=1,
|
||
|
max_length=4000,
|
||
|
style=discord.InputTextStyle.long,
|
||
|
),
|
||
|
title="Enter prompt",
|
||
|
timeout=120,
|
||
|
)
|
||
|
if is_owner:
|
||
|
self.add_item(
|
||
|
discord.ui.InputText(
|
||
|
label="System Prompt",
|
||
|
placeholder="Enter prompt",
|
||
|
min_length=1,
|
||
|
max_length=4000,
|
||
|
style=discord.InputTextStyle.long,
|
||
|
value=system_prompt,
|
||
|
)
|
||
|
)
|
||
|
|
||
|
self.user_prompt = None
|
||
|
self.system_prompt = system_prompt
|
||
|
|
||
|
async def callback(self, interaction: discord.Interaction):
|
||
|
self.user_prompt = self.children[0].value
|
||
|
if len(self.children) > 1:
|
||
|
self.system_prompt = self.children[1].value
|
||
|
await interaction.response.defer()
|
||
|
self.stop()
|
||
|
|
||
|
modal = InputPrompt(await self.bot.is_owner(ctx.author))
|
||
|
await ctx.send_modal(modal)
|
||
|
await modal.wait()
|
||
|
query = modal.user_prompt
|
||
|
if not modal.user_prompt:
|
||
|
return
|
||
|
system_prompt = modal.system_prompt or system_prompt
|
||
|
else:
|
||
|
await ctx.defer()
|
||
|
|
||
|
model = model.casefold()
|
||
|
try:
|
||
|
model, tag = model.split(":", 1)
|
||
|
model = model + ":" + tag
|
||
|
self.log.debug("Model %r already has a tag")
|
||
|
except ValueError:
|
||
|
model = model + ":latest"
|
||
|
self.log.debug("Resolved model to %r" % model)
|
||
|
|
||
|
if server not in CONFIG["ollama"]:
|
||
|
await ctx.respond("Invalid server")
|
||
|
return
|
||
|
|
||
|
server_config = CONFIG["ollama"][server]
|
||
|
for model_pattern in server_config["allowed_models"]:
|
||
|
if fnmatch(model, model_pattern):
|
||
|
break
|
||
|
else:
|
||
|
allowed_models = ", ".join(map(discord.utils.escape_markdown, server_config["allowed_models"]))
|
||
|
await ctx.respond(f"Invalid model. You can only use one of the following models: {allowed_models}")
|
||
|
return
|
||
|
|
||
|
async with aiohttp.ClientSession(
|
||
|
base_url=server_config["base_url"],
|
||
|
) as session:
|
||
|
embed = discord.Embed(
|
||
|
title="Checking server...",
|
||
|
description=f"Checking that specified model and tag ({model}) are available on the server.",
|
||
|
color=discord.Color.blurple(),
|
||
|
timestamp=discord.utils.utcnow()
|
||
|
)
|
||
|
await ctx.respond(embed=embed)
|
||
|
|
||
|
try:
|
||
|
async with session.post("/show", json={"name": model}) as resp:
|
||
|
if resp.status not in [404, 200]:
|
||
|
embed = discord.Embed(
|
||
|
url=resp.url,
|
||
|
title=f"HTTP {resp.status} {resp.reason!r} while checking for model.",
|
||
|
description=f"```{await resp.text() or 'No response body'}```"[:4096],
|
||
|
color=discord.Color.red(),
|
||
|
timestamp=discord.utils.utcnow()
|
||
|
)
|
||
|
embed.set_footer(text="Unable to continue.")
|
||
|
return await ctx.edit(embed=embed)
|
||
|
except aiohttp.ClientConnectionError as e:
|
||
|
embed = discord.Embed(
|
||
|
title="Connection error while checking for model.",
|
||
|
description=f"```{e}```"[:4096],
|
||
|
color=discord.Color.red(),
|
||
|
timestamp=discord.utils.utcnow()
|
||
|
)
|
||
|
embed.set_footer(text="Unable to continue.")
|
||
|
return await ctx.edit(embed=embed)
|
||
|
|
||
|
if resp.status == 404:
|
||
|
def progress_bar(value: float, action: str = None):
|
||
|
bar = "\N{green large square}" * round(value / 10)
|
||
|
bar += "\N{white large square}" * (10 - len(bar))
|
||
|
bar += f" {value:.2f}%"
|
||
|
if action:
|
||
|
return f"{action} {bar}"
|
||
|
return bar
|
||
|
|
||
|
embed = discord.Embed(
|
||
|
title=f"Downloading {model!r}",
|
||
|
description=f"Downloading {model!r} from {server_config['base_url']}",
|
||
|
color=discord.Color.blurple(),
|
||
|
timestamp=discord.utils.utcnow()
|
||
|
)
|
||
|
embed.add_field(name="Progress", value=progress_bar(0))
|
||
|
await ctx.edit(embed=embed)
|
||
|
|
||
|
last_update = time.time()
|
||
|
|
||
|
async with session.post("/pull", json={"name": model, "stream": True}, timeout=None) as response:
|
||
|
if response.status != 200:
|
||
|
embed = discord.Embed(
|
||
|
url=response.url,
|
||
|
title=f"HTTP {response.status} {response.reason!r} while downloading model.",
|
||
|
description=f"```{await response.text() or 'No response body'}```"[:4096],
|
||
|
color=discord.Color.red(),
|
||
|
timestamp=discord.utils.utcnow()
|
||
|
)
|
||
|
embed.set_footer(text="Unable to continue.")
|
||
|
return await ctx.edit(embed=embed)
|
||
|
|
||
|
async for line in self.ollama_stream(response.content):
|
||
|
if time.time() >= (last_update + 5.1):
|
||
|
if line.get("total") is not None and line.get("completed") is not None:
|
||
|
percent = (line["completed"] / line["total"]) * 100
|
||
|
else:
|
||
|
percent = 50.0
|
||
|
|
||
|
embed.fields[0].value = progress_bar(percent, line["status"])
|
||
|
await ctx.edit(embed=embed)
|
||
|
last_update = time.time()
|
||
|
|
||
|
embed = discord.Embed(
|
||
|
title="Generating response...",
|
||
|
description=">>> \u200b",
|
||
|
color=discord.Color.blurple()
|
||
|
)
|
||
|
async with session.post(
|
||
|
"/generate",
|
||
|
json={
|
||
|
"model": model,
|
||
|
"prompt": query,
|
||
|
"format": "json",
|
||
|
"system": system_prompt,
|
||
|
"stream": True
|
||
|
}
|
||
|
) as response:
|
||
|
if response.status != 200:
|
||
|
embed = discord.Embed(
|
||
|
url=response.url,
|
||
|
title=f"HTTP {response.status} {response.reason!r} while generating response.",
|
||
|
description=f"```{await response.text() or 'No response body'}```"[:4096],
|
||
|
color=discord.Color.red(),
|
||
|
timestamp=discord.utils.utcnow()
|
||
|
)
|
||
|
embed.set_footer(text="Unable to continue.")
|
||
|
return await ctx.edit(embed=embed)
|
||
|
|
||
|
last_update = time.time()
|
||
|
async for line in self.ollama_stream(response.content):
|
||
|
if line.get("done", False) is True or time.time() >= (last_update + 5.1):
|
||
|
if line.get("done"):
|
||
|
embed.title = "Done!"
|
||
|
embed.color = discord.Color.green()
|
||
|
embed.description += line["text"]
|
||
|
if len(embed.description) >= 4096:
|
||
|
embed.description = embed.description[:4093] + "..."
|
||
|
break
|
||
|
await ctx.edit(embed=embed)
|
||
|
last_update = time.time()
|
||
|
|
||
|
|
||
|
def setup(bot):
|
||
|
bot.add_cog(Ollama(bot))
|