Skip to content

Commit

Permalink
fix 4K OOM with VAE-tiling (#144)
Browse files Browse the repository at this point in the history
* add vae tiling function;

Signed-off-by: lawrence-cj <[email protected]>

* update README.md;

Signed-off-by: lawrence-cj <[email protected]>

* change vae tile size to 1024px

Signed-off-by: lawrence-cj <[email protected]>

---------

Signed-off-by: lawrence-cj <[email protected]>
  • Loading branch information
lawrence-cj authored Jan 12, 2025
1 parent e8f6f32 commit 7056b13
Show file tree
Hide file tree
Showing 5 changed files with 41 additions and 13 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ As a result, Sana-0.6B is very competitive with modern giant diffusion model (e.

## 🔥🔥 News

- (🔥 New) \[2025/1/12\] DC-AE tiling makes Sana-4K inferences 4096x4096px images within 22GB GPU memory.[\[Guidance\]](asset/docs/model_zoo.md#-3-4k-models)
- (🔥 New) \[2025/1/11\] Sana code-base license changed to Apache 2.0.
- (🔥 New) \[2025/1/10\] Inference Sana with 8bit quantization.[\[Guidance\]](asset/docs/8bit_sana.md#quantization)
- (🔥 New) \[2025/1/8\] 4K resolution [Sana models](asset/docs/model_zoo.md) is supported in [Sana-ComfyUI](https://github.com/Efficient-Large-Model/ComfyUI_ExtraModels) and [work flow](asset/docs/ComfyUI/Sana_FlowEuler_4K.json) is also prepared. [\[4K guidance\]](asset/docs/ComfyUI/comfyui.md)
Expand Down
11 changes: 4 additions & 7 deletions asset/docs/model_zoo.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,11 +77,9 @@ image = pipe(
image[0].save('sana.png')
```

#### 2). For 4K models
## ❗ 3. 4K models

4K models need [patch_conv](https://github.com/mit-han-lab/patch_conv) to avoid OOM issue.(80GB GPU is recommended)

run `pip install patch_conv` first, then
4K models need VAE tiling to avoid OOM issue.(24 GPU is recommended)

```python
# run `pip install git+https://github.com/huggingface/diffusers` before use Sana in diffusers
Expand All @@ -98,10 +96,9 @@ pipe.to("cuda")
pipe.vae.to(torch.bfloat16)
pipe.text_encoder.to(torch.bfloat16)

# for 4096x4096 image generation OOM issue
# for 4096x4096 image generation OOM issue, feel free adjust the tile size
if pipe.transformer.config.sample_size == 128:
from patch_conv import convert_model
pipe.vae = convert_model(pipe.vae, splits=32)
pipe.vae.enable_tiling(tile_sample_min_height=1024, tile_sample_min_width=1024)

prompt = 'a cyberpunk cat with a neon sign that says "Sana"'
image = pipe(
Expand Down
4 changes: 2 additions & 2 deletions configs/sana_config/2048ms/Sana_1600M_img2048_bf16.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ model:
- 8
# VAE setting
vae:
vae_type: dc-ae
vae_pretrained: mit-han-lab/dc-ae-f32c32-sana-1.0
vae_type: AutoencoderDC
vae_pretrained: mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers
scale_factor: 0.41407
vae_latent_dim: 32
vae_downsample_rate: 32
Expand Down
4 changes: 2 additions & 2 deletions configs/sana_config/4096ms/Sana_1600M_img4096_bf16.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ model:
- 8
# VAE setting
vae:
vae_type: dc-ae
vae_pretrained: mit-han-lab/dc-ae-f32c32-sana-1.0
vae_type: AutoencoderDC
vae_pretrained: mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers
scale_factor: 0.41407
vae_latent_dim: 32
vae_downsample_rate: 32
Expand Down
34 changes: 32 additions & 2 deletions diffusion/model/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# SPDX-License-Identifier: Apache-2.0

import torch
from diffusers import AutoencoderDC
from diffusers.models import AutoencoderKL
from mmcv import Registry
from termcolor import colored
Expand Down Expand Up @@ -87,6 +88,10 @@ def get_vae(name, model_path, device="cuda"):
print(colored(f"[DC-AE] Loading model from {model_path}", attrs=["bold"]))
dc_ae = DCAE_HF.from_pretrained(model_path).to(device).eval()
return dc_ae
elif "AutoencoderDC" in name:
print(colored(f"[AutoencoderDC] Loading model from {model_path}", attrs=["bold"]))
dc_ae = AutoencoderDC.from_pretrained(model_path).to(device).eval()
return dc_ae
else:
print("error load vae")
exit()
Expand All @@ -102,8 +107,14 @@ def vae_encode(name, vae, images, sample_posterior, device):
z = (z - vae.config.shift_factor) * vae.config.scaling_factor
elif "dc-ae" in name:
ae = vae
scaling_factor = ae.cfg.scaling_factor if ae.cfg.scaling_factor else 0.41407
z = ae.encode(images.to(device))
z = z * scaling_factor
elif "AutoencoderDC" in name:
ae = vae
scaling_factor = ae.config.scaling_factor if ae.config.scaling_factor else 0.41407
z = ae.encode(images.to(device))
z = z * ae.cfg.scaling_factor
z = z * scaling_factor
else:
print("error load vae")
exit()
Expand All @@ -116,7 +127,26 @@ def vae_decode(name, vae, latent):
samples = vae.decode(latent).sample
elif "dc-ae" in name:
ae = vae
samples = ae.decode(latent.detach() / ae.cfg.scaling_factor)
vae_scale_factor = (
2 ** (len(ae.config.encoder_block_out_channels) - 1)
if hasattr(ae, "config") and ae.config is not None
else 32
)
scaling_factor = ae.cfg.scaling_factor if ae.cfg.scaling_factor else 0.41407
if latent.shape[-1] * vae_scale_factor > 4000 or latent.shape[-2] * vae_scale_factor > 4000:
from patch_conv import convert_model

ae = convert_model(ae, splits=4)
samples = ae.decode(latent.detach() / scaling_factor)
elif "AutoencoderDC" in name:
ae = vae
scaling_factor = ae.config.scaling_factor if ae.config.scaling_factor else 0.41407
try:
samples = ae.decode(latent / scaling_factor, return_dict=False)[0]
except torch.cuda.OutOfMemoryError as e:
print("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
ae.enable_tiling(tile_sample_min_height=1024, tile_sample_min_width=1024)
samples = ae.decode(latent / scaling_factor, return_dict=False)[0]
else:
print("error load vae")
exit()
Expand Down

0 comments on commit 7056b13

Please sign in to comment.