Improve context management

This commit is contained in:
Nexus 2023-11-13 20:20:55 +00:00
parent c531975569
commit 59f5101869
Signed by: nex
GPG key ID: 0FA334385D0B689F

View file

@ -92,12 +92,10 @@ except Exception as _pyttsx3_err:
async def ollama_stream_reader(response: httpx.Response) -> typing.AsyncGenerator[
dict[str, str | int | bool], None
]:
print("Starting to iterate over ollama response %r..." % response, file=sys.stderr)
async for chunk in response.aiter_lines():
# Each line is a JSON string
try:
loaded = json.loads(chunk)
print("Loaded chunk: %r" % loaded)
yield loaded
except json.JSONDecodeError as e:
print("Failed to decode chunk %r: %r" % (chunk, e), file=sys.stderr)
@ -137,6 +135,7 @@ class OtherCog(commands.Cog):
self._worker_task = self.bot.loop.create_task(self.cache_population_job())
self.ollama_locks: dict[discord.Message, asyncio.Event] = {}
self.context_cache: dict[str, list[int]] = {}
def cog_unload(self):
self._worker_task.cancel()
@ -1840,7 +1839,8 @@ class OtherCog(commands.Cog):
ctx: discord.ApplicationContext,
model: str = "orca-mini",
query: str = None,
context: str = None
context: str = None,
server: str = "auto"
):
""":3"""
with open("./assets/ollama-prompt.txt") as file:
@ -1892,20 +1892,16 @@ class OtherCog(commands.Cog):
await ctx.defer()
if context:
try:
context_decoded = base64.b64decode(context).decode()
context_decompressed = await asyncio.to_thread(
functools.partial(zlib.decompress, context_decoded.encode())
)
context = json.loads(context_decompressed)
except (ValueError, zlib.error, UnicodeDecodeError) as e:
return await ctx.respond("Failed to decode context: " + str(e))
if context not in self.context_cache:
return await ctx.respond(":x: Context not found in cache.")
context = self.context_cache[context]
content = None
try_hosts = {
"127.0.0.1:11434": "localhost",
"100.106.34.86:11434": "Nex Laptop",
"100.66.187.46:11434": "Nexbox",
"100.106.34.86:11434": "NexTop",
"ollama.shronk.net": "Alibaba Cloud",
"100.66.187.46:11434": "NexBox",
"100.116.242.161:11434": "PortaPi"
}
model = model.casefold()
@ -1918,6 +1914,7 @@ class OtherCog(commands.Cog):
ephemeral=True
)
model = "orca-mini"
if server != "auto":
async with httpx.AsyncClient(follow_redirects=True) as client:
for host in try_hosts.keys():
try:
@ -2117,43 +2114,33 @@ class OtherCog(commands.Cog):
context: Optional[list[int]] = chunk.get("context")
# noinspection PyTypeChecker
if context:
context_json = json.dumps(context)
start = time()
context_json_compressed = await asyncio.to_thread(
functools.partial(zlib.compress, context_json.encode())
)
end = time()
compress_time_spent = format(round(end * 1000 - start * 1000), ",")
context: str = base64.b64encode(context_json_compressed).decode()
key = os.urandom(8).hex()
self.context_cache[key] = context
else:
compress_time_spent = "N/A"
context = None
context = key = None
value = ("* Total: {}\n"
"* Model load: {}\n"
"* Sample generation: {}\n"
"* Prompt eval: {}\n"
"* Response generation: {}\n"
"* Context compression: {} milliseconds").format(
"* Response generation: {}\n").format(
total_time_spent,
load_time_spent,
sample_time_sent,
prompt_eval_time_spent,
eval_time_spent,
compress_time_spent
)
embed.add_field(
name="Timings",
value=value
value=value,
inline=False
)
if context:
embed.add_field(
name="Context Key",
value=key,
inline=True
)
await msg.edit(content=None, embed=embed, view=None)
if context:
await ctx.respond(
"Context:\n"
"```\n"
f"{context}\n"
"```",
ephemeral=True
)
self.ollama_locks.pop(msg, None)