70 lines
1.9 KiB
Python
70 lines
1.9 KiB
Python
|
import tomllib
|
||
|
import logging
|
||
|
from typing import Callable
|
||
|
|
||
|
import httpx
|
||
|
from pydantic import BaseModel, Field, AnyHttpUrl
|
||
|
|
||
|
log = logging.getLogger(__name__)
|
||
|
|
||
|
|
||
|
class ServerConfig(BaseModel):
|
||
|
name: str = Field(min_length=1, max_length=32)
|
||
|
base_url: AnyHttpUrl
|
||
|
gpu: bool = False
|
||
|
vram_gb: int = 4
|
||
|
throttle: bool = False
|
||
|
|
||
|
def __repr__(self):
|
||
|
return "<ServerConfig name={0.name} base_url={0.base_url} gpu={0.gpu!s} vram_gb={0.vram_gb}>".format(self)
|
||
|
|
||
|
def __str__(self):
|
||
|
return self.name
|
||
|
|
||
|
async def is_online(self) -> bool:
|
||
|
"""
|
||
|
Checks that the current server is online and responding to requests.
|
||
|
"""
|
||
|
async with httpx.AsyncClient(base_url=str(self.base_url)) as client:
|
||
|
try:
|
||
|
response = await client.get("/api/tags")
|
||
|
return response.status_code == 200
|
||
|
except httpx.RequestError:
|
||
|
return False
|
||
|
|
||
|
def __hash__(self):
|
||
|
return hash(self.base_url)
|
||
|
|
||
|
|
||
|
def get_servers(filter_func: Callable[[ServerConfig], bool] = None) -> list[ServerConfig]:
|
||
|
config = get_config()
|
||
|
keys = list(config["servers"].keys())
|
||
|
log.info("Servers: %r", keys)
|
||
|
try:
|
||
|
keys = config["servers"].pop("order")
|
||
|
log.info("Ordered keys: %r", keys)
|
||
|
except ValueError:
|
||
|
pass
|
||
|
servers = [ServerConfig(name=key, **config["servers"][key]) for key in keys]
|
||
|
if filter_func:
|
||
|
servers = list(filter(filter_func, servers))
|
||
|
return servers
|
||
|
|
||
|
|
||
|
def get_server(name_or_base_url: str) -> ServerConfig | None:
|
||
|
servers = get_servers()
|
||
|
for server in servers:
|
||
|
if server.name == name_or_base_url or server.base_url == name_or_base_url:
|
||
|
return server
|
||
|
return None
|
||
|
|
||
|
|
||
|
def get_config():
|
||
|
with open("config.toml", "rb") as file:
|
||
|
_loaded = tomllib.load(file)
|
||
|
|
||
|
_loaded.setdefault("servers", {})
|
||
|
_loaded["servers"].setdefault("order", [])
|
||
|
_loaded.setdefault("bot", {})
|
||
|
return _loaded
|