diff --git a/README.md b/README.md index 3cab9a0..e6091d6 100644 --- a/README.md +++ b/README.md @@ -262,3 +262,13 @@ images # List[PIL.Image.Image] primaryClass = {cs.LG} } ``` + +```bibtex +@article{Lezama2022ImprovedMI, + title = {Improved Masked Image Generation with Token-Critic}, + author = {Jos{\'e} Lezama and Huiwen Chang and Lu Jiang and Irfan Essa}, + journal = {ArXiv}, + year = {2022}, + volume = {abs/2209.04439} +} +``` diff --git a/muse_maskgit_pytorch/muse_maskgit_pytorch.py b/muse_maskgit_pytorch/muse_maskgit_pytorch.py index 2856551..b166eb5 100644 --- a/muse_maskgit_pytorch/muse_maskgit_pytorch.py +++ b/muse_maskgit_pytorch/muse_maskgit_pytorch.py @@ -263,6 +263,7 @@ def forward( self, x, return_embed = False, + return_logits = False, labels = None, ignore_index = 0, self_cond_embed = None, @@ -320,10 +321,14 @@ def forward( return logits if self.dim_out == 1: - return F.binary_cross_entropy_with_logits(rearrange(logits, '... 1 -> ...'), labels) + loss = F.binary_cross_entropy_with_logits(rearrange(logits, '... 1 -> ...'), labels) + else: + loss = F.cross_entropy(rearrange(logits, 'b n c -> b c n'), labels, ignore_index = ignore_index) + + if not return_logits: + return loss - logits = rearrange(logits, 'b n c -> b c n') - return F.cross_entropy(logits, labels, ignore_index = ignore_index) + return loss, logits # specialized transformers @@ -383,12 +388,14 @@ def __init__( image_size, transformer: MaskGitTransformer, noise_schedule: Callable = cosine_schedule, + token_critic: Optional[TokenCritic] = None, vae: Optional[VQGanVAE] = None, cond_vae: Optional[VQGanVAE] = None, cond_image_size = None, cond_drop_prob = 0.5, self_cond_prob = 0.9, - no_mask_token_prob = 0. + no_mask_token_prob = 0., + critic_loss_weight = 1. ): super().__init__() self.vae = vae.copy_for_eval() if exists(vae) else None @@ -413,6 +420,9 @@ def __init__( self.mask_id = transformer.mask_id self.noise_schedule = noise_schedule + self.token_critic = token_critic + self.critic_loss_weight = critic_loss_weight + # self conditioning self.self_cond_prob = self_cond_prob @@ -440,6 +450,7 @@ def generate( temperature = 1., topk_filter_thres = 0.9, can_remask_prev_masked = False, + force_not_use_token_critic = False, timesteps = 18, # ideal number of steps is 18 in maskgit paper cond_scale = 3, ): @@ -466,6 +477,13 @@ def generate( demask_fn = self.transformer.forward_with_cond_scale + # whether to use token critic for scores + + use_token_critic = exists(self.token_critic) and not force_not_use_token_critic + + if use_token_critic: + token_critic_fn = self.token_critic.forward_with_cond_scale + # negative prompting, as in paper neg_text_embeds = None @@ -475,6 +493,9 @@ def generate( neg_text_embeds = self.transformer.encode_text(negative_texts) demask_fn = partial(self.transformer.forward_with_neg_prompt, neg_text_embeds = neg_text_embeds) + if use_token_critic: + token_critic_fn = partial(self.token_critic.forward_with_neg_prompt, neg_text_embeds = neg_text_embeds) + if self.resize_image_for_cond_image: assert exists(cond_images), 'conditioning image must be passed in to generate for super res maskgit' with torch.no_grad(): @@ -516,15 +537,25 @@ def generate( ids ) - probs_without_temperature = logits.softmax(dim = -1) - - scores = 1 - probs_without_temperature.gather(2, pred_ids[..., None]) - scores = rearrange(scores, '... 1 -> ...') + if use_token_critic: + scores = token_critic_fn( + ids, + text_embeds = text_embeds, + conditioning_token_ids = cond_ids, + cond_scale = cond_scale + ) - if not can_remask_prev_masked: - scores = scores.masked_fill(~is_mask, -1e5) + scores = rearrange(scores, '... 1 -> ...') else: - assert self.no_mask_token_prob > 0., 'without training with some of the non-masked tokens forced to predict, not sure if the logits will be meaningful for these token' + probs_without_temperature = logits.softmax(dim = -1) + + scores = 1 - probs_without_temperature.gather(2, pred_ids[..., None]) + scores = rearrange(scores, '... 1 -> ...') + + if not can_remask_prev_masked: + scores = scores.masked_fill(~is_mask, -1e5) + else: + assert self.no_mask_token_prob > 0., 'without training with some of the non-masked tokens forced to predict, not sure if the logits will be meaningful for these token' # get ids @@ -544,7 +575,9 @@ def forward( cond_token_ids: Optional[torch.Tensor] = None, texts: Optional[List[str]] = None, text_embeds: Optional[torch.Tensor] = None, - cond_drop_prob = None + cond_drop_prob = None, + train_only_generator = False, + sample_temperature = None ): # tokenize if needed @@ -618,7 +651,7 @@ def forward( # get loss - ce_loss = self.transformer( + ce_loss, logits = self.transformer( x, texts = texts, text_embeds = text_embeds, @@ -626,10 +659,30 @@ def forward( conditioning_token_ids = cond_token_ids, labels = labels, cond_drop_prob = cond_drop_prob, - ignore_index = ignore_index + ignore_index = ignore_index, + return_logits = True + ) + + if not exists(self.token_critic) or train_only_generator: + return ce_loss + + # token critic loss + + sampled_ids = gumbel_sample(logits, temperature = default(sample_temperature, random())) + + critic_input = torch.where(mask, sampled_ids, x) + critic_labels = (ids != critic_input).float() + + bc_loss = self.token_critic( + critic_input, + texts = texts, + text_embeds = text_embeds, + conditioning_token_ids = cond_token_ids, + labels = critic_labels, + cond_drop_prob = cond_drop_prob ) - return ce_loss + return ce_loss + self.critic_loss_weight * bc_loss # final Muse class diff --git a/setup.py b/setup.py index dac1aa8..8ff885a 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'muse-maskgit-pytorch', packages = find_packages(exclude=[]), - version = '0.0.22', + version = '0.0.23', license='MIT', description = 'MUSE - Text-to-Image Generation via Masked Generative Transformers, in Pytorch', author = 'Phil Wang',