Start writing ollama client class

This commit is contained in:
Nexus 2024-04-13 23:51:50 +01:00
parent f6d093d45d
commit 0cc5ce7150
Signed by: nex
GPG key ID: 0FA334385D0B689F

View file

@ -1,4 +1,5 @@
import asyncio
import contextlib
import json
import logging
import os
@ -7,6 +8,8 @@ import time
import typing
import base64
import io
import httpx
import redis
from discord import Interaction
@ -20,6 +23,16 @@ from discord.ext import commands
from conf import CONFIG
async def ollama_stream(iterator: aiohttp.StreamReader) -> typing.AsyncIterator[dict]:
async for line in iterator:
line = line.decode("utf-8", "replace").strip()
try:
line = json.loads(line)
except json.JSONDecodeError:
continue
yield line
def get_time_spent(nanoseconds: int) -> str:
hours, minutes, seconds = 0, 0, 0
seconds = nanoseconds / 1e9
@ -50,6 +63,195 @@ def get_time_spent(nanoseconds: int) -> str:
return ", ".join(reversed(result))
class OllamaDownloadHandler:
def __init__(self, client: httpx.AsyncClient, model: str):
self.client = client
self.model = model
self._abort = asyncio.Event()
self._total = 1
self._completed = 0
self.status = "starting"
self.total_duration_s = 0
self.load_duration_s = 0
self.prompt_eval_duration_s = 0
self.eval_duration_s = 0
self.eval_count = 0
self.prompt_eval_count = 0
def abort(self):
self._abort.set()
@property
def percent(self) -> float:
return round((self._completed / self._total) * 100, 2)
def parse_line(self, line: dict):
if line.get("total"):
self._total = line["total"]
if line.get("completed"):
self._completed = line["completed"]
if line.get("status"):
self.status = line["status"]
async def __aiter__(self):
async with self.client.post("/api/pull", json={"name": self.model, "stream": True}, timeout=None) as response:
response.raise_for_status()
async for line in ollama_stream(response.content):
self.parse_line(line)
if self._abort.is_set():
break
yield line
def __await__(self):
async for _ in self:
pass
class OllamaChatHandler:
def __init__(self, client: httpx.AsyncClient, model: str, messages: list):
self.client = client
self.model = model
self.messages = messages
self._abort = asyncio.Event()
self.buffer = io.StringIO()
self.total_duration_s = 0
self.load_duration_s = 0
self.prompt_eval_duration_s = 0
self.eval_duration_s = 0
self.eval_count = 0
self.prompt_eval_count = 0
def abort(self):
self._abort.set()
@property
def result(self) -> str:
"""The current response. Can be called multiple times."""
return self.buffer.getvalue()
def parse_line(self, line: dict):
if line.get("total_duration"):
self.total_duration_s = line["total_duration"] / 1e9
if line.get("load_duration"):
self.load_duration_s = line["load_duration"] / 1e9
if line.get("prompt_eval_duration"):
self.prompt_eval_duration_s = line["prompt_eval_duration"] / 1e9
if line.get("eval_duration"):
self.eval_duration_s = line["eval_duration"] / 1e9
if line.get("eval_count"):
self.eval_count = line["eval_count"]
if line.get("prompt_eval_count"):
self.prompt_eval_count = line["prompt_eval_count"]
async def __aiter__(self):
async with self.client.post(
"/api/chat",
json={
"model": self.model,
"stream": True,
"messages": self.messages
}
) as response:
response.raise_for_status()
async for line in ollama_stream(response.content):
if self._abort.is_set():
break
if line.get("message"):
self.buffer.write(line["message"]["content"])
yield line
if line.get("done"):
break
@classmethod
async def get_streamless(cls, client: httpx.AsyncClient, model: str, messages: list) -> "OllamaChatHandler":
async with client.post("/api/chat", json={"model": model, "messages": messages, "stream": False}) as response:
response.raise_for_status()
handler = cls(client, model, messages)
line = await response.json()
handler.parse_line(line)
handler.buffer.write(line["message"]["content"])
return handler
class OllamaClient:
def __init__(self, base_url: str, authorisation: tuple[str, str] = None):
self.base_url = base_url
self.authorisation = authorisation
def _with_async_client(self, t) -> contextlib.AbstractContextManager[httpx.AsyncClient]:
with httpx.AsyncClient(base_url=self.base_url, timeout=t, auth=self.authorisation) as client:
yield client
def with_client(
self,
timeout: httpx.Timeout | float | int | None = None
) -> contextlib.AbstractContextManager[httpx.AsyncClient]:
"""
Creates an instance for a request, with properly populated values.
:param timeout:
:return:
"""
if isinstance(timeout, (float, int)):
if timeout == -1:
timeout = None
timeout = httpx.Timeout(timeout)
else:
timeout = timeout or httpx.Timeout(60)
return self._with_async_client(timeout)
async def get_tags(self) -> dict[typing.Literal["models"], dict[str, str, int, dict[str, str, None]]]:
"""
Gets the tags for the server.
:return:
"""
async with self.with_client() as client:
async with client.get("/api/tags") as resp:
return await resp.json()
async def has_model_named(self, name: str, tag: str = None) -> bool:
"""Checks that the given server has the model downloaded, optionally with a tag.
:param name: The name of the model (e.g. orca-mini, orca-mini:latest)
:param tag: a specific tag to check for (e.g. latest, chat)
:return: A boolean indicating an existing download."""
if tag is not None:
name += ":" + tag
async with self.with_client() as client:
async with client.post("/api/show", json={"name": name}) as resp:
return resp.status == 200
def download_model(self, name: str, tag: str = "latest") -> OllamaDownloadHandler:
"""Starts the download for a model.
:param name: The name of the model.
:param tag: The tag of the model. Defaults to latest.
:return: An OllamaDownloadHandler instance.
"""
handler = OllamaDownloadHandler(httpx.AsyncClient(base_url=self.base_url), name + ":" + tag)
return handler
def new_chat(
self,
model: str,
messages: list[dict[str, str]],
) -> OllamaChatHandler:
"""
Starts a chat with the given messages.
:param model:
:param messages:
:return:
"""
handler = OllamaChatHandler(httpx.AsyncClient(base_url=self.base_url), model, messages)
return handler
class OllamaView(View):
def __init__(self, ctx: discord.ApplicationContext):
super().__init__(timeout=3600, disable_on_timeout=True)
@ -93,6 +295,10 @@ class ChatHistory:
def create_thread(self, member: discord.Member, default: str | None = None) -> str:
"""
Creates a thread, returns its ID.
:param member: The member who created the thread.
:param default: The system prompt to use.
:return: The thread's ID.
"""
key = os.urandom(3).hex()
self._internal[key] = {
@ -269,18 +475,6 @@ class Ollama(commands.Cog):
self.last_server += 1
return SERVER_KEYS[self.last_server % len(SERVER_KEYS)]
async def ollama_stream(self, iterator: aiohttp.StreamReader) -> typing.AsyncIterator[dict]:
async for line in iterator:
original_line = line
line = line.decode("utf-8", "replace").strip()
try:
line = json.loads(line)
except json.JSONDecodeError:
self.log.warning("Unable to decode JSON: %r", original_line)
continue
else:
self.log.debug("Decoded JSON %r -> %r", original_line, line)
yield line
async def check_server(self, url: str) -> bool:
"""Checks that a server is online and responding."""
@ -513,7 +707,7 @@ class Ollama(commands.Cog):
embed.set_footer(text="Unable to continue.")
return await ctx.edit(embed=embed)
view = OllamaView(ctx)
async for line in self.ollama_stream(response.content):
async for line in ollama_stream(response.content):
if view.cancel.is_set():
embed = discord.Embed(
title="Download cancelled.",
@ -613,7 +807,7 @@ class Ollama(commands.Cog):
last_update = time.time()
buffer = io.StringIO()
if not view.cancel.is_set():
async for line in self.ollama_stream(response.content):
async for line in ollama_stream(response.content):
buffer.write(line["message"]["content"])
embed.description += line["message"]["content"]
embed.timestamp = discord.utils.utcnow()
@ -719,6 +913,50 @@ class Ollama(commands.Cog):
for chunk in discord.utils.as_chunks(iter(embeds or [discord.Embed(title="No Content.")]), 10):
await ctx.respond(embeds=chunk, ephemeral=ephemeral)
@commands.message_command(name="Ask AI")
async def ask_ai(self, ctx: discord.ApplicationContext, message: discord.Message):
thread = self.history.create_thread(message.author)
content = message.clean_content
if not content:
if message.embeds:
content = message.embeds[0].description or message.embeds[0].title
if not content:
return await ctx.respond("No content to send to AI.", ephemeral=True)
await ctx.defer()
user_message = {
"role": "user",
"content": message.content
}
self.history.add_message(thread, "user", user_message["content"])
server = self.next_server(False)
while not await self.check_server(server):
server = self.next_server()
if server:
break
else:
return await ctx.respond("All servers are offline. Please try again later.", ephemeral=True)
client = OllamaClient(CONFIG["ollama"][server]["base_url"])
if not await client.has_model_named("orca-mini", "3b"):
await client.download_model("orca-mini", "3b")
messages = self.history.get_history(thread)
embed = discord.Embed(description=">>> ")
async for ln in client.new_chat("orca-mini:3b", messages):
embed.description += ln["message"]["content"]
if len(embed.description) >= 4032:
break
if len(embed.description) >= 3250:
embed.colour = discord.Color.gold()
embed.set_footer(text="Warning: {:,}/4096 characters.".format(len(embed.description)))
else:
embed.colour = discord.Color.blurple()
embed.set_footer(text="Using server %r" % server, icon_url=CONFIG["ollama"][server].get("icon_url"))
await ctx.edit(embed=embed)
if ln.get("done"):
break
def setup(bot):
bot.add_cog(Ollama(bot))