From 9110d56b32d4d0a4b84130748859f302bc2cbce9 Mon Sep 17 00:00:00 2001 From: Enze Xie Date: Tue, 26 Nov 2024 16:19:48 +0800 Subject: [PATCH 01/14] update README.md. Signed-off-by: lawrence-cj --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 6a373c8..3d77bbd 100644 --- a/README.md +++ b/README.md @@ -35,6 +35,7 @@ As a result, Sana-0.6B is very competitive with modern giant diffusion model (e. ## 🔥🔥 News +- (🔥 New) \[2024/11\] 1.6B [Sana multi-linguistic models](https://huggingface.co/Efficient-Large-Model/Sana_1600M_512px_MultiLing) are released. Multi-language(Emoji & Chinese & English) are supported. - (🔥 New) \[2024/11\] 1.6B [Sana models](https://huggingface.co/collections/Efficient-Large-Model/sana-673efba2a57ed99843f11f9e) are released. - (🔥 New) \[2024/11\] Training & Inference & Metrics code are released. - (🔥 New) \[2024/11\] Working on [`diffusers`](https://github.com/huggingface/diffusers/pull/9982). @@ -144,7 +145,7 @@ save_image(image, 'output/sana.png', nrow=1, normalize=True, value_range=(-1, 1) ``` # Pull related models -huggingface-cli download google/gemma-2b-it +huggingface-cli download google/gemma-2b-it huggingface-cli download google/shieldgemma-2b huggingface-cli download mit-han-lab/dc-ae-f32c32-sana-1.0 huggingface-cli download Efficient-Large-Model/Sana_1600M_1024px @@ -158,7 +159,6 @@ docker run --gpus all --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 \ - ## đź”› Run inference with TXT or JSON files ```bash From d3f1c7e899db95d41ec401be465bd6918501571b Mon Sep 17 00:00:00 2001 From: junsong Date: Tue, 26 Nov 2024 07:00:32 -0800 Subject: [PATCH 02/14] 1. update app 2. fix the precision bug in model forward; Signed-off-by: lawrence-cj --- app/app_sana.py | 14 +++++++------- app/sana_pipeline.py | 3 +-- diffusion/model/nets/sana_multi_scale.py | 8 ++++---- 3 files changed, 12 insertions(+), 13 deletions(-) diff --git a/app/app_sana.py b/app/app_sana.py index 4410dde..0e7c451 100755 --- a/app/app_sana.py +++ b/app/app_sana.py @@ -115,6 +115,12 @@ INFER_SPEED = 0 +def norm_ip(img, low, high): + img.clamp_(min=low, max=high) + img.sub_(low).div_(max(high - low, 1e-5)) + return img + + def open_db(): db = sqlite3.connect(COUNTER_DB) db.execute("CREATE TABLE IF NOT EXISTS counter(app CHARS PRIMARY KEY UNIQUE, value INTEGER)") @@ -285,13 +291,7 @@ def generate( img = [save_image_sana(img, seed, save_img=save_image) for img in images] print(img) else: - if num_imgs > 1: - nrow = 2 - else: - nrow = 1 - img = make_grid(images, nrow=nrow, normalize=True, value_range=(-1, 1)) - img = img.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy() - img = [Image.fromarray(img.astype(np.uint8))] + img = [Image.fromarray(norm_ip(img, -1, 1).mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy().astype(np.uint8)) for img in images] torch.cuda.empty_cache() diff --git a/app/sana_pipeline.py b/app/sana_pipeline.py index a3251c1..40487e3 100644 --- a/app/sana_pipeline.py +++ b/app/sana_pipeline.py @@ -271,10 +271,9 @@ def forward( self.latent_size_w, generator=generator, device=self.device, - dtype=self.weight_dtype, ) else: - z = latents.to(self.weight_dtype).to(self.device) + z = latents.to(self.device) model_kwargs = dict(data_info={"img_hw": hw, "aspect_ratio": ar}, mask=emb_masks) if self.vis_sampler == "flow_euler": flow_solver = FlowEuler( diff --git a/diffusion/model/nets/sana_multi_scale.py b/diffusion/model/nets/sana_multi_scale.py index b79d230..ec570af 100755 --- a/diffusion/model/nets/sana_multi_scale.py +++ b/diffusion/model/nets/sana_multi_scale.py @@ -278,9 +278,9 @@ def forward(self, x, timestep, y, mask=None, data_info=None, **kwargs): y: (N, 1, 120, C) tensor of class labels """ bs = x.shape[0] - dtype = x.dtype - timestep = timestep.to(dtype) - y = y.to(dtype) + x = x.to(self.dtype) + timestep = timestep.to(self.dtype) + y = y.to(self.dtype) self.h, self.w = x.shape[-2] // self.patch_size, x.shape[-1] // self.patch_size if self.use_pe: x = self.x_embedder(x) @@ -296,7 +296,7 @@ def forward(self, x, timestep, y, mask=None, data_info=None, **kwargs): ) .unsqueeze(0) .to(x.device) - .to(dtype) + .to(self.dtype) ) x += self.pos_embed_ms # (N, T, D), where T = H * W / patch_size ** 2 else: From a1d1e141275a7d944f166908708a25bd6bf33716 Mon Sep 17 00:00:00 2001 From: lawrence-cj Date: Tue, 26 Nov 2024 23:45:40 +0800 Subject: [PATCH 03/14] pre-commit; Signed-off-by: lawrence-cj --- app/app_sana.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/app/app_sana.py b/app/app_sana.py index 0e7c451..8c9c30b 100755 --- a/app/app_sana.py +++ b/app/app_sana.py @@ -291,7 +291,19 @@ def generate( img = [save_image_sana(img, seed, save_img=save_image) for img in images] print(img) else: - img = [Image.fromarray(norm_ip(img, -1, 1).mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy().astype(np.uint8)) for img in images] + img = [ + Image.fromarray( + norm_ip(img, -1, 1) + .mul(255) + .add_(0.5) + .clamp_(0, 255) + .permute(1, 2, 0) + .to("cpu", torch.uint8) + .numpy() + .astype(np.uint8) + ) + for img in images + ] torch.cuda.empty_cache() From 94e7733746c2098800ef81639b524baffc910888 Mon Sep 17 00:00:00 2001 From: lawrence-cj Date: Wed, 27 Nov 2024 00:32:36 +0800 Subject: [PATCH 04/14] add a AdamW optimizer type as config file for reference. Signed-off-by: lawrence-cj --- .../1024ms/Sana_1600M_img1024_AdamW.yaml | 104 ++++++++++++++++++ 1 file changed, 104 insertions(+) create mode 100644 configs/sana_config/1024ms/Sana_1600M_img1024_AdamW.yaml diff --git a/configs/sana_config/1024ms/Sana_1600M_img1024_AdamW.yaml b/configs/sana_config/1024ms/Sana_1600M_img1024_AdamW.yaml new file mode 100644 index 0000000..429d998 --- /dev/null +++ b/configs/sana_config/1024ms/Sana_1600M_img1024_AdamW.yaml @@ -0,0 +1,104 @@ +data: + data_dir: [data/data_public/dir1] + image_size: 1024 + caption_proportion: + prompt: 1 + external_caption_suffixes: ['', _InternVL2-26B, _VILA1-5-13B] + external_clipscore_suffixes: + - _InternVL2-26B_clip_score + - _VILA1-5-13B_clip_score + - _prompt_clip_score + clip_thr_temperature: 0.1 + clip_thr: 25.0 + load_text_feat: false + load_vae_feat: false + transform: default_train + type: SanaWebDatasetMS + sort_dataset: false +# model config +model: + model: SanaMS_1600M_P1_D20 + image_size: 1024 + mixed_precision: fp16 # ['fp16', 'fp32', 'bf16'] + fp32_attention: true + load_from: + resume_from: + aspect_ratio_type: ASPECT_RATIO_1024 + multi_scale: true + #pe_interpolation: 1. + attn_type: linear + ffn_type: glumbconv + mlp_acts: + - silu + - silu + - + mlp_ratio: 2.5 + use_pe: false + qk_norm: false + class_dropout_prob: 0.1 + # PAG + pag_applied_layers: + - 8 +# VAE setting +vae: + vae_type: dc-ae + vae_pretrained: mit-han-lab/dc-ae-f32c32-sana-1.0 + scale_factor: 0.41407 + vae_latent_dim: 32 + vae_downsample_rate: 32 + sample_posterior: true +# text encoder +text_encoder: + text_encoder_name: gemma-2-2b-it + y_norm: true + y_norm_scale_factor: 0.01 + model_max_length: 300 + # CHI + chi_prompt: + - 'Given a user prompt, generate an "Enhanced prompt" that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt:' + - '- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes.' + - '- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.' + - 'Here are examples of how to transform or refine prompts:' + - '- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers.' + - '- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus passing by towering glass skyscrapers.' + - 'Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:' + - 'User Prompt: ' +# Sana schedule Flow +scheduler: + predict_v: true + noise_schedule: linear_flow + pred_sigma: false + flow_shift: 3.0 + # logit-normal timestep + weighting_scheme: logit_normal + logit_mean: 0.0 + logit_std: 1.0 + vis_sampler: flow_dpm-solver +# training setting +train: + num_workers: 10 + seed: 1 + train_batch_size: 64 + num_epochs: 100 + gradient_accumulation_steps: 1 + grad_checkpointing: true + gradient_clip: 0.1 + optimizer: + lr: 1.0e-4 + type: AdamW + weight_decay: 0.01 + eps: 1.0e-8 + betas: [0.9, 0.999] + lr_schedule: constant + lr_schedule_args: + num_warmup_steps: 2000 + local_save_vis: true # if save log image locally + visualize: true + eval_sampling_steps: 500 + log_interval: 20 + save_model_epochs: 5 + save_model_steps: 500 + work_dir: output/debug + online_metric: false + eval_metric_step: 2000 + online_metric_dir: metric_helper From 7d0d6595215fdcf0df2feb9f01d79579070b0721 Mon Sep 17 00:00:00 2001 From: junsong Date: Wed, 27 Nov 2024 01:37:11 +0800 Subject: [PATCH 05/14] fix all the model input z.to(weight_dtype) bugs; Signed-off-by: lawrence-cj --- scripts/inference.py | 1 - scripts/inference_dpg.py | 8 +++++++- scripts/inference_geneval.py | 8 +++++++- scripts/inference_image_reward.py | 16 +++++++++++----- scripts/interface.py | 2 +- 5 files changed, 26 insertions(+), 9 deletions(-) diff --git a/scripts/inference.py b/scripts/inference.py index 52d21fd..b98e1c8 100755 --- a/scripts/inference.py +++ b/scripts/inference.py @@ -146,7 +146,6 @@ def visualize(config, args, model, items, bs, sample_steps, cfg_scale, pag_scale latent_size, device=device, generator=generator, - dtype=weight_dtype, ) model_kwargs = dict(data_info={"img_hw": hw, "aspect_ratio": ar}, mask=emb_masks) diff --git a/scripts/inference_dpg.py b/scripts/inference_dpg.py index bde68c9..f9dc460 100644 --- a/scripts/inference_dpg.py +++ b/scripts/inference_dpg.py @@ -117,7 +117,12 @@ def visualize(items, bs, sample_steps, cfg_scale, pag_scale=1.0): with torch.no_grad(): n = len(prompts) z = torch.randn( - n, config.vae.vae_latent_dim, latent_size, latent_size, device=device, generator=generator + n, + config.vae.vae_latent_dim, + latent_size, + latent_size, + device=device, + generator=generator, ) model_kwargs = dict(data_info={"img_hw": hw, "aspect_ratio": ar}, mask=emb_masks) @@ -432,6 +437,7 @@ def guidance_type_select(default_guidance_type, pag_scale, attn_type): save_root = create_save_root(args, dataset, epoch_name, step_name, sample_steps, guidance_type) os.makedirs(save_root, exist_ok=True) if args.if_save_dirname and args.gpu_id == 0: + os.makedirs(f"{work_dir}/metrics", exist_ok=True) # save at work_dir/metrics/tmp_dpg_xxx.txt for metrics testing with open(f"{work_dir}/metrics/tmp_{dataset}_{time.time()}.txt", "w") as f: print(f"save tmp file at {work_dir}/metrics/tmp_{dataset}_{time.time()}.txt") diff --git a/scripts/inference_geneval.py b/scripts/inference_geneval.py index 07cef9f..9b3e78b 100644 --- a/scripts/inference_geneval.py +++ b/scripts/inference_geneval.py @@ -226,7 +226,12 @@ def visualize(sample_steps, cfg_scale, pag_scale): with torch.no_grad(): n = len(prompts) z = torch.randn( - n, config.vae.vae_latent_dim, latent_size, latent_size, device=device, generator=generator + n, + config.vae.vae_latent_dim, + latent_size, + latent_size, + device=device, + generator=generator, ) model_kwargs = dict(data_info={"img_hw": hw, "aspect_ratio": ar}, mask=emb_masks) @@ -535,6 +540,7 @@ def guidance_type_select(default_guidance_type, pag_scale, attn_type): save_root = create_save_root(args, args.dataset, epoch_name, step_name, sample_steps, guidance_type) os.makedirs(save_root, exist_ok=True) if args.if_save_dirname and args.gpu_id == 0: + os.makedirs(f"{work_dir}/metrics", exist_ok=True) # save at work_dir/metrics/tmp_geneval_xxx.txt for metrics testing with open(f"{work_dir}/metrics/tmp_geneval_{time.time()}.txt", "w") as f: print(f"save tmp file at {work_dir}/metrics/tmp_geneval_{time.time()}.txt") diff --git a/scripts/inference_image_reward.py b/scripts/inference_image_reward.py index fb3e1a5..71b6f46 100644 --- a/scripts/inference_image_reward.py +++ b/scripts/inference_image_reward.py @@ -118,7 +118,14 @@ def visualize(items, bs, sample_steps, cfg_scale, pag_scale=1.0): # start sampling with torch.no_grad(): n = len(prompts) - z = torch.randn(n, config.vae.vae_latent_dim, latent_size, latent_size, device=device, generator=generator) + z = torch.randn( + n, + config.vae.vae_latent_dim, + latent_size, + latent_size, + device=device, + generator=generator, + ) model_kwargs = dict(data_info={"img_hw": hw, "aspect_ratio": ar}, mask=emb_masks) if args.sampling_algo == "dpm-solver": @@ -205,7 +212,6 @@ def get_args(): class SanaInference(SanaConfig): config: str = "" model_path: Optional[str] = field(default=None, metadata={"help": "Path to the model file (optional)"}) - version: str = "sigma" txt_file: str = "asset/samples.txt" json_file: Optional[str] = None sample_nums: int = 100_000 @@ -214,7 +220,7 @@ class SanaInference(SanaConfig): cfg_scale: float = 4.5 pag_scale: float = 1.0 sampling_algo: str = field( - default="dpm-solver", metadata={"choices": ["dpm-solver", "sa-solver", "flow_euler", "flow_dpm-solver"]} + default="flow_dpm-solver", metadata={"choices": ["dpm-solver", "sa-solver", "flow_euler", "flow_dpm-solver"]} ) seed: int = 0 dataset: str = "custom" @@ -233,7 +239,6 @@ class SanaInference(SanaConfig): default=None, metadata={"help": "A list value, like [0, 1.] for ablation"} ) ablation_key: Optional[str] = field(default=None, metadata={"choices": ["step", "cfg_scale", "pag_scale"]}) - debug: bool = False if_save_dirname: bool = field( default=False, metadata={"help": "if save img save dir name at wor_dir/metrics/tmp_time.time().txt for metric testing"}, @@ -244,7 +249,6 @@ class SanaInference(SanaConfig): args = get_args() config = args = pyrallis.parse(config_class=SanaInference, config_path=args.config) - # config = read_config(args.config) args.image_size = config.model.image_size if args.custom_image_size: @@ -311,6 +315,7 @@ class SanaInference(SanaConfig): "linear_head_dim": config.model.linear_head_dim, "pred_sigma": pred_sigma, "learn_sigma": learn_sigma, + "use_fp32_attention": getattr(config.model, "fp32_attention", False), } model = build_model(config.model.model, **model_kwargs).to(device) logger.info( @@ -411,6 +416,7 @@ def guidance_type_select(default_guidance_type, pag_scale, attn_type): save_root = create_save_root(args, dataset, epoch_name, step_name, sample_steps, guidance_type) os.makedirs(save_root, exist_ok=True) if args.if_save_dirname and args.gpu_id == 0: + os.makedirs(f"{work_dir}/metrics", exist_ok=True) # save at work_dir/metrics/tmp_xxx.txt for metrics testing with open(f"{work_dir}/metrics/tmp_{dataset}_{time.time()}.txt", "w") as f: print(f"save tmp file at {work_dir}/metrics/tmp_{dataset}_{time.time()}.txt") diff --git a/scripts/interface.py b/scripts/interface.py index d81199c..bacf28d 100755 --- a/scripts/interface.py +++ b/scripts/interface.py @@ -165,7 +165,7 @@ def generate_img( n = len(prompts) latent_size_h, latent_size_w = height // config.vae.vae_downsample_rate, width // config.vae.vae_downsample_rate - z = torch.randn(n, config.vae.vae_latent_dim, latent_size_h, latent_size_w, device=device, dtype=weight_dtype) + z = torch.randn(n, config.vae.vae_latent_dim, latent_size_h, latent_size_w, device=device) model_kwargs = dict(data_info={"img_hw": (latent_size_h, latent_size_w), "aspect_ratio": 1.0}, mask=emb_masks) print(f"Latent Size: {z.shape}") # Sample images: From beb7785dd3b1320546ffdd39d6ed3fa38a62a581 Mon Sep 17 00:00:00 2001 From: Muinez Date: Tue, 26 Nov 2024 23:32:07 +0300 Subject: [PATCH 06/14] add caching & bucketing --- diffusion/utils/config.py | 1 + train_scripts/make_buckets.py | 136 +++++ train_scripts/train_local.py | 939 ++++++++++++++++++++++++++++++++++ train_scripts/train_local.sh | 25 + 4 files changed, 1101 insertions(+) create mode 100644 train_scripts/make_buckets.py create mode 100644 train_scripts/train_local.py create mode 100644 train_scripts/train_local.sh diff --git a/diffusion/utils/config.py b/diffusion/utils/config.py index 209a076..e688e49 100644 --- a/diffusion/utils/config.py +++ b/diffusion/utils/config.py @@ -23,6 +23,7 @@ def __str__(self): @dataclass class DataConfig(BaseConfig): data_dir: List[Optional[str]] = field(default_factory=list) + buckets_file: str = "buckets.json" caption_proportion: Dict[str, int] = field(default_factory=lambda: {"prompt": 1}) external_caption_suffixes: List[str] = field(default_factory=list) external_clipscore_suffixes: List[str] = field(default_factory=list) diff --git a/train_scripts/make_buckets.py b/train_scripts/make_buckets.py new file mode 100644 index 0000000..4a4f64f --- /dev/null +++ b/train_scripts/make_buckets.py @@ -0,0 +1,136 @@ +import torch +from diffusion.model.builder import get_vae, vae_encode +from diffusion.utils.config import SanaConfig +import pyrallis +from PIL import Image +import torchvision.transforms as T +import os +import os.path as osp +from torchvision.transforms import InterpolationMode +import json +from torch.utils.data import DataLoader +from tqdm import tqdm +import math +from itertools import chain + +@pyrallis.wrap() +def main(config: SanaConfig) -> None: + preferred_pixel_count = config.model.image_size * config.model.image_size + + min_size = config.model.image_size // 2 + max_size = config.model.image_size * 2 + step = 32 + + ratios_array = [] + while(min_size != max_size): + width = int(preferred_pixel_count / min_size) + if(width % step != 0): + mod = width % step + if(mod < step//2): + width -= mod + else: + width += step - mod + + ratio = min_size / width + + ratios_array.append((ratio, (int(min_size), width))) + min_size += step + + def get_closest_ratio(height: float, width: float): + aspect_ratio = height / width + closest_ratio = min(ratios_array, key=lambda ratio: abs(ratio[0] - aspect_ratio)) + return closest_ratio + + def get_preffered_size(height: float, width: float): + pixel_count = height * width + + scale = math.sqrt(pixel_count / preferred_pixel_count) + return height / scale, width / scale + + class BucketsDataset(torch.utils.data.Dataset): + def __init__(self, data_dir, skip_files): + valid_extensions = {".png", ".jpg", ".jpeg", ".bmp", ".tiff", ".webp"} + self.files = ([ + osp.join(data_dir, f) for f in os.listdir(data_dir) + if osp.isfile(osp.join(data_dir, f)) and osp.splitext(f)[1].lower() in valid_extensions and osp.join(data_dir, f) not in skip_files ]) + + self.transform = T.Compose([ + T.ToTensor(), + T.Normalize([0.5], [0.5]), + ]) + + def __len__(self): + return len(self.files) + + def __getitem__(self, idx): + path = self.files[idx] + img = Image.open(path).convert("RGB") + ratio = get_closest_ratio(img.height, img.width) + prefsize = get_preffered_size(img.height, img.width) + + crop = T.Resize(ratio[1], interpolation=InterpolationMode.BICUBIC) + return { + 'img': self.transform(crop(img)), + 'size': torch.tensor([ratio[1][0], ratio[1][1]]), + 'prefsize': torch.tensor([prefsize[0], prefsize[1]]), + 'ratio': ratio[0], + 'path': path + } + + vae = get_vae(config.vae.vae_type, config.vae.vae_pretrained, "cuda").to(torch.float16) + + def encode_images(batch, vae): + with torch.no_grad(): + z = vae_encode( + config.vae.vae_type, vae, batch, + sample_posterior=config.vae.sample_posterior, # Adjust as necessary + device="cuda" + ) + return z + + if os.path.exists(config.data.buckets_file): + with open(config.data.buckets_file, 'r') as json_file: + buckets = json.load(json_file) + existings_images = set(chain.from_iterable(buckets.values())) + else: + buckets = {} + existings_images = set() + + def add_to_list(key, item): + if key in buckets: + buckets[key].append(item) + else: + buckets[key] = [item] + + for path in config.data.data_dir: + print(f'Processing {path}') + dataset = BucketsDataset(path, existings_images) + dataloader = DataLoader(dataset, batch_size=1) + for batch in tqdm(dataloader): + img = batch['img'] + size = batch['size'] + ratio = batch['ratio'] + image_path = batch['path'] + prefsize = batch['prefsize'] + + encoded = encode_images(img.to(torch.half), vae) + + for i in range(0, len(encoded)): + filename_wo_ext = os.path.splitext(os.path.basename(image_path[i]))[0] + add_to_list(str(ratio[i].item()), image_path[i]) + + torch.save({ + 'img': encoded[i].detach().clone(), + 'size': size[i], + 'prefsize': prefsize[i], + 'ratio': ratio[i] + }, f"{path}/{filename_wo_ext}_img.npz") + + with open(config.data.buckets_file, 'w') as json_file: + json.dump(buckets, json_file, indent=4) + + for ratio in buckets.keys(): + print(f'{float(ratio):.2f}: {len(buckets[ratio])}') + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/train_scripts/train_local.py b/train_scripts/train_local.py new file mode 100644 index 0000000..084cf83 --- /dev/null +++ b/train_scripts/train_local.py @@ -0,0 +1,939 @@ +# Copyright 2024 NVIDIA CORPORATION & AFFILIATES +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +import datetime +import getpass +import hashlib +import json +import os +import os.path as osp +import random +import time +import types +import warnings +from pathlib import Path + +import numpy as np +import pyrallis +import torch +from accelerate import Accelerator, InitProcessGroupKwargs +from accelerate.utils import DistributedType +from PIL import Image +from termcolor import colored +import torch.utils +import torch.utils.data +warnings.filterwarnings("ignore") # ignore warning + +from diffusion import DPMS, FlowEuler, Scheduler +from diffusion.model.builder import build_model, get_tokenizer_and_text_encoder, get_vae, vae_decode, vae_encode +from diffusion.model.respace import compute_density_for_timestep_sampling +from diffusion.utils.checkpoint import load_checkpoint, save_checkpoint +from diffusion.utils.config import SanaConfig +from diffusion.utils.dist_utils import clip_grad_norm_, flush, get_world_size +from diffusion.utils.logger import LogBuffer, get_root_logger +from diffusion.utils.lr_scheduler import build_lr_scheduler +from diffusion.utils.misc import DebugUnderflowOverflow, init_random_seed, set_random_seed +from diffusion.utils.optimizer import build_optimizer +import json +import random +import math +import gc + +os.environ["TOKENIZERS_PARALLELISM"] = "false" + + +def set_fsdp_env(): + os.environ["ACCELERATE_USE_FSDP"] = "true" + os.environ["FSDP_AUTO_WRAP_POLICY"] = "TRANSFORMER_BASED_WRAP" + os.environ["FSDP_BACKWARD_PREFETCH"] = "BACKWARD_PRE" + os.environ["FSDP_TRANSFORMER_CLS_TO_WRAP"] = "SanaBlock" + +image_index = 0 +@torch.inference_mode() +def log_validation(accelerator, config, model, logger, step, device, vae=None, init_noise=None): + + torch.cuda.empty_cache() + vis_sampler = config.scheduler.vis_sampler + model = accelerator.unwrap_model(model).eval() + hw = torch.tensor([[image_size, image_size]], dtype=torch.float, device=device).repeat(1, 1) + ar = torch.tensor([[1.0]], device=device).repeat(1, 1) + null_y = torch.load(null_embed_path, map_location="cpu") + null_y = null_y["uncond_prompt_embeds"].to(device) + + # Create sampling noise: + logger.info("Running validation... ") + image_logs = [] + + def run_sampling(init_z=None, label_suffix="", vae=None, sampler="dpm-solver"): + latents = [] + current_image_logs = [] + + for prompt in validation_prompts: + logger.info(prompt) + z = ( + torch.randn(1, config.vae.vae_latent_dim, latent_size, latent_size, device=device) + if init_z is None + else init_z + ) + embed = torch.load( + osp.join(config.train.valid_prompt_embed_root, f"{prompt[:50]}_{valid_prompt_embed_suffix}"), + map_location="cpu", + ) + caption_embs, emb_masks = embed["caption_embeds"].to(device), embed["emb_mask"].to(device) + # caption_embs = caption_embs[:, None] + # emb_masks = emb_masks[:, None] + model_kwargs = dict(data_info={"img_hw": hw, "aspect_ratio": ar}, mask=emb_masks) + + if sampler == "dpm-solver": + dpm_solver = DPMS( + model.forward_with_dpmsolver, + condition=caption_embs, + uncondition=null_y, + cfg_scale=4.5, + model_kwargs=model_kwargs, + ) + denoised = dpm_solver.sample( + z, + steps=14, + order=2, + skip_type="time_uniform", + method="multistep", + ) + elif sampler == "flow_euler": + flow_solver = FlowEuler( + model, condition=caption_embs, uncondition=null_y, cfg_scale=5.5, model_kwargs=model_kwargs + ) + denoised = flow_solver.sample(z, steps=28) + elif sampler == "flow_dpm-solver": + dpm_solver = DPMS( + model.forward_with_dpmsolver, + condition=caption_embs, + uncondition=null_y, + cfg_scale=5.5, + model_type="flow", + model_kwargs=model_kwargs, + schedule="FLOW", + ) + + denoised = dpm_solver.sample( + z, + steps=24, + order=2, + skip_type="time_uniform_flow", + method="multistep", + flow_shift=config.scheduler.flow_shift, + ) + else: + raise ValueError(f"{sampler} not implemented") + + latents.append(denoised) + torch.cuda.empty_cache() + + del_vae = False + if vae is None: + vae = get_vae(config.vae.vae_type, config.vae.vae_pretrained, accelerator.device).to(torch.float16) + del_vae = True + for prompt, latent in zip(validation_prompts, latents): + latent = latent.to(torch.float16) + samples = vae_decode(config.vae.vae_type, vae, latent) + samples = ( + torch.clamp(127.5 * samples + 128.0, 0, 255).permute(0, 2, 3, 1).to("cpu", dtype=torch.uint8).numpy()[0] + ) + image = Image.fromarray(samples) + current_image_logs.append({"validation_prompt": prompt + label_suffix, "images": [image]}) + + if del_vae: + vae = None + gc.collect() + torch.cuda.empty_cache() + return current_image_logs + + # First run with original noise + image_logs += run_sampling(init_z=None, label_suffix="", vae=vae, sampler=vis_sampler) + + # Second run with init_noise if provided + if init_noise is not None: + init_noise = torch.clone(init_noise).to(device) + image_logs += run_sampling(init_z=init_noise, label_suffix=" w/ init noise", vae=vae, sampler=vis_sampler) + + formatted_images = [] + for log in image_logs: + images = log["images"] + validation_prompt = log["validation_prompt"] + for image in images: + formatted_images.append((validation_prompt, np.asarray(image))) + + for tracker in accelerator.trackers: + if tracker.name == "tensorboard": + for validation_prompt, image in formatted_images: + tracker.writer.add_images(validation_prompt, image[None, ...], step, dataformats="NHWC") + elif tracker.name == "wandb": + import wandb + + wandb_images = [] + for validation_prompt, image in formatted_images: + wandb_images.append(wandb.Image(image, caption=validation_prompt, file_type="jpg")) + tracker.log({"validation": wandb_images}) + else: + logger.warn(f"image logging not implemented for {tracker.name}") + + def concatenate_images(image_caption, images_per_row=5, image_format="webp"): + import io + + images = [log["images"][0] for log in image_caption] + if images[0].size[0] > 1024: + images = [image.resize((1024, 1024)) for image in images] + + widths, heights = zip(*(img.size for img in images)) + max_width = max(widths) + total_height = sum(heights[i : i + images_per_row][0] for i in range(0, len(images), images_per_row)) + + new_im = Image.new("RGB", (max_width * images_per_row, total_height)) + + y_offset = 0 + for i in range(0, len(images), images_per_row): + row_images = images[i : i + images_per_row] + x_offset = 0 + for img in row_images: + new_im.paste(img, (x_offset, y_offset)) + x_offset += max_width + y_offset += heights[i] + webp_image_bytes = io.BytesIO() + new_im.save(webp_image_bytes, format=image_format) + webp_image_bytes.seek(0) + new_im = Image.open(webp_image_bytes) + + return new_im + + if config.train.local_save_vis: + file_format = "webp" + local_vis_save_path = osp.join(config.work_dir, "log_vis") + os.umask(0o000) + os.makedirs(local_vis_save_path, exist_ok=True) + concatenated_image = concatenate_images(image_logs, images_per_row=5, image_format=file_format) + save_path = ( + osp.join(local_vis_save_path, f"vis_{step}.{file_format}") + if init_noise is None + else osp.join(local_vis_save_path, f"vis_{step}_w_init.{file_format}") + ) + concatenated_image.save(save_path) + + del vae + flush() + return image_logs + + +class RatioBucketsDataset(): + def __init__( + self, + buckets_file + ): + with open(buckets_file, 'r') as file: + self.buckets = json.load(file) + + def __getitem__(self, idx): + while True: + loader = random.choice(self.loaders) + + try: + return next(loader) + except StopIteration: + self.loaders.remove(loader) + print(f"bucket ended, {len(self.loaders)}") + + def __len__(self): + return self.size + + def make_loaders(self, batch_size): + self.loaders = [] + self.size = 0 + for bucket in self.buckets.keys(): + dataset = ImageDataset(self.buckets[bucket]) + + loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, pin_memory=False, drop_last=False) + self.loaders.append(iter(loader)) + self.size += math.ceil(len(dataset) / batch_size) + +class ImageDataset(torch.utils.data.Dataset): + def __init__( + self, + images + ): + self.images = images + + def getdata(self, idx): + path = self.images[idx] + filename_wo_ext = os.path.splitext(os.path.basename(path))[0] + + text_file = os.path.join(os.path.dirname(path), f"{filename_wo_ext}.txt") + with open(text_file, 'r') as file: + prompt = file.read() + + cache_file = os.path.join(os.path.dirname(path), f"{filename_wo_ext}_img.npz") + cached_data = torch.load(cache_file) + + size = cached_data['prefsize'] + ratio = cached_data['ratio'] + vae_embed = cached_data['img'] + + data_info = { + "img_hw": size, + "aspect_ratio": torch.tensor(ratio.item()), + } + + return ( + vae_embed, + data_info, + prompt, + ) + + def __getitem__(self, idx): + for _ in range(10): + try: + data = self.getdata(idx) + return data + except Exception as e: + print(f"Error details: {str(e)}") + idx = idx + 1 + raise RuntimeError("Too many bad data.") + + def __len__(self): + return len(self.images) + +def train(config, args, accelerator, model, optimizer, lr_scheduler, dataset, train_diffusion, logger): + if getattr(config.train, "debug_nan", False): + DebugUnderflowOverflow(model) + logger.info("NaN debugger registered. Start to detect overflow during training.") + log_buffer = LogBuffer() + + def check_nan_inf(model): + for name, param in model.named_parameters(): + if torch.isnan(param).any() or torch.isinf(param).any(): + print(f"NaN/Inf detected in {name}") + + check_nan_inf(model) + + global_step = start_step + 1 + skip_step = max(config.train.skip_step, global_step) % len(dataset) + skip_step = skip_step if skip_step < (len(dataset) - 20) else 0 + loss_nan_timer = 0 + + # Now you train the model + for epoch in range(start_epoch + 1, config.train.num_epochs + 1): + time_start, last_tic = time.time(), time.time() + if skip_step > 1 and accelerator.is_main_process: + logger.info(f"Skipped Steps: {skip_step}") + skip_step = 1 + data_time_start = time.time() + data_time_all = 0 + lm_time_all = 0 + model_time_all = 0 + dataset.make_loaders(config.train.train_batch_size) + for step, batch in enumerate(dataset): + # image, json_info, key = batch + accelerator.wait_for_everyone() + data_time_all += time.time() - data_time_start + z = batch[0].to(accelerator.device) + + accelerator.wait_for_everyone() + + clean_images = z + data_info = batch[1] + + lm_time_start = time.time() + prompts = list(batch[2]) + shuffled_prompts = [] + for prompt in prompts: + tags = prompt.split(",") # Split the string into a list of tags + random.shuffle(tags) # Shuffle the tags + shuffled_prompts.append(",".join(tags)) # Join them back into a string + + if "T5" in config.text_encoder.text_encoder_name: + with torch.no_grad(): + txt_tokens = tokenizer( + shuffled_prompts, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt" + ).to(accelerator.device) + y = text_encoder(txt_tokens.input_ids, attention_mask=txt_tokens.attention_mask)[0][:, None] + y_mask = txt_tokens.attention_mask[:, None, None] + elif ( + "gemma" in config.text_encoder.text_encoder_name or "Qwen" in config.text_encoder.text_encoder_name + ): + with torch.no_grad(): + if not config.text_encoder.chi_prompt: + max_length_all = config.text_encoder.model_max_length + prompt = shuffled_prompts + else: + chi_prompt = "\n".join(config.text_encoder.chi_prompt) + prompt = [chi_prompt + i for i in shuffled_prompts] + num_chi_prompt_tokens = len(tokenizer.encode(chi_prompt)) + max_length_all = ( + num_chi_prompt_tokens + config.text_encoder.model_max_length - 2 + ) # magic number 2: [bos], [_] + txt_tokens = tokenizer( + prompt, + padding="max_length", + max_length=max_length_all, + truncation=True, + return_tensors="pt", + ).to(accelerator.device) + select_index = [0] + list( + range(-config.text_encoder.model_max_length + 1, 0) + ) # first bos and end N-1 + y = text_encoder(txt_tokens.input_ids, attention_mask=txt_tokens.attention_mask)[0][:, None][ + :, :, select_index + ] + y_mask = txt_tokens.attention_mask[:, None, None][:, :, :, select_index] + else: + print("error") + exit() + + # Sample a random timestep for each image + bs = clean_images.shape[0] + timesteps = torch.randint( + 0, config.scheduler.train_sampling_steps, (bs,), device=clean_images.device + ).long() + if config.scheduler.weighting_scheme in ["logit_normal"]: + # adapting from diffusers.training_utils + u = compute_density_for_timestep_sampling( + weighting_scheme=config.scheduler.weighting_scheme, + batch_size=bs, + logit_mean=config.scheduler.logit_mean, + logit_std=config.scheduler.logit_std, + mode_scale=None, # not used + ) + timesteps = (u * config.scheduler.train_sampling_steps).long().to(clean_images.device) + grad_norm = None + accelerator.wait_for_everyone() + lm_time_all += time.time() - lm_time_start + model_time_start = time.time() + with accelerator.accumulate(model): + # Predict the noise residual + optimizer.zero_grad() + loss_term = train_diffusion.training_losses( + model, clean_images, timesteps, model_kwargs=dict(y=y, mask=y_mask, data_info=data_info) + ) + loss = loss_term["loss"].mean() + + # Check if the loss is NaN + if torch.isnan(loss): + loss_nan_timer += 1 + print(f'Skip nan: {loss_nan_timer}') + continue # Skip the rest of the loop iteration if loss is NaN + + accelerator.backward(loss) + if accelerator.sync_gradients: + grad_norm = accelerator.clip_grad_norm_(model.parameters(), config.train.gradient_clip) + + optimizer.step() + lr_scheduler.step() + accelerator.wait_for_everyone() + model_time_all += time.time() - model_time_start + + lr = lr_scheduler.get_last_lr()[0] + logs = {args.loss_report_name: accelerator.gather(loss).mean().item()} + if grad_norm is not None: + logs.update(grad_norm=accelerator.gather(grad_norm).mean().item()) + log_buffer.update(logs) + if (step + 1) % config.train.log_interval == 0 or (step + 1) == 1: + accelerator.wait_for_everyone() + t = (time.time() - last_tic) / config.train.log_interval + t_d = data_time_all / config.train.log_interval + t_m = model_time_all / config.train.log_interval + t_lm = lm_time_all / config.train.log_interval + avg_time = (time.time() - time_start) / (step + 1) + eta = str(datetime.timedelta(seconds=int(avg_time * (total_steps - global_step - 1)))) + eta_epoch = str( + datetime.timedelta( + seconds=int(avg_time * (len(dataset) - step // config.train.train_batch_size - step - 1)) + ) + ) + log_buffer.average() + + current_step = ( + global_step - step // config.train.train_batch_size + ) % len(dataset) + current_step = len(dataset) if current_step == 0 else current_step + info = ( + f"Epoch: {epoch} | Global Step: {global_step} | Local Step: {current_step} // {len(dataset)}, " + f"total_eta: {eta}, epoch_eta:{eta_epoch}, time: all:{t:.3f}, model:{t_m:.3f}, data:{t_d:.3f}, " + f"lm:{t_lm:.3f}, lr:{lr:.3e}, " + ) + info += ( + f"s:({model.module.h}, {model.module.w}), " + if hasattr(model, "module") + else f"s:({model.h}, {model.w}), " + ) + + info += ", ".join([f"{k}:{v:.4f}" for k, v in log_buffer.output.items()]) + last_tic = time.time() + log_buffer.clear() + data_time_all = 0 + model_time_all = 0 + lm_time_all = 0 + if accelerator.is_main_process: + logger.info(info) + + logs.update(lr=lr) + if accelerator.is_main_process: + accelerator.log(logs, step=global_step) + + global_step += 1 + + if loss_nan_timer > 20: + raise ValueError("Loss is NaN too much times. Break here.") + if global_step % config.train.save_model_steps == 0: + accelerator.wait_for_everyone() + if accelerator.is_main_process: + os.umask(0o000) + checkpoints_dir = osp.join(config.work_dir, "checkpoints") + + # Remove all old checkpoint files in the directory + for filename in os.listdir(checkpoints_dir): + file_path = osp.join(checkpoints_dir, filename) + if os.path.isfile(file_path): + os.remove(file_path) + ckpt_saved_path = save_checkpoint( + checkpoints_dir, + epoch=epoch, + step=global_step, + model=accelerator.unwrap_model(model), + optimizer=optimizer, + lr_scheduler=lr_scheduler, + generator=generator, + add_symlink=True, + ) + if config.train.online_metric and global_step % config.train.eval_metric_step == 0 and step > 1: + online_metric_monitor_dir = osp.join(config.work_dir, config.train.online_metric_dir) + os.makedirs(online_metric_monitor_dir, exist_ok=True) + with open(f"{online_metric_monitor_dir}/{ckpt_saved_path.split('/')[-1]}.txt", "w") as f: + f.write(osp.join(config.work_dir, "config.py") + "\n") + f.write(ckpt_saved_path) + + # if (time.time() - training_start_time) / 3600 > 3.8: + # logger.info(f"Stopping training at epoch {epoch}, step {global_step} due to time limit.") + # return + if config.train.visualize and (global_step % config.train.eval_sampling_steps == 0 or (step + 1) == 1): + accelerator.wait_for_everyone() + if accelerator.is_main_process: + if validation_noise is not None: + log_validation( + accelerator=accelerator, + config=config, + model=model, + logger=logger, + step=global_step, + device=accelerator.device, + vae=vae, + init_noise=validation_noise, + ) + else: + log_validation( + accelerator=accelerator, + config=config, + model=model, + logger=logger, + step=global_step, + device=accelerator.device, + vae=vae, + ) + + data_time_start = time.time() + + accelerator.wait_for_everyone() + + +@pyrallis.wrap() +def main(cfg: SanaConfig) -> None: + global start_epoch, start_step, vae, generator, num_replicas, rank, training_start_time + global load_vae_feat, load_text_feat, validation_noise, text_encoder, tokenizer + global max_length, validation_prompts, latent_size, valid_prompt_embed_suffix, null_embed_path + global image_size, cache_file, total_steps + + config = cfg + args = cfg + # config = read_config(args.config) + + training_start_time = time.time() + load_from = True + if args.resume_from or config.model.resume_from: + load_from = False + config.model.resume_from = dict( + checkpoint=args.resume_from or config.model.resume_from, + load_ema=False, + resume_optimizer=False, + resume_lr_scheduler=False, + ) + + if args.debug: + config.train.log_interval = 1 + config.train.train_batch_size = min(64, config.train.train_batch_size) + args.report_to = "tensorboard" + + os.umask(0o000) + os.makedirs(config.work_dir, exist_ok=True) + + init_handler = InitProcessGroupKwargs() + init_handler.timeout = datetime.timedelta(seconds=5400) # change timeout to avoid a strange NCCL bug + # Initialize accelerator and tensorboard logging + if config.train.use_fsdp: + init_train = "FSDP" + from accelerate import FullyShardedDataParallelPlugin + from torch.distributed.fsdp.fully_sharded_data_parallel import FullStateDictConfig + + set_fsdp_env() + fsdp_plugin = FullyShardedDataParallelPlugin( + state_dict_config=FullStateDictConfig(offload_to_cpu=False, rank0_only=False), + ) + else: + init_train = "DDP" + fsdp_plugin = None + + accelerator = Accelerator( + mixed_precision=config.model.mixed_precision, + gradient_accumulation_steps=config.train.gradient_accumulation_steps, + log_with=args.report_to, + project_dir=osp.join(config.work_dir, "logs"), + fsdp_plugin=fsdp_plugin, + kwargs_handlers=[init_handler], + ) + + log_name = "train_log.log" + logger = get_root_logger(osp.join(config.work_dir, log_name)) + logger.info(accelerator.state) + + config.train.seed = init_random_seed(getattr(config.train, "seed", None)) + set_random_seed(config.train.seed + int(os.environ["LOCAL_RANK"])) + generator = torch.Generator(device="cpu").manual_seed(config.train.seed) + + if accelerator.is_main_process: + pyrallis.dump(config, open(osp.join(config.work_dir, "config.yaml"), "w"), sort_keys=False, indent=4) + if args.report_to == "wandb": + import wandb + + wandb.init(project=args.tracker_project_name, name=args.name, resume="allow", id=args.name) + + logger.info(f"Config: \n{config}") + logger.info(f"World_size: {get_world_size()}, seed: {config.train.seed}") + logger.info(f"Initializing: {init_train} for training") + image_size = config.model.image_size + latent_size = int(image_size) // config.vae.vae_downsample_rate + pred_sigma = getattr(config.scheduler, "pred_sigma", True) + learn_sigma = getattr(config.scheduler, "learn_sigma", True) and pred_sigma + max_length = config.text_encoder.model_max_length + vae = None + validation_noise = ( + torch.randn(1, config.vae.vae_latent_dim, latent_size, latent_size, device="cpu", generator=generator) + if getattr(config.train, "deterministic_validation", False) + else None + ) + + tokenizer = text_encoder = None + if not config.data.load_text_feat: + tokenizer, text_encoder = get_tokenizer_and_text_encoder( + name=config.text_encoder.text_encoder_name, device=accelerator.device + ) + text_encoder.requires_grad_(False) + text_embed_dim = text_encoder.config.hidden_size + else: + text_embed_dim = config.text_encoder.caption_channels + + logger.info(f"vae type: {config.vae.vae_type}") + if config.text_encoder.chi_prompt: + chi_prompt = "\n".join(config.text_encoder.chi_prompt) + logger.info(f"Complex Human Instruct: {chi_prompt}") + + os.makedirs(config.train.null_embed_root, exist_ok=True) + null_embed_path = osp.join( + config.train.null_embed_root, + f"null_embed_diffusers_{config.text_encoder.text_encoder_name}_{max_length}token_{text_embed_dim}.pth", + ) + if config.train.visualize and len(config.train.validation_prompts): + # preparing embeddings for visualization. We put it here for saving GPU memory + valid_prompt_embed_suffix = f"{max_length}token_{config.text_encoder.text_encoder_name}_{text_embed_dim}.pth" + validation_prompts = config.train.validation_prompts + skip = True + if config.text_encoder.chi_prompt: + uuid_chi_prompt = hashlib.sha256(chi_prompt.encode()).hexdigest() + else: + uuid_chi_prompt = hashlib.sha256(b"").hexdigest() + config.train.valid_prompt_embed_root = osp.join(config.train.valid_prompt_embed_root, uuid_chi_prompt) + Path(config.train.valid_prompt_embed_root).mkdir(parents=True, exist_ok=True) + + if config.text_encoder.chi_prompt: + # Save complex human instruct to a file + chi_prompt_file = osp.join(config.train.valid_prompt_embed_root, "chi_prompt.txt") + with open(chi_prompt_file, "w", encoding="utf-8") as f: + f.write(chi_prompt) + + for prompt in validation_prompts: + prompt_embed_path = osp.join( + config.train.valid_prompt_embed_root, f"{prompt[:50]}_{valid_prompt_embed_suffix}" + ) + if not (osp.exists(prompt_embed_path) and osp.exists(null_embed_path)): + skip = False + logger.info("Preparing Visualization prompt embeddings...") + break + if accelerator.is_main_process and not skip: + if config.data.load_text_feat and (tokenizer is None or text_encoder is None): + logger.info(f"Loading text encoder and tokenizer from {config.text_encoder.text_encoder_name} ...") + tokenizer, text_encoder = get_tokenizer_and_text_encoder(name=config.text_encoder.text_encoder_name) + + for prompt in validation_prompts: + prompt_embed_path = osp.join( + config.train.valid_prompt_embed_root, f"{prompt[:50]}_{valid_prompt_embed_suffix}" + ) + if "T5" in config.text_encoder.text_encoder_name: + txt_tokens = tokenizer( + prompt, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt" + ).to(accelerator.device) + caption_emb = text_encoder(txt_tokens.input_ids, attention_mask=txt_tokens.attention_mask)[0] + caption_emb_mask = txt_tokens.attention_mask + elif ( + "gemma" in config.text_encoder.text_encoder_name or "Qwen" in config.text_encoder.text_encoder_name + ): + if not config.text_encoder.chi_prompt: + max_length_all = config.text_encoder.model_max_length + else: + chi_prompt = "\n".join(config.text_encoder.chi_prompt) + prompt = chi_prompt + prompt + num_chi_prompt_tokens = len(tokenizer.encode(chi_prompt)) + max_length_all = ( + num_chi_prompt_tokens + config.text_encoder.model_max_length - 2 + ) # magic number 2: [bos], [_] + + txt_tokens = tokenizer( + prompt, + max_length=max_length_all, + padding="max_length", + truncation=True, + return_tensors="pt", + ).to(accelerator.device) + select_index = [0] + list(range(-config.text_encoder.model_max_length + 1, 0)) + caption_emb = text_encoder(txt_tokens.input_ids, attention_mask=txt_tokens.attention_mask)[0][ + :, select_index + ] + caption_emb_mask = txt_tokens.attention_mask[:, select_index] + else: + raise ValueError(f"{config.text_encoder.text_encoder_name} is not supported!!") + + torch.save({"caption_embeds": caption_emb, "emb_mask": caption_emb_mask}, prompt_embed_path) + + null_tokens = tokenizer( + "bad artwork,ugly,sketch,poorly drawn,messy,noisy,score: 0/10,blurry,low quality,old", max_length=max_length, padding="max_length", truncation=True, return_tensors="pt" + ).to(accelerator.device) + if "T5" in config.text_encoder.text_encoder_name: + null_token_emb = text_encoder(null_tokens.input_ids, attention_mask=null_tokens.attention_mask)[0] + elif "gemma" in config.text_encoder.text_encoder_name or "Qwen" in config.text_encoder.text_encoder_name: + null_token_emb = text_encoder(null_tokens.input_ids, attention_mask=null_tokens.attention_mask)[0] + else: + raise ValueError(f"{config.text_encoder.text_encoder_name} is not supported!!") + torch.save( + {"uncond_prompt_embeds": null_token_emb, "uncond_prompt_embeds_mask": null_tokens.attention_mask}, + null_embed_path, + ) + if config.data.load_text_feat: + del tokenizer + del text_encoder + del null_token_emb + del null_tokens + flush() + + os.environ["AUTOCAST_LINEAR_ATTN"] = "true" if config.model.autocast_linear_attn else "false" + + # 1. build scheduler + train_diffusion = Scheduler( + str(config.scheduler.train_sampling_steps), + noise_schedule=config.scheduler.noise_schedule, + predict_v=config.scheduler.predict_v, + learn_sigma=learn_sigma, + pred_sigma=pred_sigma, + snr=config.train.snr_loss, + flow_shift=config.scheduler.flow_shift, + ) + predict_info = f"v-prediction: {config.scheduler.predict_v}, noise schedule: {config.scheduler.noise_schedule}" + if "flow" in config.scheduler.noise_schedule: + predict_info += f", flow shift: {config.scheduler.flow_shift}" + if config.scheduler.weighting_scheme in ["logit_normal", "mode"]: + predict_info += ( + f", flow weighting: {config.scheduler.weighting_scheme}, " + f"logit-mean: {config.scheduler.logit_mean}, logit-std: {config.scheduler.logit_std}" + ) + logger.info(predict_info) + + # 2. build models + model_kwargs = { + "pe_interpolation": config.model.pe_interpolation, + "config": config, + "model_max_length": max_length, + "qk_norm": config.model.qk_norm, + "micro_condition": config.model.micro_condition, + "caption_channels": text_embed_dim, + "y_norm": config.text_encoder.y_norm, + "attn_type": config.model.attn_type, + "ffn_type": config.model.ffn_type, + "mlp_ratio": config.model.mlp_ratio, + "mlp_acts": list(config.model.mlp_acts), + "in_channels": config.vae.vae_latent_dim, + "y_norm_scale_factor": config.text_encoder.y_norm_scale_factor, + "use_pe": config.model.use_pe, + "linear_head_dim": config.model.linear_head_dim, + "pred_sigma": pred_sigma, + "learn_sigma": learn_sigma, + } + model = build_model( + config.model.model, + config.train.grad_checkpointing, + getattr(config.model, "fp32_attention", False), + input_size=latent_size, + **model_kwargs, + ).train() + + logger.info( + colored( + f"{model.__class__.__name__}:{config.model.model}, " + f"Model Parameters: {sum(p.numel() for p in model.parameters()) / 1e6:.2f}M", + "green", + attrs=["bold"], + ) + ) + # 2-1. load model + if args.load_from is not None: + config.model.load_from = args.load_from + # if config.model.load_from is not None and load_from: + # _, missing, unexpected, _ = load_checkpoint( + # config.model.load_from, + # model, + # load_ema=config.model.resume_from.get("load_ema", False), + # null_embed_path=null_embed_path, + # ) + # logger.warning(f"Missing keys: {missing}") + # logger.warning(f"Unexpected keys: {unexpected}") + + # prepare for FSDP clip grad norm calculation + if accelerator.distributed_type == DistributedType.FSDP: + for m in accelerator._models: + m.clip_grad_norm_ = types.MethodType(clip_grad_norm_, m) + + # 3. build dataloader + config.data.data_dir = config.data.data_dir if isinstance(config.data.data_dir, list) else [config.data.data_dir] + config.data.data_dir = [ + data if data.startswith(("https://", "http://", "gs://", "/", "~")) else osp.abspath(osp.expanduser(data)) + for data in config.data.data_dir + ] + num_replicas = int(os.environ["WORLD_SIZE"]) + rank = int(os.environ["RANK"]) + dataset = RatioBucketsDataset(config.data.buckets_file) + accelerator.wait_for_everyone() + dataset.make_loaders(batch_size=config.train.train_batch_size) + + load_vae_feat = getattr(dataset, "load_vae_feat", False) + load_text_feat = getattr(dataset, "load_text_feat", False) + + optimizer = build_optimizer(model, config.train.optimizer) + + if config.train.lr_schedule_args and config.train.lr_schedule_args.get("num_warmup_steps", None): + config.train.lr_schedule_args["num_warmup_steps"] = ( + config.train.lr_schedule_args["num_warmup_steps"] * num_replicas + ) + lr_scheduler = build_lr_scheduler(config.train, optimizer, dataset, 1) + + logger.warning( + f"{colored(f'Basic Setting: ', 'green', attrs=['bold'])}" + f"lr: {config.train.optimizer['lr']:.9f}, bs: {config.train.train_batch_size}, gc: {config.train.grad_checkpointing}, " + f"gc_accum_step: {config.train.gradient_accumulation_steps}, qk norm: {config.model.qk_norm}, " + f"fp32 attn: {config.model.fp32_attention}, attn type: {config.model.attn_type}, ffn type: {config.model.ffn_type}, " + f"text encoder: {config.text_encoder.text_encoder_name}, captions: {config.data.caption_proportion}, precision: {config.model.mixed_precision}" + ) + + timestamp = time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime()) + + if accelerator.is_main_process: + tracker_config = dict(vars(config)) + try: + accelerator.init_trackers(args.tracker_project_name, tracker_config) + except: + accelerator.init_trackers(f"tb_{timestamp}") + + start_epoch = 0 + start_step = 0 + total_steps = len(dataset) * config.train.num_epochs + + # Resume training + if config.model.resume_from is not None and config.model.resume_from["checkpoint"] is not None: + rng_state = None + ckpt_path = osp.join(config.work_dir, "checkpoints") + check_flag = osp.exists(ckpt_path) and len(os.listdir(ckpt_path)) != 0 + if config.model.resume_from["checkpoint"] == "latest": + if check_flag: + checkpoints = os.listdir(ckpt_path) + if "latest.pth" in checkpoints and osp.exists(osp.join(ckpt_path, "latest.pth")): + config.model.resume_from["checkpoint"] = osp.realpath(osp.join(ckpt_path, "latest.pth")) + else: + checkpoints = [i for i in checkpoints if i.startswith("epoch_")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.replace(".pth", "").split("_")[3])) + config.model.resume_from["checkpoint"] = osp.join(ckpt_path, checkpoints[-1]) + else: + config.model.resume_from["checkpoint"] = config.model.load_from + + if config.model.resume_from["checkpoint"] is not None: + _, missing, unexpected, rng_state = load_checkpoint( + **config.model.resume_from, + model=model, + optimizer=optimizer if check_flag else None, + lr_scheduler=lr_scheduler if check_flag else None, + null_embed_path=null_embed_path, + ) + + logger.warning(f"Missing keys: {missing}") + logger.warning(f"Unexpected keys: {unexpected}") + + path = osp.basename(config.model.resume_from["checkpoint"]) + try: + start_epoch = int(path.replace(".pth", "").split("_")[1]) - 1 + start_step = int(path.replace(".pth", "").split("_")[3]) + except: + pass + + # resume randomise + if rng_state: + logger.info("resuming randomise") + torch.set_rng_state(rng_state["torch"]) + torch.cuda.set_rng_state_all(rng_state["torch_cuda"]) + np.random.set_state(rng_state["numpy"]) + random.setstate(rng_state["python"]) + generator.set_state(rng_state["generator"]) # resume generator status + + # Prepare everything + # There is no specific order to remember, you just need to unpack the + # objects in the same order you gave them to the prepare method. + model = accelerator.prepare(model) + optimizer, lr_scheduler = accelerator.prepare(optimizer, lr_scheduler) + + train( + config=config, + args=args, + accelerator=accelerator, + model=model, + optimizer=optimizer, + lr_scheduler=lr_scheduler, + dataset=dataset, + train_diffusion=train_diffusion, + logger=logger, + ) + + +if __name__ == "__main__": + main() diff --git a/train_scripts/train_local.sh b/train_scripts/train_local.sh new file mode 100644 index 0000000..e7f88b9 --- /dev/null +++ b/train_scripts/train_local.sh @@ -0,0 +1,25 @@ +#/bin/bash +set -e + +work_dir=output/debug +np=1 + + +if [[ $1 == *.yaml ]]; then + config=$1 + shift +else + config="configs/sana_config/512ms/sample_dataset.yaml" + echo "Only support .yaml files, but get $1. Set to --config_path=$config" +fi + +TRITON_PRINT_AUTOTUNING=1 \ + torchrun --nproc_per_node=$np --master_port=15432 \ + train_scripts/train_local.py \ + --config_path=$config \ + --work_dir=$work_dir \ + --name=tmp \ + --resume_from=latest \ + --report_to=tensorboard \ + --debug=true \ + "$@" From a83189aecb9ae0324e435acd1df4ff7e97ed40f5 Mon Sep 17 00:00:00 2001 From: Muinez <76997923+Muinez@users.noreply.github.com> Date: Tue, 26 Nov 2024 23:44:22 +0300 Subject: [PATCH 07/14] Update README.md --- README.md | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/README.md b/README.md index 3d77bbd..9d60dac 100644 --- a/README.md +++ b/README.md @@ -218,6 +218,25 @@ bash train_scripts/train.sh \ --train.train_batch_size=8 ``` +Local training with bucketing and VAE embedding caching: +```bash +# Prepare buckets and cache VAE embeds +python train_scripts/make_buckets.py \ + --config=configs/sana_config/1024ms/Sana_1600M_img1024.yaml \ + --data.data_dir=[asset/example_data] \ + --data.buckets_file=buckets.json + +# Start training with cached VAE embeddings and bucketing +bash train_scripts/train_local.sh \ + configs/sana_config/1024ms/Sana_1600M_img1024.yaml \ + --data.buckets_file=buckets.json \ + --train.train_batch_size=30 +``` +Using the AdamW optimizer, training with a batch size of 30 on 1024x1024 resolution consumes ~48GB VRAM on an NVIDIA A6000 GPU. +Each training iteration takes ~7.5 seconds. + + + # đź’» 4. Metric toolkit Refer to [Toolkit Manual](asset/docs/metrics_toolkit.md). From 335d4452126952b90376a6f7ccd0ce0490a16fa9 Mon Sep 17 00:00:00 2001 From: lawrence-cj Date: Wed, 27 Nov 2024 10:56:38 +0800 Subject: [PATCH 08/14] change code license to Apache 2.0. Signed-off-by: lawrence-cj --- LICENSE | 318 ++++++++++++++++++++++++++++++++++-------------------- README.md | 1 + 2 files changed, 202 insertions(+), 117 deletions(-) diff --git a/LICENSE b/LICENSE index 20f57b9..000dde7 100755 --- a/LICENSE +++ b/LICENSE @@ -1,117 +1,201 @@ -Copyright (c) 2019, NVIDIA Corporation. All rights reserved. - - -Nvidia Source Code License-NC - -======================================================================= - -1. Definitions - -“Licensor” means any person or entity that distributes its Work. - -“Work” means (a) the original work of authorship made available under -this license, which may include software, documentation, or other -files, and (b) any additions to or derivative works thereof -that are made available under this license. - -“NVIDIA Processors” means any central processing unit (CPU), -graphics processing unit (GPU), field-programmable gate array (FPGA), -application-specific integrated circuit (ASIC) or any combination -thereof designed, made, sold, or provided by NVIDIA or its affiliates. - -The terms “reproduce,” “reproduction,” “derivative works,” and -“distribution” have the meaning as provided under U.S. copyright law; -provided, however, that for the purposes of this license, derivative -works shall not include works that remain separable from, or merely -link (or bind by name) to the interfaces of, the Work. - -Works are “made available” under this license by including in or with -the Work either (a) a copyright notice referencing the applicability -of this license to the Work, or (b) a copy of this license. - -"Safe Model" means ShieldGemma-2B, which is a series of safety -content moderation models designed to moderate four categories of -harmful content: sexually explicit material, dangerous content, -hate speech, and harassment, and which you separately obtain -from Google at https://huggingface.co/google/shieldgemma-2b. - - -2. License Grant - -2.1 Copyright Grant. Subject to the terms and conditions of this -license, each Licensor grants to you a perpetual, worldwide, -non-exclusive, royalty-free, copyright license to use, reproduce, -prepare derivative works of, publicly display, publicly perform, -sublicense and distribute its Work and any resulting derivative -works in any form. - -3. Limitations - -3.1 Redistribution. You may reproduce or distribute the Work only if -(a) you do so under this license, (b) you include a complete copy of -this license with your distribution, and (c) you retain without -modification any copyright, patent, trademark, or attribution notices -that are present in the Work. - -3.2 Derivative Works. You may specify that additional or different -terms apply to the use, reproduction, and distribution of your -derivative works of the Work (“Your Terms”) only if (a) Your Terms -provide that the use limitation in Section 3.3 applies to your -derivative works, and (b) you identify the specific derivative works -that are subject to Your Terms. Notwithstanding Your Terms, this -license (including the redistribution requirements in Section 3.1) -will continue to apply to the Work itself. - -3.3 Use Limitation. The Work and any derivative works thereof only may -be used or intended for use non-commercially and with NVIDIA Processors, -in accordance with Section 3.4, below. Notwithstanding the foregoing, -NVIDIA Corporation and its affiliates may use the Work and any -derivative works commercially. As used herein, “non-commercially” -means for research or evaluation purposes only. - -3.4 You shall filter your input content to the Work and any derivative -works thereof through the Safe Model to ensure that no content described -as Not Safe For Work (NSFW) is processed or generated. You shall not use -the Work to process or generate NSFW content. You are solely responsible -for any damages and liabilities arising from your failure to adequately -filter content in accordance with this section. As used herein, -“Not Safe For Work” or “NSFW” means content, videos or website pages -that contain potentially disturbing subject matter, including but not -limited to content that is sexually explicit, dangerous, hate, -or harassment. - -3.5 Patent Claims. If you bring or threaten to bring a patent claim -against any Licensor (including any claim, cross-claim or counterclaim -in a lawsuit) to enforce any patents that you allege are infringed by -any Work, then your rights under this license from such Licensor -(including the grant in Section 2.1) will terminate immediately. - -3.6 Trademarks. This license does not grant any rights to use any -Licensor’s or its affiliates’ names, logos, or trademarks, except as -necessary to reproduce the notices described in this license. - -3.7 Termination. If you violate any term of this license, then your -rights under this license (including the grant in Section 2.1) will -terminate immediately. - -4. Disclaimer of Warranty. - -THE WORK IS PROVIDED “AS IS” WITHOUT WARRANTIES OR CONDITIONS OF ANY -KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF -MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR -NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES -UNDER THIS LICENSE. - -5. Limitation of Liability. - -EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL -THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE -SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT, -INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF -OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK -(INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION, -LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER -DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF THE -POSSIBILITY OF SUCH DAMAGES. - -======================================================================= + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + +TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + +1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + +2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + +3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + +4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + +5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + +6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + +7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + +8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + +9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + +END OF TERMS AND CONDITIONS + +APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + +Copyright 2024 Junsong Chen, Jincheng Yu, Enze Xie + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. diff --git a/README.md b/README.md index 3d77bbd..8a5bff8 100644 --- a/README.md +++ b/README.md @@ -35,6 +35,7 @@ As a result, Sana-0.6B is very competitive with modern giant diffusion model (e. ## 🔥🔥 News +- (🔥 New) \[2024/11\] Sana code-base license changed to Apache 2.0. - (🔥 New) \[2024/11\] 1.6B [Sana multi-linguistic models](https://huggingface.co/Efficient-Large-Model/Sana_1600M_512px_MultiLing) are released. Multi-language(Emoji & Chinese & English) are supported. - (🔥 New) \[2024/11\] 1.6B [Sana models](https://huggingface.co/collections/Efficient-Large-Model/sana-673efba2a57ed99843f11f9e) are released. - (🔥 New) \[2024/11\] Training & Inference & Metrics code are released. From de0cd7b1fc825d9427f0736bd0daa07d6f9fa5af Mon Sep 17 00:00:00 2001 From: Enze Xie Date: Wed, 27 Nov 2024 15:02:19 +0800 Subject: [PATCH 09/14] update README.md; Signed-off-by: lawrence-cj --- README.md | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 8a5bff8..cda95de 100644 --- a/README.md +++ b/README.md @@ -8,7 +8,8 @@   -   +   +   @@ -35,8 +36,9 @@ As a result, Sana-0.6B is very competitive with modern giant diffusion model (e. ## 🔥🔥 News -- (🔥 New) \[2024/11\] Sana code-base license changed to Apache 2.0. -- (🔥 New) \[2024/11\] 1.6B [Sana multi-linguistic models](https://huggingface.co/Efficient-Large-Model/Sana_1600M_512px_MultiLing) are released. Multi-language(Emoji & Chinese & English) are supported. +- (🔥 New) \[2024/11/27\] Sana Replicate API is launching at [Sana-API](https://replicate.com/chenxwh/sana). +- (🔥 New) \[2024/11/27\] Sana code-base license changed to Apache 2.0. +- (🔥 New) \[2024/11/26\] 1.6B [Sana multi-linguistic models](https://huggingface.co/Efficient-Large-Model/Sana_1600M_512px_MultiLing) are released. Multi-language(Emoji & Chinese & English) are supported. - (🔥 New) \[2024/11\] 1.6B [Sana models](https://huggingface.co/collections/Efficient-Large-Model/sana-673efba2a57ed99843f11f9e) are released. - (🔥 New) \[2024/11\] Training & Inference & Metrics code are released. - (🔥 New) \[2024/11\] Working on [`diffusers`](https://github.com/huggingface/diffusers/pull/9982). From 16b7ad57bf377f408e103f5a65fe557b1d53aede Mon Sep 17 00:00:00 2001 From: lawrence-cj Date: Wed, 27 Nov 2024 18:35:16 +0800 Subject: [PATCH 10/14] pre-commit & and still need to be re-format into current code-base Signed-off-by: lawrence-cj --- README.md | 4 +- train_scripts/make_buckets.py | 116 ++++++++++++++++++---------------- train_scripts/train_local.py | 89 ++++++++++++++------------ 3 files changed, 113 insertions(+), 96 deletions(-) diff --git a/README.md b/README.md index 277ec91..372d384 100644 --- a/README.md +++ b/README.md @@ -222,6 +222,7 @@ bash train_scripts/train.sh \ ``` Local training with bucketing and VAE embedding caching: + ```bash # Prepare buckets and cache VAE embeds python train_scripts/make_buckets.py \ @@ -235,11 +236,10 @@ bash train_scripts/train_local.sh \ --data.buckets_file=buckets.json \ --train.train_batch_size=30 ``` + Using the AdamW optimizer, training with a batch size of 30 on 1024x1024 resolution consumes ~48GB VRAM on an NVIDIA A6000 GPU. Each training iteration takes ~7.5 seconds. - - # đź’» 4. Metric toolkit Refer to [Toolkit Manual](asset/docs/metrics_toolkit.md). diff --git a/train_scripts/make_buckets.py b/train_scripts/make_buckets.py index 4a4f64f..36c4ed4 100644 --- a/train_scripts/make_buckets.py +++ b/train_scripts/make_buckets.py @@ -1,17 +1,20 @@ -import torch -from diffusion.model.builder import get_vae, vae_encode -from diffusion.utils.config import SanaConfig -import pyrallis -from PIL import Image -import torchvision.transforms as T +import json +import math import os import os.path as osp -from torchvision.transforms import InterpolationMode -import json +from itertools import chain + +import pyrallis +import torch +import torchvision.transforms as T +from PIL import Image from torch.utils.data import DataLoader +from torchvision.transforms import InterpolationMode from tqdm import tqdm -import math -from itertools import chain + +from diffusion.model.builder import get_vae, vae_encode +from diffusion.utils.config import SanaConfig + @pyrallis.wrap() def main(config: SanaConfig) -> None: @@ -22,16 +25,16 @@ def main(config: SanaConfig) -> None: step = 32 ratios_array = [] - while(min_size != max_size): + while min_size != max_size: width = int(preferred_pixel_count / min_size) - if(width % step != 0): - mod = width % step - if(mod < step//2): + if width % step != 0: + mod = width % step + if mod < step // 2: width -= mod else: width += step - mod - ratio = min_size / width + ratio = min_size / width ratios_array.append((ratio, (int(min_size), width))) min_size += step @@ -43,25 +46,31 @@ def get_closest_ratio(height: float, width: float): def get_preffered_size(height: float, width: float): pixel_count = height * width - + scale = math.sqrt(pixel_count / preferred_pixel_count) return height / scale, width / scale class BucketsDataset(torch.utils.data.Dataset): def __init__(self, data_dir, skip_files): valid_extensions = {".png", ".jpg", ".jpeg", ".bmp", ".tiff", ".webp"} - self.files = ([ - osp.join(data_dir, f) for f in os.listdir(data_dir) - if osp.isfile(osp.join(data_dir, f)) and osp.splitext(f)[1].lower() in valid_extensions and osp.join(data_dir, f) not in skip_files ]) - - self.transform = T.Compose([ - T.ToTensor(), - T.Normalize([0.5], [0.5]), - ]) - + self.files = [ + osp.join(data_dir, f) + for f in os.listdir(data_dir) + if osp.isfile(osp.join(data_dir, f)) + and osp.splitext(f)[1].lower() in valid_extensions + and osp.join(data_dir, f) not in skip_files + ] + + self.transform = T.Compose( + [ + T.ToTensor(), + T.Normalize([0.5], [0.5]), + ] + ) + def __len__(self): return len(self.files) - + def __getitem__(self, idx): path = self.files[idx] img = Image.open(path).convert("RGB") @@ -70,11 +79,11 @@ def __getitem__(self, idx): crop = T.Resize(ratio[1], interpolation=InterpolationMode.BICUBIC) return { - 'img': self.transform(crop(img)), - 'size': torch.tensor([ratio[1][0], ratio[1][1]]), - 'prefsize': torch.tensor([prefsize[0], prefsize[1]]), - 'ratio': ratio[0], - 'path': path + "img": self.transform(crop(img)), + "size": torch.tensor([ratio[1][0], ratio[1][1]]), + "prefsize": torch.tensor([prefsize[0], prefsize[1]]), + "ratio": ratio[0], + "path": path, } vae = get_vae(config.vae.vae_type, config.vae.vae_pretrained, "cuda").to(torch.float16) @@ -82,14 +91,16 @@ def __getitem__(self, idx): def encode_images(batch, vae): with torch.no_grad(): z = vae_encode( - config.vae.vae_type, vae, batch, + config.vae.vae_type, + vae, + batch, sample_posterior=config.vae.sample_posterior, # Adjust as necessary - device="cuda" + device="cuda", ) return z if os.path.exists(config.data.buckets_file): - with open(config.data.buckets_file, 'r') as json_file: + with open(config.data.buckets_file) as json_file: buckets = json.load(json_file) existings_images = set(chain.from_iterable(buckets.values())) else: @@ -101,36 +112,35 @@ def add_to_list(key, item): buckets[key].append(item) else: buckets[key] = [item] - + for path in config.data.data_dir: - print(f'Processing {path}') + print(f"Processing {path}") dataset = BucketsDataset(path, existings_images) dataloader = DataLoader(dataset, batch_size=1) for batch in tqdm(dataloader): - img = batch['img'] - size = batch['size'] - ratio = batch['ratio'] - image_path = batch['path'] - prefsize = batch['prefsize'] + img = batch["img"] + size = batch["size"] + ratio = batch["ratio"] + image_path = batch["path"] + prefsize = batch["prefsize"] encoded = encode_images(img.to(torch.half), vae) - + for i in range(0, len(encoded)): filename_wo_ext = os.path.splitext(os.path.basename(image_path[i]))[0] add_to_list(str(ratio[i].item()), image_path[i]) - - torch.save({ - 'img': encoded[i].detach().clone(), - 'size': size[i], - 'prefsize': prefsize[i], - 'ratio': ratio[i] - }, f"{path}/{filename_wo_ext}_img.npz") - - with open(config.data.buckets_file, 'w') as json_file: + + torch.save( + {"img": encoded[i].detach().clone(), "size": size[i], "prefsize": prefsize[i], "ratio": ratio[i]}, + f"{path}/{filename_wo_ext}_img.npz", + ) + + with open(config.data.buckets_file, "w") as json_file: json.dump(buckets, json_file, indent=4) for ratio in buckets.keys(): - print(f'{float(ratio):.2f}: {len(buckets[ratio])}') + print(f"{float(ratio):.2f}: {len(buckets[ratio])}") + if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/train_scripts/train_local.py b/train_scripts/train_local.py index 084cf83..bc20452 100644 --- a/train_scripts/train_local.py +++ b/train_scripts/train_local.py @@ -29,14 +29,20 @@ import numpy as np import pyrallis import torch +import torch.utils +import torch.utils.data from accelerate import Accelerator, InitProcessGroupKwargs from accelerate.utils import DistributedType from PIL import Image from termcolor import colored -import torch.utils -import torch.utils.data + warnings.filterwarnings("ignore") # ignore warning +import gc +import json +import math +import random + from diffusion import DPMS, FlowEuler, Scheduler from diffusion.model.builder import build_model, get_tokenizer_and_text_encoder, get_vae, vae_decode, vae_encode from diffusion.model.respace import compute_density_for_timestep_sampling @@ -47,10 +53,6 @@ from diffusion.utils.lr_scheduler import build_lr_scheduler from diffusion.utils.misc import DebugUnderflowOverflow, init_random_seed, set_random_seed from diffusion.utils.optimizer import build_optimizer -import json -import random -import math -import gc os.environ["TOKENIZERS_PARALLELISM"] = "false" @@ -61,10 +63,13 @@ def set_fsdp_env(): os.environ["FSDP_BACKWARD_PREFETCH"] = "BACKWARD_PRE" os.environ["FSDP_TRANSFORMER_CLS_TO_WRAP"] = "SanaBlock" + image_index = 0 + + @torch.inference_mode() def log_validation(accelerator, config, model, logger, step, device, vae=None, init_noise=None): - + torch.cuda.empty_cache() vis_sampler = config.scheduler.vis_sampler model = accelerator.unwrap_model(model).eval() @@ -127,7 +132,7 @@ def run_sampling(init_z=None, label_suffix="", vae=None, sampler="dpm-solver"): model_kwargs=model_kwargs, schedule="FLOW", ) - + denoised = dpm_solver.sample( z, steps=24, @@ -141,7 +146,7 @@ def run_sampling(init_z=None, label_suffix="", vae=None, sampler="dpm-solver"): latents.append(denoised) torch.cuda.empty_cache() - + del_vae = False if vae is None: vae = get_vae(config.vae.vae_type, config.vae.vae_pretrained, accelerator.device).to(torch.float16) @@ -236,12 +241,9 @@ def concatenate_images(image_caption, images_per_row=5, image_format="webp"): return image_logs -class RatioBucketsDataset(): - def __init__( - self, - buckets_file - ): - with open(buckets_file, 'r') as file: +class RatioBucketsDataset: + def __init__(self, buckets_file): + with open(buckets_file) as file: self.buckets = json.load(file) def __getitem__(self, idx): @@ -249,29 +251,29 @@ def __getitem__(self, idx): loader = random.choice(self.loaders) try: - return next(loader) + return next(loader) except StopIteration: self.loaders.remove(loader) print(f"bucket ended, {len(self.loaders)}") def __len__(self): return self.size - + def make_loaders(self, batch_size): self.loaders = [] self.size = 0 for bucket in self.buckets.keys(): dataset = ImageDataset(self.buckets[bucket]) - - loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, pin_memory=False, drop_last=False) + + loader = torch.utils.data.DataLoader( + dataset, batch_size=batch_size, shuffle=True, pin_memory=False, drop_last=False + ) self.loaders.append(iter(loader)) self.size += math.ceil(len(dataset) / batch_size) + class ImageDataset(torch.utils.data.Dataset): - def __init__( - self, - images - ): + def __init__(self, images): self.images = images def getdata(self, idx): @@ -279,16 +281,16 @@ def getdata(self, idx): filename_wo_ext = os.path.splitext(os.path.basename(path))[0] text_file = os.path.join(os.path.dirname(path), f"{filename_wo_ext}.txt") - with open(text_file, 'r') as file: + with open(text_file) as file: prompt = file.read() cache_file = os.path.join(os.path.dirname(path), f"{filename_wo_ext}_img.npz") cached_data = torch.load(cache_file) - size = cached_data['prefsize'] - ratio = cached_data['ratio'] - vae_embed = cached_data['img'] - + size = cached_data["prefsize"] + ratio = cached_data["ratio"] + vae_embed = cached_data["img"] + data_info = { "img_hw": size, "aspect_ratio": torch.tensor(ratio.item()), @@ -313,6 +315,7 @@ def __getitem__(self, idx): def __len__(self): return len(self.images) + def train(config, args, accelerator, model, optimizer, lr_scheduler, dataset, train_diffusion, logger): if getattr(config.train, "debug_nan", False): DebugUnderflowOverflow(model) @@ -358,19 +361,21 @@ def check_nan_inf(model): shuffled_prompts = [] for prompt in prompts: tags = prompt.split(",") # Split the string into a list of tags - random.shuffle(tags) # Shuffle the tags + random.shuffle(tags) # Shuffle the tags shuffled_prompts.append(",".join(tags)) # Join them back into a string if "T5" in config.text_encoder.text_encoder_name: with torch.no_grad(): txt_tokens = tokenizer( - shuffled_prompts, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt" + shuffled_prompts, + max_length=max_length, + padding="max_length", + truncation=True, + return_tensors="pt", ).to(accelerator.device) y = text_encoder(txt_tokens.input_ids, attention_mask=txt_tokens.attention_mask)[0][:, None] y_mask = txt_tokens.attention_mask[:, None, None] - elif ( - "gemma" in config.text_encoder.text_encoder_name or "Qwen" in config.text_encoder.text_encoder_name - ): + elif "gemma" in config.text_encoder.text_encoder_name or "Qwen" in config.text_encoder.text_encoder_name: with torch.no_grad(): if not config.text_encoder.chi_prompt: max_length_all = config.text_encoder.model_max_length @@ -430,13 +435,13 @@ def check_nan_inf(model): # Check if the loss is NaN if torch.isnan(loss): loss_nan_timer += 1 - print(f'Skip nan: {loss_nan_timer}') + print(f"Skip nan: {loss_nan_timer}") continue # Skip the rest of the loop iteration if loss is NaN accelerator.backward(loss) if accelerator.sync_gradients: grad_norm = accelerator.clip_grad_norm_(model.parameters(), config.train.gradient_clip) - + optimizer.step() lr_scheduler.step() accelerator.wait_for_everyone() @@ -462,9 +467,7 @@ def check_nan_inf(model): ) log_buffer.average() - current_step = ( - global_step - step // config.train.train_batch_size - ) % len(dataset) + current_step = (global_step - step // config.train.train_batch_size) % len(dataset) current_step = len(dataset) if current_step == 0 else current_step info = ( f"Epoch: {epoch} | Global Step: {global_step} | Local Step: {current_step} // {len(dataset)}, " @@ -639,7 +642,7 @@ def main(cfg: SanaConfig) -> None: if getattr(config.train, "deterministic_validation", False) else None ) - + tokenizer = text_encoder = None if not config.data.load_text_feat: tokenizer, text_encoder = get_tokenizer_and_text_encoder( @@ -732,7 +735,11 @@ def main(cfg: SanaConfig) -> None: torch.save({"caption_embeds": caption_emb, "emb_mask": caption_emb_mask}, prompt_embed_path) null_tokens = tokenizer( - "bad artwork,ugly,sketch,poorly drawn,messy,noisy,score: 0/10,blurry,low quality,old", max_length=max_length, padding="max_length", truncation=True, return_tensors="pt" + "bad artwork,ugly,sketch,poorly drawn,messy,noisy,score: 0/10,blurry,low quality,old", + max_length=max_length, + padding="max_length", + truncation=True, + return_tensors="pt", ).to(accelerator.device) if "T5" in config.text_encoder.text_encoder_name: null_token_emb = text_encoder(null_tokens.input_ids, attention_mask=null_tokens.attention_mask)[0] @@ -849,7 +856,7 @@ def main(cfg: SanaConfig) -> None: config.train.lr_schedule_args["num_warmup_steps"] * num_replicas ) lr_scheduler = build_lr_scheduler(config.train, optimizer, dataset, 1) - + logger.warning( f"{colored(f'Basic Setting: ', 'green', attrs=['bold'])}" f"lr: {config.train.optimizer['lr']:.9f}, bs: {config.train.train_batch_size}, gc: {config.train.grad_checkpointing}, " From 8aaa2c9183ae8762ccd2f09bfd829b3e413dc877 Mon Sep 17 00:00:00 2001 From: Muinez Date: Fri, 29 Nov 2024 07:54:35 +0300 Subject: [PATCH 11/14] fix saving error --- train_scripts/train_local.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/train_scripts/train_local.py b/train_scripts/train_local.py index bc20452..f77cd4c 100644 --- a/train_scripts/train_local.py +++ b/train_scripts/train_local.py @@ -504,10 +504,12 @@ def check_nan_inf(model): checkpoints_dir = osp.join(config.work_dir, "checkpoints") # Remove all old checkpoint files in the directory - for filename in os.listdir(checkpoints_dir): - file_path = osp.join(checkpoints_dir, filename) - if os.path.isfile(file_path): - os.remove(file_path) + if(os.path.exists(checkpoints_dir)): + for filename in os.listdir(checkpoints_dir): + file_path = osp.join(checkpoints_dir, filename) + if os.path.isfile(file_path): + os.remove(file_path) + ckpt_saved_path = save_checkpoint( checkpoints_dir, epoch=epoch, From 6fa0af0d628408c2ce9705947bdfdced1c12252f Mon Sep 17 00:00:00 2001 From: Muinez Date: Fri, 29 Nov 2024 16:50:47 +0300 Subject: [PATCH 12/14] fix bugs & minor changes --- train_scripts/train_local.py | 174 ++++++++++++++++++----------------- 1 file changed, 92 insertions(+), 82 deletions(-) diff --git a/train_scripts/train_local.py b/train_scripts/train_local.py index f77cd4c..c03231a 100644 --- a/train_scripts/train_local.py +++ b/train_scripts/train_local.py @@ -15,12 +15,9 @@ # SPDX-License-Identifier: Apache-2.0 import datetime -import getpass import hashlib -import json import os import os.path as osp -import random import time import types import warnings @@ -31,9 +28,9 @@ import torch import torch.utils import torch.utils.data +from PIL import Image from accelerate import Accelerator, InitProcessGroupKwargs from accelerate.utils import DistributedType -from PIL import Image from termcolor import colored warnings.filterwarnings("ignore") # ignore warning @@ -44,7 +41,7 @@ import random from diffusion import DPMS, FlowEuler, Scheduler -from diffusion.model.builder import build_model, get_tokenizer_and_text_encoder, get_vae, vae_decode, vae_encode +from diffusion.model.builder import build_model, get_tokenizer_and_text_encoder, get_vae, vae_decode from diffusion.model.respace import compute_density_for_timestep_sampling from diffusion.utils.checkpoint import load_checkpoint, save_checkpoint from diffusion.utils.config import SanaConfig @@ -68,8 +65,7 @@ def set_fsdp_env(): @torch.inference_mode() -def log_validation(accelerator, config, model, logger, step, device, vae=None, init_noise=None): - +def idation(accelerator, config, model, step, device, vae=None, init_noise=None): torch.cuda.empty_cache() vis_sampler = config.scheduler.vis_sampler model = accelerator.unwrap_model(model).eval() @@ -82,7 +78,12 @@ def log_validation(accelerator, config, model, logger, step, device, vae=None, i logger.info("Running validation... ") image_logs = [] - def run_sampling(init_z=None, label_suffix="", vae=None, sampler="dpm-solver"): + del_vae = False + if vae is None: + vae = get_vae(config.vae.vae_type, config.vae.vae_pretrained, accelerator.device).to(torch.float16) + del_vae = True + + def run_sampling(init_z=None, label_suffix="", sampler="dpm-solver"): latents = [] current_image_logs = [] @@ -147,10 +148,6 @@ def run_sampling(init_z=None, label_suffix="", vae=None, sampler="dpm-solver"): latents.append(denoised) torch.cuda.empty_cache() - del_vae = False - if vae is None: - vae = get_vae(config.vae.vae_type, config.vae.vae_pretrained, accelerator.device).to(torch.float16) - del_vae = True for prompt, latent in zip(validation_prompts, latents): latent = latent.to(torch.float16) samples = vae_decode(config.vae.vae_type, vae, latent) @@ -160,19 +157,20 @@ def run_sampling(init_z=None, label_suffix="", vae=None, sampler="dpm-solver"): image = Image.fromarray(samples) current_image_logs.append({"validation_prompt": prompt + label_suffix, "images": [image]}) - if del_vae: - vae = None - gc.collect() - torch.cuda.empty_cache() return current_image_logs # First run with original noise - image_logs += run_sampling(init_z=None, label_suffix="", vae=vae, sampler=vis_sampler) + image_logs += run_sampling(init_z=None, label_suffix="", sampler=vis_sampler) # Second run with init_noise if provided if init_noise is not None: init_noise = torch.clone(init_noise).to(device) - image_logs += run_sampling(init_z=init_noise, label_suffix=" w/ init noise", vae=vae, sampler=vis_sampler) + image_logs += run_sampling(init_z=init_noise, label_suffix=" w/ init noise", sampler=vis_sampler) + + if del_vae: + vae = None + gc.collect() + torch.cuda.empty_cache() formatted_images = [] for log in image_logs: @@ -204,13 +202,13 @@ def concatenate_images(image_caption, images_per_row=5, image_format="webp"): widths, heights = zip(*(img.size for img in images)) max_width = max(widths) - total_height = sum(heights[i : i + images_per_row][0] for i in range(0, len(images), images_per_row)) + total_height = sum(heights[i: i + images_per_row][0] for i in range(0, len(images), images_per_row)) new_im = Image.new("RGB", (max_width * images_per_row, total_height)) y_offset = 0 for i in range(0, len(images), images_per_row): - row_images = images[i : i + images_per_row] + row_images = images[i: i + images_per_row] x_offset = 0 for img in row_images: new_im.paste(img, (x_offset, y_offset)) @@ -254,7 +252,7 @@ def __getitem__(self, idx): return next(loader) except StopIteration: self.loaders.remove(loader) - print(f"bucket ended, {len(self.loaders)}") + logger.info(f"bucket ended, {len(self.loaders)}") def __len__(self): return self.size @@ -308,7 +306,7 @@ def __getitem__(self, idx): data = self.getdata(idx) return data except Exception as e: - print(f"Error details: {str(e)}") + logger.error(f"Error details: {str(e)}") idx = idx + 1 raise RuntimeError("Too many bad data.") @@ -316,7 +314,7 @@ def __len__(self): return len(self.images) -def train(config, args, accelerator, model, optimizer, lr_scheduler, dataset, train_diffusion, logger): +def train(config, args, accelerator, model, optimizer, lr_scheduler, dataset, train_diffusion): if getattr(config.train, "debug_nan", False): DebugUnderflowOverflow(model) logger.info("NaN debugger registered. Start to detect overflow during training.") @@ -325,10 +323,40 @@ def train(config, args, accelerator, model, optimizer, lr_scheduler, dataset, tr def check_nan_inf(model): for name, param in model.named_parameters(): if torch.isnan(param).any() or torch.isinf(param).any(): - print(f"NaN/Inf detected in {name}") + logger.error(f"NaN/Inf detected in {name}") check_nan_inf(model) + def save_model(save_metric=True): + accelerator.wait_for_everyone() + if accelerator.is_main_process: + os.umask(0o000) + checkpoints_dir = osp.join(config.work_dir, "checkpoints") + + # Remove all old checkpoint files in the directory + if os.path.exists(checkpoints_dir): + for filename in os.listdir(checkpoints_dir): + file_path = osp.join(checkpoints_dir, filename) + if os.path.isfile(file_path): + os.remove(file_path) + + ckpt_saved_path = save_checkpoint( + checkpoints_dir, + epoch=epoch, + step=global_step, + model=accelerator.unwrap_model(model), + optimizer=optimizer, + lr_scheduler=lr_scheduler, + generator=generator, + add_symlink=True, + ) + if save_metric: + online_metric_monitor_dir = osp.join(config.work_dir, config.train.online_metric_dir) + os.makedirs(online_metric_monitor_dir, exist_ok=True) + with open(f"{online_metric_monitor_dir}/{ckpt_saved_path.split('/')[-1]}.txt", "w") as f: + f.write(osp.join(config.work_dir, "config.py") + "\n") + f.write(ckpt_saved_path) + global_step = start_step + 1 skip_step = max(config.train.skip_step, global_step) % len(dataset) skip_step = skip_step if skip_step < (len(dataset) - 20) else 0 @@ -344,6 +372,7 @@ def check_nan_inf(model): data_time_all = 0 lm_time_all = 0 model_time_all = 0 + optimizer_time_all = 0 dataset.make_loaders(config.train.train_batch_size) for step, batch in enumerate(dataset): # image, json_info, key = batch @@ -385,7 +414,7 @@ def check_nan_inf(model): prompt = [chi_prompt + i for i in shuffled_prompts] num_chi_prompt_tokens = len(tokenizer.encode(chi_prompt)) max_length_all = ( - num_chi_prompt_tokens + config.text_encoder.model_max_length - 2 + num_chi_prompt_tokens + config.text_encoder.model_max_length - 2 ) # magic number 2: [bos], [_] txt_tokens = tokenizer( prompt, @@ -399,10 +428,10 @@ def check_nan_inf(model): ) # first bos and end N-1 y = text_encoder(txt_tokens.input_ids, attention_mask=txt_tokens.attention_mask)[0][:, None][ :, :, select_index - ] + ] y_mask = txt_tokens.attention_mask[:, None, None][:, :, :, select_index] else: - print("error") + logger.error("error") exit() # Sample a random timestep for each image @@ -435,17 +464,19 @@ def check_nan_inf(model): # Check if the loss is NaN if torch.isnan(loss): loss_nan_timer += 1 - print(f"Skip nan: {loss_nan_timer}") + logger.warning(f"Skip nan: {loss_nan_timer}") continue # Skip the rest of the loop iteration if loss is NaN accelerator.backward(loss) if accelerator.sync_gradients: grad_norm = accelerator.clip_grad_norm_(model.parameters(), config.train.gradient_clip) + model_time_all += time.time() - model_time_start + optimizer_time_start = time.time() optimizer.step() + optimizer_time_all += time.time() - optimizer_time_start lr_scheduler.step() accelerator.wait_for_everyone() - model_time_all += time.time() - model_time_start lr = lr_scheduler.get_last_lr()[0] logs = {args.loss_report_name: accelerator.gather(loss).mean().item()} @@ -457,6 +488,7 @@ def check_nan_inf(model): t = (time.time() - last_tic) / config.train.log_interval t_d = data_time_all / config.train.log_interval t_m = model_time_all / config.train.log_interval + t_opt = optimizer_time_all / config.train.log_interval t_lm = lm_time_all / config.train.log_interval avg_time = (time.time() - time_start) / (step + 1) eta = str(datetime.timedelta(seconds=int(avg_time * (total_steps - global_step - 1)))) @@ -471,7 +503,7 @@ def check_nan_inf(model): current_step = len(dataset) if current_step == 0 else current_step info = ( f"Epoch: {epoch} | Global Step: {global_step} | Local Step: {current_step} // {len(dataset)}, " - f"total_eta: {eta}, epoch_eta:{eta_epoch}, time: all:{t:.3f}, model:{t_m:.3f}, data:{t_d:.3f}, " + f"total_eta: {eta}, epoch_eta:{eta_epoch}, time: all:{t:.3f}, model:{t_m:.3f}, optimizer:{t_opt:.3f}, data:{t_d:.3f}, " f"lm:{t_lm:.3f}, lr:{lr:.3e}, " ) info += ( @@ -485,6 +517,7 @@ def check_nan_inf(model): log_buffer.clear() data_time_all = 0 model_time_all = 0 + optimizer_time_all = 0 lm_time_all = 0 if accelerator.is_main_process: logger.info(info) @@ -498,34 +531,7 @@ def check_nan_inf(model): if loss_nan_timer > 20: raise ValueError("Loss is NaN too much times. Break here.") if global_step % config.train.save_model_steps == 0: - accelerator.wait_for_everyone() - if accelerator.is_main_process: - os.umask(0o000) - checkpoints_dir = osp.join(config.work_dir, "checkpoints") - - # Remove all old checkpoint files in the directory - if(os.path.exists(checkpoints_dir)): - for filename in os.listdir(checkpoints_dir): - file_path = osp.join(checkpoints_dir, filename) - if os.path.isfile(file_path): - os.remove(file_path) - - ckpt_saved_path = save_checkpoint( - checkpoints_dir, - epoch=epoch, - step=global_step, - model=accelerator.unwrap_model(model), - optimizer=optimizer, - lr_scheduler=lr_scheduler, - generator=generator, - add_symlink=True, - ) - if config.train.online_metric and global_step % config.train.eval_metric_step == 0 and step > 1: - online_metric_monitor_dir = osp.join(config.work_dir, config.train.online_metric_dir) - os.makedirs(online_metric_monitor_dir, exist_ok=True) - with open(f"{online_metric_monitor_dir}/{ckpt_saved_path.split('/')[-1]}.txt", "w") as f: - f.write(osp.join(config.work_dir, "config.py") + "\n") - f.write(ckpt_saved_path) + save_model(config.train.online_metric and global_step % config.train.eval_metric_step == 0 and step > 1) # if (time.time() - training_start_time) / 3600 > 3.8: # logger.info(f"Stopping training at epoch {epoch}, step {global_step} due to time limit.") @@ -534,22 +540,20 @@ def check_nan_inf(model): accelerator.wait_for_everyone() if accelerator.is_main_process: if validation_noise is not None: - log_validation( + idation( accelerator=accelerator, config=config, model=model, - logger=logger, step=global_step, device=accelerator.device, vae=vae, init_noise=validation_noise, ) else: - log_validation( + idation( accelerator=accelerator, config=config, model=model, - logger=logger, step=global_step, device=accelerator.device, vae=vae, @@ -557,15 +561,18 @@ def check_nan_inf(model): data_time_start = time.time() + if epoch % config.train.save_model_epochs == 0 or epoch == config.train.num_epochs and not config.debug: + save_model() accelerator.wait_for_everyone() + save_model() @pyrallis.wrap() def main(cfg: SanaConfig) -> None: global start_epoch, start_step, vae, generator, num_replicas, rank, training_start_time - global load_vae_feat, load_text_feat, validation_noise, text_encoder, tokenizer + global load_vae_feat, load_text_feat, validation_noise, text_encoder, tokenizer, logger global max_length, validation_prompts, latent_size, valid_prompt_embed_suffix, null_embed_path - global image_size, cache_file, total_steps + global image_size, total_steps config = cfg args = cfg @@ -707,7 +714,7 @@ def main(cfg: SanaConfig) -> None: caption_emb = text_encoder(txt_tokens.input_ids, attention_mask=txt_tokens.attention_mask)[0] caption_emb_mask = txt_tokens.attention_mask elif ( - "gemma" in config.text_encoder.text_encoder_name or "Qwen" in config.text_encoder.text_encoder_name + "gemma" in config.text_encoder.text_encoder_name or "Qwen" in config.text_encoder.text_encoder_name ): if not config.text_encoder.chi_prompt: max_length_all = config.text_encoder.model_max_length @@ -716,7 +723,7 @@ def main(cfg: SanaConfig) -> None: prompt = chi_prompt + prompt num_chi_prompt_tokens = len(tokenizer.encode(chi_prompt)) max_length_all = ( - num_chi_prompt_tokens + config.text_encoder.model_max_length - 2 + num_chi_prompt_tokens + config.text_encoder.model_max_length - 2 ) # magic number 2: [bos], [_] txt_tokens = tokenizer( @@ -728,8 +735,8 @@ def main(cfg: SanaConfig) -> None: ).to(accelerator.device) select_index = [0] + list(range(-config.text_encoder.model_max_length + 1, 0)) caption_emb = text_encoder(txt_tokens.input_ids, attention_mask=txt_tokens.attention_mask)[0][ - :, select_index - ] + :, select_index + ] caption_emb_mask = txt_tokens.attention_mask[:, select_index] else: raise ValueError(f"{config.text_encoder.text_encoder_name} is not supported!!") @@ -821,15 +828,17 @@ def main(cfg: SanaConfig) -> None: # 2-1. load model if args.load_from is not None: config.model.load_from = args.load_from - # if config.model.load_from is not None and load_from: - # _, missing, unexpected, _ = load_checkpoint( - # config.model.load_from, - # model, - # load_ema=config.model.resume_from.get("load_ema", False), - # null_embed_path=null_embed_path, - # ) - # logger.warning(f"Missing keys: {missing}") - # logger.warning(f"Unexpected keys: {unexpected}") + if config.model.load_from is not None and load_from: + _, missing, unexpected, _ = load_checkpoint( + config.model.load_from, + model, + load_ema=config.model.resume_from.get("load_ema", False), + null_embed_path=null_embed_path, + ) + if missing: + logger.warning(f"Missing keys: {missing}") + if unexpected: + logger.warning(f"Unexpected keys: {unexpected}") # prepare for FSDP clip grad norm calculation if accelerator.distributed_type == DistributedType.FSDP: @@ -855,7 +864,7 @@ def main(cfg: SanaConfig) -> None: if config.train.lr_schedule_args and config.train.lr_schedule_args.get("num_warmup_steps", None): config.train.lr_schedule_args["num_warmup_steps"] = ( - config.train.lr_schedule_args["num_warmup_steps"] * num_replicas + config.train.lr_schedule_args["num_warmup_steps"] * num_replicas ) lr_scheduler = build_lr_scheduler(config.train, optimizer, dataset, 1) @@ -906,8 +915,10 @@ def main(cfg: SanaConfig) -> None: null_embed_path=null_embed_path, ) - logger.warning(f"Missing keys: {missing}") - logger.warning(f"Unexpected keys: {unexpected}") + if missing: + logger.warning(f"Missing keys: {missing}") + if unexpected: + logger.warning(f"Unexpected keys: {unexpected}") path = osp.basename(config.model.resume_from["checkpoint"]) try: @@ -939,8 +950,7 @@ def main(cfg: SanaConfig) -> None: optimizer=optimizer, lr_scheduler=lr_scheduler, dataset=dataset, - train_diffusion=train_diffusion, - logger=logger, + train_diffusion=train_diffusion ) From ab73314b64f1012b36c4e2cc3db6c5d1d0105c40 Mon Sep 17 00:00:00 2001 From: Muinez Date: Sun, 1 Dec 2024 00:44:06 +0300 Subject: [PATCH 13/14] remove prompts shuffling --- train_scripts/train_local.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/train_scripts/train_local.py b/train_scripts/train_local.py index c03231a..c46d22f 100644 --- a/train_scripts/train_local.py +++ b/train_scripts/train_local.py @@ -387,16 +387,16 @@ def save_model(save_metric=True): lm_time_start = time.time() prompts = list(batch[2]) - shuffled_prompts = [] - for prompt in prompts: - tags = prompt.split(",") # Split the string into a list of tags - random.shuffle(tags) # Shuffle the tags - shuffled_prompts.append(",".join(tags)) # Join them back into a string + # shuffled_prompts = [] + # for prompt in prompts: + # tags = prompt.split(",") # Split the string into a list of tags + # random.shuffle(tags) # Shuffle the tags + # shuffled_prompts.append(",".join(tags)) # Join them back into a string if "T5" in config.text_encoder.text_encoder_name: with torch.no_grad(): txt_tokens = tokenizer( - shuffled_prompts, + prompts, max_length=max_length, padding="max_length", truncation=True, @@ -408,10 +408,10 @@ def save_model(save_metric=True): with torch.no_grad(): if not config.text_encoder.chi_prompt: max_length_all = config.text_encoder.model_max_length - prompt = shuffled_prompts + prompt = prompts else: chi_prompt = "\n".join(config.text_encoder.chi_prompt) - prompt = [chi_prompt + i for i in shuffled_prompts] + prompt = [chi_prompt + i for i in prompts] num_chi_prompt_tokens = len(tokenizer.encode(chi_prompt)) max_length_all = ( num_chi_prompt_tokens + config.text_encoder.model_max_length - 2 From 157b1746cb0284e98785ca921514b65a6450a287 Mon Sep 17 00:00:00 2001 From: Muinez <76997923+Muinez@users.noreply.github.com> Date: Sat, 4 Jan 2025 11:29:31 +0300 Subject: [PATCH 14/14] fix error --- train_scripts/make_buckets.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/train_scripts/make_buckets.py b/train_scripts/make_buckets.py index 36c4ed4..7a84a46 100644 --- a/train_scripts/make_buckets.py +++ b/train_scripts/make_buckets.py @@ -15,6 +15,8 @@ from diffusion.model.builder import get_vae, vae_encode from diffusion.utils.config import SanaConfig +from PIL import PngImagePlugin +PngImagePlugin.MAX_TEXT_CHUNK = 100 * 1024 * 1024 @pyrallis.wrap() def main(config: SanaConfig) -> None: