diff --git a/muse_maskgit_pytorch/trainers.py b/muse_maskgit_pytorch/trainers.py index a736f2a..7d8fc5d 100644 --- a/muse_maskgit_pytorch/trainers.py +++ b/muse_maskgit_pytorch/trainers.py @@ -288,6 +288,8 @@ def train_step(self): # update discriminator if exists(self.vae.discr): + self.discr_optim.zero_grad() + for _ in range(self.grad_accum_every): img = next(self.dl_iter) img = img.to(device) @@ -302,7 +304,6 @@ def train_step(self): self.accelerator.clip_grad_norm_(self.vae.discr.parameters(), self.discr_max_grad_norm) self.discr_optim.step() - self.discr_optim.zero_grad() # log diff --git a/setup.py b/setup.py index ebbf31c..8c898c8 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'muse-maskgit-pytorch', packages = find_packages(exclude=[]), - version = '0.0.17', + version = '0.0.18', license='MIT', description = 'MUSE - Text-to-Image Generation via Masked Generative Transformers, in Pytorch', author = 'Phil Wang',