Add a failover/priority system to ollama
All checks were successful
Build and Publish Jimmy.2 / build_and_publish (push) Successful in 8s

This commit is contained in:
Nexus 2024-05-28 00:48:18 +01:00
parent 4dc0e6418f
commit 6ed7de1d73
Signed by: nex
GPG key ID: 0FA334385D0B689F
2 changed files with 24 additions and 17 deletions

View file

@ -439,18 +439,27 @@ class Ollama(commands.Cog):
def __init__(self, bot: commands.Bot): def __init__(self, bot: commands.Bot):
self.bot = bot self.bot = bot
self.log = logging.getLogger("jimmy.cogs.ollama") self.log = logging.getLogger("jimmy.cogs.ollama")
self.last_server = 0
self.contexts = {} self.contexts = {}
self.history = ChatHistory() self.history = ChatHistory()
self.lock = asyncio.Lock() self.servers = {
server: asyncio.Lock() for server in CONFIG["ollama"]
}
if CONFIG["ollama"].get("order"):
self.servers = {}
for key in CONFIG["ollama"]["order"]:
self.servers[key] = asyncio.Lock()
def next_server(self, increment: bool = True) -> str: def next_server(self) -> str:
"""Returns the next server key.""" """
if increment: Returns the next server key.
self.last_server += 1
s = SERVER_KEYS[self.last_server % len(SERVER_KEYS)] :returns: The key for the server
self.log.info("Next server is %s", s) :raises: RuntimeError - If no servers are available.
return s """
for server_name, locked in self.servers.items():
if not locked.locked():
return server_name
raise RuntimeError("No servers available.")
async def check_server(self, url: str) -> bool: async def check_server(self, url: str) -> bool:
"""Checks that a server is online and responding.""" """Checks that a server is online and responding."""
@ -558,9 +567,7 @@ class Ollama(commands.Cog):
else: else:
image_data = None image_data = None
if server == "next": if server not in CONFIG["ollama"]:
server = self.next_server()
elif server not in CONFIG["ollama"]:
await ctx.respond("Invalid server") await ctx.respond("Invalid server")
return return
@ -587,7 +594,10 @@ class Ollama(commands.Cog):
await ctx.respond(embed=embed) await ctx.respond(embed=embed)
if not await self.check_server(server_config["base_url"]): if not await self.check_server(server_config["base_url"]):
for i in range(10): for i in range(10):
try:
server = self.next_server() server = self.next_server()
except RuntimeError:
continue
embed = discord.Embed( embed = discord.Embed(
title="Server was offline. Trying next server.", title="Server was offline. Trying next server.",
description=f"Trying server {server}...", description=f"Trying server {server}...",

View file

@ -37,7 +37,7 @@ except FileNotFoundError:
CONFIG = {} CONFIG = {}
CONFIG.setdefault("logging", {}) CONFIG.setdefault("logging", {})
CONFIG.setdefault("jimmy", {}) CONFIG.setdefault("jimmy", {})
CONFIG.setdefault("ollama", {}) CONFIG.setdefault("ollama", {"order": []})
CONFIG.setdefault("screenshot", {}) CONFIG.setdefault("screenshot", {})
CONFIG.setdefault("responder", {}) CONFIG.setdefault("responder", {})
CONFIG.setdefault("network", {}) CONFIG.setdefault("network", {})
@ -45,9 +45,6 @@ CONFIG.setdefault("quote_a", {"channel": None})
CONFIG.setdefault("redis", {"host": "redis", "port": 6379, "decode_responses": True}) CONFIG.setdefault("redis", {"host": "redis", "port": 6379, "decode_responses": True})
CONFIG.setdefault("starboard", {}) CONFIG.setdefault("starboard", {})
# if CONFIG["redis"].pop("db", None) is not None:
# log.warning("`redis.db` cannot be manually specified, each cog that uses redis has its own db value! Value ignored")
if CONFIG["redis"].pop("no_ping", None) is not None: if CONFIG["redis"].pop("no_ping", None) is not None:
log.warning("`redis.no_ping` was deprecated after 808D621F. Ping is now always mandatory.") log.warning("`redis.no_ping` was deprecated after 808D621F. Ping is now always mandatory.")