2024-01-10 10:13:37 +00:00
|
|
|
import asyncio
|
2024-04-18 00:24:58 +01:00
|
|
|
import base64
|
|
|
|
import io
|
2024-01-06 21:43:52 +00:00
|
|
|
import json
|
|
|
|
import logging
|
2024-01-10 15:59:13 +00:00
|
|
|
import os
|
2024-01-09 14:49:29 +00:00
|
|
|
import textwrap
|
2024-01-06 21:43:52 +00:00
|
|
|
import time
|
|
|
|
import typing
|
2024-04-18 00:24:58 +01:00
|
|
|
from fnmatch import fnmatch
|
2024-04-13 23:51:50 +01:00
|
|
|
|
2024-04-18 00:24:58 +01:00
|
|
|
import aiohttp
|
|
|
|
import discord
|
2024-01-12 16:47:45 +00:00
|
|
|
import redis
|
2024-03-22 09:08:03 +00:00
|
|
|
from discord import Interaction
|
2024-04-18 00:24:58 +01:00
|
|
|
from discord.ext import commands
|
2024-01-10 10:13:37 +00:00
|
|
|
from discord.ui import View, button
|
2024-01-16 10:14:50 +00:00
|
|
|
from yarl import URL
|
2024-04-18 00:24:58 +01:00
|
|
|
|
2024-01-06 21:56:18 +00:00
|
|
|
from conf import CONFIG
|
2024-01-06 21:43:52 +00:00
|
|
|
|
|
|
|
|
2024-04-13 23:51:50 +01:00
|
|
|
async def ollama_stream(iterator: aiohttp.StreamReader) -> typing.AsyncIterator[dict]:
|
|
|
|
async for line in iterator:
|
|
|
|
line = line.decode("utf-8", "replace").strip()
|
|
|
|
try:
|
|
|
|
line = json.loads(line)
|
|
|
|
except json.JSONDecodeError:
|
|
|
|
continue
|
|
|
|
yield line
|
|
|
|
|
|
|
|
|
2024-01-12 10:00:12 +00:00
|
|
|
def get_time_spent(nanoseconds: int) -> str:
|
|
|
|
hours, minutes, seconds = 0, 0, 0
|
|
|
|
seconds = nanoseconds / 1e9
|
|
|
|
if seconds >= 60:
|
|
|
|
minutes, seconds = divmod(seconds, 60)
|
|
|
|
if minutes >= 60:
|
|
|
|
hours, minutes = divmod(minutes, 60)
|
|
|
|
|
|
|
|
result = []
|
|
|
|
if seconds:
|
|
|
|
if seconds != 1:
|
|
|
|
label = "seconds"
|
|
|
|
else:
|
|
|
|
label = "second"
|
|
|
|
result.append(f"{round(seconds)} {label}")
|
|
|
|
if minutes:
|
|
|
|
if minutes != 1:
|
|
|
|
label = "minutes"
|
|
|
|
else:
|
|
|
|
label = "minute"
|
|
|
|
result.append(f"{round(minutes)} {label}")
|
|
|
|
if hours:
|
|
|
|
if hours != 1:
|
|
|
|
label = "hours"
|
|
|
|
else:
|
|
|
|
label = "hour"
|
|
|
|
result.append(f"{round(hours)} {label}")
|
|
|
|
return ", ".join(reversed(result))
|
|
|
|
|
|
|
|
|
2024-04-13 23:51:50 +01:00
|
|
|
class OllamaDownloadHandler:
|
2024-04-14 18:26:51 +01:00
|
|
|
def __init__(self, base_url: str, model: str):
|
|
|
|
self.base_url = base_url
|
2024-04-13 23:51:50 +01:00
|
|
|
self.model = model
|
|
|
|
self._abort = asyncio.Event()
|
|
|
|
self._total = 1
|
|
|
|
self._completed = 0
|
|
|
|
self.status = "starting"
|
|
|
|
|
|
|
|
self.total_duration_s = 0
|
|
|
|
self.load_duration_s = 0
|
|
|
|
self.prompt_eval_duration_s = 0
|
|
|
|
self.eval_duration_s = 0
|
|
|
|
self.eval_count = 0
|
|
|
|
self.prompt_eval_count = 0
|
|
|
|
|
|
|
|
def abort(self):
|
|
|
|
self._abort.set()
|
|
|
|
|
2024-04-14 17:23:05 +01:00
|
|
|
def __enter__(self):
|
|
|
|
self._abort.clear()
|
|
|
|
return self
|
|
|
|
|
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
|
|
self._abort.set()
|
|
|
|
|
2024-04-13 23:51:50 +01:00
|
|
|
@property
|
|
|
|
def percent(self) -> float:
|
|
|
|
return round((self._completed / self._total) * 100, 2)
|
|
|
|
|
|
|
|
def parse_line(self, line: dict):
|
|
|
|
if line.get("total"):
|
|
|
|
self._total = line["total"]
|
|
|
|
if line.get("completed"):
|
|
|
|
self._completed = line["completed"]
|
|
|
|
if line.get("status"):
|
|
|
|
self.status = line["status"]
|
|
|
|
|
|
|
|
async def __aiter__(self):
|
2024-04-14 18:45:48 +01:00
|
|
|
async with aiohttp.ClientSession(base_url=self.base_url) as client:
|
2024-04-14 18:26:51 +01:00
|
|
|
async with client.post("/api/pull", json={"name": self.model, "stream": True}, timeout=None) as response:
|
|
|
|
response.raise_for_status()
|
|
|
|
async for line in ollama_stream(response.content):
|
|
|
|
self.parse_line(line)
|
|
|
|
if self._abort.is_set():
|
|
|
|
break
|
|
|
|
yield line
|
2024-04-13 23:51:50 +01:00
|
|
|
|
2024-04-14 17:26:14 +01:00
|
|
|
async def flatten(self) -> "OllamaDownloadHandler":
|
|
|
|
"""Returns the current instance, but fully consumed."""
|
2024-04-13 23:51:50 +01:00
|
|
|
async for _ in self:
|
|
|
|
pass
|
2024-04-14 17:26:14 +01:00
|
|
|
return self
|
2024-04-13 23:51:50 +01:00
|
|
|
|
|
|
|
|
|
|
|
class OllamaChatHandler:
|
2024-04-14 18:26:51 +01:00
|
|
|
def __init__(self, base_url: str, model: str, messages: list):
|
|
|
|
self.base_url = base_url
|
2024-04-13 23:51:50 +01:00
|
|
|
self.model = model
|
|
|
|
self.messages = messages
|
|
|
|
self._abort = asyncio.Event()
|
|
|
|
self.buffer = io.StringIO()
|
|
|
|
|
|
|
|
self.total_duration_s = 0
|
|
|
|
self.load_duration_s = 0
|
|
|
|
self.prompt_eval_duration_s = 0
|
|
|
|
self.eval_duration_s = 0
|
|
|
|
self.eval_count = 0
|
|
|
|
self.prompt_eval_count = 0
|
|
|
|
|
|
|
|
def abort(self):
|
|
|
|
self._abort.set()
|
|
|
|
|
|
|
|
@property
|
|
|
|
def result(self) -> str:
|
|
|
|
"""The current response. Can be called multiple times."""
|
|
|
|
return self.buffer.getvalue()
|
|
|
|
|
|
|
|
def parse_line(self, line: dict):
|
|
|
|
if line.get("total_duration"):
|
|
|
|
self.total_duration_s = line["total_duration"] / 1e9
|
|
|
|
if line.get("load_duration"):
|
|
|
|
self.load_duration_s = line["load_duration"] / 1e9
|
|
|
|
if line.get("prompt_eval_duration"):
|
|
|
|
self.prompt_eval_duration_s = line["prompt_eval_duration"] / 1e9
|
|
|
|
if line.get("eval_duration"):
|
|
|
|
self.eval_duration_s = line["eval_duration"] / 1e9
|
|
|
|
|
|
|
|
if line.get("eval_count"):
|
|
|
|
self.eval_count = line["eval_count"]
|
|
|
|
if line.get("prompt_eval_count"):
|
|
|
|
self.prompt_eval_count = line["prompt_eval_count"]
|
|
|
|
|
2024-04-14 17:19:54 +01:00
|
|
|
def __enter__(self):
|
|
|
|
self._abort.clear()
|
|
|
|
return self
|
|
|
|
|
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
|
|
self._abort.set()
|
|
|
|
|
2024-04-13 23:51:50 +01:00
|
|
|
async def __aiter__(self):
|
2024-04-14 18:45:48 +01:00
|
|
|
async with aiohttp.ClientSession(base_url=self.base_url) as client:
|
2024-04-14 18:26:51 +01:00
|
|
|
async with client.post(
|
2024-04-16 00:46:26 +01:00
|
|
|
"/api/chat", json={"model": self.model, "stream": True, "messages": self.messages}
|
2024-04-14 18:26:51 +01:00
|
|
|
) as response:
|
|
|
|
response.raise_for_status()
|
|
|
|
async for line in ollama_stream(response.content):
|
|
|
|
if self._abort.is_set():
|
|
|
|
break
|
|
|
|
|
|
|
|
if line.get("message"):
|
|
|
|
self.buffer.write(line["message"]["content"])
|
|
|
|
yield line
|
|
|
|
|
|
|
|
if line.get("done"):
|
|
|
|
break
|
2024-04-13 23:51:50 +01:00
|
|
|
|
|
|
|
|
|
|
|
class OllamaClient:
|
|
|
|
def __init__(self, base_url: str, authorisation: tuple[str, str] = None):
|
|
|
|
self.base_url = base_url
|
|
|
|
self.authorisation = authorisation
|
|
|
|
|
2024-04-16 00:46:26 +01:00
|
|
|
def with_client(self, timeout: aiohttp.ClientTimeout | float | int | None = None) -> aiohttp.ClientSession:
|
2024-04-13 23:51:50 +01:00
|
|
|
"""
|
|
|
|
Creates an instance for a request, with properly populated values.
|
|
|
|
:param timeout:
|
|
|
|
:return:
|
|
|
|
"""
|
|
|
|
if isinstance(timeout, (float, int)):
|
|
|
|
if timeout == -1:
|
2024-04-14 18:40:53 +01:00
|
|
|
timeout = 10800
|
|
|
|
timeout = aiohttp.ClientTimeout(timeout)
|
2024-04-13 23:51:50 +01:00
|
|
|
else:
|
2024-04-14 18:42:53 +01:00
|
|
|
timeout = timeout or aiohttp.ClientTimeout(120)
|
2024-04-14 18:40:53 +01:00
|
|
|
return aiohttp.ClientSession(self.base_url, timeout=timeout, auth=self.authorisation)
|
2024-04-13 23:51:50 +01:00
|
|
|
|
|
|
|
async def get_tags(self) -> dict[typing.Literal["models"], dict[str, str, int, dict[str, str, None]]]:
|
|
|
|
"""
|
|
|
|
Gets the tags for the server.
|
|
|
|
:return:
|
|
|
|
"""
|
|
|
|
async with self.with_client() as client:
|
|
|
|
async with client.get("/api/tags") as resp:
|
|
|
|
return await resp.json()
|
|
|
|
|
|
|
|
async def has_model_named(self, name: str, tag: str = None) -> bool:
|
|
|
|
"""Checks that the given server has the model downloaded, optionally with a tag.
|
|
|
|
|
|
|
|
:param name: The name of the model (e.g. orca-mini, orca-mini:latest)
|
|
|
|
:param tag: a specific tag to check for (e.g. latest, chat)
|
|
|
|
:return: A boolean indicating an existing download."""
|
|
|
|
if tag is not None:
|
|
|
|
name += ":" + tag
|
|
|
|
async with self.with_client() as client:
|
|
|
|
async with client.post("/api/show", json={"name": name}) as resp:
|
|
|
|
return resp.status == 200
|
|
|
|
|
|
|
|
def download_model(self, name: str, tag: str = "latest") -> OllamaDownloadHandler:
|
|
|
|
"""Starts the download for a model.
|
|
|
|
|
|
|
|
:param name: The name of the model.
|
|
|
|
:param tag: The tag of the model. Defaults to latest.
|
|
|
|
:return: An OllamaDownloadHandler instance.
|
|
|
|
"""
|
2024-04-14 18:26:51 +01:00
|
|
|
handler = OllamaDownloadHandler(self.base_url, name + ":" + tag)
|
2024-04-13 23:51:50 +01:00
|
|
|
return handler
|
|
|
|
|
|
|
|
def new_chat(
|
2024-04-16 00:46:26 +01:00
|
|
|
self,
|
|
|
|
model: str,
|
|
|
|
messages: list[dict[str, str]],
|
2024-04-13 23:51:50 +01:00
|
|
|
) -> OllamaChatHandler:
|
|
|
|
"""
|
|
|
|
Starts a chat with the given messages.
|
|
|
|
|
|
|
|
:param model:
|
|
|
|
:param messages:
|
|
|
|
:return:
|
|
|
|
"""
|
2024-04-14 18:26:51 +01:00
|
|
|
handler = OllamaChatHandler(self.base_url, model, messages)
|
2024-04-13 23:51:50 +01:00
|
|
|
return handler
|
|
|
|
|
|
|
|
|
2024-01-10 10:13:37 +00:00
|
|
|
class OllamaView(View):
|
|
|
|
def __init__(self, ctx: discord.ApplicationContext):
|
|
|
|
super().__init__(timeout=3600, disable_on_timeout=True)
|
|
|
|
self.ctx = ctx
|
|
|
|
self.cancel = asyncio.Event()
|
|
|
|
|
2024-01-10 16:04:58 +00:00
|
|
|
async def interaction_check(self, interaction: discord.Interaction) -> bool:
|
|
|
|
return interaction.user == self.ctx.user
|
|
|
|
|
2024-01-10 10:13:37 +00:00
|
|
|
@button(label="Stop", style=discord.ButtonStyle.danger, emoji="\N{wastebasket}\U0000fe0f")
|
|
|
|
async def _stop(self, btn: discord.ui.Button, interaction: discord.Interaction):
|
|
|
|
self.cancel.set()
|
|
|
|
btn.disabled = True
|
|
|
|
await interaction.response.edit_message(view=self)
|
|
|
|
self.stop()
|
|
|
|
|
|
|
|
|
2024-01-12 15:39:39 +00:00
|
|
|
class ChatHistory:
|
|
|
|
def __init__(self):
|
|
|
|
self._internal = {}
|
2024-01-12 17:18:23 +00:00
|
|
|
self.log = logging.getLogger("jimmy.cogs.ollama.history")
|
2024-01-12 16:50:28 +00:00
|
|
|
no_ping = CONFIG["redis"].pop("no_ping", False)
|
2024-01-12 16:47:45 +00:00
|
|
|
self.redis = redis.Redis(**CONFIG["redis"])
|
2024-01-12 16:50:28 +00:00
|
|
|
if no_ping is False:
|
|
|
|
assert self.redis.ping(), "Redis appears to be offline."
|
2024-01-12 16:47:45 +00:00
|
|
|
|
|
|
|
def load_thread(self, thread_id: str):
|
|
|
|
value: str = self.redis.get("threads:" + thread_id)
|
|
|
|
if value:
|
2024-01-12 17:18:23 +00:00
|
|
|
self.log.debug("Loaded thread %r: %r", thread_id, value)
|
2024-01-12 16:47:45 +00:00
|
|
|
loaded = json.loads(value)
|
|
|
|
self._internal.update(loaded)
|
|
|
|
return self.get_thread(thread_id)
|
|
|
|
|
|
|
|
def save_thread(self, thread_id: str):
|
2024-01-12 17:18:23 +00:00
|
|
|
self.log.info("Saving thread:%s - %r", thread_id, self._internal[thread_id])
|
2024-04-16 00:46:26 +01:00
|
|
|
self.redis.set("threads:" + thread_id, json.dumps(self._internal[thread_id]))
|
2024-01-12 15:39:39 +00:00
|
|
|
|
2024-03-22 09:08:03 +00:00
|
|
|
def create_thread(self, member: discord.Member, default: str | None = None) -> str:
|
2024-01-12 15:39:39 +00:00
|
|
|
"""
|
|
|
|
Creates a thread, returns its ID.
|
2024-04-13 23:51:50 +01:00
|
|
|
|
|
|
|
:param member: The member who created the thread.
|
|
|
|
:param default: The system prompt to use.
|
|
|
|
:return: The thread's ID.
|
2024-01-12 15:39:39 +00:00
|
|
|
"""
|
|
|
|
key = os.urandom(3).hex()
|
2024-04-16 00:46:26 +01:00
|
|
|
self._internal[key] = {"member": member.id, "seed": round(time.time()), "messages": []}
|
2024-01-12 15:39:39 +00:00
|
|
|
with open("./assets/ollama-prompt.txt") as file:
|
2024-03-22 09:08:03 +00:00
|
|
|
system_prompt = default or file.read()
|
2024-04-16 00:46:26 +01:00
|
|
|
self.add_message(key, "system", system_prompt)
|
2024-01-12 15:39:39 +00:00
|
|
|
return key
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def _construct_message(role: str, content: str, images: typing.Optional[list[str]]) -> dict[str, str]:
|
2024-04-16 00:46:26 +01:00
|
|
|
x = {"role": role, "content": content}
|
2024-01-12 15:39:39 +00:00
|
|
|
if images:
|
|
|
|
x["images"] = images
|
|
|
|
return x
|
|
|
|
|
2024-01-12 16:26:32 +00:00
|
|
|
@staticmethod
|
|
|
|
def autocomplete(ctx: discord.AutocompleteContext):
|
|
|
|
# noinspection PyTypeChecker
|
|
|
|
cog: Ollama = ctx.bot.get_cog("Ollama")
|
|
|
|
instance = cog.history
|
2024-01-12 17:04:24 +00:00
|
|
|
return list(
|
|
|
|
filter(
|
2024-04-16 00:46:26 +01:00
|
|
|
lambda v: (ctx.value or v) in v,
|
|
|
|
map(lambda d: list(d.keys()), instance.threads_for(ctx.interaction.user)),
|
2024-01-12 17:04:24 +00:00
|
|
|
)
|
|
|
|
)
|
2024-01-12 16:26:32 +00:00
|
|
|
|
2024-01-12 16:59:29 +00:00
|
|
|
def all_threads(self) -> dict[str, dict[str, list[dict[str, str]] | int]]:
|
2024-01-12 16:26:32 +00:00
|
|
|
"""Returns all saved threads."""
|
|
|
|
return self._internal.copy()
|
|
|
|
|
2024-01-12 17:04:24 +00:00
|
|
|
def threads_for(self, user: discord.Member) -> dict[str, dict[str, list[dict[str, str]] | int]]:
|
2024-01-12 16:26:32 +00:00
|
|
|
"""Returns all saved threads for a specific user"""
|
|
|
|
t = self.all_threads()
|
|
|
|
for k, v in t.copy().items():
|
2024-01-12 16:59:29 +00:00
|
|
|
if v["member"] != user.id:
|
2024-01-12 16:26:32 +00:00
|
|
|
t.pop(k)
|
|
|
|
return t
|
|
|
|
|
2024-01-12 15:39:39 +00:00
|
|
|
def add_message(
|
2024-04-16 00:46:26 +01:00
|
|
|
self,
|
|
|
|
thread: str,
|
|
|
|
role: typing.Literal["user", "assistant", "system"],
|
|
|
|
content: str,
|
|
|
|
images: typing.Optional[list[str]] = None,
|
2024-01-12 15:39:39 +00:00
|
|
|
) -> None:
|
|
|
|
"""
|
|
|
|
Appends a message to the given thread.
|
|
|
|
|
|
|
|
:param thread: The thread's ID.
|
|
|
|
:param role: The author of the message.
|
|
|
|
:param content: The message's actual content.
|
|
|
|
:param images: Any images that were attached to the message, in base64.
|
|
|
|
:return: None
|
|
|
|
"""
|
2024-01-12 17:18:23 +00:00
|
|
|
new = self._construct_message(role, content, images)
|
|
|
|
self.log.debug("Adding message to thread %r: %r", thread, new)
|
|
|
|
self._internal[thread]["messages"].append(new)
|
2024-01-12 15:39:39 +00:00
|
|
|
|
|
|
|
def get_history(self, thread: str) -> list[dict[str, str]]:
|
|
|
|
"""
|
|
|
|
Gets the history of a thread.
|
|
|
|
"""
|
2024-01-12 16:26:32 +00:00
|
|
|
if self._internal.get(thread) is None:
|
|
|
|
return []
|
2024-01-12 15:39:39 +00:00
|
|
|
return self._internal[thread]["messages"].copy() # copy() makes it immutable.
|
|
|
|
|
2024-01-12 15:51:25 +00:00
|
|
|
def get_thread(self, thread: str) -> dict[str, list[dict[str, str]] | discord.Member | int]:
|
|
|
|
"""Gets a copy of an entire thread"""
|
2024-01-12 15:54:33 +00:00
|
|
|
return self._internal.get(thread, {}).copy()
|
2024-01-12 15:51:25 +00:00
|
|
|
|
2024-01-12 17:11:16 +00:00
|
|
|
def find_thread(self, thread_id: str):
|
|
|
|
"""Attempts to find a thread."""
|
2024-01-12 17:18:23 +00:00
|
|
|
self.log.debug("Checking cache for %r...", thread_id)
|
2024-01-12 17:11:16 +00:00
|
|
|
if c := self.get_thread(thread_id):
|
|
|
|
return c
|
2024-01-12 17:18:23 +00:00
|
|
|
self.log.debug("Checking db for %r...", thread_id)
|
2024-01-12 17:11:16 +00:00
|
|
|
if d := self.load_thread(thread_id):
|
|
|
|
return d
|
2024-01-12 17:18:23 +00:00
|
|
|
self.log.warning("No thread with ID %r found.", thread_id)
|
2024-01-12 17:11:16 +00:00
|
|
|
|
2024-01-12 15:39:39 +00:00
|
|
|
|
2024-01-06 21:43:52 +00:00
|
|
|
SERVER_KEYS = list(CONFIG["ollama"].keys())
|
|
|
|
|
2024-01-11 14:41:07 +00:00
|
|
|
|
2024-03-22 09:08:03 +00:00
|
|
|
class OllamaGetPrompt(discord.ui.Modal):
|
|
|
|
|
|
|
|
def __init__(self, ctx: discord.ApplicationContext, prompt_type: str = "User"):
|
|
|
|
super().__init__(
|
|
|
|
discord.ui.InputText(
|
|
|
|
style=discord.InputTextStyle.long,
|
|
|
|
label="%s prompt" % prompt_type,
|
|
|
|
placeholder="Enter your prompt here.",
|
|
|
|
),
|
|
|
|
timeout=300,
|
|
|
|
title="Ollama %s prompt" % prompt_type,
|
|
|
|
)
|
|
|
|
self.ctx = ctx
|
|
|
|
self.prompt_type = prompt_type
|
|
|
|
self.value = None
|
|
|
|
|
|
|
|
async def interaction_check(self, interaction: discord.Interaction) -> bool:
|
|
|
|
return interaction.user == self.ctx.user
|
|
|
|
|
|
|
|
async def callback(self, interaction: Interaction):
|
|
|
|
await interaction.response.defer()
|
|
|
|
self.value = self.children[0].value
|
|
|
|
self.stop()
|
|
|
|
|
|
|
|
|
2024-03-22 09:14:52 +00:00
|
|
|
class PromptSelector(discord.ui.View):
|
|
|
|
def __init__(self, ctx: discord.ApplicationContext):
|
|
|
|
super().__init__(timeout=600, disable_on_timeout=True)
|
|
|
|
self.ctx = ctx
|
|
|
|
self.system_prompt = None
|
|
|
|
self.user_prompt = None
|
|
|
|
|
|
|
|
async def interaction_check(self, interaction: Interaction) -> bool:
|
|
|
|
return interaction.user == self.ctx.user
|
|
|
|
|
|
|
|
def update_ui(self):
|
|
|
|
if self.system_prompt is not None:
|
|
|
|
self.get_item("sys").style = discord.ButtonStyle.secondary # type: ignore
|
|
|
|
if self.user_prompt is not None:
|
|
|
|
self.get_item("usr").style = discord.ButtonStyle.secondary # type: ignore
|
|
|
|
|
|
|
|
@discord.ui.button(label="Set System Prompt", style=discord.ButtonStyle.primary, custom_id="sys")
|
|
|
|
async def set_system_prompt(self, btn: discord.ui.Button, interaction: Interaction):
|
|
|
|
modal = OllamaGetPrompt(self.ctx, "System")
|
|
|
|
await interaction.response.send_modal(modal)
|
|
|
|
await modal.wait()
|
|
|
|
self.system_prompt = modal.value
|
2024-03-22 09:16:38 +00:00
|
|
|
self.update_ui()
|
|
|
|
await interaction.edit_original_response(view=self)
|
2024-03-22 09:14:52 +00:00
|
|
|
|
2024-03-22 09:16:38 +00:00
|
|
|
@discord.ui.button(label="Set User Prompt", style=discord.ButtonStyle.primary, custom_id="usr")
|
2024-03-22 09:15:45 +00:00
|
|
|
async def set_user_prompt(self, btn: discord.ui.Button, interaction: Interaction):
|
2024-03-22 09:14:52 +00:00
|
|
|
modal = OllamaGetPrompt(self.ctx)
|
|
|
|
await interaction.response.send_modal(modal)
|
|
|
|
await modal.wait()
|
|
|
|
self.user_prompt = modal.value
|
2024-03-22 09:16:38 +00:00
|
|
|
self.update_ui()
|
|
|
|
await interaction.edit_original_response(view=self)
|
2024-03-22 09:14:52 +00:00
|
|
|
|
|
|
|
@discord.ui.button(label="Done", style=discord.ButtonStyle.success, custom_id="done")
|
|
|
|
async def done(self, btn: discord.ui.Button, interaction: Interaction):
|
|
|
|
self.stop()
|
|
|
|
|
|
|
|
|
2024-01-06 21:43:52 +00:00
|
|
|
class Ollama(commands.Cog):
|
|
|
|
def __init__(self, bot: commands.Bot):
|
|
|
|
self.bot = bot
|
|
|
|
self.log = logging.getLogger("jimmy.cogs.ollama")
|
2024-01-10 15:11:36 +00:00
|
|
|
self.last_server = 0
|
2024-01-10 15:59:13 +00:00
|
|
|
self.contexts = {}
|
2024-01-12 15:39:39 +00:00
|
|
|
self.history = ChatHistory()
|
2024-04-18 00:43:37 +01:00
|
|
|
self.lock = asyncio.Lock()
|
2024-01-10 15:11:36 +00:00
|
|
|
|
|
|
|
def next_server(self, increment: bool = True) -> str:
|
|
|
|
"""Returns the next server key."""
|
|
|
|
if increment:
|
|
|
|
self.last_server += 1
|
2024-04-14 18:02:22 +01:00
|
|
|
s = SERVER_KEYS[self.last_server % len(SERVER_KEYS)]
|
|
|
|
self.log.info("Next server is %s", s)
|
|
|
|
return s
|
2024-01-06 21:43:52 +00:00
|
|
|
|
2024-01-11 13:20:32 +00:00
|
|
|
async def check_server(self, url: str) -> bool:
|
|
|
|
"""Checks that a server is online and responding."""
|
2024-04-14 18:02:22 +01:00
|
|
|
if url in SERVER_KEYS:
|
|
|
|
url = CONFIG["ollama"][url]["base_url"]
|
2024-01-11 13:32:41 +00:00
|
|
|
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(10)) as session:
|
2024-04-14 18:11:57 +01:00
|
|
|
self.log.info("Checking if %r is online.", url)
|
2024-01-11 13:20:32 +00:00
|
|
|
try:
|
2024-01-11 13:32:41 +00:00
|
|
|
async with session.get(url + "/api/tags") as resp:
|
2024-04-14 18:11:57 +01:00
|
|
|
self.log.info("%r is online.", resp.url.host)
|
2024-01-11 13:20:32 +00:00
|
|
|
return resp.ok
|
2024-01-11 14:34:46 +00:00
|
|
|
except (aiohttp.ClientConnectionError, asyncio.TimeoutError):
|
2024-01-11 14:41:07 +00:00
|
|
|
self.log.warning("%r is offline.", url, exc_info=True)
|
2024-01-11 13:20:32 +00:00
|
|
|
return False
|
|
|
|
|
2024-01-06 21:43:52 +00:00
|
|
|
@commands.slash_command()
|
|
|
|
async def ollama(
|
2024-04-16 00:46:26 +01:00
|
|
|
self,
|
|
|
|
ctx: discord.ApplicationContext,
|
|
|
|
query: typing.Annotated[
|
|
|
|
str,
|
|
|
|
discord.Option(
|
2024-01-06 21:43:52 +00:00
|
|
|
str,
|
2024-04-16 00:46:26 +01:00
|
|
|
"The query to feed into ollama. Not the system prompt.",
|
|
|
|
),
|
|
|
|
],
|
|
|
|
model: typing.Annotated[
|
|
|
|
str,
|
|
|
|
discord.Option(
|
2024-01-10 15:59:13 +00:00
|
|
|
str,
|
2024-04-16 00:46:26 +01:00
|
|
|
"The model to use for ollama. Defaults to 'llama2-uncensored:latest'.",
|
|
|
|
default="llama2-uncensored:7b-chat",
|
|
|
|
),
|
|
|
|
],
|
|
|
|
server: typing.Annotated[
|
|
|
|
str, discord.Option(str, "The server to use for ollama.", default="next", choices=SERVER_KEYS)
|
|
|
|
],
|
|
|
|
context: typing.Annotated[
|
|
|
|
str, discord.Option(str, "The context key of a previous ollama response to use as context.", default=None)
|
|
|
|
],
|
|
|
|
give_acid: typing.Annotated[
|
|
|
|
bool,
|
|
|
|
discord.Option(
|
|
|
|
bool, "Whether to give the AI acid, LSD, and other hallucinogens before responding.", default=False
|
|
|
|
),
|
|
|
|
],
|
|
|
|
image: typing.Annotated[
|
|
|
|
discord.Attachment,
|
|
|
|
discord.Option(discord.Attachment, "An image to feed into ollama. Only works with llava.", default=None),
|
|
|
|
],
|
2024-01-06 21:43:52 +00:00
|
|
|
):
|
2024-04-25 19:33:34 +01:00
|
|
|
if not SERVER_KEYS:
|
|
|
|
return await ctx.respond("No servers available. Please try again later.")
|
2024-03-22 09:15:45 +00:00
|
|
|
system_query = None
|
2024-01-10 15:59:13 +00:00
|
|
|
if context is not None:
|
2024-01-12 15:54:33 +00:00
|
|
|
if not self.history.get_thread(context):
|
2024-01-10 15:59:13 +00:00
|
|
|
await ctx.respond("Invalid context key.")
|
|
|
|
return
|
2024-03-22 09:08:03 +00:00
|
|
|
|
|
|
|
try:
|
|
|
|
await ctx.defer()
|
|
|
|
except discord.HTTPException:
|
|
|
|
pass
|
2024-01-06 21:43:52 +00:00
|
|
|
|
2024-03-22 09:14:52 +00:00
|
|
|
if query == "$":
|
|
|
|
v = PromptSelector(ctx)
|
|
|
|
await ctx.respond("Select edit your prompts, as desired. Click done when you want to continue.", view=v)
|
|
|
|
await v.wait()
|
|
|
|
query = v.user_prompt or query
|
2024-03-22 09:15:45 +00:00
|
|
|
system_query = v.system_prompt
|
2024-03-22 09:18:08 +00:00
|
|
|
await ctx.delete(delay=0.1)
|
2024-03-22 09:14:52 +00:00
|
|
|
|
2024-01-06 21:43:52 +00:00
|
|
|
model = model.casefold()
|
|
|
|
try:
|
|
|
|
model, tag = model.split(":", 1)
|
|
|
|
model = model + ":" + tag
|
2024-01-10 15:11:36 +00:00
|
|
|
self.log.debug("Model %r already has a tag", model)
|
2024-01-06 21:43:52 +00:00
|
|
|
except ValueError:
|
2024-03-22 09:08:03 +00:00
|
|
|
model += ":latest"
|
2024-01-06 21:43:52 +00:00
|
|
|
self.log.debug("Resolved model to %r" % model)
|
|
|
|
|
2024-01-10 16:10:45 +00:00
|
|
|
if image:
|
|
|
|
if fnmatch(model, "llava:*") is False:
|
2024-01-11 13:21:41 +00:00
|
|
|
await ctx.respond(
|
2024-04-16 00:46:26 +01:00
|
|
|
"You can only use images with llava. Switching model to `llava:latest`.", delete_after=5
|
2024-01-11 13:21:41 +00:00
|
|
|
)
|
|
|
|
model = "llava:latest"
|
|
|
|
|
|
|
|
if image.size > 1024 * 1024 * 25:
|
2024-01-10 16:10:45 +00:00
|
|
|
await ctx.respond("Attachment is too large. Maximum size is 25 MB, for sanity. Try compressing it.")
|
|
|
|
return
|
|
|
|
elif not fnmatch(image.content_type, "image/*"):
|
|
|
|
await ctx.respond("Attachment is not an image. Try using a different file.")
|
|
|
|
return
|
|
|
|
else:
|
|
|
|
data = io.BytesIO()
|
|
|
|
await image.save(data)
|
|
|
|
data.seek(0)
|
2024-03-22 09:08:03 +00:00
|
|
|
image_data = base64.b64encode(data.read()).decode()
|
2024-01-10 16:10:45 +00:00
|
|
|
else:
|
|
|
|
image_data = None
|
|
|
|
|
2024-01-10 15:11:36 +00:00
|
|
|
if server == "next":
|
|
|
|
server = self.next_server()
|
|
|
|
elif server not in CONFIG["ollama"]:
|
2024-01-06 21:43:52 +00:00
|
|
|
await ctx.respond("Invalid server")
|
|
|
|
return
|
|
|
|
|
|
|
|
server_config = CONFIG["ollama"][server]
|
|
|
|
for model_pattern in server_config["allowed_models"]:
|
|
|
|
if fnmatch(model, model_pattern):
|
|
|
|
break
|
|
|
|
else:
|
|
|
|
allowed_models = ", ".join(map(discord.utils.escape_markdown, server_config["allowed_models"]))
|
|
|
|
await ctx.respond(f"Invalid model. You can only use one of the following models: {allowed_models}")
|
|
|
|
return
|
|
|
|
|
|
|
|
async with aiohttp.ClientSession(
|
2024-04-16 00:46:26 +01:00
|
|
|
base_url=server_config["base_url"],
|
|
|
|
timeout=aiohttp.ClientTimeout(connect=30, sock_read=10800, sock_connect=30, total=10830),
|
2024-01-06 21:43:52 +00:00
|
|
|
) as session:
|
|
|
|
embed = discord.Embed(
|
|
|
|
title="Checking server...",
|
|
|
|
description=f"Checking that specified model and tag ({model}) are available on the server.",
|
|
|
|
color=discord.Color.blurple(),
|
2024-04-16 00:46:26 +01:00
|
|
|
timestamp=discord.utils.utcnow(),
|
2024-01-06 21:43:52 +00:00
|
|
|
)
|
2024-01-11 13:20:32 +00:00
|
|
|
embed.set_footer(text="Using server %r" % server, icon_url=server_config.get("icon_url"))
|
2024-01-06 21:43:52 +00:00
|
|
|
await ctx.respond(embed=embed)
|
2024-01-11 13:20:32 +00:00
|
|
|
if not await self.check_server(server_config["base_url"]):
|
|
|
|
for i in range(10):
|
|
|
|
server = self.next_server()
|
|
|
|
embed = discord.Embed(
|
|
|
|
title="Server was offline. Trying next server.",
|
|
|
|
description=f"Trying server {server}...",
|
|
|
|
color=discord.Color.gold(),
|
2024-04-16 00:46:26 +01:00
|
|
|
timestamp=discord.utils.utcnow(),
|
2024-01-11 13:20:32 +00:00
|
|
|
)
|
|
|
|
embed.set_footer(text="Using server %r" % server, icon_url=server_config.get("icon_url"))
|
|
|
|
await ctx.edit(embed=embed)
|
2024-01-11 14:41:07 +00:00
|
|
|
await asyncio.sleep(1)
|
2024-01-11 13:20:32 +00:00
|
|
|
if await self.check_server(CONFIG["ollama"][server]["base_url"]):
|
|
|
|
server_config = CONFIG["ollama"][server]
|
2024-01-16 10:14:50 +00:00
|
|
|
setattr(session, "_base_url", URL(server_config["base_url"]))
|
2024-01-11 13:20:32 +00:00
|
|
|
break
|
|
|
|
else:
|
|
|
|
embed = discord.Embed(
|
|
|
|
title="All servers are offline.",
|
|
|
|
description="Please try again later.",
|
|
|
|
color=discord.Color.red(),
|
2024-04-16 00:46:26 +01:00
|
|
|
timestamp=discord.utils.utcnow(),
|
2024-01-11 13:20:32 +00:00
|
|
|
)
|
|
|
|
embed.set_footer(text="Unable to continue.")
|
|
|
|
return await ctx.edit(embed=embed)
|
2024-01-06 21:43:52 +00:00
|
|
|
|
|
|
|
try:
|
2024-01-06 22:11:20 +00:00
|
|
|
self.log.debug("Connecting to %r", server_config["base_url"])
|
2024-01-06 21:58:09 +00:00
|
|
|
async with session.post("/api/show", json={"name": model}) as resp:
|
2024-01-06 22:11:20 +00:00
|
|
|
self.log.debug("%r responded.", server_config["base_url"])
|
2024-01-06 21:43:52 +00:00
|
|
|
if resp.status not in [404, 200]:
|
|
|
|
embed = discord.Embed(
|
|
|
|
url=resp.url,
|
|
|
|
title=f"HTTP {resp.status} {resp.reason!r} while checking for model.",
|
|
|
|
description=f"```{await resp.text() or 'No response body'}```"[:4096],
|
|
|
|
color=discord.Color.red(),
|
2024-04-16 00:46:26 +01:00
|
|
|
timestamp=discord.utils.utcnow(),
|
2024-01-06 21:43:52 +00:00
|
|
|
)
|
|
|
|
embed.set_footer(text="Unable to continue.")
|
|
|
|
return await ctx.edit(embed=embed)
|
|
|
|
except aiohttp.ClientConnectionError as e:
|
|
|
|
embed = discord.Embed(
|
|
|
|
title="Connection error while checking for model.",
|
|
|
|
description=f"```{e}```"[:4096],
|
|
|
|
color=discord.Color.red(),
|
2024-04-16 00:46:26 +01:00
|
|
|
timestamp=discord.utils.utcnow(),
|
2024-01-06 21:43:52 +00:00
|
|
|
)
|
|
|
|
embed.set_footer(text="Unable to continue.")
|
|
|
|
return await ctx.edit(embed=embed)
|
|
|
|
|
|
|
|
if resp.status == 404:
|
2024-01-06 22:11:20 +00:00
|
|
|
self.log.debug("Beginning download of %r", model)
|
2024-01-10 20:20:00 +00:00
|
|
|
|
2024-01-10 15:59:13 +00:00
|
|
|
def progress_bar(_v: float, action: str = None):
|
|
|
|
bar = "\N{large green square}" * round(_v / 10)
|
2024-01-06 21:43:52 +00:00
|
|
|
bar += "\N{white large square}" * (10 - len(bar))
|
2024-01-10 15:59:13 +00:00
|
|
|
bar += f" {_v:.2f}%"
|
2024-01-06 21:43:52 +00:00
|
|
|
if action:
|
|
|
|
return f"{action} {bar}"
|
|
|
|
return bar
|
|
|
|
|
|
|
|
embed = discord.Embed(
|
|
|
|
title=f"Downloading {model!r}",
|
|
|
|
description=f"Downloading {model!r} from {server_config['base_url']}",
|
|
|
|
color=discord.Color.blurple(),
|
2024-04-16 00:46:26 +01:00
|
|
|
timestamp=discord.utils.utcnow(),
|
2024-01-06 21:43:52 +00:00
|
|
|
)
|
|
|
|
embed.add_field(name="Progress", value=progress_bar(0))
|
|
|
|
await ctx.edit(embed=embed)
|
|
|
|
|
|
|
|
last_update = time.time()
|
|
|
|
|
2024-01-06 21:58:09 +00:00
|
|
|
async with session.post("/api/pull", json={"name": model, "stream": True}, timeout=None) as response:
|
2024-01-06 21:43:52 +00:00
|
|
|
if response.status != 200:
|
|
|
|
embed = discord.Embed(
|
|
|
|
url=response.url,
|
|
|
|
title=f"HTTP {response.status} {response.reason!r} while downloading model.",
|
|
|
|
description=f"```{await response.text() or 'No response body'}```"[:4096],
|
|
|
|
color=discord.Color.red(),
|
2024-04-16 00:46:26 +01:00
|
|
|
timestamp=discord.utils.utcnow(),
|
2024-01-06 21:43:52 +00:00
|
|
|
)
|
|
|
|
embed.set_footer(text="Unable to continue.")
|
|
|
|
return await ctx.edit(embed=embed)
|
2024-01-11 14:58:00 +00:00
|
|
|
view = OllamaView(ctx)
|
2024-04-13 23:51:50 +01:00
|
|
|
async for line in ollama_stream(response.content):
|
2024-01-11 14:58:00 +00:00
|
|
|
if view.cancel.is_set():
|
|
|
|
embed = discord.Embed(
|
|
|
|
title="Download cancelled.",
|
|
|
|
colour=discord.Colour.red(),
|
2024-04-16 00:46:26 +01:00
|
|
|
timestamp=discord.utils.utcnow(),
|
2024-01-11 14:58:00 +00:00
|
|
|
)
|
|
|
|
return await ctx.edit(embed=embed, view=None)
|
2024-01-06 21:43:52 +00:00
|
|
|
if time.time() >= (last_update + 5.1):
|
|
|
|
if line.get("total") is not None and line.get("completed") is not None:
|
|
|
|
percent = (line["completed"] / line["total"]) * 100
|
|
|
|
else:
|
|
|
|
percent = 50.0
|
|
|
|
|
|
|
|
embed.fields[0].value = progress_bar(percent, line["status"])
|
2024-01-11 14:58:00 +00:00
|
|
|
await ctx.edit(embed=embed, view=view)
|
2024-01-06 21:43:52 +00:00
|
|
|
last_update = time.time()
|
2024-01-06 22:11:20 +00:00
|
|
|
else:
|
|
|
|
self.log.debug("Model %r already exists on server.", model)
|
2024-01-06 21:43:52 +00:00
|
|
|
|
2024-01-10 15:59:13 +00:00
|
|
|
key = os.urandom(6).hex()
|
|
|
|
|
2024-01-06 21:43:52 +00:00
|
|
|
embed = discord.Embed(
|
|
|
|
title="Generating response...",
|
2024-01-06 22:07:31 +00:00
|
|
|
description=">>> ",
|
|
|
|
color=discord.Color.blurple(),
|
2024-04-16 00:46:26 +01:00
|
|
|
timestamp=discord.utils.utcnow(),
|
2024-01-06 21:43:52 +00:00
|
|
|
)
|
2024-01-10 15:35:18 +00:00
|
|
|
embed.set_author(
|
|
|
|
name=model,
|
|
|
|
url="https://ollama.ai/library/" + model.split(":")[0],
|
2024-04-16 00:46:26 +01:00
|
|
|
icon_url="https://ollama.ai/public/ollama.png",
|
2024-01-10 15:35:18 +00:00
|
|
|
)
|
2024-01-10 10:29:48 +00:00
|
|
|
embed.add_field(
|
2024-04-16 00:46:26 +01:00
|
|
|
name="Prompt", value=">>> " + textwrap.shorten(query, width=1020, placeholder="..."), inline=False
|
2024-01-10 10:29:48 +00:00
|
|
|
)
|
2024-01-10 10:39:37 +00:00
|
|
|
embed.set_footer(text="Using server %r" % server, icon_url=server_config.get("icon_url"))
|
2024-01-10 16:20:03 +00:00
|
|
|
if image_data:
|
|
|
|
if (image.height / image.width) >= 1.5:
|
|
|
|
embed.set_image(url=image.url)
|
|
|
|
else:
|
|
|
|
embed.set_thumbnail(url=image.url)
|
2024-01-10 10:29:48 +00:00
|
|
|
view = OllamaView(ctx)
|
2024-01-09 22:54:37 +00:00
|
|
|
try:
|
2024-01-10 10:29:48 +00:00
|
|
|
await ctx.edit(embed=embed, view=view)
|
2024-01-09 22:54:37 +00:00
|
|
|
except discord.NotFound:
|
2024-01-10 10:29:48 +00:00
|
|
|
await ctx.respond(embed=embed, view=view)
|
2024-01-10 15:59:13 +00:00
|
|
|
self.log.debug("Beginning to generate response with key %r.", key)
|
|
|
|
|
2024-01-12 16:47:45 +00:00
|
|
|
if context is None:
|
2024-03-22 09:08:03 +00:00
|
|
|
context = self.history.create_thread(ctx.user, system_query)
|
2024-01-12 16:47:45 +00:00
|
|
|
elif context is not None and self.history.get_thread(context) is None:
|
2024-01-12 17:11:16 +00:00
|
|
|
__thread = self.history.find_thread(context)
|
|
|
|
if not __thread:
|
2024-01-12 16:47:45 +00:00
|
|
|
return await ctx.respond("Invalid thread ID.")
|
|
|
|
else:
|
2024-01-12 17:11:16 +00:00
|
|
|
context = list(__thread.keys())[0]
|
2024-01-12 16:47:45 +00:00
|
|
|
|
2024-01-12 15:39:39 +00:00
|
|
|
messages = self.history.get_history(context)
|
2024-04-16 00:46:26 +01:00
|
|
|
user_message = {"role": "user", "content": query}
|
2024-01-12 15:39:39 +00:00
|
|
|
if image_data:
|
|
|
|
user_message["images"] = [image_data]
|
|
|
|
messages.append(user_message)
|
2024-01-12 15:51:25 +00:00
|
|
|
|
|
|
|
params = {"seed": self.history.get_thread(context)["seed"]}
|
|
|
|
if give_acid is True:
|
2024-02-05 15:15:20 +00:00
|
|
|
params["temperature"] = 2
|
|
|
|
params["top_k"] = 0
|
|
|
|
params["top_p"] = 2
|
|
|
|
params["repeat_penalty"] = 2
|
2024-01-12 15:51:25 +00:00
|
|
|
|
2024-04-16 00:46:26 +01:00
|
|
|
payload = {"model": model, "stream": True, "options": params, "messages": messages}
|
2024-01-06 21:43:52 +00:00
|
|
|
async with session.post(
|
2024-01-12 15:39:39 +00:00
|
|
|
"/api/chat",
|
2024-01-10 15:59:13 +00:00
|
|
|
json=payload,
|
2024-01-06 21:43:52 +00:00
|
|
|
) as response:
|
|
|
|
if response.status != 200:
|
|
|
|
embed = discord.Embed(
|
|
|
|
url=response.url,
|
|
|
|
title=f"HTTP {response.status} {response.reason!r} while generating response.",
|
|
|
|
description=f"```{await response.text() or 'No response body'}```"[:4096],
|
|
|
|
color=discord.Color.red(),
|
2024-04-16 00:46:26 +01:00
|
|
|
timestamp=discord.utils.utcnow(),
|
2024-01-06 21:43:52 +00:00
|
|
|
)
|
|
|
|
embed.set_footer(text="Unable to continue.")
|
|
|
|
return await ctx.edit(embed=embed)
|
|
|
|
|
|
|
|
last_update = time.time()
|
2024-01-09 14:49:29 +00:00
|
|
|
buffer = io.StringIO()
|
2024-01-10 10:29:48 +00:00
|
|
|
if not view.cancel.is_set():
|
2024-04-13 23:51:50 +01:00
|
|
|
async for line in ollama_stream(response.content):
|
2024-01-12 15:43:16 +00:00
|
|
|
buffer.write(line["message"]["content"])
|
|
|
|
embed.description += line["message"]["content"]
|
2024-01-10 10:29:48 +00:00
|
|
|
embed.timestamp = discord.utils.utcnow()
|
2024-02-06 00:54:53 +00:00
|
|
|
if len(embed.description) >= 4000:
|
|
|
|
embed.description = "[...]" + line["message"]["content"]
|
|
|
|
if len(embed.description) >= 3250:
|
|
|
|
embed.colour = discord.Color.gold()
|
|
|
|
embed.set_footer(text="Warning: {:,}/4096 characters.".format(len(embed.description)))
|
|
|
|
else:
|
|
|
|
embed.colour = discord.Color.blurple()
|
|
|
|
embed.set_footer(text="Using server %r" % server, icon_url=server_config.get("icon_url"))
|
2024-01-10 10:13:37 +00:00
|
|
|
|
2024-01-10 10:29:48 +00:00
|
|
|
if view.cancel.is_set():
|
|
|
|
break
|
2024-01-10 10:13:37 +00:00
|
|
|
|
2024-01-10 10:29:48 +00:00
|
|
|
if time.time() >= (last_update + 5.1):
|
|
|
|
await ctx.edit(embed=embed, view=view)
|
|
|
|
self.log.debug(f"Updating message ({last_update} -> {time.time()})")
|
|
|
|
last_update = time.time()
|
2024-01-10 10:13:37 +00:00
|
|
|
view.stop()
|
2024-01-12 15:48:58 +00:00
|
|
|
self.history.add_message(context, "user", user_message["content"], user_message.get("images"))
|
2024-01-12 15:39:39 +00:00
|
|
|
self.history.add_message(context, "assistant", buffer.getvalue())
|
2024-01-12 17:18:23 +00:00
|
|
|
self.history.save_thread(context)
|
|
|
|
|
2024-01-12 15:39:39 +00:00
|
|
|
embed.add_field(name="Context Key", value=context, inline=True)
|
2024-01-09 09:34:55 +00:00
|
|
|
self.log.debug("Ollama finished consuming.")
|
2024-01-09 08:56:46 +00:00
|
|
|
embed.title = "Done!"
|
2024-01-12 10:00:12 +00:00
|
|
|
embed.colour = discord.Color.green()
|
2024-01-09 14:49:29 +00:00
|
|
|
|
|
|
|
value = buffer.getvalue()
|
|
|
|
if len(value) >= 4096:
|
|
|
|
embeds = [discord.Embed(title="Done!", colour=discord.Color.green())]
|
2024-04-16 00:46:26 +01:00
|
|
|
|
2024-01-09 14:49:29 +00:00
|
|
|
current_page = ""
|
|
|
|
for word in value.split():
|
|
|
|
if len(current_page) + len(word) >= 4096:
|
|
|
|
embeds.append(discord.Embed(description=current_page))
|
|
|
|
current_page = ""
|
|
|
|
current_page += word + " "
|
|
|
|
else:
|
|
|
|
embeds.append(discord.Embed(description=current_page))
|
2024-04-16 00:46:26 +01:00
|
|
|
|
2024-01-10 10:41:30 +00:00
|
|
|
await ctx.edit(embeds=embeds, view=None)
|
2024-01-09 14:49:29 +00:00
|
|
|
else:
|
2024-01-10 10:41:30 +00:00
|
|
|
await ctx.edit(embed=embed, view=None)
|
2024-01-06 21:43:52 +00:00
|
|
|
|
2024-01-12 09:53:57 +00:00
|
|
|
if line.get("done"):
|
2024-01-12 10:00:12 +00:00
|
|
|
total_duration = get_time_spent(line["total_duration"])
|
|
|
|
load_duration = get_time_spent(line["load_duration"])
|
|
|
|
prompt_eval_duration = get_time_spent(line["prompt_eval_duration"])
|
|
|
|
eval_duration = get_time_spent(line["eval_duration"])
|
2024-01-12 09:53:57 +00:00
|
|
|
|
|
|
|
embed = discord.Embed(
|
|
|
|
title="Timings",
|
|
|
|
description=f"Total: {total_duration}\nLoad: {load_duration}\n"
|
2024-04-16 00:46:26 +01:00
|
|
|
f"Prompt Eval: {prompt_eval_duration}\nEval: {eval_duration}",
|
2024-01-12 09:53:57 +00:00
|
|
|
color=discord.Color.blurple(),
|
2024-04-16 00:46:26 +01:00
|
|
|
timestamp=discord.utils.utcnow(),
|
2024-01-12 09:53:57 +00:00
|
|
|
)
|
|
|
|
return await ctx.respond(embed=embed, ephemeral=True)
|
|
|
|
|
2024-01-12 16:26:32 +00:00
|
|
|
@commands.slash_command(name="ollama-history")
|
|
|
|
async def ollama_history(
|
2024-04-16 00:46:26 +01:00
|
|
|
self,
|
|
|
|
ctx: discord.ApplicationContext,
|
|
|
|
thread_id: typing.Annotated[
|
|
|
|
str,
|
|
|
|
discord.Option(
|
|
|
|
name="thread_id",
|
|
|
|
description="Thread/Context ID",
|
|
|
|
type=str,
|
|
|
|
autocomplete=ChatHistory.autocomplete,
|
|
|
|
),
|
|
|
|
],
|
2024-01-12 16:26:32 +00:00
|
|
|
):
|
|
|
|
"""Shows the history for a thread."""
|
2024-04-25 19:33:34 +01:00
|
|
|
if not SERVER_KEYS:
|
|
|
|
return await ctx.respond("No servers available. Please try again later.")
|
2024-01-12 16:26:32 +00:00
|
|
|
paginator = commands.Paginator("", "", 4000, "\n\n")
|
|
|
|
|
2024-01-12 17:11:16 +00:00
|
|
|
thread = self.history.load_thread(thread_id)
|
|
|
|
if not thread:
|
|
|
|
return await ctx.respond("No thread with that ID exists.")
|
|
|
|
history = self.history.get_history(thread_id)
|
2024-01-12 16:26:32 +00:00
|
|
|
if not history:
|
|
|
|
return await ctx.respond("No history or invalid context key.")
|
|
|
|
|
|
|
|
for message in history:
|
|
|
|
if message["role"] == "system":
|
|
|
|
continue
|
|
|
|
max_length = 4000 - len("> **%s**: " % message["role"])
|
2024-04-16 00:46:26 +01:00
|
|
|
paginator.add_line("> **{}**: {}".format(message["role"], textwrap.shorten(message["content"], max_length)))
|
2024-01-12 16:26:32 +00:00
|
|
|
|
|
|
|
embeds = []
|
|
|
|
for page in paginator.pages:
|
2024-04-16 00:46:26 +01:00
|
|
|
embeds.append(discord.Embed(description=page))
|
2024-01-12 17:01:41 +00:00
|
|
|
ephemeral = len(embeds) > 1
|
2024-01-12 16:47:45 +00:00
|
|
|
for chunk in discord.utils.as_chunks(iter(embeds or [discord.Embed(title="No Content.")]), 10):
|
2024-01-12 17:01:41 +00:00
|
|
|
await ctx.respond(embeds=chunk, ephemeral=ephemeral)
|
2024-01-12 16:26:32 +00:00
|
|
|
|
2024-04-13 23:51:50 +01:00
|
|
|
@commands.message_command(name="Ask AI")
|
|
|
|
async def ask_ai(self, ctx: discord.ApplicationContext, message: discord.Message):
|
2024-04-25 19:33:34 +01:00
|
|
|
if not SERVER_KEYS:
|
|
|
|
return await ctx.respond("No servers available. Please try again later.")
|
2024-04-13 23:51:50 +01:00
|
|
|
thread = self.history.create_thread(message.author)
|
|
|
|
content = message.clean_content
|
|
|
|
if not content:
|
|
|
|
if message.embeds:
|
|
|
|
content = message.embeds[0].description or message.embeds[0].title
|
|
|
|
if not content:
|
|
|
|
return await ctx.respond("No content to send to AI.", ephemeral=True)
|
|
|
|
await ctx.defer()
|
2024-04-16 00:46:26 +01:00
|
|
|
user_message = {"role": "user", "content": message.content}
|
2024-04-13 23:51:50 +01:00
|
|
|
self.history.add_message(thread, "user", user_message["content"])
|
|
|
|
|
2024-04-14 18:11:34 +01:00
|
|
|
for _ in range(10):
|
2024-04-13 23:51:50 +01:00
|
|
|
server = self.next_server()
|
2024-04-14 18:11:34 +01:00
|
|
|
if await self.check_server(CONFIG["ollama"][server]["base_url"]):
|
2024-04-13 23:51:50 +01:00
|
|
|
break
|
|
|
|
else:
|
|
|
|
return await ctx.respond("All servers are offline. Please try again later.", ephemeral=True)
|
|
|
|
|
|
|
|
client = OllamaClient(CONFIG["ollama"][server]["base_url"])
|
2024-04-18 00:57:46 +01:00
|
|
|
if not await client.has_model_named("orca-mini", "7b"):
|
|
|
|
with client.download_model("orca-mini", "7b") as handler:
|
2024-04-14 17:23:05 +01:00
|
|
|
async for _ in handler:
|
2024-04-14 18:47:42 +01:00
|
|
|
self.log.info(
|
2024-04-18 00:57:46 +01:00
|
|
|
"Downloading orca-mini:7b on server %r - %s (%.2f%%)", server, handler.status, handler.percent
|
2024-04-14 18:47:42 +01:00
|
|
|
)
|
2024-04-13 23:51:50 +01:00
|
|
|
|
2024-04-18 00:43:37 +01:00
|
|
|
if self.lock.locked():
|
|
|
|
await ctx.respond("Waiting for server to be free...")
|
|
|
|
async with self.lock:
|
|
|
|
await ctx.delete(delay=0.1)
|
2024-04-13 23:51:50 +01:00
|
|
|
messages = self.history.get_history(thread)
|
2024-04-18 00:43:37 +01:00
|
|
|
embed = discord.Embed(description="*Waking Ollama up...*")
|
2024-04-18 00:52:15 +01:00
|
|
|
self.log.debug("Acquiring lock")
|
2024-04-18 00:43:37 +01:00
|
|
|
async with self.lock:
|
2024-04-18 00:49:56 +01:00
|
|
|
await ctx.respond(embed=embed, ephemeral=True)
|
2024-04-18 00:43:37 +01:00
|
|
|
last_edit = time.time()
|
2024-04-18 00:49:56 +01:00
|
|
|
msg = None
|
2024-04-18 00:57:46 +01:00
|
|
|
with client.new_chat("orca-mini:7b", messages) as handler:
|
2024-04-18 00:52:15 +01:00
|
|
|
self.log.info("New chat connection established.")
|
2024-04-18 00:43:37 +01:00
|
|
|
async for ln in handler:
|
|
|
|
done = ln.get("done") is True
|
|
|
|
embed.description = handler.result
|
|
|
|
if len(embed.description) >= 4096:
|
|
|
|
break
|
|
|
|
if len(embed.description) >= 3250:
|
|
|
|
embed.colour = discord.Color.gold()
|
|
|
|
embed.set_footer(text="Warning: {:,}/4096 characters.".format(len(embed.description)))
|
|
|
|
else:
|
|
|
|
embed.colour = discord.Color.blurple()
|
|
|
|
embed.set_footer(text="Using server %r" % server, icon_url=CONFIG["ollama"][server].get("icon_url"))
|
2024-04-18 00:49:56 +01:00
|
|
|
if msg is None:
|
|
|
|
await ctx.delete(delay=0.1)
|
|
|
|
msg = await message.reply(embed=embed)
|
2024-04-18 00:43:37 +01:00
|
|
|
last_edit = time.time()
|
2024-04-18 00:49:56 +01:00
|
|
|
else:
|
|
|
|
if time.time() >= (last_edit + 5.1) or done is True:
|
2024-04-18 00:51:36 +01:00
|
|
|
await msg.edit(embed=embed)
|
2024-04-18 00:49:56 +01:00
|
|
|
last_edit = time.time()
|
2024-04-18 00:43:37 +01:00
|
|
|
if done:
|
|
|
|
break
|
|
|
|
embed.colour = discord.Colour.dark_theme()
|
2024-04-18 00:46:04 +01:00
|
|
|
return await msg.edit(embed=embed)
|
2024-04-13 23:51:50 +01:00
|
|
|
|
2024-01-12 09:53:57 +00:00
|
|
|
|
2024-01-06 21:43:52 +00:00
|
|
|
def setup(bot):
|
|
|
|
bot.add_cog(Ollama(bot))
|