Don't pass clients to handlers

This commit is contained in:
Nexus 2024-04-14 18:26:51 +01:00
parent 0a2deba623
commit 3f894c493d
Signed by: nex
GPG key ID: 0FA334385D0B689F

View file

@ -64,8 +64,8 @@ def get_time_spent(nanoseconds: int) -> str:
class OllamaDownloadHandler:
def __init__(self, client: httpx.AsyncClient, model: str):
self.client = client
def __init__(self, base_url: str, model: str):
self.base_url = base_url
self.model = model
self._abort = asyncio.Event()
self._total = 1
@ -102,13 +102,14 @@ class OllamaDownloadHandler:
self.status = line["status"]
async def __aiter__(self):
async with self.client.post("/api/pull", json={"name": self.model, "stream": True}, timeout=None) as response:
response.raise_for_status()
async for line in ollama_stream(response.content):
self.parse_line(line)
if self._abort.is_set():
break
yield line
async with httpx.AsyncClient(base_url=self.base_url) as client:
async with client.post("/api/pull", json={"name": self.model, "stream": True}, timeout=None) as response:
response.raise_for_status()
async for line in ollama_stream(response.content):
self.parse_line(line)
if self._abort.is_set():
break
yield line
async def flatten(self) -> "OllamaDownloadHandler":
"""Returns the current instance, but fully consumed."""
@ -118,8 +119,8 @@ class OllamaDownloadHandler:
class OllamaChatHandler:
def __init__(self, client: httpx.AsyncClient, model: str, messages: list):
self.client = client
def __init__(self, base_url: str, model: str, messages: list):
self.base_url = base_url
self.model = model
self.messages = messages
self._abort = asyncio.Event()
@ -163,35 +164,26 @@ class OllamaChatHandler:
self._abort.set()
async def __aiter__(self):
async with self.client.post(
"/api/chat",
json={
"model": self.model,
"stream": True,
"messages": self.messages
}
) as response:
response.raise_for_status()
async for line in ollama_stream(response.content):
if self._abort.is_set():
break
async with httpx.AsyncClient(base_url=self.base_url) as client:
async with client.post(
"/api/chat",
json={
"model": self.model,
"stream": True,
"messages": self.messages
}
) as response:
response.raise_for_status()
async for line in ollama_stream(response.content):
if self._abort.is_set():
break
if line.get("message"):
self.buffer.write(line["message"]["content"])
yield line
if line.get("message"):
self.buffer.write(line["message"]["content"])
yield line
if line.get("done"):
break
@classmethod
async def get_streamless(cls, client: httpx.AsyncClient, model: str, messages: list) -> "OllamaChatHandler":
async with client.post("/api/chat", json={"model": model, "messages": messages, "stream": False}) as response:
response.raise_for_status()
handler = cls(client, model, messages)
line = await response.json()
handler.parse_line(line)
handler.buffer.write(line["message"]["content"])
return handler
if line.get("done"):
break
class OllamaClient:
@ -199,10 +191,6 @@ class OllamaClient:
self.base_url = base_url
self.authorisation = authorisation
def _with_async_client(self, t) -> contextlib.AbstractContextManager[httpx.AsyncClient]:
with httpx.AsyncClient(base_url=self.base_url, timeout=t, auth=self.authorisation) as client:
yield client
def with_client(
self,
timeout: httpx.Timeout | float | int | None = None
@ -218,7 +206,8 @@ class OllamaClient:
timeout = httpx.Timeout(timeout)
else:
timeout = timeout or httpx.Timeout(60)
yield from self._with_async_client(timeout)
async with httpx.AsyncClient(base_url=self.base_url, timeout=timeout, auth=self.authorisation) as client:
yield client
async def get_tags(self) -> dict[typing.Literal["models"], dict[str, str, int, dict[str, str, None]]]:
"""
@ -248,7 +237,7 @@ class OllamaClient:
:param tag: The tag of the model. Defaults to latest.
:return: An OllamaDownloadHandler instance.
"""
handler = OllamaDownloadHandler(httpx.AsyncClient(base_url=self.base_url), name + ":" + tag)
handler = OllamaDownloadHandler(self.base_url, name + ":" + tag)
return handler
def new_chat(
@ -263,7 +252,7 @@ class OllamaClient:
:param messages:
:return:
"""
handler = OllamaChatHandler(httpx.AsyncClient(base_url=self.base_url), model, messages)
handler = OllamaChatHandler(self.base_url, model, messages)
return handler