From 5570c742d96c2f646b978a76cf3ef0014d6e3197 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Sun, 22 Jan 2023 09:18:55 -0800 Subject: [PATCH] add self critic option for token critic --- README.md | 9 ++++ muse_maskgit_pytorch/muse_maskgit_pytorch.py | 46 +++++++++++++++++--- setup.py | 2 +- 3 files changed, 51 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index 8b33a46..70544c6 100644 --- a/README.md +++ b/README.md @@ -272,3 +272,12 @@ images # List[PIL.Image.Image] volume = {abs/2209.04439} } ``` + +```bibtex +@inproceedings{Nijkamp2021SCRIPTSP, + title = {SCRIPT: Self-Critic PreTraining of Transformers}, + author = {Erik Nijkamp and Bo Pang and Ying Nian Wu and Caiming Xiong}, + booktitle = {North American Chapter of the Association for Computational Linguistics}, + year = {2021} +} +``` diff --git a/muse_maskgit_pytorch/muse_maskgit_pytorch.py b/muse_maskgit_pytorch/muse_maskgit_pytorch.py index b166eb5..1509c67 100644 --- a/muse_maskgit_pytorch/muse_maskgit_pytorch.py +++ b/muse_maskgit_pytorch/muse_maskgit_pytorch.py @@ -194,6 +194,7 @@ def __init__( **kwargs ): super().__init__() + self.dim = dim self.mask_id = num_tokens if add_mask_id else None self.num_tokens = num_tokens @@ -330,6 +331,32 @@ def forward( return loss, logits +# self critic wrapper + +class SelfCritic(nn.Module): + def __init__(self, net): + super().__init__() + self.net = net + self.to_pred = nn.Linear(net.dim, 1) + + def forward_with_cond_scale(self, x, *args, **kwargs): + _, embeds = self.net.forward_with_cond_scale(x, *args, return_embed = True, **kwargs) + return self.to_pred(embeds) + + def forward_with_neg_prompt(self, x, *args, **kwargs): + _, embeds = self.net.forward_with_neg_prompt(x, *args, return_embed = True, **kwargs) + return self.to_pred(embeds) + + def forward(self, x, *args, labels = None, **kwargs): + _, embeds = self.net(x, *args, return_embed = True, **kwargs) + logits = self.to_pred(embeds) + + if not exists(labels): + return logits + + logits = rearrange(logits, '... 1 -> ...') + return F.binary_cross_entropy_with_logits(logits, labels) + # specialized transformers class MaskGitTransformer(Transformer): @@ -389,6 +416,7 @@ def __init__( transformer: MaskGitTransformer, noise_schedule: Callable = cosine_schedule, token_critic: Optional[TokenCritic] = None, + self_token_critic = False, vae: Optional[VQGanVAE] = None, cond_vae: Optional[VQGanVAE] = None, cond_image_size = None, @@ -420,7 +448,12 @@ def __init__( self.mask_id = transformer.mask_id self.noise_schedule = noise_schedule + assert not (self_token_critic and exists(token_critic)) self.token_critic = token_critic + + if self_token_critic: + self.token_critic = SelfCritic(transformer) + self.critic_loss_weight = critic_loss_weight # self conditioning @@ -632,6 +665,12 @@ def forward( x = torch.where(mask, mask_id, ids) + # get text embeddings + + if exists(texts): + text_embeds = self.transformer.encode_text(texts) + texts = None + # self conditioning self_cond_embed = None @@ -640,7 +679,6 @@ def forward( with torch.no_grad(): _, self_cond_embed = self.transformer( x, - texts = texts, text_embeds = text_embeds, conditioning_token_ids = cond_token_ids, cond_drop_prob = 0., @@ -653,7 +691,6 @@ def forward( ce_loss, logits = self.transformer( x, - texts = texts, text_embeds = text_embeds, self_cond_embed = self_cond_embed, conditioning_token_ids = cond_token_ids, @@ -673,16 +710,15 @@ def forward( critic_input = torch.where(mask, sampled_ids, x) critic_labels = (ids != critic_input).float() - bc_loss = self.token_critic( + bce_loss = self.token_critic( critic_input, - texts = texts, text_embeds = text_embeds, conditioning_token_ids = cond_token_ids, labels = critic_labels, cond_drop_prob = cond_drop_prob ) - return ce_loss + self.critic_loss_weight * bc_loss + return ce_loss + self.critic_loss_weight * bce_loss # final Muse class diff --git a/setup.py b/setup.py index 8ff885a..63e9fdd 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'muse-maskgit-pytorch', packages = find_packages(exclude=[]), - version = '0.0.23', + version = '0.0.24', license='MIT', description = 'MUSE - Text-to-Image Generation via Masked Generative Transformers, in Pytorch', author = 'Phil Wang',