diff --git a/configs/sana_config/512ms/Sana_600M_img256.yaml b/configs/sana_config/512ms/Sana_600M_img256.yaml new file mode 100644 index 0000000..2b082ce --- /dev/null +++ b/configs/sana_config/512ms/Sana_600M_img256.yaml @@ -0,0 +1,97 @@ +data: + data_dir: [data/data_public/dir1] + image_size: 256 + type: MorphicStreamingDataset + sort_dataset: false +# model config +model: + model: SanaMS_600M_P1_D28 + image_size: 256 + mixed_precision: bf16 + fp32_attention: true + load_from: + resume_from: + aspect_ratio_type: ASPECT_RATIO_256 + multi_scale: false + #pe_interpolation: 1. + attn_type: linear + linear_head_dim: 32 + ffn_type: glumbconv + mlp_acts: + - silu + - silu + - null + mlp_ratio: 2.5 + use_pe: true + qk_norm: true + class_dropout_prob: 0.1 +# VAE setting +vae: + vae_type: dc-ae + vae_pretrained: mit-han-lab/dc-ae-f32c32-sana-1.0 + scale_factor: 0.41407 + vae_latent_dim: 32 + vae_downsample_rate: 32 + sample_posterior: true +# text encoder +text_encoder: + # text_encoder_name: + text_encoder_name : T5 + y_norm: true + y_norm_scale_factor: 0.01 + model_max_length: 300 + # CHI + chi_prompt: + - 'Given a user prompt, generate an "Enhanced prompt" that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt:' + - '- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes.' + - '- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.' + - 'Here are examples of how to transform or refine prompts:' + - '- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers.' + - '- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus passing by towering glass skyscrapers.' + - 'Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:' + - 'User Prompt: ' +# Sana schedule Flow +scheduler: + predict_v: true + noise_schedule: linear_flow + pred_sigma: false + flow_shift: 1.0 + # logit-normal timestep + weighting_scheme: logit_normal + logit_mean: 0.0 + logit_std: 1.0 + vis_sampler: flow_dpm-solver +# training setting +train: + use_fsdp: true + num_workers: 10 + seed: 1 + train_batch_size: 256 + num_epochs: 100 + gradient_accumulation_steps: 1 + grad_checkpointing: false + gradient_clip: 0.1 + optimizer: + betas: + - 0.9 + - 0.999 + - 0.9999 + eps: + - 1.0e-30 + - 1.0e-16 + lr: 0.0001 + type: CAMEWrapper + weight_decay: 0.0 + lr_schedule: constant + lr_schedule_args: + num_warmup_steps: 2000 + local_save_vis: true # if save log image locally + visualize: false + eval_sampling_steps: 500 + log_interval: 20 + save_model_epochs: 5 + save_model_steps: 500 + work_dir: output/debug + online_metric: false + eval_metric_step: 2000 + online_metric_dir: metric_helper diff --git a/diffusion/data/datasets/__init__.py b/diffusion/data/datasets/__init__.py index 1cd1ec6..3e2cfc1 100755 --- a/diffusion/data/datasets/__init__.py +++ b/diffusion/data/datasets/__init__.py @@ -1,3 +1,4 @@ from .sana_data import SanaImgDataset, SanaWebDataset from .sana_data_multi_scale import DummyDatasetMS, SanaWebDatasetMS +from .morphic_data import MorphicStreamingDataset from .utils import * diff --git a/diffusion/data/datasets/morphic_data.py b/diffusion/data/datasets/morphic_data.py new file mode 100644 index 0000000..e57b755 --- /dev/null +++ b/diffusion/data/datasets/morphic_data.py @@ -0,0 +1,33 @@ +from functools import partial +from numpy.random import choice +import functools +import torchvision.transforms.functional as TF +import litdata as ld + + + +def select_sketch_sample(keys,weights): + assert sum(weights)==1 + return choice(keys,p=weights) + + +class MorphicStreamingDataset(ld.StreamingDataset): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._keys=['sketch_refined','sketch_anime','sketch_opensketch'] + self._weights=[0.4,0.3,0.3] + self.aspect_ratio = 1.0 + def __getitem__(self, index): + sample = super().__getitem__(index) + output = {} + sketch_sampler = functools.partial(select_sketch_sample,keys=self._keys, + weights=self._weights) + # (TODO) Check if this preprocessing is correct + output['image']= TF.center_crop(sample['image'],(256,256)) + output['text'] = sample['caption'] + if output['image'].shape[0]==1: + output['image']=output['image'].repeat(3,1,1) + output['image']= (output['image']/127.5) - 1.0 + output['sketch']= TF.to_tensor(TF.center_crop(sample[sketch_sampler()],(256,256))).reshape((1,256,256)) + return output + \ No newline at end of file diff --git a/train_scripts/train.py b/train_scripts/train.py index 72a3ec0..79c7324 100755 --- a/train_scripts/train.py +++ b/train_scripts/train.py @@ -52,6 +52,10 @@ from diffusion.utils.misc import DebugUnderflowOverflow, init_random_seed, read_config, set_random_seed from diffusion.utils.optimizer import auto_scale_lr, build_optimizer +from diffusion.data.datasets import MorphicStreamingDataset + +import litdata as ld + os.environ["TOKENIZERS_PARALLELISM"] = "false" @@ -59,7 +63,7 @@ def set_fsdp_env(): os.environ["ACCELERATE_USE_FSDP"] = "true" os.environ["FSDP_AUTO_WRAP_POLICY"] = "TRANSFORMER_BASED_WRAP" os.environ["FSDP_BACKWARD_PREFETCH"] = "BACKWARD_PRE" - os.environ["FSDP_TRANSFORMER_CLS_TO_WRAP"] = "SanaBlock" + os.environ["FSDP_TRANSFORMER_CLS_TO_WRAP"] = "SanaMSBlock" @torch.inference_mode() @@ -84,7 +88,7 @@ def run_sampling(init_z=None, label_suffix="", vae=None, sampler="dpm-solver"): torch.randn(1, config.vae.vae_latent_dim, latent_size, latent_size, device=device) if init_z is None else init_z - ) + ).to(torch.bfloat16) embed = torch.load( osp.join(config.train.valid_prompt_embed_root, f"{prompt[:50]}_{valid_prompt_embed_suffix}"), map_location="cpu", @@ -273,13 +277,13 @@ def train(config, args, accelerator, model, optimizer, lr_scheduler, train_datal # Now you train the model for epoch in range(start_epoch + 1, config.train.num_epochs + 1): time_start, last_tic = time.time(), time.time() - sampler = ( - train_dataloader.batch_sampler.sampler - if (num_replicas > 1 or config.model.multi_scale) - else train_dataloader.sampler - ) - sampler.set_epoch(epoch) - sampler.set_start(max((skip_step - 1) * config.train.train_batch_size, 0)) + # sampler = ( + # train_dataloader.batch_sampler.sampler + # if (num_replicas > 1 or config.model.multi_scale) + # else train_dataloader.sampler + # ) + # sampler.set_epoch(epoch) + # sampler.set_start(max((skip_step - 1) * config.train.train_batch_size, 0)) if skip_step > 1 and accelerator.is_main_process: logger.info(f"Skipped Steps: {skip_step}") skip_step = 1 @@ -302,14 +306,14 @@ def train(config, args, accelerator, model, optimizer, lr_scheduler, train_datal enabled=(config.model.mixed_precision == "fp16" or config.model.mixed_precision == "bf16"), ): z = vae_encode( - config.vae.vae_type, vae, batch[0], config.vae.sample_posterior, accelerator.device + config.vae.vae_type, vae, batch['image'], config.vae.sample_posterior, accelerator.device ) accelerator.wait_for_everyone() vae_time_all += time.time() - vae_time_start clean_images = z - data_info = batch[3] + data_info = batch['text'] lm_time_start = time.time() if load_text_feat: @@ -319,7 +323,7 @@ def train(config, args, accelerator, model, optimizer, lr_scheduler, train_datal if "T5" in config.text_encoder.text_encoder_name: with torch.no_grad(): txt_tokens = tokenizer( - batch[1], max_length=max_length, padding="max_length", truncation=True, return_tensors="pt" + batch['text'], max_length=max_length, padding="max_length", truncation=True, return_tensors="pt" ).to(accelerator.device) y = text_encoder(txt_tokens.input_ids, attention_mask=txt_tokens.attention_mask)[0][:, None] y_mask = txt_tokens.attention_mask[:, None, None] @@ -409,20 +413,20 @@ def train(config, args, accelerator, model, optimizer, lr_scheduler, train_datal datetime.timedelta( seconds=int( avg_time - * (train_dataloader_len - sampler.step_start // config.train.train_batch_size - step - 1) + * (train_dataloader_len // config.train.train_batch_size - step - 1) ) ) ) log_buffer.average() current_step = ( - global_step - sampler.step_start // config.train.train_batch_size + global_step // config.train.train_batch_size ) % train_dataloader_len current_step = train_dataloader_len if current_step == 0 else current_step info = ( f"Epoch: {epoch} | Global Step: {global_step} | Local Step: {current_step} // {train_dataloader_len}, " f"total_eta: {eta}, epoch_eta:{eta_epoch}, time: all:{t:.3f}, model:{t_m:.3f}, data:{t_d:.3f}, " - f"lm:{t_lm:.3f}, vae:{t_vae:.3f}, lr:{lr:.3e}, Cap: {batch[5][0]}, " + f"lm:{t_lm:.3f}, vae:{t_vae:.3f}, lr:{lr:.3e}" ) info += ( f"s:({model.module.h}, {model.module.w}), " @@ -504,7 +508,7 @@ def train(config, args, accelerator, model, optimizer, lr_scheduler, train_datal # for internal, refactor dataloader logic to remove the ad-hoc implementation if ( config.model.multi_scale - and (train_dataloader_len - sampler.step_start // config.train.train_batch_size - step) < 30 + and (train_dataloader_len // config.train.train_batch_size - step) < 30 ): global_step = epoch * train_dataloader_len logger.info("Early stop current iteration") @@ -541,7 +545,7 @@ def main(cfg: SanaConfig) -> None: global load_vae_feat, load_text_feat, validation_noise, text_encoder, tokenizer global max_length, validation_prompts, latent_size, valid_prompt_embed_suffix, null_embed_path global image_size, cache_file, total_steps - + import uuid config = cfg args = cfg # config = read_config(args.config) @@ -576,6 +580,7 @@ def main(cfg: SanaConfig) -> None: set_fsdp_env() fsdp_plugin = FullyShardedDataParallelPlugin( state_dict_config=FullStateDictConfig(offload_to_cpu=False, rank0_only=False), + ) else: init_train = "DDP" @@ -602,8 +607,8 @@ def main(cfg: SanaConfig) -> None: pyrallis.dump(config, open(osp.join(config.work_dir, "config.yaml"), "w"), sort_keys=False, indent=4) if args.report_to == "wandb": import wandb - - wandb.init(project=args.tracker_project_name, name=args.name, resume="allow", id=args.name) + uniqueid = str(uuid.uuid4())[:4] + wandb.init(project=args.tracker_project_name , name=args.name + '_' + uniqueid , resume="allow", id=args.name + '_' + uniqueid) logger.info(f"Config: \n{config}") logger.info(f"World_size: {get_world_size()}, seed: {config.train.seed}") @@ -814,17 +819,18 @@ def main(cfg: SanaConfig) -> None: ] num_replicas = int(os.environ["WORLD_SIZE"]) rank = int(os.environ["RANK"]) - dataset = build_dataset( - asdict(config.data), - resolution=image_size, - aspect_ratio_type=config.model.aspect_ratio_type, - real_prompt_ratio=config.train.real_prompt_ratio, - max_length=max_length, - config=config, - caption_proportion=config.data.caption_proportion, - sort_dataset=config.data.sort_dataset, - vae_downsample_rate=config.vae.vae_downsample_rate, - ) + # dataset = build_dataset( + # asdict(config.data), + # resolution=image_size, + # aspect_ratio_type=config.model.aspect_ratio_type, + # real_prompt_ratio=config.train.real_prompt_ratio, + # max_length=max_length, + # config=config, + # caption_proportion=config.data.caption_proportion, + # sort_dataset=config.data.sort_dataset, + # vae_downsample_rate=config.vae.vae_downsample_rate, + # ) + dataset = MorphicStreamingDataset('/data/mharikum/laion/filtered_laion_2B_256-data-litdata/') accelerator.wait_for_everyone() if config.model.multi_scale: drop_last = True @@ -861,13 +867,14 @@ def main(cfg: SanaConfig) -> None: logger.info(f"rank-{rank} Cached file len: {len(train_dataloader.batch_sampler.cached_idx)}") else: sampler = DistributedRangedSampler(dataset, num_replicas=num_replicas, rank=rank) - train_dataloader = build_dataloader( - dataset, - num_workers=config.train.num_workers, - batch_size=config.train.train_batch_size, - shuffle=False, - sampler=sampler, - ) + # train_dataloader = build_dataloader( + # dataset, + # num_workers=config.train.num_workers, + # batch_size=config.train.train_batch_size, + # shuffle=False, + # sampler=sampler, + # ) + train_dataloader = ld.StreamingDataLoader(dataset,num_workers=config.train.num_workers,drop_last=True,batch_size=config.train.train_batch_size,shuffle=False) train_dataloader_len = len(train_dataloader) load_vae_feat = getattr(train_dataloader.dataset, "load_vae_feat", False) load_text_feat = getattr(train_dataloader.dataset, "load_text_feat", False) @@ -896,12 +903,12 @@ def main(cfg: SanaConfig) -> None: timestamp = time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime()) - if accelerator.is_main_process: - tracker_config = dict(vars(config)) - try: - accelerator.init_trackers(args.tracker_project_name, tracker_config) - except: - accelerator.init_trackers(f"tb_{timestamp}") + # if accelerator.is_main_process: + # tracker_config = dict(vars(config)) + # try: + # accelerator.init_trackers(args.tracker_project_name, tracker_config) + # except: + # accelerator.init_trackers(f"tb_{timestamp}") start_epoch = 0 start_step = 0