Fix ollama iter error

This commit is contained in:
Nexus 2023-11-10 22:38:47 +00:00
parent b90f138758
commit f6b6f6edd3
Signed by: nex
GPG key ID: 0FA334385D0B689F

View file

@ -3,6 +3,8 @@ import functools
import glob
import io
import json
import typing
import math
import os
import random
@ -64,24 +66,37 @@ except Exception as _pyttsx3_err:
VOICES = []
class OllamaStreamReader:
def __init__(self, response: httpx.Response):
self.response = response
self.stream = response.aiter_bytes(1)
self._buffer = b""
# class OllamaStreamReader:
# def __init__(self, response: httpx.Response):
# self.response = response
# self.stream = response.aiter_bytes(1)
# self._buffer = b""
#
# async def __aiter__(self):
# return self
#
# async def __anext__(self) -> dict[str, str | int | bool]:
# if self.response.is_stream_consumed:
# raise StopAsyncIteration
# self._buffer = b""
# while not self._buffer.endswith(b"}\n"):
# async for char in self.stream:
# self._buffer += char
#
# return json.loads(self._buffer.decode("utf-8", "replace"))
async def __aiter__(self):
return self
async def __anext__(self) -> dict[str, str | int | bool]:
if self.response.is_stream_consumed:
raise StopAsyncIteration
self._buffer = b""
while not self._buffer.endswith(b"}\n"):
async for char in self.stream:
self._buffer += char
return json.loads(self._buffer.decode("utf-8", "replace"))
async def ollama_stream_reader(response: httpx.Response) -> typing.AsyncGenerator[
dict[str, str | int | bool], None
]:
stream = response.aiter_bytes(1)
_buffer = b""
while not response.is_stream_consumed:
_buffer = b""
while not _buffer.endswith(b"}\n"):
async for char in stream:
_buffer += char
yield json.loads(_buffer.decode("utf-8", "replace"))
def format_autocomplete(ctx: discord.AutocompleteContext):
@ -1812,7 +1827,7 @@ class OtherCog(commands.Cog):
except httpx.TransportError as e:
return await msg.edit(content="Failed to connect to Ollama: `%s`" % e)
if response.status_code == 404:
await msg.edit(content="Downloading model %r, please wait.")
await msg.edit(content=f"Downloading model {model}, please wait.")
async with ctx.channel.typing():
async with client.stream(
"POST",
@ -1822,7 +1837,7 @@ class OtherCog(commands.Cog):
) as response:
if response.status_code != 200:
return await msg.edit(content="Failed to download model: `%s`" % response.text)
async for chunk in OllamaStreamReader(response):
async for chunk in ollama_stream_reader(response):
if "total" in chunk and "completed" in chunk:
completed = chunk["completed"] or 1 # avoid division by zero
total = chunk["total"] or 1
@ -1857,13 +1872,16 @@ class OtherCog(commands.Cog):
if response.status_code != 200:
return await msg.edit(content="Failed to generate text: `%s`" % response.text)
last_edit = msg.edited_at.timestamp() if msg.edited_at else msg.created_at.timestamp()
async for chunk in OllamaStreamReader(response):
async for chunk in ollama_stream_reader(response):
if "done" not in chunk or "response" not in chunk:
continue
else:
content = "Response is still being generated..."
if chunk["done"] is True:
content = None
output.description = chunk["response"]
if (time() - last_edit) >= 5 or chunk["done"] is True:
await msg.edit(content=None, embed=output)
await msg.edit(content=content, embed=output)
break