From 7b4be5431eb57d0eed189d5a5cb55518130e3cb8 Mon Sep 17 00:00:00 2001 From: winskuo-quic <143469905+winskuo-quic@users.noreply.github.com> Date: Sat, 24 Aug 2024 11:30:48 +0800 Subject: [PATCH] Qualcomm AI Engine Direct - Use AIHub's context binary file for Stable Diffusion (#4836) Summary: - Add script for the export and runtime of AIHUB Stable Diffusion. - Add AIHUB Stable Diffusion runner - Add README tutorial --- backends/qualcomm/tests/test_qnn_delegate.py | 49 ++ examples/qualcomm/CMakeLists.txt | 5 + .../stable_diffusion/CMakeLists.txt | 26 + .../qaihub_scripts/stable_diffusion/README.md | 35 + .../stable_diffusion/install_requirements.sh | 3 + .../qaihub_stable_diffusion.py | 472 +++++++++++++ .../qaihub_stable_diffusion_runner.cpp | 140 ++++ .../stable_diffusion/runner/runner.cpp | 621 ++++++++++++++++++ .../stable_diffusion/runner/runner.h | 141 ++++ .../stable_diffusion/stable_diffusion_lib.py | 22 + 10 files changed, 1514 insertions(+) create mode 100644 examples/qualcomm/qaihub_scripts/stable_diffusion/CMakeLists.txt create mode 100644 examples/qualcomm/qaihub_scripts/stable_diffusion/README.md create mode 100755 examples/qualcomm/qaihub_scripts/stable_diffusion/install_requirements.sh create mode 100644 examples/qualcomm/qaihub_scripts/stable_diffusion/qaihub_stable_diffusion.py create mode 100644 examples/qualcomm/qaihub_scripts/stable_diffusion/qaihub_stable_diffusion_runner.cpp create mode 100644 examples/qualcomm/qaihub_scripts/stable_diffusion/runner/runner.cpp create mode 100644 examples/qualcomm/qaihub_scripts/stable_diffusion/runner/runner.h create mode 100644 examples/qualcomm/qaihub_scripts/stable_diffusion/stable_diffusion_lib.py diff --git a/backends/qualcomm/tests/test_qnn_delegate.py b/backends/qualcomm/tests/test_qnn_delegate.py index dd704c35c0..08fd907c40 100644 --- a/backends/qualcomm/tests/test_qnn_delegate.py +++ b/backends/qualcomm/tests/test_qnn_delegate.py @@ -1998,6 +1998,55 @@ def test_llama3_8b(self): model_out = msg["result"] self.assertTrue(model_out.startswith(prompt)) + def test_stable_diffusion(self): + if not self.required_envs(): + self.skipTest("missing required envs") + + prompt = "a photo of an astronaut riding a horse on mars" + cmds = [ + "python", + f"{self.executorch_root}/examples/qualcomm/qaihub_scripts/stable_diffusion/qaihub_stable_diffusion.py", + "--artifact", + self.artifact_dir, + "--build_folder", + self.build_folder, + "--device", + self.device, + "--model", + self.model, + "--text_encoder_bin", + f"{self.artifact_dir}/text_encoder.serialized.bin", + "--unet_bin", + f"{self.artifact_dir}/unet.serialized.bin", + "--vae_bin", + f"{self.artifact_dir}/vae.serialized.bin", + "--vocab_json", + f"{self.artifact_dir}/vocab.json", + "--num_time_steps", + "20", + "--ip", + self.ip, + "--port", + str(self.port), + "--prompt", + f"{prompt}", + "--fix_latents", + ] + if self.host: + cmds.extend(["--host", self.host]) + + p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL) + with Listener((self.ip, self.port)) as listener: + conn = listener.accept() + p.communicate() + msg = json.loads(conn.recv()) + if "Error" in msg: + self.fail(msg["Error"]) + else: + # For the default settings and prompt, the expected results will be {PSNR: 23.258, SSIM: 0.852} + self.assertGreaterEqual(msg["PSNR"], 20) + self.assertGreaterEqual(msg["SSIM"], 0.8) + class TestExampleScript(TestQNN): def required_envs(self, conditions=None) -> bool: diff --git a/examples/qualcomm/CMakeLists.txt b/examples/qualcomm/CMakeLists.txt index fd9c1388b2..94af209cb6 100644 --- a/examples/qualcomm/CMakeLists.txt +++ b/examples/qualcomm/CMakeLists.txt @@ -81,3 +81,8 @@ add_subdirectory( add_subdirectory( ${CMAKE_CURRENT_SOURCE_DIR}/qaihub_scripts/llama ) + +# build qaihub_stable_diffusion_runner +add_subdirectory( + ${CMAKE_CURRENT_SOURCE_DIR}/qaihub_scripts/stable_diffusion +) diff --git a/examples/qualcomm/qaihub_scripts/stable_diffusion/CMakeLists.txt b/examples/qualcomm/qaihub_scripts/stable_diffusion/CMakeLists.txt new file mode 100644 index 0000000000..c897f5f9f8 --- /dev/null +++ b/examples/qualcomm/qaihub_scripts/stable_diffusion/CMakeLists.txt @@ -0,0 +1,26 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# preprocess qaihub_stable_diffusion_runner_src files +set(_qaihub_stable_diffusion_runner__srcs + ${CMAKE_CURRENT_LIST_DIR}/qaihub_stable_diffusion_runner.cpp + ${CMAKE_CURRENT_LIST_DIR}/runner/runner.cpp + ${CMAKE_CURRENT_LIST_DIR}/runner/runner.h +) + +# build qaihub_stable_diffusion_runner +add_executable(qaihub_stable_diffusion_runner ${_qaihub_stable_diffusion_runner__srcs}) +target_include_directories(qaihub_stable_diffusion_runner + PUBLIC ${_common_include_directories} +) +target_link_libraries(qaihub_stable_diffusion_runner + qnn_executorch_backend + executorch_no_prim_ops + extension_data_loader + extension_module + gflags +) +target_compile_options(qaihub_stable_diffusion_runner PUBLIC ${_common_compile_options}) diff --git a/examples/qualcomm/qaihub_scripts/stable_diffusion/README.md b/examples/qualcomm/qaihub_scripts/stable_diffusion/README.md new file mode 100644 index 0000000000..21b3370df7 --- /dev/null +++ b/examples/qualcomm/qaihub_scripts/stable_diffusion/README.md @@ -0,0 +1,35 @@ +# Summary + +## Overview +This file provides you the instructions to run Stable-Diffusion-v2.1 with different parameters via Qualcomm HTP backend. We will demonstrate how to run Stable Diffusion v2.1 on mobile devices using context binaries from Qualcomm AI Hub’s Stable Diffusion v2.1 + +Please check corresponding section for more information. + +## Stable-Diffusion-v2.1 +The model architecture, scheduler, and time embedding are from the [stabilityai/stable-diffusion-2-1-base](https://huggingface.co/stabilityai/stable-diffusion-2-1-base). + +### Instructions +#### Step 1: Setup +1. Follow the [tutorial](https://pytorch.org/executorch/main/getting-started-setup) to set up ExecuTorch. +2. Follow the [tutorial](https://pytorch.org/executorch/stable/build-run-qualcomm-ai-engine-direct-backend.html) to build Qualcomm AI Engine Direct Backend. + +#### Step2: Prepare Model +1. Download the context binaries for TextEncoder, UNet, and VAEDecoder under https://huggingface.co/qualcomm/Stable-Diffusion-v2.1/tree/main +2. Download vocab.json under https://huggingface.co/openai/clip-vit-base-patch32/tree/main + + +#### Step3: Install Requirements +Before running the code, you need to install the necessary Python packages. + +We have verified the code with `diffusers`==0.29.0 and `piq`==0.8.0. Please follow the instructions here to install the required items: +```bash +sh examples/qualcomm/qaihub_scripts/stable_diffusion/install_requirements.sh +``` + +#### Step4: Run default example +In this example, we execute the script for 20 time steps with the `prompt` 'a photo of an astronaut riding a horse on mars': +```bash +python examples/qualcomm/qaihub_scripts/stable_diffusion/qaihub_stable_diffusion.py -a ${ARTIFACTS} -b build_android -m ${SOC_MODEL} --s ${SERIAL_NUM} --text_encoder_bin ${PATH_TO_TEXT_ENCODER_CONTEXT_BINARY} --unet_bin ${PATH_TO_UNET_CONTEXT_BINARY} --vae_bin ${PATH_TO_VAE_CONTEXT_BINARY} --vocab_json ${PATH_TO_VOCAB_JSON_FILE} --num_time_steps 20 --prompt "a photo of an astronaut riding a horse on mars" +``` +- Please replace `${PATH_TO_TEXT_ENCODER_CONTEXT_BINARY}`, `${PATH_TO_UNET_CONTEXT_BINARY}`, and `${PATH_TO_VAE_CONTEXT_BINARY}` with the actual paths to your AI Hub context binary files. +- Please replace `${PATH_TO_VOCAB_JSON_FILE}` with the actual path to your vocab.json file. diff --git a/examples/qualcomm/qaihub_scripts/stable_diffusion/install_requirements.sh b/examples/qualcomm/qaihub_scripts/stable_diffusion/install_requirements.sh new file mode 100755 index 0000000000..bbb4767bee --- /dev/null +++ b/examples/qualcomm/qaihub_scripts/stable_diffusion/install_requirements.sh @@ -0,0 +1,3 @@ +# For Stable Diffusion V2.1 +pip install diffusers==0.29.0 +pip install piq==0.8.0 diff --git a/examples/qualcomm/qaihub_scripts/stable_diffusion/qaihub_stable_diffusion.py b/examples/qualcomm/qaihub_scripts/stable_diffusion/qaihub_stable_diffusion.py new file mode 100644 index 0000000000..862db31f17 --- /dev/null +++ b/examples/qualcomm/qaihub_scripts/stable_diffusion/qaihub_stable_diffusion.py @@ -0,0 +1,472 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import gc +import json +import os +from multiprocessing.connection import Client + +import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManagerAdaptor +import numpy as np +import piq +import torch +from diffusers import EulerDiscreteScheduler, UNet2DConditionModel +from diffusers.models.embeddings import get_timestep_embedding +from executorch.backends.qualcomm.serialization.qnn_compile_spec_schema import ( + QcomChipset, +) +from executorch.backends.qualcomm.utils.utils import ( + canonicalize_program, + from_context_binary, + generate_htp_compiler_spec, + generate_qnn_executorch_compiler_spec, + generate_qnn_executorch_option, +) + +from executorch.examples.qualcomm.qaihub_scripts.stable_diffusion.stable_diffusion_lib import ( + StableDiffusion, +) +from executorch.examples.qualcomm.utils import ( + setup_common_args_and_variables, + SimpleADB, +) +from executorch.exir.backend.backend_api import to_backend +from executorch.exir.passes.memory_planning_pass import MemoryPlanningPass +from PIL import Image +from torchvision.transforms import ToTensor + +target_names = ("text_encoder", "unet", "vae") + + +def get_quant_data( + encoding: dict, data: torch.Tensor, input_model: str, input_index: int +): + scale = encoding[f"{input_model}_input"]["scale"][input_index] + offset = encoding[f"{input_model}_input"]["offset"][input_index] + if offset < 0: + quant_data = data.div(scale).sub(offset).clip(min=0, max=65535).detach() + else: + quant_data = data.div(scale).add(offset).clip(min=0, max=65535).detach() + + return quant_data.to(dtype=torch.uint16) + + +def get_encoding( + path_to_shard: str, + compiler_specs: str, + get_input: bool, + get_output: bool, + num_input: int, + num_output: int, +): + encoding_list = [] + with open(path_to_shard, "rb") as f: + ctx_bin = f.read() + qnn_mgr = PyQnnManagerAdaptor.QnnManager( + generate_qnn_executorch_option(compiler_specs), ctx_bin + ) + assert qnn_mgr.Init().value == 0, "failed to load context binary" + qnn_mgr.AllocateTensor() + if get_input: + encoding_input = {"scale": [], "offset": []} + for i in range(num_input): + inputs = qnn_mgr.GetGraphInputs()[i] + encoding = inputs.GetEncodings() + encoding_input["scale"].append(encoding.data["scale"].item()) + encoding_input["offset"].append(encoding.data["offset"].item()) + encoding_list.append(encoding_input) + if get_output: + encoding_output = {"scale": [], "offset": []} + for i in range(num_output): + outputs = qnn_mgr.GetGraphOutputs()[i] + encoding = outputs.GetEncodings() + encoding_output["scale"].append(encoding.data["scale"].item()) + encoding_output["offset"].append(encoding.data["offset"].item()) + encoding_list.append(encoding_output) + qnn_mgr.Destroy() + return encoding_list + + +def get_encodings( + path_to_shard_encoder: str, + path_to_shard_unet: str, + path_to_shard_vae: str, + compiler_specs, +): + text_encoder_encoding = get_encoding( + path_to_shard=path_to_shard_encoder, + compiler_specs=compiler_specs, + get_input=False, + get_output=True, + num_input=1, + num_output=1, + ) + unet_encoding = get_encoding( + path_to_shard=path_to_shard_unet, + compiler_specs=compiler_specs, + get_input=True, + get_output=True, + num_input=3, + num_output=1, + ) + vae_encoding = get_encoding( + path_to_shard=path_to_shard_vae, + compiler_specs=compiler_specs, + get_input=True, + get_output=True, + num_input=1, + num_output=1, + ) + + return ( + text_encoder_encoding[0], + unet_encoding[0], + unet_encoding[1], + vae_encoding[0], + vae_encoding[1], + ) + + +def get_time_embedding(timestep, time_embedding): + timestep = torch.tensor([timestep]) + t_emb = get_timestep_embedding(timestep, 320, True, 0) + emb = time_embedding(t_emb) + + return emb + + +def build_args_parser(): + parser = setup_common_args_and_variables() + + parser.add_argument( + "-a", + "--artifact", + help="Path for storing generated artifacts by this example. Default ./stable_diffusion_qai_hub", + default="./stable_diffusion_qai_hub", + type=str, + ) + + parser.add_argument( + "--pte_prefix", + help="Prefix of pte files name. Default qaihub_stable_diffusion", + default="qaihub_stable_diffusion", + type=str, + ) + + parser.add_argument( + "--text_encoder_bin", + type=str, + default=None, + help="[For AI hub ctx binary] Path to Text Encoder.", + required=True, + ) + + parser.add_argument( + "--unet_bin", + type=str, + default=None, + help="[For AI hub ctx binary] Path to UNet.", + required=True, + ) + + parser.add_argument( + "--vae_bin", + type=str, + default=None, + help="[For AI hub ctx binary] Path to Vae Decoder.", + required=True, + ) + + parser.add_argument( + "--prompt", + default="a photo of an astronaut riding a horse on mars", + type=str, + help="Prompt to generate image from.", + ) + + parser.add_argument( + "--num_time_steps", + default=20, + type=int, + help="The number of diffusion time steps.", + ) + + parser.add_argument( + "--guidance_scale", + type=float, + default=7.5, + help="Strength of guidance (higher means more influence from prompt).", + ) + + parser.add_argument( + "--vocab_json", + type=str, + help="Path to tokenizer vocab.json file. Can get vocab.json under https://huggingface.co/openai/clip-vit-base-patch32/tree/main", + required=True, + ) + + parser.add_argument( + "--pre_gen_pte", + help="folder path to pre-compiled ptes", + default=None, + type=str, + ) + + parser.add_argument( + "--fix_latents", + help="Enable this option to fix the latents in the unet diffuse step.", + action="store_true", + ) + + return parser + + +def broadcast_ut_result(output_image, seed): + sd = StableDiffusion(seed) + to_tensor = ToTensor() + target = sd(args.prompt, 512, 512, args.num_time_steps) + target = to_tensor(target).unsqueeze(0) + output_tensor = to_tensor( + Image.fromarray(np.round(output_image[0] * 255).astype(np.uint8)[0]) + ).unsqueeze(0) + + psnr_piq = piq.psnr(target, output_tensor) + ssim_piq = piq.ssim(target, output_tensor) + print(f"PSNR: {round(psnr_piq.item(), 3)}, SSIM: {round(ssim_piq.item(), 3)}") + if args.ip and args.port != -1: + with Client((args.ip, args.port)) as conn: + conn.send(json.dumps({"PSNR": psnr_piq.item(), "SSIM": ssim_piq.item()})) + + +def save_result(output_image): + img = Image.fromarray(np.round(output_image[0] * 255).astype(np.uint8)[0]) + save_path = f"{args.artifact}/outputs/output_image.jpg" + img.save(save_path) + print(f"Output image saved at {save_path}") + + +def gen_pte_from_ctx_bin(args, compiler_specs): + # Create custom operators as context loader + bundle_programs = [ + from_context_binary(args.text_encoder_bin, "ctx_loader_0"), + from_context_binary(args.unet_bin, "ctx_loader_1"), + from_context_binary(args.vae_bin, "ctx_loader_2"), + ] + + # Lower with QnnBackend + lowered_modules = [ + to_backend("QnnBackend", prog["edge_program"], compiler_specs) + for prog in bundle_programs + ] + # Setup spill-fill buffer for relieving runtime memory usage + canonicalize_program(lowered_modules) + # export pte files + pte_files = [] + for target_name in target_names: + memory_planning_pass = MemoryPlanningPass( + memory_planning_algo="greedy", + alloc_graph_input=False, + alloc_graph_output=False, + ) + pte_files.append(f"{args.artifact}/{args.pte_prefix}_{target_name}.pte") + with open(pte_files[-1], "wb") as file: + file.write( + lowered_modules[0].buffer( + extract_delegate_segments=True, memory_planning=memory_planning_pass + ) + ) + # GC for reducing host memory consuming + bundle_programs.pop(0) + lowered_modules.pop(0) + gc.collect() + + return pte_files + + +def inference(args, compiler_specs, pte_files): + # Loading a pretrained EulerDiscreteScheduler from the https://huggingface.co/stabilityai/stable-diffusion-2-1-base. + scheduler = EulerDiscreteScheduler.from_pretrained( + "stabilityai/stable-diffusion-2-1-base", subfolder="scheduler", revision="main" + ) + + # Loading a pretrained UNet2DConditionModel (which includes the time embedding) from the https://huggingface.co/stabilityai/stable-diffusion-2-1-base. + time_embedding = UNet2DConditionModel.from_pretrained( + "stabilityai/stable-diffusion-2-1-base", subfolder="unet", revision="main" + ).time_embedding + + scheduler.set_timesteps(args.num_time_steps) + scheduler.config.prediction_type = "epsilon" + # Get encoding of unet and vae + ( + encoder_output, + unet_input, + unet_output, + vae_input, + vae_output, + ) = get_encodings( + args.text_encoder_bin, + args.unet_bin, + args.vae_bin, + compiler_specs, + ) + encoding = { + "encoder_output": encoder_output, + "unet_input": unet_input, + "unet_output": unet_output, + "vae_input": vae_input, + "vae_output": vae_output, + } + + adb = SimpleADB( + qnn_sdk=os.getenv("QNN_SDK_ROOT"), + build_path=args.build_folder, + pte_path=pte_files, + workspace=f"/data/local/tmp/executorch/{args.pte_prefix}", + device_id=args.device, + host_id=args.host, + soc_model=args.model, + runner="examples/qualcomm/qaihub_scripts/stable_diffusion/qaihub_stable_diffusion_runner", + ) + + input_unet = () + input_list_unet = "" + + for i, t in enumerate(scheduler.timesteps): + time_emb = get_quant_data( + encoding, get_time_embedding(t, time_embedding), "unet", 1 + ) + input_list_unet += f"input_{i}_0.raw\n" + input_unet = input_unet + (time_emb,) + + qnn_executor_runner_args = [ + f"--text_encoder_path {adb.workspace}/{args.pte_prefix}_text_encoder.pte", + f"--unet_path {adb.workspace}/{args.pte_prefix}_unet.pte", + f"--vae_path {adb.workspace}/{args.pte_prefix}_vae.pte", + f"--input_list_path {adb.workspace}/input_list.txt", + f"--output_folder_path {adb.output_folder}", + f'--prompt "{args.prompt}"', + f"--guidance_scale {args.guidance_scale}", + f"--num_time_steps {args.num_time_steps}", + f"--vocab_json {adb.workspace}/vocab.json", + ] + if args.fix_latents: + qnn_executor_runner_args.append("--fix_latents") + + text_encoder_output_scale = encoding["encoder_output"]["scale"][0] + text_encoder_output_offset = encoding["encoder_output"]["offset"][0] + unet_input_latent_scale = encoding["unet_input"]["scale"][0] + unet_input_latent_offset = encoding["unet_input"]["offset"][0] + unet_input_text_emb_scale = encoding["unet_input"]["scale"][2] + unet_input_text_emb_offset = encoding["unet_input"]["offset"][2] + unet_output_scale = encoding["unet_output"]["scale"][0] + unet_output_offset = encoding["unet_output"]["offset"][0] + vae_input_scale = encoding["vae_input"]["scale"][0] + vae_input_offset = encoding["vae_input"]["offset"][0] + vae_output_scale = encoding["vae_output"]["scale"][0] + vae_output_offset = encoding["vae_output"]["offset"][0] + + qnn_executor_runner_args = qnn_executor_runner_args + [ + f"--text_encoder_output_scale {text_encoder_output_scale}", + f"--text_encoder_output_offset {text_encoder_output_offset}", + f"--unet_input_latent_scale {unet_input_latent_scale}", + f"--unet_input_latent_offset {unet_input_latent_offset}", + f"--unet_input_text_emb_scale {unet_input_text_emb_scale}", + f"--unet_input_text_emb_offset {unet_input_text_emb_offset}", + f"--unet_output_scale {unet_output_scale}", + f"--unet_output_offset {unet_output_offset}", + f"--vae_input_scale {vae_input_scale}", + f"--vae_input_offset {vae_input_offset}", + f"--vae_output_scale {vae_output_scale}", + f"--vae_output_offset {vae_output_offset}", + ] + + qnn_executor_runner_args = " ".join( + [ + f"cd {adb.workspace} &&", + "export ADSP_LIBRARY_PATH=. &&", + "export LD_LIBRARY_PATH=. &&", + f"./qaihub_stable_diffusion_runner {' '.join(qnn_executor_runner_args)}", + ] + ) + + files = [args.vocab_json] + + if args.fix_latents: + seed = 42 + latents = torch.randn((1, 4, 64, 64), generator=torch.manual_seed(seed)).to( + "cpu" + ) + # We need to explicitly permute after init tensor or else the random value will be different + latents = latents.permute(0, 2, 3, 1).contiguous() + latents = latents * scheduler.init_noise_sigma + flattened_tensor = latents.view(-1) + # Save the flattened tensor to a .raw file + with open(os.path.join(args.artifact, "latents.raw"), "wb") as file: + file.write(flattened_tensor.numpy().tobytes()) + files.append(os.path.join(args.artifact, "latents.raw")) + + adb.push(inputs=input_unet, input_list=input_list_unet, files=files) + adb.execute(custom_runner_cmd=qnn_executor_runner_args) + + output_image = [] + + def post_process_vae(): + with open(f"{args.artifact}/outputs/output_0_0.raw", "rb") as f: + output_image.append( + np.fromfile(f, dtype=np.float32).reshape(1, 512, 512, 3) + ) + + adb.pull(output_path=args.artifact, callback=post_process_vae) + + if args.fix_latents: + broadcast_ut_result(output_image, seed) + else: + save_result(output_image) + + +def main(args): + os.makedirs(args.artifact, exist_ok=True) + + # common part for compile & inference + backend_options = generate_htp_compiler_spec( + use_fp16=False, + use_multi_contexts=True, + ) + compiler_specs = generate_qnn_executorch_compiler_spec( + soc_model=getattr(QcomChipset, args.model), + backend_options=backend_options, + is_from_context_binary=True, + ) + + if args.pre_gen_pte is None: + pte_files = gen_pte_from_ctx_bin(args, compiler_specs) + assert ( + len(pte_files) == 3 + ), f"Error: Expected 3 PTE files, but got {len(pte_files)} files." + + else: + pte_files = [ + f"{args.pre_gen_pte}/{args.pte_prefix}_{target_name}.pte" + for target_name in target_names + ] + if args.compile_only: + return + + inference(args, compiler_specs, pte_files) + + +if __name__ == "__main__": # noqa: C901 + parser = build_args_parser() + args = parser.parse_args() + + try: + main(args) + except Exception as e: + if args.ip and args.port != -1: + with Client((args.ip, args.port)) as conn: + conn.send(json.dumps({"Error": str(e)})) + else: + raise Exception(e) diff --git a/examples/qualcomm/qaihub_scripts/stable_diffusion/qaihub_stable_diffusion_runner.cpp b/examples/qualcomm/qaihub_scripts/stable_diffusion/qaihub_stable_diffusion_runner.cpp new file mode 100644 index 0000000000..687a260c4a --- /dev/null +++ b/examples/qualcomm/qaihub_scripts/stable_diffusion/qaihub_stable_diffusion_runner.cpp @@ -0,0 +1,140 @@ +/* + * Copyright (c) Qualcomm Innovation Center, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include + +DEFINE_string( + text_encoder_path, + "qaihub_stable_diffusion_text_encoder.pte", + "Text Encoder Model serialized in flatbuffer format."); +DEFINE_string( + unet_path, + "qaihub_stable_diffusion_unet.pte", + "Unet Model serialized in flatbuffer format."); +DEFINE_string( + vae_path, + "qaihub_stable_diffusion_vae.pte", + "Vae Model serialized in flatbuffer format."); +DEFINE_string( + output_folder_path, + "outputs", + "Executorch inference data output path."); +DEFINE_string( + input_list_path, + "input_list.txt", + "Input list storing time embedding."); +DEFINE_string( + vocab_json, + "vocab.json", + "Json path to retrieve a list of vocabs."); +DEFINE_string( + prompt, + "a photo of an astronaut riding a horse on mars", + "User input prompt"); +DEFINE_int32(num_time_steps, 20, "Number of time steps."); +DEFINE_double(guidance_scale, 7.5, "Guidance Scale"); + +DEFINE_double(text_encoder_output_scale, 0.0, "Text encoder output scale"); +DEFINE_int32(text_encoder_output_offset, 0, "Text encoder output offset"); +DEFINE_double(unet_input_latent_scale, 0.0, "Unet input latent scale"); +DEFINE_int32(unet_input_latent_offset, 0, "Unet input latent offset"); +DEFINE_double(unet_input_text_emb_scale, 0.0, "Unet input text emb scale"); +DEFINE_int32(unet_input_text_emb_offset, 0, "Unet input text emb offset"); +DEFINE_double(unet_output_scale, 0.0, "Unet output scale"); +DEFINE_int32(unet_output_offset, 0, "Unet output offset"); +DEFINE_double(vae_input_scale, 0.0, "Vae input scale"); +DEFINE_int32(vae_input_offset, 0, "Vae input offset"); +DEFINE_double(vae_output_scale, 0.0, "Vae output scale"); +DEFINE_int32(vae_output_offset, 0, "Vae output offset"); +DEFINE_bool( + fix_latents, + false, + "Enable this option to fix the latents in the unet diffuse step."); + +void usage_message() { + std::string usage_message = + "This is a sample executor runner capable of executing stable diffusion models." + "Users will need binary .pte program files for text_encoder, unet, and vae. Below are the options to retrieve required .pte program files:\n" + "For further information on how to generate the .pte program files and example command to execute this runner, please refer to qaihub_stable_diffsion.py."; + gflags::SetUsageMessage(usage_message); +} + +int main(int argc, char** argv) { + using namespace torch::executor; + runtime_init(); + usage_message(); + gflags::ParseCommandLineFlags(&argc, &argv, true); + bool is_default = + gflags::GetCommandLineFlagInfoOrDie("text_encoder_output_scale") + .is_default || + gflags::GetCommandLineFlagInfoOrDie("text_encoder_output_offset") + .is_default || + gflags::GetCommandLineFlagInfoOrDie("unet_input_latent_scale") + .is_default || + gflags::GetCommandLineFlagInfoOrDie("unet_input_latent_offset") + .is_default || + gflags::GetCommandLineFlagInfoOrDie("unet_input_text_emb_scale") + .is_default || + gflags::GetCommandLineFlagInfoOrDie("unet_input_text_emb_offset") + .is_default || + gflags::GetCommandLineFlagInfoOrDie("unet_output_scale").is_default || + gflags::GetCommandLineFlagInfoOrDie("unet_output_offset").is_default || + gflags::GetCommandLineFlagInfoOrDie("vae_input_scale").is_default || + gflags::GetCommandLineFlagInfoOrDie("vae_input_offset").is_default || + gflags::GetCommandLineFlagInfoOrDie("vae_output_scale").is_default || + gflags::GetCommandLineFlagInfoOrDie("vae_output_offset").is_default; + + ET_CHECK_MSG( + !is_default, + "Please provide scale and offset for unet latent input, unet output, and vae input/output." + "Please refer to qaihub_stable_diffusion.py if you are unsure how to retrieve these values."); + + ET_LOG(Info, "Stable Diffusion runner started"); + std::vector models_path = { + FLAGS_text_encoder_path, FLAGS_unet_path, FLAGS_vae_path}; + + // Create stable_diffusion_runner + Runner runner( + models_path, + FLAGS_num_time_steps, + FLAGS_guidance_scale, + FLAGS_text_encoder_output_scale, + FLAGS_text_encoder_output_offset, + FLAGS_unet_input_latent_scale, + FLAGS_unet_input_latent_offset, + FLAGS_unet_input_text_emb_scale, + FLAGS_unet_input_text_emb_offset, + FLAGS_unet_output_scale, + FLAGS_unet_output_offset, + FLAGS_vae_input_scale, + FLAGS_vae_input_offset, + FLAGS_vae_output_scale, + FLAGS_vae_output_offset, + FLAGS_output_folder_path, + FLAGS_fix_latents); + + ET_CHECK_MSG( + runner.init_tokenizer(FLAGS_vocab_json) == Error::Ok, + "Runner failed to init tokenizer"); + + ET_CHECK_MSG(runner.load() == Error::Ok, "Runner failed to load method"); + + ET_CHECK_MSG( + runner.parse_input_list(FLAGS_input_list_path) == Error::Ok, + "Failed to parse time embedding input list"); + ET_CHECK_MSG( + runner.generate(FLAGS_prompt) == Error::Ok, "Runner failed to generate"); + + ET_CHECK_MSG( + runner.print_performance() == Error::Ok, + "Runner failed to print performance"); + + return 0; +} diff --git a/examples/qualcomm/qaihub_scripts/stable_diffusion/runner/runner.cpp b/examples/qualcomm/qaihub_scripts/stable_diffusion/runner/runner.cpp new file mode 100644 index 0000000000..a997397855 --- /dev/null +++ b/examples/qualcomm/qaihub_scripts/stable_diffusion/runner/runner.cpp @@ -0,0 +1,621 @@ +/* + * Copyright (c) Qualcomm Innovation Center, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +// A simple stable diffusion runner that includes preprocessing and post +// processing logic. The module takes in a string as input and emits a tensor as +// output. + +#include +#include +#include + +#include +#include +#include +#include + +#include +#include + +namespace torch { +namespace executor { + +Runner::Runner( + const std::vector& models_path, + const int num_time_steps, + const float guidance_scale, + const float text_encoder_output_scale, + const int text_encoder_output_offset, + const float unet_input_latent_scale, + const int unet_input_latent_offset, + const float unet_input_text_emb_scale, + const float unet_input_text_emb_offset, + const float unet_output_scale, + const int unet_output_offset, + const float vae_input_scale, + const int vae_input_offset, + const float vae_output_scale, + const int vae_output_offset, + const std::string output_path, + const bool fix_latents) + : num_time_steps_(num_time_steps), + guidance_scale_(guidance_scale), + text_encoder_output_scale_(text_encoder_output_scale), + text_encoder_output_offset_(text_encoder_output_offset), + unet_input_latent_scale_(unet_input_latent_scale), + unet_input_latent_offset_(unet_input_latent_offset), + unet_input_text_emb_scale_(unet_input_text_emb_scale), + unet_input_text_emb_offset_(unet_input_text_emb_offset), + unet_output_scale_(unet_output_scale), + unet_output_offset_(unet_output_offset), + vae_input_scale_(vae_input_scale), + vae_input_offset_(vae_input_offset), + vae_output_scale_(vae_output_scale), + vae_output_offset_(vae_output_offset), + output_path_(output_path), + fix_latents_(fix_latents) { + for (int i = 0; i < models_path.size(); i++) { + modules_.push_back(std::make_unique( + models_path[i], Module::LoadMode::MmapUseMlockIgnoreErrors)); + ET_LOG(Info, "creating module: model_path=%s", models_path[i].c_str()); + } +} + +std::vector> Runner::get_methods_meta() { + std::vector> methods_meta; + for (std::unique_ptr& module : modules_) { + methods_meta.emplace_back(module->method_meta("forward")); + } + return methods_meta; +} + +bool Runner::is_loaded() const { + bool loaded = true; + for (const std::unique_ptr& module : modules_) { + loaded &= module->is_loaded(); + } + return loaded; +} + +Error Runner::load() { + if (is_loaded()) { + return Error::Ok; + } + stats_.model_load_start_ms = util::time_in_ms(); + for (auto& module : modules_) { + ET_CHECK_OK_OR_RETURN_ERROR(module->load_method("forward")); + } + stats_.model_load_end_ms = util::time_in_ms(); + return Error::Ok; +} + +Error Runner::parse_input_list(std::string& path) { + // Fill in data for input + std::ifstream input_list(path); + time_emb_list_.reserve(num_time_steps_); + ET_CHECK_MSG(input_list.is_open(), "Input list error opening file"); + std::string time_emb_file; + for (int i = 0; i < num_time_steps_; i++) { + std::getline(input_list, time_emb_file); + std::ifstream is; + is.open(time_emb_file, std::ios::binary); + is.seekg(0, std::ios::end); + size_t filesize = is.tellg(); + is.seekg(0, std::ios::beg); + std::vector time_emb; + time_emb.resize(filesize / sizeof(uint16_t)); + is.read(reinterpret_cast(time_emb.data()), filesize); + time_emb_list_.push_back(time_emb); + } + return Error::Ok; +} + +Error Runner::init_tokenizer(const std::string& vocab_json_path) { + ET_LOG(Info, "Loading Tokenizer from json"); + stats_.tokenizer_load_start_ms = util::time_in_ms(); + std::ifstream fin(vocab_json_path); + auto update_map = [this](std::string& target, std::regex& re) { + std::smatch sm; + std::regex_search(target, sm, re); + // replace special character, please extend this if any cornor case found + std::string text = sm[1]; + std::unordered_map post_process = { + {"\"", std::regex(R"(\\\")")}, + {" ", std::regex(R"()")}, + {"\\", std::regex(R"(\\\\)")}}; + for (auto& p : post_process) { + text = std::regex_replace(text, p.second, p.first); + } + vocab_to_token_map_[text] = std::stoi(sm[2]); + }; + + if (fin.is_open()) { + std::string line, text; + while (getline(fin, line)) { + text += line; + } + fin.close(); + + std::regex re_anchor(R"(\d,\")"); + std::regex re_pattern(R"(\{?\"(.*)\":([\d]+)\}?)"); + auto begin = std::sregex_iterator(text.begin(), text.end(), re_anchor); + auto end = std::sregex_iterator(); + size_t pos = 0; + for (std::sregex_iterator iter = begin; iter != end; ++iter) { + std::smatch match; + size_t len = iter->position() - pos + 1; + std::string target = text.substr(pos, len); + update_map(target, re_pattern); + pos = iter->position() + 1; + } + // process last vocabulary + std::string target = text.substr(pos); + update_map(target, re_pattern); + } + stats_.tokenizer_load_end_ms = util::time_in_ms(); + return Error::Ok; +} + +std::vector Runner::tokenize(std::string prompt) { + std::string bos("<|startoftext|>"), eos("<|endoftext|>"); + std::vector vocabs; + vocabs.reserve(max_tokens_); + std::vector tokens(1, vocab_to_token_map_[bos]); + + // pretokenize + // ref: https://github.com/monatis/clip.cpp + // https://huggingface.co/openai/clip-vit-base-patch32 + std::string text; + std::regex re( + R"('s|'t|'re|'ve|'m|'ll|'d| ?[[:alpha:]]+| ?[[:digit:]]+| ?[^\s[:alpha:][:digit:]]+|\s+(?!\S)|\s+)"); + std::smatch sm; + while (std::regex_search(prompt, sm, re)) { + for (auto& v : sm) { + vocabs.push_back(v); + } + prompt = sm.suffix(); + } + for (std::string& v : vocabs) { + std::string word = (v[0] == ' ') ? v.substr(1) : v; + word += " "; + auto iter = vocab_to_token_map_.find(word); + if (iter != vocab_to_token_map_.end()) { + tokens.push_back(iter->second); + continue; + } + for (int i = 0; i < v.size(); ++i) { + for (int j = v.size() - 1; j >= i; --j) { + std::string token = v.substr(i, j - 1 + 1); + auto iter = vocab_to_token_map_.find(token); + if (iter != vocab_to_token_map_.end()) { + tokens.push_back(iter->second); + i = j + 1; + break; + } else if (j == i) { + ET_LOG(Error, "unknown token found: %s", token.c_str()); + } + } + } + } + tokens.push_back(vocab_to_token_map_[eos]); + return tokens; +} + +std::vector Runner::gen_latent_from_file() { + std::vector tensor_vector; + std::ifstream file("latents.raw", std::ios::binary); + if (!file.is_open()) { + ET_LOG(Error, "Error opening file!"); + return tensor_vector; + } + + // Read the tensor data + float value; + while (file.read(reinterpret_cast(&value), sizeof(float))) { + tensor_vector.push_back(value); + } + file.close(); + return tensor_vector; +} + +std::vector Runner::gen_random_latent(float sigma) { + std::random_device rnd_device; + std::mt19937 mersenne_engine{rnd_device()}; + std::normal_distribution dist{0.0f, 1.0f}; + + constexpr int latent_size = 1 * 64 * 64 * 4; + std::vector random_vector(latent_size); + + for (float& value : random_vector) { + value = dist(mersenne_engine) * sigma; + } + return random_vector; +} + +std::vector Runner::get_time_steps() { + std::vector time_steps(num_time_steps_); + for (int i = 0; i < num_time_steps_; ++i) { + time_steps[i] = (num_train_timesteps_ - 1) * + (1.0f - static_cast(i) / (num_time_steps_ - 1)); + } + return time_steps; +} + +std::vector Runner::get_sigmas(const std::vector& time_steps) { + float start = std::sqrt(beta_start_); + float end = std::sqrt(beta_end_); + std::vector betas(num_train_timesteps_); + float step = (end - start) / (num_train_timesteps_ - 1); + for (int i = 0; i < num_train_timesteps_; ++i) { + float value = start + i * step; + betas[i] = 1 - (value * value); + } + + std::vector alphas_cumprod(num_train_timesteps_); + float cumprod = 1.0; + for (int i = 0; i < num_train_timesteps_; ++i) { + cumprod *= betas[i]; + alphas_cumprod[i] = cumprod; + } + + std::vector sigmas(num_train_timesteps_); + for (int i = 0; i < num_train_timesteps_; ++i) { + sigmas[i] = std::sqrt((1.0 - alphas_cumprod[i]) / alphas_cumprod[i]); + } + + std::vector res(time_steps.size()); + for (size_t i = 0; i < time_steps.size(); ++i) { + float index = + static_cast(i) * (sigmas.size() - 1) / (time_steps.size() - 1); + size_t lower_index = static_cast(std::floor(index)); + size_t upper_index = static_cast(std::ceil(index)); + + float weight = index - lower_index; + res[i] = + (1.0 - weight) * sigmas[lower_index] + weight * sigmas[upper_index]; + } + std::reverse(res.begin(), res.end()); + res.push_back(0); + + return res; +} + +void Runner::scale_model_input( + const std::vector& latents, + std::vector& latent_model_input, + float sigma) { + for (int i = 0; i < latents.size(); i++) { + latent_model_input[i] = (latents[i] / std::sqrt(sigma * sigma + 1)); + } +} + +void Runner::quant_tensor( + const std::vector& fp_vec, + std::vector& quant_vec, + float scale, + int offset) { + offset = abs(offset); + for (int i = 0; i < fp_vec.size(); i++) { + quant_vec[i] = static_cast((fp_vec[i] / scale) + offset); + } +} + +void Runner::dequant_tensor( + const std::vector& quant_vec, + std::vector& fp_vec, + float scale, + int offset) { + offset = abs(offset); + for (int i = 0; i < quant_vec.size(); i++) { + fp_vec[i] = (quant_vec[i] - offset) * scale; + } +} + +// Using the same algorithm as EulerDiscreteScheduler in python. +void Runner::step( + const std::vector& model_output, + const std::vector& sigmas, + std::vector& sample, + std::vector& prev_sample, + int step_index) { + float sigma = sigmas[step_index]; + float dt = sigmas[step_index + 1] - sigma; + + for (int i = 0; i < sample.size(); ++i) { + float sigma_hat = sample[i] - (sigma * model_output[i]); + prev_sample[i] = (sample[i] - sigma_hat) / sigma; + prev_sample[i] = sample[i] + (prev_sample[i] * dt); + } + sample = prev_sample; +} + +Error Runner::generate(std::string prompt) { + ET_LOG(Info, "Start generating"); + stats_.generate_start_ms = util::time_in_ms(); + + // Start tokenize + stats_.tokenizer_parsing_start_ms = util::time_in_ms(); + std::vector cond_tokens = tokenize(prompt); + cond_tokens.resize(max_tokens_); + std::vector uncond_tokens = tokenize(""); + uncond_tokens.resize(max_tokens_); + stats_.tokenizer_parsing_end_ms = util::time_in_ms(); + + std::vector> method_metas = get_methods_meta(); + + MethodMeta encoder_method_meta = method_metas[0].get(); + // Initialize text_encoder input tensors: cond/uncond tokenized_input[1,77] + ManagedTensor managed_cond_tokens( + cond_tokens.data(), + {1, 77}, + encoder_method_meta.input_tensor_meta(0)->scalar_type()); + ManagedTensor managed_uncond_tokens( + uncond_tokens.data(), + {1, 77}, + encoder_method_meta.input_tensor_meta(0)->scalar_type()); + Tensor cond_tokens_tensor = managed_cond_tokens.get_aliasing_tensor(); + Tensor uncond_tokens_tensor = managed_uncond_tokens.get_aliasing_tensor(); + // Initialize text_encoder output tensors: cond/uncond embedding[1, 77, 1024] + constexpr int emb_size = 1 * 77 * 1024; + std::vector cond_emb_vec(emb_size); + std::vector uncond_emb_vec(emb_size); + std::vector fp_emb_vec(emb_size); + ManagedTensor managed_cond_emb( + cond_emb_vec.data(), + {1, 77, 1024}, + encoder_method_meta.output_tensor_meta(0)->scalar_type()); + ManagedTensor managed_uncond_emb( + uncond_emb_vec.data(), + {1, 77, 1024}, + encoder_method_meta.output_tensor_meta(0)->scalar_type()); + Tensor cond_emb_tensor = managed_cond_emb.get_aliasing_tensor(); + Tensor uncond_emb_tensor = managed_uncond_emb.get_aliasing_tensor(); + modules_[0]->set_output_data_ptr(cond_emb_tensor, 0); + long encoder_start = util::time_in_ms(); + auto cond_res = modules_[0]->forward({cond_tokens_tensor}); + stats_.text_encoder_execution_time += (util::time_in_ms() - encoder_start); + modules_[0]->set_output_data_ptr(uncond_emb_tensor, 0); + encoder_start = util::time_in_ms(); + auto uncond_res = modules_[0]->forward({uncond_tokens_tensor}); + stats_.text_encoder_execution_time += (util::time_in_ms() - encoder_start); + + // Initialize unet parameters + MethodMeta unet_method_meta = method_metas[1].get(); + std::vector time_steps = get_time_steps(); + std::vector sigmas = get_sigmas(time_steps); + float max_sigma = *std::max_element(sigmas.begin(), sigmas.end()); + std::vector latent; + if (fix_latents_) { + latent = gen_latent_from_file(); + } else { + latent = gen_random_latent(max_sigma); + } + std::vector prev_sample(latent.size()); + + // Initialize unet input tensors + // 1. latent[1,64,64,4] + // 2. time_embedding[1,1280] + // 3. cond/uncond embedding[1,77,1024] + std::vector latent_model_input(latent.size()); + std::vector fp_latent_model_input(latent.size()); + ManagedTensor managed_latent( + latent_model_input.data(), + {1, 64, 64, 4}, + unet_method_meta.input_tensor_meta(0)->scalar_type()); + Tensor latent_tensor = managed_latent.get_aliasing_tensor(); + std::vector managed_time_emb_tensors; + std::vector time_emb_tensors; + managed_time_emb_tensors.reserve(num_time_steps_); + time_emb_tensors.reserve(num_time_steps_); + for (int step_index = 0; step_index < num_time_steps_; step_index++) { + managed_time_emb_tensors.emplace_back(ManagedTensor( + time_emb_list_[step_index].data(), + {1, 1280}, + unet_method_meta.input_tensor_meta(1)->scalar_type())); + time_emb_tensors.emplace_back( + managed_time_emb_tensors.back().get_aliasing_tensor()); + } + // requantize text encoders output + dequant_tensor( + cond_emb_vec, + fp_emb_vec, + text_encoder_output_scale_, + text_encoder_output_offset_); + quant_tensor( + fp_emb_vec, + cond_emb_vec, + unet_input_text_emb_scale_, + unet_input_text_emb_offset_); + dequant_tensor( + uncond_emb_vec, + fp_emb_vec, + text_encoder_output_scale_, + text_encoder_output_offset_); + quant_tensor( + fp_emb_vec, + uncond_emb_vec, + unet_input_text_emb_scale_, + unet_input_text_emb_offset_); + + // Initialize unet output tensors: text/uncond noise_pred[1,64,64,4] + std::vector noise_pred_text(latent.size()); + std::vector noise_pred_uncond(latent.size()); + std::vector fp_noise_pred_text(noise_pred_text.size()); + std::vector fp_noise_pred_uncond(noise_pred_uncond.size()); + ManagedTensor managed_noise_pred_text( + noise_pred_text.data(), + {1, 64, 64, 4}, + unet_method_meta.output_tensor_meta(0)->scalar_type()); + Tensor noise_pred_text_tensor = managed_noise_pred_text.get_aliasing_tensor(); + ManagedTensor managed_noise_pred_uncond( + noise_pred_uncond.data(), + {1, 64, 64, 4}, + unet_method_meta.output_tensor_meta(0)->scalar_type()); + Tensor noise_pred_uncond_tensor = + managed_noise_pred_uncond.get_aliasing_tensor(); + + // Execute unet + for (int step_index = 0; step_index < num_time_steps_; step_index++) { + long start_post_process = util::time_in_ms(); + scale_model_input(latent, fp_latent_model_input, sigmas[step_index]); + + quant_tensor( + fp_latent_model_input, + latent_model_input, + unet_input_latent_scale_, + unet_input_latent_offset_); + + stats_.unet_aggregate_post_processing_time += + (util::time_in_ms() - start_post_process); + modules_[1]->set_output_data_ptr(noise_pred_text_tensor, 0); + long start_unet_execution = util::time_in_ms(); + auto cond_res = modules_[1]->forward( + {latent_tensor, time_emb_tensors[step_index], cond_emb_tensor}); + stats_.unet_aggregate_execution_time += + (util::time_in_ms() - start_unet_execution); + modules_[1]->set_output_data_ptr(noise_pred_uncond_tensor, 0); + start_unet_execution = util::time_in_ms(); + auto uncond_res = modules_[1]->forward( + {latent_tensor, + time_emb_tensors[step_index], + uncond_emb_tensor}); // results in noise_pred_uncond_vec + stats_.unet_aggregate_execution_time += + (util::time_in_ms() - start_unet_execution); + + // start unet post processing + start_post_process = util::time_in_ms(); + + dequant_tensor( + noise_pred_text, + fp_noise_pred_text, + unet_output_scale_, + unet_output_offset_); + dequant_tensor( + noise_pred_uncond, + fp_noise_pred_uncond, + unet_output_scale_, + unet_output_offset_); + + for (int i = 0; i < fp_noise_pred_text.size(); i++) { + fp_noise_pred_text[i] = fp_noise_pred_uncond[i] + + guidance_scale_ * (fp_noise_pred_text[i] - fp_noise_pred_uncond[i]); + } + step(fp_noise_pred_text, sigmas, latent, prev_sample, step_index); + stats_.unet_aggregate_post_processing_time += + (util::time_in_ms() - start_post_process); + } + + // Start VAE + MethodMeta vae_method_meta = method_metas[2].get(); + // Initialize vae input tensor : latent[1,64,64,4] + std::vector vae_input(latent.size()); + ManagedTensor managed_vae_input( + vae_input.data(), + {1, 64, 64, 4}, + vae_method_meta.input_tensor_meta(0)->scalar_type()); + Tensor vae_input_tensor = managed_vae_input.get_aliasing_tensor(); + // Intialize vae output tensor: output[1,512,512,3] + constexpr int image_size = 1 * 512 * 512 * 3; + std::vector q_out(image_size); + std::vector out(image_size); + ManagedTensor managed_output( + q_out.data(), + {1, 512, 512, 3}, + vae_method_meta.output_tensor_meta(0)->scalar_type()); + Tensor output_tensor = managed_output.get_aliasing_tensor(); + + quant_tensor(latent, vae_input, vae_input_scale_, vae_input_offset_); + + modules_[2]->set_output_data_ptr(output_tensor, 0); + long start_vae_execution = util::time_in_ms(); + auto vae_res = modules_[2]->forward({vae_input_tensor}); + stats_.vae_execution_time = (util::time_in_ms() - start_vae_execution); + stats_.generate_end_ms = util::time_in_ms(); + + // Dequant uint16 output to fp32 output + dequant_tensor(q_out, out, vae_output_scale_, vae_output_offset_); + + // Saving outputs + auto output_file_name = output_path_ + "/output_0_0.raw"; + std::ofstream fout(output_file_name.c_str(), std::ios::binary); + fout.write( + reinterpret_cast(out.data()), out.size() * sizeof(float)); + fout.close(); + + return Error::Ok; +} + +Error Runner::print_performance() { + ET_LOG(Info, "\tTotal Number of steps:\t\t\t\t%d", num_time_steps_); + + ET_LOG( + Info, + "\tTokenizer Load Time:\t\t\t\t%f (seconds)", + ((double)(stats_.tokenizer_load_end_ms - stats_.tokenizer_load_start_ms) / + stats_.SCALING_FACTOR_UNITS_PER_SECOND)); + + ET_LOG( + Info, + "\tModel Load Time:\t\t\t\t%f (seconds)", + ((double)(stats_.model_load_end_ms - stats_.model_load_start_ms) / + stats_.SCALING_FACTOR_UNITS_PER_SECOND)); + + ET_LOG( + Info, + "\tGenerate Time(Tokenize + Encoder + UNet + VAE):\t%f (seconds)", + ((double)(stats_.generate_end_ms - stats_.generate_start_ms) / + stats_.SCALING_FACTOR_UNITS_PER_SECOND)); + + ET_LOG( + Info, + "\tTokenize Time:\t\t\t\t\t%f (seconds)", + ((double)(stats_.tokenizer_parsing_end_ms - + stats_.tokenizer_parsing_start_ms) / + stats_.SCALING_FACTOR_UNITS_PER_SECOND)); + + ET_LOG( + Info, + "\tText Encoder Execution Time:\t\t\t%f (seconds)", + ((double)(stats_.text_encoder_execution_time) / + stats_.SCALING_FACTOR_UNITS_PER_SECOND)); + + ET_LOG( + Info, + "\tUnet Aggregate (Cond + Uncond) Execution Time:\t%f (seconds)", + ((double)stats_.unet_aggregate_execution_time / + (stats_.SCALING_FACTOR_UNITS_PER_SECOND))); + + ET_LOG( + Info, + "\tUnet Average Execution Time:\t\t\t%f (seconds)", + ((double)(stats_.unet_aggregate_execution_time / (num_time_steps_ * 2)) / + (stats_.SCALING_FACTOR_UNITS_PER_SECOND))); + + ET_LOG( + Info, + "\tUnet Aggregate Post-Processing Time:\t\t%f (seconds)", + ((double)(stats_.unet_aggregate_post_processing_time) / + stats_.SCALING_FACTOR_UNITS_PER_SECOND)); + + ET_LOG( + Info, + "\tUnet Average Post-Processing Time:\t\t%f (seconds)", + ((double)(stats_.unet_aggregate_post_processing_time / + (num_time_steps_ * 2)) / + (stats_.SCALING_FACTOR_UNITS_PER_SECOND))); + + ET_LOG( + Info, + "\tVAE Execution Time:\t\t\t\t%f (seconds)", + ((double)(stats_.vae_execution_time) / + stats_.SCALING_FACTOR_UNITS_PER_SECOND)); + return Error::Ok; +} + +} // namespace executor +} // namespace torch diff --git a/examples/qualcomm/qaihub_scripts/stable_diffusion/runner/runner.h b/examples/qualcomm/qaihub_scripts/stable_diffusion/runner/runner.h new file mode 100644 index 0000000000..e081ab80cc --- /dev/null +++ b/examples/qualcomm/qaihub_scripts/stable_diffusion/runner/runner.h @@ -0,0 +1,141 @@ +/* + * Copyright (c) Qualcomm Innovation Center, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +// A simple diffusion runner that includes preprocessing and post processing +// logic. The module takes in a string as input and emites a tensor as output. + +#pragma once + +#include +#include +#include + +#include + +namespace torch { +namespace executor { + +class Runner { + public: + explicit Runner( + const std::vector& models_path, + const int num_time_steps, + const float guidance_scale, + const float text_encoder_output_scale, + const int text_encoder_output_offset, + const float unet_input_latent_scale, + const int unet_input_latent_offset, + const float unet_input_text_emb_scale, + const float unet_input_text_emb_offset, + const float unet_output_scale, + const int unet_output_offset, + const float vae_input_scale, + const int vae_input_offset, + const float vae_output_scale, + const int vae_output_offset, + const std::string output_path, + const bool fix_latents); + + struct Stats { + // Scaling factor for timestamps - in this case, we use ms. + const long SCALING_FACTOR_UNITS_PER_SECOND = 1000; + // Time stamps for the different stages of the execution + // model_load_start_ms: Model loading time + long model_load_start_ms; + long model_load_end_ms; + + // tokenizer loading time + long tokenizer_load_start_ms = 0; + long tokenizer_load_end_ms = 0; + + // tokenizer parsing time + long tokenizer_parsing_start_ms = 0; + long tokenizer_parsing_end_ms = 0; + + // Total time to run generate + long generate_start_ms = 0; + long generate_end_ms = 0; + + // text encoder execution time + long text_encoder_execution_time = 0; + + // Unet aggregation execution time over n steps for cond + uncond + long unet_aggregate_execution_time = 0; + + // UNet aggregation post processing time over n steps for cond + uncond. + // This is the time from processing unet's output until feeding it into the + // next iteration. + long unet_aggregate_post_processing_time = 0; + + // VAE execution time + long vae_execution_time = 0; + }; + + bool is_loaded() const; + Error load(); + Error init_tokenizer(const std::string& vocab_json_path); + Error print_performance(); + std::vector tokenize(std::string prompt); + std::vector gen_latent_from_file(); + std::vector gen_random_latent(float sigma); + void step( + const std::vector& model_output, + const std::vector& sigmas, + std::vector& sample, + std::vector& prev_sample, + int step_index); + std::vector> get_methods_meta(); + std::vector get_time_steps(); + std::vector get_sigmas(const std::vector& time_steps); + void scale_model_input( + const std::vector& vec, + std::vector& latent_model_input, + float sigma); + Error parse_input_list(std::string& path); + Error generate(std::string prompt); + void quant_tensor( + const std::vector& fp_vec, + std::vector& quant_vec, + float scale, + int offset); + void dequant_tensor( + const std::vector& quant_vec, + std::vector& fp_vec, + float scale, + int offset); + + private: + Stats stats_; + std::vector> modules_; + std::vector> time_emb_list_; + std::unordered_map vocab_to_token_map_; + + std::string output_path_; + int num_time_steps_; + float guidance_scale_; + float text_encoder_output_scale_; + int text_encoder_output_offset_; + float unet_input_latent_scale_; + int unet_input_latent_offset_; + float unet_input_text_emb_scale_; + int unet_input_text_emb_offset_; + float unet_output_scale_; + int unet_output_offset_; + float vae_input_scale_; + int vae_input_offset_; + float vae_output_scale_; + int vae_output_offset_; + const float beta_start_ = 0.00085; + const float beta_end_ = 0.012; + const int num_train_timesteps_ = 1000; + const int max_tokens_ = 77; + const bool fix_latents_ = false; +}; + +} // namespace executor +} // namespace torch diff --git a/examples/qualcomm/qaihub_scripts/stable_diffusion/stable_diffusion_lib.py b/examples/qualcomm/qaihub_scripts/stable_diffusion/stable_diffusion_lib.py new file mode 100644 index 0000000000..8ec5783131 --- /dev/null +++ b/examples/qualcomm/qaihub_scripts/stable_diffusion/stable_diffusion_lib.py @@ -0,0 +1,22 @@ +import torch +from diffusers import EulerDiscreteScheduler, StableDiffusionPipeline + + +class StableDiffusion: + def __init__(self, seed=42): + self.model_id: str = "stabilityai/stable-diffusion-2-1-base" + self.generator = torch.manual_seed(seed) + self.scheduler = EulerDiscreteScheduler.from_pretrained( + self.model_id, subfolder="scheduler" + ) + + self.pipe = StableDiffusionPipeline.from_pretrained( + self.model_id, scheduler=self.scheduler, torch_dtype=torch.float32 + ) + self.pipe = self.pipe.to("cpu") + + def __call__(self, prompt, height, width, num_time_steps): + image = self.pipe( + prompt, height, width, num_time_steps, generator=self.generator + ).images[0] + return image