Skip to content

Commit

Permalink
just make token critic work, make fast later
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jan 20, 2023
1 parent 862b8e9 commit 8211513
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 16 deletions.
10 changes: 10 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -262,3 +262,13 @@ images # List[PIL.Image.Image]
primaryClass = {cs.LG}
}
```

```bibtex
@article{Lezama2022ImprovedMI,
title = {Improved Masked Image Generation with Token-Critic},
author = {Jos{\'e} Lezama and Huiwen Chang and Lu Jiang and Irfan Essa},
journal = {ArXiv},
year = {2022},
volume = {abs/2209.04439}
}
```
83 changes: 68 additions & 15 deletions muse_maskgit_pytorch/muse_maskgit_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,7 @@ def forward(
self,
x,
return_embed = False,
return_logits = False,
labels = None,
ignore_index = 0,
self_cond_embed = None,
Expand Down Expand Up @@ -320,10 +321,14 @@ def forward(
return logits

if self.dim_out == 1:
return F.binary_cross_entropy_with_logits(rearrange(logits, '... 1 -> ...'), labels)
loss = F.binary_cross_entropy_with_logits(rearrange(logits, '... 1 -> ...'), labels)
else:
loss = F.cross_entropy(rearrange(logits, 'b n c -> b c n'), labels, ignore_index = ignore_index)

if not return_logits:
return loss

logits = rearrange(logits, 'b n c -> b c n')
return F.cross_entropy(logits, labels, ignore_index = ignore_index)
return loss, logits

# specialized transformers

Expand Down Expand Up @@ -383,12 +388,14 @@ def __init__(
image_size,
transformer: MaskGitTransformer,
noise_schedule: Callable = cosine_schedule,
token_critic: Optional[TokenCritic] = None,
vae: Optional[VQGanVAE] = None,
cond_vae: Optional[VQGanVAE] = None,
cond_image_size = None,
cond_drop_prob = 0.5,
self_cond_prob = 0.9,
no_mask_token_prob = 0.
no_mask_token_prob = 0.,
critic_loss_weight = 1.
):
super().__init__()
self.vae = vae.copy_for_eval() if exists(vae) else None
Expand All @@ -413,6 +420,9 @@ def __init__(
self.mask_id = transformer.mask_id
self.noise_schedule = noise_schedule

self.token_critic = token_critic
self.critic_loss_weight = critic_loss_weight

# self conditioning
self.self_cond_prob = self_cond_prob

Expand Down Expand Up @@ -440,6 +450,7 @@ def generate(
temperature = 1.,
topk_filter_thres = 0.9,
can_remask_prev_masked = False,
force_not_use_token_critic = False,
timesteps = 18, # ideal number of steps is 18 in maskgit paper
cond_scale = 3,
):
Expand All @@ -466,6 +477,13 @@ def generate(

demask_fn = self.transformer.forward_with_cond_scale

# whether to use token critic for scores

use_token_critic = exists(self.token_critic) and not force_not_use_token_critic

if use_token_critic:
token_critic_fn = self.token_critic.forward_with_cond_scale

# negative prompting, as in paper

neg_text_embeds = None
Expand All @@ -475,6 +493,9 @@ def generate(
neg_text_embeds = self.transformer.encode_text(negative_texts)
demask_fn = partial(self.transformer.forward_with_neg_prompt, neg_text_embeds = neg_text_embeds)

if use_token_critic:
token_critic_fn = partial(self.token_critic.forward_with_neg_prompt, neg_text_embeds = neg_text_embeds)

if self.resize_image_for_cond_image:
assert exists(cond_images), 'conditioning image must be passed in to generate for super res maskgit'
with torch.no_grad():
Expand Down Expand Up @@ -516,15 +537,25 @@ def generate(
ids
)

probs_without_temperature = logits.softmax(dim = -1)

scores = 1 - probs_without_temperature.gather(2, pred_ids[..., None])
scores = rearrange(scores, '... 1 -> ...')
if use_token_critic:
scores = token_critic_fn(
ids,
text_embeds = text_embeds,
conditioning_token_ids = cond_ids,
cond_scale = cond_scale
)

if not can_remask_prev_masked:
scores = scores.masked_fill(~is_mask, -1e5)
scores = rearrange(scores, '... 1 -> ...')
else:
assert self.no_mask_token_prob > 0., 'without training with some of the non-masked tokens forced to predict, not sure if the logits will be meaningful for these token'
probs_without_temperature = logits.softmax(dim = -1)

scores = 1 - probs_without_temperature.gather(2, pred_ids[..., None])
scores = rearrange(scores, '... 1 -> ...')

if not can_remask_prev_masked:
scores = scores.masked_fill(~is_mask, -1e5)
else:
assert self.no_mask_token_prob > 0., 'without training with some of the non-masked tokens forced to predict, not sure if the logits will be meaningful for these token'

# get ids

Expand All @@ -544,7 +575,9 @@ def forward(
cond_token_ids: Optional[torch.Tensor] = None,
texts: Optional[List[str]] = None,
text_embeds: Optional[torch.Tensor] = None,
cond_drop_prob = None
cond_drop_prob = None,
train_only_generator = False,
sample_temperature = None
):
# tokenize if needed

Expand Down Expand Up @@ -618,18 +651,38 @@ def forward(

# get loss

ce_loss = self.transformer(
ce_loss, logits = self.transformer(
x,
texts = texts,
text_embeds = text_embeds,
self_cond_embed = self_cond_embed,
conditioning_token_ids = cond_token_ids,
labels = labels,
cond_drop_prob = cond_drop_prob,
ignore_index = ignore_index
ignore_index = ignore_index,
return_logits = True
)

if not exists(self.token_critic) or train_only_generator:
return ce_loss

# token critic loss

sampled_ids = gumbel_sample(logits, temperature = default(sample_temperature, random()))

critic_input = torch.where(mask, sampled_ids, x)
critic_labels = (ids != critic_input).float()

bc_loss = self.token_critic(
critic_input,
texts = texts,
text_embeds = text_embeds,
conditioning_token_ids = cond_token_ids,
labels = critic_labels,
cond_drop_prob = cond_drop_prob
)

return ce_loss
return ce_loss + self.critic_loss_weight * bc_loss

# final Muse class

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'muse-maskgit-pytorch',
packages = find_packages(exclude=[]),
version = '0.0.22',
version = '0.0.23',
license='MIT',
description = 'MUSE - Text-to-Image Generation via Masked Generative Transformers, in Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 8211513

Please sign in to comment.