From 0cc5ce71503ae72d2815ddfa9da9f76f80cccc70 Mon Sep 17 00:00:00 2001 From: nexy7574 Date: Sat, 13 Apr 2024 23:51:50 +0100 Subject: [PATCH] Start writing ollama client class --- src/cogs/ollama.py | 266 ++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 252 insertions(+), 14 deletions(-) diff --git a/src/cogs/ollama.py b/src/cogs/ollama.py index 36dad41..d0a1659 100644 --- a/src/cogs/ollama.py +++ b/src/cogs/ollama.py @@ -1,4 +1,5 @@ import asyncio +import contextlib import json import logging import os @@ -7,6 +8,8 @@ import time import typing import base64 import io + +import httpx import redis from discord import Interaction @@ -20,6 +23,16 @@ from discord.ext import commands from conf import CONFIG +async def ollama_stream(iterator: aiohttp.StreamReader) -> typing.AsyncIterator[dict]: + async for line in iterator: + line = line.decode("utf-8", "replace").strip() + try: + line = json.loads(line) + except json.JSONDecodeError: + continue + yield line + + def get_time_spent(nanoseconds: int) -> str: hours, minutes, seconds = 0, 0, 0 seconds = nanoseconds / 1e9 @@ -50,6 +63,195 @@ def get_time_spent(nanoseconds: int) -> str: return ", ".join(reversed(result)) +class OllamaDownloadHandler: + def __init__(self, client: httpx.AsyncClient, model: str): + self.client = client + self.model = model + self._abort = asyncio.Event() + self._total = 1 + self._completed = 0 + self.status = "starting" + + self.total_duration_s = 0 + self.load_duration_s = 0 + self.prompt_eval_duration_s = 0 + self.eval_duration_s = 0 + self.eval_count = 0 + self.prompt_eval_count = 0 + + def abort(self): + self._abort.set() + + @property + def percent(self) -> float: + return round((self._completed / self._total) * 100, 2) + + def parse_line(self, line: dict): + if line.get("total"): + self._total = line["total"] + if line.get("completed"): + self._completed = line["completed"] + if line.get("status"): + self.status = line["status"] + + async def __aiter__(self): + async with self.client.post("/api/pull", json={"name": self.model, "stream": True}, timeout=None) as response: + response.raise_for_status() + async for line in ollama_stream(response.content): + self.parse_line(line) + if self._abort.is_set(): + break + yield line + + def __await__(self): + async for _ in self: + pass + + +class OllamaChatHandler: + def __init__(self, client: httpx.AsyncClient, model: str, messages: list): + self.client = client + self.model = model + self.messages = messages + self._abort = asyncio.Event() + self.buffer = io.StringIO() + + self.total_duration_s = 0 + self.load_duration_s = 0 + self.prompt_eval_duration_s = 0 + self.eval_duration_s = 0 + self.eval_count = 0 + self.prompt_eval_count = 0 + + def abort(self): + self._abort.set() + + @property + def result(self) -> str: + """The current response. Can be called multiple times.""" + return self.buffer.getvalue() + + def parse_line(self, line: dict): + if line.get("total_duration"): + self.total_duration_s = line["total_duration"] / 1e9 + if line.get("load_duration"): + self.load_duration_s = line["load_duration"] / 1e9 + if line.get("prompt_eval_duration"): + self.prompt_eval_duration_s = line["prompt_eval_duration"] / 1e9 + if line.get("eval_duration"): + self.eval_duration_s = line["eval_duration"] / 1e9 + + if line.get("eval_count"): + self.eval_count = line["eval_count"] + if line.get("prompt_eval_count"): + self.prompt_eval_count = line["prompt_eval_count"] + + async def __aiter__(self): + async with self.client.post( + "/api/chat", + json={ + "model": self.model, + "stream": True, + "messages": self.messages + } + ) as response: + response.raise_for_status() + async for line in ollama_stream(response.content): + if self._abort.is_set(): + break + + if line.get("message"): + self.buffer.write(line["message"]["content"]) + yield line + + if line.get("done"): + break + + @classmethod + async def get_streamless(cls, client: httpx.AsyncClient, model: str, messages: list) -> "OllamaChatHandler": + async with client.post("/api/chat", json={"model": model, "messages": messages, "stream": False}) as response: + response.raise_for_status() + handler = cls(client, model, messages) + line = await response.json() + handler.parse_line(line) + handler.buffer.write(line["message"]["content"]) + return handler + + +class OllamaClient: + def __init__(self, base_url: str, authorisation: tuple[str, str] = None): + self.base_url = base_url + self.authorisation = authorisation + + def _with_async_client(self, t) -> contextlib.AbstractContextManager[httpx.AsyncClient]: + with httpx.AsyncClient(base_url=self.base_url, timeout=t, auth=self.authorisation) as client: + yield client + + def with_client( + self, + timeout: httpx.Timeout | float | int | None = None + ) -> contextlib.AbstractContextManager[httpx.AsyncClient]: + """ + Creates an instance for a request, with properly populated values. + :param timeout: + :return: + """ + if isinstance(timeout, (float, int)): + if timeout == -1: + timeout = None + timeout = httpx.Timeout(timeout) + else: + timeout = timeout or httpx.Timeout(60) + return self._with_async_client(timeout) + + async def get_tags(self) -> dict[typing.Literal["models"], dict[str, str, int, dict[str, str, None]]]: + """ + Gets the tags for the server. + :return: + """ + async with self.with_client() as client: + async with client.get("/api/tags") as resp: + return await resp.json() + + async def has_model_named(self, name: str, tag: str = None) -> bool: + """Checks that the given server has the model downloaded, optionally with a tag. + + :param name: The name of the model (e.g. orca-mini, orca-mini:latest) + :param tag: a specific tag to check for (e.g. latest, chat) + :return: A boolean indicating an existing download.""" + if tag is not None: + name += ":" + tag + async with self.with_client() as client: + async with client.post("/api/show", json={"name": name}) as resp: + return resp.status == 200 + + def download_model(self, name: str, tag: str = "latest") -> OllamaDownloadHandler: + """Starts the download for a model. + + :param name: The name of the model. + :param tag: The tag of the model. Defaults to latest. + :return: An OllamaDownloadHandler instance. + """ + handler = OllamaDownloadHandler(httpx.AsyncClient(base_url=self.base_url), name + ":" + tag) + return handler + + def new_chat( + self, + model: str, + messages: list[dict[str, str]], + ) -> OllamaChatHandler: + """ + Starts a chat with the given messages. + + :param model: + :param messages: + :return: + """ + handler = OllamaChatHandler(httpx.AsyncClient(base_url=self.base_url), model, messages) + return handler + + + class OllamaView(View): def __init__(self, ctx: discord.ApplicationContext): super().__init__(timeout=3600, disable_on_timeout=True) @@ -93,6 +295,10 @@ class ChatHistory: def create_thread(self, member: discord.Member, default: str | None = None) -> str: """ Creates a thread, returns its ID. + + :param member: The member who created the thread. + :param default: The system prompt to use. + :return: The thread's ID. """ key = os.urandom(3).hex() self._internal[key] = { @@ -269,18 +475,6 @@ class Ollama(commands.Cog): self.last_server += 1 return SERVER_KEYS[self.last_server % len(SERVER_KEYS)] - async def ollama_stream(self, iterator: aiohttp.StreamReader) -> typing.AsyncIterator[dict]: - async for line in iterator: - original_line = line - line = line.decode("utf-8", "replace").strip() - try: - line = json.loads(line) - except json.JSONDecodeError: - self.log.warning("Unable to decode JSON: %r", original_line) - continue - else: - self.log.debug("Decoded JSON %r -> %r", original_line, line) - yield line async def check_server(self, url: str) -> bool: """Checks that a server is online and responding.""" @@ -513,7 +707,7 @@ class Ollama(commands.Cog): embed.set_footer(text="Unable to continue.") return await ctx.edit(embed=embed) view = OllamaView(ctx) - async for line in self.ollama_stream(response.content): + async for line in ollama_stream(response.content): if view.cancel.is_set(): embed = discord.Embed( title="Download cancelled.", @@ -613,7 +807,7 @@ class Ollama(commands.Cog): last_update = time.time() buffer = io.StringIO() if not view.cancel.is_set(): - async for line in self.ollama_stream(response.content): + async for line in ollama_stream(response.content): buffer.write(line["message"]["content"]) embed.description += line["message"]["content"] embed.timestamp = discord.utils.utcnow() @@ -719,6 +913,50 @@ class Ollama(commands.Cog): for chunk in discord.utils.as_chunks(iter(embeds or [discord.Embed(title="No Content.")]), 10): await ctx.respond(embeds=chunk, ephemeral=ephemeral) + @commands.message_command(name="Ask AI") + async def ask_ai(self, ctx: discord.ApplicationContext, message: discord.Message): + thread = self.history.create_thread(message.author) + content = message.clean_content + if not content: + if message.embeds: + content = message.embeds[0].description or message.embeds[0].title + if not content: + return await ctx.respond("No content to send to AI.", ephemeral=True) + await ctx.defer() + user_message = { + "role": "user", + "content": message.content + } + self.history.add_message(thread, "user", user_message["content"]) + + server = self.next_server(False) + while not await self.check_server(server): + server = self.next_server() + if server: + break + else: + return await ctx.respond("All servers are offline. Please try again later.", ephemeral=True) + + client = OllamaClient(CONFIG["ollama"][server]["base_url"]) + if not await client.has_model_named("orca-mini", "3b"): + await client.download_model("orca-mini", "3b") + + messages = self.history.get_history(thread) + embed = discord.Embed(description=">>> ") + async for ln in client.new_chat("orca-mini:3b", messages): + embed.description += ln["message"]["content"] + if len(embed.description) >= 4032: + break + if len(embed.description) >= 3250: + embed.colour = discord.Color.gold() + embed.set_footer(text="Warning: {:,}/4096 characters.".format(len(embed.description))) + else: + embed.colour = discord.Color.blurple() + embed.set_footer(text="Using server %r" % server, icon_url=CONFIG["ollama"][server].get("icon_url")) + await ctx.edit(embed=embed) + if ln.get("done"): + break + def setup(bot): bot.add_cog(Ollama(bot))