From 21a9bcecf9167a144981421ebbaaa4dc2a93f76a Mon Sep 17 00:00:00 2001 From: nex Date: Mon, 13 Nov 2023 19:36:45 +0000 Subject: [PATCH] Implement context --- cogs/other.py | 38 +++++++++++++++++++++++++------------- 1 file changed, 25 insertions(+), 13 deletions(-) diff --git a/cogs/other.py b/cogs/other.py index 6761d04..a58de82 100644 --- a/cogs/other.py +++ b/cogs/other.py @@ -1890,6 +1890,17 @@ class OtherCog(commands.Cog): system_prompt = modal.system_prompt or system_prompt else: await ctx.defer() + + if context: + try: + context_decoded = base64.b85decode(context).decode() + context_decompressed = await asyncio.to_thread( + functools.partial(zlib.decompress, context_decoded.encode()) + ) + context = json.loads(context_decompressed) + except (ValueError, zlib.error, UnicodeDecodeError) as e: + return await ctx.respond("Failed to decode context: " + str(e)) + content = None try_hosts = { "127.0.0.1:11434": "localhost", @@ -2006,16 +2017,19 @@ class OtherCog(commands.Cog): embed.set_footer(text=f"Powered by Ollama • {host} ({try_hosts.get(host, 'Other')})") await msg.edit(embed=embed) async with ctx.channel.typing(): + payload = { + "model": model, + "prompt": query, + "format": "json", + "system": system_prompt, + "stream": True + } + if context: + payload["context"] = context async with client.stream( "POST", "/generate", - json={ - "model": model, - "prompt": query, - "format": "json", - "system": system_prompt, - "stream": True - }, + json=payload, timeout=None ) as response: if response.status_code != 200: @@ -2104,13 +2118,11 @@ class OtherCog(commands.Cog): context_json = json.dumps(context) start = time() context_json_compressed = await asyncio.to_thread( - zlib.compress, - context_json.encode(), - 9 + functools.partial(zlib.compress, context_json.encode()) ) end = time() - compress_time_spent = format(end * 1000 - start * 1000, ",") - context: str = base64.b64encode(context_json_compressed).decode() + compress_time_spent = format(round(end * 1000 - start * 1000), ",") + context: str = base64.b85encode(context_json_compressed).decode() else: compress_time_spent = "N/A" context = None @@ -2135,7 +2147,7 @@ class OtherCog(commands.Cog): embed.add_field( name="Context", value=f"```\n{context}\n```"[:1024], - inline=True + inline=False ) await msg.edit(content=None, embed=embed, view=None) self.ollama_locks.pop(msg, None)