2024-07-29 00:38:53 +01:00
|
|
|
import asyncio
|
2024-07-28 23:33:15 +01:00
|
|
|
import json
|
|
|
|
import logging
|
|
|
|
from pathlib import Path
|
|
|
|
from contextlib import asynccontextmanager
|
|
|
|
|
|
|
|
import httpx
|
|
|
|
import typing
|
2024-07-29 00:36:32 +01:00
|
|
|
|
2024-07-28 23:33:15 +01:00
|
|
|
from ollama import AsyncClient
|
|
|
|
|
|
|
|
import niobot
|
|
|
|
|
|
|
|
if typing.TYPE_CHECKING:
|
|
|
|
from ..main import TortoiseIntegratedBot
|
|
|
|
|
|
|
|
|
|
|
|
@asynccontextmanager
|
|
|
|
async def ollama_client(url: str) -> AsyncClient:
|
|
|
|
client = AsyncClient(url)
|
|
|
|
async with client._client:
|
|
|
|
yield client
|
|
|
|
|
|
|
|
|
|
|
|
class AIModule(niobot.Module):
|
|
|
|
bot: "TortoiseIntegratedBot"
|
2024-07-29 00:19:04 +01:00
|
|
|
log = logging.getLogger(__name__)
|
2024-07-28 23:33:15 +01:00
|
|
|
|
|
|
|
async def find_server(self, gpu_only: bool = True) -> dict[str, str | bool] | None:
|
|
|
|
for name, cfg in self.bot.cfg["ollama"].items():
|
|
|
|
url = cfg["url"]
|
|
|
|
gpu = cfg["gpu"]
|
|
|
|
if gpu_only and not gpu:
|
2024-07-29 00:19:04 +01:00
|
|
|
self.log.info("Skipping %r, need GPU.", name)
|
2024-07-28 23:33:15 +01:00
|
|
|
continue
|
|
|
|
async with ollama_client(url) as client:
|
|
|
|
try:
|
2024-07-29 00:19:04 +01:00
|
|
|
self.log.info("Trying %r", name)
|
2024-07-28 23:33:15 +01:00
|
|
|
await client.ps()
|
|
|
|
except (httpx.HTTPError, ConnectionError):
|
2024-07-29 00:19:04 +01:00
|
|
|
self.log.warning("%r is offline, trying next.", name)
|
2024-07-28 23:33:15 +01:00
|
|
|
continue
|
|
|
|
else:
|
2024-07-29 00:19:04 +01:00
|
|
|
self.log.info("%r is online.")
|
2024-07-28 23:33:15 +01:00
|
|
|
return {"name": name, **cfg}
|
2024-07-29 00:19:04 +01:00
|
|
|
self.log.warning("No suitable ollama server is online.")
|
2024-07-28 23:33:15 +01:00
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def read_users():
|
|
|
|
p = Path("./store/users.json")
|
|
|
|
if not p.exists():
|
|
|
|
return {}
|
|
|
|
return json.loads(p.read_text())
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def write_users(users: dict[str, str]):
|
|
|
|
p = Path("./store/users.json")
|
|
|
|
with open(p, "w") as _fd:
|
|
|
|
json.dump(users, _fd)
|
|
|
|
|
|
|
|
@niobot.command("whitelist.add")
|
|
|
|
@niobot.is_owner()
|
|
|
|
async def whitelist_add(self, ctx: niobot.Context, user_id: str, model: str = "llama3:latest"):
|
|
|
|
"""[Owner] Adds a user to the whitelist."""
|
|
|
|
users = self.read_users()
|
|
|
|
users[user_id] = model
|
|
|
|
self.write_users(users)
|
|
|
|
await ctx.respond(f"Added {user_id} to the whitelist.")
|
|
|
|
|
|
|
|
@niobot.command("whitelist.list")
|
|
|
|
@niobot.is_owner()
|
|
|
|
async def whitelist_list(self, ctx: niobot.Context):
|
|
|
|
"""[Owner] Lists all users in the whitelist."""
|
|
|
|
users = self.read_users()
|
|
|
|
if not users:
|
|
|
|
await ctx.respond("No users in the whitelist.")
|
|
|
|
return
|
|
|
|
await ctx.respond("\n".join(f"{k}: {v}" for k, v in users.items()))
|
|
|
|
|
|
|
|
@niobot.command("whitelist.remove")
|
|
|
|
@niobot.is_owner()
|
|
|
|
async def whitelist_remove(self, ctx: niobot.Context, user_id: str):
|
|
|
|
"""[Owner] Removes a user from the whitelist."""
|
|
|
|
users = self.read_users()
|
|
|
|
if user_id not in users:
|
|
|
|
await ctx.respond(f"{user_id} not in the whitelist.")
|
|
|
|
return
|
|
|
|
del users[user_id]
|
|
|
|
self.write_users(users)
|
|
|
|
await ctx.respond(f"Removed {user_id} from the whitelist.")
|
|
|
|
|
|
|
|
@niobot.command("ollama.set-model")
|
2024-09-12 01:45:38 +01:00
|
|
|
@niobot.from_homeserver("nexy7574.co.uk", "nicroxio.co.uk", "shronk.net")
|
2024-07-28 23:33:15 +01:00
|
|
|
async def set_model(self, ctx: niobot.Context, model: str):
|
|
|
|
"""Sets the model you want to use."""
|
|
|
|
users = self.read_users()
|
|
|
|
if ctx.message.sender not in users:
|
|
|
|
await ctx.respond("You must be whitelisted first.")
|
|
|
|
return
|
|
|
|
users[ctx.message.sender] = model
|
|
|
|
self.write_users(users)
|
|
|
|
await ctx.respond(f"Set model to {model}. Don't forget to pull it with `h!ollama.pull`.")
|
|
|
|
|
|
|
|
@niobot.command("ollama.pull")
|
2024-09-12 01:45:38 +01:00
|
|
|
@niobot.from_homeserver("nexy7574.co.uk", "nicroxio.co.uk", "shronk.net")
|
2024-07-28 23:33:15 +01:00
|
|
|
async def pull_model(self, ctx: niobot.Context):
|
|
|
|
"""Pulls the model you set."""
|
|
|
|
users = self.read_users()
|
|
|
|
if ctx.message.sender not in users:
|
|
|
|
await ctx.respond("You need to set a model first. See: `h!help ollama.set-model`")
|
|
|
|
return
|
|
|
|
model = users[ctx.message.sender]
|
2024-07-29 00:21:06 +01:00
|
|
|
server = await self.find_server(gpu_only=False)
|
2024-07-28 23:33:15 +01:00
|
|
|
if not server:
|
|
|
|
await ctx.respond("No servers available.")
|
|
|
|
return
|
|
|
|
msg = await ctx.respond(f"Pulling {model} on {server['name']!r}...")
|
|
|
|
async with ollama_client(server["url"]) as client:
|
|
|
|
await client.pull(model)
|
|
|
|
await msg.edit(f"Pulled model {model}.")
|
|
|
|
|
|
|
|
@niobot.command("ollama.chat", greedy=True)
|
2024-09-12 01:45:38 +01:00
|
|
|
@niobot.from_homeserver("nexy7574.co.uk", "nicroxio.co.uk", "shronk.net")
|
2024-07-28 23:33:15 +01:00
|
|
|
async def chat(self, ctx: niobot.Context):
|
|
|
|
"""Chat with the model."""
|
2024-07-28 23:50:29 +01:00
|
|
|
if "--gpu" in ctx.args:
|
|
|
|
ctx.args.remove("--gpu")
|
|
|
|
gpu_only = True
|
|
|
|
else:
|
|
|
|
gpu_only = False
|
2024-07-28 23:33:15 +01:00
|
|
|
try:
|
|
|
|
message = " ".join(ctx.args)
|
|
|
|
users = self.read_users()
|
|
|
|
if ctx.message.sender not in users:
|
|
|
|
await ctx.respond("You need to set a model first. See: `h!help ollama.set-model`")
|
|
|
|
return
|
|
|
|
model = users[ctx.message.sender]
|
|
|
|
res = await ctx.respond("Finding server...")
|
2024-07-28 23:50:29 +01:00
|
|
|
server = await self.find_server(gpu_only=gpu_only)
|
2024-07-28 23:33:15 +01:00
|
|
|
if not server:
|
|
|
|
await res.edit(content="No servers available.")
|
|
|
|
return
|
|
|
|
async with ollama_client(server["url"]) as client:
|
|
|
|
await res.edit(content=f"Generating response...")
|
|
|
|
try:
|
|
|
|
response = await client.chat(model, [{"role": "user", "content": message}])
|
|
|
|
except httpx.HTTPError as e:
|
|
|
|
response = {"message": {"content": f"Error: {e}"}}
|
|
|
|
await res.edit(content=response["message"]["content"])
|
|
|
|
except Exception as e:
|
|
|
|
logging.exception(e)
|
2024-07-29 00:36:32 +01:00
|
|
|
await ctx.respond(content="An error occurred.")
|
|
|
|
|
2024-09-12 01:45:38 +01:00
|
|
|
@niobot.command("ollama.status", aliases=["ollama.ping"])
|
|
|
|
@niobot.from_homeserver("nexy7574.co.uk", "nicroxio.co.uk", "shronk.net")
|
2024-07-29 00:36:32 +01:00
|
|
|
async def status(self, ctx: niobot.Context, gpu_only: bool = False):
|
|
|
|
"""Checks which servers are online."""
|
|
|
|
lines: dict[str, dict[str, str | None | bool]] = {}
|
|
|
|
for name, cfg in self.bot.cfg["ollama"].items():
|
|
|
|
lines[name] = {
|
|
|
|
"url": cfg["url"],
|
|
|
|
"gpu": cfg["gpu"],
|
|
|
|
"online": None
|
|
|
|
}
|
|
|
|
|
|
|
|
emojis = {
|
|
|
|
True: "\N{white heavy check mark}",
|
|
|
|
False: "\N{cross mark}",
|
|
|
|
None: "\N{hourglass with flowing sand}"
|
|
|
|
}
|
|
|
|
|
|
|
|
def get_lines():
|
|
|
|
ln = []
|
|
|
|
for _n, _d in lines.items():
|
|
|
|
if gpu_only and _d["gpu"] is False:
|
|
|
|
continue
|
|
|
|
ln.append(f"* **{_n}**: {emojis[_d['online']]}")
|
|
|
|
return "\n".join(ln)
|
|
|
|
|
|
|
|
response = await ctx.respond(get_lines())
|
|
|
|
|
2024-07-29 00:38:53 +01:00
|
|
|
async def ping_task(target_url, target_name):
|
|
|
|
async with ollama_client(target_url) as client:
|
2024-07-29 00:36:32 +01:00
|
|
|
try:
|
2024-07-29 00:39:52 +01:00
|
|
|
self.log.info("[status] Checking %s (%r)", target_name, target_url)
|
2024-07-29 00:36:32 +01:00
|
|
|
await client.ps()
|
|
|
|
except (httpx.HTTPError, ConnectionError):
|
2024-07-29 00:38:53 +01:00
|
|
|
lines[target_name]["online"] = False
|
2024-07-29 00:36:32 +01:00
|
|
|
else:
|
2024-07-29 00:38:53 +01:00
|
|
|
lines[target_name]["online"] = True
|
2024-07-29 00:36:32 +01:00
|
|
|
await response.edit(content=get_lines())
|
2024-07-29 00:38:53 +01:00
|
|
|
|
|
|
|
tasks = []
|
|
|
|
for name, cfg in self.bot.cfg["ollama"].items():
|
|
|
|
url = cfg["url"]
|
|
|
|
gpu = cfg["gpu"]
|
|
|
|
if gpu_only and not gpu:
|
|
|
|
continue
|
|
|
|
t = asyncio.create_task(ping_task(url, name))
|
|
|
|
tasks.append(t)
|
|
|
|
await asyncio.gather(*tasks, return_exceptions=True)
|