From f3aadd2ce092a076fd0d84abe2b675d7de55cbb3 Mon Sep 17 00:00:00 2001 From: nexy7574 Date: Wed, 10 Jan 2024 15:11:36 +0000 Subject: [PATCH] Enable load balanced servers --- src/cogs/ollama.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/src/cogs/ollama.py b/src/cogs/ollama.py index 1612ba2..51f7e95 100644 --- a/src/cogs/ollama.py +++ b/src/cogs/ollama.py @@ -1,5 +1,4 @@ import asyncio -import collections import json import logging import textwrap @@ -35,6 +34,13 @@ class Ollama(commands.Cog): def __init__(self, bot: commands.Bot): self.bot = bot self.log = logging.getLogger("jimmy.cogs.ollama") + self.last_server = 0 + + def next_server(self, increment: bool = True) -> str: + """Returns the next server key.""" + if increment: + self.last_server += 1 + return SERVER_KEYS[self.last_server % len(SERVER_KEYS)] async def ollama_stream(self, iterator: aiohttp.StreamReader) -> typing.AsyncIterator[dict]: async for line in iterator: @@ -73,7 +79,7 @@ class Ollama(commands.Cog): discord.Option( str, "The server to use for ollama.", - default=SERVER_KEYS[0], + default="next", choices=SERVER_KEYS ) ], @@ -86,12 +92,14 @@ class Ollama(commands.Cog): try: model, tag = model.split(":", 1) model = model + ":" + tag - self.log.debug("Model %r already has a tag") + self.log.debug("Model %r already has a tag", model) except ValueError: model = model + ":latest" self.log.debug("Resolved model to %r" % model) - if server not in CONFIG["ollama"]: + if server == "next": + server = self.next_server() + elif server not in CONFIG["ollama"]: await ctx.respond("Invalid server") return