Working shit
* /status works * starting up works * Config works * /ollama works * autocomplete works * downloading works * threads work * images work
This commit is contained in:
parent
e45e87167f
commit
2b4c324ba6
8 changed files with 547 additions and 1 deletions
2
.gitignore
vendored
2
.gitignore
vendored
|
@ -281,3 +281,5 @@ pyrightconfig.json
|
||||||
.ionide
|
.ionide
|
||||||
|
|
||||||
# End of https://www.toptal.com/developers/gitignore/api/python,visualstudiocode,pycharm+all
|
# End of https://www.toptal.com/developers/gitignore/api/python,visualstudiocode,pycharm+all
|
||||||
|
.venv/
|
||||||
|
default.db
|
||||||
|
|
31
config.toml
Normal file
31
config.toml
Normal file
|
@ -0,0 +1,31 @@
|
||||||
|
[bot]
|
||||||
|
token = "MTI0OTUwNjU1ODIwODExNDgxMA.G_JIT7.R2R_1-2IHhdzEf6mHUgIa82oyRuwonBRrkd_Pc"
|
||||||
|
debug_guilds = [1106243455816052847]
|
||||||
|
|
||||||
|
[servers]
|
||||||
|
order = ["SpeedySHRoNK", "IvyPC", "nextop-ts", "optiplex", "shronk"]
|
||||||
|
|
||||||
|
[servers.SpeedySHRoNK]
|
||||||
|
base_url = "http://ollama.shronk.net:11434"
|
||||||
|
gpu = true
|
||||||
|
vram_gb = 10
|
||||||
|
|
||||||
|
[servers.IvyPC]
|
||||||
|
base_url = "http://192.168.0.26:11435"
|
||||||
|
gpu = true
|
||||||
|
vram_gb = 8
|
||||||
|
|
||||||
|
[servers.nextop-ts]
|
||||||
|
base_url = "http://laptop-linux.fluffy-gentoo.ts.net:11434"
|
||||||
|
gpu = true
|
||||||
|
vram_gb = 4
|
||||||
|
|
||||||
|
[servers.optiplex]
|
||||||
|
base_url = "http://192.168.0.254:11434"
|
||||||
|
gpu = false
|
||||||
|
vram_gb = 16
|
||||||
|
|
||||||
|
[servers.shronk]
|
||||||
|
base_url = "http://ollama.shronk.net:11434"
|
||||||
|
gpu = false
|
||||||
|
vram_gb = 16
|
328
jimmy/cogs/chat.py
Normal file
328
jimmy/cogs/chat.py
Normal file
|
@ -0,0 +1,328 @@
|
||||||
|
import asyncio
|
||||||
|
import io
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
import typing
|
||||||
|
import contextlib
|
||||||
|
from fnmatch import fnmatch
|
||||||
|
|
||||||
|
import discord
|
||||||
|
from discord import Interaction
|
||||||
|
from ollama import AsyncClient, ResponseError, Options
|
||||||
|
from discord.ext import commands
|
||||||
|
from jimmy.utils import async_ratio, create_ollama_message
|
||||||
|
from jimmy.config import get_servers, ServerConfig, get_server
|
||||||
|
from jimmy.db import OllamaThread
|
||||||
|
from humanize import naturalsize
|
||||||
|
|
||||||
|
|
||||||
|
@contextlib.asynccontextmanager
|
||||||
|
async def ollama_client(host: str, **kwargs) -> AsyncClient:
|
||||||
|
host = str(host)
|
||||||
|
client = AsyncClient(host, **kwargs)
|
||||||
|
try:
|
||||||
|
yield client
|
||||||
|
finally:
|
||||||
|
# Ollama doesn't auto-close the client, so we have to do it ourselves.
|
||||||
|
await client._client.aclose()
|
||||||
|
|
||||||
|
|
||||||
|
class StopDownloadView(discord.ui.View):
|
||||||
|
def __init__(self, ctx: discord.ApplicationContext):
|
||||||
|
super().__init__(timeout=None)
|
||||||
|
self.ctx = ctx
|
||||||
|
self.event = asyncio.Event()
|
||||||
|
|
||||||
|
async def interaction_check(self, interaction: Interaction) -> bool:
|
||||||
|
return interaction.user == self.ctx.user
|
||||||
|
|
||||||
|
@discord.ui.button(label="Cancel", style=discord.ButtonStyle.danger)
|
||||||
|
async def cancel_download(self, _, interaction: discord.Interaction):
|
||||||
|
self.event.set()
|
||||||
|
self.stop()
|
||||||
|
await interaction.response.edit_message(view=None)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_available_tags_autocomplete(ctx: discord.AutocompleteContext):
|
||||||
|
chosen_server = get_server(ctx.options.get("server") or get_servers()[0].name)
|
||||||
|
async with ollama_client(str(chosen_server.base_url), timeout=2) as client:
|
||||||
|
tags = (await client.list())["models"]
|
||||||
|
return [tag["model"] for tag in tags if ctx.value.casefold() in tag["model"].casefold()]
|
||||||
|
|
||||||
|
|
||||||
|
_ServerOptionChoices = [discord.OptionChoice(server.name, server.name) for server in get_servers()]
|
||||||
|
|
||||||
|
|
||||||
|
class Chat(commands.Cog):
|
||||||
|
def __init__(self, bot):
|
||||||
|
self.bot = bot
|
||||||
|
self.server_locks = {}
|
||||||
|
for server in get_servers():
|
||||||
|
self.server_locks[server.name] = asyncio.Lock()
|
||||||
|
self.log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@commands.slash_command()
|
||||||
|
async def status(self, ctx: discord.ApplicationContext):
|
||||||
|
"""Checks the status on all servers."""
|
||||||
|
await ctx.defer()
|
||||||
|
|
||||||
|
def decorate_name(_s: ServerConfig):
|
||||||
|
if _s.gpu:
|
||||||
|
return f"{_s.name} (\u26A1)"
|
||||||
|
return _s.name
|
||||||
|
|
||||||
|
embed = discord.Embed(
|
||||||
|
title="Ollama Statuses:",
|
||||||
|
color=discord.Color.blurple()
|
||||||
|
)
|
||||||
|
fields = {}
|
||||||
|
for server in get_servers():
|
||||||
|
if server.throttle and self.server_locks[server.name].locked():
|
||||||
|
embed.add_field(
|
||||||
|
name=decorate_name(server),
|
||||||
|
value=f"\N{closed lock with key} In use.",
|
||||||
|
inline=False
|
||||||
|
)
|
||||||
|
fields[server] = len(embed.fields) - 1
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
embed.add_field(
|
||||||
|
name=decorate_name(server),
|
||||||
|
value=f"\N{hourglass with flowing sand} Waiting...",
|
||||||
|
inline=False
|
||||||
|
)
|
||||||
|
fields[server] = len(embed.fields) - 1
|
||||||
|
|
||||||
|
await ctx.respond(embed=embed)
|
||||||
|
tasks = {}
|
||||||
|
for server in get_servers():
|
||||||
|
if server.throttle and self.server_locks[server.name].locked():
|
||||||
|
continue
|
||||||
|
tasks[server] = asyncio.create_task(server.is_online())
|
||||||
|
|
||||||
|
await asyncio.gather(*tasks.values())
|
||||||
|
for server, task in tasks.items():
|
||||||
|
if task.result():
|
||||||
|
embed.set_field_at(
|
||||||
|
fields[server],
|
||||||
|
name=decorate_name(server),
|
||||||
|
value=f"\N{white heavy check mark} Online.",
|
||||||
|
inline=False
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
embed.set_field_at(
|
||||||
|
fields[server],
|
||||||
|
name=decorate_name(server),
|
||||||
|
value=f"\N{cross mark} Offline.",
|
||||||
|
inline=False
|
||||||
|
)
|
||||||
|
await ctx.edit(embed=embed)
|
||||||
|
|
||||||
|
@commands.slash_command(name="ollama")
|
||||||
|
async def start_ollama_chat(
|
||||||
|
self,
|
||||||
|
ctx: discord.ApplicationContext,
|
||||||
|
prompt: str,
|
||||||
|
system_prompt: typing.Annotated[
|
||||||
|
str | None,
|
||||||
|
discord.Option(
|
||||||
|
discord.SlashCommandOptionType.string,
|
||||||
|
description="The system prompt to use.",
|
||||||
|
default=None
|
||||||
|
)
|
||||||
|
],
|
||||||
|
server: typing.Annotated[
|
||||||
|
str,
|
||||||
|
discord.Option(
|
||||||
|
discord.SlashCommandOptionType.string,
|
||||||
|
description="The server to use.",
|
||||||
|
choices=_ServerOptionChoices,
|
||||||
|
default=get_servers()[0].name
|
||||||
|
)
|
||||||
|
],
|
||||||
|
model: typing.Annotated[
|
||||||
|
str,
|
||||||
|
discord.Option(
|
||||||
|
discord.SlashCommandOptionType.string,
|
||||||
|
description="The model to use.",
|
||||||
|
autocomplete=get_available_tags_autocomplete,
|
||||||
|
default="llama3:latest"
|
||||||
|
)
|
||||||
|
],
|
||||||
|
image: typing.Annotated[
|
||||||
|
discord.Attachment | None,
|
||||||
|
discord.Option(
|
||||||
|
discord.SlashCommandOptionType.attachment,
|
||||||
|
description="The image to use for llava.",
|
||||||
|
default=None
|
||||||
|
)
|
||||||
|
],
|
||||||
|
thread_id: typing.Annotated[
|
||||||
|
str | None,
|
||||||
|
discord.Option(
|
||||||
|
discord.SlashCommandOptionType.string,
|
||||||
|
description="The thread ID to continue.",
|
||||||
|
default=None
|
||||||
|
)
|
||||||
|
]
|
||||||
|
):
|
||||||
|
"""Have a chat with ollama"""
|
||||||
|
await ctx.defer()
|
||||||
|
server = get_server(server)
|
||||||
|
async with self.server_locks[server.name]:
|
||||||
|
if not await server.is_online():
|
||||||
|
await ctx.respond(
|
||||||
|
content=f"{server} is offline.",
|
||||||
|
delete_after=60
|
||||||
|
)
|
||||||
|
return
|
||||||
|
async with ollama_client(str(server.base_url)) as client:
|
||||||
|
client: AsyncClient
|
||||||
|
self.log.info("Checking if %r has the model %r", server, model)
|
||||||
|
tags = (await client.list())["models"]
|
||||||
|
if model not in [x["model"] for x in tags]:
|
||||||
|
embed = discord.Embed(
|
||||||
|
title=f"Downloading {model} on {server}.",
|
||||||
|
description=f"Initiating download...",
|
||||||
|
color=discord.Color.blurple()
|
||||||
|
)
|
||||||
|
view = StopDownloadView(ctx)
|
||||||
|
await ctx.respond(
|
||||||
|
embed=embed,
|
||||||
|
view=view
|
||||||
|
)
|
||||||
|
last_edit = 0
|
||||||
|
async with ctx.typing():
|
||||||
|
try:
|
||||||
|
last_completed = 0
|
||||||
|
last_completed_ts = time.time()
|
||||||
|
|
||||||
|
async for line in await client.pull(model, stream=True):
|
||||||
|
if view.event.is_set():
|
||||||
|
embed.add_field(name="Error!", value="Download cancelled.")
|
||||||
|
embed.colour = discord.Colour.red()
|
||||||
|
await ctx.edit(embed=embed)
|
||||||
|
return
|
||||||
|
self.log.info("Response from %r: %r", server, line)
|
||||||
|
if line["status"] in {
|
||||||
|
"pulling manifest",
|
||||||
|
"verifying sha256 digest",
|
||||||
|
"writing manifest",
|
||||||
|
"removing any unused layers",
|
||||||
|
"success"
|
||||||
|
}:
|
||||||
|
embed.description = line["status"].capitalize()
|
||||||
|
else:
|
||||||
|
total = line["total"]
|
||||||
|
completed = line.get("completed", 0)
|
||||||
|
percent = round(completed / total * 100, 1)
|
||||||
|
pb_fill = "▰" * int(percent / 10)
|
||||||
|
pb_empty = "▱" * (10 - int(percent / 10))
|
||||||
|
bytes_per_second = completed - last_completed
|
||||||
|
bytes_per_second /= (time.time() - last_completed_ts)
|
||||||
|
last_completed = completed
|
||||||
|
last_completed_ts = time.time()
|
||||||
|
mbps = round((bytes_per_second * 8) / 1024 / 1024)
|
||||||
|
progress_bar = f"[{pb_fill}{pb_empty}]"
|
||||||
|
ns_total = naturalsize(total, binary=True)
|
||||||
|
ns_completed = naturalsize(completed, binary=True)
|
||||||
|
embed.description = (
|
||||||
|
f"{line['status'].capitalize()} {percent}% {progress_bar} "
|
||||||
|
f"({ns_completed}/{ns_total} @ {mbps} Mb/s)"
|
||||||
|
)
|
||||||
|
|
||||||
|
if time.time() - last_edit >= 2.5:
|
||||||
|
await ctx.edit(embed=embed)
|
||||||
|
last_edit = time.time()
|
||||||
|
except ResponseError as err:
|
||||||
|
if err.error.endswith("file does not exist"):
|
||||||
|
await ctx.edit(
|
||||||
|
embed=None,
|
||||||
|
content="The model %r does not exist." % model,
|
||||||
|
delete_after=60,
|
||||||
|
view=None
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
embed.add_field(
|
||||||
|
name="Error!",
|
||||||
|
value=err.error
|
||||||
|
)
|
||||||
|
embed.colour = discord.Colour.red()
|
||||||
|
await ctx.edit(embed=embed, view=None)
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
embed.colour = discord.Colour.green()
|
||||||
|
embed.description = f"Downloaded {model} on {server}."
|
||||||
|
await ctx.edit(embed=embed, delete_after=30, view=None)
|
||||||
|
|
||||||
|
messages = []
|
||||||
|
if thread_id:
|
||||||
|
thread = await OllamaThread.get_or_none(thread_id=thread_id)
|
||||||
|
if thread:
|
||||||
|
for msg in thread.messages:
|
||||||
|
messages.append(
|
||||||
|
await create_ollama_message(msg["content"], role=msg["role"])
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
await ctx.respond(content="No thread with that ID exists.", delete_after=30)
|
||||||
|
if system_prompt:
|
||||||
|
messages.append(await create_ollama_message(system_prompt, role="system"))
|
||||||
|
messages.append(await create_ollama_message(prompt, images=[await image.read()] if image else None))
|
||||||
|
embed = discord.Embed(title=f"{model}:", description="")
|
||||||
|
view = StopDownloadView(ctx)
|
||||||
|
msg = await ctx.respond(
|
||||||
|
embed=embed,
|
||||||
|
view=view
|
||||||
|
)
|
||||||
|
last_edit = time.time()
|
||||||
|
buffer = io.StringIO()
|
||||||
|
async for response in await client.chat(
|
||||||
|
model,
|
||||||
|
messages,
|
||||||
|
stream=True,
|
||||||
|
options=Options(
|
||||||
|
num_ctx=4096,
|
||||||
|
low_vram=server.vram_gb < 8,
|
||||||
|
temperature=1.5
|
||||||
|
)
|
||||||
|
):
|
||||||
|
self.log.info("Response from %r: %r", server, response)
|
||||||
|
buffer.write(response["message"]["content"])
|
||||||
|
|
||||||
|
if len(buffer.getvalue()) > 4096:
|
||||||
|
value = "... " + buffer.getvalue()[4:]
|
||||||
|
else:
|
||||||
|
value = buffer.getvalue()
|
||||||
|
embed.description = value
|
||||||
|
if view.event.is_set():
|
||||||
|
embed.add_field(name="Error!", value="Chat cancelled.")
|
||||||
|
embed.colour = discord.Colour.red()
|
||||||
|
await msg.edit(embed=embed, view=None)
|
||||||
|
return
|
||||||
|
if time.time() - last_edit >= 2.5:
|
||||||
|
await msg.edit(embed=embed, view=view)
|
||||||
|
last_edit = time.time()
|
||||||
|
embed.colour = discord.Colour.green()
|
||||||
|
if len(buffer.getvalue()) > 4096:
|
||||||
|
file = discord.File(
|
||||||
|
io.BytesIO(buffer.getvalue().encode()),
|
||||||
|
filename="full-chat.txt"
|
||||||
|
)
|
||||||
|
embed.add_field(
|
||||||
|
name="Full chat",
|
||||||
|
value="The chat was too long to fit in this message. "
|
||||||
|
f"You can download the `full-chat.txt` file to see the full message."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
file = discord.utils.MISSING
|
||||||
|
|
||||||
|
thread = OllamaThread(
|
||||||
|
messages=[{"role": m["role"], "content": m["content"]} for m in messages],
|
||||||
|
)
|
||||||
|
await thread.save()
|
||||||
|
embed.set_footer(text=f"Chat ID: {thread.thread_id}")
|
||||||
|
await msg.edit(embed=embed, view=None, file=file)
|
||||||
|
|
||||||
|
|
||||||
|
def setup(bot):
|
||||||
|
bot.add_cog(Chat(bot))
|
69
jimmy/config.py
Normal file
69
jimmy/config.py
Normal file
|
@ -0,0 +1,69 @@
|
||||||
|
import tomllib
|
||||||
|
import logging
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
from pydantic import BaseModel, Field, AnyHttpUrl
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ServerConfig(BaseModel):
|
||||||
|
name: str = Field(min_length=1, max_length=32)
|
||||||
|
base_url: AnyHttpUrl
|
||||||
|
gpu: bool = False
|
||||||
|
vram_gb: int = 4
|
||||||
|
throttle: bool = False
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return "<ServerConfig name={0.name} base_url={0.base_url} gpu={0.gpu!s} vram_gb={0.vram_gb}>".format(self)
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return self.name
|
||||||
|
|
||||||
|
async def is_online(self) -> bool:
|
||||||
|
"""
|
||||||
|
Checks that the current server is online and responding to requests.
|
||||||
|
"""
|
||||||
|
async with httpx.AsyncClient(base_url=str(self.base_url)) as client:
|
||||||
|
try:
|
||||||
|
response = await client.get("/api/tags")
|
||||||
|
return response.status_code == 200
|
||||||
|
except httpx.RequestError:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def __hash__(self):
|
||||||
|
return hash(self.base_url)
|
||||||
|
|
||||||
|
|
||||||
|
def get_servers(filter_func: Callable[[ServerConfig], bool] = None) -> list[ServerConfig]:
|
||||||
|
config = get_config()
|
||||||
|
keys = list(config["servers"].keys())
|
||||||
|
log.info("Servers: %r", keys)
|
||||||
|
try:
|
||||||
|
keys = config["servers"].pop("order")
|
||||||
|
log.info("Ordered keys: %r", keys)
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
servers = [ServerConfig(name=key, **config["servers"][key]) for key in keys]
|
||||||
|
if filter_func:
|
||||||
|
servers = list(filter(filter_func, servers))
|
||||||
|
return servers
|
||||||
|
|
||||||
|
|
||||||
|
def get_server(name_or_base_url: str) -> ServerConfig | None:
|
||||||
|
servers = get_servers()
|
||||||
|
for server in servers:
|
||||||
|
if server.name == name_or_base_url or server.base_url == name_or_base_url:
|
||||||
|
return server
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def get_config():
|
||||||
|
with open("config.toml", "rb") as file:
|
||||||
|
_loaded = tomllib.load(file)
|
||||||
|
|
||||||
|
_loaded.setdefault("servers", {})
|
||||||
|
_loaded["servers"].setdefault("order", [])
|
||||||
|
_loaded.setdefault("bot", {})
|
||||||
|
return _loaded
|
13
jimmy/db.py
Normal file
13
jimmy/db.py
Normal file
|
@ -0,0 +1,13 @@
|
||||||
|
import os
|
||||||
|
|
||||||
|
from tortoise.models import Model
|
||||||
|
from tortoise import fields
|
||||||
|
|
||||||
|
|
||||||
|
class OllamaThread(Model):
|
||||||
|
thread_id = fields.CharField(max_length=255, unique=True, default=lambda: os.urandom(4).hex())
|
||||||
|
messages = fields.JSONField(default=[])
|
||||||
|
created_at = fields.DatetimeField(auto_now_add=True)
|
||||||
|
|
||||||
|
class Meta:
|
||||||
|
table = "ollama_threads"
|
53
jimmy/main.py
Normal file
53
jimmy/main.py
Normal file
|
@ -0,0 +1,53 @@
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import logging
|
||||||
|
import discord
|
||||||
|
from discord.ext import commands
|
||||||
|
from tortoise import Tortoise
|
||||||
|
sys.path.extend("..") # noqa: E402
|
||||||
|
from .config import get_config
|
||||||
|
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
|
||||||
|
|
||||||
|
class SentientJimmy(commands.Bot):
|
||||||
|
def __init__(self):
|
||||||
|
intents = discord.Intents.default()
|
||||||
|
# noinspection PyUnresolvedReferences
|
||||||
|
intents.message_content = True
|
||||||
|
super().__init__(
|
||||||
|
commands.when_mentioned_or("."),
|
||||||
|
intents=intents,
|
||||||
|
case_insensitive=True,
|
||||||
|
strip_after_prefix=True,
|
||||||
|
debug_guilds=get_config()["bot"].get("debug_guilds"),
|
||||||
|
)
|
||||||
|
self.load_extension("jimmy.cogs.chat")
|
||||||
|
self.load_extension("jishaku")
|
||||||
|
|
||||||
|
async def start(self, token: str, *, reconnect: bool = True) -> None:
|
||||||
|
is_docker = os.path.exists("/.dockerenv")
|
||||||
|
default_db = "sqlite://:memory:" if is_docker else "sqlite://default.db"
|
||||||
|
await Tortoise.init(
|
||||||
|
db_url=get_config()["bot"].get("db_url", default_db),
|
||||||
|
modules={"models": ["jimmy.db"]}
|
||||||
|
)
|
||||||
|
await Tortoise.generate_schemas()
|
||||||
|
await super().start(token, reconnect=reconnect)
|
||||||
|
|
||||||
|
async def close(self) -> None:
|
||||||
|
await Tortoise.close_connections()
|
||||||
|
await super().close()
|
||||||
|
|
||||||
|
def run(self) -> None:
|
||||||
|
token = get_config()["bot"]["token"]
|
||||||
|
super().run(token)
|
||||||
|
|
||||||
|
|
||||||
|
bot = SentientJimmy()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
bot.run()
|
47
jimmy/utils.py
Normal file
47
jimmy/utils.py
Normal file
|
@ -0,0 +1,47 @@
|
||||||
|
import asyncio
|
||||||
|
import typing
|
||||||
|
from functools import partial
|
||||||
|
from fuzzywuzzy.fuzz import ratio
|
||||||
|
from ollama import Message
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = (
|
||||||
|
'async_ratio',
|
||||||
|
'create_ollama_message',
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def async_ratio(a: str, b: str) -> int:
|
||||||
|
"""
|
||||||
|
Wraps fuzzywuzzy ratio in an async function
|
||||||
|
|
||||||
|
:param a: str - first string
|
||||||
|
:param b: str - second string
|
||||||
|
:return: int - ratio of similarity
|
||||||
|
"""
|
||||||
|
return await asyncio.to_thread(partial(ratio, a, b))
|
||||||
|
|
||||||
|
|
||||||
|
async def create_ollama_message(
|
||||||
|
content: str,
|
||||||
|
role: typing.Literal["system", "assistant", "user"] = "user",
|
||||||
|
images: typing.List[str | bytes] = None
|
||||||
|
) -> Message:
|
||||||
|
"""
|
||||||
|
Create a message for ollama.
|
||||||
|
|
||||||
|
:param content: str - the content of the message
|
||||||
|
:param role: str - the role of the message
|
||||||
|
:param images: list - the images to attach to the message
|
||||||
|
:return: dict - the message
|
||||||
|
"""
|
||||||
|
def factory(**kwargs):
|
||||||
|
return Message(**kwargs)
|
||||||
|
return await asyncio.to_thread(
|
||||||
|
partial(
|
||||||
|
factory,
|
||||||
|
role=role,
|
||||||
|
content=content,
|
||||||
|
images=images
|
||||||
|
)
|
||||||
|
)
|
|
@ -1,5 +1,8 @@
|
||||||
py-cord~=2.5
|
py-cord~=2.5
|
||||||
ollama-python~=0.1
|
ollama~=0.2
|
||||||
tortoise-orm[asyncpg]~=0.21
|
tortoise-orm[asyncpg]~=0.21
|
||||||
uvicorn[standard]~=0.30
|
uvicorn[standard]~=0.30
|
||||||
fastapi~=0.111
|
fastapi~=0.111
|
||||||
|
jishaku~=2.5
|
||||||
|
fuzzywuzzy~=0.18
|
||||||
|
humanize~=4.9
|
||||||
|
|
Loading…
Reference in a new issue