Enable ollama pull
Some checks failed
Build and Publish / build_and_publish (push) Failing after 1m53s

This commit is contained in:
Nexus 2024-06-11 00:58:17 +01:00
parent 99001a60ba
commit 28908f217c
Signed by: nex
GPG key ID: 0FA334385D0B689F
2 changed files with 135 additions and 11 deletions

View file

@ -46,10 +46,13 @@ async def get_available_tags_autocomplete(ctx: discord.AutocompleteContext):
chosen_server = get_server(ctx.options.get("server") or get_servers()[0].name)
async with ollama_client(str(chosen_server.base_url), timeout=2) as client:
tags = (await client.list())["models"]
return [tag["model"] for tag in tags if ctx.value.casefold() in tag["model"].casefold()]
v = [tag["model"] for tag in tags if ctx.value.casefold() in tag["model"].casefold()]
return [ctx.value, *v][:25]
_ServerOptionChoices = [discord.OptionChoice(server.name, server.name) for server in get_servers()]
_ServerOptionAutocomplete = discord.utils.basic_autocomplete(
[x.name for x in get_servers()]
)
class Chat(commands.Cog):
@ -112,7 +115,15 @@ class Chat(commands.Cog):
)
await ctx.edit(embed=embed)
@commands.slash_command(name="ollama")
ollama_group = discord.SlashCommandGroup(
name="ollama",
description="Commands related to ollama.",
guild_only=True,
max_concurrency=commands.MaxConcurrency(1, per=commands.BucketType.user, wait=False),
cooldown=commands.CooldownMapping(commands.Cooldown(1, 10), commands.BucketType.user)
)
@ollama_group.command(name="chat")
async def start_ollama_chat(
self,
ctx: discord.ApplicationContext,
@ -130,7 +141,7 @@ class Chat(commands.Cog):
discord.Option(
discord.SlashCommandOptionType.string,
description="The server to use.",
choices=_ServerOptionChoices,
autocomplete=_ServerOptionAutocomplete,
default=get_servers()[0].name
)
],
@ -140,7 +151,7 @@ class Chat(commands.Cog):
discord.SlashCommandOptionType.string,
description="The model to use.",
autocomplete=get_available_tags_autocomplete,
default="llama3:latest"
default="default"
)
],
image: typing.Annotated[
@ -173,7 +184,9 @@ class Chat(commands.Cog):
"""Have a chat with ollama"""
await ctx.defer()
server = get_server(server)
if not await server.is_online():
if not server:
return await ctx.respond("\N{cross mark} Unknown Server.")
elif not await server.is_online():
await ctx.respond(
content=f"{server} is offline. Finding a suitable server...",
)
@ -183,10 +196,13 @@ class Chat(commands.Cog):
return await ctx.edit(content=str(err), delete_after=30)
await ctx.delete(delay=5)
async with self.server_locks[server.name]:
if model == "default":
model = server.default_model
async with ollama_client(str(server.base_url)) as client:
client: AsyncClient
self.log.info("Checking if %r has the model %r", server, model)
tags = (await client.list())["models"]
# Download code. It's recommended to collapse this in the editor.
if model not in [x["model"] for x in tags]:
embed = discord.Embed(
title=f"Downloading {model} on {server}.",
@ -265,6 +281,7 @@ class Chat(commands.Cog):
await ctx.edit(embed=embed, delete_after=30, view=None)
messages = []
thread = None
if thread_id:
thread = await OllamaThread.get_or_none(thread_id=thread_id)
if thread:
@ -330,6 +347,7 @@ class Chat(commands.Cog):
else:
file = discord.utils.MISSING
if not thread:
thread = OllamaThread(
messages=[{"role": m["role"], "content": m["content"]} for m in messages],
)
@ -337,6 +355,112 @@ class Chat(commands.Cog):
embed.set_footer(text=f"Chat ID: {thread.thread_id}")
await msg.edit(embed=embed, view=None, file=file)
@ollama_group.command(name="pull")
async def pull_ollama_model(
self,
ctx: discord.ApplicationContext,
server: typing.Annotated[
str,
discord.Option(
discord.SlashCommandOptionType.string,
description="The server to use.",
autocomplete=_ServerOptionAutocomplete,
default=get_servers()[0].name
)
],
model: typing.Annotated[
str,
discord.Option(
discord.SlashCommandOptionType.string,
description="The model to use.",
autocomplete=get_available_tags_autocomplete,
default="llama3:latest"
)
],
):
"""Downloads a tag on the target server"""
await ctx.defer()
server = get_server(server)
if not server:
return await ctx.respond("\N{cross mark} Unknown server.")
elif not await server.is_online():
return await ctx.respond(f"\N{cross mark} Server {server.name!r} is not responding")
embed = discord.Embed(
title=f"Downloading {model} on {server}.",
description=f"Initiating download...",
color=discord.Color.blurple()
)
view = StopDownloadView(ctx)
await ctx.respond(
embed=embed,
view=view
)
last_edit = 0
async with ctx.typing():
try:
last_completed = 0
last_completed_ts = time.time()
async for line in await client.pull(model, stream=True):
if view.event.is_set():
embed.add_field(name="Error!", value="Download cancelled.")
embed.colour = discord.Colour.red()
await ctx.edit(embed=embed)
return
self.log.debug("Response from %r: %r", server, line)
if line["status"] in {
"pulling manifest",
"verifying sha256 digest",
"writing manifest",
"removing any unused layers",
"success"
}:
embed.description = line["status"].capitalize()
else:
total = line["total"]
completed = line.get("completed", 0)
percent = round(completed / total * 100, 1)
pb_fill = "" * int(percent / 10)
pb_empty = "" * (10 - int(percent / 10))
bytes_per_second = completed - last_completed
bytes_per_second /= (time.time() - last_completed_ts)
last_completed = completed
last_completed_ts = time.time()
mbps = round((bytes_per_second * 8) / 1024 / 1024)
eta = (total - completed) / max(1, bytes_per_second)
progress_bar = f"[{pb_fill}{pb_empty}]"
ns_total = naturalsize(total, binary=True)
ns_completed = naturalsize(completed, binary=True)
embed.description = (
f"{line['status'].capitalize()} {percent}% {progress_bar} "
f"({ns_completed}/{ns_total} @ {mbps} Mb/s) "
f"[ETA: {naturaldelta(eta)}]"
)
if time.time() - last_edit >= 2.5:
await ctx.edit(embed=embed)
last_edit = time.time()
except ResponseError as err:
if err.error.endswith("file does not exist"):
await ctx.edit(
embed=None,
content="The model %r does not exist." % model,
delete_after=60,
view=None
)
else:
embed.add_field(
name="Error!",
value=err.error
)
embed.colour = discord.Colour.red()
await ctx.edit(embed=embed, view=None)
return
else:
embed.colour = discord.Colour.green()
embed.description = f"Downloaded {model} on {server}."
await ctx.edit(embed=embed, delete_after=30, view=None)
def setup(bot):
bot.add_cog(Chat(bot))

View file

@ -66,7 +66,7 @@ def get_server(name_or_base_url: str) -> ServerConfig | None:
else:
if parsed.netloc and parsed.scheme in ["http", "https"]:
defaults = {
"name": ":temporary:",
"name": parsed.hostname,
"base_url": "{0.scheme}://{0.netloc}".format(parsed),
"gpu": False,
"vram_gb": 2,