This commit is contained in:
parent
54ef8641fc
commit
9362ec4563
@ -3,6 +3,8 @@ import os
|
|||||||
import sys
|
import sys
|
||||||
from baichat_py import Completion
|
from baichat_py import Completion
|
||||||
|
|
||||||
|
import traceback
|
||||||
|
|
||||||
PREFIX = "!"
|
PREFIX = "!"
|
||||||
|
|
||||||
USERNAME = os.environ.get("MATRIX_USERNAME", "ai")
|
USERNAME = os.environ.get("MATRIX_USERNAME", "ai")
|
||||||
@ -542,14 +544,15 @@ def validate_cfg(cfg: float) -> str:
|
|||||||
class AsyncImagine:
|
class AsyncImagine:
|
||||||
"""Async class for handling API requests to the Imagine service."""
|
"""Async class for handling API requests to the Imagine service."""
|
||||||
|
|
||||||
HEADERS = {"accept": "*/*", "user-agent": "okhttp/4.10.0"}
|
HEADERS = {
|
||||||
|
"accept": "*/*",
|
||||||
|
"user-agent": "okhttp/4.10.0"
|
||||||
|
}
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.asset = "https://1966211409.rsc.cdn77.org"
|
self.asset = "https://1966211409.rsc.cdn77.org"
|
||||||
self.api = "https://inferenceengine.vyro.ai"
|
self.api = "https://inferenceengine.vyro.ai"
|
||||||
self.session = aiohttp.ClientSession(
|
self.session = aiohttp.ClientSession(raise_for_status=True, headers=self.HEADERS)
|
||||||
raise_for_status=True, headers=self.HEADERS
|
|
||||||
)
|
|
||||||
self.version = "1"
|
self.version = "1"
|
||||||
|
|
||||||
async def close(self) -> None:
|
async def close(self) -> None:
|
||||||
@ -568,149 +571,91 @@ class AsyncImagine:
|
|||||||
|
|
||||||
async def assets(self, style: Style = Style.IMAGINE_V1) -> bytes:
|
async def assets(self, style: Style = Style.IMAGINE_V1) -> bytes:
|
||||||
"""Gets the assets."""
|
"""Gets the assets."""
|
||||||
async with self.session.get(url=self.get_style_url(style=style)) as resp:
|
async with self.session.get(
|
||||||
return await resp.read()
|
url=self.get_style_url(style=style)
|
||||||
|
|
||||||
async def variate(
|
|
||||||
self, image: bytes, prompt: str, style: Style = Style.IMAGINE_V1
|
|
||||||
) -> bytes:
|
|
||||||
async with self.session.post(
|
|
||||||
url=f"{self.api}/variate",
|
|
||||||
data={
|
|
||||||
"model_version": self.version,
|
|
||||||
"prompt": prompt + (style.value[3] or ""),
|
|
||||||
"strength": "0",
|
|
||||||
"style_id": str(style.value[0]),
|
|
||||||
"image": self.bytes_to_io(image, "image.png"),
|
|
||||||
},
|
|
||||||
) as resp:
|
) as resp:
|
||||||
return await resp.read()
|
return await resp.read()
|
||||||
|
|
||||||
async def sdprem(
|
async def sdprem(self, prompt: str, negative: str | bool = None, priority: str = None, steps: str = None,
|
||||||
self,
|
high_res_results: str = None, style: Style = Style.IMAGINE_V1, seed: str = None,
|
||||||
prompt: str,
|
ratio: Ratio = Ratio.RATIO_1X1, cfg: float = 9.5) -> bytes | None:
|
||||||
negative: str = None,
|
|
||||||
priority: str = None,
|
|
||||||
steps: str = None,
|
|
||||||
high_res_results: str = None,
|
|
||||||
style: Style = Style.IMAGINE_V1,
|
|
||||||
seed: str = None,
|
|
||||||
ratio: Ratio = Ratio.RATIO_1X1,
|
|
||||||
cfg: float = 9.5,
|
|
||||||
) -> bytes:
|
|
||||||
"""Generates AI Art."""
|
"""Generates AI Art."""
|
||||||
try:
|
try:
|
||||||
validated_cfg = validate_cfg(cfg)
|
validated_cfg = validate_cfg(cfg)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"An error occurred while validating cfg: {e}")
|
print(f"An error occurred while validating cfg: {e}")
|
||||||
|
traceback.print_exc() # Print the full traceback for detailed debugging
|
||||||
return None
|
return None
|
||||||
|
|
||||||
try:
|
for attempt in range(2):
|
||||||
async with self.session.post(
|
try:
|
||||||
url=f"{self.api}/sdprem",
|
async with self.session.post(
|
||||||
data={
|
url=f"{self.api}/sdprem",
|
||||||
"model_version": self.version,
|
data={
|
||||||
"prompt": prompt + (style.value[3] or ""),
|
"model_version": self.version,
|
||||||
"negative_prompt": negative or "",
|
"prompt": prompt + (style.value[3] or ""),
|
||||||
"style_id": style.value[0],
|
"negative_prompt": negative or "ugly, disfigured, low quality, blurry, nsfw",
|
||||||
"width": ratio.value[0],
|
"style_id": style.value[0],
|
||||||
"height": ratio.value[1],
|
"width": ratio.value[0],
|
||||||
"seed": seed or "",
|
"height": ratio.value[1],
|
||||||
"steps": steps or "30",
|
"seed": seed or "",
|
||||||
"cfg": validated_cfg,
|
"steps": steps or "30",
|
||||||
"priority": priority or "0",
|
"cfg": validated_cfg,
|
||||||
"high_res_results": high_res_results or "0",
|
"priority": priority or "0",
|
||||||
},
|
"high_res_results": high_res_results or "0"
|
||||||
) as resp:
|
}
|
||||||
return await resp.read()
|
) as resp:
|
||||||
except Exception as e:
|
return await resp.read()
|
||||||
print(f"An error occurred while making the request: {e}")
|
except Exception as e:
|
||||||
return None
|
print(f"An error occurred while making the request: {e}")
|
||||||
|
traceback.print_exc() # Print the full traceback for detailed debugging
|
||||||
|
if attempt == 0:
|
||||||
|
await asyncio.sleep(0.4)
|
||||||
|
print("Retrying....")
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
async def upscale(self, image: bytes) -> bytes:
|
async def upscale(self, image: bytes) -> bytes | None:
|
||||||
"""Upscales the image."""
|
"""Upscales the image."""
|
||||||
try:
|
try:
|
||||||
async with self.session.post(
|
async with self.session.post(
|
||||||
url=f"{self.api}/upscale",
|
url=f"{self.api}/upscale",
|
||||||
data={
|
data={
|
||||||
"model_version": self.version,
|
"model_version": self.version,
|
||||||
"image": self.bytes_to_io(image, "test.png"),
|
"image": self.bytes_to_io(image, "test.png")
|
||||||
},
|
}
|
||||||
) as resp:
|
) as resp:
|
||||||
return await resp.read()
|
return await resp.read()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"An error occurred while making the request: {e}")
|
print(f"An error occurred while making the request: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def translate(self, prompt: str) -> str:
|
|
||||||
"""Translates the prompt."""
|
|
||||||
async with self.session.post(
|
|
||||||
url=f"{self.api}/translate",
|
|
||||||
data={"q": prompt, "source": detect(prompt), "target": "en"},
|
|
||||||
) as resp:
|
|
||||||
return (await resp.json())["translatedText"]
|
|
||||||
|
|
||||||
async def interrogator(self, image: bytes) -> str:
|
async def interrogator(self, image: bytes) -> str:
|
||||||
"""Generates a prompt."""
|
"""Generates a prompt."""
|
||||||
async with self.session.post(
|
async with self.session.post(
|
||||||
url=f"{self.api}/interrogator",
|
url=f"{self.api}/interrogator",
|
||||||
data={
|
data={
|
||||||
"model_version": str(self.version),
|
"model_version": str(self.version),
|
||||||
"image": self.bytes_to_io(image, "prompt_generator_temp.png"),
|
"image": self.bytes_to_io(image, "prompt_generator_temp.png")
|
||||||
},
|
}
|
||||||
) as resp:
|
) as resp:
|
||||||
return await resp.text()
|
return await resp.text()
|
||||||
|
|
||||||
async def sdimg(
|
async def sdimg(self, image: bytes, prompt: str, negative: str = None, seed: str = None, cfg: float = 9.5) -> bytes:
|
||||||
self,
|
|
||||||
image: bytes,
|
|
||||||
prompt: str,
|
|
||||||
negative: str = None,
|
|
||||||
seed: str = None,
|
|
||||||
cfg: float = 9.5,
|
|
||||||
) -> bytes:
|
|
||||||
"""Performs inpainting."""
|
"""Performs inpainting."""
|
||||||
async with self.session.post(
|
async with self.session.post(
|
||||||
url=f"{self.api}/sdimg",
|
url=f"{self.api}/sdimg",
|
||||||
data={
|
data={
|
||||||
"model_version": self.version,
|
"model_version": self.version,
|
||||||
"prompt": prompt,
|
"prompt": prompt,
|
||||||
"negative_prompt": negative or "",
|
"negative_prompt": negative or "",
|
||||||
"seed": seed or "",
|
"seed": seed or "",
|
||||||
"cfg": validate_cfg(cfg),
|
"cfg": validate_cfg(cfg),
|
||||||
"image": self.bytes_to_io(image, "image.png"),
|
"image": self.bytes_to_io(image, "image.png")
|
||||||
},
|
}
|
||||||
) as resp:
|
) as resp:
|
||||||
return await resp.read()
|
return await resp.read()
|
||||||
|
|
||||||
async def controlnet(
|
|
||||||
self,
|
|
||||||
image: bytes,
|
|
||||||
prompt: str,
|
|
||||||
negative: str = None,
|
|
||||||
cfg: float = 9.5,
|
|
||||||
control: Control = Control.SCRIBBLE,
|
|
||||||
style: Style = Style.IMAGINE_V1,
|
|
||||||
seed: str = None,
|
|
||||||
) -> bytes:
|
|
||||||
"""Performs image remix."""
|
|
||||||
async with self.session.post(
|
|
||||||
url=f"{self.api}/controlnet",
|
|
||||||
data={
|
|
||||||
"model_version": self.version,
|
|
||||||
"prompt": prompt + (style.value[3] or ""),
|
|
||||||
"negative_prompt": negative or "",
|
|
||||||
"strength": "0",
|
|
||||||
"cfg": validate_cfg(cfg),
|
|
||||||
"control": control.value,
|
|
||||||
"style_id": str(style.value[0]),
|
|
||||||
"seed": seed or "",
|
|
||||||
"image": self.bytes_to_io(image, "image.png"),
|
|
||||||
},
|
|
||||||
) as resp:
|
|
||||||
return await resp.read()
|
|
||||||
|
|
||||||
|
|
||||||
def run():
|
def run():
|
||||||
if not USERNAME or not SERVER or not PASSWORD:
|
if not USERNAME or not SERVER or not PASSWORD:
|
||||||
print(
|
print(
|
||||||
@ -777,17 +722,12 @@ def run():
|
|||||||
else:
|
else:
|
||||||
prompt += arg + " "
|
prompt += arg + " "
|
||||||
|
|
||||||
async def generate_image(
|
async def generate_image(image_prompt, style_value, ratio_value, negative, upscale):
|
||||||
image_prompt, style_value, ratio_value, negative
|
if negative is None:
|
||||||
):
|
negative = False
|
||||||
imagine = AsyncImagine()
|
imagine = AsyncImagine()
|
||||||
filename = str(uuid.uuid4()) + ".png"
|
style_enum = Style[style_value]
|
||||||
try:
|
ratio_enum = Ratio[ratio_value]
|
||||||
style_enum = Style[style_value]
|
|
||||||
ratio_enum = Ratio[ratio_value]
|
|
||||||
except KeyError:
|
|
||||||
style_enum = Style.IMAGINE_V3
|
|
||||||
ratio_enum = Ratio.RATIO_1X1
|
|
||||||
img_data = await imagine.sdprem(
|
img_data = await imagine.sdprem(
|
||||||
prompt=image_prompt,
|
prompt=image_prompt,
|
||||||
style=style_enum,
|
style=style_enum,
|
||||||
@ -795,21 +735,23 @@ def run():
|
|||||||
priority="1",
|
priority="1",
|
||||||
high_res_results="1",
|
high_res_results="1",
|
||||||
steps="70",
|
steps="70",
|
||||||
negative=negative,
|
negative=negative
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if upscale:
|
||||||
|
img_data = await imagine.upscale(image=img_data)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with open(filename, mode="wb") as img_file:
|
img_file = io.BytesIO(img_data)
|
||||||
img_file.write(img_data)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"An error occurred while writing the image to file: {e}")
|
print(
|
||||||
|
f"An error occurred while creating the in-memory image file: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
await imagine.close()
|
await imagine.close()
|
||||||
|
return img_file
|
||||||
|
|
||||||
return filename
|
filename = await generate_image(prompt, style, ratio, negative, upscale=False)
|
||||||
|
|
||||||
filename = await generate_image(prompt, style, ratio, negative)
|
|
||||||
|
|
||||||
await bot.api.send_image_message(
|
await bot.api.send_image_message(
|
||||||
room_id=room.room_id, image_filepath=filename
|
room_id=room.room_id, image_filepath=filename
|
||||||
|
Loading…
Reference in New Issue
Block a user