Skip to content

Commit

Permalink
bet on self conditioning, given multiple papers having success with i…
Browse files Browse the repository at this point in the history
…t. improvise and adapt it to maskgit
  • Loading branch information
lucidrains committed Jan 16, 2023
1 parent 3f9de5c commit 38b748a
Show file tree
Hide file tree
Showing 4 changed files with 113 additions and 15 deletions.
23 changes: 22 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -227,9 +227,9 @@ images # List[PIL.Image.Image]
- [x] test end-to-end
- [x] separate cond_images_or_ids, it is not done right
- [x] add training code for vae
- [x] add optional self-conditioning on embeddings

- [ ] hook up accelerate training code for maskgit
- [ ] add optional self-conditioning on embeddings
- [ ] combine with token critic paper, already implemented at <a href="https://github.com/lucidrains/phenaki-pytorch">Phenaki</a>

## Citations
Expand All @@ -241,3 +241,24 @@ images # List[PIL.Image.Image]
year = {2023}
}
```

```bibtex
@article{Chen2022AnalogBG,
title = {Analog Bits: Generating Discrete Data using Diffusion Models with Self-Conditioning},
author = {Ting Chen and Ruixiang Zhang and Geo rey E. Hinton},
journal = {ArXiv},
year = {2022},
volume = {abs/2208.04202}
}
```

```bibtex
@misc{jabri2022scalable,
title = {Scalable Adaptive Computation for Iterative Generation},
author = {Allan Jabri and David Fleet and Ting Chen},
year = {2022},
eprint = {2212.11972},
archivePrefix = {arXiv},
primaryClass = {cs.LG}
}
```
90 changes: 78 additions & 12 deletions muse_maskgit_pytorch/muse_maskgit_pytorch.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import math
from random import random
from functools import partial

import torch
Expand Down Expand Up @@ -26,6 +27,15 @@ def exists(val):
def default(val, d):
return val if exists(val) else d

def eval_decorator(fn):
def inner(model, *args, **kwargs):
was_training = model.training
model.eval()
out = fn(model, *args, **kwargs)
model.train(was_training)
return out
return inner

# classes

class LayerNorm(nn.Module):
Expand Down Expand Up @@ -161,6 +171,7 @@ def __init__(
dim,
seq_len,
t5_name = DEFAULT_T5_NAME,
self_cond = False,
**kwargs
):
super().__init__()
Expand All @@ -184,38 +195,57 @@ def __init__(

self.text_embed_proj = nn.Linear(text_embed_dim, dim, bias = False) if text_embed_dim != dim else nn.Identity()

# optional self conditioning

self.self_cond = self_cond
self.self_cond_to_init_embed = FeedForward(dim)

def forward_with_cond_scale(
self,
*args,
cond_scale = 3.,
return_embed = False,
**kwargs
):
logits = self.forward(*args, cond_drop_prob = 0., **kwargs)
if cond_scale == 1:
return logits
return self.forward(*args, return_embed = return_embed, cond_drop_prob = 0., **kwargs)

logits, embed = self.forward(*args, return_embed = True, cond_drop_prob = 0., **kwargs)

null_logits = self.forward(*args, cond_drop_prob = 1., **kwargs)

return null_logits + (logits - null_logits) * cond_scale
scaled_logits = null_logits + (logits - null_logits) * cond_scale

if return_embed:
return scaled_logits, embed

return scaled_logits

def forward_with_neg_prompt(
self,
text_embed: torch.Tensor,
neg_text_embed: torch.Tensor,
cond_scale = 3.,
return_embed = False,
**kwargs
):
neg_logits = self.forward(*args, neg_text_embed = neg_text_embed, cond_drop_prob = 0., **kwargs)
pos_logits = self.forward(*args, text_embed = text_embed, cond_drop_prob = 0., **kwargs)
pos_logits, embed = self.forward(*args, return_embed = True, text_embed = text_embed, cond_drop_prob = 0., **kwargs)

logits = neg_logits + (pos_logits - neg_logits) * cond_scale

if return_embed:
return scaled_logits, embed

return neg_logits + (pos_logits - neg_logits) * cond_scale
return scaled_logits

def forward(
self,
x,
return_embed = False,
labels = None,
ignore_index = 0,
self_cond_embed = None,
cond_drop_prob = 0.,
conditioning_token_ids: Optional[torch.Tensor] = None,
texts: Optional[List[str]] = None,
Expand Down Expand Up @@ -254,12 +284,17 @@ def forward(
x = self.token_emb(x)
x = x + self.pos_emb(torch.arange(n, device = device))

x = self.transformer_blocks(x, context = context, context_mask = context_mask)
if self.self_cond:
if not exists(self_cond_embed):
self_cond_embed = torch.zeros_like(x)
x = x + self.self_cond_to_init_embed(self_cond_embed)

if return_embed:
return x
embed = self.transformer_blocks(x, context = context, context_mask = context_mask)

logits = self.to_logits(embed)

logits = self.to_logits(x)
if return_embed:
return logits, embed

if not exists(labels):
return logits
Expand Down Expand Up @@ -316,7 +351,8 @@ def __init__(
vae: Optional[VQGanVAE] = None,
cond_vae: Optional[VQGanVAE] = None,
cond_image_size = None,
cond_drop_prob = 0.5
cond_drop_prob = 0.5,
self_cond_prob = 0.9
):
super().__init__()
self.vae = vae.copy_for_eval() if exists(vae) else None
Expand All @@ -335,11 +371,15 @@ def __init__(
self.cond_drop_prob = cond_drop_prob

self.transformer = transformer
self.self_cond = transformer.self_cond
assert self.vae.codebook_size == self.cond_vae.codebook_size == transformer.num_tokens, 'transformer num_tokens must be set to be equal to the vae codebook size'

self.mask_id = transformer.mask_id
self.noise_schedule = noise_schedule

# self conditioning
self.self_cond_prob = self_cond_prob

def save(self, path):
torch.save(self.state_dict(), path)

Expand All @@ -349,6 +389,8 @@ def load(self, path):
state_dict = torch.load(str(path))
self.load_state_dict(state_dict)

@torch.no_grad()
@eval_decorator
def generate(
self,
texts: List[str],
Expand Down Expand Up @@ -398,6 +440,8 @@ def generate(
with torch.no_grad():
_, cond_ids, _ = self.cond_vae.encode(cond_images)

self_cond_embed = None

for timestep, steps_until_x0 in tqdm(zip(torch.linspace(0, 1, timesteps, device = device), reversed(range(timesteps))), total = timesteps):

rand_mask_prob = self.noise_schedule(timestep)
Expand All @@ -407,13 +451,17 @@ def generate(

ids = ids.scatter(1, masked_indices, self.mask_id)

logits = demask_fn(
logits, embed = demask_fn(
ids,
text_embeds = text_embeds,
self_cond_embed = self_cond_embed,
conditioning_token_ids = cond_ids,
cond_scale = cond_scale
cond_scale = cond_scale,
return_embed = True
)

self_cond_embed = embed if self.self_cond else None

filtered_logits = top_k(logits, topk_filter_thres)

temperature = starting_temperature * (steps_until_x0 / timesteps) # temperature is annealed
Expand Down Expand Up @@ -507,12 +555,30 @@ def forward(
x = torch.where(mask, mask_id, ids)
labels = torch.where(mask, ids, ignore_index)

# self conditioning

self_cond_embed = None

if self.transformer.self_cond and random() < self.self_cond_prob:
with torch.no_grad():
_, self_cond_embed = self.transformer(
x,
texts = texts,
text_embeds = text_embeds,
conditioning_token_ids = cond_token_ids,
cond_drop_prob = 0.,
return_embed = True
)

self_cond_embed.detach_()

# get loss

ce_loss = self.transformer(
x,
texts = texts,
text_embeds = text_embeds,
self_cond_embed = self_cond_embed,
conditioning_token_ids = cond_token_ids,
labels = labels,
cond_drop_prob = cond_drop_prob,
Expand Down
13 changes: 12 additions & 1 deletion muse_maskgit_pytorch/t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
import transformers
from transformers import T5Tokenizer, T5EncoderModel, T5Config

from beartype import beartype
from typing import List, Union

transformers.logging.set_verbosity_error()

def exists(val):
Expand Down Expand Up @@ -53,7 +56,15 @@ def get_encoded_dim(name):

# encoding text

def t5_encode_text(texts, name = DEFAULT_T5_NAME, output_device = None):
@beartype
def t5_encode_text(
texts: Union[str, List[str]],
name = DEFAULT_T5_NAME,
output_device = None
):
if isinstance(texts, str):
texts = [texts]

t5, tokenizer = get_model_and_tokenizer(name)

if torch.cuda.is_available():
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'muse-maskgit-pytorch',
packages = find_packages(exclude=[]),
version = '0.0.19',
version = '0.0.20',
license='MIT',
description = 'MUSE - Text-to-Image Generation via Masked Generative Transformers, in Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 38b748a

Please sign in to comment.