From b22e14bf71b42b531ee48c22e1e5062824529b1a Mon Sep 17 00:00:00 2001 From: Alejandro Gil Date: Sun, 18 Jun 2023 09:30:53 -0700 Subject: [PATCH] Removed duplicated code for the maskgit training script. --- train_muse_maskgit.py | 59 ------------------------------------------- 1 file changed, 59 deletions(-) diff --git a/train_muse_maskgit.py b/train_muse_maskgit.py index b7697ff..c2272b1 100644 --- a/train_muse_maskgit.py +++ b/train_muse_maskgit.py @@ -684,65 +684,6 @@ def main(): xformers=xformers, ) - # load the maskgit transformer from disk if we have previously trained one - if args.resume_path: - if args.latest_checkpoint: - accelerator.print("Finding latest checkpoint...") - orig_vae_path = args.resume_path - - if os.path.isfile(args.resume_path) or ".pt" in args.resume_path: - # If args.resume_path is a file, split it into directory and filename - args.resume_path, _ = os.path.split(args.resume_path) - - checkpoint_files = glob.glob(os.path.join(args.resume_path, "maskgit.*.pt")) - if checkpoint_files: - latest_checkpoint_file = max( - checkpoint_files, - key=lambda x: int(re.search(r"maskgit\.(\d+)\.pt$", x).group(1)) - if not x.endswith("ema.pt") - else -1, - ) - - # Check if latest checkpoint is empty or unreadable - if os.path.getsize(latest_checkpoint_file) == 0 or not os.access( - latest_checkpoint_file, os.R_OK - ): - accelerator.print( - f"Warning: latest MaskGit checkpoint {latest_checkpoint_file} is empty or unreadable." - ) - if len(checkpoint_files) > 1: - # Use the second last checkpoint as a fallback - latest_checkpoint_file = max( - checkpoint_files[:-1], - key=lambda x: int(re.search(r"maskgit\.(\d+)\.pt$", x).group(1)) - if not x.endswith("ema.pt") - else -1, - ) - accelerator.print("Using second last MaskGit checkpoint: ", latest_checkpoint_file) - else: - accelerator.print("No usable MaskGit checkpoint found.") - elif latest_checkpoint_file != orig_vae_path: - accelerator.print("Resuming MaskGit from latest checkpoint: ", latest_checkpoint_file) - else: - accelerator.print("Using MaskGit checkpoint specified in resume_path: ", orig_vae_path) - - args.resume_path = latest_checkpoint_file - else: - accelerator.print("No MaskGit checkpoints found in directory: ", args.resume_path) - else: - accelerator.print("Resuming MaskGit from: ", args.resume_path) - - # use config next to checkpoint if there is one and merge the cli arguments to it - # the cli arguments will take priority so we can use it to override any value we want. - if os.path.exists(f"{args.resume_path}.yaml"): - accelerator.print( - "Config file found, reusing config from it. Use cli arguments to override any desired value." - ) - conf = OmegaConf.load(f"{args.resume_path}.yaml") - cli_conf = OmegaConf.from_cli() - # merge the config file and the cli arguments. - conf = OmegaConf.merge(conf, cli_conf) - # (2) pass your trained VAE and the base transformer to MaskGit maskgit = MaskGit( vae=vae, # vqgan vae