diff --git a/muse_maskgit_pytorch/muse_maskgit_pytorch.py b/muse_maskgit_pytorch/muse_maskgit_pytorch.py index 1509c67..d3f1fa6 100644 --- a/muse_maskgit_pytorch/muse_maskgit_pytorch.py +++ b/muse_maskgit_pytorch/muse_maskgit_pytorch.py @@ -486,6 +486,7 @@ def generate( force_not_use_token_critic = False, timesteps = 18, # ideal number of steps is 18 in maskgit paper cond_scale = 3, + critic_noise_scale = 1 ): fmap_size = default(fmap_size, self.vae.get_encoded_fmap_size(self.image_size)) @@ -579,6 +580,9 @@ def generate( ) scores = rearrange(scores, '... 1 -> ...') + + scores = scores + (uniform(scores.shape, device = device) - 0.5) * critic_noise_scale * (steps_until_x0 / timesteps) + else: probs_without_temperature = logits.softmax(dim = -1) diff --git a/setup.py b/setup.py index 63e9fdd..bbd923d 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'muse-maskgit-pytorch', packages = find_packages(exclude=[]), - version = '0.0.24', + version = '0.0.25', license='MIT', description = 'MUSE - Text-to-Image Generation via Masked Generative Transformers, in Pytorch', author = 'Phil Wang',