mirror of
https://github.com/nexy7574/LCC-bot.git
synced 2024-09-19 10:03:40 +01:00
Fix ollama iter error
This commit is contained in:
parent
b90f138758
commit
f6b6f6edd3
1 changed files with 38 additions and 20 deletions
|
@ -3,6 +3,8 @@ import functools
|
|||
import glob
|
||||
import io
|
||||
import json
|
||||
import typing
|
||||
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
|
@ -64,24 +66,37 @@ except Exception as _pyttsx3_err:
|
|||
VOICES = []
|
||||
|
||||
|
||||
class OllamaStreamReader:
|
||||
def __init__(self, response: httpx.Response):
|
||||
self.response = response
|
||||
self.stream = response.aiter_bytes(1)
|
||||
self._buffer = b""
|
||||
# class OllamaStreamReader:
|
||||
# def __init__(self, response: httpx.Response):
|
||||
# self.response = response
|
||||
# self.stream = response.aiter_bytes(1)
|
||||
# self._buffer = b""
|
||||
#
|
||||
# async def __aiter__(self):
|
||||
# return self
|
||||
#
|
||||
# async def __anext__(self) -> dict[str, str | int | bool]:
|
||||
# if self.response.is_stream_consumed:
|
||||
# raise StopAsyncIteration
|
||||
# self._buffer = b""
|
||||
# while not self._buffer.endswith(b"}\n"):
|
||||
# async for char in self.stream:
|
||||
# self._buffer += char
|
||||
#
|
||||
# return json.loads(self._buffer.decode("utf-8", "replace"))
|
||||
|
||||
async def __aiter__(self):
|
||||
return self
|
||||
|
||||
async def __anext__(self) -> dict[str, str | int | bool]:
|
||||
if self.response.is_stream_consumed:
|
||||
raise StopAsyncIteration
|
||||
self._buffer = b""
|
||||
while not self._buffer.endswith(b"}\n"):
|
||||
async for char in self.stream:
|
||||
self._buffer += char
|
||||
|
||||
return json.loads(self._buffer.decode("utf-8", "replace"))
|
||||
async def ollama_stream_reader(response: httpx.Response) -> typing.AsyncGenerator[
|
||||
dict[str, str | int | bool], None
|
||||
]:
|
||||
stream = response.aiter_bytes(1)
|
||||
_buffer = b""
|
||||
while not response.is_stream_consumed:
|
||||
_buffer = b""
|
||||
while not _buffer.endswith(b"}\n"):
|
||||
async for char in stream:
|
||||
_buffer += char
|
||||
yield json.loads(_buffer.decode("utf-8", "replace"))
|
||||
|
||||
|
||||
def format_autocomplete(ctx: discord.AutocompleteContext):
|
||||
|
@ -1812,7 +1827,7 @@ class OtherCog(commands.Cog):
|
|||
except httpx.TransportError as e:
|
||||
return await msg.edit(content="Failed to connect to Ollama: `%s`" % e)
|
||||
if response.status_code == 404:
|
||||
await msg.edit(content="Downloading model %r, please wait.")
|
||||
await msg.edit(content=f"Downloading model {model}, please wait.")
|
||||
async with ctx.channel.typing():
|
||||
async with client.stream(
|
||||
"POST",
|
||||
|
@ -1822,7 +1837,7 @@ class OtherCog(commands.Cog):
|
|||
) as response:
|
||||
if response.status_code != 200:
|
||||
return await msg.edit(content="Failed to download model: `%s`" % response.text)
|
||||
async for chunk in OllamaStreamReader(response):
|
||||
async for chunk in ollama_stream_reader(response):
|
||||
if "total" in chunk and "completed" in chunk:
|
||||
completed = chunk["completed"] or 1 # avoid division by zero
|
||||
total = chunk["total"] or 1
|
||||
|
@ -1857,13 +1872,16 @@ class OtherCog(commands.Cog):
|
|||
if response.status_code != 200:
|
||||
return await msg.edit(content="Failed to generate text: `%s`" % response.text)
|
||||
last_edit = msg.edited_at.timestamp() if msg.edited_at else msg.created_at.timestamp()
|
||||
async for chunk in OllamaStreamReader(response):
|
||||
async for chunk in ollama_stream_reader(response):
|
||||
if "done" not in chunk or "response" not in chunk:
|
||||
continue
|
||||
else:
|
||||
content = "Response is still being generated..."
|
||||
if chunk["done"] is True:
|
||||
content = None
|
||||
output.description = chunk["response"]
|
||||
if (time() - last_edit) >= 5 or chunk["done"] is True:
|
||||
await msg.edit(content=None, embed=output)
|
||||
await msg.edit(content=content, embed=output)
|
||||
break
|
||||
|
||||
|
||||
|
|
Loading…
Reference in a new issue