Migrate to chat endpoint

This commit is contained in:
Nexus 2024-01-12 15:39:39 +00:00
parent 980c61d08b
commit 380d500e32

View file

@ -65,6 +65,63 @@ class OllamaView(View):
self.stop()
class ChatHistory:
def __init__(self):
self._internal = {}
def create_thread(self, member: discord.Member) -> str:
"""
Creates a thread, returns its ID.
"""
key = os.urandom(3).hex()
self._internal[key] = {
"member": member,
"messages": []
}
with open("./assets/ollama-prompt.txt") as file:
system_prompt = file.read()
self.add_message(
key,
"system",
system_prompt
)
return key
@staticmethod
def _construct_message(role: str, content: str, images: typing.Optional[list[str]]) -> dict[str, str]:
x = {
"role": role,
"content": content
}
if images:
x["images"] = images
return x
def add_message(
self,
thread: str,
role: typing.Literal["user", "assistant", "system"],
content: str,
images: typing.Optional[list[str]] = None
) -> None:
"""
Appends a message to the given thread.
:param thread: The thread's ID.
:param role: The author of the message.
:param content: The message's actual content.
:param images: Any images that were attached to the message, in base64.
:return: None
"""
self._internal[thread]["messages"].append(self._construct_message(role, content, images))
def get_history(self, thread: str) -> list[dict[str, str]]:
"""
Gets the history of a thread.
"""
return self._internal[thread]["messages"].copy() # copy() makes it immutable.
SERVER_KEYS = list(CONFIG["ollama"].keys())
@ -74,6 +131,7 @@ class Ollama(commands.Cog):
self.log = logging.getLogger("jimmy.cogs.ollama")
self.last_server = 0
self.contexts = {}
self.history = ChatHistory()
def next_server(self, increment: bool = True) -> str:
"""Returns the next server key."""
@ -163,9 +221,6 @@ class Ollama(commands.Cog):
if context not in self.contexts:
await ctx.respond("Invalid context key.")
return
return await ctx.respond("Context is currently disabled.", ephemeral=True)
with open("./assets/ollama-prompt.txt") as file:
system_prompt = file.read()
await ctx.defer()
model = model.casefold()
@ -366,19 +421,24 @@ class Ollama(commands.Cog):
params["top_k"] = 500
params["top_p"] = 500
if context is None or context in self.contexts:
context = self.history.create_thread(ctx.user)
messages = self.history.get_history(context)
user_message = {
"role": "user",
"content": query
}
if image_data:
user_message["images"] = [image_data]
messages.append(user_message)
payload = {
"model": model,
"prompt": query,
"system": system_prompt,
"stream": True,
"options": params,
"messages": messages
}
if context is not None:
payload["context"] = self.contexts[context]
if image_data:
payload["images"] = [image_data]
async with session.post(
"/api/generate",
"/api/chat",
json=payload,
) as response:
if response.status != 200:
@ -394,16 +454,13 @@ class Ollama(commands.Cog):
last_update = time.time()
buffer = io.StringIO()
context = []
if not view.cancel.is_set():
async for line in self.ollama_stream(response.content):
if "context" in line:
context = line["context"]
buffer.write(line["response"])
embed.description += line["response"]
buffer.write(line["assistant"])
embed.description += line["assistant"]
embed.timestamp = discord.utils.utcnow()
if len(embed.description) >= 4096:
embed.description = embed.description = "..." + line["response"]
embed.description = embed.description = "..." + line["assistant"]
if view.cancel.is_set():
break
@ -413,9 +470,9 @@ class Ollama(commands.Cog):
self.log.debug(f"Updating message ({last_update} -> {time.time()})")
last_update = time.time()
view.stop()
if context:
self.contexts[key] = context
embed.add_field(name="Context Key", value=key, inline=True)
self.history.add_message(context, "user", user_message["content"], user_message["images"])
self.history.add_message(context, "assistant", buffer.getvalue())
embed.add_field(name="Context Key", value=context, inline=True)
self.log.debug("Ollama finished consuming.")
embed.title = "Done!"
embed.colour = discord.Color.green()