From 3f894c493d054064b8d872fe2d01bca8d0400fce Mon Sep 17 00:00:00 2001 From: nexy7574 Date: Sun, 14 Apr 2024 18:26:51 +0100 Subject: [PATCH] Don't pass clients to handlers --- src/cogs/ollama.py | 79 ++++++++++++++++++++-------------------------- 1 file changed, 34 insertions(+), 45 deletions(-) diff --git a/src/cogs/ollama.py b/src/cogs/ollama.py index 0f7ea14..aad739c 100644 --- a/src/cogs/ollama.py +++ b/src/cogs/ollama.py @@ -64,8 +64,8 @@ def get_time_spent(nanoseconds: int) -> str: class OllamaDownloadHandler: - def __init__(self, client: httpx.AsyncClient, model: str): - self.client = client + def __init__(self, base_url: str, model: str): + self.base_url = base_url self.model = model self._abort = asyncio.Event() self._total = 1 @@ -102,13 +102,14 @@ class OllamaDownloadHandler: self.status = line["status"] async def __aiter__(self): - async with self.client.post("/api/pull", json={"name": self.model, "stream": True}, timeout=None) as response: - response.raise_for_status() - async for line in ollama_stream(response.content): - self.parse_line(line) - if self._abort.is_set(): - break - yield line + async with httpx.AsyncClient(base_url=self.base_url) as client: + async with client.post("/api/pull", json={"name": self.model, "stream": True}, timeout=None) as response: + response.raise_for_status() + async for line in ollama_stream(response.content): + self.parse_line(line) + if self._abort.is_set(): + break + yield line async def flatten(self) -> "OllamaDownloadHandler": """Returns the current instance, but fully consumed.""" @@ -118,8 +119,8 @@ class OllamaDownloadHandler: class OllamaChatHandler: - def __init__(self, client: httpx.AsyncClient, model: str, messages: list): - self.client = client + def __init__(self, base_url: str, model: str, messages: list): + self.base_url = base_url self.model = model self.messages = messages self._abort = asyncio.Event() @@ -163,35 +164,26 @@ class OllamaChatHandler: self._abort.set() async def __aiter__(self): - async with self.client.post( - "/api/chat", - json={ - "model": self.model, - "stream": True, - "messages": self.messages - } - ) as response: - response.raise_for_status() - async for line in ollama_stream(response.content): - if self._abort.is_set(): - break + async with httpx.AsyncClient(base_url=self.base_url) as client: + async with client.post( + "/api/chat", + json={ + "model": self.model, + "stream": True, + "messages": self.messages + } + ) as response: + response.raise_for_status() + async for line in ollama_stream(response.content): + if self._abort.is_set(): + break - if line.get("message"): - self.buffer.write(line["message"]["content"]) - yield line + if line.get("message"): + self.buffer.write(line["message"]["content"]) + yield line - if line.get("done"): - break - - @classmethod - async def get_streamless(cls, client: httpx.AsyncClient, model: str, messages: list) -> "OllamaChatHandler": - async with client.post("/api/chat", json={"model": model, "messages": messages, "stream": False}) as response: - response.raise_for_status() - handler = cls(client, model, messages) - line = await response.json() - handler.parse_line(line) - handler.buffer.write(line["message"]["content"]) - return handler + if line.get("done"): + break class OllamaClient: @@ -199,10 +191,6 @@ class OllamaClient: self.base_url = base_url self.authorisation = authorisation - def _with_async_client(self, t) -> contextlib.AbstractContextManager[httpx.AsyncClient]: - with httpx.AsyncClient(base_url=self.base_url, timeout=t, auth=self.authorisation) as client: - yield client - def with_client( self, timeout: httpx.Timeout | float | int | None = None @@ -218,7 +206,8 @@ class OllamaClient: timeout = httpx.Timeout(timeout) else: timeout = timeout or httpx.Timeout(60) - yield from self._with_async_client(timeout) + async with httpx.AsyncClient(base_url=self.base_url, timeout=timeout, auth=self.authorisation) as client: + yield client async def get_tags(self) -> dict[typing.Literal["models"], dict[str, str, int, dict[str, str, None]]]: """ @@ -248,7 +237,7 @@ class OllamaClient: :param tag: The tag of the model. Defaults to latest. :return: An OllamaDownloadHandler instance. """ - handler = OllamaDownloadHandler(httpx.AsyncClient(base_url=self.base_url), name + ":" + tag) + handler = OllamaDownloadHandler(self.base_url, name + ":" + tag) return handler def new_chat( @@ -263,7 +252,7 @@ class OllamaClient: :param messages: :return: """ - handler = OllamaChatHandler(httpx.AsyncClient(base_url=self.base_url), model, messages) + handler = OllamaChatHandler(self.base_url, model, messages) return handler