Start writing ollama client class
This commit is contained in:
parent
f6d093d45d
commit
0cc5ce7150
1 changed files with 252 additions and 14 deletions
|
@ -1,4 +1,5 @@
|
|||
import asyncio
|
||||
import contextlib
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
|
@ -7,6 +8,8 @@ import time
|
|||
import typing
|
||||
import base64
|
||||
import io
|
||||
|
||||
import httpx
|
||||
import redis
|
||||
from discord import Interaction
|
||||
|
||||
|
@ -20,6 +23,16 @@ from discord.ext import commands
|
|||
from conf import CONFIG
|
||||
|
||||
|
||||
async def ollama_stream(iterator: aiohttp.StreamReader) -> typing.AsyncIterator[dict]:
|
||||
async for line in iterator:
|
||||
line = line.decode("utf-8", "replace").strip()
|
||||
try:
|
||||
line = json.loads(line)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
yield line
|
||||
|
||||
|
||||
def get_time_spent(nanoseconds: int) -> str:
|
||||
hours, minutes, seconds = 0, 0, 0
|
||||
seconds = nanoseconds / 1e9
|
||||
|
@ -50,6 +63,195 @@ def get_time_spent(nanoseconds: int) -> str:
|
|||
return ", ".join(reversed(result))
|
||||
|
||||
|
||||
class OllamaDownloadHandler:
|
||||
def __init__(self, client: httpx.AsyncClient, model: str):
|
||||
self.client = client
|
||||
self.model = model
|
||||
self._abort = asyncio.Event()
|
||||
self._total = 1
|
||||
self._completed = 0
|
||||
self.status = "starting"
|
||||
|
||||
self.total_duration_s = 0
|
||||
self.load_duration_s = 0
|
||||
self.prompt_eval_duration_s = 0
|
||||
self.eval_duration_s = 0
|
||||
self.eval_count = 0
|
||||
self.prompt_eval_count = 0
|
||||
|
||||
def abort(self):
|
||||
self._abort.set()
|
||||
|
||||
@property
|
||||
def percent(self) -> float:
|
||||
return round((self._completed / self._total) * 100, 2)
|
||||
|
||||
def parse_line(self, line: dict):
|
||||
if line.get("total"):
|
||||
self._total = line["total"]
|
||||
if line.get("completed"):
|
||||
self._completed = line["completed"]
|
||||
if line.get("status"):
|
||||
self.status = line["status"]
|
||||
|
||||
async def __aiter__(self):
|
||||
async with self.client.post("/api/pull", json={"name": self.model, "stream": True}, timeout=None) as response:
|
||||
response.raise_for_status()
|
||||
async for line in ollama_stream(response.content):
|
||||
self.parse_line(line)
|
||||
if self._abort.is_set():
|
||||
break
|
||||
yield line
|
||||
|
||||
def __await__(self):
|
||||
async for _ in self:
|
||||
pass
|
||||
|
||||
|
||||
class OllamaChatHandler:
|
||||
def __init__(self, client: httpx.AsyncClient, model: str, messages: list):
|
||||
self.client = client
|
||||
self.model = model
|
||||
self.messages = messages
|
||||
self._abort = asyncio.Event()
|
||||
self.buffer = io.StringIO()
|
||||
|
||||
self.total_duration_s = 0
|
||||
self.load_duration_s = 0
|
||||
self.prompt_eval_duration_s = 0
|
||||
self.eval_duration_s = 0
|
||||
self.eval_count = 0
|
||||
self.prompt_eval_count = 0
|
||||
|
||||
def abort(self):
|
||||
self._abort.set()
|
||||
|
||||
@property
|
||||
def result(self) -> str:
|
||||
"""The current response. Can be called multiple times."""
|
||||
return self.buffer.getvalue()
|
||||
|
||||
def parse_line(self, line: dict):
|
||||
if line.get("total_duration"):
|
||||
self.total_duration_s = line["total_duration"] / 1e9
|
||||
if line.get("load_duration"):
|
||||
self.load_duration_s = line["load_duration"] / 1e9
|
||||
if line.get("prompt_eval_duration"):
|
||||
self.prompt_eval_duration_s = line["prompt_eval_duration"] / 1e9
|
||||
if line.get("eval_duration"):
|
||||
self.eval_duration_s = line["eval_duration"] / 1e9
|
||||
|
||||
if line.get("eval_count"):
|
||||
self.eval_count = line["eval_count"]
|
||||
if line.get("prompt_eval_count"):
|
||||
self.prompt_eval_count = line["prompt_eval_count"]
|
||||
|
||||
async def __aiter__(self):
|
||||
async with self.client.post(
|
||||
"/api/chat",
|
||||
json={
|
||||
"model": self.model,
|
||||
"stream": True,
|
||||
"messages": self.messages
|
||||
}
|
||||
) as response:
|
||||
response.raise_for_status()
|
||||
async for line in ollama_stream(response.content):
|
||||
if self._abort.is_set():
|
||||
break
|
||||
|
||||
if line.get("message"):
|
||||
self.buffer.write(line["message"]["content"])
|
||||
yield line
|
||||
|
||||
if line.get("done"):
|
||||
break
|
||||
|
||||
@classmethod
|
||||
async def get_streamless(cls, client: httpx.AsyncClient, model: str, messages: list) -> "OllamaChatHandler":
|
||||
async with client.post("/api/chat", json={"model": model, "messages": messages, "stream": False}) as response:
|
||||
response.raise_for_status()
|
||||
handler = cls(client, model, messages)
|
||||
line = await response.json()
|
||||
handler.parse_line(line)
|
||||
handler.buffer.write(line["message"]["content"])
|
||||
return handler
|
||||
|
||||
|
||||
class OllamaClient:
|
||||
def __init__(self, base_url: str, authorisation: tuple[str, str] = None):
|
||||
self.base_url = base_url
|
||||
self.authorisation = authorisation
|
||||
|
||||
def _with_async_client(self, t) -> contextlib.AbstractContextManager[httpx.AsyncClient]:
|
||||
with httpx.AsyncClient(base_url=self.base_url, timeout=t, auth=self.authorisation) as client:
|
||||
yield client
|
||||
|
||||
def with_client(
|
||||
self,
|
||||
timeout: httpx.Timeout | float | int | None = None
|
||||
) -> contextlib.AbstractContextManager[httpx.AsyncClient]:
|
||||
"""
|
||||
Creates an instance for a request, with properly populated values.
|
||||
:param timeout:
|
||||
:return:
|
||||
"""
|
||||
if isinstance(timeout, (float, int)):
|
||||
if timeout == -1:
|
||||
timeout = None
|
||||
timeout = httpx.Timeout(timeout)
|
||||
else:
|
||||
timeout = timeout or httpx.Timeout(60)
|
||||
return self._with_async_client(timeout)
|
||||
|
||||
async def get_tags(self) -> dict[typing.Literal["models"], dict[str, str, int, dict[str, str, None]]]:
|
||||
"""
|
||||
Gets the tags for the server.
|
||||
:return:
|
||||
"""
|
||||
async with self.with_client() as client:
|
||||
async with client.get("/api/tags") as resp:
|
||||
return await resp.json()
|
||||
|
||||
async def has_model_named(self, name: str, tag: str = None) -> bool:
|
||||
"""Checks that the given server has the model downloaded, optionally with a tag.
|
||||
|
||||
:param name: The name of the model (e.g. orca-mini, orca-mini:latest)
|
||||
:param tag: a specific tag to check for (e.g. latest, chat)
|
||||
:return: A boolean indicating an existing download."""
|
||||
if tag is not None:
|
||||
name += ":" + tag
|
||||
async with self.with_client() as client:
|
||||
async with client.post("/api/show", json={"name": name}) as resp:
|
||||
return resp.status == 200
|
||||
|
||||
def download_model(self, name: str, tag: str = "latest") -> OllamaDownloadHandler:
|
||||
"""Starts the download for a model.
|
||||
|
||||
:param name: The name of the model.
|
||||
:param tag: The tag of the model. Defaults to latest.
|
||||
:return: An OllamaDownloadHandler instance.
|
||||
"""
|
||||
handler = OllamaDownloadHandler(httpx.AsyncClient(base_url=self.base_url), name + ":" + tag)
|
||||
return handler
|
||||
|
||||
def new_chat(
|
||||
self,
|
||||
model: str,
|
||||
messages: list[dict[str, str]],
|
||||
) -> OllamaChatHandler:
|
||||
"""
|
||||
Starts a chat with the given messages.
|
||||
|
||||
:param model:
|
||||
:param messages:
|
||||
:return:
|
||||
"""
|
||||
handler = OllamaChatHandler(httpx.AsyncClient(base_url=self.base_url), model, messages)
|
||||
return handler
|
||||
|
||||
|
||||
|
||||
class OllamaView(View):
|
||||
def __init__(self, ctx: discord.ApplicationContext):
|
||||
super().__init__(timeout=3600, disable_on_timeout=True)
|
||||
|
@ -93,6 +295,10 @@ class ChatHistory:
|
|||
def create_thread(self, member: discord.Member, default: str | None = None) -> str:
|
||||
"""
|
||||
Creates a thread, returns its ID.
|
||||
|
||||
:param member: The member who created the thread.
|
||||
:param default: The system prompt to use.
|
||||
:return: The thread's ID.
|
||||
"""
|
||||
key = os.urandom(3).hex()
|
||||
self._internal[key] = {
|
||||
|
@ -269,18 +475,6 @@ class Ollama(commands.Cog):
|
|||
self.last_server += 1
|
||||
return SERVER_KEYS[self.last_server % len(SERVER_KEYS)]
|
||||
|
||||
async def ollama_stream(self, iterator: aiohttp.StreamReader) -> typing.AsyncIterator[dict]:
|
||||
async for line in iterator:
|
||||
original_line = line
|
||||
line = line.decode("utf-8", "replace").strip()
|
||||
try:
|
||||
line = json.loads(line)
|
||||
except json.JSONDecodeError:
|
||||
self.log.warning("Unable to decode JSON: %r", original_line)
|
||||
continue
|
||||
else:
|
||||
self.log.debug("Decoded JSON %r -> %r", original_line, line)
|
||||
yield line
|
||||
|
||||
async def check_server(self, url: str) -> bool:
|
||||
"""Checks that a server is online and responding."""
|
||||
|
@ -513,7 +707,7 @@ class Ollama(commands.Cog):
|
|||
embed.set_footer(text="Unable to continue.")
|
||||
return await ctx.edit(embed=embed)
|
||||
view = OllamaView(ctx)
|
||||
async for line in self.ollama_stream(response.content):
|
||||
async for line in ollama_stream(response.content):
|
||||
if view.cancel.is_set():
|
||||
embed = discord.Embed(
|
||||
title="Download cancelled.",
|
||||
|
@ -613,7 +807,7 @@ class Ollama(commands.Cog):
|
|||
last_update = time.time()
|
||||
buffer = io.StringIO()
|
||||
if not view.cancel.is_set():
|
||||
async for line in self.ollama_stream(response.content):
|
||||
async for line in ollama_stream(response.content):
|
||||
buffer.write(line["message"]["content"])
|
||||
embed.description += line["message"]["content"]
|
||||
embed.timestamp = discord.utils.utcnow()
|
||||
|
@ -719,6 +913,50 @@ class Ollama(commands.Cog):
|
|||
for chunk in discord.utils.as_chunks(iter(embeds or [discord.Embed(title="No Content.")]), 10):
|
||||
await ctx.respond(embeds=chunk, ephemeral=ephemeral)
|
||||
|
||||
@commands.message_command(name="Ask AI")
|
||||
async def ask_ai(self, ctx: discord.ApplicationContext, message: discord.Message):
|
||||
thread = self.history.create_thread(message.author)
|
||||
content = message.clean_content
|
||||
if not content:
|
||||
if message.embeds:
|
||||
content = message.embeds[0].description or message.embeds[0].title
|
||||
if not content:
|
||||
return await ctx.respond("No content to send to AI.", ephemeral=True)
|
||||
await ctx.defer()
|
||||
user_message = {
|
||||
"role": "user",
|
||||
"content": message.content
|
||||
}
|
||||
self.history.add_message(thread, "user", user_message["content"])
|
||||
|
||||
server = self.next_server(False)
|
||||
while not await self.check_server(server):
|
||||
server = self.next_server()
|
||||
if server:
|
||||
break
|
||||
else:
|
||||
return await ctx.respond("All servers are offline. Please try again later.", ephemeral=True)
|
||||
|
||||
client = OllamaClient(CONFIG["ollama"][server]["base_url"])
|
||||
if not await client.has_model_named("orca-mini", "3b"):
|
||||
await client.download_model("orca-mini", "3b")
|
||||
|
||||
messages = self.history.get_history(thread)
|
||||
embed = discord.Embed(description=">>> ")
|
||||
async for ln in client.new_chat("orca-mini:3b", messages):
|
||||
embed.description += ln["message"]["content"]
|
||||
if len(embed.description) >= 4032:
|
||||
break
|
||||
if len(embed.description) >= 3250:
|
||||
embed.colour = discord.Color.gold()
|
||||
embed.set_footer(text="Warning: {:,}/4096 characters.".format(len(embed.description)))
|
||||
else:
|
||||
embed.colour = discord.Color.blurple()
|
||||
embed.set_footer(text="Using server %r" % server, icon_url=CONFIG["ollama"][server].get("icon_url"))
|
||||
await ctx.edit(embed=embed)
|
||||
if ln.get("done"):
|
||||
break
|
||||
|
||||
|
||||
def setup(bot):
|
||||
bot.add_cog(Ollama(bot))
|
||||
|
|
Loading…
Reference in a new issue