diff --git a/README.md b/README.md index a2e84dc..150e5ea 100644 --- a/README.md +++ b/README.md @@ -227,9 +227,9 @@ images # List[PIL.Image.Image] - [x] test end-to-end - [x] separate cond_images_or_ids, it is not done right - [x] add training code for vae +- [x] add optional self-conditioning on embeddings - [ ] hook up accelerate training code for maskgit -- [ ] add optional self-conditioning on embeddings - [ ] combine with token critic paper, already implemented at Phenaki ## Citations @@ -241,3 +241,24 @@ images # List[PIL.Image.Image] year = {2023} } ``` + +```bibtex +@article{Chen2022AnalogBG, + title = {Analog Bits: Generating Discrete Data using Diffusion Models with Self-Conditioning}, + author = {Ting Chen and Ruixiang Zhang and Geo rey E. Hinton}, + journal = {ArXiv}, + year = {2022}, + volume = {abs/2208.04202} +} +``` + +```bibtex +@misc{jabri2022scalable, + title = {Scalable Adaptive Computation for Iterative Generation}, + author = {Allan Jabri and David Fleet and Ting Chen}, + year = {2022}, + eprint = {2212.11972}, + archivePrefix = {arXiv}, + primaryClass = {cs.LG} +} +``` diff --git a/muse_maskgit_pytorch/muse_maskgit_pytorch.py b/muse_maskgit_pytorch/muse_maskgit_pytorch.py index 5ba9597..69d90fd 100644 --- a/muse_maskgit_pytorch/muse_maskgit_pytorch.py +++ b/muse_maskgit_pytorch/muse_maskgit_pytorch.py @@ -1,4 +1,5 @@ import math +from random import random from functools import partial import torch @@ -26,6 +27,15 @@ def exists(val): def default(val, d): return val if exists(val) else d +def eval_decorator(fn): + def inner(model, *args, **kwargs): + was_training = model.training + model.eval() + out = fn(model, *args, **kwargs) + model.train(was_training) + return out + return inner + # classes class LayerNorm(nn.Module): @@ -161,6 +171,7 @@ def __init__( dim, seq_len, t5_name = DEFAULT_T5_NAME, + self_cond = False, **kwargs ): super().__init__() @@ -184,31 +195,49 @@ def __init__( self.text_embed_proj = nn.Linear(text_embed_dim, dim, bias = False) if text_embed_dim != dim else nn.Identity() + # optional self conditioning + + self.self_cond = self_cond + self.self_cond_to_init_embed = FeedForward(dim) + def forward_with_cond_scale( self, *args, cond_scale = 3., + return_embed = False, **kwargs ): - logits = self.forward(*args, cond_drop_prob = 0., **kwargs) if cond_scale == 1: - return logits + return self.forward(*args, return_embed = return_embed, cond_drop_prob = 0., **kwargs) + + logits, embed = self.forward(*args, return_embed = True, cond_drop_prob = 0., **kwargs) null_logits = self.forward(*args, cond_drop_prob = 1., **kwargs) - return null_logits + (logits - null_logits) * cond_scale + scaled_logits = null_logits + (logits - null_logits) * cond_scale + + if return_embed: + return scaled_logits, embed + + return scaled_logits def forward_with_neg_prompt( self, text_embed: torch.Tensor, neg_text_embed: torch.Tensor, cond_scale = 3., + return_embed = False, **kwargs ): neg_logits = self.forward(*args, neg_text_embed = neg_text_embed, cond_drop_prob = 0., **kwargs) - pos_logits = self.forward(*args, text_embed = text_embed, cond_drop_prob = 0., **kwargs) + pos_logits, embed = self.forward(*args, return_embed = True, text_embed = text_embed, cond_drop_prob = 0., **kwargs) + + logits = neg_logits + (pos_logits - neg_logits) * cond_scale + + if return_embed: + return scaled_logits, embed - return neg_logits + (pos_logits - neg_logits) * cond_scale + return scaled_logits def forward( self, @@ -216,6 +245,7 @@ def forward( return_embed = False, labels = None, ignore_index = 0, + self_cond_embed = None, cond_drop_prob = 0., conditioning_token_ids: Optional[torch.Tensor] = None, texts: Optional[List[str]] = None, @@ -254,12 +284,17 @@ def forward( x = self.token_emb(x) x = x + self.pos_emb(torch.arange(n, device = device)) - x = self.transformer_blocks(x, context = context, context_mask = context_mask) + if self.self_cond: + if not exists(self_cond_embed): + self_cond_embed = torch.zeros_like(x) + x = x + self.self_cond_to_init_embed(self_cond_embed) - if return_embed: - return x + embed = self.transformer_blocks(x, context = context, context_mask = context_mask) + + logits = self.to_logits(embed) - logits = self.to_logits(x) + if return_embed: + return logits, embed if not exists(labels): return logits @@ -316,7 +351,8 @@ def __init__( vae: Optional[VQGanVAE] = None, cond_vae: Optional[VQGanVAE] = None, cond_image_size = None, - cond_drop_prob = 0.5 + cond_drop_prob = 0.5, + self_cond_prob = 0.9 ): super().__init__() self.vae = vae.copy_for_eval() if exists(vae) else None @@ -335,11 +371,15 @@ def __init__( self.cond_drop_prob = cond_drop_prob self.transformer = transformer + self.self_cond = transformer.self_cond assert self.vae.codebook_size == self.cond_vae.codebook_size == transformer.num_tokens, 'transformer num_tokens must be set to be equal to the vae codebook size' self.mask_id = transformer.mask_id self.noise_schedule = noise_schedule + # self conditioning + self.self_cond_prob = self_cond_prob + def save(self, path): torch.save(self.state_dict(), path) @@ -349,6 +389,8 @@ def load(self, path): state_dict = torch.load(str(path)) self.load_state_dict(state_dict) + @torch.no_grad() + @eval_decorator def generate( self, texts: List[str], @@ -398,6 +440,8 @@ def generate( with torch.no_grad(): _, cond_ids, _ = self.cond_vae.encode(cond_images) + self_cond_embed = None + for timestep, steps_until_x0 in tqdm(zip(torch.linspace(0, 1, timesteps, device = device), reversed(range(timesteps))), total = timesteps): rand_mask_prob = self.noise_schedule(timestep) @@ -407,13 +451,17 @@ def generate( ids = ids.scatter(1, masked_indices, self.mask_id) - logits = demask_fn( + logits, embed = demask_fn( ids, text_embeds = text_embeds, + self_cond_embed = self_cond_embed, conditioning_token_ids = cond_ids, - cond_scale = cond_scale + cond_scale = cond_scale, + return_embed = True ) + self_cond_embed = embed if self.self_cond else None + filtered_logits = top_k(logits, topk_filter_thres) temperature = starting_temperature * (steps_until_x0 / timesteps) # temperature is annealed @@ -507,12 +555,30 @@ def forward( x = torch.where(mask, mask_id, ids) labels = torch.where(mask, ids, ignore_index) + # self conditioning + + self_cond_embed = None + + if self.transformer.self_cond and random() < self.self_cond_prob: + with torch.no_grad(): + _, self_cond_embed = self.transformer( + x, + texts = texts, + text_embeds = text_embeds, + conditioning_token_ids = cond_token_ids, + cond_drop_prob = 0., + return_embed = True + ) + + self_cond_embed.detach_() + # get loss ce_loss = self.transformer( x, texts = texts, text_embeds = text_embeds, + self_cond_embed = self_cond_embed, conditioning_token_ids = cond_token_ids, labels = labels, cond_drop_prob = cond_drop_prob, diff --git a/muse_maskgit_pytorch/t5.py b/muse_maskgit_pytorch/t5.py index 6833f71..ccb68f6 100644 --- a/muse_maskgit_pytorch/t5.py +++ b/muse_maskgit_pytorch/t5.py @@ -3,6 +3,9 @@ import transformers from transformers import T5Tokenizer, T5EncoderModel, T5Config +from beartype import beartype +from typing import List, Union + transformers.logging.set_verbosity_error() def exists(val): @@ -53,7 +56,15 @@ def get_encoded_dim(name): # encoding text -def t5_encode_text(texts, name = DEFAULT_T5_NAME, output_device = None): +@beartype +def t5_encode_text( + texts: Union[str, List[str]], + name = DEFAULT_T5_NAME, + output_device = None +): + if isinstance(texts, str): + texts = [texts] + t5, tokenizer = get_model_and_tokenizer(name) if torch.cuda.is_available(): diff --git a/setup.py b/setup.py index 9b8228b..c2da37e 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'muse-maskgit-pytorch', packages = find_packages(exclude=[]), - version = '0.0.19', + version = '0.0.20', license='MIT', description = 'MUSE - Text-to-Image Generation via Masked Generative Transformers, in Pytorch', author = 'Phil Wang',