Enable ollama pull
Some checks failed
Build and Publish / build_and_publish (push) Failing after 1m53s
Some checks failed
Build and Publish / build_and_publish (push) Failing after 1m53s
This commit is contained in:
parent
99001a60ba
commit
28908f217c
2 changed files with 135 additions and 11 deletions
|
@ -46,10 +46,13 @@ async def get_available_tags_autocomplete(ctx: discord.AutocompleteContext):
|
||||||
chosen_server = get_server(ctx.options.get("server") or get_servers()[0].name)
|
chosen_server = get_server(ctx.options.get("server") or get_servers()[0].name)
|
||||||
async with ollama_client(str(chosen_server.base_url), timeout=2) as client:
|
async with ollama_client(str(chosen_server.base_url), timeout=2) as client:
|
||||||
tags = (await client.list())["models"]
|
tags = (await client.list())["models"]
|
||||||
return [tag["model"] for tag in tags if ctx.value.casefold() in tag["model"].casefold()]
|
|
||||||
|
|
||||||
|
v = [tag["model"] for tag in tags if ctx.value.casefold() in tag["model"].casefold()]
|
||||||
|
return [ctx.value, *v][:25]
|
||||||
|
|
||||||
_ServerOptionChoices = [discord.OptionChoice(server.name, server.name) for server in get_servers()]
|
_ServerOptionAutocomplete = discord.utils.basic_autocomplete(
|
||||||
|
[x.name for x in get_servers()]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class Chat(commands.Cog):
|
class Chat(commands.Cog):
|
||||||
|
@ -112,7 +115,15 @@ class Chat(commands.Cog):
|
||||||
)
|
)
|
||||||
await ctx.edit(embed=embed)
|
await ctx.edit(embed=embed)
|
||||||
|
|
||||||
@commands.slash_command(name="ollama")
|
ollama_group = discord.SlashCommandGroup(
|
||||||
|
name="ollama",
|
||||||
|
description="Commands related to ollama.",
|
||||||
|
guild_only=True,
|
||||||
|
max_concurrency=commands.MaxConcurrency(1, per=commands.BucketType.user, wait=False),
|
||||||
|
cooldown=commands.CooldownMapping(commands.Cooldown(1, 10), commands.BucketType.user)
|
||||||
|
)
|
||||||
|
|
||||||
|
@ollama_group.command(name="chat")
|
||||||
async def start_ollama_chat(
|
async def start_ollama_chat(
|
||||||
self,
|
self,
|
||||||
ctx: discord.ApplicationContext,
|
ctx: discord.ApplicationContext,
|
||||||
|
@ -130,7 +141,7 @@ class Chat(commands.Cog):
|
||||||
discord.Option(
|
discord.Option(
|
||||||
discord.SlashCommandOptionType.string,
|
discord.SlashCommandOptionType.string,
|
||||||
description="The server to use.",
|
description="The server to use.",
|
||||||
choices=_ServerOptionChoices,
|
autocomplete=_ServerOptionAutocomplete,
|
||||||
default=get_servers()[0].name
|
default=get_servers()[0].name
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
|
@ -140,7 +151,7 @@ class Chat(commands.Cog):
|
||||||
discord.SlashCommandOptionType.string,
|
discord.SlashCommandOptionType.string,
|
||||||
description="The model to use.",
|
description="The model to use.",
|
||||||
autocomplete=get_available_tags_autocomplete,
|
autocomplete=get_available_tags_autocomplete,
|
||||||
default="llama3:latest"
|
default="default"
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
image: typing.Annotated[
|
image: typing.Annotated[
|
||||||
|
@ -173,7 +184,9 @@ class Chat(commands.Cog):
|
||||||
"""Have a chat with ollama"""
|
"""Have a chat with ollama"""
|
||||||
await ctx.defer()
|
await ctx.defer()
|
||||||
server = get_server(server)
|
server = get_server(server)
|
||||||
if not await server.is_online():
|
if not server:
|
||||||
|
return await ctx.respond("\N{cross mark} Unknown Server.")
|
||||||
|
elif not await server.is_online():
|
||||||
await ctx.respond(
|
await ctx.respond(
|
||||||
content=f"{server} is offline. Finding a suitable server...",
|
content=f"{server} is offline. Finding a suitable server...",
|
||||||
)
|
)
|
||||||
|
@ -183,10 +196,13 @@ class Chat(commands.Cog):
|
||||||
return await ctx.edit(content=str(err), delete_after=30)
|
return await ctx.edit(content=str(err), delete_after=30)
|
||||||
await ctx.delete(delay=5)
|
await ctx.delete(delay=5)
|
||||||
async with self.server_locks[server.name]:
|
async with self.server_locks[server.name]:
|
||||||
|
if model == "default":
|
||||||
|
model = server.default_model
|
||||||
async with ollama_client(str(server.base_url)) as client:
|
async with ollama_client(str(server.base_url)) as client:
|
||||||
client: AsyncClient
|
client: AsyncClient
|
||||||
self.log.info("Checking if %r has the model %r", server, model)
|
self.log.info("Checking if %r has the model %r", server, model)
|
||||||
tags = (await client.list())["models"]
|
tags = (await client.list())["models"]
|
||||||
|
# Download code. It's recommended to collapse this in the editor.
|
||||||
if model not in [x["model"] for x in tags]:
|
if model not in [x["model"] for x in tags]:
|
||||||
embed = discord.Embed(
|
embed = discord.Embed(
|
||||||
title=f"Downloading {model} on {server}.",
|
title=f"Downloading {model} on {server}.",
|
||||||
|
@ -265,6 +281,7 @@ class Chat(commands.Cog):
|
||||||
await ctx.edit(embed=embed, delete_after=30, view=None)
|
await ctx.edit(embed=embed, delete_after=30, view=None)
|
||||||
|
|
||||||
messages = []
|
messages = []
|
||||||
|
thread = None
|
||||||
if thread_id:
|
if thread_id:
|
||||||
thread = await OllamaThread.get_or_none(thread_id=thread_id)
|
thread = await OllamaThread.get_or_none(thread_id=thread_id)
|
||||||
if thread:
|
if thread:
|
||||||
|
@ -330,6 +347,7 @@ class Chat(commands.Cog):
|
||||||
else:
|
else:
|
||||||
file = discord.utils.MISSING
|
file = discord.utils.MISSING
|
||||||
|
|
||||||
|
if not thread:
|
||||||
thread = OllamaThread(
|
thread = OllamaThread(
|
||||||
messages=[{"role": m["role"], "content": m["content"]} for m in messages],
|
messages=[{"role": m["role"], "content": m["content"]} for m in messages],
|
||||||
)
|
)
|
||||||
|
@ -337,6 +355,112 @@ class Chat(commands.Cog):
|
||||||
embed.set_footer(text=f"Chat ID: {thread.thread_id}")
|
embed.set_footer(text=f"Chat ID: {thread.thread_id}")
|
||||||
await msg.edit(embed=embed, view=None, file=file)
|
await msg.edit(embed=embed, view=None, file=file)
|
||||||
|
|
||||||
|
@ollama_group.command(name="pull")
|
||||||
|
async def pull_ollama_model(
|
||||||
|
self,
|
||||||
|
ctx: discord.ApplicationContext,
|
||||||
|
server: typing.Annotated[
|
||||||
|
str,
|
||||||
|
discord.Option(
|
||||||
|
discord.SlashCommandOptionType.string,
|
||||||
|
description="The server to use.",
|
||||||
|
autocomplete=_ServerOptionAutocomplete,
|
||||||
|
default=get_servers()[0].name
|
||||||
|
)
|
||||||
|
],
|
||||||
|
model: typing.Annotated[
|
||||||
|
str,
|
||||||
|
discord.Option(
|
||||||
|
discord.SlashCommandOptionType.string,
|
||||||
|
description="The model to use.",
|
||||||
|
autocomplete=get_available_tags_autocomplete,
|
||||||
|
default="llama3:latest"
|
||||||
|
)
|
||||||
|
],
|
||||||
|
):
|
||||||
|
"""Downloads a tag on the target server"""
|
||||||
|
await ctx.defer()
|
||||||
|
server = get_server(server)
|
||||||
|
if not server:
|
||||||
|
return await ctx.respond("\N{cross mark} Unknown server.")
|
||||||
|
elif not await server.is_online():
|
||||||
|
return await ctx.respond(f"\N{cross mark} Server {server.name!r} is not responding")
|
||||||
|
embed = discord.Embed(
|
||||||
|
title=f"Downloading {model} on {server}.",
|
||||||
|
description=f"Initiating download...",
|
||||||
|
color=discord.Color.blurple()
|
||||||
|
)
|
||||||
|
view = StopDownloadView(ctx)
|
||||||
|
await ctx.respond(
|
||||||
|
embed=embed,
|
||||||
|
view=view
|
||||||
|
)
|
||||||
|
last_edit = 0
|
||||||
|
async with ctx.typing():
|
||||||
|
try:
|
||||||
|
last_completed = 0
|
||||||
|
last_completed_ts = time.time()
|
||||||
|
|
||||||
|
async for line in await client.pull(model, stream=True):
|
||||||
|
if view.event.is_set():
|
||||||
|
embed.add_field(name="Error!", value="Download cancelled.")
|
||||||
|
embed.colour = discord.Colour.red()
|
||||||
|
await ctx.edit(embed=embed)
|
||||||
|
return
|
||||||
|
self.log.debug("Response from %r: %r", server, line)
|
||||||
|
if line["status"] in {
|
||||||
|
"pulling manifest",
|
||||||
|
"verifying sha256 digest",
|
||||||
|
"writing manifest",
|
||||||
|
"removing any unused layers",
|
||||||
|
"success"
|
||||||
|
}:
|
||||||
|
embed.description = line["status"].capitalize()
|
||||||
|
else:
|
||||||
|
total = line["total"]
|
||||||
|
completed = line.get("completed", 0)
|
||||||
|
percent = round(completed / total * 100, 1)
|
||||||
|
pb_fill = "▰" * int(percent / 10)
|
||||||
|
pb_empty = "▱" * (10 - int(percent / 10))
|
||||||
|
bytes_per_second = completed - last_completed
|
||||||
|
bytes_per_second /= (time.time() - last_completed_ts)
|
||||||
|
last_completed = completed
|
||||||
|
last_completed_ts = time.time()
|
||||||
|
mbps = round((bytes_per_second * 8) / 1024 / 1024)
|
||||||
|
eta = (total - completed) / max(1, bytes_per_second)
|
||||||
|
progress_bar = f"[{pb_fill}{pb_empty}]"
|
||||||
|
ns_total = naturalsize(total, binary=True)
|
||||||
|
ns_completed = naturalsize(completed, binary=True)
|
||||||
|
embed.description = (
|
||||||
|
f"{line['status'].capitalize()} {percent}% {progress_bar} "
|
||||||
|
f"({ns_completed}/{ns_total} @ {mbps} Mb/s) "
|
||||||
|
f"[ETA: {naturaldelta(eta)}]"
|
||||||
|
)
|
||||||
|
|
||||||
|
if time.time() - last_edit >= 2.5:
|
||||||
|
await ctx.edit(embed=embed)
|
||||||
|
last_edit = time.time()
|
||||||
|
except ResponseError as err:
|
||||||
|
if err.error.endswith("file does not exist"):
|
||||||
|
await ctx.edit(
|
||||||
|
embed=None,
|
||||||
|
content="The model %r does not exist." % model,
|
||||||
|
delete_after=60,
|
||||||
|
view=None
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
embed.add_field(
|
||||||
|
name="Error!",
|
||||||
|
value=err.error
|
||||||
|
)
|
||||||
|
embed.colour = discord.Colour.red()
|
||||||
|
await ctx.edit(embed=embed, view=None)
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
embed.colour = discord.Colour.green()
|
||||||
|
embed.description = f"Downloaded {model} on {server}."
|
||||||
|
await ctx.edit(embed=embed, delete_after=30, view=None)
|
||||||
|
|
||||||
|
|
||||||
def setup(bot):
|
def setup(bot):
|
||||||
bot.add_cog(Chat(bot))
|
bot.add_cog(Chat(bot))
|
||||||
|
|
|
@ -66,7 +66,7 @@ def get_server(name_or_base_url: str) -> ServerConfig | None:
|
||||||
else:
|
else:
|
||||||
if parsed.netloc and parsed.scheme in ["http", "https"]:
|
if parsed.netloc and parsed.scheme in ["http", "https"]:
|
||||||
defaults = {
|
defaults = {
|
||||||
"name": ":temporary:",
|
"name": parsed.hostname,
|
||||||
"base_url": "{0.scheme}://{0.netloc}".format(parsed),
|
"base_url": "{0.scheme}://{0.netloc}".format(parsed),
|
||||||
"gpu": False,
|
"gpu": False,
|
||||||
"vram_gb": 2,
|
"vram_gb": 2,
|
||||||
|
|
Loading…
Reference in a new issue