Add context and acid
This commit is contained in:
parent
8ccdeb8a0e
commit
1f7ada8e93
1 changed files with 51 additions and 10 deletions
|
@ -1,6 +1,7 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
import textwrap
|
import textwrap
|
||||||
import time
|
import time
|
||||||
import typing
|
import typing
|
||||||
|
@ -35,6 +36,7 @@ class Ollama(commands.Cog):
|
||||||
self.bot = bot
|
self.bot = bot
|
||||||
self.log = logging.getLogger("jimmy.cogs.ollama")
|
self.log = logging.getLogger("jimmy.cogs.ollama")
|
||||||
self.last_server = 0
|
self.last_server = 0
|
||||||
|
self.contexts = {}
|
||||||
|
|
||||||
def next_server(self, increment: bool = True) -> str:
|
def next_server(self, increment: bool = True) -> str:
|
||||||
"""Returns the next server key."""
|
"""Returns the next server key."""
|
||||||
|
@ -83,7 +85,27 @@ class Ollama(commands.Cog):
|
||||||
choices=SERVER_KEYS
|
choices=SERVER_KEYS
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
|
context: typing.Annotated[
|
||||||
|
str,
|
||||||
|
discord.Option(
|
||||||
|
str,
|
||||||
|
"The context key of a previous ollama response to use as context.",
|
||||||
|
default=None
|
||||||
|
)
|
||||||
|
],
|
||||||
|
give_acid: typing.Annotated[
|
||||||
|
bool,
|
||||||
|
discord.Option(
|
||||||
|
bool,
|
||||||
|
"Whether to give the AI acid, LSD, and other hallucinogens before responding.",
|
||||||
|
default=False
|
||||||
|
)
|
||||||
|
]
|
||||||
):
|
):
|
||||||
|
if context is not None:
|
||||||
|
if context not in self.contexts:
|
||||||
|
await ctx.respond("Invalid context key.")
|
||||||
|
return
|
||||||
with open("./assets/ollama-prompt.txt") as file:
|
with open("./assets/ollama-prompt.txt") as file:
|
||||||
system_prompt = file.read()
|
system_prompt = file.read()
|
||||||
await ctx.defer()
|
await ctx.defer()
|
||||||
|
@ -150,10 +172,10 @@ class Ollama(commands.Cog):
|
||||||
|
|
||||||
if resp.status == 404:
|
if resp.status == 404:
|
||||||
self.log.debug("Beginning download of %r", model)
|
self.log.debug("Beginning download of %r", model)
|
||||||
def progress_bar(value: float, action: str = None):
|
def progress_bar(_v: float, action: str = None):
|
||||||
bar = "\N{large green square}" * round(value / 10)
|
bar = "\N{large green square}" * round(_v / 10)
|
||||||
bar += "\N{white large square}" * (10 - len(bar))
|
bar += "\N{white large square}" * (10 - len(bar))
|
||||||
bar += f" {value:.2f}%"
|
bar += f" {_v:.2f}%"
|
||||||
if action:
|
if action:
|
||||||
return f"{action} {bar}"
|
return f"{action} {bar}"
|
||||||
return bar
|
return bar
|
||||||
|
@ -194,6 +216,8 @@ class Ollama(commands.Cog):
|
||||||
else:
|
else:
|
||||||
self.log.debug("Model %r already exists on server.", model)
|
self.log.debug("Model %r already exists on server.", model)
|
||||||
|
|
||||||
|
key = os.urandom(6).hex()
|
||||||
|
|
||||||
embed = discord.Embed(
|
embed = discord.Embed(
|
||||||
title="Generating response...",
|
title="Generating response...",
|
||||||
description=">>> ",
|
description=">>> ",
|
||||||
|
@ -216,15 +240,26 @@ class Ollama(commands.Cog):
|
||||||
await ctx.edit(embed=embed, view=view)
|
await ctx.edit(embed=embed, view=view)
|
||||||
except discord.NotFound:
|
except discord.NotFound:
|
||||||
await ctx.respond(embed=embed, view=view)
|
await ctx.respond(embed=embed, view=view)
|
||||||
self.log.debug("Beginning to generate response.")
|
self.log.debug("Beginning to generate response with key %r.", key)
|
||||||
|
|
||||||
|
params = {}
|
||||||
|
if give_acid is True:
|
||||||
|
params["temperature"] = 5
|
||||||
|
params["top_k"] = 500
|
||||||
|
params["top_p"] = 5
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"model": model,
|
||||||
|
"prompt": query,
|
||||||
|
"system": system_prompt,
|
||||||
|
"stream": True,
|
||||||
|
"options": params,
|
||||||
|
}
|
||||||
|
if context is not None:
|
||||||
|
payload["context"] = self.contexts[context]
|
||||||
async with session.post(
|
async with session.post(
|
||||||
"/api/generate",
|
"/api/generate",
|
||||||
json={
|
json=payload,
|
||||||
"model": model,
|
|
||||||
"prompt": query,
|
|
||||||
"system": system_prompt,
|
|
||||||
"stream": True
|
|
||||||
},
|
|
||||||
) as response:
|
) as response:
|
||||||
if response.status != 200:
|
if response.status != 200:
|
||||||
embed = discord.Embed(
|
embed = discord.Embed(
|
||||||
|
@ -239,8 +274,11 @@ class Ollama(commands.Cog):
|
||||||
|
|
||||||
last_update = time.time()
|
last_update = time.time()
|
||||||
buffer = io.StringIO()
|
buffer = io.StringIO()
|
||||||
|
context = []
|
||||||
if not view.cancel.is_set():
|
if not view.cancel.is_set():
|
||||||
async for line in self.ollama_stream(response.content):
|
async for line in self.ollama_stream(response.content):
|
||||||
|
if "context" in line:
|
||||||
|
context = line["context"]
|
||||||
buffer.write(line["response"])
|
buffer.write(line["response"])
|
||||||
embed.description += line["response"]
|
embed.description += line["response"]
|
||||||
embed.timestamp = discord.utils.utcnow()
|
embed.timestamp = discord.utils.utcnow()
|
||||||
|
@ -255,6 +293,9 @@ class Ollama(commands.Cog):
|
||||||
self.log.debug(f"Updating message ({last_update} -> {time.time()})")
|
self.log.debug(f"Updating message ({last_update} -> {time.time()})")
|
||||||
last_update = time.time()
|
last_update = time.time()
|
||||||
view.stop()
|
view.stop()
|
||||||
|
if context:
|
||||||
|
self.contexts[key] = context
|
||||||
|
embed.add_field(name="Context Key", value=key, inline=True)
|
||||||
self.log.debug("Ollama finished consuming.")
|
self.log.debug("Ollama finished consuming.")
|
||||||
embed.title = "Done!"
|
embed.title = "Done!"
|
||||||
embed.color = discord.Color.green()
|
embed.color = discord.Color.green()
|
||||||
|
|
Loading…
Reference in a new issue