Migrate to chat endpoint
This commit is contained in:
parent
980c61d08b
commit
380d500e32
1 changed files with 76 additions and 19 deletions
|
@ -65,6 +65,63 @@ class OllamaView(View):
|
||||||
self.stop()
|
self.stop()
|
||||||
|
|
||||||
|
|
||||||
|
class ChatHistory:
|
||||||
|
def __init__(self):
|
||||||
|
self._internal = {}
|
||||||
|
|
||||||
|
def create_thread(self, member: discord.Member) -> str:
|
||||||
|
"""
|
||||||
|
Creates a thread, returns its ID.
|
||||||
|
"""
|
||||||
|
key = os.urandom(3).hex()
|
||||||
|
self._internal[key] = {
|
||||||
|
"member": member,
|
||||||
|
"messages": []
|
||||||
|
}
|
||||||
|
with open("./assets/ollama-prompt.txt") as file:
|
||||||
|
system_prompt = file.read()
|
||||||
|
self.add_message(
|
||||||
|
key,
|
||||||
|
"system",
|
||||||
|
system_prompt
|
||||||
|
)
|
||||||
|
return key
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _construct_message(role: str, content: str, images: typing.Optional[list[str]]) -> dict[str, str]:
|
||||||
|
x = {
|
||||||
|
"role": role,
|
||||||
|
"content": content
|
||||||
|
}
|
||||||
|
if images:
|
||||||
|
x["images"] = images
|
||||||
|
return x
|
||||||
|
|
||||||
|
def add_message(
|
||||||
|
self,
|
||||||
|
thread: str,
|
||||||
|
role: typing.Literal["user", "assistant", "system"],
|
||||||
|
content: str,
|
||||||
|
images: typing.Optional[list[str]] = None
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Appends a message to the given thread.
|
||||||
|
|
||||||
|
:param thread: The thread's ID.
|
||||||
|
:param role: The author of the message.
|
||||||
|
:param content: The message's actual content.
|
||||||
|
:param images: Any images that were attached to the message, in base64.
|
||||||
|
:return: None
|
||||||
|
"""
|
||||||
|
self._internal[thread]["messages"].append(self._construct_message(role, content, images))
|
||||||
|
|
||||||
|
def get_history(self, thread: str) -> list[dict[str, str]]:
|
||||||
|
"""
|
||||||
|
Gets the history of a thread.
|
||||||
|
"""
|
||||||
|
return self._internal[thread]["messages"].copy() # copy() makes it immutable.
|
||||||
|
|
||||||
|
|
||||||
SERVER_KEYS = list(CONFIG["ollama"].keys())
|
SERVER_KEYS = list(CONFIG["ollama"].keys())
|
||||||
|
|
||||||
|
|
||||||
|
@ -74,6 +131,7 @@ class Ollama(commands.Cog):
|
||||||
self.log = logging.getLogger("jimmy.cogs.ollama")
|
self.log = logging.getLogger("jimmy.cogs.ollama")
|
||||||
self.last_server = 0
|
self.last_server = 0
|
||||||
self.contexts = {}
|
self.contexts = {}
|
||||||
|
self.history = ChatHistory()
|
||||||
|
|
||||||
def next_server(self, increment: bool = True) -> str:
|
def next_server(self, increment: bool = True) -> str:
|
||||||
"""Returns the next server key."""
|
"""Returns the next server key."""
|
||||||
|
@ -163,9 +221,6 @@ class Ollama(commands.Cog):
|
||||||
if context not in self.contexts:
|
if context not in self.contexts:
|
||||||
await ctx.respond("Invalid context key.")
|
await ctx.respond("Invalid context key.")
|
||||||
return
|
return
|
||||||
return await ctx.respond("Context is currently disabled.", ephemeral=True)
|
|
||||||
with open("./assets/ollama-prompt.txt") as file:
|
|
||||||
system_prompt = file.read()
|
|
||||||
await ctx.defer()
|
await ctx.defer()
|
||||||
|
|
||||||
model = model.casefold()
|
model = model.casefold()
|
||||||
|
@ -366,19 +421,24 @@ class Ollama(commands.Cog):
|
||||||
params["top_k"] = 500
|
params["top_k"] = 500
|
||||||
params["top_p"] = 500
|
params["top_p"] = 500
|
||||||
|
|
||||||
|
if context is None or context in self.contexts:
|
||||||
|
context = self.history.create_thread(ctx.user)
|
||||||
|
messages = self.history.get_history(context)
|
||||||
|
user_message = {
|
||||||
|
"role": "user",
|
||||||
|
"content": query
|
||||||
|
}
|
||||||
|
if image_data:
|
||||||
|
user_message["images"] = [image_data]
|
||||||
|
messages.append(user_message)
|
||||||
payload = {
|
payload = {
|
||||||
"model": model,
|
"model": model,
|
||||||
"prompt": query,
|
|
||||||
"system": system_prompt,
|
|
||||||
"stream": True,
|
"stream": True,
|
||||||
"options": params,
|
"options": params,
|
||||||
|
"messages": messages
|
||||||
}
|
}
|
||||||
if context is not None:
|
|
||||||
payload["context"] = self.contexts[context]
|
|
||||||
if image_data:
|
|
||||||
payload["images"] = [image_data]
|
|
||||||
async with session.post(
|
async with session.post(
|
||||||
"/api/generate",
|
"/api/chat",
|
||||||
json=payload,
|
json=payload,
|
||||||
) as response:
|
) as response:
|
||||||
if response.status != 200:
|
if response.status != 200:
|
||||||
|
@ -394,16 +454,13 @@ class Ollama(commands.Cog):
|
||||||
|
|
||||||
last_update = time.time()
|
last_update = time.time()
|
||||||
buffer = io.StringIO()
|
buffer = io.StringIO()
|
||||||
context = []
|
|
||||||
if not view.cancel.is_set():
|
if not view.cancel.is_set():
|
||||||
async for line in self.ollama_stream(response.content):
|
async for line in self.ollama_stream(response.content):
|
||||||
if "context" in line:
|
buffer.write(line["assistant"])
|
||||||
context = line["context"]
|
embed.description += line["assistant"]
|
||||||
buffer.write(line["response"])
|
|
||||||
embed.description += line["response"]
|
|
||||||
embed.timestamp = discord.utils.utcnow()
|
embed.timestamp = discord.utils.utcnow()
|
||||||
if len(embed.description) >= 4096:
|
if len(embed.description) >= 4096:
|
||||||
embed.description = embed.description = "..." + line["response"]
|
embed.description = embed.description = "..." + line["assistant"]
|
||||||
|
|
||||||
if view.cancel.is_set():
|
if view.cancel.is_set():
|
||||||
break
|
break
|
||||||
|
@ -413,9 +470,9 @@ class Ollama(commands.Cog):
|
||||||
self.log.debug(f"Updating message ({last_update} -> {time.time()})")
|
self.log.debug(f"Updating message ({last_update} -> {time.time()})")
|
||||||
last_update = time.time()
|
last_update = time.time()
|
||||||
view.stop()
|
view.stop()
|
||||||
if context:
|
self.history.add_message(context, "user", user_message["content"], user_message["images"])
|
||||||
self.contexts[key] = context
|
self.history.add_message(context, "assistant", buffer.getvalue())
|
||||||
embed.add_field(name="Context Key", value=key, inline=True)
|
embed.add_field(name="Context Key", value=context, inline=True)
|
||||||
self.log.debug("Ollama finished consuming.")
|
self.log.debug("Ollama finished consuming.")
|
||||||
embed.title = "Done!"
|
embed.title = "Done!"
|
||||||
embed.colour = discord.Color.green()
|
embed.colour = discord.Color.green()
|
||||||
|
|
Loading…
Reference in a new issue