From 7056b137237ff09d82e9ece3f9972f3360ab0bb1 Mon Sep 17 00:00:00 2001 From: Junsong Chen Date: Sun, 12 Jan 2025 15:25:44 +0800 Subject: [PATCH] fix 4K OOM with VAE-tiling (#144) * add vae tiling function; Signed-off-by: lawrence-cj * update README.md; Signed-off-by: lawrence-cj * change vae tile size to 1024px Signed-off-by: lawrence-cj --------- Signed-off-by: lawrence-cj --- README.md | 1 + asset/docs/model_zoo.md | 11 +++--- .../2048ms/Sana_1600M_img2048_bf16.yaml | 4 +-- .../4096ms/Sana_1600M_img4096_bf16.yaml | 4 +-- diffusion/model/builder.py | 34 +++++++++++++++++-- 5 files changed, 41 insertions(+), 13 deletions(-) diff --git a/README.md b/README.md index a14afda..5d66d98 100644 --- a/README.md +++ b/README.md @@ -36,6 +36,7 @@ As a result, Sana-0.6B is very competitive with modern giant diffusion model (e. ## 🔥🔥 News +- (🔥 New) \[2025/1/12\] DC-AE tiling makes Sana-4K inferences 4096x4096px images within 22GB GPU memory.[\[Guidance\]](asset/docs/model_zoo.md#-3-4k-models) - (🔥 New) \[2025/1/11\] Sana code-base license changed to Apache 2.0. - (🔥 New) \[2025/1/10\] Inference Sana with 8bit quantization.[\[Guidance\]](asset/docs/8bit_sana.md#quantization) - (🔥 New) \[2025/1/8\] 4K resolution [Sana models](asset/docs/model_zoo.md) is supported in [Sana-ComfyUI](https://github.com/Efficient-Large-Model/ComfyUI_ExtraModels) and [work flow](asset/docs/ComfyUI/Sana_FlowEuler_4K.json) is also prepared. [\[4K guidance\]](asset/docs/ComfyUI/comfyui.md) diff --git a/asset/docs/model_zoo.md b/asset/docs/model_zoo.md index 55d76f2..01ea915 100644 --- a/asset/docs/model_zoo.md +++ b/asset/docs/model_zoo.md @@ -77,11 +77,9 @@ image = pipe( image[0].save('sana.png') ``` -#### 2). For 4K models +## ❗ 3. 4K models -4K models need [patch_conv](https://github.com/mit-han-lab/patch_conv) to avoid OOM issue.(80GB GPU is recommended) - -run `pip install patch_conv` first, then +4K models need VAE tiling to avoid OOM issue.(24 GPU is recommended) ```python # run `pip install git+https://github.com/huggingface/diffusers` before use Sana in diffusers @@ -98,10 +96,9 @@ pipe.to("cuda") pipe.vae.to(torch.bfloat16) pipe.text_encoder.to(torch.bfloat16) -# for 4096x4096 image generation OOM issue +# for 4096x4096 image generation OOM issue, feel free adjust the tile size if pipe.transformer.config.sample_size == 128: - from patch_conv import convert_model - pipe.vae = convert_model(pipe.vae, splits=32) + pipe.vae.enable_tiling(tile_sample_min_height=1024, tile_sample_min_width=1024) prompt = 'a cyberpunk cat with a neon sign that says "Sana"' image = pipe( diff --git a/configs/sana_config/2048ms/Sana_1600M_img2048_bf16.yaml b/configs/sana_config/2048ms/Sana_1600M_img2048_bf16.yaml index 993459b..b519f24 100644 --- a/configs/sana_config/2048ms/Sana_1600M_img2048_bf16.yaml +++ b/configs/sana_config/2048ms/Sana_1600M_img2048_bf16.yaml @@ -41,8 +41,8 @@ model: - 8 # VAE setting vae: - vae_type: dc-ae - vae_pretrained: mit-han-lab/dc-ae-f32c32-sana-1.0 + vae_type: AutoencoderDC + vae_pretrained: mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers scale_factor: 0.41407 vae_latent_dim: 32 vae_downsample_rate: 32 diff --git a/configs/sana_config/4096ms/Sana_1600M_img4096_bf16.yaml b/configs/sana_config/4096ms/Sana_1600M_img4096_bf16.yaml index 96c3eb5..88002d8 100644 --- a/configs/sana_config/4096ms/Sana_1600M_img4096_bf16.yaml +++ b/configs/sana_config/4096ms/Sana_1600M_img4096_bf16.yaml @@ -41,8 +41,8 @@ model: - 8 # VAE setting vae: - vae_type: dc-ae - vae_pretrained: mit-han-lab/dc-ae-f32c32-sana-1.0 + vae_type: AutoencoderDC + vae_pretrained: mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers scale_factor: 0.41407 vae_latent_dim: 32 vae_downsample_rate: 32 diff --git a/diffusion/model/builder.py b/diffusion/model/builder.py index 751f2d5..89c7822 100755 --- a/diffusion/model/builder.py +++ b/diffusion/model/builder.py @@ -15,6 +15,7 @@ # SPDX-License-Identifier: Apache-2.0 import torch +from diffusers import AutoencoderDC from diffusers.models import AutoencoderKL from mmcv import Registry from termcolor import colored @@ -87,6 +88,10 @@ def get_vae(name, model_path, device="cuda"): print(colored(f"[DC-AE] Loading model from {model_path}", attrs=["bold"])) dc_ae = DCAE_HF.from_pretrained(model_path).to(device).eval() return dc_ae + elif "AutoencoderDC" in name: + print(colored(f"[AutoencoderDC] Loading model from {model_path}", attrs=["bold"])) + dc_ae = AutoencoderDC.from_pretrained(model_path).to(device).eval() + return dc_ae else: print("error load vae") exit() @@ -102,8 +107,14 @@ def vae_encode(name, vae, images, sample_posterior, device): z = (z - vae.config.shift_factor) * vae.config.scaling_factor elif "dc-ae" in name: ae = vae + scaling_factor = ae.cfg.scaling_factor if ae.cfg.scaling_factor else 0.41407 + z = ae.encode(images.to(device)) + z = z * scaling_factor + elif "AutoencoderDC" in name: + ae = vae + scaling_factor = ae.config.scaling_factor if ae.config.scaling_factor else 0.41407 z = ae.encode(images.to(device)) - z = z * ae.cfg.scaling_factor + z = z * scaling_factor else: print("error load vae") exit() @@ -116,7 +127,26 @@ def vae_decode(name, vae, latent): samples = vae.decode(latent).sample elif "dc-ae" in name: ae = vae - samples = ae.decode(latent.detach() / ae.cfg.scaling_factor) + vae_scale_factor = ( + 2 ** (len(ae.config.encoder_block_out_channels) - 1) + if hasattr(ae, "config") and ae.config is not None + else 32 + ) + scaling_factor = ae.cfg.scaling_factor if ae.cfg.scaling_factor else 0.41407 + if latent.shape[-1] * vae_scale_factor > 4000 or latent.shape[-2] * vae_scale_factor > 4000: + from patch_conv import convert_model + + ae = convert_model(ae, splits=4) + samples = ae.decode(latent.detach() / scaling_factor) + elif "AutoencoderDC" in name: + ae = vae + scaling_factor = ae.config.scaling_factor if ae.config.scaling_factor else 0.41407 + try: + samples = ae.decode(latent / scaling_factor, return_dict=False)[0] + except torch.cuda.OutOfMemoryError as e: + print("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.") + ae.enable_tiling(tile_sample_min_height=1024, tile_sample_min_width=1024) + samples = ae.decode(latent / scaling_factor, return_dict=False)[0] else: print("error load vae") exit()