From f6b6f6edd3f3bbf485f431210ab4ad66949bd240 Mon Sep 17 00:00:00 2001 From: nex Date: Fri, 10 Nov 2023 22:38:47 +0000 Subject: [PATCH] Fix ollama iter error --- cogs/other.py | 58 +++++++++++++++++++++++++++++++++------------------ 1 file changed, 38 insertions(+), 20 deletions(-) diff --git a/cogs/other.py b/cogs/other.py index da6d56d..cfd04c0 100644 --- a/cogs/other.py +++ b/cogs/other.py @@ -3,6 +3,8 @@ import functools import glob import io import json +import typing + import math import os import random @@ -64,24 +66,37 @@ except Exception as _pyttsx3_err: VOICES = [] -class OllamaStreamReader: - def __init__(self, response: httpx.Response): - self.response = response - self.stream = response.aiter_bytes(1) - self._buffer = b"" +# class OllamaStreamReader: +# def __init__(self, response: httpx.Response): +# self.response = response +# self.stream = response.aiter_bytes(1) +# self._buffer = b"" +# +# async def __aiter__(self): +# return self +# +# async def __anext__(self) -> dict[str, str | int | bool]: +# if self.response.is_stream_consumed: +# raise StopAsyncIteration +# self._buffer = b"" +# while not self._buffer.endswith(b"}\n"): +# async for char in self.stream: +# self._buffer += char +# +# return json.loads(self._buffer.decode("utf-8", "replace")) - async def __aiter__(self): - return self - async def __anext__(self) -> dict[str, str | int | bool]: - if self.response.is_stream_consumed: - raise StopAsyncIteration - self._buffer = b"" - while not self._buffer.endswith(b"}\n"): - async for char in self.stream: - self._buffer += char - - return json.loads(self._buffer.decode("utf-8", "replace")) +async def ollama_stream_reader(response: httpx.Response) -> typing.AsyncGenerator[ + dict[str, str | int | bool], None +]: + stream = response.aiter_bytes(1) + _buffer = b"" + while not response.is_stream_consumed: + _buffer = b"" + while not _buffer.endswith(b"}\n"): + async for char in stream: + _buffer += char + yield json.loads(_buffer.decode("utf-8", "replace")) def format_autocomplete(ctx: discord.AutocompleteContext): @@ -1812,7 +1827,7 @@ class OtherCog(commands.Cog): except httpx.TransportError as e: return await msg.edit(content="Failed to connect to Ollama: `%s`" % e) if response.status_code == 404: - await msg.edit(content="Downloading model %r, please wait.") + await msg.edit(content=f"Downloading model {model}, please wait.") async with ctx.channel.typing(): async with client.stream( "POST", @@ -1822,7 +1837,7 @@ class OtherCog(commands.Cog): ) as response: if response.status_code != 200: return await msg.edit(content="Failed to download model: `%s`" % response.text) - async for chunk in OllamaStreamReader(response): + async for chunk in ollama_stream_reader(response): if "total" in chunk and "completed" in chunk: completed = chunk["completed"] or 1 # avoid division by zero total = chunk["total"] or 1 @@ -1857,13 +1872,16 @@ class OtherCog(commands.Cog): if response.status_code != 200: return await msg.edit(content="Failed to generate text: `%s`" % response.text) last_edit = msg.edited_at.timestamp() if msg.edited_at else msg.created_at.timestamp() - async for chunk in OllamaStreamReader(response): + async for chunk in ollama_stream_reader(response): if "done" not in chunk or "response" not in chunk: continue else: + content = "Response is still being generated..." + if chunk["done"] is True: + content = None output.description = chunk["response"] if (time() - last_edit) >= 5 or chunk["done"] is True: - await msg.edit(content=None, embed=output) + await msg.edit(content=content, embed=output) break