Enable load balanced servers

This commit is contained in:
Nexus 2024-01-10 15:11:36 +00:00
parent 4028afce3d
commit f3aadd2ce0

View file

@ -1,5 +1,4 @@
import asyncio import asyncio
import collections
import json import json
import logging import logging
import textwrap import textwrap
@ -35,6 +34,13 @@ class Ollama(commands.Cog):
def __init__(self, bot: commands.Bot): def __init__(self, bot: commands.Bot):
self.bot = bot self.bot = bot
self.log = logging.getLogger("jimmy.cogs.ollama") self.log = logging.getLogger("jimmy.cogs.ollama")
self.last_server = 0
def next_server(self, increment: bool = True) -> str:
"""Returns the next server key."""
if increment:
self.last_server += 1
return SERVER_KEYS[self.last_server % len(SERVER_KEYS)]
async def ollama_stream(self, iterator: aiohttp.StreamReader) -> typing.AsyncIterator[dict]: async def ollama_stream(self, iterator: aiohttp.StreamReader) -> typing.AsyncIterator[dict]:
async for line in iterator: async for line in iterator:
@ -73,7 +79,7 @@ class Ollama(commands.Cog):
discord.Option( discord.Option(
str, str,
"The server to use for ollama.", "The server to use for ollama.",
default=SERVER_KEYS[0], default="next",
choices=SERVER_KEYS choices=SERVER_KEYS
) )
], ],
@ -86,12 +92,14 @@ class Ollama(commands.Cog):
try: try:
model, tag = model.split(":", 1) model, tag = model.split(":", 1)
model = model + ":" + tag model = model + ":" + tag
self.log.debug("Model %r already has a tag") self.log.debug("Model %r already has a tag", model)
except ValueError: except ValueError:
model = model + ":latest" model = model + ":latest"
self.log.debug("Resolved model to %r" % model) self.log.debug("Resolved model to %r" % model)
if server not in CONFIG["ollama"]: if server == "next":
server = self.next_server()
elif server not in CONFIG["ollama"]:
await ctx.respond("Invalid server") await ctx.respond("Invalid server")
return return