Enable ollama pull
Some checks failed
Build and Publish / build_and_publish (push) Failing after 1m53s
Some checks failed
Build and Publish / build_and_publish (push) Failing after 1m53s
This commit is contained in:
parent
99001a60ba
commit
28908f217c
2 changed files with 135 additions and 11 deletions
|
@ -46,10 +46,13 @@ async def get_available_tags_autocomplete(ctx: discord.AutocompleteContext):
|
|||
chosen_server = get_server(ctx.options.get("server") or get_servers()[0].name)
|
||||
async with ollama_client(str(chosen_server.base_url), timeout=2) as client:
|
||||
tags = (await client.list())["models"]
|
||||
return [tag["model"] for tag in tags if ctx.value.casefold() in tag["model"].casefold()]
|
||||
|
||||
v = [tag["model"] for tag in tags if ctx.value.casefold() in tag["model"].casefold()]
|
||||
return [ctx.value, *v][:25]
|
||||
|
||||
_ServerOptionChoices = [discord.OptionChoice(server.name, server.name) for server in get_servers()]
|
||||
_ServerOptionAutocomplete = discord.utils.basic_autocomplete(
|
||||
[x.name for x in get_servers()]
|
||||
)
|
||||
|
||||
|
||||
class Chat(commands.Cog):
|
||||
|
@ -112,7 +115,15 @@ class Chat(commands.Cog):
|
|||
)
|
||||
await ctx.edit(embed=embed)
|
||||
|
||||
@commands.slash_command(name="ollama")
|
||||
ollama_group = discord.SlashCommandGroup(
|
||||
name="ollama",
|
||||
description="Commands related to ollama.",
|
||||
guild_only=True,
|
||||
max_concurrency=commands.MaxConcurrency(1, per=commands.BucketType.user, wait=False),
|
||||
cooldown=commands.CooldownMapping(commands.Cooldown(1, 10), commands.BucketType.user)
|
||||
)
|
||||
|
||||
@ollama_group.command(name="chat")
|
||||
async def start_ollama_chat(
|
||||
self,
|
||||
ctx: discord.ApplicationContext,
|
||||
|
@ -130,7 +141,7 @@ class Chat(commands.Cog):
|
|||
discord.Option(
|
||||
discord.SlashCommandOptionType.string,
|
||||
description="The server to use.",
|
||||
choices=_ServerOptionChoices,
|
||||
autocomplete=_ServerOptionAutocomplete,
|
||||
default=get_servers()[0].name
|
||||
)
|
||||
],
|
||||
|
@ -140,7 +151,7 @@ class Chat(commands.Cog):
|
|||
discord.SlashCommandOptionType.string,
|
||||
description="The model to use.",
|
||||
autocomplete=get_available_tags_autocomplete,
|
||||
default="llama3:latest"
|
||||
default="default"
|
||||
)
|
||||
],
|
||||
image: typing.Annotated[
|
||||
|
@ -173,7 +184,9 @@ class Chat(commands.Cog):
|
|||
"""Have a chat with ollama"""
|
||||
await ctx.defer()
|
||||
server = get_server(server)
|
||||
if not await server.is_online():
|
||||
if not server:
|
||||
return await ctx.respond("\N{cross mark} Unknown Server.")
|
||||
elif not await server.is_online():
|
||||
await ctx.respond(
|
||||
content=f"{server} is offline. Finding a suitable server...",
|
||||
)
|
||||
|
@ -183,10 +196,13 @@ class Chat(commands.Cog):
|
|||
return await ctx.edit(content=str(err), delete_after=30)
|
||||
await ctx.delete(delay=5)
|
||||
async with self.server_locks[server.name]:
|
||||
if model == "default":
|
||||
model = server.default_model
|
||||
async with ollama_client(str(server.base_url)) as client:
|
||||
client: AsyncClient
|
||||
self.log.info("Checking if %r has the model %r", server, model)
|
||||
tags = (await client.list())["models"]
|
||||
# Download code. It's recommended to collapse this in the editor.
|
||||
if model not in [x["model"] for x in tags]:
|
||||
embed = discord.Embed(
|
||||
title=f"Downloading {model} on {server}.",
|
||||
|
@ -265,6 +281,7 @@ class Chat(commands.Cog):
|
|||
await ctx.edit(embed=embed, delete_after=30, view=None)
|
||||
|
||||
messages = []
|
||||
thread = None
|
||||
if thread_id:
|
||||
thread = await OllamaThread.get_or_none(thread_id=thread_id)
|
||||
if thread:
|
||||
|
@ -330,6 +347,7 @@ class Chat(commands.Cog):
|
|||
else:
|
||||
file = discord.utils.MISSING
|
||||
|
||||
if not thread:
|
||||
thread = OllamaThread(
|
||||
messages=[{"role": m["role"], "content": m["content"]} for m in messages],
|
||||
)
|
||||
|
@ -337,6 +355,112 @@ class Chat(commands.Cog):
|
|||
embed.set_footer(text=f"Chat ID: {thread.thread_id}")
|
||||
await msg.edit(embed=embed, view=None, file=file)
|
||||
|
||||
@ollama_group.command(name="pull")
|
||||
async def pull_ollama_model(
|
||||
self,
|
||||
ctx: discord.ApplicationContext,
|
||||
server: typing.Annotated[
|
||||
str,
|
||||
discord.Option(
|
||||
discord.SlashCommandOptionType.string,
|
||||
description="The server to use.",
|
||||
autocomplete=_ServerOptionAutocomplete,
|
||||
default=get_servers()[0].name
|
||||
)
|
||||
],
|
||||
model: typing.Annotated[
|
||||
str,
|
||||
discord.Option(
|
||||
discord.SlashCommandOptionType.string,
|
||||
description="The model to use.",
|
||||
autocomplete=get_available_tags_autocomplete,
|
||||
default="llama3:latest"
|
||||
)
|
||||
],
|
||||
):
|
||||
"""Downloads a tag on the target server"""
|
||||
await ctx.defer()
|
||||
server = get_server(server)
|
||||
if not server:
|
||||
return await ctx.respond("\N{cross mark} Unknown server.")
|
||||
elif not await server.is_online():
|
||||
return await ctx.respond(f"\N{cross mark} Server {server.name!r} is not responding")
|
||||
embed = discord.Embed(
|
||||
title=f"Downloading {model} on {server}.",
|
||||
description=f"Initiating download...",
|
||||
color=discord.Color.blurple()
|
||||
)
|
||||
view = StopDownloadView(ctx)
|
||||
await ctx.respond(
|
||||
embed=embed,
|
||||
view=view
|
||||
)
|
||||
last_edit = 0
|
||||
async with ctx.typing():
|
||||
try:
|
||||
last_completed = 0
|
||||
last_completed_ts = time.time()
|
||||
|
||||
async for line in await client.pull(model, stream=True):
|
||||
if view.event.is_set():
|
||||
embed.add_field(name="Error!", value="Download cancelled.")
|
||||
embed.colour = discord.Colour.red()
|
||||
await ctx.edit(embed=embed)
|
||||
return
|
||||
self.log.debug("Response from %r: %r", server, line)
|
||||
if line["status"] in {
|
||||
"pulling manifest",
|
||||
"verifying sha256 digest",
|
||||
"writing manifest",
|
||||
"removing any unused layers",
|
||||
"success"
|
||||
}:
|
||||
embed.description = line["status"].capitalize()
|
||||
else:
|
||||
total = line["total"]
|
||||
completed = line.get("completed", 0)
|
||||
percent = round(completed / total * 100, 1)
|
||||
pb_fill = "▰" * int(percent / 10)
|
||||
pb_empty = "▱" * (10 - int(percent / 10))
|
||||
bytes_per_second = completed - last_completed
|
||||
bytes_per_second /= (time.time() - last_completed_ts)
|
||||
last_completed = completed
|
||||
last_completed_ts = time.time()
|
||||
mbps = round((bytes_per_second * 8) / 1024 / 1024)
|
||||
eta = (total - completed) / max(1, bytes_per_second)
|
||||
progress_bar = f"[{pb_fill}{pb_empty}]"
|
||||
ns_total = naturalsize(total, binary=True)
|
||||
ns_completed = naturalsize(completed, binary=True)
|
||||
embed.description = (
|
||||
f"{line['status'].capitalize()} {percent}% {progress_bar} "
|
||||
f"({ns_completed}/{ns_total} @ {mbps} Mb/s) "
|
||||
f"[ETA: {naturaldelta(eta)}]"
|
||||
)
|
||||
|
||||
if time.time() - last_edit >= 2.5:
|
||||
await ctx.edit(embed=embed)
|
||||
last_edit = time.time()
|
||||
except ResponseError as err:
|
||||
if err.error.endswith("file does not exist"):
|
||||
await ctx.edit(
|
||||
embed=None,
|
||||
content="The model %r does not exist." % model,
|
||||
delete_after=60,
|
||||
view=None
|
||||
)
|
||||
else:
|
||||
embed.add_field(
|
||||
name="Error!",
|
||||
value=err.error
|
||||
)
|
||||
embed.colour = discord.Colour.red()
|
||||
await ctx.edit(embed=embed, view=None)
|
||||
return
|
||||
else:
|
||||
embed.colour = discord.Colour.green()
|
||||
embed.description = f"Downloaded {model} on {server}."
|
||||
await ctx.edit(embed=embed, delete_after=30, view=None)
|
||||
|
||||
|
||||
def setup(bot):
|
||||
bot.add_cog(Chat(bot))
|
||||
|
|
|
@ -66,7 +66,7 @@ def get_server(name_or_base_url: str) -> ServerConfig | None:
|
|||
else:
|
||||
if parsed.netloc and parsed.scheme in ["http", "https"]:
|
||||
defaults = {
|
||||
"name": ":temporary:",
|
||||
"name": parsed.hostname,
|
||||
"base_url": "{0.scheme}://{0.netloc}".format(parsed),
|
||||
"gpu": False,
|
||||
"vram_gb": 2,
|
||||
|
|
Loading…
Reference in a new issue