Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding temp code to run SANA training with litdata #132

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 97 additions & 0 deletions configs/sana_config/512ms/Sana_600M_img256.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
data:
data_dir: [data/data_public/dir1]
image_size: 256
type: MorphicStreamingDataset
sort_dataset: false
# model config
model:
model: SanaMS_600M_P1_D28
image_size: 256
mixed_precision: bf16
fp32_attention: true
load_from:
resume_from:
aspect_ratio_type: ASPECT_RATIO_256
multi_scale: false
#pe_interpolation: 1.
attn_type: linear
linear_head_dim: 32
ffn_type: glumbconv
mlp_acts:
- silu
- silu
- null
mlp_ratio: 2.5
use_pe: true
qk_norm: true
class_dropout_prob: 0.1
# VAE setting
vae:
vae_type: dc-ae
vae_pretrained: mit-han-lab/dc-ae-f32c32-sana-1.0
scale_factor: 0.41407
vae_latent_dim: 32
vae_downsample_rate: 32
sample_posterior: true
# text encoder
text_encoder:
# text_encoder_name:
text_encoder_name : T5
y_norm: true
y_norm_scale_factor: 0.01
model_max_length: 300
# CHI
chi_prompt:
- 'Given a user prompt, generate an "Enhanced prompt" that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt:'
- '- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes.'
- '- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.'
- 'Here are examples of how to transform or refine prompts:'
- '- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers.'
- '- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus passing by towering glass skyscrapers.'
- 'Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:'
- 'User Prompt: '
# Sana schedule Flow
scheduler:
predict_v: true
noise_schedule: linear_flow
pred_sigma: false
flow_shift: 1.0
# logit-normal timestep
weighting_scheme: logit_normal
logit_mean: 0.0
logit_std: 1.0
vis_sampler: flow_dpm-solver
# training setting
train:
use_fsdp: true
num_workers: 10
seed: 1
train_batch_size: 256
num_epochs: 100
gradient_accumulation_steps: 1
grad_checkpointing: false
gradient_clip: 0.1
optimizer:
betas:
- 0.9
- 0.999
- 0.9999
eps:
- 1.0e-30
- 1.0e-16
lr: 0.0001
type: CAMEWrapper
weight_decay: 0.0
lr_schedule: constant
lr_schedule_args:
num_warmup_steps: 2000
local_save_vis: true # if save log image locally
visualize: false
eval_sampling_steps: 500
log_interval: 20
save_model_epochs: 5
save_model_steps: 500
work_dir: output/debug
online_metric: false
eval_metric_step: 2000
online_metric_dir: metric_helper
1 change: 1 addition & 0 deletions diffusion/data/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .sana_data import SanaImgDataset, SanaWebDataset
from .sana_data_multi_scale import DummyDatasetMS, SanaWebDatasetMS
from .morphic_data import MorphicStreamingDataset
from .utils import *
33 changes: 33 additions & 0 deletions diffusion/data/datasets/morphic_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from functools import partial
from numpy.random import choice
import functools
import torchvision.transforms.functional as TF
import litdata as ld



def select_sketch_sample(keys,weights):
assert sum(weights)==1
return choice(keys,p=weights)


class MorphicStreamingDataset(ld.StreamingDataset):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._keys=['sketch_refined','sketch_anime','sketch_opensketch']
self._weights=[0.4,0.3,0.3]
self.aspect_ratio = 1.0
def __getitem__(self, index):
sample = super().__getitem__(index)
output = {}
sketch_sampler = functools.partial(select_sketch_sample,keys=self._keys,
weights=self._weights)
# (TODO) Check if this preprocessing is correct
output['image']= TF.center_crop(sample['image'],(256,256))
output['text'] = sample['caption']
if output['image'].shape[0]==1:
output['image']=output['image'].repeat(3,1,1)
output['image']= (output['image']/127.5) - 1.0
output['sketch']= TF.to_tensor(TF.center_crop(sample[sketch_sampler()],(256,256))).reshape((1,256,256))
return output

93 changes: 50 additions & 43 deletions train_scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,14 +52,18 @@
from diffusion.utils.misc import DebugUnderflowOverflow, init_random_seed, read_config, set_random_seed
from diffusion.utils.optimizer import auto_scale_lr, build_optimizer

from diffusion.data.datasets import MorphicStreamingDataset

import litdata as ld

os.environ["TOKENIZERS_PARALLELISM"] = "false"


def set_fsdp_env():
os.environ["ACCELERATE_USE_FSDP"] = "true"
os.environ["FSDP_AUTO_WRAP_POLICY"] = "TRANSFORMER_BASED_WRAP"
os.environ["FSDP_BACKWARD_PREFETCH"] = "BACKWARD_PRE"
os.environ["FSDP_TRANSFORMER_CLS_TO_WRAP"] = "SanaBlock"
os.environ["FSDP_TRANSFORMER_CLS_TO_WRAP"] = "SanaMSBlock"


@torch.inference_mode()
Expand All @@ -84,7 +88,7 @@ def run_sampling(init_z=None, label_suffix="", vae=None, sampler="dpm-solver"):
torch.randn(1, config.vae.vae_latent_dim, latent_size, latent_size, device=device)
if init_z is None
else init_z
)
).to(torch.bfloat16)
embed = torch.load(
osp.join(config.train.valid_prompt_embed_root, f"{prompt[:50]}_{valid_prompt_embed_suffix}"),
map_location="cpu",
Expand Down Expand Up @@ -273,13 +277,13 @@ def train(config, args, accelerator, model, optimizer, lr_scheduler, train_datal
# Now you train the model
for epoch in range(start_epoch + 1, config.train.num_epochs + 1):
time_start, last_tic = time.time(), time.time()
sampler = (
train_dataloader.batch_sampler.sampler
if (num_replicas > 1 or config.model.multi_scale)
else train_dataloader.sampler
)
sampler.set_epoch(epoch)
sampler.set_start(max((skip_step - 1) * config.train.train_batch_size, 0))
# sampler = (
# train_dataloader.batch_sampler.sampler
# if (num_replicas > 1 or config.model.multi_scale)
# else train_dataloader.sampler
# )
# sampler.set_epoch(epoch)
# sampler.set_start(max((skip_step - 1) * config.train.train_batch_size, 0))
if skip_step > 1 and accelerator.is_main_process:
logger.info(f"Skipped Steps: {skip_step}")
skip_step = 1
Expand All @@ -302,14 +306,14 @@ def train(config, args, accelerator, model, optimizer, lr_scheduler, train_datal
enabled=(config.model.mixed_precision == "fp16" or config.model.mixed_precision == "bf16"),
):
z = vae_encode(
config.vae.vae_type, vae, batch[0], config.vae.sample_posterior, accelerator.device
config.vae.vae_type, vae, batch['image'], config.vae.sample_posterior, accelerator.device
)

accelerator.wait_for_everyone()
vae_time_all += time.time() - vae_time_start

clean_images = z
data_info = batch[3]
data_info = batch['text']

lm_time_start = time.time()
if load_text_feat:
Expand All @@ -319,7 +323,7 @@ def train(config, args, accelerator, model, optimizer, lr_scheduler, train_datal
if "T5" in config.text_encoder.text_encoder_name:
with torch.no_grad():
txt_tokens = tokenizer(
batch[1], max_length=max_length, padding="max_length", truncation=True, return_tensors="pt"
batch['text'], max_length=max_length, padding="max_length", truncation=True, return_tensors="pt"
).to(accelerator.device)
y = text_encoder(txt_tokens.input_ids, attention_mask=txt_tokens.attention_mask)[0][:, None]
y_mask = txt_tokens.attention_mask[:, None, None]
Expand Down Expand Up @@ -409,20 +413,20 @@ def train(config, args, accelerator, model, optimizer, lr_scheduler, train_datal
datetime.timedelta(
seconds=int(
avg_time
* (train_dataloader_len - sampler.step_start // config.train.train_batch_size - step - 1)
* (train_dataloader_len // config.train.train_batch_size - step - 1)
)
)
)
log_buffer.average()

current_step = (
global_step - sampler.step_start // config.train.train_batch_size
global_step // config.train.train_batch_size
) % train_dataloader_len
current_step = train_dataloader_len if current_step == 0 else current_step
info = (
f"Epoch: {epoch} | Global Step: {global_step} | Local Step: {current_step} // {train_dataloader_len}, "
f"total_eta: {eta}, epoch_eta:{eta_epoch}, time: all:{t:.3f}, model:{t_m:.3f}, data:{t_d:.3f}, "
f"lm:{t_lm:.3f}, vae:{t_vae:.3f}, lr:{lr:.3e}, Cap: {batch[5][0]}, "
f"lm:{t_lm:.3f}, vae:{t_vae:.3f}, lr:{lr:.3e}"
)
info += (
f"s:({model.module.h}, {model.module.w}), "
Expand Down Expand Up @@ -504,7 +508,7 @@ def train(config, args, accelerator, model, optimizer, lr_scheduler, train_datal
# for internal, refactor dataloader logic to remove the ad-hoc implementation
if (
config.model.multi_scale
and (train_dataloader_len - sampler.step_start // config.train.train_batch_size - step) < 30
and (train_dataloader_len // config.train.train_batch_size - step) < 30
):
global_step = epoch * train_dataloader_len
logger.info("Early stop current iteration")
Expand Down Expand Up @@ -541,7 +545,7 @@ def main(cfg: SanaConfig) -> None:
global load_vae_feat, load_text_feat, validation_noise, text_encoder, tokenizer
global max_length, validation_prompts, latent_size, valid_prompt_embed_suffix, null_embed_path
global image_size, cache_file, total_steps

import uuid
config = cfg
args = cfg
# config = read_config(args.config)
Expand Down Expand Up @@ -576,6 +580,7 @@ def main(cfg: SanaConfig) -> None:
set_fsdp_env()
fsdp_plugin = FullyShardedDataParallelPlugin(
state_dict_config=FullStateDictConfig(offload_to_cpu=False, rank0_only=False),

)
else:
init_train = "DDP"
Expand All @@ -602,8 +607,8 @@ def main(cfg: SanaConfig) -> None:
pyrallis.dump(config, open(osp.join(config.work_dir, "config.yaml"), "w"), sort_keys=False, indent=4)
if args.report_to == "wandb":
import wandb

wandb.init(project=args.tracker_project_name, name=args.name, resume="allow", id=args.name)
uniqueid = str(uuid.uuid4())[:4]
wandb.init(project=args.tracker_project_name , name=args.name + '_' + uniqueid , resume="allow", id=args.name + '_' + uniqueid)

logger.info(f"Config: \n{config}")
logger.info(f"World_size: {get_world_size()}, seed: {config.train.seed}")
Expand Down Expand Up @@ -814,17 +819,18 @@ def main(cfg: SanaConfig) -> None:
]
num_replicas = int(os.environ["WORLD_SIZE"])
rank = int(os.environ["RANK"])
dataset = build_dataset(
asdict(config.data),
resolution=image_size,
aspect_ratio_type=config.model.aspect_ratio_type,
real_prompt_ratio=config.train.real_prompt_ratio,
max_length=max_length,
config=config,
caption_proportion=config.data.caption_proportion,
sort_dataset=config.data.sort_dataset,
vae_downsample_rate=config.vae.vae_downsample_rate,
)
# dataset = build_dataset(
# asdict(config.data),
# resolution=image_size,
# aspect_ratio_type=config.model.aspect_ratio_type,
# real_prompt_ratio=config.train.real_prompt_ratio,
# max_length=max_length,
# config=config,
# caption_proportion=config.data.caption_proportion,
# sort_dataset=config.data.sort_dataset,
# vae_downsample_rate=config.vae.vae_downsample_rate,
# )
dataset = MorphicStreamingDataset('/data/mharikum/laion/filtered_laion_2B_256-data-litdata/')
accelerator.wait_for_everyone()
if config.model.multi_scale:
drop_last = True
Expand Down Expand Up @@ -861,13 +867,14 @@ def main(cfg: SanaConfig) -> None:
logger.info(f"rank-{rank} Cached file len: {len(train_dataloader.batch_sampler.cached_idx)}")
else:
sampler = DistributedRangedSampler(dataset, num_replicas=num_replicas, rank=rank)
train_dataloader = build_dataloader(
dataset,
num_workers=config.train.num_workers,
batch_size=config.train.train_batch_size,
shuffle=False,
sampler=sampler,
)
# train_dataloader = build_dataloader(
# dataset,
# num_workers=config.train.num_workers,
# batch_size=config.train.train_batch_size,
# shuffle=False,
# sampler=sampler,
# )
train_dataloader = ld.StreamingDataLoader(dataset,num_workers=config.train.num_workers,drop_last=True,batch_size=config.train.train_batch_size,shuffle=False)
train_dataloader_len = len(train_dataloader)
load_vae_feat = getattr(train_dataloader.dataset, "load_vae_feat", False)
load_text_feat = getattr(train_dataloader.dataset, "load_text_feat", False)
Expand Down Expand Up @@ -896,12 +903,12 @@ def main(cfg: SanaConfig) -> None:

timestamp = time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime())

if accelerator.is_main_process:
tracker_config = dict(vars(config))
try:
accelerator.init_trackers(args.tracker_project_name, tracker_config)
except:
accelerator.init_trackers(f"tb_{timestamp}")
# if accelerator.is_main_process:
# tracker_config = dict(vars(config))
# try:
# accelerator.init_trackers(args.tracker_project_name, tracker_config)
# except:
# accelerator.init_trackers(f"tb_{timestamp}")

start_epoch = 0
start_step = 0
Expand Down