Skip to content

Commit

Permalink
add self critic option for token critic
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jan 22, 2023
1 parent c8e6f8c commit 5570c74
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 6 deletions.
9 changes: 9 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -272,3 +272,12 @@ images # List[PIL.Image.Image]
volume = {abs/2209.04439}
}
```

```bibtex
@inproceedings{Nijkamp2021SCRIPTSP,
title = {SCRIPT: Self-Critic PreTraining of Transformers},
author = {Erik Nijkamp and Bo Pang and Ying Nian Wu and Caiming Xiong},
booktitle = {North American Chapter of the Association for Computational Linguistics},
year = {2021}
}
```
46 changes: 41 additions & 5 deletions muse_maskgit_pytorch/muse_maskgit_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,7 @@ def __init__(
**kwargs
):
super().__init__()
self.dim = dim
self.mask_id = num_tokens if add_mask_id else None

self.num_tokens = num_tokens
Expand Down Expand Up @@ -330,6 +331,32 @@ def forward(

return loss, logits

# self critic wrapper

class SelfCritic(nn.Module):
def __init__(self, net):
super().__init__()
self.net = net
self.to_pred = nn.Linear(net.dim, 1)

def forward_with_cond_scale(self, x, *args, **kwargs):
_, embeds = self.net.forward_with_cond_scale(x, *args, return_embed = True, **kwargs)
return self.to_pred(embeds)

def forward_with_neg_prompt(self, x, *args, **kwargs):
_, embeds = self.net.forward_with_neg_prompt(x, *args, return_embed = True, **kwargs)
return self.to_pred(embeds)

def forward(self, x, *args, labels = None, **kwargs):
_, embeds = self.net(x, *args, return_embed = True, **kwargs)
logits = self.to_pred(embeds)

if not exists(labels):
return logits

logits = rearrange(logits, '... 1 -> ...')
return F.binary_cross_entropy_with_logits(logits, labels)

# specialized transformers

class MaskGitTransformer(Transformer):
Expand Down Expand Up @@ -389,6 +416,7 @@ def __init__(
transformer: MaskGitTransformer,
noise_schedule: Callable = cosine_schedule,
token_critic: Optional[TokenCritic] = None,
self_token_critic = False,
vae: Optional[VQGanVAE] = None,
cond_vae: Optional[VQGanVAE] = None,
cond_image_size = None,
Expand Down Expand Up @@ -420,7 +448,12 @@ def __init__(
self.mask_id = transformer.mask_id
self.noise_schedule = noise_schedule

assert not (self_token_critic and exists(token_critic))
self.token_critic = token_critic

if self_token_critic:
self.token_critic = SelfCritic(transformer)

self.critic_loss_weight = critic_loss_weight

# self conditioning
Expand Down Expand Up @@ -632,6 +665,12 @@ def forward(

x = torch.where(mask, mask_id, ids)

# get text embeddings

if exists(texts):
text_embeds = self.transformer.encode_text(texts)
texts = None

# self conditioning

self_cond_embed = None
Expand All @@ -640,7 +679,6 @@ def forward(
with torch.no_grad():
_, self_cond_embed = self.transformer(
x,
texts = texts,
text_embeds = text_embeds,
conditioning_token_ids = cond_token_ids,
cond_drop_prob = 0.,
Expand All @@ -653,7 +691,6 @@ def forward(

ce_loss, logits = self.transformer(
x,
texts = texts,
text_embeds = text_embeds,
self_cond_embed = self_cond_embed,
conditioning_token_ids = cond_token_ids,
Expand All @@ -673,16 +710,15 @@ def forward(
critic_input = torch.where(mask, sampled_ids, x)
critic_labels = (ids != critic_input).float()

bc_loss = self.token_critic(
bce_loss = self.token_critic(
critic_input,
texts = texts,
text_embeds = text_embeds,
conditioning_token_ids = cond_token_ids,
labels = critic_labels,
cond_drop_prob = cond_drop_prob
)

return ce_loss + self.critic_loss_weight * bc_loss
return ce_loss + self.critic_loss_weight * bce_loss

# final Muse class

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'muse-maskgit-pytorch',
packages = find_packages(exclude=[]),
version = '0.0.23',
version = '0.0.24',
license='MIT',
description = 'MUSE - Text-to-Image Generation via Masked Generative Transformers, in Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 5570c74

Please sign in to comment.