Skip to content

Commit

Permalink
Rest API support and cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
eagarvey-amd committed Aug 8, 2024
1 parent d5f37ea commit 4759e80
Show file tree
Hide file tree
Showing 18 changed files with 450 additions and 211 deletions.
6 changes: 3 additions & 3 deletions .github/workflows/nightly.yml
Original file line number Diff line number Diff line change
Expand Up @@ -46,18 +46,18 @@ jobs:
draft: true
prerelease: true

- name: Build Package
- name: Build Package (api only)
shell: powershell
run: |
./setup_venv.ps1
python process_skipfiles.py
$env:SHARK_PACKAGE_VERSION=${{ env.package_version }}
pip install -e .
pip freeze -l
pyinstaller .\apps\shark_studio\shark_studio.spec
pyinstaller .\apps\shark_studio\shark_studio_apionly.spec
mv ./dist/nodai_shark_studio.exe ./dist/nodai_shark_studio_${{ env.package_version_ }}.exe
signtool sign /f c:\g\shark_02152023.cer /fd certHash /csp "eToken Base Cryptographic Provider" /k "${{ secrets.CI_CERT }}" ./dist/nodai_shark_studio_${{ env.package_version_ }}.exe
- name: Upload Release Assets
id: upload-release-assets
uses: dwenegar/upload-release-assets@v1
Expand Down
11 changes: 4 additions & 7 deletions apps/shark_studio/api/initializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@ def imports():
action="ignore", category=UserWarning, module="huggingface-hub"
)

import gradio # noqa: F401
# import gradio # noqa: F401

startup_timer.record("import gradio")
# startup_timer.record("import gradio")

import apps.shark_studio.web.utils.globals as global_obj

Expand All @@ -56,9 +56,8 @@ def initialize():
# existing temporary images there if they exist. Then we can import gradio.
# It has to be in this order or gradio ignores what we've set up.

config_tmp()
# clear_tmp_mlir()
clear_tmp_imgs()
# config_tmp()
# clear_tmp_imgs()

from apps.shark_studio.web.utils.file_utils import (
create_model_folders,
Expand All @@ -67,8 +66,6 @@ def initialize():
# Create custom models folders if they don't exist
create_model_folders()

import gradio as gr

# initialize_rest(reload_script_modules=False)


Expand Down
155 changes: 94 additions & 61 deletions apps/shark_studio/api/sd.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from random import randint



from apps.shark_studio.api.controlnet import control_adapter_map
from apps.shark_studio.api.utils import parse_device
from apps.shark_studio.web.utils.state import status_label
Expand All @@ -30,6 +29,7 @@


from subprocess import check_output

EMPTY_SD_MAP = {
"clip": None,
"scheduler": None,
Expand Down Expand Up @@ -114,11 +114,14 @@ def __init__(
from turbine_models.custom_models.sdxl_inference.sdxl_compiled_pipeline import (
SharkSDXLPipeline,
)

self.turbine_pipe = SharkSDXLPipeline
self.dynamic_steps = False
self.model_map = EMPTY_SDXL_MAP
else:
from turbine_models.custom_models.sd_inference.sd_pipeline import SharkSDPipeline
from turbine_models.custom_models.sd_inference.sd_pipeline import (
SharkSDPipeline,
)

self.turbine_pipe = SharkSDPipeline
self.dynamic_steps = True
Expand Down Expand Up @@ -209,6 +212,7 @@ def prepare_pipe(
preprocessCKPT,
save_irpa,
)

custom_weights = os.path.join(
get_checkpoints_path("checkpoints"),
safe_name(self.base_model_id.split("/")[-1]),
Expand All @@ -223,14 +227,20 @@ def prepare_pipe(
"diffusion_pytorch_model.safetensors",
)
weights[key] = save_irpa(unet_weights_path, "unet.")

elif key in ["clip", "prompt_encoder"]:
if not self.is_sdxl:
if key in ["mmdit"]:
mmdit_weights_path = os.path.join(
diffusers_weights_path,
"mmdit",
"diffusion_pytorch_model_fp16.safetensors",
)
weights[key] = save_irpa(mmdit_weights_path, "mmdit.")
elif key in ["clip", "prompt_encoder", "text_encoder"]:
if not self.is_sdxl and not self.is_custom:
sd1_path = os.path.join(
diffusers_weights_path, "text_encoder", "model.safetensors"
)
weights[key] = save_irpa(sd1_path, "text_encoder_model.")
else:
elif self.is_sdxl:
clip_1_path = os.path.join(
diffusers_weights_path, "text_encoder", "model.safetensors"
)
Expand All @@ -243,14 +253,35 @@ def prepare_pipe(
save_irpa(clip_1_path, "text_encoder_model_1."),
save_irpa(clip_2_path, "text_encoder_model_2."),
]

elif self.is_custom:
clip_g_path = os.path.join(
diffusers_weights_path,
"text_encoder",
"model.fp16.safetensors",
)
clip_l_path = os.path.join(
diffusers_weights_path,
"text_encoder_2",
"model.fp16.safetensors",
)
t5xxl_path = os.path.join(
diffusers_weights_path,
"text_encoder_3",
"model.fp16.safetensors",
)
weights[key] = [
save_irpa(clip_g_path, "clip_g.transformer."),
save_irpa(clip_l_path, "clip_l.transformer."),
save_irpa(t5xxl_path, "t5xxl.transformer."),
]
elif key in ["vae_decode"] and weights[key] is None:
vae_weights_path = os.path.join(
diffusers_weights_path,
"vae",
"diffusion_pytorch_model.safetensors",
)
weights[key] = save_irpa(vae_weights_path, "vae.")

progress(0.25, desc=f"Preparing pipeline for {self.ui_device}...")

vmfbs, weights = self.sd_pipe.check_prepared(
Expand Down Expand Up @@ -291,49 +322,6 @@ def generate_images(
return img


def shark_sd_fn_dict_input(sd_kwargs: dict, *, progress=gr.Progress()):
print("\n[LOG] Submitting Request...")

for key in sd_kwargs:
if sd_kwargs[key] in [None, []]:
sd_kwargs[key] = None
if sd_kwargs[key] in ["None"]:
sd_kwargs[key] = ""
if key in ["steps", "height", "width", "batch_count", "batch_size"]:
sd_kwargs[key] = int(sd_kwargs[key])
if key == "seed":
sd_kwargs[key] = int(sd_kwargs[key])

# TODO: move these checks into the UI code so we don't have gradio warnings in a generalized dict input function.
if not sd_kwargs["device"]:
gr.Warning("No device specified. Please specify a device.")
return None, ""
if sd_kwargs["height"] not in [512, 1024]:
gr.Warning("Height must be 512 or 1024. This is a temporary limitation.")
return None, ""
if sd_kwargs["height"] != sd_kwargs["width"]:
gr.Warning("Height and width must be the same. This is a temporary limitation.")
return None, ""
if sd_kwargs["base_model_id"] == "stabilityai/sdxl-turbo":
if sd_kwargs["steps"] > 10:
gr.Warning("Max steps for sdxl-turbo is 10. 1 to 4 steps are recommended.")
return None, ""
if sd_kwargs["guidance_scale"] > 3:
gr.Warning(
"sdxl-turbo CFG scale should be less than 2.0 if using negative prompt, 0 otherwise."
)
return None, ""
if sd_kwargs["target_triple"] == "":
if not parse_device(sd_kwargs["device"], sd_kwargs["target_triple"])[2]:
gr.Warning(
"Target device architecture could not be inferred. Please specify a target triple, e.g. 'gfx1100' for a Radeon 7900xtx."
)
return None, ""

generated_imgs = yield from shark_sd_fn(**sd_kwargs)
return generated_imgs


def shark_sd_fn(
prompt,
negative_prompt,
Expand All @@ -359,7 +347,8 @@ def shark_sd_fn(
controlnets: dict,
embeddings: dict,
seed_increment: str | int = 1,
progress=gr.Progress(),
output_type: str = "png",
# progress=gr.Progress(),
):
sd_kwargs = locals()
if not isinstance(sd_init_image, list):
Expand Down Expand Up @@ -464,8 +453,8 @@ def shark_sd_fn(
if submit_run_kwargs["seed"] in [-1, "-1"]:
submit_run_kwargs["seed"] = randint(0, 4294967295)
seed_increment = "random"
#print(f"\n[LOG] Random seed: {seed}")
progress(None, desc=f"Generating...")
# print(f"\n[LOG] Random seed: {seed}")
# progress(None, desc=f"Generating...")

for current_batch in range(batch_count):
start_time = time.time()
Expand All @@ -479,13 +468,14 @@ def shark_sd_fn(
# break
# else:
for batch in range(batch_size):
save_output_img(
out_imgs[batch],
seed,
sd_kwargs,
)
if output_type == "png":
save_output_img(
out_imgs[batch],
seed,
sd_kwargs,
)
generated_imgs.extend(out_imgs)

yield generated_imgs, status_label(
"Stable Diffusion", current_batch + 1, batch_count, batch_size
)
Expand All @@ -495,13 +485,56 @@ def shark_sd_fn(
return (generated_imgs, "")


def shark_sd_fn_dict_input(sd_kwargs: dict, *, progress=gr.Progress()):
print("\n[LOG] Submitting Request...")

for key in sd_kwargs:
if sd_kwargs[key] in [None, []]:
sd_kwargs[key] = None
if sd_kwargs[key] in ["None"]:
sd_kwargs[key] = ""
if key in ["steps", "height", "width", "batch_count", "batch_size"]:
sd_kwargs[key] = int(sd_kwargs[key])
if key == "seed":
sd_kwargs[key] = int(sd_kwargs[key])

# TODO: move these checks into the UI code so we don't have gradio warnings in a generalized dict input function.
if not sd_kwargs["device"]:
gr.Warning("No device specified. Please specify a device.")
return None, ""
if sd_kwargs["height"] not in [512, 1024]:
gr.Warning("Height must be 512 or 1024. This is a temporary limitation.")
return None, ""
if sd_kwargs["height"] != sd_kwargs["width"]:
gr.Warning("Height and width must be the same. This is a temporary limitation.")
return None, ""
if sd_kwargs["base_model_id"] == "stabilityai/sdxl-turbo":
if sd_kwargs["steps"] > 10:
gr.Warning("Max steps for sdxl-turbo is 10. 1 to 4 steps are recommended.")
return None, ""
if sd_kwargs["guidance_scale"] > 3:
gr.Warning(
"sdxl-turbo CFG scale should be less than 2.0 if using negative prompt, 0 otherwise."
)
return None, ""
if sd_kwargs["target_triple"] == "":
if not parse_device(sd_kwargs["device"], sd_kwargs["target_triple"])[2]:
gr.Warning(
"Target device architecture could not be inferred. Please specify a target triple, e.g. 'gfx1100' for a Radeon 7900xtx."
)
return None, ""

generated_imgs = yield from shark_sd_fn(**sd_kwargs)
return generated_imgs


def get_next_seed(seed, seed_increment: str | int = 10):
if isinstance(seed_increment, int):
#print(f"\n[LOG] Seed after batch increment: {seed + seed_increment}")
# print(f"\n[LOG] Seed after batch increment: {seed + seed_increment}")
return int(seed + seed_increment)
elif seed_increment == "random":
seed = randint(0, 4294967295)
#print(f"\n[LOG] Random seed: {seed}")
# print(f"\n[LOG] Random seed: {seed}")
return seed


Expand Down
16 changes: 7 additions & 9 deletions apps/shark_studio/api/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,9 @@ def iree_target_map(device):
}



def get_available_devices():
return ['rocm', 'cpu']
return ["rocm", "cpu"]

def get_devices_by_name(driver_name):

device_list = []
Expand Down Expand Up @@ -94,7 +94,7 @@ def get_devices_by_name(driver_name):
device_list.append(f"{device_name} => {driver_name}://{i}")
return device_list

#set_iree_runtime_flags()
# set_iree_runtime_flags()

available_devices = []
rocm_devices = get_devices_by_name("rocm")
Expand Down Expand Up @@ -140,17 +140,14 @@ def get_devices_by_name(driver_name):
break
return available_devices


def clean_device_info(raw_device):
# return appropriate device and device_id for consumption by Studio pipeline
# Multiple devices only supported for vulkan and rocm (as of now).
# default device must be selected for all others

device_id = None
device = (
raw_device
if "=>" not in raw_device
else raw_device.split("=>")[1].strip()
)
device = raw_device if "=>" not in raw_device else raw_device.split("=>")[1].strip()
if "://" in device:
device, device_id = device.split("://")
if len(device_id) <= 2:
Expand All @@ -162,6 +159,7 @@ def clean_device_info(raw_device):
device_id = 0
return device, device_id


def parse_device(device_str, target_override=""):

rt_driver, device_id = clean_device_info(device_str)
Expand Down Expand Up @@ -287,4 +285,4 @@ def get_all_devices(driver_name):
# # Due to lack of support for multi-reduce, we always collapse reduction
# # dims before dispatch formation right now.
# iree_flags += ["--iree-flow-collapse-reduction-dims"]
# return iree_flags
# return iree_flags
2 changes: 1 addition & 1 deletion apps/shark_studio/modules/shared_cmd_opts.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,7 +597,7 @@ def is_valid_file(arg):
"--defaults",
default="sdxl-turbo.json",
type=str,
help="Path to the default API request .json file. Works for CLI and webui."
help="Path to the default API request .json file. Works for CLI and webui.",
)

p.add_argument(
Expand Down
Loading

0 comments on commit 4759e80

Please sign in to comment.