import json import logging from pathlib import Path from contextlib import asynccontextmanager import httpx import typing from ollama import AsyncClient import niobot import tomllib if typing.TYPE_CHECKING: from ..main import TortoiseIntegratedBot @asynccontextmanager async def ollama_client(url: str) -> AsyncClient: client = AsyncClient(url) async with client._client: yield client class AIModule(niobot.Module): bot: "TortoiseIntegratedBot" async def find_server(self, gpu_only: bool = True) -> dict[str, str | bool] | None: for name, cfg in self.bot.cfg["ollama"].items(): url = cfg["url"] gpu = cfg["gpu"] if gpu_only and not gpu: continue async with ollama_client(url) as client: try: await client.ps() except (httpx.HTTPError, ConnectionError): continue else: return {"name": name, **cfg} @staticmethod def read_users(): p = Path("./store/users.json") if not p.exists(): return {} return json.loads(p.read_text()) @staticmethod def write_users(users: dict[str, str]): p = Path("./store/users.json") with open(p, "w") as _fd: json.dump(users, _fd) @niobot.command("ping") async def ping_command(self, ctx: niobot.Context): """Checks the bot is running.""" reply = await ctx.respond("Pong!") server = await self.find_server() if not server: await reply.edit("Pong :(\nNo servers available.") return await reply.edit(f"Pong!\nSelected server: {server['name']}") @niobot.command("whitelist.add") @niobot.is_owner() async def whitelist_add(self, ctx: niobot.Context, user_id: str, model: str = "llama3:latest"): """[Owner] Adds a user to the whitelist.""" users = self.read_users() users[user_id] = model self.write_users(users) await ctx.respond(f"Added {user_id} to the whitelist.") @niobot.command("whitelist.list") @niobot.is_owner() async def whitelist_list(self, ctx: niobot.Context): """[Owner] Lists all users in the whitelist.""" users = self.read_users() if not users: await ctx.respond("No users in the whitelist.") return await ctx.respond("\n".join(f"{k}: {v}" for k, v in users.items())) @niobot.command("whitelist.remove") @niobot.is_owner() async def whitelist_remove(self, ctx: niobot.Context, user_id: str): """[Owner] Removes a user from the whitelist.""" users = self.read_users() if user_id not in users: await ctx.respond(f"{user_id} not in the whitelist.") return del users[user_id] self.write_users(users) await ctx.respond(f"Removed {user_id} from the whitelist.") @niobot.command("ollama.set-model") async def set_model(self, ctx: niobot.Context, model: str): """Sets the model you want to use.""" users = self.read_users() if ctx.message.sender not in users: await ctx.respond("You must be whitelisted first.") return users[ctx.message.sender] = model self.write_users(users) await ctx.respond(f"Set model to {model}. Don't forget to pull it with `h!ollama.pull`.") @niobot.command("ollama.pull") async def pull_model(self, ctx: niobot.Context): """Pulls the model you set.""" users = self.read_users() if ctx.message.sender not in users: await ctx.respond("You need to set a model first. See: `h!help ollama.set-model`") return model = users[ctx.message.sender] server = await self.find_server() if not server: await ctx.respond("No servers available.") return msg = await ctx.respond(f"Pulling {model} on {server['name']!r}...") async with ollama_client(server["url"]) as client: await client.pull(model) await msg.edit(f"Pulled model {model}.") @niobot.command("ollama.chat", greedy=True) async def chat(self, ctx: niobot.Context): """Chat with the model.""" try: message = " ".join(ctx.args) users = self.read_users() if ctx.message.sender not in users: await ctx.respond("You need to set a model first. See: `h!help ollama.set-model`") return model = users[ctx.message.sender] res = await ctx.respond("Finding server...") server = await self.find_server() if not server: await res.edit(content="No servers available.") return async with ollama_client(server["url"]) as client: await res.edit(content=f"Generating response...") try: response = await client.chat(model, [{"role": "user", "content": message}]) except httpx.HTTPError as e: response = {"message": {"content": f"Error: {e}"}} await res.edit(content=response["message"]["content"]) except Exception as e: logging.exception(e) await res.edit(content="An error occurred.")