sentient-jimmy/jimmy/config.py
nexy7574 af11baeeaa
All checks were successful
Build and Publish / build_and_publish (push) Successful in 45s
Clarify on-the-fly server names
2024-06-11 01:15:25 +01:00

107 lines
3.3 KiB
Python

import os
import tomllib
import logging
import urllib.parse
from typing import Callable
import httpx
from pydantic import BaseModel, Field, AnyHttpUrl
log = logging.getLogger(__name__)
class ServerConfig(BaseModel):
name: str = Field(min_length=1, max_length=4096)
base_url: AnyHttpUrl
gpu: bool = False
vram_gb: int = 4
default_model: str = "llama3:latest"
def __repr__(self):
return "<ServerConfig name={0.name} base_url={0.base_url} gpu={0.gpu!s} vram_gb={0.vram_gb}>".format(self)
def __str__(self):
return self.name
async def is_online(self) -> bool:
"""
Checks that the current server is online and responding to requests.
"""
async with httpx.AsyncClient(base_url=str(self.base_url)) as client:
try:
response = await client.get("/api/tags")
return response.status_code == 200
except httpx.RequestError:
return False
def __hash__(self):
return hash(self.base_url)
def get_servers(filter_func: Callable[[ServerConfig], bool] = None) -> list[ServerConfig]:
config = get_config()
keys = list(config["servers"].keys())
log.info("Servers: %r", keys)
try:
keys = config["servers"].pop("order")
log.info("Ordered keys: %r", keys)
except ValueError:
pass
servers = [ServerConfig(name=key, **config["servers"][key]) for key in keys]
if filter_func:
servers = list(filter(filter_func, servers))
return servers
def get_server(name_or_base_url: str) -> ServerConfig | None:
servers = get_servers()
for server in servers:
if server.name == name_or_base_url or server.base_url == name_or_base_url:
return server
try:
parsed = urllib.parse.urlparse(name_or_base_url)
except ValueError:
pass
else:
if parsed.netloc and parsed.scheme in ["http", "https"]:
defaults = {
"name": ":temporary:-:%s:" % parsed.hostname,
"base_url": "{0.scheme}://{0.netloc}".format(parsed),
"gpu": False,
"vram_gb": 2,
"default_model": "orca-mini:3b"
}
if parsed.path and parsed.path.endswith(("/api", "/api/")):
defaults["base_url"] += parsed.path
parsed_qs = urllib.parse.parse_qs(parsed.query)
for key, values in parsed_qs.items():
if not values:
continue
if key == "gpu":
values = [
values[0][0].lower() in ("t", "1", "y")
]
elif key == "vram_gb":
try:
values = [
int(values[0])
]
except ValueError:
values = []
if values:
defaults[key] = values[0]
return ServerConfig(**defaults)
return None
def get_config():
with open("config.toml", "rb") as file:
_loaded = tomllib.load(file)
_loaded.setdefault("servers", {})
_loaded["servers"].setdefault("order", [])
_loaded.setdefault("bot", {})
if database_url := os.getenv("DATABASE_URL"):
_loaded["bot"]["db_url"] = database_url
return _loaded