mirror of
https://github.com/nexy7574/LCC-bot.git
synced 2024-09-20 02:26:32 +01:00
Fix ollama iter error
This commit is contained in:
parent
b90f138758
commit
f6b6f6edd3
1 changed files with 38 additions and 20 deletions
|
@ -3,6 +3,8 @@ import functools
|
||||||
import glob
|
import glob
|
||||||
import io
|
import io
|
||||||
import json
|
import json
|
||||||
|
import typing
|
||||||
|
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
|
@ -64,24 +66,37 @@ except Exception as _pyttsx3_err:
|
||||||
VOICES = []
|
VOICES = []
|
||||||
|
|
||||||
|
|
||||||
class OllamaStreamReader:
|
# class OllamaStreamReader:
|
||||||
def __init__(self, response: httpx.Response):
|
# def __init__(self, response: httpx.Response):
|
||||||
self.response = response
|
# self.response = response
|
||||||
self.stream = response.aiter_bytes(1)
|
# self.stream = response.aiter_bytes(1)
|
||||||
self._buffer = b""
|
# self._buffer = b""
|
||||||
|
#
|
||||||
|
# async def __aiter__(self):
|
||||||
|
# return self
|
||||||
|
#
|
||||||
|
# async def __anext__(self) -> dict[str, str | int | bool]:
|
||||||
|
# if self.response.is_stream_consumed:
|
||||||
|
# raise StopAsyncIteration
|
||||||
|
# self._buffer = b""
|
||||||
|
# while not self._buffer.endswith(b"}\n"):
|
||||||
|
# async for char in self.stream:
|
||||||
|
# self._buffer += char
|
||||||
|
#
|
||||||
|
# return json.loads(self._buffer.decode("utf-8", "replace"))
|
||||||
|
|
||||||
async def __aiter__(self):
|
|
||||||
return self
|
|
||||||
|
|
||||||
async def __anext__(self) -> dict[str, str | int | bool]:
|
async def ollama_stream_reader(response: httpx.Response) -> typing.AsyncGenerator[
|
||||||
if self.response.is_stream_consumed:
|
dict[str, str | int | bool], None
|
||||||
raise StopAsyncIteration
|
]:
|
||||||
self._buffer = b""
|
stream = response.aiter_bytes(1)
|
||||||
while not self._buffer.endswith(b"}\n"):
|
_buffer = b""
|
||||||
async for char in self.stream:
|
while not response.is_stream_consumed:
|
||||||
self._buffer += char
|
_buffer = b""
|
||||||
|
while not _buffer.endswith(b"}\n"):
|
||||||
return json.loads(self._buffer.decode("utf-8", "replace"))
|
async for char in stream:
|
||||||
|
_buffer += char
|
||||||
|
yield json.loads(_buffer.decode("utf-8", "replace"))
|
||||||
|
|
||||||
|
|
||||||
def format_autocomplete(ctx: discord.AutocompleteContext):
|
def format_autocomplete(ctx: discord.AutocompleteContext):
|
||||||
|
@ -1812,7 +1827,7 @@ class OtherCog(commands.Cog):
|
||||||
except httpx.TransportError as e:
|
except httpx.TransportError as e:
|
||||||
return await msg.edit(content="Failed to connect to Ollama: `%s`" % e)
|
return await msg.edit(content="Failed to connect to Ollama: `%s`" % e)
|
||||||
if response.status_code == 404:
|
if response.status_code == 404:
|
||||||
await msg.edit(content="Downloading model %r, please wait.")
|
await msg.edit(content=f"Downloading model {model}, please wait.")
|
||||||
async with ctx.channel.typing():
|
async with ctx.channel.typing():
|
||||||
async with client.stream(
|
async with client.stream(
|
||||||
"POST",
|
"POST",
|
||||||
|
@ -1822,7 +1837,7 @@ class OtherCog(commands.Cog):
|
||||||
) as response:
|
) as response:
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
return await msg.edit(content="Failed to download model: `%s`" % response.text)
|
return await msg.edit(content="Failed to download model: `%s`" % response.text)
|
||||||
async for chunk in OllamaStreamReader(response):
|
async for chunk in ollama_stream_reader(response):
|
||||||
if "total" in chunk and "completed" in chunk:
|
if "total" in chunk and "completed" in chunk:
|
||||||
completed = chunk["completed"] or 1 # avoid division by zero
|
completed = chunk["completed"] or 1 # avoid division by zero
|
||||||
total = chunk["total"] or 1
|
total = chunk["total"] or 1
|
||||||
|
@ -1857,13 +1872,16 @@ class OtherCog(commands.Cog):
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
return await msg.edit(content="Failed to generate text: `%s`" % response.text)
|
return await msg.edit(content="Failed to generate text: `%s`" % response.text)
|
||||||
last_edit = msg.edited_at.timestamp() if msg.edited_at else msg.created_at.timestamp()
|
last_edit = msg.edited_at.timestamp() if msg.edited_at else msg.created_at.timestamp()
|
||||||
async for chunk in OllamaStreamReader(response):
|
async for chunk in ollama_stream_reader(response):
|
||||||
if "done" not in chunk or "response" not in chunk:
|
if "done" not in chunk or "response" not in chunk:
|
||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
|
content = "Response is still being generated..."
|
||||||
|
if chunk["done"] is True:
|
||||||
|
content = None
|
||||||
output.description = chunk["response"]
|
output.description = chunk["response"]
|
||||||
if (time() - last_edit) >= 5 or chunk["done"] is True:
|
if (time() - last_edit) >= 5 or chunk["done"] is True:
|
||||||
await msg.edit(content=None, embed=output)
|
await msg.edit(content=content, embed=output)
|
||||||
break
|
break
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue