Add context and acid

This commit is contained in:
Nexus 2024-01-10 15:59:13 +00:00
parent 8ccdeb8a0e
commit 1f7ada8e93

View file

@ -1,6 +1,7 @@
import asyncio
import json
import logging
import os
import textwrap
import time
import typing
@ -35,6 +36,7 @@ class Ollama(commands.Cog):
self.bot = bot
self.log = logging.getLogger("jimmy.cogs.ollama")
self.last_server = 0
self.contexts = {}
def next_server(self, increment: bool = True) -> str:
"""Returns the next server key."""
@ -83,7 +85,27 @@ class Ollama(commands.Cog):
choices=SERVER_KEYS
)
],
context: typing.Annotated[
str,
discord.Option(
str,
"The context key of a previous ollama response to use as context.",
default=None
)
],
give_acid: typing.Annotated[
bool,
discord.Option(
bool,
"Whether to give the AI acid, LSD, and other hallucinogens before responding.",
default=False
)
]
):
if context is not None:
if context not in self.contexts:
await ctx.respond("Invalid context key.")
return
with open("./assets/ollama-prompt.txt") as file:
system_prompt = file.read()
await ctx.defer()
@ -150,10 +172,10 @@ class Ollama(commands.Cog):
if resp.status == 404:
self.log.debug("Beginning download of %r", model)
def progress_bar(value: float, action: str = None):
bar = "\N{large green square}" * round(value / 10)
def progress_bar(_v: float, action: str = None):
bar = "\N{large green square}" * round(_v / 10)
bar += "\N{white large square}" * (10 - len(bar))
bar += f" {value:.2f}%"
bar += f" {_v:.2f}%"
if action:
return f"{action} {bar}"
return bar
@ -194,6 +216,8 @@ class Ollama(commands.Cog):
else:
self.log.debug("Model %r already exists on server.", model)
key = os.urandom(6).hex()
embed = discord.Embed(
title="Generating response...",
description=">>> ",
@ -216,15 +240,26 @@ class Ollama(commands.Cog):
await ctx.edit(embed=embed, view=view)
except discord.NotFound:
await ctx.respond(embed=embed, view=view)
self.log.debug("Beginning to generate response.")
self.log.debug("Beginning to generate response with key %r.", key)
params = {}
if give_acid is True:
params["temperature"] = 5
params["top_k"] = 500
params["top_p"] = 5
payload = {
"model": model,
"prompt": query,
"system": system_prompt,
"stream": True,
"options": params,
}
if context is not None:
payload["context"] = self.contexts[context]
async with session.post(
"/api/generate",
json={
"model": model,
"prompt": query,
"system": system_prompt,
"stream": True
},
json=payload,
) as response:
if response.status != 200:
embed = discord.Embed(
@ -239,8 +274,11 @@ class Ollama(commands.Cog):
last_update = time.time()
buffer = io.StringIO()
context = []
if not view.cancel.is_set():
async for line in self.ollama_stream(response.content):
if "context" in line:
context = line["context"]
buffer.write(line["response"])
embed.description += line["response"]
embed.timestamp = discord.utils.utcnow()
@ -255,6 +293,9 @@ class Ollama(commands.Cog):
self.log.debug(f"Updating message ({last_update} -> {time.time()})")
last_update = time.time()
view.stop()
if context:
self.contexts[key] = context
embed.add_field(name="Context Key", value=key, inline=True)
self.log.debug("Ollama finished consuming.")
embed.title = "Done!"
embed.color = discord.Color.green()