Enable load balanced servers

This commit is contained in:
Nexus 2024-01-10 15:11:36 +00:00
parent 4028afce3d
commit f3aadd2ce0

View file

@ -1,5 +1,4 @@
import asyncio
import collections
import json
import logging
import textwrap
@ -35,6 +34,13 @@ class Ollama(commands.Cog):
def __init__(self, bot: commands.Bot):
self.bot = bot
self.log = logging.getLogger("jimmy.cogs.ollama")
self.last_server = 0
def next_server(self, increment: bool = True) -> str:
"""Returns the next server key."""
if increment:
self.last_server += 1
return SERVER_KEYS[self.last_server % len(SERVER_KEYS)]
async def ollama_stream(self, iterator: aiohttp.StreamReader) -> typing.AsyncIterator[dict]:
async for line in iterator:
@ -73,7 +79,7 @@ class Ollama(commands.Cog):
discord.Option(
str,
"The server to use for ollama.",
default=SERVER_KEYS[0],
default="next",
choices=SERVER_KEYS
)
],
@ -86,12 +92,14 @@ class Ollama(commands.Cog):
try:
model, tag = model.split(":", 1)
model = model + ":" + tag
self.log.debug("Model %r already has a tag")
self.log.debug("Model %r already has a tag", model)
except ValueError:
model = model + ":latest"
self.log.debug("Resolved model to %r" % model)
if server not in CONFIG["ollama"]:
if server == "next":
server = self.next_server()
elif server not in CONFIG["ollama"]:
await ctx.respond("Invalid server")
return