diff --git a/muse_maskgit_pytorch/muse_maskgit_pytorch.py b/muse_maskgit_pytorch/muse_maskgit_pytorch.py index 0984659..f0224a3 100644 --- a/muse_maskgit_pytorch/muse_maskgit_pytorch.py +++ b/muse_maskgit_pytorch/muse_maskgit_pytorch.py @@ -498,15 +498,15 @@ def generate( ids ) - scores = 1 - logits.gather(2, pred_ids[..., None]) + probs_without_temperature = logits.softmax(dim = -1) + + scores = 1 - probs_without_temperature.gather(2, pred_ids[..., None]) scores = rearrange(scores, '... 1 -> ...') if not can_remask_prev_masked: - # without doing MLM type 15% random or non-masked predictions - # non-masked tokens may not get correct logits (scores) - # but not sure - scores = scores.masked_fill(~is_mask, -1e5) + else: + assert self.no_mask_token_prob > 0., 'without training with some of the non-masked tokens forced to predict, not sure if the logits will be meaningful for these token' # get ids diff --git a/setup.py b/setup.py index 77c0064..dac1aa8 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'muse-maskgit-pytorch', packages = find_packages(exclude=[]), - version = '0.0.21', + version = '0.0.22', license='MIT', description = 'MUSE - Text-to-Image Generation via Masked Generative Transformers, in Pytorch', author = 'Phil Wang',