Skip to content

Commit

Permalink
Merge pull request Sygil-Dev#41 from ZeroCool940711/dev
Browse files Browse the repository at this point in the history
Removed duplicated code for the maskgit training script.
  • Loading branch information
ZeroCool940711 authored Jun 18, 2023
2 parents 9f2516d + b22e14b commit 87a8a2d
Showing 1 changed file with 0 additions and 59 deletions.
59 changes: 0 additions & 59 deletions train_muse_maskgit.py
Original file line number Diff line number Diff line change
Expand Up @@ -684,65 +684,6 @@ def main():
xformers=xformers,
)

# load the maskgit transformer from disk if we have previously trained one
if args.resume_path:
if args.latest_checkpoint:
accelerator.print("Finding latest checkpoint...")
orig_vae_path = args.resume_path

if os.path.isfile(args.resume_path) or ".pt" in args.resume_path:
# If args.resume_path is a file, split it into directory and filename
args.resume_path, _ = os.path.split(args.resume_path)

checkpoint_files = glob.glob(os.path.join(args.resume_path, "maskgit.*.pt"))
if checkpoint_files:
latest_checkpoint_file = max(
checkpoint_files,
key=lambda x: int(re.search(r"maskgit\.(\d+)\.pt$", x).group(1))
if not x.endswith("ema.pt")
else -1,
)

# Check if latest checkpoint is empty or unreadable
if os.path.getsize(latest_checkpoint_file) == 0 or not os.access(
latest_checkpoint_file, os.R_OK
):
accelerator.print(
f"Warning: latest MaskGit checkpoint {latest_checkpoint_file} is empty or unreadable."
)
if len(checkpoint_files) > 1:
# Use the second last checkpoint as a fallback
latest_checkpoint_file = max(
checkpoint_files[:-1],
key=lambda x: int(re.search(r"maskgit\.(\d+)\.pt$", x).group(1))
if not x.endswith("ema.pt")
else -1,
)
accelerator.print("Using second last MaskGit checkpoint: ", latest_checkpoint_file)
else:
accelerator.print("No usable MaskGit checkpoint found.")
elif latest_checkpoint_file != orig_vae_path:
accelerator.print("Resuming MaskGit from latest checkpoint: ", latest_checkpoint_file)
else:
accelerator.print("Using MaskGit checkpoint specified in resume_path: ", orig_vae_path)

args.resume_path = latest_checkpoint_file
else:
accelerator.print("No MaskGit checkpoints found in directory: ", args.resume_path)
else:
accelerator.print("Resuming MaskGit from: ", args.resume_path)

# use config next to checkpoint if there is one and merge the cli arguments to it
# the cli arguments will take priority so we can use it to override any value we want.
if os.path.exists(f"{args.resume_path}.yaml"):
accelerator.print(
"Config file found, reusing config from it. Use cli arguments to override any desired value."
)
conf = OmegaConf.load(f"{args.resume_path}.yaml")
cli_conf = OmegaConf.from_cli()
# merge the config file and the cli arguments.
conf = OmegaConf.merge(conf, cli_conf)

# (2) pass your trained VAE and the base transformer to MaskGit
maskgit = MaskGit(
vae=vae, # vqgan vae
Expand Down

0 comments on commit 87a8a2d

Please sign in to comment.