Skip to content

Commit

Permalink
fix a discriminator gradient issue thanks to @apoorv2904
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jan 12, 2023
1 parent 19769e1 commit f7ea957
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 2 deletions.
3 changes: 2 additions & 1 deletion muse_maskgit_pytorch/trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,8 @@ def train_step(self):
# update discriminator

if exists(self.vae.discr):
self.discr_optim.zero_grad()

for _ in range(self.grad_accum_every):
img = next(self.dl_iter)
img = img.to(device)
Expand All @@ -302,7 +304,6 @@ def train_step(self):
self.accelerator.clip_grad_norm_(self.vae.discr.parameters(), self.discr_max_grad_norm)

self.discr_optim.step()
self.discr_optim.zero_grad()

# log

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'muse-maskgit-pytorch',
packages = find_packages(exclude=[]),
version = '0.0.17',
version = '0.0.18',
license='MIT',
description = 'MUSE - Text-to-Image Generation via Masked Generative Transformers, in Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit f7ea957

Please sign in to comment.