mirror of
https://github.com/nexy7574/LCC-bot.git
synced 2024-09-19 18:16:34 +01:00
Improve context management
This commit is contained in:
parent
c531975569
commit
59f5101869
1 changed files with 33 additions and 46 deletions
|
@ -92,12 +92,10 @@ except Exception as _pyttsx3_err:
|
||||||
async def ollama_stream_reader(response: httpx.Response) -> typing.AsyncGenerator[
|
async def ollama_stream_reader(response: httpx.Response) -> typing.AsyncGenerator[
|
||||||
dict[str, str | int | bool], None
|
dict[str, str | int | bool], None
|
||||||
]:
|
]:
|
||||||
print("Starting to iterate over ollama response %r..." % response, file=sys.stderr)
|
|
||||||
async for chunk in response.aiter_lines():
|
async for chunk in response.aiter_lines():
|
||||||
# Each line is a JSON string
|
# Each line is a JSON string
|
||||||
try:
|
try:
|
||||||
loaded = json.loads(chunk)
|
loaded = json.loads(chunk)
|
||||||
print("Loaded chunk: %r" % loaded)
|
|
||||||
yield loaded
|
yield loaded
|
||||||
except json.JSONDecodeError as e:
|
except json.JSONDecodeError as e:
|
||||||
print("Failed to decode chunk %r: %r" % (chunk, e), file=sys.stderr)
|
print("Failed to decode chunk %r: %r" % (chunk, e), file=sys.stderr)
|
||||||
|
@ -137,6 +135,7 @@ class OtherCog(commands.Cog):
|
||||||
self._worker_task = self.bot.loop.create_task(self.cache_population_job())
|
self._worker_task = self.bot.loop.create_task(self.cache_population_job())
|
||||||
|
|
||||||
self.ollama_locks: dict[discord.Message, asyncio.Event] = {}
|
self.ollama_locks: dict[discord.Message, asyncio.Event] = {}
|
||||||
|
self.context_cache: dict[str, list[int]] = {}
|
||||||
|
|
||||||
def cog_unload(self):
|
def cog_unload(self):
|
||||||
self._worker_task.cancel()
|
self._worker_task.cancel()
|
||||||
|
@ -1840,7 +1839,8 @@ class OtherCog(commands.Cog):
|
||||||
ctx: discord.ApplicationContext,
|
ctx: discord.ApplicationContext,
|
||||||
model: str = "orca-mini",
|
model: str = "orca-mini",
|
||||||
query: str = None,
|
query: str = None,
|
||||||
context: str = None
|
context: str = None,
|
||||||
|
server: str = "auto"
|
||||||
):
|
):
|
||||||
""":3"""
|
""":3"""
|
||||||
with open("./assets/ollama-prompt.txt") as file:
|
with open("./assets/ollama-prompt.txt") as file:
|
||||||
|
@ -1892,20 +1892,16 @@ class OtherCog(commands.Cog):
|
||||||
await ctx.defer()
|
await ctx.defer()
|
||||||
|
|
||||||
if context:
|
if context:
|
||||||
try:
|
if context not in self.context_cache:
|
||||||
context_decoded = base64.b64decode(context).decode()
|
return await ctx.respond(":x: Context not found in cache.")
|
||||||
context_decompressed = await asyncio.to_thread(
|
context = self.context_cache[context]
|
||||||
functools.partial(zlib.decompress, context_decoded.encode())
|
|
||||||
)
|
|
||||||
context = json.loads(context_decompressed)
|
|
||||||
except (ValueError, zlib.error, UnicodeDecodeError) as e:
|
|
||||||
return await ctx.respond("Failed to decode context: " + str(e))
|
|
||||||
|
|
||||||
content = None
|
content = None
|
||||||
try_hosts = {
|
try_hosts = {
|
||||||
"127.0.0.1:11434": "localhost",
|
"127.0.0.1:11434": "localhost",
|
||||||
"100.106.34.86:11434": "Nex Laptop",
|
"100.106.34.86:11434": "NexTop",
|
||||||
"100.66.187.46:11434": "Nexbox",
|
"ollama.shronk.net": "Alibaba Cloud",
|
||||||
|
"100.66.187.46:11434": "NexBox",
|
||||||
"100.116.242.161:11434": "PortaPi"
|
"100.116.242.161:11434": "PortaPi"
|
||||||
}
|
}
|
||||||
model = model.casefold()
|
model = model.casefold()
|
||||||
|
@ -1918,19 +1914,20 @@ class OtherCog(commands.Cog):
|
||||||
ephemeral=True
|
ephemeral=True
|
||||||
)
|
)
|
||||||
model = "orca-mini"
|
model = "orca-mini"
|
||||||
async with httpx.AsyncClient(follow_redirects=True) as client:
|
if server != "auto":
|
||||||
for host in try_hosts.keys():
|
async with httpx.AsyncClient(follow_redirects=True) as client:
|
||||||
try:
|
for host in try_hosts.keys():
|
||||||
response = await client.get(
|
try:
|
||||||
f"http://{host}/api/tags",
|
response = await client.get(
|
||||||
)
|
f"http://{host}/api/tags",
|
||||||
response.raise_for_status()
|
)
|
||||||
except (httpx.TransportError, httpx.NetworkError, httpx.HTTPStatusError):
|
response.raise_for_status()
|
||||||
continue
|
except (httpx.TransportError, httpx.NetworkError, httpx.HTTPStatusError):
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
break
|
||||||
else:
|
else:
|
||||||
break
|
return await ctx.respond(":x: No servers available.")
|
||||||
else:
|
|
||||||
return await ctx.respond(":x: No servers available.")
|
|
||||||
|
|
||||||
embed = discord.Embed(
|
embed = discord.Embed(
|
||||||
colour=discord.Colour.greyple()
|
colour=discord.Colour.greyple()
|
||||||
|
@ -2117,43 +2114,33 @@ class OtherCog(commands.Cog):
|
||||||
context: Optional[list[int]] = chunk.get("context")
|
context: Optional[list[int]] = chunk.get("context")
|
||||||
# noinspection PyTypeChecker
|
# noinspection PyTypeChecker
|
||||||
if context:
|
if context:
|
||||||
context_json = json.dumps(context)
|
key = os.urandom(8).hex()
|
||||||
start = time()
|
self.context_cache[key] = context
|
||||||
context_json_compressed = await asyncio.to_thread(
|
|
||||||
functools.partial(zlib.compress, context_json.encode())
|
|
||||||
)
|
|
||||||
end = time()
|
|
||||||
compress_time_spent = format(round(end * 1000 - start * 1000), ",")
|
|
||||||
context: str = base64.b64encode(context_json_compressed).decode()
|
|
||||||
else:
|
else:
|
||||||
compress_time_spent = "N/A"
|
context = key = None
|
||||||
context = None
|
|
||||||
value = ("* Total: {}\n"
|
value = ("* Total: {}\n"
|
||||||
"* Model load: {}\n"
|
"* Model load: {}\n"
|
||||||
"* Sample generation: {}\n"
|
"* Sample generation: {}\n"
|
||||||
"* Prompt eval: {}\n"
|
"* Prompt eval: {}\n"
|
||||||
"* Response generation: {}\n"
|
"* Response generation: {}\n").format(
|
||||||
"* Context compression: {} milliseconds").format(
|
|
||||||
total_time_spent,
|
total_time_spent,
|
||||||
load_time_spent,
|
load_time_spent,
|
||||||
sample_time_sent,
|
sample_time_sent,
|
||||||
prompt_eval_time_spent,
|
prompt_eval_time_spent,
|
||||||
eval_time_spent,
|
eval_time_spent,
|
||||||
compress_time_spent
|
|
||||||
)
|
)
|
||||||
embed.add_field(
|
embed.add_field(
|
||||||
name="Timings",
|
name="Timings",
|
||||||
value=value
|
value=value,
|
||||||
|
inline=False
|
||||||
)
|
)
|
||||||
await msg.edit(content=None, embed=embed, view=None)
|
|
||||||
if context:
|
if context:
|
||||||
await ctx.respond(
|
embed.add_field(
|
||||||
"Context:\n"
|
name="Context Key",
|
||||||
"```\n"
|
value=key,
|
||||||
f"{context}\n"
|
inline=True
|
||||||
"```",
|
|
||||||
ephemeral=True
|
|
||||||
)
|
)
|
||||||
|
await msg.edit(content=None, embed=embed, view=None)
|
||||||
self.ollama_locks.pop(msg, None)
|
self.ollama_locks.pop(msg, None)
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue