Skip to content

Commit

Permalink
Continued to add support for the EMA model.
Browse files Browse the repository at this point in the history
  • Loading branch information
ZeroCool940711 committed Sep 5, 2023
1 parent a1a8de1 commit c007058
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 7 deletions.
2 changes: 2 additions & 0 deletions muse_maskgit_pytorch/trainers/maskgit_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def __init__(
logging_dir="./results/logs",
apply_grad_penalty_every=4,
use_ema=True,
ema_vae=None,
ema_update_after_step=0,
ema_update_every=1,
validation_prompts=["a photo of a dog"],
Expand Down Expand Up @@ -100,6 +101,7 @@ def __init__(
if use_ema:
ema_model = EMA(
self.model,
ema_model=ema_vae,
update_after_step=ema_update_after_step,
update_every=ema_update_every,
)
Expand Down
28 changes: 22 additions & 6 deletions train_muse_maskgit.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,13 +594,28 @@ def main():
accelerator.print("Loading Muse VQGanVAE")

if args.latest_checkpoint:
args.vae_path, ema_vae_path = get_latest_checkpoints(args.vae_path, use_ema=args.use_ema)
print(f"Resuming VAE from latest checkpoint: {args.vae_path if not args.use_ema else ema_vae_path}")
#if args.use_ema:
# print(f"Resuming EMA VAE from latest checkpoint: {ema_vae_path}")
args.vae_path, ema_model_path = get_latest_checkpoints(args.vae_path, use_ema=args.use_ema)
if ema_model_path:
ema_vae = VQGanVAE(
dim=args.dim,
vq_codebook_dim=args.vq_codebook_dim,
vq_codebook_size=args.vq_codebook_size,
l2_recon_loss=args.use_l2_recon_loss,
channels=args.channels,
layers=args.layers,
discr_layers=args.discr_layers,
accelerator=accelerator,
)
print(f"Resuming EMA VAE from latest checkpoint: {ema_model_path}")

ema_vae.load(ema_model_path, map="cpu")
else:
ema_vae = None

print(f"Resuming VAE from latest checkpoint: {args.resume_path}")
else:
accelerator.print("Resuming VAE from: ", args.vae_path)
ema_vae_path = None
ema_vae = None

# use config next to checkpoint if there is one and merge the cli arguments to it
# the cli arguments will take priority so we can use it to override any value we want.
Expand All @@ -621,7 +636,7 @@ def main():
discr_layers=args.discr_layers,
).to(accelerator.device)

vae.load(args.vae_path if not args.use_ema or not ema_vae_path else ema_vae_path, map="cpu")
vae.load(args.vae_path, map="cpu")

elif args.taming_model_path is not None and args.taming_config_path is not None:
accelerator.print(f"Using Taming VQGanVAE, loading from {args.taming_model_path}")
Expand Down Expand Up @@ -844,6 +859,7 @@ def main():
results_dir=args.results_dir,
logging_dir=args.logging_dir if args.logging_dir else os.path.join(args.results_dir, "logs"),
use_ema=args.use_ema,
ema_vae=ema_vae,
ema_update_after_step=args.ema_update_after_step,
ema_update_every=args.ema_update_every,
apply_grad_penalty_every=args.apply_grad_penalty_every,
Expand Down
23 changes: 22 additions & 1 deletion train_muse_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,10 +481,27 @@ def main():

if args.latest_checkpoint:
args.resume_path, ema_model_path = get_latest_checkpoints(args.resume_path, use_ema=args.use_ema, model_type="vae")
#print(f"Resuming VAE from latest checkpoint: {args.resume_path if not args.use_ema else ema_model_path}")
if ema_model_path:
ema_vae = VQGanVAE(
dim=args.dim,
vq_codebook_dim=args.vq_codebook_dim,
vq_codebook_size=args.vq_codebook_size,
l2_recon_loss=args.use_l2_recon_loss,
channels=args.channels,
layers=args.layers,
discr_layers=args.discr_layers,
accelerator=accelerator,
)
print(f"Resuming EMA VAE from latest checkpoint: {ema_model_path}")

ema_vae.load(ema_model_path, map="cpu")
else:
ema_vae = None

print(f"Resuming VAE from latest checkpoint: {args.resume_path}")
else:
accelerator.print("Resuming VAE from: ", args.resume_path)
ema_vae = None

if load:
#vae.load(args.resume_path if not args.use_ema or not ema_model_path else ema_model_path, map="cpu")
Expand Down Expand Up @@ -515,6 +532,7 @@ def main():
current_step = 0
else:
accelerator.print("Initialising empty VAE")

vae = VQGanVAE(
dim=args.dim,
vq_codebook_dim=args.vq_codebook_dim,
Expand All @@ -525,6 +543,8 @@ def main():
accelerator=accelerator,
)

ema_vae = None

current_step = 0

# Use the parameters() method to get an iterator over all the learnable parameters of the model
Expand Down Expand Up @@ -564,6 +584,7 @@ def main():
save_model_every=args.save_model_every,
results_dir=args.results_dir,
logging_dir=args.logging_dir if args.logging_dir else os.path.join(args.results_dir, "logs"),
ema_vae=ema_vae,
use_ema=args.use_ema,
ema_beta=args.ema_beta,
ema_update_after_step=args.ema_update_after_step,
Expand Down

0 comments on commit c007058

Please sign in to comment.