Enable load balanced servers
This commit is contained in:
parent
4028afce3d
commit
f3aadd2ce0
1 changed files with 12 additions and 4 deletions
|
@ -1,5 +1,4 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import collections
|
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import textwrap
|
import textwrap
|
||||||
|
@ -35,6 +34,13 @@ class Ollama(commands.Cog):
|
||||||
def __init__(self, bot: commands.Bot):
|
def __init__(self, bot: commands.Bot):
|
||||||
self.bot = bot
|
self.bot = bot
|
||||||
self.log = logging.getLogger("jimmy.cogs.ollama")
|
self.log = logging.getLogger("jimmy.cogs.ollama")
|
||||||
|
self.last_server = 0
|
||||||
|
|
||||||
|
def next_server(self, increment: bool = True) -> str:
|
||||||
|
"""Returns the next server key."""
|
||||||
|
if increment:
|
||||||
|
self.last_server += 1
|
||||||
|
return SERVER_KEYS[self.last_server % len(SERVER_KEYS)]
|
||||||
|
|
||||||
async def ollama_stream(self, iterator: aiohttp.StreamReader) -> typing.AsyncIterator[dict]:
|
async def ollama_stream(self, iterator: aiohttp.StreamReader) -> typing.AsyncIterator[dict]:
|
||||||
async for line in iterator:
|
async for line in iterator:
|
||||||
|
@ -73,7 +79,7 @@ class Ollama(commands.Cog):
|
||||||
discord.Option(
|
discord.Option(
|
||||||
str,
|
str,
|
||||||
"The server to use for ollama.",
|
"The server to use for ollama.",
|
||||||
default=SERVER_KEYS[0],
|
default="next",
|
||||||
choices=SERVER_KEYS
|
choices=SERVER_KEYS
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
|
@ -86,12 +92,14 @@ class Ollama(commands.Cog):
|
||||||
try:
|
try:
|
||||||
model, tag = model.split(":", 1)
|
model, tag = model.split(":", 1)
|
||||||
model = model + ":" + tag
|
model = model + ":" + tag
|
||||||
self.log.debug("Model %r already has a tag")
|
self.log.debug("Model %r already has a tag", model)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
model = model + ":latest"
|
model = model + ":latest"
|
||||||
self.log.debug("Resolved model to %r" % model)
|
self.log.debug("Resolved model to %r" % model)
|
||||||
|
|
||||||
if server not in CONFIG["ollama"]:
|
if server == "next":
|
||||||
|
server = self.next_server()
|
||||||
|
elif server not in CONFIG["ollama"]:
|
||||||
await ctx.respond("Invalid server")
|
await ctx.respond("Invalid server")
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue