nonsensebot/app/modules/ai.py
2024-07-28 23:33:15 +01:00

147 lines
5.3 KiB
Python

import json
import logging
from pathlib import Path
from contextlib import asynccontextmanager
import httpx
import typing
from ollama import AsyncClient
import niobot
import tomllib
if typing.TYPE_CHECKING:
from ..main import TortoiseIntegratedBot
@asynccontextmanager
async def ollama_client(url: str) -> AsyncClient:
client = AsyncClient(url)
async with client._client:
yield client
class AIModule(niobot.Module):
bot: "TortoiseIntegratedBot"
async def find_server(self, gpu_only: bool = True) -> dict[str, str | bool] | None:
for name, cfg in self.bot.cfg["ollama"].items():
url = cfg["url"]
gpu = cfg["gpu"]
if gpu_only and not gpu:
continue
async with ollama_client(url) as client:
try:
await client.ps()
except (httpx.HTTPError, ConnectionError):
continue
else:
return {"name": name, **cfg}
@staticmethod
def read_users():
p = Path("./store/users.json")
if not p.exists():
return {}
return json.loads(p.read_text())
@staticmethod
def write_users(users: dict[str, str]):
p = Path("./store/users.json")
with open(p, "w") as _fd:
json.dump(users, _fd)
@niobot.command("ping")
async def ping_command(self, ctx: niobot.Context):
"""Checks the bot is running."""
reply = await ctx.respond("Pong!")
server = await self.find_server()
if not server:
await reply.edit("Pong :(\nNo servers available.")
return
await reply.edit(f"Pong!\nSelected server: {server['name']}")
@niobot.command("whitelist.add")
@niobot.is_owner()
async def whitelist_add(self, ctx: niobot.Context, user_id: str, model: str = "llama3:latest"):
"""[Owner] Adds a user to the whitelist."""
users = self.read_users()
users[user_id] = model
self.write_users(users)
await ctx.respond(f"Added {user_id} to the whitelist.")
@niobot.command("whitelist.list")
@niobot.is_owner()
async def whitelist_list(self, ctx: niobot.Context):
"""[Owner] Lists all users in the whitelist."""
users = self.read_users()
if not users:
await ctx.respond("No users in the whitelist.")
return
await ctx.respond("\n".join(f"{k}: {v}" for k, v in users.items()))
@niobot.command("whitelist.remove")
@niobot.is_owner()
async def whitelist_remove(self, ctx: niobot.Context, user_id: str):
"""[Owner] Removes a user from the whitelist."""
users = self.read_users()
if user_id not in users:
await ctx.respond(f"{user_id} not in the whitelist.")
return
del users[user_id]
self.write_users(users)
await ctx.respond(f"Removed {user_id} from the whitelist.")
@niobot.command("ollama.set-model")
async def set_model(self, ctx: niobot.Context, model: str):
"""Sets the model you want to use."""
users = self.read_users()
if ctx.message.sender not in users:
await ctx.respond("You must be whitelisted first.")
return
users[ctx.message.sender] = model
self.write_users(users)
await ctx.respond(f"Set model to {model}. Don't forget to pull it with `h!ollama.pull`.")
@niobot.command("ollama.pull")
async def pull_model(self, ctx: niobot.Context):
"""Pulls the model you set."""
users = self.read_users()
if ctx.message.sender not in users:
await ctx.respond("You need to set a model first. See: `h!help ollama.set-model`")
return
model = users[ctx.message.sender]
server = await self.find_server()
if not server:
await ctx.respond("No servers available.")
return
msg = await ctx.respond(f"Pulling {model} on {server['name']!r}...")
async with ollama_client(server["url"]) as client:
await client.pull(model)
await msg.edit(f"Pulled model {model}.")
@niobot.command("ollama.chat", greedy=True)
async def chat(self, ctx: niobot.Context):
"""Chat with the model."""
try:
message = " ".join(ctx.args)
users = self.read_users()
if ctx.message.sender not in users:
await ctx.respond("You need to set a model first. See: `h!help ollama.set-model`")
return
model = users[ctx.message.sender]
res = await ctx.respond("Finding server...")
server = await self.find_server()
if not server:
await res.edit(content="No servers available.")
return
async with ollama_client(server["url"]) as client:
await res.edit(content=f"Generating response...")
try:
response = await client.chat(model, [{"role": "user", "content": message}])
except httpx.HTTPError as e:
response = {"message": {"content": f"Error: {e}"}}
await res.edit(content=response["message"]["content"])
except Exception as e:
logging.exception(e)
await res.edit(content="An error occurred.")