Fix ollama iter error

This commit is contained in:
Nexus 2023-11-10 22:38:47 +00:00
parent b90f138758
commit f6b6f6edd3
Signed by: nex
GPG key ID: 0FA334385D0B689F

View file

@ -3,6 +3,8 @@ import functools
import glob import glob
import io import io
import json import json
import typing
import math import math
import os import os
import random import random
@ -64,24 +66,37 @@ except Exception as _pyttsx3_err:
VOICES = [] VOICES = []
class OllamaStreamReader: # class OllamaStreamReader:
def __init__(self, response: httpx.Response): # def __init__(self, response: httpx.Response):
self.response = response # self.response = response
self.stream = response.aiter_bytes(1) # self.stream = response.aiter_bytes(1)
self._buffer = b"" # self._buffer = b""
#
# async def __aiter__(self):
# return self
#
# async def __anext__(self) -> dict[str, str | int | bool]:
# if self.response.is_stream_consumed:
# raise StopAsyncIteration
# self._buffer = b""
# while not self._buffer.endswith(b"}\n"):
# async for char in self.stream:
# self._buffer += char
#
# return json.loads(self._buffer.decode("utf-8", "replace"))
async def __aiter__(self):
return self
async def __anext__(self) -> dict[str, str | int | bool]: async def ollama_stream_reader(response: httpx.Response) -> typing.AsyncGenerator[
if self.response.is_stream_consumed: dict[str, str | int | bool], None
raise StopAsyncIteration ]:
self._buffer = b"" stream = response.aiter_bytes(1)
while not self._buffer.endswith(b"}\n"): _buffer = b""
async for char in self.stream: while not response.is_stream_consumed:
self._buffer += char _buffer = b""
while not _buffer.endswith(b"}\n"):
return json.loads(self._buffer.decode("utf-8", "replace")) async for char in stream:
_buffer += char
yield json.loads(_buffer.decode("utf-8", "replace"))
def format_autocomplete(ctx: discord.AutocompleteContext): def format_autocomplete(ctx: discord.AutocompleteContext):
@ -1812,7 +1827,7 @@ class OtherCog(commands.Cog):
except httpx.TransportError as e: except httpx.TransportError as e:
return await msg.edit(content="Failed to connect to Ollama: `%s`" % e) return await msg.edit(content="Failed to connect to Ollama: `%s`" % e)
if response.status_code == 404: if response.status_code == 404:
await msg.edit(content="Downloading model %r, please wait.") await msg.edit(content=f"Downloading model {model}, please wait.")
async with ctx.channel.typing(): async with ctx.channel.typing():
async with client.stream( async with client.stream(
"POST", "POST",
@ -1822,7 +1837,7 @@ class OtherCog(commands.Cog):
) as response: ) as response:
if response.status_code != 200: if response.status_code != 200:
return await msg.edit(content="Failed to download model: `%s`" % response.text) return await msg.edit(content="Failed to download model: `%s`" % response.text)
async for chunk in OllamaStreamReader(response): async for chunk in ollama_stream_reader(response):
if "total" in chunk and "completed" in chunk: if "total" in chunk and "completed" in chunk:
completed = chunk["completed"] or 1 # avoid division by zero completed = chunk["completed"] or 1 # avoid division by zero
total = chunk["total"] or 1 total = chunk["total"] or 1
@ -1857,13 +1872,16 @@ class OtherCog(commands.Cog):
if response.status_code != 200: if response.status_code != 200:
return await msg.edit(content="Failed to generate text: `%s`" % response.text) return await msg.edit(content="Failed to generate text: `%s`" % response.text)
last_edit = msg.edited_at.timestamp() if msg.edited_at else msg.created_at.timestamp() last_edit = msg.edited_at.timestamp() if msg.edited_at else msg.created_at.timestamp()
async for chunk in OllamaStreamReader(response): async for chunk in ollama_stream_reader(response):
if "done" not in chunk or "response" not in chunk: if "done" not in chunk or "response" not in chunk:
continue continue
else: else:
content = "Response is still being generated..."
if chunk["done"] is True:
content = None
output.description = chunk["response"] output.description = chunk["response"]
if (time() - last_edit) >= 5 or chunk["done"] is True: if (time() - last_edit) >= 5 or chunk["done"] is True:
await msg.edit(content=None, embed=output) await msg.edit(content=content, embed=output)
break break