Skip to content

Commit

Permalink
add negative prompting
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jan 10, 2023
1 parent 277239a commit 5bdcefa
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 2 deletions.
26 changes: 25 additions & 1 deletion muse_maskgit_pytorch/muse_maskgit_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,18 @@ def forward_with_cond_scale(

return null_logits + (logits - null_logits) * cond_scale

def forward_with_neg_prompt(
self,
text_embed: torch.Tensor,
neg_text_embed: torch.Tensor,
cond_scale = 3.,
**kwargs
):
neg_logits = self.forward(*args, neg_text_embed = neg_text_embed, cond_drop_prob = 0., **kwargs)
pos_logits = self.forward(*args, text_embed = text_embed, cond_drop_prob = 0., **kwargs)

return neg_logits + (pos_logits - neg_logits) * cond_scale

def forward(
self,
x,
Expand Down Expand Up @@ -340,6 +352,7 @@ def load(self, path):
def generate(
self,
texts: List[str],
negative_texts: Optional[List[str]] = None,
cond_images: Optional[torch.Tensor] = None,
fmap_size = None,
temperature = 1.,
Expand Down Expand Up @@ -369,6 +382,17 @@ def generate(

text_embeds = self.transformer.encode_text(texts)

demask_fn = self.transformer.forward_with_cond_scale

# negative prompting, as in paper

neg_text_embeds = None
if exists(negative_texts):
assert len(texts) == len(negative_texts)

neg_text_embeds = self.transformer.encode_text(negative_texts)
demask_fn = partial(self.transformer.forward_with_neg_prompt, neg_text_embeds = neg_text_embeds)

if self.resize_image_for_cond_image:
assert exists(cond_images), 'conditioning image must be passed in to generate for super res maskgit'
with torch.no_grad():
Expand All @@ -383,7 +407,7 @@ def generate(

ids = ids.scatter(1, masked_indices, self.mask_id)

logits = self.transformer.forward_with_cond_scale(
logits = demask_fn(
ids,
text_embeds = text_embeds,
conditioning_token_ids = cond_ids,
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'muse-maskgit-pytorch',
packages = find_packages(exclude=[]),
version = '0.0.14',
version = '0.0.15',
license='MIT',
description = 'MUSE - Text-to-Image Generation via Masked Generative Transformers, in Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 5bdcefa

Please sign in to comment.