From 9110d56b32d4d0a4b84130748859f302bc2cbce9 Mon Sep 17 00:00:00 2001 From: Enze Xie Date: Tue, 26 Nov 2024 16:19:48 +0800 Subject: [PATCH 1/7] update README.md. Signed-off-by: lawrence-cj --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 6a373c8..3d77bbd 100644 --- a/README.md +++ b/README.md @@ -35,6 +35,7 @@ As a result, Sana-0.6B is very competitive with modern giant diffusion model (e. ## 🔥🔥 News +- (🔥 New) \[2024/11\] 1.6B [Sana multi-linguistic models](https://huggingface.co/Efficient-Large-Model/Sana_1600M_512px_MultiLing) are released. Multi-language(Emoji & Chinese & English) are supported. - (🔥 New) \[2024/11\] 1.6B [Sana models](https://huggingface.co/collections/Efficient-Large-Model/sana-673efba2a57ed99843f11f9e) are released. - (🔥 New) \[2024/11\] Training & Inference & Metrics code are released. - (🔥 New) \[2024/11\] Working on [`diffusers`](https://github.com/huggingface/diffusers/pull/9982). @@ -144,7 +145,7 @@ save_image(image, 'output/sana.png', nrow=1, normalize=True, value_range=(-1, 1) ``` # Pull related models -huggingface-cli download google/gemma-2b-it +huggingface-cli download google/gemma-2b-it huggingface-cli download google/shieldgemma-2b huggingface-cli download mit-han-lab/dc-ae-f32c32-sana-1.0 huggingface-cli download Efficient-Large-Model/Sana_1600M_1024px @@ -158,7 +159,6 @@ docker run --gpus all --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 \ - ## 🔛 Run inference with TXT or JSON files ```bash From 1808a02199d0eb8242df55c3a66494a14fff3296 Mon Sep 17 00:00:00 2001 From: Johnny Date: Tue, 26 Nov 2024 11:37:57 +0100 Subject: [PATCH 2/7] Update pyproject.toml --- pyproject.toml | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 9eed756..26004e2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,7 +25,7 @@ dependencies = [ "gradio", "image-reward", "ipdb", - "mmcv==1.7.2", + "mmcv==2.2.0", "omegaconf", "opencv-python", "optimum", @@ -38,13 +38,13 @@ dependencies = [ "tensorboard", "tensorboardX", "timm", - "torchaudio==2.4.0", - "torchvision==0.19", + "torchaudio==2.5.0", + "torchvision==0.20.0", "transformers", - "triton==3.0.0", + "triton==3.1.0", "wandb", "webdataset", - "xformers==0.0.27.post2", + "xformers==0.0.28.post3", "yapf", "spaces", "matplotlib", From b1be116a8541f9224b45fd0a25ff6071551db07c Mon Sep 17 00:00:00 2001 From: Johnny Date: Tue, 26 Nov 2024 12:23:12 +0100 Subject: [PATCH 3/7] Update pyproject.toml --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 26004e2..2e1788e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,7 +44,7 @@ dependencies = [ "triton==3.1.0", "wandb", "webdataset", - "xformers==0.0.28.post3", + "xformers", "yapf", "spaces", "matplotlib", From d3f1c7e899db95d41ec401be465bd6918501571b Mon Sep 17 00:00:00 2001 From: junsong Date: Tue, 26 Nov 2024 07:00:32 -0800 Subject: [PATCH 4/7] 1. update app 2. fix the precision bug in model forward; Signed-off-by: lawrence-cj --- app/app_sana.py | 14 +++++++------- app/sana_pipeline.py | 3 +-- diffusion/model/nets/sana_multi_scale.py | 8 ++++---- 3 files changed, 12 insertions(+), 13 deletions(-) diff --git a/app/app_sana.py b/app/app_sana.py index 4410dde..0e7c451 100755 --- a/app/app_sana.py +++ b/app/app_sana.py @@ -115,6 +115,12 @@ INFER_SPEED = 0 +def norm_ip(img, low, high): + img.clamp_(min=low, max=high) + img.sub_(low).div_(max(high - low, 1e-5)) + return img + + def open_db(): db = sqlite3.connect(COUNTER_DB) db.execute("CREATE TABLE IF NOT EXISTS counter(app CHARS PRIMARY KEY UNIQUE, value INTEGER)") @@ -285,13 +291,7 @@ def generate( img = [save_image_sana(img, seed, save_img=save_image) for img in images] print(img) else: - if num_imgs > 1: - nrow = 2 - else: - nrow = 1 - img = make_grid(images, nrow=nrow, normalize=True, value_range=(-1, 1)) - img = img.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy() - img = [Image.fromarray(img.astype(np.uint8))] + img = [Image.fromarray(norm_ip(img, -1, 1).mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy().astype(np.uint8)) for img in images] torch.cuda.empty_cache() diff --git a/app/sana_pipeline.py b/app/sana_pipeline.py index a3251c1..40487e3 100644 --- a/app/sana_pipeline.py +++ b/app/sana_pipeline.py @@ -271,10 +271,9 @@ def forward( self.latent_size_w, generator=generator, device=self.device, - dtype=self.weight_dtype, ) else: - z = latents.to(self.weight_dtype).to(self.device) + z = latents.to(self.device) model_kwargs = dict(data_info={"img_hw": hw, "aspect_ratio": ar}, mask=emb_masks) if self.vis_sampler == "flow_euler": flow_solver = FlowEuler( diff --git a/diffusion/model/nets/sana_multi_scale.py b/diffusion/model/nets/sana_multi_scale.py index b79d230..ec570af 100755 --- a/diffusion/model/nets/sana_multi_scale.py +++ b/diffusion/model/nets/sana_multi_scale.py @@ -278,9 +278,9 @@ def forward(self, x, timestep, y, mask=None, data_info=None, **kwargs): y: (N, 1, 120, C) tensor of class labels """ bs = x.shape[0] - dtype = x.dtype - timestep = timestep.to(dtype) - y = y.to(dtype) + x = x.to(self.dtype) + timestep = timestep.to(self.dtype) + y = y.to(self.dtype) self.h, self.w = x.shape[-2] // self.patch_size, x.shape[-1] // self.patch_size if self.use_pe: x = self.x_embedder(x) @@ -296,7 +296,7 @@ def forward(self, x, timestep, y, mask=None, data_info=None, **kwargs): ) .unsqueeze(0) .to(x.device) - .to(dtype) + .to(self.dtype) ) x += self.pos_embed_ms # (N, T, D), where T = H * W / patch_size ** 2 else: From a1d1e141275a7d944f166908708a25bd6bf33716 Mon Sep 17 00:00:00 2001 From: lawrence-cj Date: Tue, 26 Nov 2024 23:45:40 +0800 Subject: [PATCH 5/7] pre-commit; Signed-off-by: lawrence-cj --- app/app_sana.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/app/app_sana.py b/app/app_sana.py index 0e7c451..8c9c30b 100755 --- a/app/app_sana.py +++ b/app/app_sana.py @@ -291,7 +291,19 @@ def generate( img = [save_image_sana(img, seed, save_img=save_image) for img in images] print(img) else: - img = [Image.fromarray(norm_ip(img, -1, 1).mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy().astype(np.uint8)) for img in images] + img = [ + Image.fromarray( + norm_ip(img, -1, 1) + .mul(255) + .add_(0.5) + .clamp_(0, 255) + .permute(1, 2, 0) + .to("cpu", torch.uint8) + .numpy() + .astype(np.uint8) + ) + for img in images + ] torch.cuda.empty_cache() From 94e7733746c2098800ef81639b524baffc910888 Mon Sep 17 00:00:00 2001 From: lawrence-cj Date: Wed, 27 Nov 2024 00:32:36 +0800 Subject: [PATCH 6/7] add a AdamW optimizer type as config file for reference. Signed-off-by: lawrence-cj --- .../1024ms/Sana_1600M_img1024_AdamW.yaml | 104 ++++++++++++++++++ 1 file changed, 104 insertions(+) create mode 100644 configs/sana_config/1024ms/Sana_1600M_img1024_AdamW.yaml diff --git a/configs/sana_config/1024ms/Sana_1600M_img1024_AdamW.yaml b/configs/sana_config/1024ms/Sana_1600M_img1024_AdamW.yaml new file mode 100644 index 0000000..429d998 --- /dev/null +++ b/configs/sana_config/1024ms/Sana_1600M_img1024_AdamW.yaml @@ -0,0 +1,104 @@ +data: + data_dir: [data/data_public/dir1] + image_size: 1024 + caption_proportion: + prompt: 1 + external_caption_suffixes: ['', _InternVL2-26B, _VILA1-5-13B] + external_clipscore_suffixes: + - _InternVL2-26B_clip_score + - _VILA1-5-13B_clip_score + - _prompt_clip_score + clip_thr_temperature: 0.1 + clip_thr: 25.0 + load_text_feat: false + load_vae_feat: false + transform: default_train + type: SanaWebDatasetMS + sort_dataset: false +# model config +model: + model: SanaMS_1600M_P1_D20 + image_size: 1024 + mixed_precision: fp16 # ['fp16', 'fp32', 'bf16'] + fp32_attention: true + load_from: + resume_from: + aspect_ratio_type: ASPECT_RATIO_1024 + multi_scale: true + #pe_interpolation: 1. + attn_type: linear + ffn_type: glumbconv + mlp_acts: + - silu + - silu + - + mlp_ratio: 2.5 + use_pe: false + qk_norm: false + class_dropout_prob: 0.1 + # PAG + pag_applied_layers: + - 8 +# VAE setting +vae: + vae_type: dc-ae + vae_pretrained: mit-han-lab/dc-ae-f32c32-sana-1.0 + scale_factor: 0.41407 + vae_latent_dim: 32 + vae_downsample_rate: 32 + sample_posterior: true +# text encoder +text_encoder: + text_encoder_name: gemma-2-2b-it + y_norm: true + y_norm_scale_factor: 0.01 + model_max_length: 300 + # CHI + chi_prompt: + - 'Given a user prompt, generate an "Enhanced prompt" that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt:' + - '- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes.' + - '- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.' + - 'Here are examples of how to transform or refine prompts:' + - '- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers.' + - '- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus passing by towering glass skyscrapers.' + - 'Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:' + - 'User Prompt: ' +# Sana schedule Flow +scheduler: + predict_v: true + noise_schedule: linear_flow + pred_sigma: false + flow_shift: 3.0 + # logit-normal timestep + weighting_scheme: logit_normal + logit_mean: 0.0 + logit_std: 1.0 + vis_sampler: flow_dpm-solver +# training setting +train: + num_workers: 10 + seed: 1 + train_batch_size: 64 + num_epochs: 100 + gradient_accumulation_steps: 1 + grad_checkpointing: true + gradient_clip: 0.1 + optimizer: + lr: 1.0e-4 + type: AdamW + weight_decay: 0.01 + eps: 1.0e-8 + betas: [0.9, 0.999] + lr_schedule: constant + lr_schedule_args: + num_warmup_steps: 2000 + local_save_vis: true # if save log image locally + visualize: true + eval_sampling_steps: 500 + log_interval: 20 + save_model_epochs: 5 + save_model_steps: 500 + work_dir: output/debug + online_metric: false + eval_metric_step: 2000 + online_metric_dir: metric_helper From 7d0d6595215fdcf0df2feb9f01d79579070b0721 Mon Sep 17 00:00:00 2001 From: junsong Date: Wed, 27 Nov 2024 01:37:11 +0800 Subject: [PATCH 7/7] fix all the model input z.to(weight_dtype) bugs; Signed-off-by: lawrence-cj --- scripts/inference.py | 1 - scripts/inference_dpg.py | 8 +++++++- scripts/inference_geneval.py | 8 +++++++- scripts/inference_image_reward.py | 16 +++++++++++----- scripts/interface.py | 2 +- 5 files changed, 26 insertions(+), 9 deletions(-) diff --git a/scripts/inference.py b/scripts/inference.py index 52d21fd..b98e1c8 100755 --- a/scripts/inference.py +++ b/scripts/inference.py @@ -146,7 +146,6 @@ def visualize(config, args, model, items, bs, sample_steps, cfg_scale, pag_scale latent_size, device=device, generator=generator, - dtype=weight_dtype, ) model_kwargs = dict(data_info={"img_hw": hw, "aspect_ratio": ar}, mask=emb_masks) diff --git a/scripts/inference_dpg.py b/scripts/inference_dpg.py index bde68c9..f9dc460 100644 --- a/scripts/inference_dpg.py +++ b/scripts/inference_dpg.py @@ -117,7 +117,12 @@ def visualize(items, bs, sample_steps, cfg_scale, pag_scale=1.0): with torch.no_grad(): n = len(prompts) z = torch.randn( - n, config.vae.vae_latent_dim, latent_size, latent_size, device=device, generator=generator + n, + config.vae.vae_latent_dim, + latent_size, + latent_size, + device=device, + generator=generator, ) model_kwargs = dict(data_info={"img_hw": hw, "aspect_ratio": ar}, mask=emb_masks) @@ -432,6 +437,7 @@ def guidance_type_select(default_guidance_type, pag_scale, attn_type): save_root = create_save_root(args, dataset, epoch_name, step_name, sample_steps, guidance_type) os.makedirs(save_root, exist_ok=True) if args.if_save_dirname and args.gpu_id == 0: + os.makedirs(f"{work_dir}/metrics", exist_ok=True) # save at work_dir/metrics/tmp_dpg_xxx.txt for metrics testing with open(f"{work_dir}/metrics/tmp_{dataset}_{time.time()}.txt", "w") as f: print(f"save tmp file at {work_dir}/metrics/tmp_{dataset}_{time.time()}.txt") diff --git a/scripts/inference_geneval.py b/scripts/inference_geneval.py index 07cef9f..9b3e78b 100644 --- a/scripts/inference_geneval.py +++ b/scripts/inference_geneval.py @@ -226,7 +226,12 @@ def visualize(sample_steps, cfg_scale, pag_scale): with torch.no_grad(): n = len(prompts) z = torch.randn( - n, config.vae.vae_latent_dim, latent_size, latent_size, device=device, generator=generator + n, + config.vae.vae_latent_dim, + latent_size, + latent_size, + device=device, + generator=generator, ) model_kwargs = dict(data_info={"img_hw": hw, "aspect_ratio": ar}, mask=emb_masks) @@ -535,6 +540,7 @@ def guidance_type_select(default_guidance_type, pag_scale, attn_type): save_root = create_save_root(args, args.dataset, epoch_name, step_name, sample_steps, guidance_type) os.makedirs(save_root, exist_ok=True) if args.if_save_dirname and args.gpu_id == 0: + os.makedirs(f"{work_dir}/metrics", exist_ok=True) # save at work_dir/metrics/tmp_geneval_xxx.txt for metrics testing with open(f"{work_dir}/metrics/tmp_geneval_{time.time()}.txt", "w") as f: print(f"save tmp file at {work_dir}/metrics/tmp_geneval_{time.time()}.txt") diff --git a/scripts/inference_image_reward.py b/scripts/inference_image_reward.py index fb3e1a5..71b6f46 100644 --- a/scripts/inference_image_reward.py +++ b/scripts/inference_image_reward.py @@ -118,7 +118,14 @@ def visualize(items, bs, sample_steps, cfg_scale, pag_scale=1.0): # start sampling with torch.no_grad(): n = len(prompts) - z = torch.randn(n, config.vae.vae_latent_dim, latent_size, latent_size, device=device, generator=generator) + z = torch.randn( + n, + config.vae.vae_latent_dim, + latent_size, + latent_size, + device=device, + generator=generator, + ) model_kwargs = dict(data_info={"img_hw": hw, "aspect_ratio": ar}, mask=emb_masks) if args.sampling_algo == "dpm-solver": @@ -205,7 +212,6 @@ def get_args(): class SanaInference(SanaConfig): config: str = "" model_path: Optional[str] = field(default=None, metadata={"help": "Path to the model file (optional)"}) - version: str = "sigma" txt_file: str = "asset/samples.txt" json_file: Optional[str] = None sample_nums: int = 100_000 @@ -214,7 +220,7 @@ class SanaInference(SanaConfig): cfg_scale: float = 4.5 pag_scale: float = 1.0 sampling_algo: str = field( - default="dpm-solver", metadata={"choices": ["dpm-solver", "sa-solver", "flow_euler", "flow_dpm-solver"]} + default="flow_dpm-solver", metadata={"choices": ["dpm-solver", "sa-solver", "flow_euler", "flow_dpm-solver"]} ) seed: int = 0 dataset: str = "custom" @@ -233,7 +239,6 @@ class SanaInference(SanaConfig): default=None, metadata={"help": "A list value, like [0, 1.] for ablation"} ) ablation_key: Optional[str] = field(default=None, metadata={"choices": ["step", "cfg_scale", "pag_scale"]}) - debug: bool = False if_save_dirname: bool = field( default=False, metadata={"help": "if save img save dir name at wor_dir/metrics/tmp_time.time().txt for metric testing"}, @@ -244,7 +249,6 @@ class SanaInference(SanaConfig): args = get_args() config = args = pyrallis.parse(config_class=SanaInference, config_path=args.config) - # config = read_config(args.config) args.image_size = config.model.image_size if args.custom_image_size: @@ -311,6 +315,7 @@ class SanaInference(SanaConfig): "linear_head_dim": config.model.linear_head_dim, "pred_sigma": pred_sigma, "learn_sigma": learn_sigma, + "use_fp32_attention": getattr(config.model, "fp32_attention", False), } model = build_model(config.model.model, **model_kwargs).to(device) logger.info( @@ -411,6 +416,7 @@ def guidance_type_select(default_guidance_type, pag_scale, attn_type): save_root = create_save_root(args, dataset, epoch_name, step_name, sample_steps, guidance_type) os.makedirs(save_root, exist_ok=True) if args.if_save_dirname and args.gpu_id == 0: + os.makedirs(f"{work_dir}/metrics", exist_ok=True) # save at work_dir/metrics/tmp_xxx.txt for metrics testing with open(f"{work_dir}/metrics/tmp_{dataset}_{time.time()}.txt", "w") as f: print(f"save tmp file at {work_dir}/metrics/tmp_{dataset}_{time.time()}.txt") diff --git a/scripts/interface.py b/scripts/interface.py index d81199c..bacf28d 100755 --- a/scripts/interface.py +++ b/scripts/interface.py @@ -165,7 +165,7 @@ def generate_img( n = len(prompts) latent_size_h, latent_size_w = height // config.vae.vae_downsample_rate, width // config.vae.vae_downsample_rate - z = torch.randn(n, config.vae.vae_latent_dim, latent_size_h, latent_size_w, device=device, dtype=weight_dtype) + z = torch.randn(n, config.vae.vae_latent_dim, latent_size_h, latent_size_w, device=device) model_kwargs = dict(data_info={"img_hw": (latent_size_h, latent_size_w), "aspect_ratio": 1.0}, mask=emb_masks) print(f"Latent Size: {z.shape}") # Sample images: