diff --git a/.idea/college-bot-2.0.iml b/.idea/college-bot-2.0.iml index 2c80e12..5432048 100644 --- a/.idea/college-bot-2.0.iml +++ b/.idea/college-bot-2.0.iml @@ -4,7 +4,7 @@ - + \ No newline at end of file diff --git a/.idea/misc.xml b/.idea/misc.xml index d807ba6..32c9382 100644 --- a/.idea/misc.xml +++ b/.idea/misc.xml @@ -3,5 +3,5 @@ - + \ No newline at end of file diff --git a/.idea/workspace.xml b/.idea/workspace.xml index 259b868..e80bae3 100644 --- a/.idea/workspace.xml +++ b/.idea/workspace.xml @@ -4,11 +4,21 @@ - - + + + + + - - + + + + + + + + + + - { - "keyToString": { - "ASKED_ADD_EXTERNAL_FILES": "true", - "RunOnceActivity.OpenProjectViewOnStart": "true", - "RunOnceActivity.ShowReadmeOnStart": "true", - "git-widget-placeholder": "master", - "node.js.detected.package.eslint": "true", - "node.js.detected.package.tslint": "true", - "node.js.selected.package.eslint": "(autodetect)", - "node.js.selected.package.tslint": "(autodetect)", - "nodejs_package_manager_path": "npm", - "settings.editor.selected.configurable": "settings.sync", - "vue.rearranger.settings.migration": "true" + +}]]> + + + + + @@ -73,6 +90,8 @@ + + - @@ -118,6 +153,9 @@ - \ No newline at end of file diff --git a/Dockerfile b/Dockerfile index 30e5a0e..f3572d4 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,6 +1,7 @@ FROM python:3.11-bookworm RUN DEBIAN_FRONTEND=noninteractive apt-get update +RUN DEBIAN_FRONTEND=noninteractive apt-get upgrade -y RUN DEBIAN_FRONTEND=noninteractive apt-get install -y \ traceroute \ iputils-ping \ @@ -24,7 +25,7 @@ COPY requirements.txt /tmp/requirements.txt RUN pip install -Ur /tmp/requirements.txt --break-system-packages --no-input WORKDIR /app -COPY ./ /app/ -COPY cogs/ /app/cogs/ +COPY ./src/ /app/ +COPY ./src/cogs/ /app/cogs/ CMD ["python", "main.py"] diff --git a/config.example.toml b/config.example.toml index a058a4f..2be9124 100644 --- a/config.example.toml +++ b/config.example.toml @@ -3,7 +3,17 @@ token = "token" # the bot token debug_guilds = [994710566612500550] # server IDs to create slash commands in. Set to null for all guilds. [logging] -level = "DEBUG" # can be one of DEBUG, INFO, WARNING, ERROR, CRITICAL +level = "DEBUG" # can be one of DEBUG, INFO, WARNING, ERROR, CRITICAL. Defaults to INFO +file = "jimmy.log" # if omitted, defaults to jimmy.log. Always pretty prints to stdout. +mode = "a" # can be over(w)rite or (a)ppend. Defaults to append. +suppress = [ + "discord.client", + "discord.gateway", + "discord.http", + "selenium.webdriver.remote.remote_connection" # make sure to include this one to prevent /screenshot from putting + # literal images (in base64) in your logs. +] +# All the loggers specified here will have their log level set to WARNING. [ollama.internal] # name is "internal" diff --git a/docker-compose.yml b/docker-compose.yml index 9de50ee..51c63a0 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -7,4 +7,10 @@ services: - ./config.toml:/app/config.toml - ./jimmy.log:/app/jimmy.log - /dev/dri:/dev/dri - + ollama: + image: ollama/ollama:latest + restart: unless-stopped + ports: + - 11434:11434 + volumes: + - ollama-data:/root/.ollama diff --git a/assets/ollama-prompt.txt b/src/assets/ollama-prompt.txt similarity index 100% rename from assets/ollama-prompt.txt rename to src/assets/ollama-prompt.txt diff --git a/src/cogs/__init__.py b/src/cogs/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/cogs/net.py b/src/cogs/net.py similarity index 100% rename from cogs/net.py rename to src/cogs/net.py diff --git a/src/cogs/ollama.py b/src/cogs/ollama.py new file mode 100644 index 0000000..2e32549 --- /dev/null +++ b/src/cogs/ollama.py @@ -0,0 +1,252 @@ +import collections +import json +import logging +import time +import typing +from fnmatch import fnmatch + +import aiohttp +import discord +from discord.ext import commands +from ..conf import CONFIG + + +SERVER_KEYS = list(CONFIG["ollama"].keys()) + +class Ollama(commands.Cog): + def __init__(self, bot: commands.Bot): + self.bot = bot + self.log = logging.getLogger("jimmy.cogs.ollama") + + async def ollama_stream(self, iterator: aiohttp.StreamReader) -> typing.AsyncIterator[dict]: + async for line in iterator: + original_line = line + line = line.decode("utf-8", "replace").strip() + try: + line = json.loads(line) + except json.JSONDecodeError: + self.log.warning("Unable to decode JSON: %r", original_line) + continue + else: + self.log.debug("Decoded JSON %r -> %r", original_line, line) + yield line + + @commands.slash_command() + async def ollama( + self, + ctx: discord.ApplicationContext, + query: typing.Annotated[ + str, + discord.Option( + str, + "The query to feed into ollama. Not the system prompt.", + default=None + ) + ], + model: typing.Annotated[ + str, + discord.Option( + str, + "The model to use for ollama. Defaults to 'llama2-uncensored:latest'.", + default="llama2-uncensored:latest" + ) + ], + server: typing.Annotated[ + str, + discord.Option( + str, + "The server to use for ollama.", + default=SERVER_KEYS[0], + choices=SERVER_KEYS + ) + ], + ): + with open("./assets/ollama-prompt.txt") as file: + system_prompt = file.read() + + if query is None: + class InputPrompt(discord.ui.Modal): + def __init__(self, is_owner: bool): + super().__init__( + discord.ui.InputText( + label="User Prompt", + placeholder="Enter prompt", + min_length=1, + max_length=4000, + style=discord.InputTextStyle.long, + ), + title="Enter prompt", + timeout=120, + ) + if is_owner: + self.add_item( + discord.ui.InputText( + label="System Prompt", + placeholder="Enter prompt", + min_length=1, + max_length=4000, + style=discord.InputTextStyle.long, + value=system_prompt, + ) + ) + + self.user_prompt = None + self.system_prompt = system_prompt + + async def callback(self, interaction: discord.Interaction): + self.user_prompt = self.children[0].value + if len(self.children) > 1: + self.system_prompt = self.children[1].value + await interaction.response.defer() + self.stop() + + modal = InputPrompt(await self.bot.is_owner(ctx.author)) + await ctx.send_modal(modal) + await modal.wait() + query = modal.user_prompt + if not modal.user_prompt: + return + system_prompt = modal.system_prompt or system_prompt + else: + await ctx.defer() + + model = model.casefold() + try: + model, tag = model.split(":", 1) + model = model + ":" + tag + self.log.debug("Model %r already has a tag") + except ValueError: + model = model + ":latest" + self.log.debug("Resolved model to %r" % model) + + if server not in CONFIG["ollama"]: + await ctx.respond("Invalid server") + return + + server_config = CONFIG["ollama"][server] + for model_pattern in server_config["allowed_models"]: + if fnmatch(model, model_pattern): + break + else: + allowed_models = ", ".join(map(discord.utils.escape_markdown, server_config["allowed_models"])) + await ctx.respond(f"Invalid model. You can only use one of the following models: {allowed_models}") + return + + async with aiohttp.ClientSession( + base_url=server_config["base_url"], + ) as session: + embed = discord.Embed( + title="Checking server...", + description=f"Checking that specified model and tag ({model}) are available on the server.", + color=discord.Color.blurple(), + timestamp=discord.utils.utcnow() + ) + await ctx.respond(embed=embed) + + try: + async with session.post("/show", json={"name": model}) as resp: + if resp.status not in [404, 200]: + embed = discord.Embed( + url=resp.url, + title=f"HTTP {resp.status} {resp.reason!r} while checking for model.", + description=f"```{await resp.text() or 'No response body'}```"[:4096], + color=discord.Color.red(), + timestamp=discord.utils.utcnow() + ) + embed.set_footer(text="Unable to continue.") + return await ctx.edit(embed=embed) + except aiohttp.ClientConnectionError as e: + embed = discord.Embed( + title="Connection error while checking for model.", + description=f"```{e}```"[:4096], + color=discord.Color.red(), + timestamp=discord.utils.utcnow() + ) + embed.set_footer(text="Unable to continue.") + return await ctx.edit(embed=embed) + + if resp.status == 404: + def progress_bar(value: float, action: str = None): + bar = "\N{green large square}" * round(value / 10) + bar += "\N{white large square}" * (10 - len(bar)) + bar += f" {value:.2f}%" + if action: + return f"{action} {bar}" + return bar + + embed = discord.Embed( + title=f"Downloading {model!r}", + description=f"Downloading {model!r} from {server_config['base_url']}", + color=discord.Color.blurple(), + timestamp=discord.utils.utcnow() + ) + embed.add_field(name="Progress", value=progress_bar(0)) + await ctx.edit(embed=embed) + + last_update = time.time() + + async with session.post("/pull", json={"name": model, "stream": True}, timeout=None) as response: + if response.status != 200: + embed = discord.Embed( + url=response.url, + title=f"HTTP {response.status} {response.reason!r} while downloading model.", + description=f"```{await response.text() or 'No response body'}```"[:4096], + color=discord.Color.red(), + timestamp=discord.utils.utcnow() + ) + embed.set_footer(text="Unable to continue.") + return await ctx.edit(embed=embed) + + async for line in self.ollama_stream(response.content): + if time.time() >= (last_update + 5.1): + if line.get("total") is not None and line.get("completed") is not None: + percent = (line["completed"] / line["total"]) * 100 + else: + percent = 50.0 + + embed.fields[0].value = progress_bar(percent, line["status"]) + await ctx.edit(embed=embed) + last_update = time.time() + + embed = discord.Embed( + title="Generating response...", + description=">>> \u200b", + color=discord.Color.blurple() + ) + async with session.post( + "/generate", + json={ + "model": model, + "prompt": query, + "format": "json", + "system": system_prompt, + "stream": True + } + ) as response: + if response.status != 200: + embed = discord.Embed( + url=response.url, + title=f"HTTP {response.status} {response.reason!r} while generating response.", + description=f"```{await response.text() or 'No response body'}```"[:4096], + color=discord.Color.red(), + timestamp=discord.utils.utcnow() + ) + embed.set_footer(text="Unable to continue.") + return await ctx.edit(embed=embed) + + last_update = time.time() + async for line in self.ollama_stream(response.content): + if line.get("done", False) is True or time.time() >= (last_update + 5.1): + if line.get("done"): + embed.title = "Done!" + embed.color = discord.Color.green() + embed.description += line["text"] + if len(embed.description) >= 4096: + embed.description = embed.description[:4093] + "..." + break + await ctx.edit(embed=embed) + last_update = time.time() + + +def setup(bot): + bot.add_cog(Ollama(bot)) diff --git a/cogs/screenshot.py b/src/cogs/screenshot.py similarity index 100% rename from cogs/screenshot.py rename to src/cogs/screenshot.py diff --git a/cogs/ytdl.py b/src/cogs/ytdl.py similarity index 100% rename from cogs/ytdl.py rename to src/cogs/ytdl.py diff --git a/conf.py b/src/conf.py similarity index 70% rename from conf.py rename to src/conf.py index f82e55a..8cef6c8 100644 --- a/conf.py +++ b/src/conf.py @@ -4,6 +4,9 @@ from pathlib import Path try: CONFIG = toml.load('config.toml') + CONFIG.setdefault("logging", {}) + CONFIG.setdefault("jimmy", {}) + CONFIG.setdefault("ollama", {}) except FileNotFoundError: cwd = Path.cwd() logging.getLogger("jimmy.autoconf").critical("Unable to locate config.toml in %s.", cwd, exc_info=True) diff --git a/main.py b/src/main.py similarity index 82% rename from main.py rename to src/main.py index b3d9ee2..5357f01 100644 --- a/main.py +++ b/src/main.py @@ -1,36 +1,35 @@ import datetime import logging +import sys import traceback from logging import FileHandler -from pathlib import Path import discord -import toml from discord.ext import commands from rich.logging import RichHandler from conf import CONFIG log = logging.getLogger("jimmy") - - +CONFIG.setdefault("logging", {}) logging.basicConfig( + filename=CONFIG["logging"].get("file", "jimmy.log"), + filemode="a", format="%(asctime)s %(levelname)s %(name)s %(message)s", datefmt="%Y-%m-%d %H:%M:%S", - level=CONFIG.get("logging", {}).get("level", "INFO"), + level=CONFIG["logging"].get("level", "INFO"), handlers=[ RichHandler( - level=CONFIG.get("logging", {}).get("level", "INFO"), + level=CONFIG["logging"].get("level", "INFO"), show_time=False, show_path=False, markup=True ), - FileHandler( - filename=CONFIG.get("logging", {}).get("file", "jimmy.log"), - mode="a", - ) ] ) +for logger in CONFIG["logging"].get("suppress", []): + logging.getLogger(logger).setLevel(logging.WARNING) + log.info(f"Suppressed logging for {logger}") bot = commands.Bot( command_prefix=commands.when_mentioned_or("h!", "H!"), @@ -84,4 +83,7 @@ async def on_application_command_completion(ctx: discord.ApplicationContext): ) +if not CONFIG["jimmy"].get("token"): + log.critical("No token specified in config.toml. Exiting. (hint: set jimmy.token in config.toml)") + sys.exit(1) bot.run(CONFIG["jimmy"]["token"])