Don't pass clients to handlers
This commit is contained in:
parent
0a2deba623
commit
3f894c493d
1 changed files with 34 additions and 45 deletions
|
@ -64,8 +64,8 @@ def get_time_spent(nanoseconds: int) -> str:
|
||||||
|
|
||||||
|
|
||||||
class OllamaDownloadHandler:
|
class OllamaDownloadHandler:
|
||||||
def __init__(self, client: httpx.AsyncClient, model: str):
|
def __init__(self, base_url: str, model: str):
|
||||||
self.client = client
|
self.base_url = base_url
|
||||||
self.model = model
|
self.model = model
|
||||||
self._abort = asyncio.Event()
|
self._abort = asyncio.Event()
|
||||||
self._total = 1
|
self._total = 1
|
||||||
|
@ -102,7 +102,8 @@ class OllamaDownloadHandler:
|
||||||
self.status = line["status"]
|
self.status = line["status"]
|
||||||
|
|
||||||
async def __aiter__(self):
|
async def __aiter__(self):
|
||||||
async with self.client.post("/api/pull", json={"name": self.model, "stream": True}, timeout=None) as response:
|
async with httpx.AsyncClient(base_url=self.base_url) as client:
|
||||||
|
async with client.post("/api/pull", json={"name": self.model, "stream": True}, timeout=None) as response:
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
async for line in ollama_stream(response.content):
|
async for line in ollama_stream(response.content):
|
||||||
self.parse_line(line)
|
self.parse_line(line)
|
||||||
|
@ -118,8 +119,8 @@ class OllamaDownloadHandler:
|
||||||
|
|
||||||
|
|
||||||
class OllamaChatHandler:
|
class OllamaChatHandler:
|
||||||
def __init__(self, client: httpx.AsyncClient, model: str, messages: list):
|
def __init__(self, base_url: str, model: str, messages: list):
|
||||||
self.client = client
|
self.base_url = base_url
|
||||||
self.model = model
|
self.model = model
|
||||||
self.messages = messages
|
self.messages = messages
|
||||||
self._abort = asyncio.Event()
|
self._abort = asyncio.Event()
|
||||||
|
@ -163,7 +164,8 @@ class OllamaChatHandler:
|
||||||
self._abort.set()
|
self._abort.set()
|
||||||
|
|
||||||
async def __aiter__(self):
|
async def __aiter__(self):
|
||||||
async with self.client.post(
|
async with httpx.AsyncClient(base_url=self.base_url) as client:
|
||||||
|
async with client.post(
|
||||||
"/api/chat",
|
"/api/chat",
|
||||||
json={
|
json={
|
||||||
"model": self.model,
|
"model": self.model,
|
||||||
|
@ -183,26 +185,12 @@ class OllamaChatHandler:
|
||||||
if line.get("done"):
|
if line.get("done"):
|
||||||
break
|
break
|
||||||
|
|
||||||
@classmethod
|
|
||||||
async def get_streamless(cls, client: httpx.AsyncClient, model: str, messages: list) -> "OllamaChatHandler":
|
|
||||||
async with client.post("/api/chat", json={"model": model, "messages": messages, "stream": False}) as response:
|
|
||||||
response.raise_for_status()
|
|
||||||
handler = cls(client, model, messages)
|
|
||||||
line = await response.json()
|
|
||||||
handler.parse_line(line)
|
|
||||||
handler.buffer.write(line["message"]["content"])
|
|
||||||
return handler
|
|
||||||
|
|
||||||
|
|
||||||
class OllamaClient:
|
class OllamaClient:
|
||||||
def __init__(self, base_url: str, authorisation: tuple[str, str] = None):
|
def __init__(self, base_url: str, authorisation: tuple[str, str] = None):
|
||||||
self.base_url = base_url
|
self.base_url = base_url
|
||||||
self.authorisation = authorisation
|
self.authorisation = authorisation
|
||||||
|
|
||||||
def _with_async_client(self, t) -> contextlib.AbstractContextManager[httpx.AsyncClient]:
|
|
||||||
with httpx.AsyncClient(base_url=self.base_url, timeout=t, auth=self.authorisation) as client:
|
|
||||||
yield client
|
|
||||||
|
|
||||||
def with_client(
|
def with_client(
|
||||||
self,
|
self,
|
||||||
timeout: httpx.Timeout | float | int | None = None
|
timeout: httpx.Timeout | float | int | None = None
|
||||||
|
@ -218,7 +206,8 @@ class OllamaClient:
|
||||||
timeout = httpx.Timeout(timeout)
|
timeout = httpx.Timeout(timeout)
|
||||||
else:
|
else:
|
||||||
timeout = timeout or httpx.Timeout(60)
|
timeout = timeout or httpx.Timeout(60)
|
||||||
yield from self._with_async_client(timeout)
|
async with httpx.AsyncClient(base_url=self.base_url, timeout=timeout, auth=self.authorisation) as client:
|
||||||
|
yield client
|
||||||
|
|
||||||
async def get_tags(self) -> dict[typing.Literal["models"], dict[str, str, int, dict[str, str, None]]]:
|
async def get_tags(self) -> dict[typing.Literal["models"], dict[str, str, int, dict[str, str, None]]]:
|
||||||
"""
|
"""
|
||||||
|
@ -248,7 +237,7 @@ class OllamaClient:
|
||||||
:param tag: The tag of the model. Defaults to latest.
|
:param tag: The tag of the model. Defaults to latest.
|
||||||
:return: An OllamaDownloadHandler instance.
|
:return: An OllamaDownloadHandler instance.
|
||||||
"""
|
"""
|
||||||
handler = OllamaDownloadHandler(httpx.AsyncClient(base_url=self.base_url), name + ":" + tag)
|
handler = OllamaDownloadHandler(self.base_url, name + ":" + tag)
|
||||||
return handler
|
return handler
|
||||||
|
|
||||||
def new_chat(
|
def new_chat(
|
||||||
|
@ -263,7 +252,7 @@ class OllamaClient:
|
||||||
:param messages:
|
:param messages:
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
handler = OllamaChatHandler(httpx.AsyncClient(base_url=self.base_url), model, messages)
|
handler = OllamaChatHandler(self.base_url, model, messages)
|
||||||
return handler
|
return handler
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue