From 59f51018698bcdd1fc2752994f93c15256ecdf19 Mon Sep 17 00:00:00 2001 From: nex Date: Mon, 13 Nov 2023 20:20:55 +0000 Subject: [PATCH] Improve context management --- cogs/other.py | 79 +++++++++++++++++++++------------------------------ 1 file changed, 33 insertions(+), 46 deletions(-) diff --git a/cogs/other.py b/cogs/other.py index 49ba29c..fb1b47f 100644 --- a/cogs/other.py +++ b/cogs/other.py @@ -92,12 +92,10 @@ except Exception as _pyttsx3_err: async def ollama_stream_reader(response: httpx.Response) -> typing.AsyncGenerator[ dict[str, str | int | bool], None ]: - print("Starting to iterate over ollama response %r..." % response, file=sys.stderr) async for chunk in response.aiter_lines(): # Each line is a JSON string try: loaded = json.loads(chunk) - print("Loaded chunk: %r" % loaded) yield loaded except json.JSONDecodeError as e: print("Failed to decode chunk %r: %r" % (chunk, e), file=sys.stderr) @@ -137,6 +135,7 @@ class OtherCog(commands.Cog): self._worker_task = self.bot.loop.create_task(self.cache_population_job()) self.ollama_locks: dict[discord.Message, asyncio.Event] = {} + self.context_cache: dict[str, list[int]] = {} def cog_unload(self): self._worker_task.cancel() @@ -1840,7 +1839,8 @@ class OtherCog(commands.Cog): ctx: discord.ApplicationContext, model: str = "orca-mini", query: str = None, - context: str = None + context: str = None, + server: str = "auto" ): """:3""" with open("./assets/ollama-prompt.txt") as file: @@ -1892,20 +1892,16 @@ class OtherCog(commands.Cog): await ctx.defer() if context: - try: - context_decoded = base64.b64decode(context).decode() - context_decompressed = await asyncio.to_thread( - functools.partial(zlib.decompress, context_decoded.encode()) - ) - context = json.loads(context_decompressed) - except (ValueError, zlib.error, UnicodeDecodeError) as e: - return await ctx.respond("Failed to decode context: " + str(e)) + if context not in self.context_cache: + return await ctx.respond(":x: Context not found in cache.") + context = self.context_cache[context] content = None try_hosts = { "127.0.0.1:11434": "localhost", - "100.106.34.86:11434": "Nex Laptop", - "100.66.187.46:11434": "Nexbox", + "100.106.34.86:11434": "NexTop", + "ollama.shronk.net": "Alibaba Cloud", + "100.66.187.46:11434": "NexBox", "100.116.242.161:11434": "PortaPi" } model = model.casefold() @@ -1918,19 +1914,20 @@ class OtherCog(commands.Cog): ephemeral=True ) model = "orca-mini" - async with httpx.AsyncClient(follow_redirects=True) as client: - for host in try_hosts.keys(): - try: - response = await client.get( - f"http://{host}/api/tags", - ) - response.raise_for_status() - except (httpx.TransportError, httpx.NetworkError, httpx.HTTPStatusError): - continue + if server != "auto": + async with httpx.AsyncClient(follow_redirects=True) as client: + for host in try_hosts.keys(): + try: + response = await client.get( + f"http://{host}/api/tags", + ) + response.raise_for_status() + except (httpx.TransportError, httpx.NetworkError, httpx.HTTPStatusError): + continue + else: + break else: - break - else: - return await ctx.respond(":x: No servers available.") + return await ctx.respond(":x: No servers available.") embed = discord.Embed( colour=discord.Colour.greyple() @@ -2117,43 +2114,33 @@ class OtherCog(commands.Cog): context: Optional[list[int]] = chunk.get("context") # noinspection PyTypeChecker if context: - context_json = json.dumps(context) - start = time() - context_json_compressed = await asyncio.to_thread( - functools.partial(zlib.compress, context_json.encode()) - ) - end = time() - compress_time_spent = format(round(end * 1000 - start * 1000), ",") - context: str = base64.b64encode(context_json_compressed).decode() + key = os.urandom(8).hex() + self.context_cache[key] = context else: - compress_time_spent = "N/A" - context = None + context = key = None value = ("* Total: {}\n" "* Model load: {}\n" "* Sample generation: {}\n" "* Prompt eval: {}\n" - "* Response generation: {}\n" - "* Context compression: {} milliseconds").format( + "* Response generation: {}\n").format( total_time_spent, load_time_spent, sample_time_sent, prompt_eval_time_spent, eval_time_spent, - compress_time_spent ) embed.add_field( name="Timings", - value=value + value=value, + inline=False ) - await msg.edit(content=None, embed=embed, view=None) if context: - await ctx.respond( - "Context:\n" - "```\n" - f"{context}\n" - "```", - ephemeral=True + embed.add_field( + name="Context Key", + value=key, + inline=True ) + await msg.edit(content=None, embed=embed, view=None) self.ollama_locks.pop(msg, None)