diff --git a/examples/05_stable_diffusion/README.md b/examples/05_stable_diffusion/README.md index 224d799ec..1f62403de 100644 --- a/examples/05_stable_diffusion/README.md +++ b/examples/05_stable_diffusion/README.md @@ -6,7 +6,7 @@ In this example, we show how to build fast AIT modules for CLIP, UNet, VAE model First, clone, build, and install AITemplate [per the README instructions](https://github.com/facebookincubator/AITemplate#clone-the-code). -This AIT stable diffusion example depends on `diffusers`, `transformers`, `torch` and `click`. +This AIT stable diffusion example depends on `diffusers`, `transformers`, `torch` and `click`. Verify the library versions. We have tested transformers 4.21/4.22/4.23, diffusers 0.3/0.4 and torch 1.11/1.12. @@ -30,6 +30,11 @@ python3 examples/05_stable_diffusion/compile.py --token ACCESS_TOKEN ``` It generates three folders: `./tmp/CLIPTextModel`, `./tmp/UNet2DConditionModel`, `./tmp/AutoencoderKL`. In each folder, there is a `test.so` file which is the generated AIT module for the model. +Compile the img2img models: +``` +python3 examples/05_stable_diffusion/compile.py --img2img True --token ACCESS_TOKEN +``` + #### Multi-GPU profiling AIT needs to do profiling to select the best algorithms for CUTLASS and CK. To enable multiple GPUs for profiling, use the environment variable `CUDA_VISIBLE_DEVICES` on NVIDIA platform and `HIP_VISIBLE_DEVICES` on AMD platform. @@ -50,6 +55,12 @@ Run AIT models with an example image: python3 examples/05_stable_diffusion/demo.py --token ACCESS_TOKEN ``` +Img2img demo: + +``` +python3 examples/05_stable_diffusion/demo_img2img.py --token ACCESS_TOKEN +``` + Check the resulted image: `example_ait.png` @@ -131,10 +142,20 @@ _OOM = Out of Memory_ | 16 | 7906 | 0.49 | +## IMG2IMG + +### A100-40GB / CUDA 11.6, 40 steps + +| Module | PT Latency (ms) | AIT Latency (ms) | +|----------|-----------------|------------------| +| Pipeline | 4163.60 | 1785.46 | + + ### Note for Performance Results - For all benchmarks we render the images of size 512x512 +- For img2img model we only support fix input 512x768 by default, stay tuned for dynamic shape support - For NVIDIA A100, our test cluster doesn't allow to lock frequency. We make warm up longer to collect more stable results, but it is expected to have small variance to the results with locked frequency. - To benchmark MI-250 1 GCD, we lock the frequency with command `rocm-smi -d x --setperfdeterminism 1700`, where `x` is the GPU id. - Performance results are what we can reproduced & take reference. It should not be used for other purposes. diff --git a/examples/05_stable_diffusion/compile.py b/examples/05_stable_diffusion/compile.py index 63f00c297..4c6288a84 100644 --- a/examples/05_stable_diffusion/compile.py +++ b/examples/05_stable_diffusion/compile.py @@ -317,9 +317,10 @@ def compile_vae( @click.command() @click.option("--token", default="", help="access token") @click.option("--batch-size", default=1, help="batch size") +@click.option("--img2img", default=False, help="compile img2img models") @click.option("--use-fp16-acc", default=True, help="use fp16 accumulation") @click.option("--convert-conv-to-gemm", default=True, help="convert 1x1 conv to gemm") -def compile_diffusers(token, batch_size, use_fp16_acc=True, convert_conv_to_gemm=True): +def compile_diffusers(token, batch_size, img2img=False, use_fp16_acc=True, convert_conv_to_gemm=True): logging.getLogger().setLevel(logging.INFO) np.random.seed(0) torch.manual_seed(4896) @@ -338,16 +339,19 @@ def compile_diffusers(token, batch_size, use_fp16_acc=True, convert_conv_to_gemm use_auth_token=access_token, ).to("cuda") + width = 96 if img2img else 64 + # CLIP compile_clip(batch_size=batch_size, use_fp16_acc=use_fp16_acc, convert_conv_to_gemm=convert_conv_to_gemm) # UNet compile_unet( batch_size=batch_size * 2, + ww=width, use_fp16_acc=use_fp16_acc, convert_conv_to_gemm=convert_conv_to_gemm, ) # VAE - compile_vae(batch_size=batch_size, use_fp16_acc=use_fp16_acc, convert_conv_to_gemm=convert_conv_to_gemm) + compile_vae(batch_size=batch_size, width=width, use_fp16_acc=use_fp16_acc, convert_conv_to_gemm=convert_conv_to_gemm) if __name__ == "__main__": diff --git a/examples/05_stable_diffusion/demo_img2img.py b/examples/05_stable_diffusion/demo_img2img.py new file mode 100644 index 000000000..5a9f8d0d6 --- /dev/null +++ b/examples/05_stable_diffusion/demo_img2img.py @@ -0,0 +1,67 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from io import BytesIO + +import click +import requests +import torch +from PIL import Image + +from aitemplate.testing.benchmark_pt import benchmark_torch_function +from pipeline_stable_diffusion_img2img_ait import StableDiffusionImg2ImgAITPipeline + + +@click.command() +@click.option("--token", default="", help="access token") +@click.option( + "--prompt", default="A fantasy landscape, trending on artstation", help="prompt" +) +@click.option( + "--benchmark", type=bool, default=False, help="run stable diffusion e2e benchmark" +) +def run(token, prompt, benchmark): + + # load the pipeline + device = "cuda" + model_id_or_path = "CompVis/stable-diffusion-v1-4" + pipe = StableDiffusionImg2ImgAITPipeline.from_pretrained( + model_id_or_path, + revision="fp16", + torch_dtype=torch.float16, + use_auth_token=token, + ) + pipe = pipe.to(device) + + # let's download an initial image + url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg" + + response = requests.get(url) + init_image = Image.open(BytesIO(response.content)).convert("RGB") + init_image = init_image.resize((768, 512)) + + with torch.autocast("cuda"): + images = pipe( + prompt=prompt, init_image=init_image, strength=0.75, guidance_scale=7.5 + ).images + if benchmark: + args = (prompt, init_image) + t = benchmark_torch_function(10, pipe, *args) + print(f"sd e2e: {t} ms") + + images[0].save("fantasy_landscape_ait.png") + + +if __name__ == "__main__": + run() diff --git a/examples/05_stable_diffusion/pipeline_stable_diffusion_img2img_ait.py b/examples/05_stable_diffusion/pipeline_stable_diffusion_img2img_ait.py new file mode 100644 index 000000000..eb5142615 --- /dev/null +++ b/examples/05_stable_diffusion/pipeline_stable_diffusion_img2img_ait.py @@ -0,0 +1,398 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import inspect +import os +from typing import List, Optional, Union + +import numpy as np + +import PIL +import torch +from aitemplate.compiler import Model + +from diffusers import ( + AutoencoderKL, + DDIMScheduler, + LMSDiscreteScheduler, + PNDMScheduler, + StableDiffusionImg2ImgPipeline, + UNet2DConditionModel, +) +from diffusers.pipelines.stable_diffusion import ( + StableDiffusionPipelineOutput, + StableDiffusionSafetyChecker, +) +from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer + + +def preprocess(image): + w, h = image.size + w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 + image = image.resize((w, h), resample=PIL.Image.LANCZOS) + image = np.array(image).astype(np.float32) / 255.0 + image = image[None].transpose(0, 3, 1, 2) + image = torch.from_numpy(image) + return 2.0 * image - 1.0 + + +class StableDiffusionImg2ImgAITPipeline(StableDiffusionImg2ImgPipeline): + r""" + Pipeline for text-guided image to image generation using Stable Diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offsensive or harmful. + Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details. + feature_extractor ([`CLIPFeatureExtractor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPFeatureExtractor, + ): + # super().__init__() + super().__init__( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + scheduler = scheduler.set_format("pt") + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + + workdir = "tmp/" + self.clip_ait_exe = self.init_ait_module( + model_name="CLIPTextModel", workdir=workdir + ) + self.unet_ait_exe = self.init_ait_module( + model_name="UNet2DConditionModel", workdir=workdir + ) + self.vae_ait_exe = self.init_ait_module( + model_name="AutoencoderKL", workdir=workdir + ) + + def init_ait_module( + self, + model_name, + workdir, + ): + mod = Model(os.path.join(workdir, model_name, "test.so")) + return mod + + def unet_inference(self, latent_model_input, timesteps, encoder_hidden_states): + exe_module = self.unet_ait_exe + timesteps_pt = timesteps.expand(latent_model_input.shape[0]) + inputs = { + "input0": latent_model_input.permute((0, 2, 3, 1)) + .contiguous() + .cuda() + .half(), + "input1": timesteps_pt.cuda().half(), + "input2": encoder_hidden_states.cuda().half(), + } + ys = [] + num_ouputs = len(exe_module.get_output_name_to_index_map()) + for i in range(num_ouputs): + shape = exe_module.get_output_maximum_shape(i) + ys.append(torch.empty(shape).cuda().half()) + exe_module.run_with_tensors(inputs, ys, graph_mode=True) + noise_pred = ys[0].permute((0, 3, 1, 2)).float() + return noise_pred + + def clip_inference(self, input_ids, seqlen=64): + exe_module = self.clip_ait_exe + bs = input_ids.shape[0] + position_ids = torch.arange(seqlen).expand((bs, -1)).cuda() + inputs = { + "input0": input_ids, + "input1": position_ids, + } + ys = [] + num_ouputs = len(exe_module.get_output_name_to_index_map()) + for i in range(num_ouputs): + shape = exe_module.get_output_maximum_shape(i) + ys.append(torch.empty(shape).cuda().half()) + exe_module.run_with_tensors(inputs, ys, graph_mode=True) + return ys[0].float() + + def vae_inference(self, vae_input): + exe_module = self.vae_ait_exe + inputs = [torch.permute(vae_input, (0, 2, 3, 1)).contiguous().cuda().half()] + ys = [] + num_ouputs = len(exe_module.get_output_name_to_index_map()) + for i in range(num_ouputs): + shape = exe_module.get_output_maximum_shape(i) + ys.append(torch.empty(shape).cuda().half()) + exe_module.run_with_tensors(inputs, ys, graph_mode=True) + vae_out = ys[0].permute((0, 3, 1, 2)).float() + return vae_out + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]], + init_image: Union[torch.FloatTensor, PIL.Image.Image], + strength: float = 0.8, + num_inference_steps: Optional[int] = 50, + guidance_scale: Optional[float] = 7.5, + eta: Optional[float] = 0.0, + generator: Optional[torch.Generator] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + init_image (`torch.FloatTensor` or `PIL.Image.Image`): + `Image`, or tensor representing an image batch, that will be used as the starting point for the + process. + strength (`float`, *optional*, defaults to 0.8): + Conceptually, indicates how much to transform the reference `init_image`. Must be between 0 and 1. + `init_image` will be used as a starting point, adding more noise to it the larger the `strength`. The + number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added + noise will be maximum and the denoising process will run for the full number of iterations specified in + `num_inference_steps`. A value of 1, therefore, essentially ignores `init_image`. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. This parameter will be modulated by `strength`. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation + deterministic. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + if isinstance(prompt, str): + batch_size = 1 + elif isinstance(prompt, list): + batch_size = len(prompt) + else: + raise ValueError( + f"`prompt` has to be of type `str` or `list` but is {type(prompt)}" + ) + + if strength < 0 or strength > 1: + raise ValueError( + f"The value of strength should in [0.0, 1.0] but is {strength}" + ) + + # set timesteps + accepts_offset = "offset" in set( + inspect.signature(self.scheduler.set_timesteps).parameters.keys() + ) + extra_set_kwargs = {} + offset = 0 + if accepts_offset: + offset = 1 + extra_set_kwargs["offset"] = 1 + + self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs) + + if isinstance(init_image, PIL.Image.Image): + init_image = preprocess(init_image) + + # encode the init image into latents and scale the latents + init_latent_dist = self.vae.encode(init_image.to(self.device)).latent_dist + init_latents = init_latent_dist.sample(generator=generator) + init_latents = 0.18215 * init_latents + + # expand init_latents for batch_size + init_latents = torch.cat([init_latents] * batch_size) + + # get the original timestep using init_timestep + init_timestep = int(num_inference_steps * strength) + offset + init_timestep = min(init_timestep, num_inference_steps) + if isinstance(self.scheduler, LMSDiscreteScheduler): + timesteps = torch.tensor( + [num_inference_steps - init_timestep] * batch_size, + dtype=torch.long, + device=self.device, + ) + else: + timesteps = self.scheduler.timesteps[-init_timestep] + timesteps = torch.tensor( + [timesteps] * batch_size, dtype=torch.long, device=self.device + ) + + # add noise to latents using the timesteps + noise = torch.randn(init_latents.shape, generator=generator, device=self.device) + init_latents = self.scheduler.add_noise(init_latents, noise, timesteps).to( + self.device + ) + + # get prompt text embeddings + text_input = self.tokenizer( + prompt, + padding="max_length", + max_length=64, # self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_embeddings = self.clip_inference(text_input.input_ids.to(self.device)) + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance: + max_length = text_input.input_ids.shape[-1] + uncond_input = self.tokenizer( + [""] * batch_size, + padding="max_length", + max_length=max_length, + return_tensors="pt", + ) + uncond_embeddings = self.clip_inference( + uncond_input.input_ids.to(self.device) + ) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + accepts_eta = "eta" in set( + inspect.signature(self.scheduler.step).parameters.keys() + ) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + latents = init_latents + + t_start = max(num_inference_steps - init_timestep + offset, 0) + for i, t in enumerate(self.progress_bar(self.scheduler.timesteps[t_start:])): + t_index = t_start + i + + # expand the latents if we are doing classifier free guidance + latent_model_input = ( + torch.cat([latents] * 2) if do_classifier_free_guidance else latents + ) + + # if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas + if isinstance(self.scheduler, LMSDiscreteScheduler): + sigma = self.scheduler.sigmas[t_index] + # the model input needs to be scaled to match the continuous ODE formulation in K-LMS + latent_model_input = latent_model_input / ((sigma ** 2 + 1) ** 0.5) + latent_model_input = latent_model_input.to(self.unet.dtype) + t = t.to(self.unet.dtype) + + # predict the noise residual + noise_pred = self.unet_inference( + latent_model_input, t, encoder_hidden_states=text_embeddings + ) + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * ( + noise_pred_text - noise_pred_uncond + ) + + # compute the previous noisy sample x_t -> x_t-1 + if isinstance(self.scheduler, LMSDiscreteScheduler): + latents = self.scheduler.step( + noise_pred, t_index, latents, **extra_step_kwargs + ).prev_sample + else: + latents = self.scheduler.step( + noise_pred, t, latents, **extra_step_kwargs + ).prev_sample + + # scale and decode the image latents with vae + latents = 1 / 0.18215 * latents + image = self.vae_inference(latents) + + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).numpy() + + # run safety checker + safety_cheker_input = self.feature_extractor( + self.numpy_to_pil(image), return_tensors="pt" + ).to(self.device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_cheker_input.pixel_values + ) + + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput( + images=image, nsfw_content_detected=has_nsfw_concept + )