Skip to content

Commit

Permalink
Update train_muse_vae.py
Browse files Browse the repository at this point in the history
  • Loading branch information
korakoe committed Jun 4, 2023
1 parent 27cff6f commit e1b7f93
Showing 1 changed file with 66 additions and 6 deletions.
72 changes: 66 additions & 6 deletions train_muse_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@
split_dataset_into_dataloaders,
)

import os
import glob
import re


def parse_args():
# Create the parser
Expand Down Expand Up @@ -232,6 +236,11 @@ def parse_args():
action="store_true",
help="Whether to skip saving the dataset to Arrow files",
)
parser.add_argument(
"--latest_checkpoint",
action="store_true",
help="Whether to use the latest checkpoint",
)
# Parse the argument
return parser.parse_args()

Expand Down Expand Up @@ -281,19 +290,70 @@ def main():
else:
dataset = load_dataset(args.dataset_name)["train"]

vae = VQGanVAE(dim=args.dim, vq_codebook_size=args.vq_codebook_size)
if args.taming_model_path:
print("Loading Taming VQGanVAE")
if args.vae_path is not None:
load = True
accelerator.print(f"Using Muse VQGanVAE, loading from {args.vae_path}")
vae = VQGanVAE(
dim=args.dim,
vq_codebook_size=args.vq_codebook_size,
accelerator=accelerator,
)

if args.latest_checkpoint:
accelerator.print("Finding latest checkpoint...")
orig_vae_path = args.vae_path

if os.path.isfile(args.vae_path) or '.pt' in args.vae_path:
# If args.vae_path is a file, split it into directory and filename
args.vae_path, _ = os.path.split(args.vae_path)

checkpoint_files = glob.glob(os.path.join(args.vae_path, "vae.*.pt"))
if checkpoint_files:
latest_checkpoint_file = max(checkpoint_files,
key=lambda x: int(re.search(r'vae\.(\d+)\.pt', x).group(1)))

# Check if latest checkpoint is empty or unreadable
if os.path.getsize(latest_checkpoint_file) == 0 or not os.access(latest_checkpoint_file, os.R_OK):
accelerator.print(
f"Warning: latest checkpoint {latest_checkpoint_file} is empty or unreadable.")
if len(checkpoint_files) > 1:
# Use the second last checkpoint as a fallback
latest_checkpoint_file = max(checkpoint_files[:-1],
key=lambda x: int(re.search(r'vae\.(\d+)\.pt', x).group(1)))
accelerator.print("Using second last checkpoint: ", latest_checkpoint_file)
else:
accelerator.print("No usable checkpoint found.")
load = False
elif latest_checkpoint_file != orig_vae_path:
accelerator.print("Resuming VAE from latest checkpoint: ", latest_checkpoint_file)
else:
accelerator.print("Using checkpoint specified in vae_path: ", orig_vae_path)

args.vae_path = latest_checkpoint_file
else:
accelerator.print("No checkpoints found in directory: ", args.vae_path)
load = False
else:
accelerator.print("Resuming VAE from: ", args.vae_path)

if load:
vae.load(args.vae_path, map="cpu")

elif args.taming_model_path is not None and args.taming_config_path is not None:
print(f"Using Taming VQGanVAE, loading from {args.taming_model_path}")
vae = VQGanVAETaming(
vqgan_model_path=args.taming_model_path,
vqgan_config_path=args.taming_config_path,
accelerator=accelerator,
)
args.num_tokens = vae.codebook_size
args.seq_len = vae.get_encoded_fmap_size(args.image_size) ** 2
elif args.resume_path:
accelerator.print(f"Resuming VAE from: {args.resume_path}")
vae.load(args.resume_path)
else:
raise ValueError(
"You must pass either vae_path or taming_model_path + taming_config_path (but not both)"
)

if load:
resume_from_parts = args.resume_path.split(".")
for i in range(len(resume_from_parts) - 1, -1, -1):
if resume_from_parts[i].isdigit():
Expand Down

0 comments on commit e1b7f93

Please sign in to comment.