mirror of
https://github.com/nexy7574/LCC-bot.git
synced 2024-09-19 10:03:40 +01:00
Add ollama command
This commit is contained in:
parent
c59128da41
commit
e64aff7c0c
1 changed files with 98 additions and 3 deletions
101
cogs/other.py
101
cogs/other.py
|
@ -64,6 +64,26 @@ except Exception as _pyttsx3_err:
|
|||
VOICES = []
|
||||
|
||||
|
||||
class OllamaStreamReader:
|
||||
def __init__(self, response: httpx.Response):
|
||||
self.response = response
|
||||
self.stream = response.aiter_bytes(1)
|
||||
self._buffer = b""
|
||||
|
||||
async def __aiter__(self):
|
||||
return self
|
||||
|
||||
async def __anext__(self) -> dict[str, str | int | bool]:
|
||||
if self.response.is_stream_consumed:
|
||||
raise StopAsyncIteration
|
||||
self._buffer = b""
|
||||
while not self._buffer.endswith(b"}\n"):
|
||||
async for char in self.stream:
|
||||
self._buffer += char
|
||||
|
||||
return json.loads(self._buffer.decode("utf-8", "replace"))
|
||||
|
||||
|
||||
def format_autocomplete(ctx: discord.AutocompleteContext):
|
||||
url = ctx.options.get("url", os.urandom(6).hex())
|
||||
self: "OtherCog" = ctx.bot.cogs["OtherCog"] # type: ignore
|
||||
|
@ -1021,9 +1041,7 @@ class OtherCog(commands.Cog):
|
|||
try:
|
||||
extracted_info = await asyncio.to_thread(downloader.extract_info, url, download=False)
|
||||
except yt_dlp.utils.DownloadError:
|
||||
title = chosen_format = chosen_format_id = final_extension = format_note = "error"
|
||||
resolution = vcodec = acodec = "error"
|
||||
fps = 0
|
||||
title = "error"
|
||||
thumbnail_url = webpage_url = discord.Embed.Empty
|
||||
else:
|
||||
title = extracted_info.get("title", url)
|
||||
|
@ -1771,6 +1789,83 @@ class OtherCog(commands.Cog):
|
|||
% (content, output_location.name)
|
||||
)
|
||||
|
||||
@commands.command(hidden=True)
|
||||
@commands.is_owner()
|
||||
@commands.max_concurrency(1, wait=True)
|
||||
async def ollama(self, ctx: commands.Context, *, query: str):
|
||||
""":3"""
|
||||
if query.startswith("model:"):
|
||||
model, query = query.split(" ", 1)
|
||||
model = model[6:].casefold()
|
||||
try:
|
||||
_name, _tag = model.split(":", 1)
|
||||
except ValueError:
|
||||
model += ":latest"
|
||||
else:
|
||||
model = "orca-mini"
|
||||
|
||||
msg = await ctx.reply(f"Preparing {model!r} <a:loading:1101463077586735174>")
|
||||
async with httpx.AsyncClient(base_url="http://localhost:11434/api") as client:
|
||||
# get models
|
||||
try:
|
||||
response = await client.post("/show", json={"name": model})
|
||||
except httpx.TransportError as e:
|
||||
return await msg.edit(content="Failed to connect to Ollama: `%s`" % e)
|
||||
if response.status_code == 404:
|
||||
await msg.edit(content="Downloading model %r, please wait.")
|
||||
async with ctx.channel.typing():
|
||||
async with client.stream(
|
||||
"POST",
|
||||
"/pull",
|
||||
json={"name": model},
|
||||
timeout=None
|
||||
) as response:
|
||||
if response.status_code != 200:
|
||||
return await msg.edit(content="Failed to download model: `%s`" % response.text)
|
||||
async for chunk in OllamaStreamReader(response):
|
||||
if "total" in chunk and "completed" in chunk:
|
||||
completed = chunk["completed"] or 1 # avoid division by zero
|
||||
total = chunk["total"] or 1
|
||||
percent = completed / total * 100
|
||||
if not percent % 10:
|
||||
await msg.edit(content=f"`{chunk['status']}` - {percent:.0f}%")
|
||||
else:
|
||||
await msg.edit(content=f"`{chunk['status']}`")
|
||||
elif response.status_code != 200:
|
||||
return await msg.edit(content="Failed to get model: `%s`" % response.text)
|
||||
|
||||
output = discord.Embed(
|
||||
title=f"{model} says:",
|
||||
description="",
|
||||
colour=discord.Colour.blurple(),
|
||||
)
|
||||
output.set_footer(text="Powered by Ollama")
|
||||
|
||||
async with ctx.channel.typing():
|
||||
async with client.stream(
|
||||
"POST",
|
||||
"/generate",
|
||||
json={
|
||||
"model": model,
|
||||
"prompt": query,
|
||||
"format": "json",
|
||||
"system": "You are a discord bot called Jimmy Saville. "
|
||||
"Be helpful and make sure your response is safe for work, "
|
||||
"and is less than 3500 characters"
|
||||
}
|
||||
) as response:
|
||||
if response.status_code != 200:
|
||||
return await msg.edit(content="Failed to generate text: `%s`" % response.text)
|
||||
last_edit = msg.edited_at.timestamp() if msg.edited_at else msg.created_at.timestamp()
|
||||
async for chunk in OllamaStreamReader(response):
|
||||
if "done" not in chunk or "response" not in chunk:
|
||||
continue
|
||||
else:
|
||||
output.description = chunk["response"]
|
||||
if (time() - last_edit) >= 5 or chunk["done"] is True:
|
||||
await msg.edit(content=None, embed=output)
|
||||
break
|
||||
|
||||
|
||||
def setup(bot):
|
||||
bot.add_cog(OtherCog(bot))
|
||||
|
|
Loading…
Reference in a new issue