Improve context management

This commit is contained in:
Nexus 2023-11-13 20:20:55 +00:00
parent c531975569
commit 59f5101869
Signed by: nex
GPG key ID: 0FA334385D0B689F

View file

@ -92,12 +92,10 @@ except Exception as _pyttsx3_err:
async def ollama_stream_reader(response: httpx.Response) -> typing.AsyncGenerator[ async def ollama_stream_reader(response: httpx.Response) -> typing.AsyncGenerator[
dict[str, str | int | bool], None dict[str, str | int | bool], None
]: ]:
print("Starting to iterate over ollama response %r..." % response, file=sys.stderr)
async for chunk in response.aiter_lines(): async for chunk in response.aiter_lines():
# Each line is a JSON string # Each line is a JSON string
try: try:
loaded = json.loads(chunk) loaded = json.loads(chunk)
print("Loaded chunk: %r" % loaded)
yield loaded yield loaded
except json.JSONDecodeError as e: except json.JSONDecodeError as e:
print("Failed to decode chunk %r: %r" % (chunk, e), file=sys.stderr) print("Failed to decode chunk %r: %r" % (chunk, e), file=sys.stderr)
@ -137,6 +135,7 @@ class OtherCog(commands.Cog):
self._worker_task = self.bot.loop.create_task(self.cache_population_job()) self._worker_task = self.bot.loop.create_task(self.cache_population_job())
self.ollama_locks: dict[discord.Message, asyncio.Event] = {} self.ollama_locks: dict[discord.Message, asyncio.Event] = {}
self.context_cache: dict[str, list[int]] = {}
def cog_unload(self): def cog_unload(self):
self._worker_task.cancel() self._worker_task.cancel()
@ -1840,7 +1839,8 @@ class OtherCog(commands.Cog):
ctx: discord.ApplicationContext, ctx: discord.ApplicationContext,
model: str = "orca-mini", model: str = "orca-mini",
query: str = None, query: str = None,
context: str = None context: str = None,
server: str = "auto"
): ):
""":3""" """:3"""
with open("./assets/ollama-prompt.txt") as file: with open("./assets/ollama-prompt.txt") as file:
@ -1892,20 +1892,16 @@ class OtherCog(commands.Cog):
await ctx.defer() await ctx.defer()
if context: if context:
try: if context not in self.context_cache:
context_decoded = base64.b64decode(context).decode() return await ctx.respond(":x: Context not found in cache.")
context_decompressed = await asyncio.to_thread( context = self.context_cache[context]
functools.partial(zlib.decompress, context_decoded.encode())
)
context = json.loads(context_decompressed)
except (ValueError, zlib.error, UnicodeDecodeError) as e:
return await ctx.respond("Failed to decode context: " + str(e))
content = None content = None
try_hosts = { try_hosts = {
"127.0.0.1:11434": "localhost", "127.0.0.1:11434": "localhost",
"100.106.34.86:11434": "Nex Laptop", "100.106.34.86:11434": "NexTop",
"100.66.187.46:11434": "Nexbox", "ollama.shronk.net": "Alibaba Cloud",
"100.66.187.46:11434": "NexBox",
"100.116.242.161:11434": "PortaPi" "100.116.242.161:11434": "PortaPi"
} }
model = model.casefold() model = model.casefold()
@ -1918,19 +1914,20 @@ class OtherCog(commands.Cog):
ephemeral=True ephemeral=True
) )
model = "orca-mini" model = "orca-mini"
async with httpx.AsyncClient(follow_redirects=True) as client: if server != "auto":
for host in try_hosts.keys(): async with httpx.AsyncClient(follow_redirects=True) as client:
try: for host in try_hosts.keys():
response = await client.get( try:
f"http://{host}/api/tags", response = await client.get(
) f"http://{host}/api/tags",
response.raise_for_status() )
except (httpx.TransportError, httpx.NetworkError, httpx.HTTPStatusError): response.raise_for_status()
continue except (httpx.TransportError, httpx.NetworkError, httpx.HTTPStatusError):
continue
else:
break
else: else:
break return await ctx.respond(":x: No servers available.")
else:
return await ctx.respond(":x: No servers available.")
embed = discord.Embed( embed = discord.Embed(
colour=discord.Colour.greyple() colour=discord.Colour.greyple()
@ -2117,43 +2114,33 @@ class OtherCog(commands.Cog):
context: Optional[list[int]] = chunk.get("context") context: Optional[list[int]] = chunk.get("context")
# noinspection PyTypeChecker # noinspection PyTypeChecker
if context: if context:
context_json = json.dumps(context) key = os.urandom(8).hex()
start = time() self.context_cache[key] = context
context_json_compressed = await asyncio.to_thread(
functools.partial(zlib.compress, context_json.encode())
)
end = time()
compress_time_spent = format(round(end * 1000 - start * 1000), ",")
context: str = base64.b64encode(context_json_compressed).decode()
else: else:
compress_time_spent = "N/A" context = key = None
context = None
value = ("* Total: {}\n" value = ("* Total: {}\n"
"* Model load: {}\n" "* Model load: {}\n"
"* Sample generation: {}\n" "* Sample generation: {}\n"
"* Prompt eval: {}\n" "* Prompt eval: {}\n"
"* Response generation: {}\n" "* Response generation: {}\n").format(
"* Context compression: {} milliseconds").format(
total_time_spent, total_time_spent,
load_time_spent, load_time_spent,
sample_time_sent, sample_time_sent,
prompt_eval_time_spent, prompt_eval_time_spent,
eval_time_spent, eval_time_spent,
compress_time_spent
) )
embed.add_field( embed.add_field(
name="Timings", name="Timings",
value=value value=value,
inline=False
) )
await msg.edit(content=None, embed=embed, view=None)
if context: if context:
await ctx.respond( embed.add_field(
"Context:\n" name="Context Key",
"```\n" value=key,
f"{context}\n" inline=True
"```",
ephemeral=True
) )
await msg.edit(content=None, embed=embed, view=None)
self.ollama_locks.pop(msg, None) self.ollama_locks.pop(msg, None)