diff --git a/muse_maskgit_pytorch/muse_maskgit_pytorch.py b/muse_maskgit_pytorch/muse_maskgit_pytorch.py index 69d90fd..0984659 100644 --- a/muse_maskgit_pytorch/muse_maskgit_pytorch.py +++ b/muse_maskgit_pytorch/muse_maskgit_pytorch.py @@ -36,6 +36,23 @@ def inner(model, *args, **kwargs): return out return inner +# tensor helpers + +def get_mask_subset_prob(mask, prob, min_mask = 0): + batch, seq, device = *mask.shape, mask.device + num_to_mask = (mask.sum(dim = -1, keepdim = True) * prob).clamp(min = min_mask) + logits = torch.rand((batch, seq), device = device) + logits = logits.masked_fill(~mask, -1) + + randperm = logits.argsort(dim = -1).float() + + num_padding = (~mask).sum(dim = -1, keepdim = True) + randperm -= num_padding + + subset_mask = randperm < num_to_mask + subset_mask.masked_fill_(~mask, False) + return subset_mask + # classes class LayerNorm(nn.Module): @@ -352,7 +369,8 @@ def __init__( cond_vae: Optional[VQGanVAE] = None, cond_image_size = None, cond_drop_prob = 0.5, - self_cond_prob = 0.9 + self_cond_prob = 0.9, + no_mask_token_prob = 0. ): super().__init__() self.vae = vae.copy_for_eval() if exists(vae) else None @@ -380,6 +398,10 @@ def __init__( # self conditioning self.self_cond_prob = self_cond_prob + # percentage of tokens to be [mask]ed to remain the same token, so that transformer produces better embeddings across all tokens as done in original BERT paper + # may be needed for self conditioning + self.no_mask_token_prob = no_mask_token_prob + def save(self, path): torch.save(self.state_dict(), path) @@ -551,9 +573,13 @@ def forward( mask = batch_randperm < rearrange(num_token_masked, 'b -> b 1') mask_id = self.transformer.mask_id + labels = torch.where(mask, ids, ignore_index) + + if self.no_mask_token_prob > 0.: + no_mask_mask = get_mask_subset_prob(mask, self.no_mask_token_prob) + mask &= ~no_mask_mask x = torch.where(mask, mask_id, ids) - labels = torch.where(mask, ids, ignore_index) # self conditioning diff --git a/setup.py b/setup.py index c2da37e..77c0064 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'muse-maskgit-pytorch', packages = find_packages(exclude=[]), - version = '0.0.20', + version = '0.0.21', license='MIT', description = 'MUSE - Text-to-Image Generation via Masked Generative Transformers, in Pytorch', author = 'Phil Wang',