Add a failover/priority system to ollama
All checks were successful
Build and Publish Jimmy.2 / build_and_publish (push) Successful in 8s
All checks were successful
Build and Publish Jimmy.2 / build_and_publish (push) Successful in 8s
This commit is contained in:
parent
4dc0e6418f
commit
6ed7de1d73
2 changed files with 24 additions and 17 deletions
|
@ -439,18 +439,27 @@ class Ollama(commands.Cog):
|
||||||
def __init__(self, bot: commands.Bot):
|
def __init__(self, bot: commands.Bot):
|
||||||
self.bot = bot
|
self.bot = bot
|
||||||
self.log = logging.getLogger("jimmy.cogs.ollama")
|
self.log = logging.getLogger("jimmy.cogs.ollama")
|
||||||
self.last_server = 0
|
|
||||||
self.contexts = {}
|
self.contexts = {}
|
||||||
self.history = ChatHistory()
|
self.history = ChatHistory()
|
||||||
self.lock = asyncio.Lock()
|
self.servers = {
|
||||||
|
server: asyncio.Lock() for server in CONFIG["ollama"]
|
||||||
|
}
|
||||||
|
if CONFIG["ollama"].get("order"):
|
||||||
|
self.servers = {}
|
||||||
|
for key in CONFIG["ollama"]["order"]:
|
||||||
|
self.servers[key] = asyncio.Lock()
|
||||||
|
|
||||||
def next_server(self, increment: bool = True) -> str:
|
def next_server(self) -> str:
|
||||||
"""Returns the next server key."""
|
"""
|
||||||
if increment:
|
Returns the next server key.
|
||||||
self.last_server += 1
|
|
||||||
s = SERVER_KEYS[self.last_server % len(SERVER_KEYS)]
|
:returns: The key for the server
|
||||||
self.log.info("Next server is %s", s)
|
:raises: RuntimeError - If no servers are available.
|
||||||
return s
|
"""
|
||||||
|
for server_name, locked in self.servers.items():
|
||||||
|
if not locked.locked():
|
||||||
|
return server_name
|
||||||
|
raise RuntimeError("No servers available.")
|
||||||
|
|
||||||
async def check_server(self, url: str) -> bool:
|
async def check_server(self, url: str) -> bool:
|
||||||
"""Checks that a server is online and responding."""
|
"""Checks that a server is online and responding."""
|
||||||
|
@ -558,9 +567,7 @@ class Ollama(commands.Cog):
|
||||||
else:
|
else:
|
||||||
image_data = None
|
image_data = None
|
||||||
|
|
||||||
if server == "next":
|
if server not in CONFIG["ollama"]:
|
||||||
server = self.next_server()
|
|
||||||
elif server not in CONFIG["ollama"]:
|
|
||||||
await ctx.respond("Invalid server")
|
await ctx.respond("Invalid server")
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -587,7 +594,10 @@ class Ollama(commands.Cog):
|
||||||
await ctx.respond(embed=embed)
|
await ctx.respond(embed=embed)
|
||||||
if not await self.check_server(server_config["base_url"]):
|
if not await self.check_server(server_config["base_url"]):
|
||||||
for i in range(10):
|
for i in range(10):
|
||||||
server = self.next_server()
|
try:
|
||||||
|
server = self.next_server()
|
||||||
|
except RuntimeError:
|
||||||
|
continue
|
||||||
embed = discord.Embed(
|
embed = discord.Embed(
|
||||||
title="Server was offline. Trying next server.",
|
title="Server was offline. Trying next server.",
|
||||||
description=f"Trying server {server}...",
|
description=f"Trying server {server}...",
|
||||||
|
|
|
@ -37,7 +37,7 @@ except FileNotFoundError:
|
||||||
CONFIG = {}
|
CONFIG = {}
|
||||||
CONFIG.setdefault("logging", {})
|
CONFIG.setdefault("logging", {})
|
||||||
CONFIG.setdefault("jimmy", {})
|
CONFIG.setdefault("jimmy", {})
|
||||||
CONFIG.setdefault("ollama", {})
|
CONFIG.setdefault("ollama", {"order": []})
|
||||||
CONFIG.setdefault("screenshot", {})
|
CONFIG.setdefault("screenshot", {})
|
||||||
CONFIG.setdefault("responder", {})
|
CONFIG.setdefault("responder", {})
|
||||||
CONFIG.setdefault("network", {})
|
CONFIG.setdefault("network", {})
|
||||||
|
@ -45,9 +45,6 @@ CONFIG.setdefault("quote_a", {"channel": None})
|
||||||
CONFIG.setdefault("redis", {"host": "redis", "port": 6379, "decode_responses": True})
|
CONFIG.setdefault("redis", {"host": "redis", "port": 6379, "decode_responses": True})
|
||||||
CONFIG.setdefault("starboard", {})
|
CONFIG.setdefault("starboard", {})
|
||||||
|
|
||||||
# if CONFIG["redis"].pop("db", None) is not None:
|
|
||||||
# log.warning("`redis.db` cannot be manually specified, each cog that uses redis has its own db value! Value ignored")
|
|
||||||
|
|
||||||
if CONFIG["redis"].pop("no_ping", None) is not None:
|
if CONFIG["redis"].pop("no_ping", None) is not None:
|
||||||
log.warning("`redis.no_ping` was deprecated after 808D621F. Ping is now always mandatory.")
|
log.warning("`redis.no_ping` was deprecated after 808D621F. Ping is now always mandatory.")
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue