From f7ea95700016de2e9a66db1c95441a416a0f59ff Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Wed, 11 Jan 2023 17:01:25 -0800 Subject: [PATCH] fix a discriminator gradient issue thanks to @apoorv2904 --- muse_maskgit_pytorch/trainers.py | 3 ++- setup.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/muse_maskgit_pytorch/trainers.py b/muse_maskgit_pytorch/trainers.py index a736f2a..7d8fc5d 100644 --- a/muse_maskgit_pytorch/trainers.py +++ b/muse_maskgit_pytorch/trainers.py @@ -288,6 +288,8 @@ def train_step(self): # update discriminator if exists(self.vae.discr): + self.discr_optim.zero_grad() + for _ in range(self.grad_accum_every): img = next(self.dl_iter) img = img.to(device) @@ -302,7 +304,6 @@ def train_step(self): self.accelerator.clip_grad_norm_(self.vae.discr.parameters(), self.discr_max_grad_norm) self.discr_optim.step() - self.discr_optim.zero_grad() # log diff --git a/setup.py b/setup.py index ebbf31c..8c898c8 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'muse-maskgit-pytorch', packages = find_packages(exclude=[]), - version = '0.0.17', + version = '0.0.18', license='MIT', description = 'MUSE - Text-to-Image Generation via Masked Generative Transformers, in Pytorch', author = 'Phil Wang',