import os import tomllib import logging from typing import Callable import httpx from pydantic import BaseModel, Field, AnyHttpUrl log = logging.getLogger(__name__) class ServerConfig(BaseModel): name: str = Field(min_length=1, max_length=32) base_url: AnyHttpUrl gpu: bool = False vram_gb: int = 4 throttle: bool = False def __repr__(self): return "".format(self) def __str__(self): return self.name async def is_online(self) -> bool: """ Checks that the current server is online and responding to requests. """ async with httpx.AsyncClient(base_url=str(self.base_url)) as client: try: response = await client.get("/api/tags") return response.status_code == 200 except httpx.RequestError: return False def __hash__(self): return hash(self.base_url) def get_servers(filter_func: Callable[[ServerConfig], bool] = None) -> list[ServerConfig]: config = get_config() keys = list(config["servers"].keys()) log.info("Servers: %r", keys) try: keys = config["servers"].pop("order") log.info("Ordered keys: %r", keys) except ValueError: pass servers = [ServerConfig(name=key, **config["servers"][key]) for key in keys] if filter_func: servers = list(filter(filter_func, servers)) return servers def get_server(name_or_base_url: str) -> ServerConfig | None: servers = get_servers() for server in servers: if server.name == name_or_base_url or server.base_url == name_or_base_url: return server return None def get_config(): with open("config.toml", "rb") as file: _loaded = tomllib.load(file) _loaded.setdefault("servers", {}) _loaded["servers"].setdefault("order", []) _loaded.setdefault("bot", {}) if database_url := os.getenv("DATABASE_URL"): _loaded["bot"]["db_url"] = database_url return _loaded