Skip to content

Commit

Permalink
increase chances that self conditioning ends up working
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jan 19, 2023
1 parent 38b748a commit a076be0
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 3 deletions.
30 changes: 28 additions & 2 deletions muse_maskgit_pytorch/muse_maskgit_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,23 @@ def inner(model, *args, **kwargs):
return out
return inner

# tensor helpers

def get_mask_subset_prob(mask, prob, min_mask = 0):
batch, seq, device = *mask.shape, mask.device
num_to_mask = (mask.sum(dim = -1, keepdim = True) * prob).clamp(min = min_mask)
logits = torch.rand((batch, seq), device = device)
logits = logits.masked_fill(~mask, -1)

randperm = logits.argsort(dim = -1).float()

num_padding = (~mask).sum(dim = -1, keepdim = True)
randperm -= num_padding

subset_mask = randperm < num_to_mask
subset_mask.masked_fill_(~mask, False)
return subset_mask

# classes

class LayerNorm(nn.Module):
Expand Down Expand Up @@ -352,7 +369,8 @@ def __init__(
cond_vae: Optional[VQGanVAE] = None,
cond_image_size = None,
cond_drop_prob = 0.5,
self_cond_prob = 0.9
self_cond_prob = 0.9,
no_mask_token_prob = 0.
):
super().__init__()
self.vae = vae.copy_for_eval() if exists(vae) else None
Expand Down Expand Up @@ -380,6 +398,10 @@ def __init__(
# self conditioning
self.self_cond_prob = self_cond_prob

# percentage of tokens to be [mask]ed to remain the same token, so that transformer produces better embeddings across all tokens as done in original BERT paper
# may be needed for self conditioning
self.no_mask_token_prob = no_mask_token_prob

def save(self, path):
torch.save(self.state_dict(), path)

Expand Down Expand Up @@ -551,9 +573,13 @@ def forward(
mask = batch_randperm < rearrange(num_token_masked, 'b -> b 1')

mask_id = self.transformer.mask_id
labels = torch.where(mask, ids, ignore_index)

if self.no_mask_token_prob > 0.:
no_mask_mask = get_mask_subset_prob(mask, self.no_mask_token_prob)
mask &= ~no_mask_mask

x = torch.where(mask, mask_id, ids)
labels = torch.where(mask, ids, ignore_index)

# self conditioning

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'muse-maskgit-pytorch',
packages = find_packages(exclude=[]),
version = '0.0.20',
version = '0.0.21',
license='MIT',
description = 'MUSE - Text-to-Image Generation via Masked Generative Transformers, in Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit a076be0

Please sign in to comment.