From a1a8de1c049b254781585da5ac1ec7e502cf4af1 Mon Sep 17 00:00:00 2001 From: Alejandro Gil Date: Tue, 5 Sep 2023 12:42:11 -0700 Subject: [PATCH] Added option to use the ema_model argument when initializing the EMA from the normal VAE, this way we can pass an already initialized EMA VAE to the trainer. --- muse_maskgit_pytorch/trainers/vqvae_trainers.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/muse_maskgit_pytorch/trainers/vqvae_trainers.py b/muse_maskgit_pytorch/trainers/vqvae_trainers.py index c8917d9..ea939f0 100644 --- a/muse_maskgit_pytorch/trainers/vqvae_trainers.py +++ b/muse_maskgit_pytorch/trainers/vqvae_trainers.py @@ -55,6 +55,7 @@ def __init__( lr_warmup_steps=500, discr_max_grad_norm=None, use_ema=True, + ema_vae=None, ema_beta=0.995, ema_update_after_step=0, ema_update_every=1, @@ -155,6 +156,7 @@ def __init__( if use_ema: self.ema_model = EMA( vae, + ema_model=ema_vae, update_after_step=ema_update_after_step, update_every=ema_update_every, )