mirror of
https://github.com/nexy7574/LCC-bot.git
synced 2024-09-19 18:16:34 +01:00
Add ollama command
This commit is contained in:
parent
c59128da41
commit
e64aff7c0c
1 changed files with 98 additions and 3 deletions
101
cogs/other.py
101
cogs/other.py
|
@ -64,6 +64,26 @@ except Exception as _pyttsx3_err:
|
||||||
VOICES = []
|
VOICES = []
|
||||||
|
|
||||||
|
|
||||||
|
class OllamaStreamReader:
|
||||||
|
def __init__(self, response: httpx.Response):
|
||||||
|
self.response = response
|
||||||
|
self.stream = response.aiter_bytes(1)
|
||||||
|
self._buffer = b""
|
||||||
|
|
||||||
|
async def __aiter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __anext__(self) -> dict[str, str | int | bool]:
|
||||||
|
if self.response.is_stream_consumed:
|
||||||
|
raise StopAsyncIteration
|
||||||
|
self._buffer = b""
|
||||||
|
while not self._buffer.endswith(b"}\n"):
|
||||||
|
async for char in self.stream:
|
||||||
|
self._buffer += char
|
||||||
|
|
||||||
|
return json.loads(self._buffer.decode("utf-8", "replace"))
|
||||||
|
|
||||||
|
|
||||||
def format_autocomplete(ctx: discord.AutocompleteContext):
|
def format_autocomplete(ctx: discord.AutocompleteContext):
|
||||||
url = ctx.options.get("url", os.urandom(6).hex())
|
url = ctx.options.get("url", os.urandom(6).hex())
|
||||||
self: "OtherCog" = ctx.bot.cogs["OtherCog"] # type: ignore
|
self: "OtherCog" = ctx.bot.cogs["OtherCog"] # type: ignore
|
||||||
|
@ -1021,9 +1041,7 @@ class OtherCog(commands.Cog):
|
||||||
try:
|
try:
|
||||||
extracted_info = await asyncio.to_thread(downloader.extract_info, url, download=False)
|
extracted_info = await asyncio.to_thread(downloader.extract_info, url, download=False)
|
||||||
except yt_dlp.utils.DownloadError:
|
except yt_dlp.utils.DownloadError:
|
||||||
title = chosen_format = chosen_format_id = final_extension = format_note = "error"
|
title = "error"
|
||||||
resolution = vcodec = acodec = "error"
|
|
||||||
fps = 0
|
|
||||||
thumbnail_url = webpage_url = discord.Embed.Empty
|
thumbnail_url = webpage_url = discord.Embed.Empty
|
||||||
else:
|
else:
|
||||||
title = extracted_info.get("title", url)
|
title = extracted_info.get("title", url)
|
||||||
|
@ -1771,6 +1789,83 @@ class OtherCog(commands.Cog):
|
||||||
% (content, output_location.name)
|
% (content, output_location.name)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@commands.command(hidden=True)
|
||||||
|
@commands.is_owner()
|
||||||
|
@commands.max_concurrency(1, wait=True)
|
||||||
|
async def ollama(self, ctx: commands.Context, *, query: str):
|
||||||
|
""":3"""
|
||||||
|
if query.startswith("model:"):
|
||||||
|
model, query = query.split(" ", 1)
|
||||||
|
model = model[6:].casefold()
|
||||||
|
try:
|
||||||
|
_name, _tag = model.split(":", 1)
|
||||||
|
except ValueError:
|
||||||
|
model += ":latest"
|
||||||
|
else:
|
||||||
|
model = "orca-mini"
|
||||||
|
|
||||||
|
msg = await ctx.reply(f"Preparing {model!r} <a:loading:1101463077586735174>")
|
||||||
|
async with httpx.AsyncClient(base_url="http://localhost:11434/api") as client:
|
||||||
|
# get models
|
||||||
|
try:
|
||||||
|
response = await client.post("/show", json={"name": model})
|
||||||
|
except httpx.TransportError as e:
|
||||||
|
return await msg.edit(content="Failed to connect to Ollama: `%s`" % e)
|
||||||
|
if response.status_code == 404:
|
||||||
|
await msg.edit(content="Downloading model %r, please wait.")
|
||||||
|
async with ctx.channel.typing():
|
||||||
|
async with client.stream(
|
||||||
|
"POST",
|
||||||
|
"/pull",
|
||||||
|
json={"name": model},
|
||||||
|
timeout=None
|
||||||
|
) as response:
|
||||||
|
if response.status_code != 200:
|
||||||
|
return await msg.edit(content="Failed to download model: `%s`" % response.text)
|
||||||
|
async for chunk in OllamaStreamReader(response):
|
||||||
|
if "total" in chunk and "completed" in chunk:
|
||||||
|
completed = chunk["completed"] or 1 # avoid division by zero
|
||||||
|
total = chunk["total"] or 1
|
||||||
|
percent = completed / total * 100
|
||||||
|
if not percent % 10:
|
||||||
|
await msg.edit(content=f"`{chunk['status']}` - {percent:.0f}%")
|
||||||
|
else:
|
||||||
|
await msg.edit(content=f"`{chunk['status']}`")
|
||||||
|
elif response.status_code != 200:
|
||||||
|
return await msg.edit(content="Failed to get model: `%s`" % response.text)
|
||||||
|
|
||||||
|
output = discord.Embed(
|
||||||
|
title=f"{model} says:",
|
||||||
|
description="",
|
||||||
|
colour=discord.Colour.blurple(),
|
||||||
|
)
|
||||||
|
output.set_footer(text="Powered by Ollama")
|
||||||
|
|
||||||
|
async with ctx.channel.typing():
|
||||||
|
async with client.stream(
|
||||||
|
"POST",
|
||||||
|
"/generate",
|
||||||
|
json={
|
||||||
|
"model": model,
|
||||||
|
"prompt": query,
|
||||||
|
"format": "json",
|
||||||
|
"system": "You are a discord bot called Jimmy Saville. "
|
||||||
|
"Be helpful and make sure your response is safe for work, "
|
||||||
|
"and is less than 3500 characters"
|
||||||
|
}
|
||||||
|
) as response:
|
||||||
|
if response.status_code != 200:
|
||||||
|
return await msg.edit(content="Failed to generate text: `%s`" % response.text)
|
||||||
|
last_edit = msg.edited_at.timestamp() if msg.edited_at else msg.created_at.timestamp()
|
||||||
|
async for chunk in OllamaStreamReader(response):
|
||||||
|
if "done" not in chunk or "response" not in chunk:
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
output.description = chunk["response"]
|
||||||
|
if (time() - last_edit) >= 5 or chunk["done"] is True:
|
||||||
|
await msg.edit(content=None, embed=output)
|
||||||
|
break
|
||||||
|
|
||||||
|
|
||||||
def setup(bot):
|
def setup(bot):
|
||||||
bot.add_cog(OtherCog(bot))
|
bot.add_cog(OtherCog(bot))
|
||||||
|
|
Loading…
Reference in a new issue