diff --git a/infer_maskgit.py b/infer_maskgit.py index c01f13e..bda7b1c 100644 --- a/infer_maskgit.py +++ b/infer_maskgit.py @@ -1,7 +1,5 @@ import argparse -import glob import os -import re from dataclasses import dataclass from datetime import datetime from typing import Optional @@ -18,6 +16,9 @@ VQGanVAETaming, get_accelerator, ) +from muse_maskgit_pytorch.utils import ( + get_latest_checkpoints, +) # Create the parser parser = argparse.ArgumentParser() @@ -211,43 +212,10 @@ def main(): print("Loading Muse VQGanVAE") if args.latest_checkpoint: - print("Finding latest VAE checkpoint...") - orig_vae_path = args.vae_path - - if os.path.isfile(args.vae_path) or ".pt" in args.vae_path: - # If args.vae_path is a file, split it into directory and filename - args.vae_path, _ = os.path.split(args.vae_path) - - checkpoint_files = glob.glob(os.path.join(args.vae_path, "vae.*.pt")) - if checkpoint_files: - latest_checkpoint_file = max( - checkpoint_files, key=lambda x: int(re.search(r"vae\.(\d+)\.pt", x).group(1)) - ) - - # Check if latest checkpoint is empty or unreadable - if os.path.getsize(latest_checkpoint_file) == 0 or not os.access( - latest_checkpoint_file, os.R_OK - ): - print( - f"Warning: latest VAE checkpoint {latest_checkpoint_file} is empty or unreadable." - ) - if len(checkpoint_files) > 1: - # Use the second last checkpoint as a fallback - latest_checkpoint_file = max( - checkpoint_files[:-1], - key=lambda x: int(re.search(r"vae\.(\d+)\.pt", x).group(1)), - ) - print("Using second last VAE checkpoint: ", latest_checkpoint_file) - else: - print("No usable checkpoint found.") - elif latest_checkpoint_file != orig_vae_path: - print("Resuming VAE from latest checkpoint: ", latest_checkpoint_file) - else: - print("Using VAE checkpoint specified in vae_path: ", orig_vae_path) - - args.vae_path = latest_checkpoint_file - else: - print("No VAE checkpoints found in directory: ", args.vae_path) + args.vae_path, ema_model_path = get_latest_checkpoints(args.vae_path, use_ema=args.use_ema, model_type="vae") + print(f"Resuming VAE from latest checkpoint: {args.resume_path}") + #if args.use_ema: + # print(f"Resuming EMA VAE from latest checkpoint: {ema_model_path}") else: print("Resuming VAE from: ", args.vae_path) @@ -341,61 +309,10 @@ def main(): accelerator.print("Loading Muse MaskGit...") if args.latest_checkpoint: - accelerator.print("Finding latest MaskGit checkpoint...") - orig_vae_path = args.resume_path - - if os.path.isfile(args.resume_path) or ".pt" in args.resume_path: - # If args.resume_path is a file, split it into directory and filename - args.resume_path, _ = os.path.split(args.resume_path) - - if args.cond_image_size: - checkpoint_files = glob.glob(os.path.join(args.resume_path, "maskgit_superres.*.pt")) - else: - checkpoint_files = glob.glob(os.path.join(args.resume_path, "maskgit.*.pt")) - - if checkpoint_files: - if args.cond_image_size: - latest_checkpoint_file = max( - checkpoint_files, - key=lambda x: int(re.search(r"maskgit_superres\.(\d+)\.pt", x).group(1)), - ) - else: - latest_checkpoint_file = max( - checkpoint_files, key=lambda x: int(re.search(r"maskgit\.(\d+)\.pt", x).group(1)) - ) - - # Check if latest checkpoint is empty or unreadable - if os.path.getsize(latest_checkpoint_file) == 0 or not os.access( - latest_checkpoint_file, os.R_OK - ): - accelerator.print( - f"Warning: latest MaskGit checkpoint {latest_checkpoint_file} is empty or unreadable." - ) - if len(checkpoint_files) > 1: - # Use the second last checkpoint as a fallback - if args.cond_image_size: - latest_checkpoint_file = max( - checkpoint_files[:-1], - key=lambda x: int(re.search(r"maskgit_superres\.(\d+)\.pt", x).group(1)), - ) - else: - latest_checkpoint_file = max( - checkpoint_files[:-1], - key=lambda x: int(re.search(r"maskgit\.(\d+)\.pt", x).group(1)), - ) - accelerator.print("Using second last MaskGit checkpoint: ", latest_checkpoint_file) - else: - accelerator.print("No usable MaskGit checkpoint found.") - load = False - elif latest_checkpoint_file != orig_vae_path: - accelerator.print("Resuming MaskGit from latest checkpoint: ", latest_checkpoint_file) - else: - accelerator.print("Using MaskGit checkpoint specified in resume_path: ", orig_vae_path) - - args.resume_path = latest_checkpoint_file - else: - accelerator.print("No MaskGit checkpoints found in directory: ", args.resume_path) - load = False + args.resume_path, ema_model_path = get_latest_checkpoints(args.resume_path, use_ema=args.use_ema, model_type="maskgit", cond_image_size=args.cond_image_size) + print(f"Resuming MaskGit from latest checkpoint: {args.resume_path}") + #if args.use_ema: + # print(f"Resuming EMA MaskGit from latest checkpoint: {ema_model_path}") else: accelerator.print("Resuming MaskGit from: ", args.resume_path) diff --git a/infer_vae.py b/infer_vae.py index 0c96fff..b310ad6 100644 --- a/infer_vae.py +++ b/infer_vae.py @@ -26,6 +26,10 @@ ImageDataset, get_dataset_from_dataroot, ) +from muse_maskgit_pytorch.utils import ( + get_latest_checkpoints, +) + from muse_maskgit_pytorch.vqvae import VQVAE # Create the parser @@ -211,6 +215,8 @@ action="store_true", help="Save the original input.png and output.png images to a subfolder instead of deleting them after the grid is made.", ) +parser.add_argument("--use_ema", action="store_true", help="Whether to use ema.") +parser.add_argument("--ema_beta", type=float, default=0.995, help="Ema beta.") @dataclass @@ -373,52 +379,14 @@ def main(): ).to("cpu" if args.cpu else accelerator.device if args.gpu == 0 else f"cuda:{args.gpu}") if args.latest_checkpoint: - accelerator.print("Finding latest checkpoint...") - orig_vae_path = args.vae_path - - if os.path.isfile(args.vae_path) or ".pt" in args.vae_path: - # If args.vae_path is a file, split it into directory and filename - args.vae_path, _ = os.path.split(args.vae_path) - - checkpoint_files = glob.glob(os.path.join(args.vae_path, "vae.*.pt")) - if checkpoint_files: - latest_checkpoint_file = max( - checkpoint_files, - key=lambda x: int(re.search(r"vae\.(\d+)\.pt$", x).group(1)) - if not x.endswith("ema.pt") - else -1, - ) - - # Check if latest checkpoint is empty or unreadable - if os.path.getsize(latest_checkpoint_file) == 0 or not os.access( - latest_checkpoint_file, os.R_OK - ): - accelerator.print( - f"Warning: latest checkpoint {latest_checkpoint_file} is empty or unreadable." - ) - if len(checkpoint_files) > 1: - # Use the second last checkpoint as a fallback - latest_checkpoint_file = max( - checkpoint_files[:-1], - key=lambda x: int(re.search(r"vae\.(\d+)\.pt$", x).group(1)) - if not x.endswith("ema.pt") - else -1, - ) - accelerator.print("Using second last checkpoint: ", latest_checkpoint_file) - else: - accelerator.print("No usable checkpoint found.") - elif latest_checkpoint_file != orig_vae_path: - accelerator.print("Resuming VAE from latest checkpoint: ", latest_checkpoint_file) - else: - accelerator.print("Using checkpoint specified in vae_path: ", orig_vae_path) - - args.vae_path = latest_checkpoint_file - else: - accelerator.print("No checkpoints found in directory: ", args.vae_path) + args.vae_path, ema_model_path = get_latest_checkpoints(args.vae_path, use_ema=args.use_ema) + print(f"Resuming VAE from latest checkpoint: {args.vae_path if not args.use_ema else ema_model_path}") + #if args.use_ema: + # print(f"Resuming EMA VAE from latest checkpoint: {ema_model_path}") else: accelerator.print("Resuming VAE from: ", args.vae_path) - vae.load(args.vae_path) + vae.load(args.vae_path if not args.use_ema or not ema_model_path else ema_model_path, map="cpu") if args.use_paintmind: # load VAE diff --git a/muse_maskgit_pytorch/trainers/vqvae_trainers.py b/muse_maskgit_pytorch/trainers/vqvae_trainers.py index 14261aa..c8917d9 100644 --- a/muse_maskgit_pytorch/trainers/vqvae_trainers.py +++ b/muse_maskgit_pytorch/trainers/vqvae_trainers.py @@ -236,7 +236,7 @@ def log_validation_images(self, logs, steps): log_imgs.append(Image.open(ema_save_path)) prompts.append("ema") - super().log_validation_images(log_imgs, steps, prompts=prompts) + super().log_validation_images(log_imgs, steps, prompts=prompts) self.model.train() def train(self): diff --git a/muse_maskgit_pytorch/utils.py b/muse_maskgit_pytorch/utils.py new file mode 100644 index 0000000..ef228d4 --- /dev/null +++ b/muse_maskgit_pytorch/utils.py @@ -0,0 +1,122 @@ +from __future__ import print_function +import re, glob, os, torch + +def get_latest_checkpoints(resume_path, use_ema=False, model_type="vae", cond_image_size=False): + """Gets the latest checkpoint paths for both the non-ema and ema VAEs. + + Args: + resume_path: The path to the directory containing the VAE checkpoints. + + Returns: + A tuple containing the paths to the latest non-ema and ema VAE checkpoints, respectively. + """ + + vae_path, _ = os.path.split(resume_path) + if cond_image_size: + checkpoint_files = glob.glob(os.path.join(vae_path, "maskgit_superres.*.pt")) + else: + checkpoint_files = glob.glob(os.path.join(vae_path, "vae.*.pt" if model_type == "vae" else "maskgit.*.pt")) + #print(checkpoint_files) + + print(f"Finding latest {'VAE' if model_type == 'vae' else 'MaskGit'} checkpoint...") + + # Get the latest non-ema VAE checkpoint path + if cond_image_size: + latest_non_ema_checkpoint_file = max( + checkpoint_files, + key=lambda x: int(re.search(r"maskgit_superres\.(\d+)\.pt", x).group(1)), + ) + else: + latest_non_ema_checkpoint_file = max( + checkpoint_files, + key=lambda x: int(re.search(r"vae\.(\d+)\.pt$" if model_type == "vae" else r"maskgit\.(\d+)\.pt$", x).group(1)) + if not x.endswith("ema.pt") + else -1, + ) + + # Check if the latest checkpoints are empty or unreadable + if os.path.getsize(latest_non_ema_checkpoint_file) == 0 or not os.access( + latest_non_ema_checkpoint_file, os.R_OK + ): + print( + f"Warning: latest checkpoint {latest_non_ema_checkpoint_file} is empty or unreadable." + ) + if len(checkpoint_files) > 1: + # Use the second last checkpoint as a fallback + if cond_image_size: + latest_non_ema_checkpoint_file = max( + checkpoint_files[:-1], + key=lambda x: int(re.search(r"maskgit_superres\.(\d+)\.pt", x).group(1)), + ) + else: + latest_non_ema_checkpoint_file = max( + checkpoint_files[:-1], + key=lambda x: int(re.search(r"vae\.(\d+)\.pt$" if model_type == "vae" else r"maskgit\.(\d+)\.pt$", x).group(1)) + if not x.endswith("ema.pt") + else -1, + ) + print("Using second last checkpoint: ", latest_non_ema_checkpoint_file) + else: + print("No usable checkpoint found.") + + if use_ema: + # Get the latest ema VAE checkpoint path + if cond_image_size: + latest_ema_checkpoint_file = max( + checkpoint_files, + key=lambda x: int(re.search(r"maskgit_superres\.(\d+)\.ema\.pt$", x).group(1)) + if x.endswith("ema.pt") + else -1, + ) + else: + latest_ema_checkpoint_file = max( + checkpoint_files, + key=lambda x: int(re.search(r"vae\.(\d+)\.ema\.pt$" if model_type == "vae" else r"maskgit\.(\d+)\.ema\.pt$", x).group(1)) + if x.endswith("ema.pt") + else -1, + ) + + if os.path.getsize(latest_ema_checkpoint_file) == 0 or not os.access( + latest_ema_checkpoint_file, os.R_OK + ): + print( + f"Warning: latest EMA checkpoint {latest_ema_checkpoint_file} is empty or unreadable." + ) + if len(checkpoint_files) > 1: + # Use the second last checkpoint as a fallback + if cond_image_size: + latest_ema_checkpoint_file = max( + checkpoint_files[:-1], + key=lambda x: int(re.search(r"maskgit_superres\.(\d+)\.ema\.pt$", x).group(1)) + if x.endswith("ema.pt") + else -1, + ) + else: + latest_ema_checkpoint_file = max( + checkpoint_files[:-1], + key=lambda x: int(re.search(r"vae\.(\d+)\.ema\.pt$" if model_type == "vae" else r"maskgit\.(\d+)\.ema\.pt$", x).group(1)) + if x.endswith("ema.pt") + else -1, + ) + print("Using second last EMA checkpoint: ", latest_ema_checkpoint_file) + else: + latest_ema_checkpoint_file = None + + return latest_non_ema_checkpoint_file, latest_ema_checkpoint_file + +def remove_duplicate_weights(ema_state_dict, non_ema_state_dict): + """Removes duplicate weights from the ema state dictionary. + + Args: + ema_state_dict: The state dictionary of the ema model. + non_ema_state_dict: The state dictionary of the non-ema model. + + Returns: + The ema state dictionary with duplicate weights removed. + """ + + ema_state_dict_copy = ema_state_dict.copy() + for key, value in ema_state_dict.items(): + if key in non_ema_state_dict and torch.equal(ema_state_dict[key], non_ema_state_dict[key]): + del ema_state_dict_copy[key] + return ema_state_dict_copy \ No newline at end of file diff --git a/muse_maskgit_pytorch/vqgan_vae.py b/muse_maskgit_pytorch/vqgan_vae.py index 94fbb62..13cf042 100644 --- a/muse_maskgit_pytorch/vqgan_vae.py +++ b/muse_maskgit_pytorch/vqgan_vae.py @@ -439,7 +439,11 @@ def state_dict(self, *args, **kwargs): @remove_vgg def load_state_dict(self, *args, **kwargs): - return super().load_state_dict(*args, **kwargs) + try: + return super().load_state_dict(*args, **kwargs) + except RuntimeError: + return super().load_state_dict(*args, **kwargs, strict=False) + def save(self, path): if self.accelerator is not None: diff --git a/train_muse_maskgit.py b/train_muse_maskgit.py index b3baaba..9f0ba68 100644 --- a/train_muse_maskgit.py +++ b/train_muse_maskgit.py @@ -1,8 +1,6 @@ import argparse -import glob import logging import os -import re from dataclasses import dataclass from typing import Optional, Union @@ -17,6 +15,9 @@ from omegaconf import OmegaConf from rich import inspect from torch.optim import Optimizer +from muse_maskgit_pytorch.utils import ( + get_latest_checkpoints, +) try: import torch_xla @@ -593,45 +594,13 @@ def main(): accelerator.print("Loading Muse VQGanVAE") if args.latest_checkpoint: - accelerator.print("Finding latest VAE checkpoint...") - orig_vae_path = args.vae_path - - if os.path.isfile(args.vae_path) or ".pt" in args.vae_path: - # If args.vae_path is a file, split it into directory and filename - args.vae_path, _ = os.path.split(args.vae_path) - - checkpoint_files = glob.glob(os.path.join(args.vae_path, "vae.*.pt")) - if checkpoint_files: - latest_checkpoint_file = max( - checkpoint_files, key=lambda x: int(re.search(r"vae\.(\d+)\.pt", x).group(1)) - ) - - # Check if latest checkpoint is empty or unreadable - if os.path.getsize(latest_checkpoint_file) == 0 or not os.access( - latest_checkpoint_file, os.R_OK - ): - accelerator.print( - f"Warning: latest VAE checkpoint {latest_checkpoint_file} is empty or unreadable." - ) - if len(checkpoint_files) > 1: - # Use the second last checkpoint as a fallback - latest_checkpoint_file = max( - checkpoint_files[:-1], - key=lambda x: int(re.search(r"vae\.(\d+)\.pt", x).group(1)), - ) - accelerator.print("Using second last VAE checkpoint: ", latest_checkpoint_file) - else: - accelerator.print("No usable checkpoint found.") - elif latest_checkpoint_file != orig_vae_path: - accelerator.print("Resuming VAE from latest checkpoint: ", latest_checkpoint_file) - else: - accelerator.print("Using VAE checkpoint specified in vae_path: ", orig_vae_path) - - args.vae_path = latest_checkpoint_file - else: - accelerator.print("No VAE checkpoints found in directory: ", args.vae_path) + args.vae_path, ema_vae_path = get_latest_checkpoints(args.vae_path, use_ema=args.use_ema) + print(f"Resuming VAE from latest checkpoint: {args.vae_path if not args.use_ema else ema_vae_path}") + #if args.use_ema: + # print(f"Resuming EMA VAE from latest checkpoint: {ema_vae_path}") else: accelerator.print("Resuming VAE from: ", args.vae_path) + ema_vae_path = None # use config next to checkpoint if there is one and merge the cli arguments to it # the cli arguments will take priority so we can use it to override any value we want. @@ -651,7 +620,8 @@ def main(): layers=args.layers, discr_layers=args.discr_layers, ).to(accelerator.device) - vae.load(args.vae_path) + + vae.load(args.vae_path if not args.use_ema or not ema_vae_path else ema_vae_path, map="cpu") elif args.taming_model_path is not None and args.taming_config_path is not None: accelerator.print(f"Using Taming VQGanVAE, loading from {args.taming_model_path}") @@ -721,65 +691,10 @@ def main(): accelerator.print("Loading Muse MaskGit...") if args.latest_checkpoint: - accelerator.print("Finding latest MaskGit checkpoint...") - orig_vae_path = args.resume_path - - if os.path.isfile(args.resume_path) or ".pt" in args.resume_path: - # If args.resume_path is a file, split it into directory and filename - args.resume_path, _ = os.path.split(args.resume_path) - - if args.cond_image_size: - checkpoint_files = glob.glob(os.path.join(args.resume_path, "maskgit_superres.*.pt")) - else: - checkpoint_files = glob.glob(os.path.join(args.resume_path, "maskgit.*.pt")) - - if checkpoint_files: - if args.cond_image_size: - latest_checkpoint_file = max( - checkpoint_files, - key=lambda x: int(re.search(r"maskgit_superres\.(\d+)\.pt", x).group(1)), - ) - else: - latest_checkpoint_file = max( - checkpoint_files, key=lambda x: int(re.search(r"maskgit\.(\d+)\.pt", x).group(1)) - ) - - # Check if latest checkpoint is empty or unreadable - if os.path.getsize(latest_checkpoint_file) == 0 or not os.access( - latest_checkpoint_file, os.R_OK - ): - accelerator.print( - f"Warning: latest MaskGit checkpoint {latest_checkpoint_file} is empty or unreadable." - ) - if len(checkpoint_files) > 1: - # Use the second last checkpoint as a fallback - if args.cond_image_size: - latest_checkpoint_file = max( - checkpoint_files[:-1], - key=lambda x: int(re.search(r"maskgit_superres\.(\d+)\.pt", x).group(1)), - ) - else: - latest_checkpoint_file = max( - checkpoint_files[:-1], - key=lambda x: int(re.search(r"maskgit\.(\d+)\.pt", x).group(1)), - ) - accelerator.print( - "Using second last MaskGit checkpoint: ", latest_checkpoint_file - ) - else: - accelerator.print("No usable MaskGit checkpoint found.") - load = False - elif latest_checkpoint_file != orig_vae_path: - accelerator.print("Resuming MaskGit from latest checkpoint: ", latest_checkpoint_file) - else: - accelerator.print( - "Using MaskGit checkpoint specified in resume_path: ", orig_vae_path - ) - - args.resume_path = latest_checkpoint_file - else: - accelerator.print("No MaskGit checkpoints found in directory: ", args.resume_path) - load = False + args.resume_path, ema_model_path = get_latest_checkpoints(args.resume_path, use_ema=args.use_ema, model_type="maskgit", cond_image_size=args.cond_image_size) + print(f"Resuming MaskGit from latest checkpoint: {args.resume_path}") + #if args.use_ema: + # print(f"Resuming EMA MaskGit from latest checkpoint: {ema_model_path}") else: accelerator.print("Resuming MaskGit from: ", args.resume_path) diff --git a/train_muse_vae.py b/train_muse_vae.py index ec3c5af..91e2712 100644 --- a/train_muse_vae.py +++ b/train_muse_vae.py @@ -1,7 +1,5 @@ import argparse -import glob import os -import re from dataclasses import dataclass from typing import Optional, Union @@ -9,7 +7,9 @@ from accelerate.utils import ProjectConfiguration from datasets import load_dataset from omegaconf import OmegaConf - +from muse_maskgit_pytorch.utils import ( + get_latest_checkpoints, +) from muse_maskgit_pytorch import ( VQGanVAE, VQGanVAETaming, @@ -466,7 +466,8 @@ def main(): if args.resume_path is not None and len(args.resume_path) > 1: load = True - accelerator.print(f"Using Muse VQGanVAE, loading from {args.resume_path}") + + accelerator.print(f"Loading Muse VQGanVAE...") vae = VQGanVAE( dim=args.dim, vq_codebook_dim=args.vq_codebook_dim, @@ -479,48 +480,14 @@ def main(): ) if args.latest_checkpoint: - accelerator.print("Finding latest checkpoint...") - orig_vae_path = args.resume_path - - if os.path.isfile(args.resume_path) or ".pt" in args.resume_path: - # If args.vae_path is a file, split it into directory and filename - args.resume_path, _ = os.path.split(args.resume_path) - - checkpoint_files = glob.glob(os.path.join(args.resume_path, "vae.*.pt")) - if checkpoint_files: - latest_checkpoint_file = max( - checkpoint_files, key=lambda x: int(re.search(r"vae\.(\d+)\.pt", x).group(1)) - ) - - # Check if latest checkpoint is empty or unreadable - if os.path.getsize(latest_checkpoint_file) == 0 or not os.access( - latest_checkpoint_file, os.R_OK - ): - accelerator.print( - f"Warning: latest checkpoint {latest_checkpoint_file} is empty or unreadable." - ) - if len(checkpoint_files) > 1: - # Use the second last checkpoint as a fallback - latest_checkpoint_file = max( - checkpoint_files[:-1], key=lambda x: int(re.search(r"vae\.(\d+)\.pt", x).group(1)) - ) - accelerator.print("Using second last checkpoint: ", latest_checkpoint_file) - else: - accelerator.print("No usable checkpoint found.") - load = False - elif latest_checkpoint_file != orig_vae_path: - accelerator.print("Resuming VAE from latest checkpoint: ", latest_checkpoint_file) - else: - accelerator.print("Using checkpoint specified in vae_path: ", orig_vae_path) - - args.resume_path = latest_checkpoint_file - else: - accelerator.print("No checkpoints found in directory: ", args.resume_path) - load = False + args.resume_path, ema_model_path = get_latest_checkpoints(args.resume_path, use_ema=args.use_ema, model_type="vae") + #print(f"Resuming VAE from latest checkpoint: {args.resume_path if not args.use_ema else ema_model_path}") + print(f"Resuming VAE from latest checkpoint: {args.resume_path}") else: accelerator.print("Resuming VAE from: ", args.resume_path) if load: + #vae.load(args.resume_path if not args.use_ema or not ema_model_path else ema_model_path, map="cpu") vae.load(args.resume_path, map="cpu") resume_from_parts = args.resume_path.split(".")