Skip to content

Commit

Permalink
Added error handling for the VAE state_dict, in some occasions we mig…
Browse files Browse the repository at this point in the history
…ht end up with it throwing a RuntimeError, in this case we fallback to using strict=False when loading the state_dict.
  • Loading branch information
ZeroCool940711 committed Sep 5, 2023
1 parent fce8fd1 commit 3f4684d
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion muse_maskgit_pytorch/vqgan_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,7 +439,11 @@ def state_dict(self, *args, **kwargs):

@remove_vgg
def load_state_dict(self, *args, **kwargs):
return super().load_state_dict(*args, **kwargs)
try:
return super().load_state_dict(*args, **kwargs)
except RuntimeError:
return super().load_state_dict(*args, **kwargs, strict=False)


def save(self, path):
if self.accelerator is not None:
Expand Down

0 comments on commit 3f4684d

Please sign in to comment.