diff --git a/muse_maskgit_pytorch/muse_maskgit_pytorch.py b/muse_maskgit_pytorch/muse_maskgit_pytorch.py index 3e01412..b76b031 100644 --- a/muse_maskgit_pytorch/muse_maskgit_pytorch.py +++ b/muse_maskgit_pytorch/muse_maskgit_pytorch.py @@ -198,6 +198,18 @@ def forward_with_cond_scale( return null_logits + (logits - null_logits) * cond_scale + def forward_with_neg_prompt( + self, + text_embed: torch.Tensor, + neg_text_embed: torch.Tensor, + cond_scale = 3., + **kwargs + ): + neg_logits = self.forward(*args, neg_text_embed = neg_text_embed, cond_drop_prob = 0., **kwargs) + pos_logits = self.forward(*args, text_embed = text_embed, cond_drop_prob = 0., **kwargs) + + return neg_logits + (pos_logits - neg_logits) * cond_scale + def forward( self, x, @@ -340,6 +352,7 @@ def load(self, path): def generate( self, texts: List[str], + negative_texts: Optional[List[str]] = None, cond_images: Optional[torch.Tensor] = None, fmap_size = None, temperature = 1., @@ -369,6 +382,17 @@ def generate( text_embeds = self.transformer.encode_text(texts) + demask_fn = self.transformer.forward_with_cond_scale + + # negative prompting, as in paper + + neg_text_embeds = None + if exists(negative_texts): + assert len(texts) == len(negative_texts) + + neg_text_embeds = self.transformer.encode_text(negative_texts) + demask_fn = partial(self.transformer.forward_with_neg_prompt, neg_text_embeds = neg_text_embeds) + if self.resize_image_for_cond_image: assert exists(cond_images), 'conditioning image must be passed in to generate for super res maskgit' with torch.no_grad(): @@ -383,7 +407,7 @@ def generate( ids = ids.scatter(1, masked_indices, self.mask_id) - logits = self.transformer.forward_with_cond_scale( + logits = demask_fn( ids, text_embeds = text_embeds, conditioning_token_ids = cond_ids, diff --git a/setup.py b/setup.py index f34404b..0367e8f 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'muse-maskgit-pytorch', packages = find_packages(exclude=[]), - version = '0.0.14', + version = '0.0.15', license='MIT', description = 'MUSE - Text-to-Image Generation via Masked Generative Transformers, in Pytorch', author = 'Phil Wang',