mirror of
https://github.com/nexy7574/LCC-bot.git
synced 2024-09-19 18:16:34 +01:00
Implement context
This commit is contained in:
parent
2fec64c1e2
commit
21a9bcecf9
1 changed files with 25 additions and 13 deletions
|
@ -1890,6 +1890,17 @@ class OtherCog(commands.Cog):
|
||||||
system_prompt = modal.system_prompt or system_prompt
|
system_prompt = modal.system_prompt or system_prompt
|
||||||
else:
|
else:
|
||||||
await ctx.defer()
|
await ctx.defer()
|
||||||
|
|
||||||
|
if context:
|
||||||
|
try:
|
||||||
|
context_decoded = base64.b85decode(context).decode()
|
||||||
|
context_decompressed = await asyncio.to_thread(
|
||||||
|
functools.partial(zlib.decompress, context_decoded.encode())
|
||||||
|
)
|
||||||
|
context = json.loads(context_decompressed)
|
||||||
|
except (ValueError, zlib.error, UnicodeDecodeError) as e:
|
||||||
|
return await ctx.respond("Failed to decode context: " + str(e))
|
||||||
|
|
||||||
content = None
|
content = None
|
||||||
try_hosts = {
|
try_hosts = {
|
||||||
"127.0.0.1:11434": "localhost",
|
"127.0.0.1:11434": "localhost",
|
||||||
|
@ -2006,16 +2017,19 @@ class OtherCog(commands.Cog):
|
||||||
embed.set_footer(text=f"Powered by Ollama • {host} ({try_hosts.get(host, 'Other')})")
|
embed.set_footer(text=f"Powered by Ollama • {host} ({try_hosts.get(host, 'Other')})")
|
||||||
await msg.edit(embed=embed)
|
await msg.edit(embed=embed)
|
||||||
async with ctx.channel.typing():
|
async with ctx.channel.typing():
|
||||||
async with client.stream(
|
payload = {
|
||||||
"POST",
|
|
||||||
"/generate",
|
|
||||||
json={
|
|
||||||
"model": model,
|
"model": model,
|
||||||
"prompt": query,
|
"prompt": query,
|
||||||
"format": "json",
|
"format": "json",
|
||||||
"system": system_prompt,
|
"system": system_prompt,
|
||||||
"stream": True
|
"stream": True
|
||||||
},
|
}
|
||||||
|
if context:
|
||||||
|
payload["context"] = context
|
||||||
|
async with client.stream(
|
||||||
|
"POST",
|
||||||
|
"/generate",
|
||||||
|
json=payload,
|
||||||
timeout=None
|
timeout=None
|
||||||
) as response:
|
) as response:
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
|
@ -2104,13 +2118,11 @@ class OtherCog(commands.Cog):
|
||||||
context_json = json.dumps(context)
|
context_json = json.dumps(context)
|
||||||
start = time()
|
start = time()
|
||||||
context_json_compressed = await asyncio.to_thread(
|
context_json_compressed = await asyncio.to_thread(
|
||||||
zlib.compress,
|
functools.partial(zlib.compress, context_json.encode())
|
||||||
context_json.encode(),
|
|
||||||
9
|
|
||||||
)
|
)
|
||||||
end = time()
|
end = time()
|
||||||
compress_time_spent = format(end * 1000 - start * 1000, ",")
|
compress_time_spent = format(round(end * 1000 - start * 1000), ",")
|
||||||
context: str = base64.b64encode(context_json_compressed).decode()
|
context: str = base64.b85encode(context_json_compressed).decode()
|
||||||
else:
|
else:
|
||||||
compress_time_spent = "N/A"
|
compress_time_spent = "N/A"
|
||||||
context = None
|
context = None
|
||||||
|
@ -2135,7 +2147,7 @@ class OtherCog(commands.Cog):
|
||||||
embed.add_field(
|
embed.add_field(
|
||||||
name="Context",
|
name="Context",
|
||||||
value=f"```\n{context}\n```"[:1024],
|
value=f"```\n{context}\n```"[:1024],
|
||||||
inline=True
|
inline=False
|
||||||
)
|
)
|
||||||
await msg.edit(content=None, embed=embed, view=None)
|
await msg.edit(content=None, embed=embed, view=None)
|
||||||
self.ollama_locks.pop(msg, None)
|
self.ollama_locks.pop(msg, None)
|
||||||
|
|
Loading…
Reference in a new issue