diff --git a/muse_maskgit_pytorch/dataset.py b/muse_maskgit_pytorch/dataset.py index 7679800..185489d 100644 --- a/muse_maskgit_pytorch/dataset.py +++ b/muse_maskgit_pytorch/dataset.py @@ -31,6 +31,7 @@ def __init__( flip=True, center_crop=True, stream=False, + using_taming=False ): super().__init__() self.dataset = dataset @@ -46,6 +47,7 @@ def __init__( transform_list.append(T.CenterCrop(image_size)) transform_list.append(T.ToTensor()) self.transform = T.Compose(transform_list) + self.using_taming = using_taming def __len__(self): if not self.stream: @@ -55,7 +57,10 @@ def __len__(self): def __getitem__(self, index): image = self.dataset[index][self.image_column] - return self.transform(image) - 0.5 + if self.using_taming: + return self.transform(image) - 0.5 + else: + return self.transform(image) class ImageTextDataset(ImageDataset): @@ -69,6 +74,7 @@ def __init__( flip=True, center_crop=True, stream=False, + using_taming=False ): super().__init__( dataset, @@ -77,6 +83,7 @@ def __init__( flip=flip, center_crop=center_crop, stream=stream, + using_taming=using_taming ) self.caption_column: str = caption_column self.tokenizer: T5Tokenizer = tokenizer @@ -104,7 +111,11 @@ def __getitem__(self, index): input_ids = encoded.input_ids attn_mask = encoded.attention_mask - return self.transform(image), input_ids[0], attn_mask[0] + + if self.using_taming: + return self.transform(image) - 0.5, input_ids[0], attn_mask[0] + else: + return self.transform(image), input_ids[0], attn_mask[0] class URLTextDataset(ImageDataset): @@ -117,6 +128,7 @@ def __init__( caption_column="caption", flip=True, center_crop=True, + using_taming=True ): super().__init__( dataset, @@ -124,6 +136,7 @@ def __init__( image_column=image_column, flip=flip, center_crop=center_crop, + using_taming=using_taming ) self.caption_column: str = caption_column self.tokenizer: T5Tokenizer = tokenizer @@ -161,13 +174,17 @@ def __getitem__(self, index): input_ids = encoded.input_ids attn_mask = encoded.attention_mask - return self.transform(image), input_ids[0], attn_mask[0] + if self.using_taming: + return self.transform(image) - 0.5, input_ids[0], attn_mask[0] + else: + return self.transform(image), input_ids[0], attn_mask[0] class LocalTextImageDataset(Dataset): - def __init__(self, path, image_size, tokenizer, flip=True, center_crop=True): + def __init__(self, path, image_size, tokenizer, flip=True, center_crop=True, using_taming=False): super().__init__() self.tokenizer = tokenizer + self.using_taming = using_taming print("Building dataset...") @@ -226,7 +243,10 @@ def __getitem__(self, index): input_ids = encoded.input_ids attn_mask = encoded.attention_mask - return self.transform(image), input_ids[0], attn_mask[0] + if self.using_taming: + return self.transform(image) - 0.5, input_ids[0], attn_mask[0] + else: + return self.transform(image), input_ids[0], attn_mask[0] def get_directory_size(path): diff --git a/muse_maskgit_pytorch/trainers/base_accelerated_trainer.py b/muse_maskgit_pytorch/trainers/base_accelerated_trainer.py index 9374a5b..3f8452f 100644 --- a/muse_maskgit_pytorch/trainers/base_accelerated_trainer.py +++ b/muse_maskgit_pytorch/trainers/base_accelerated_trainer.py @@ -11,6 +11,8 @@ from lion_pytorch import Lion from torch import nn from torch.optim import Adam, AdamW, Optimizer +from torch_optimizer import AdaBound, AdaMod, AccSGD, AdamP, AggMo, DiffGrad, \ + Lamb, NovoGrad, PID, QHAdam, QHM, RAdam, SGDP, SGDW, Shampoo, SWATS, Yogi from transformers.optimization import Adafactor from torch.utils.data import DataLoader, random_split @@ -136,6 +138,40 @@ def get_optimizer( ) elif optimizer == "Adafactor": return Adafactor(parameters, lr=lr, weight_decay=weight_decay, relative_step=False, scale_parameter=False, **optimizer_kwargs) + elif optimizer == "AccSGD": + return AccSGD(parameters, lr=lr, weight_decay=weight_decay) + elif optimizer == "AdaBound": + return AdaBound(parameters, lr=lr, weight_decay=weight_decay) + elif optimizer == "AdaMod": + return AdaMod(parameters, lr=lr, weight_decay=weight_decay) + elif optimizer == "AdamP": + return AdamP(parameters, lr=lr, weight_decay=weight_decay) + elif optimizer == "AggMo": + return AggMo(parameters, lr=lr, weight_decay=weight_decay) + elif optimizer == "DiffGrad": + return DiffGrad(parameters, lr=lr, weight_decay=weight_decay) + elif optimizer == "Lamb": + return Lamb(parameters, lr=lr, weight_decay=weight_decay) + elif optimizer == "NovoGrad": + return NovoGrad(parameters, lr=lr, weight_decay=weight_decay) + elif optimizer == "PID": + return PID(parameters, lr=lr, weight_decay=weight_decay) + elif optimizer == "QHAdam": + return QHAdam(parameters, lr=lr, weight_decay=weight_decay) + elif optimizer == "QHM": + return QHM(parameters, lr=lr, weight_decay=weight_decay) + elif optimizer == "RAdam": + return RAdam(parameters, lr=lr, weight_decay=weight_decay) + elif optimizer == "SGDP": + return SGDP(parameters, lr=lr, weight_decay=weight_decay) + elif optimizer == "SGDW": + return SGDW(parameters, lr=lr, weight_decay=weight_decay) + elif optimizer == "Shampoo": + return Shampoo(parameters, lr=lr, weight_decay=weight_decay) + elif optimizer == "SWATS": + return SWATS(parameters, lr=lr, weight_decay=weight_decay) + elif optimizer == "Yogi": + return Yogi(parameters, lr=lr, weight_decay=weight_decay) else: raise NotImplementedError(f"{optimizer} optimizer not supported yet.") diff --git a/setup.py b/setup.py index 815f7f7..de927de 100644 --- a/setup.py +++ b/setup.py @@ -33,6 +33,7 @@ "taming-transformers @ git+https://github.com/neggles/taming-transformers.git@v0.0.2", "transformers", "torchvision", + "torch_optimizer", "tqdm", "vector-quantize-pytorch>=0.10.14", "lion-pytorch", diff --git a/train_muse_maskgit.py b/train_muse_maskgit.py index 87354c6..1df1333 100644 --- a/train_muse_maskgit.py +++ b/train_muse_maskgit.py @@ -14,6 +14,10 @@ from diffusers.optimization import SchedulerType, get_scheduler from torch.optim import Optimizer +import os +import glob +import re + try: import torch_xla import torch_xla.core.xla_model as xm @@ -287,7 +291,10 @@ "--optimizer", type=str, default="Adafactor", - help="Optimizer to use. Choose between: ['Adam', 'AdamW', 'Lion', 'Adafactor']. Default: Adafactor (paper recommended)", + help="Optimizer to use. Choose between: ['Adam', 'AdamW','Lion', 'Adafactor', " + "'AdaBound', 'AdaMod', 'AccSGD', 'AdamP', 'AggMo', 'DiffGrad', 'Lamb', " + "'NovoGrad', 'PID', 'QHAdam', 'QHM', 'RAdam', 'SGDP', 'SGDW', 'Shampoo', " + "'SWATS', 'Yogi']. Default: Lion", ) parser.add_argument( "--weight_decay", @@ -311,6 +318,11 @@ action="store_true", help="whether to load a dataset with links instead of image (image column becomes URL column)", ) +parser.add_argument( + "--latest_checkpoint", + action="store_true", + help="Automatically find and use the latest checkpoint in the folder.", + ) parser.add_argument( "--debug", action="store_true", @@ -374,6 +386,7 @@ class Arguments: cache_path: Optional[str] = None skip_arrow: bool = False link: bool = True + latest_checkpoint: bool = False debug: bool = False @@ -438,13 +451,54 @@ def main(): # Load the VAE with accelerator.main_process_first(): if args.vae_path is not None: + load = True accelerator.print(f"Using Muse VQGanVAE, loading from {args.vae_path}") vae = VQGanVAE( dim=args.dim, vq_codebook_size=args.vq_codebook_size, accelerator=accelerator, ) - vae.load(args.vae_path, map="cpu") + + if args.latest_checkpoint: + accelerator.print("Finding latest checkpoint...") + orig_vae_path = args.vae_path + + if os.path.isfile(args.vae_path) or '.pt' in args.vae_path: + # If args.vae_path is a file, split it into directory and filename + args.vae_path, _ = os.path.split(args.vae_path) + + checkpoint_files = glob.glob(os.path.join(args.vae_path, "vae.*.pt")) + if checkpoint_files: + latest_checkpoint_file = max(checkpoint_files, + key=lambda x: int(re.search(r'vae\.(\d+)\.pt', x).group(1))) + + # Check if latest checkpoint is empty or unreadable + if os.path.getsize(latest_checkpoint_file) == 0 or not os.access(latest_checkpoint_file, os.R_OK): + accelerator.print( + f"Warning: latest checkpoint {latest_checkpoint_file} is empty or unreadable.") + if len(checkpoint_files) > 1: + # Use the second last checkpoint as a fallback + latest_checkpoint_file = max(checkpoint_files[:-1], + key=lambda x: int(re.search(r'vae\.(\d+)\.pt', x).group(1))) + accelerator.print("Using second last checkpoint: ", latest_checkpoint_file) + else: + accelerator.print("No usable checkpoint found.") + load = False + elif latest_checkpoint_file != orig_vae_path: + accelerator.print("Resuming VAE from latest checkpoint: ", latest_checkpoint_file) + else: + accelerator.print("Using checkpoint specified in vae_path: ", orig_vae_path) + + args.vae_path = latest_checkpoint_file + else: + accelerator.print("No checkpoints found in directory: ", args.vae_path) + load = False + else: + accelerator.print("Resuming VAE from: ", args.vae_path) + + if load: + vae.load(args.vae_path, map="cpu") + elif args.taming_model_path is not None and args.taming_config_path is not None: print(f"Using Taming VQGanVAE, loading from {args.taming_model_path}") vae = VQGanVAETaming( @@ -490,16 +544,57 @@ def main(): # load the maskgit transformer from disk if we have previously trained one with accelerator.main_process_first(): if args.resume_path: + load = True accelerator.print(f"Resuming MaskGit from: {args.resume_path}") - maskgit.load(args.resume_path) - resume_from_parts = args.resume_path.split(".") - for i in range(len(resume_from_parts) - 1, -1, -1): - if resume_from_parts[i].isdigit(): - current_step = int(resume_from_parts[i]) - accelerator.print(f"Found step {current_step} for the MaskGit model.") - break - if current_step == 0: - accelerator.print("No step found for the MaskGit model.") + + if args.latest_checkpoint: + accelerator.print("Finding latest checkpoint...") + orig_vae_path = args.resume_path + + if os.path.isfile(args.resume_path) or '.pt' in args.resume_path: + # If args.resume_path is a file, split it into directory and filename + args.resume_path, _ = os.path.split(args.resume_path) + + checkpoint_files = glob.glob(os.path.join(args.resume_path, "maskgit.*.pt")) + if checkpoint_files: + latest_checkpoint_file = max(checkpoint_files, + key=lambda x: int(re.search(r'maskgit\.(\d+)\.pt', x).group(1))) + + # Check if latest checkpoint is empty or unreadable + if os.path.getsize(latest_checkpoint_file) == 0 or not os.access(latest_checkpoint_file, os.R_OK): + accelerator.print( + f"Warning: latest checkpoint {latest_checkpoint_file} is empty or unreadable.") + if len(checkpoint_files) > 1: + # Use the second last checkpoint as a fallback + latest_checkpoint_file = max(checkpoint_files[:-1], key=lambda x: int( + re.search(r'maskgit\.(\d+)\.pt', x).group(1))) + accelerator.print("Using second last checkpoint: ", latest_checkpoint_file) + else: + accelerator.print("No usable checkpoint found.") + load = False + elif latest_checkpoint_file != orig_vae_path: + accelerator.print("Resuming MaskGit from latest checkpoint: ", latest_checkpoint_file) + else: + accelerator.print("Using checkpoint specified in resume_path: ", orig_vae_path) + + args.resume_path = latest_checkpoint_file + else: + accelerator.print("No checkpoints found in directory: ", args.resume_path) + load = False + else: + accelerator.print("Resuming MaskGit from: ", args.resume_path) + + if load: + maskgit.load(args.resume_path) + + resume_from_parts = args.resume_path.split(".") + for i in range(len(resume_from_parts) - 1, -1, -1): + if resume_from_parts[i].isdigit(): + current_step = int(resume_from_parts[i]) + accelerator.print(f"Found step {current_step} for the MaskGit model.") + break + if current_step == 0: + accelerator.print("No step found for the MaskGit model.") else: accelerator.print("Initialized new empty MaskGit model.") current_step = 0 @@ -513,6 +608,7 @@ def main(): tokenizer=transformer.tokenizer, center_crop=False if args.no_center_crop else True, flip=False if args.no_flip else True, + using_taming=False if not args.taming_model_path else True ) elif args.link: if not args.dataset_name: @@ -526,6 +622,7 @@ def main(): caption_column=args.caption_column, center_crop=False if args.no_center_crop else True, flip=False if args.no_flip else True, + using_taming=False if not args.taming_model_path else True ) else: dataset = ImageTextDataset( @@ -537,6 +634,7 @@ def main(): center_crop=False if args.no_center_crop else True, flip=False if args.no_flip else True, stream=args.streaming, + using_taming=False if not args.taming_model_path else True ) # Create the dataloaders