diff --git a/src/cogs/ollama.py b/src/cogs/ollama.py index 2f6e770..e2a7555 100644 --- a/src/cogs/ollama.py +++ b/src/cogs/ollama.py @@ -439,18 +439,27 @@ class Ollama(commands.Cog): def __init__(self, bot: commands.Bot): self.bot = bot self.log = logging.getLogger("jimmy.cogs.ollama") - self.last_server = 0 self.contexts = {} self.history = ChatHistory() - self.lock = asyncio.Lock() + self.servers = { + server: asyncio.Lock() for server in CONFIG["ollama"] + } + if CONFIG["ollama"].get("order"): + self.servers = {} + for key in CONFIG["ollama"]["order"]: + self.servers[key] = asyncio.Lock() - def next_server(self, increment: bool = True) -> str: - """Returns the next server key.""" - if increment: - self.last_server += 1 - s = SERVER_KEYS[self.last_server % len(SERVER_KEYS)] - self.log.info("Next server is %s", s) - return s + def next_server(self) -> str: + """ + Returns the next server key. + + :returns: The key for the server + :raises: RuntimeError - If no servers are available. + """ + for server_name, locked in self.servers.items(): + if not locked.locked(): + return server_name + raise RuntimeError("No servers available.") async def check_server(self, url: str) -> bool: """Checks that a server is online and responding.""" @@ -558,9 +567,7 @@ class Ollama(commands.Cog): else: image_data = None - if server == "next": - server = self.next_server() - elif server not in CONFIG["ollama"]: + if server not in CONFIG["ollama"]: await ctx.respond("Invalid server") return @@ -587,7 +594,10 @@ class Ollama(commands.Cog): await ctx.respond(embed=embed) if not await self.check_server(server_config["base_url"]): for i in range(10): - server = self.next_server() + try: + server = self.next_server() + except RuntimeError: + continue embed = discord.Embed( title="Server was offline. Trying next server.", description=f"Trying server {server}...", diff --git a/src/conf.py b/src/conf.py index e3c677e..b2a4030 100644 --- a/src/conf.py +++ b/src/conf.py @@ -37,7 +37,7 @@ except FileNotFoundError: CONFIG = {} CONFIG.setdefault("logging", {}) CONFIG.setdefault("jimmy", {}) -CONFIG.setdefault("ollama", {}) +CONFIG.setdefault("ollama", {"order": []}) CONFIG.setdefault("screenshot", {}) CONFIG.setdefault("responder", {}) CONFIG.setdefault("network", {}) @@ -45,9 +45,6 @@ CONFIG.setdefault("quote_a", {"channel": None}) CONFIG.setdefault("redis", {"host": "redis", "port": 6379, "decode_responses": True}) CONFIG.setdefault("starboard", {}) -# if CONFIG["redis"].pop("db", None) is not None: -# log.warning("`redis.db` cannot be manually specified, each cog that uses redis has its own db value! Value ignored") - if CONFIG["redis"].pop("no_ping", None) is not None: log.warning("`redis.no_ping` was deprecated after 808D621F. Ping is now always mandatory.")