2024-06-10 17:03:58 +01:00
|
|
|
import os
|
2024-06-10 03:14:52 +01:00
|
|
|
import tomllib
|
|
|
|
import logging
|
2024-06-11 00:53:48 +01:00
|
|
|
import urllib.parse
|
2024-06-10 03:14:52 +01:00
|
|
|
from typing import Callable
|
|
|
|
|
|
|
|
import httpx
|
|
|
|
from pydantic import BaseModel, Field, AnyHttpUrl
|
|
|
|
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
class ServerConfig(BaseModel):
|
|
|
|
name: str = Field(min_length=1, max_length=32)
|
|
|
|
base_url: AnyHttpUrl
|
|
|
|
gpu: bool = False
|
|
|
|
vram_gb: int = 4
|
2024-06-11 00:53:48 +01:00
|
|
|
default_model: str = "llama3:latest"
|
2024-06-10 03:14:52 +01:00
|
|
|
|
|
|
|
def __repr__(self):
|
|
|
|
return "<ServerConfig name={0.name} base_url={0.base_url} gpu={0.gpu!s} vram_gb={0.vram_gb}>".format(self)
|
|
|
|
|
|
|
|
def __str__(self):
|
|
|
|
return self.name
|
|
|
|
|
|
|
|
async def is_online(self) -> bool:
|
|
|
|
"""
|
|
|
|
Checks that the current server is online and responding to requests.
|
|
|
|
"""
|
|
|
|
async with httpx.AsyncClient(base_url=str(self.base_url)) as client:
|
|
|
|
try:
|
|
|
|
response = await client.get("/api/tags")
|
|
|
|
return response.status_code == 200
|
|
|
|
except httpx.RequestError:
|
|
|
|
return False
|
|
|
|
|
|
|
|
def __hash__(self):
|
|
|
|
return hash(self.base_url)
|
|
|
|
|
|
|
|
|
|
|
|
def get_servers(filter_func: Callable[[ServerConfig], bool] = None) -> list[ServerConfig]:
|
|
|
|
config = get_config()
|
|
|
|
keys = list(config["servers"].keys())
|
|
|
|
log.info("Servers: %r", keys)
|
|
|
|
try:
|
|
|
|
keys = config["servers"].pop("order")
|
|
|
|
log.info("Ordered keys: %r", keys)
|
|
|
|
except ValueError:
|
|
|
|
pass
|
|
|
|
servers = [ServerConfig(name=key, **config["servers"][key]) for key in keys]
|
|
|
|
if filter_func:
|
|
|
|
servers = list(filter(filter_func, servers))
|
|
|
|
return servers
|
|
|
|
|
|
|
|
|
|
|
|
def get_server(name_or_base_url: str) -> ServerConfig | None:
|
|
|
|
servers = get_servers()
|
|
|
|
for server in servers:
|
|
|
|
if server.name == name_or_base_url or server.base_url == name_or_base_url:
|
|
|
|
return server
|
2024-06-11 00:53:48 +01:00
|
|
|
|
|
|
|
try:
|
|
|
|
parsed = urllib.parse.urlparse(name_or_base_url)
|
|
|
|
except ValueError:
|
|
|
|
pass
|
|
|
|
else:
|
|
|
|
if parsed.netloc and parsed.scheme in ["http", "https"]:
|
|
|
|
defaults = {
|
2024-06-11 00:58:17 +01:00
|
|
|
"name": parsed.hostname,
|
2024-06-11 00:53:48 +01:00
|
|
|
"base_url": "{0.scheme}://{0.netloc}".format(parsed),
|
|
|
|
"gpu": False,
|
|
|
|
"vram_gb": 2,
|
|
|
|
"default_model": "orca-mini:3b"
|
|
|
|
}
|
|
|
|
if parsed.path and parsed.path.endswith(("/api", "/api/")):
|
|
|
|
defaults["base_url"] += parsed.path
|
|
|
|
parsed_qs = urllib.parse.parse_qs(parsed.query)
|
|
|
|
for key, values in parsed_qs.items():
|
|
|
|
if not values:
|
|
|
|
continue
|
|
|
|
if key == "gpu":
|
|
|
|
values = [
|
|
|
|
values[0][0].lower() in ("t", "1", "y")
|
|
|
|
]
|
|
|
|
elif key == "vram_gb":
|
|
|
|
try:
|
|
|
|
values = [
|
|
|
|
int(values[0])
|
|
|
|
]
|
|
|
|
except ValueError:
|
|
|
|
values = []
|
|
|
|
if values:
|
|
|
|
defaults[key] = values[0]
|
|
|
|
return ServerConfig(**defaults)
|
2024-06-10 03:14:52 +01:00
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
def get_config():
|
|
|
|
with open("config.toml", "rb") as file:
|
|
|
|
_loaded = tomllib.load(file)
|
|
|
|
|
|
|
|
_loaded.setdefault("servers", {})
|
|
|
|
_loaded["servers"].setdefault("order", [])
|
|
|
|
_loaded.setdefault("bot", {})
|
2024-06-10 17:03:58 +01:00
|
|
|
if database_url := os.getenv("DATABASE_URL"):
|
|
|
|
_loaded["bot"]["db_url"] = database_url
|
2024-06-10 03:14:52 +01:00
|
|
|
return _loaded
|