fix: image gen
Some checks failed
black-action / runner / black formatter (push) Failing after 8s

This commit is contained in:
0xMRTT 2023-06-21 10:51:02 +02:00
parent 54ef8641fc
commit 9362ec4563
Signed by: 0xMRTT
GPG Key ID: 910B287304120902

View File

@ -3,6 +3,8 @@ import os
import sys import sys
from baichat_py import Completion from baichat_py import Completion
import traceback
PREFIX = "!" PREFIX = "!"
USERNAME = os.environ.get("MATRIX_USERNAME", "ai") USERNAME = os.environ.get("MATRIX_USERNAME", "ai")
@ -542,14 +544,15 @@ def validate_cfg(cfg: float) -> str:
class AsyncImagine: class AsyncImagine:
"""Async class for handling API requests to the Imagine service.""" """Async class for handling API requests to the Imagine service."""
HEADERS = {"accept": "*/*", "user-agent": "okhttp/4.10.0"} HEADERS = {
"accept": "*/*",
"user-agent": "okhttp/4.10.0"
}
def __init__(self): def __init__(self):
self.asset = "https://1966211409.rsc.cdn77.org" self.asset = "https://1966211409.rsc.cdn77.org"
self.api = "https://inferenceengine.vyro.ai" self.api = "https://inferenceengine.vyro.ai"
self.session = aiohttp.ClientSession( self.session = aiohttp.ClientSession(raise_for_status=True, headers=self.HEADERS)
raise_for_status=True, headers=self.HEADERS
)
self.version = "1" self.version = "1"
async def close(self) -> None: async def close(self) -> None:
@ -568,149 +571,91 @@ class AsyncImagine:
async def assets(self, style: Style = Style.IMAGINE_V1) -> bytes: async def assets(self, style: Style = Style.IMAGINE_V1) -> bytes:
"""Gets the assets.""" """Gets the assets."""
async with self.session.get(url=self.get_style_url(style=style)) as resp: async with self.session.get(
return await resp.read() url=self.get_style_url(style=style)
async def variate(
self, image: bytes, prompt: str, style: Style = Style.IMAGINE_V1
) -> bytes:
async with self.session.post(
url=f"{self.api}/variate",
data={
"model_version": self.version,
"prompt": prompt + (style.value[3] or ""),
"strength": "0",
"style_id": str(style.value[0]),
"image": self.bytes_to_io(image, "image.png"),
},
) as resp: ) as resp:
return await resp.read() return await resp.read()
async def sdprem( async def sdprem(self, prompt: str, negative: str | bool = None, priority: str = None, steps: str = None,
self, high_res_results: str = None, style: Style = Style.IMAGINE_V1, seed: str = None,
prompt: str, ratio: Ratio = Ratio.RATIO_1X1, cfg: float = 9.5) -> bytes | None:
negative: str = None,
priority: str = None,
steps: str = None,
high_res_results: str = None,
style: Style = Style.IMAGINE_V1,
seed: str = None,
ratio: Ratio = Ratio.RATIO_1X1,
cfg: float = 9.5,
) -> bytes:
"""Generates AI Art.""" """Generates AI Art."""
try: try:
validated_cfg = validate_cfg(cfg) validated_cfg = validate_cfg(cfg)
except Exception as e: except Exception as e:
print(f"An error occurred while validating cfg: {e}") print(f"An error occurred while validating cfg: {e}")
traceback.print_exc() # Print the full traceback for detailed debugging
return None return None
try: for attempt in range(2):
async with self.session.post( try:
url=f"{self.api}/sdprem", async with self.session.post(
data={ url=f"{self.api}/sdprem",
"model_version": self.version, data={
"prompt": prompt + (style.value[3] or ""), "model_version": self.version,
"negative_prompt": negative or "", "prompt": prompt + (style.value[3] or ""),
"style_id": style.value[0], "negative_prompt": negative or "ugly, disfigured, low quality, blurry, nsfw",
"width": ratio.value[0], "style_id": style.value[0],
"height": ratio.value[1], "width": ratio.value[0],
"seed": seed or "", "height": ratio.value[1],
"steps": steps or "30", "seed": seed or "",
"cfg": validated_cfg, "steps": steps or "30",
"priority": priority or "0", "cfg": validated_cfg,
"high_res_results": high_res_results or "0", "priority": priority or "0",
}, "high_res_results": high_res_results or "0"
) as resp: }
return await resp.read() ) as resp:
except Exception as e: return await resp.read()
print(f"An error occurred while making the request: {e}") except Exception as e:
return None print(f"An error occurred while making the request: {e}")
traceback.print_exc() # Print the full traceback for detailed debugging
if attempt == 0:
await asyncio.sleep(0.4)
print("Retrying....")
else:
return None
async def upscale(self, image: bytes) -> bytes: async def upscale(self, image: bytes) -> bytes | None:
"""Upscales the image.""" """Upscales the image."""
try: try:
async with self.session.post( async with self.session.post(
url=f"{self.api}/upscale", url=f"{self.api}/upscale",
data={ data={
"model_version": self.version, "model_version": self.version,
"image": self.bytes_to_io(image, "test.png"), "image": self.bytes_to_io(image, "test.png")
}, }
) as resp: ) as resp:
return await resp.read() return await resp.read()
except Exception as e: except Exception as e:
print(f"An error occurred while making the request: {e}") print(f"An error occurred while making the request: {e}")
return None return None
async def translate(self, prompt: str) -> str:
"""Translates the prompt."""
async with self.session.post(
url=f"{self.api}/translate",
data={"q": prompt, "source": detect(prompt), "target": "en"},
) as resp:
return (await resp.json())["translatedText"]
async def interrogator(self, image: bytes) -> str: async def interrogator(self, image: bytes) -> str:
"""Generates a prompt.""" """Generates a prompt."""
async with self.session.post( async with self.session.post(
url=f"{self.api}/interrogator", url=f"{self.api}/interrogator",
data={ data={
"model_version": str(self.version), "model_version": str(self.version),
"image": self.bytes_to_io(image, "prompt_generator_temp.png"), "image": self.bytes_to_io(image, "prompt_generator_temp.png")
}, }
) as resp: ) as resp:
return await resp.text() return await resp.text()
async def sdimg( async def sdimg(self, image: bytes, prompt: str, negative: str = None, seed: str = None, cfg: float = 9.5) -> bytes:
self,
image: bytes,
prompt: str,
negative: str = None,
seed: str = None,
cfg: float = 9.5,
) -> bytes:
"""Performs inpainting.""" """Performs inpainting."""
async with self.session.post( async with self.session.post(
url=f"{self.api}/sdimg", url=f"{self.api}/sdimg",
data={ data={
"model_version": self.version, "model_version": self.version,
"prompt": prompt, "prompt": prompt,
"negative_prompt": negative or "", "negative_prompt": negative or "",
"seed": seed or "", "seed": seed or "",
"cfg": validate_cfg(cfg), "cfg": validate_cfg(cfg),
"image": self.bytes_to_io(image, "image.png"), "image": self.bytes_to_io(image, "image.png")
}, }
) as resp: ) as resp:
return await resp.read() return await resp.read()
async def controlnet(
self,
image: bytes,
prompt: str,
negative: str = None,
cfg: float = 9.5,
control: Control = Control.SCRIBBLE,
style: Style = Style.IMAGINE_V1,
seed: str = None,
) -> bytes:
"""Performs image remix."""
async with self.session.post(
url=f"{self.api}/controlnet",
data={
"model_version": self.version,
"prompt": prompt + (style.value[3] or ""),
"negative_prompt": negative or "",
"strength": "0",
"cfg": validate_cfg(cfg),
"control": control.value,
"style_id": str(style.value[0]),
"seed": seed or "",
"image": self.bytes_to_io(image, "image.png"),
},
) as resp:
return await resp.read()
def run(): def run():
if not USERNAME or not SERVER or not PASSWORD: if not USERNAME or not SERVER or not PASSWORD:
print( print(
@ -777,17 +722,12 @@ def run():
else: else:
prompt += arg + " " prompt += arg + " "
async def generate_image( async def generate_image(image_prompt, style_value, ratio_value, negative, upscale):
image_prompt, style_value, ratio_value, negative if negative is None:
): negative = False
imagine = AsyncImagine() imagine = AsyncImagine()
filename = str(uuid.uuid4()) + ".png" style_enum = Style[style_value]
try: ratio_enum = Ratio[ratio_value]
style_enum = Style[style_value]
ratio_enum = Ratio[ratio_value]
except KeyError:
style_enum = Style.IMAGINE_V3
ratio_enum = Ratio.RATIO_1X1
img_data = await imagine.sdprem( img_data = await imagine.sdprem(
prompt=image_prompt, prompt=image_prompt,
style=style_enum, style=style_enum,
@ -795,21 +735,23 @@ def run():
priority="1", priority="1",
high_res_results="1", high_res_results="1",
steps="70", steps="70",
negative=negative, negative=negative
) )
if upscale:
img_data = await imagine.upscale(image=img_data)
try: try:
with open(filename, mode="wb") as img_file: img_file = io.BytesIO(img_data)
img_file.write(img_data)
except Exception as e: except Exception as e:
print(f"An error occurred while writing the image to file: {e}") print(
f"An error occurred while creating the in-memory image file: {e}")
return None return None
await imagine.close() await imagine.close()
return img_file
return filename filename = await generate_image(prompt, style, ratio, negative, upscale=False)
filename = await generate_image(prompt, style, ratio, negative)
await bot.api.send_image_message( await bot.api.send_image_message(
room_id=room.room_id, image_filepath=filename room_id=room.room_id, image_filepath=filename