From 9362ec456365a2c5cd688a93b761f167b565bf6e Mon Sep 17 00:00:00 2001 From: 0xMRTT <0xMRTT@proton.me> Date: Wed, 21 Jun 2023 10:51:02 +0200 Subject: [PATCH] fix: image gen --- matrixai/__init__.py | 210 ++++++++++++++++--------------------------- 1 file changed, 76 insertions(+), 134 deletions(-) diff --git a/matrixai/__init__.py b/matrixai/__init__.py index 48bdc56..28182bc 100644 --- a/matrixai/__init__.py +++ b/matrixai/__init__.py @@ -3,6 +3,8 @@ import os import sys from baichat_py import Completion +import traceback + PREFIX = "!" USERNAME = os.environ.get("MATRIX_USERNAME", "ai") @@ -542,14 +544,15 @@ def validate_cfg(cfg: float) -> str: class AsyncImagine: """Async class for handling API requests to the Imagine service.""" - HEADERS = {"accept": "*/*", "user-agent": "okhttp/4.10.0"} + HEADERS = { + "accept": "*/*", + "user-agent": "okhttp/4.10.0" + } def __init__(self): self.asset = "https://1966211409.rsc.cdn77.org" self.api = "https://inferenceengine.vyro.ai" - self.session = aiohttp.ClientSession( - raise_for_status=True, headers=self.HEADERS - ) + self.session = aiohttp.ClientSession(raise_for_status=True, headers=self.HEADERS) self.version = "1" async def close(self) -> None: @@ -568,149 +571,91 @@ class AsyncImagine: async def assets(self, style: Style = Style.IMAGINE_V1) -> bytes: """Gets the assets.""" - async with self.session.get(url=self.get_style_url(style=style)) as resp: - return await resp.read() - - async def variate( - self, image: bytes, prompt: str, style: Style = Style.IMAGINE_V1 - ) -> bytes: - async with self.session.post( - url=f"{self.api}/variate", - data={ - "model_version": self.version, - "prompt": prompt + (style.value[3] or ""), - "strength": "0", - "style_id": str(style.value[0]), - "image": self.bytes_to_io(image, "image.png"), - }, + async with self.session.get( + url=self.get_style_url(style=style) ) as resp: return await resp.read() - async def sdprem( - self, - prompt: str, - negative: str = None, - priority: str = None, - steps: str = None, - high_res_results: str = None, - style: Style = Style.IMAGINE_V1, - seed: str = None, - ratio: Ratio = Ratio.RATIO_1X1, - cfg: float = 9.5, - ) -> bytes: + async def sdprem(self, prompt: str, negative: str | bool = None, priority: str = None, steps: str = None, + high_res_results: str = None, style: Style = Style.IMAGINE_V1, seed: str = None, + ratio: Ratio = Ratio.RATIO_1X1, cfg: float = 9.5) -> bytes | None: """Generates AI Art.""" try: validated_cfg = validate_cfg(cfg) except Exception as e: print(f"An error occurred while validating cfg: {e}") + traceback.print_exc() # Print the full traceback for detailed debugging return None - try: - async with self.session.post( - url=f"{self.api}/sdprem", - data={ - "model_version": self.version, - "prompt": prompt + (style.value[3] or ""), - "negative_prompt": negative or "", - "style_id": style.value[0], - "width": ratio.value[0], - "height": ratio.value[1], - "seed": seed or "", - "steps": steps or "30", - "cfg": validated_cfg, - "priority": priority or "0", - "high_res_results": high_res_results or "0", - }, - ) as resp: - return await resp.read() - except Exception as e: - print(f"An error occurred while making the request: {e}") - return None + for attempt in range(2): + try: + async with self.session.post( + url=f"{self.api}/sdprem", + data={ + "model_version": self.version, + "prompt": prompt + (style.value[3] or ""), + "negative_prompt": negative or "ugly, disfigured, low quality, blurry, nsfw", + "style_id": style.value[0], + "width": ratio.value[0], + "height": ratio.value[1], + "seed": seed or "", + "steps": steps or "30", + "cfg": validated_cfg, + "priority": priority or "0", + "high_res_results": high_res_results or "0" + } + ) as resp: + return await resp.read() + except Exception as e: + print(f"An error occurred while making the request: {e}") + traceback.print_exc() # Print the full traceback for detailed debugging + if attempt == 0: + await asyncio.sleep(0.4) + print("Retrying....") + else: + return None - async def upscale(self, image: bytes) -> bytes: + async def upscale(self, image: bytes) -> bytes | None: """Upscales the image.""" try: async with self.session.post( - url=f"{self.api}/upscale", - data={ - "model_version": self.version, - "image": self.bytes_to_io(image, "test.png"), - }, + url=f"{self.api}/upscale", + data={ + "model_version": self.version, + "image": self.bytes_to_io(image, "test.png") + } ) as resp: return await resp.read() except Exception as e: print(f"An error occurred while making the request: {e}") return None - - async def translate(self, prompt: str) -> str: - """Translates the prompt.""" - async with self.session.post( - url=f"{self.api}/translate", - data={"q": prompt, "source": detect(prompt), "target": "en"}, - ) as resp: - return (await resp.json())["translatedText"] - + async def interrogator(self, image: bytes) -> str: """Generates a prompt.""" async with self.session.post( - url=f"{self.api}/interrogator", - data={ - "model_version": str(self.version), - "image": self.bytes_to_io(image, "prompt_generator_temp.png"), - }, + url=f"{self.api}/interrogator", + data={ + "model_version": str(self.version), + "image": self.bytes_to_io(image, "prompt_generator_temp.png") + } ) as resp: return await resp.text() - async def sdimg( - self, - image: bytes, - prompt: str, - negative: str = None, - seed: str = None, - cfg: float = 9.5, - ) -> bytes: + async def sdimg(self, image: bytes, prompt: str, negative: str = None, seed: str = None, cfg: float = 9.5) -> bytes: """Performs inpainting.""" async with self.session.post( - url=f"{self.api}/sdimg", - data={ - "model_version": self.version, - "prompt": prompt, - "negative_prompt": negative or "", - "seed": seed or "", - "cfg": validate_cfg(cfg), - "image": self.bytes_to_io(image, "image.png"), - }, + url=f"{self.api}/sdimg", + data={ + "model_version": self.version, + "prompt": prompt, + "negative_prompt": negative or "", + "seed": seed or "", + "cfg": validate_cfg(cfg), + "image": self.bytes_to_io(image, "image.png") + } ) as resp: return await resp.read() - async def controlnet( - self, - image: bytes, - prompt: str, - negative: str = None, - cfg: float = 9.5, - control: Control = Control.SCRIBBLE, - style: Style = Style.IMAGINE_V1, - seed: str = None, - ) -> bytes: - """Performs image remix.""" - async with self.session.post( - url=f"{self.api}/controlnet", - data={ - "model_version": self.version, - "prompt": prompt + (style.value[3] or ""), - "negative_prompt": negative or "", - "strength": "0", - "cfg": validate_cfg(cfg), - "control": control.value, - "style_id": str(style.value[0]), - "seed": seed or "", - "image": self.bytes_to_io(image, "image.png"), - }, - ) as resp: - return await resp.read() - - def run(): if not USERNAME or not SERVER or not PASSWORD: print( @@ -777,17 +722,12 @@ def run(): else: prompt += arg + " " - async def generate_image( - image_prompt, style_value, ratio_value, negative - ): + async def generate_image(image_prompt, style_value, ratio_value, negative, upscale): + if negative is None: + negative = False imagine = AsyncImagine() - filename = str(uuid.uuid4()) + ".png" - try: - style_enum = Style[style_value] - ratio_enum = Ratio[ratio_value] - except KeyError: - style_enum = Style.IMAGINE_V3 - ratio_enum = Ratio.RATIO_1X1 + style_enum = Style[style_value] + ratio_enum = Ratio[ratio_value] img_data = await imagine.sdprem( prompt=image_prompt, style=style_enum, @@ -795,21 +735,23 @@ def run(): priority="1", high_res_results="1", steps="70", - negative=negative, + negative=negative ) + if upscale: + img_data = await imagine.upscale(image=img_data) + try: - with open(filename, mode="wb") as img_file: - img_file.write(img_data) + img_file = io.BytesIO(img_data) except Exception as e: - print(f"An error occurred while writing the image to file: {e}") + print( + f"An error occurred while creating the in-memory image file: {e}") return None await imagine.close() + return img_file - return filename - - filename = await generate_image(prompt, style, ratio, negative) + filename = await generate_image(prompt, style, ratio, negative, upscale=False) await bot.api.send_image_message( room_id=room.room_id, image_filepath=filename