Add GPU-only requirement
All checks were successful
Build and Publish college-bot-v2 / build_and_publish (push) Successful in 14s

This commit is contained in:
Nexus 2024-06-09 19:24:47 +01:00
parent f9fa74b37a
commit 9fdddb8d45

View file

@ -911,75 +911,6 @@ class Ollama(commands.Cog):
for chunk in discord.utils.as_chunks(iter(embeds or [discord.Embed(title="No Content.")]), 10): for chunk in discord.utils.as_chunks(iter(embeds or [discord.Embed(title="No Content.")]), 10):
await ctx.respond(embeds=chunk, ephemeral=ephemeral) await ctx.respond(embeds=chunk, ephemeral=ephemeral)
@commands.message_command(name="Ask AI")
async def ask_ai(self, ctx: discord.ApplicationContext, message: discord.Message):
if not SERVER_KEYS:
return await ctx.respond("No servers available. Please try again later.")
thread = self.history.create_thread(message.author)
content = message.clean_content
if not content:
if message.embeds:
content = message.embeds[0].description or message.embeds[0].title
if not content:
return await ctx.respond("No content to send to AI.", ephemeral=True)
await ctx.defer()
user_message = {"role": "user", "content": message.content}
self.history.add_message(thread, "user", user_message["content"])
tried = set()
for _ in range(10):
server = self.next_server(tried)
if await self.check_server(CONFIG["ollama"][server]["base_url"]):
break
tried.add(server)
else:
return await ctx.respond("All servers are offline. Please try again later.", ephemeral=True)
client = OllamaClient(CONFIG["ollama"][server]["base_url"])
if not await client.has_model_named("orca-mini", "7b"):
with client.download_model("orca-mini", "7b") as handler:
async for _ in handler:
self.log.info(
"Downloading orca-mini:7b on server %r - %s (%.2f%%)", server, handler.status, handler.percent
)
if self.lock.locked():
await ctx.respond("Waiting for server to be free...")
async with self.lock:
await ctx.delete(delay=0.1)
messages = self.history.get_history(thread)
embed = discord.Embed(description="*Waking Ollama up...*")
self.log.debug("Acquiring lock")
async with self.servers[server]:
await ctx.respond(embed=embed, ephemeral=True)
last_edit = time.time()
msg = None
with client.new_chat("orca-mini:7b", messages) as handler:
self.log.info("New chat connection established.")
async for ln in handler:
done = ln.get("done") is True
embed.description = handler.result
if len(embed.description) >= 4096:
break
if len(embed.description) >= 3250:
embed.colour = discord.Color.gold()
embed.set_footer(text="Warning: {:,}/4096 characters.".format(len(embed.description)))
else:
embed.colour = discord.Color.blurple()
embed.set_footer(text="Using server %r" % server, icon_url=CONFIG["ollama"][server].get("icon_url"))
if msg is None:
await ctx.delete(delay=0.1)
msg = await message.reply(embed=embed)
last_edit = time.time()
else:
if time.time() >= (last_edit + 5.1) or done is True:
await msg.edit(embed=embed)
last_edit = time.time()
if done:
break
embed.colour = discord.Colour.dark_theme()
return await msg.edit(embed=embed)
@commands.command(name="ollama-status", aliases=["ollama_status", "os"]) @commands.command(name="ollama-status", aliases=["ollama_status", "os"])
async def ollama_status(self, ctx: commands.Context): async def ollama_status(self, ctx: commands.Context):
embed = discord.Embed( embed = discord.Embed(
@ -988,7 +919,7 @@ class Ollama(commands.Cog):
timestamp=discord.utils.utcnow(), timestamp=discord.utils.utcnow(),
) )
if CONFIG["ollama"].get("order"): if CONFIG["ollama"].get("order"):
ln = [f"Server order:"] ln = ["Server order:"]
for n, key in enumerate(CONFIG["ollama"].get("order"), start=1): for n, key in enumerate(CONFIG["ollama"].get("order"), start=1):
ln.append(f"{n}. {key!r}") ln.append(f"{n}. {key!r}")
embed.description = "\n".join(ln) embed.description = "\n".join(ln)
@ -1177,6 +1108,10 @@ class Ollama(commands.Cog):
tried = set() tried = set()
for _ in range(10): for _ in range(10):
server = self.next_server(tried) server = self.next_server(tried)
is_gpu = CONFIG["ollama"][server].get("is_gpu", False)
if not is_gpu:
self.log.info("Skipping server %r as it is not a GPU server.", server)
continue
if await self.check_server(CONFIG["ollama"][server]["base_url"]): if await self.check_server(CONFIG["ollama"][server]["base_url"]):
break break
tried.add(server) tried.add(server)
@ -1416,6 +1351,10 @@ class Ollama(commands.Cog):
tried = set() tried = set()
for _ in range(10): for _ in range(10):
server = self.next_server(tried) server = self.next_server(tried)
is_gpu = CONFIG["ollama"][server].get("is_gpu", False)
if not is_gpu:
self.log.info("Skipping server %r as it is not a GPU server.", server)
continue
if await self.check_server(CONFIG["ollama"][server]["base_url"]): if await self.check_server(CONFIG["ollama"][server]["base_url"]):
break break
tried.add(server) tried.add(server)