Don't pass clients to handlers

This commit is contained in:
Nexus 2024-04-14 18:26:51 +01:00
parent 0a2deba623
commit 3f894c493d
Signed by: nex
GPG key ID: 0FA334385D0B689F

View file

@ -64,8 +64,8 @@ def get_time_spent(nanoseconds: int) -> str:
class OllamaDownloadHandler: class OllamaDownloadHandler:
def __init__(self, client: httpx.AsyncClient, model: str): def __init__(self, base_url: str, model: str):
self.client = client self.base_url = base_url
self.model = model self.model = model
self._abort = asyncio.Event() self._abort = asyncio.Event()
self._total = 1 self._total = 1
@ -102,13 +102,14 @@ class OllamaDownloadHandler:
self.status = line["status"] self.status = line["status"]
async def __aiter__(self): async def __aiter__(self):
async with self.client.post("/api/pull", json={"name": self.model, "stream": True}, timeout=None) as response: async with httpx.AsyncClient(base_url=self.base_url) as client:
response.raise_for_status() async with client.post("/api/pull", json={"name": self.model, "stream": True}, timeout=None) as response:
async for line in ollama_stream(response.content): response.raise_for_status()
self.parse_line(line) async for line in ollama_stream(response.content):
if self._abort.is_set(): self.parse_line(line)
break if self._abort.is_set():
yield line break
yield line
async def flatten(self) -> "OllamaDownloadHandler": async def flatten(self) -> "OllamaDownloadHandler":
"""Returns the current instance, but fully consumed.""" """Returns the current instance, but fully consumed."""
@ -118,8 +119,8 @@ class OllamaDownloadHandler:
class OllamaChatHandler: class OllamaChatHandler:
def __init__(self, client: httpx.AsyncClient, model: str, messages: list): def __init__(self, base_url: str, model: str, messages: list):
self.client = client self.base_url = base_url
self.model = model self.model = model
self.messages = messages self.messages = messages
self._abort = asyncio.Event() self._abort = asyncio.Event()
@ -163,35 +164,26 @@ class OllamaChatHandler:
self._abort.set() self._abort.set()
async def __aiter__(self): async def __aiter__(self):
async with self.client.post( async with httpx.AsyncClient(base_url=self.base_url) as client:
"/api/chat", async with client.post(
json={ "/api/chat",
"model": self.model, json={
"stream": True, "model": self.model,
"messages": self.messages "stream": True,
} "messages": self.messages
) as response: }
response.raise_for_status() ) as response:
async for line in ollama_stream(response.content): response.raise_for_status()
if self._abort.is_set(): async for line in ollama_stream(response.content):
break if self._abort.is_set():
break
if line.get("message"): if line.get("message"):
self.buffer.write(line["message"]["content"]) self.buffer.write(line["message"]["content"])
yield line yield line
if line.get("done"): if line.get("done"):
break break
@classmethod
async def get_streamless(cls, client: httpx.AsyncClient, model: str, messages: list) -> "OllamaChatHandler":
async with client.post("/api/chat", json={"model": model, "messages": messages, "stream": False}) as response:
response.raise_for_status()
handler = cls(client, model, messages)
line = await response.json()
handler.parse_line(line)
handler.buffer.write(line["message"]["content"])
return handler
class OllamaClient: class OllamaClient:
@ -199,10 +191,6 @@ class OllamaClient:
self.base_url = base_url self.base_url = base_url
self.authorisation = authorisation self.authorisation = authorisation
def _with_async_client(self, t) -> contextlib.AbstractContextManager[httpx.AsyncClient]:
with httpx.AsyncClient(base_url=self.base_url, timeout=t, auth=self.authorisation) as client:
yield client
def with_client( def with_client(
self, self,
timeout: httpx.Timeout | float | int | None = None timeout: httpx.Timeout | float | int | None = None
@ -218,7 +206,8 @@ class OllamaClient:
timeout = httpx.Timeout(timeout) timeout = httpx.Timeout(timeout)
else: else:
timeout = timeout or httpx.Timeout(60) timeout = timeout or httpx.Timeout(60)
yield from self._with_async_client(timeout) async with httpx.AsyncClient(base_url=self.base_url, timeout=timeout, auth=self.authorisation) as client:
yield client
async def get_tags(self) -> dict[typing.Literal["models"], dict[str, str, int, dict[str, str, None]]]: async def get_tags(self) -> dict[typing.Literal["models"], dict[str, str, int, dict[str, str, None]]]:
""" """
@ -248,7 +237,7 @@ class OllamaClient:
:param tag: The tag of the model. Defaults to latest. :param tag: The tag of the model. Defaults to latest.
:return: An OllamaDownloadHandler instance. :return: An OllamaDownloadHandler instance.
""" """
handler = OllamaDownloadHandler(httpx.AsyncClient(base_url=self.base_url), name + ":" + tag) handler = OllamaDownloadHandler(self.base_url, name + ":" + tag)
return handler return handler
def new_chat( def new_chat(
@ -263,7 +252,7 @@ class OllamaClient:
:param messages: :param messages:
:return: :return:
""" """
handler = OllamaChatHandler(httpx.AsyncClient(base_url=self.base_url), model, messages) handler = OllamaChatHandler(self.base_url, model, messages)
return handler return handler