fix next_server
All checks were successful
Build and Publish Jimmy.2 / build_and_publish (push) Successful in 7s

This commit is contained in:
Nexus 2024-05-31 18:00:56 +01:00
parent 34a3cc622e
commit 70afac54ca
Signed by: nex
GPG key ID: 0FA334385D0B689F

View file

@ -449,15 +449,17 @@ class Ollama(commands.Cog):
for key in CONFIG["ollama"]["order"]: for key in CONFIG["ollama"]["order"]:
self.servers[key] = asyncio.Lock() self.servers[key] = asyncio.Lock()
def next_server(self) -> str: def next_server(self, tried: typing.Iterable[str] = None) -> str:
""" """
Returns the next server key. Returns the next server key.
:param tried: A list of keys already tried
:returns: The key for the server :returns: The key for the server
:raises: RuntimeError - If no servers are available. :raises: RuntimeError - If no servers are available.
""" """
tried = tried or set()
for server_name, locked in self.servers.items(): for server_name, locked in self.servers.items():
if not locked.locked(): if locked.locked() is False and server_name not in tried:
return server_name return server_name
raise RuntimeError("No servers available.") raise RuntimeError("No servers available.")
@ -593,11 +595,15 @@ class Ollama(commands.Cog):
embed.set_footer(text="Using server %r" % server, icon_url=server_config.get("icon_url")) embed.set_footer(text="Using server %r" % server, icon_url=server_config.get("icon_url"))
await ctx.respond(embed=embed) await ctx.respond(embed=embed)
if not await self.check_server(server_config["base_url"]): if not await self.check_server(server_config["base_url"]):
tried = {server}
for i in range(10): for i in range(10):
try: try:
server = self.next_server() server = self.next_server(tried)
except RuntimeError: except RuntimeError:
tried.add(server)
continue continue
finally:
tried.add(server)
embed = discord.Embed( embed = discord.Embed(
title="Server was offline. Trying next server.", title="Server was offline. Trying next server.",
description=f"Trying server {server}...", description=f"Trying server {server}...",
@ -709,8 +715,8 @@ class Ollama(commands.Cog):
) )
embed.set_author( embed.set_author(
name=model, name=model,
url="https://ollama.ai/library/" + model.split(":")[0], url="https://ollama.com/library/" + model.split(":")[0],
icon_url="https://ollama.ai/public/ollama.png", icon_url="https://ollama.com/public/ollama.png",
) )
embed.add_field( embed.add_field(
name="Prompt", value=">>> " + textwrap.shorten(query, width=1020, placeholder="..."), inline=False name="Prompt", value=">>> " + textwrap.shorten(query, width=1020, placeholder="..."), inline=False
@ -885,10 +891,12 @@ class Ollama(commands.Cog):
user_message = {"role": "user", "content": message.content} user_message = {"role": "user", "content": message.content}
self.history.add_message(thread, "user", user_message["content"]) self.history.add_message(thread, "user", user_message["content"])
tried = set()
for _ in range(10): for _ in range(10):
server = self.next_server() server = self.next_server()
if await self.check_server(CONFIG["ollama"][server]["base_url"]): if await self.check_server(CONFIG["ollama"][server]["base_url"]):
break break
tried.add(server)
else: else:
return await ctx.respond("All servers are offline. Please try again later.", ephemeral=True) return await ctx.respond("All servers are offline. Please try again later.", ephemeral=True)