diff --git a/muse_maskgit_pytorch/__init__.py b/muse_maskgit_pytorch/__init__.py index 4587d40..c9169fd 100644 --- a/muse_maskgit_pytorch/__init__.py +++ b/muse_maskgit_pytorch/__init__.py @@ -1,4 +1,14 @@ from muse_maskgit_pytorch.vqgan_vae import VQGanVAE -from muse_maskgit_pytorch.muse_maskgit_pytorch import Transformer, MaskGit, Muse, MaskGitTransformer, TokenCritic +from muse_maskgit_pytorch.muse_maskgit_pytorch import ( + Transformer, + MaskGit, + Muse, + MaskGitTransformer, + TokenCritic, +) -from muse_maskgit_pytorch.trainers import VQGanVAETrainer, MaskGitTrainer, get_accelerator +from muse_maskgit_pytorch.trainers import ( + VQGanVAETrainer, + MaskGitTrainer, + get_accelerator, +) diff --git a/muse_maskgit_pytorch/dataset.py b/muse_maskgit_pytorch/dataset.py index 7a44224..2f951ef 100644 --- a/muse_maskgit_pytorch/dataset.py +++ b/muse_maskgit_pytorch/dataset.py @@ -10,8 +10,10 @@ from torch.utils.data import Dataset, DataLoader, random_split import os from tqdm import tqdm + ImageFile.LOAD_TRUNCATED_IMAGES = True + class ImageDataset(Dataset): def __init__(self, dataset, image_size, image_column="image"): super().__init__() @@ -31,11 +33,19 @@ def __len__(self): return len(self.dataset) def __getitem__(self, index): - image= self.dataset[index][self.image_column] + image = self.dataset[index][self.image_column] return self.transform(image) + class ImageTextDataset(ImageDataset): - def __init__(self, dataset, image_size, tokenizer, image_column="image", caption_column="caption"): + def __init__( + self, + dataset, + image_size, + tokenizer, + image_column="image", + caption_column="caption", + ): super().__init__(dataset, image_size=image_size, image_column=image_column) self.caption_column = caption_column self.tokenizer = tokenizer @@ -65,7 +75,10 @@ def __getitem__(self, index): attn_mask = encoded.attention_mask return self.transform(image), input_ids[0], attn_mask[0] -def get_dataset_from_dataroot(data_root, image_column="image", caption_column="caption", save_path="dataset"): + +def get_dataset_from_dataroot( + data_root, image_column="image", caption_column="caption", save_path="dataset" +): if os.path.exists(save_path): return load_from_disk(save_path) image_paths = list(Path(data_root).rglob("*.[jJ][pP][gG]")) @@ -74,7 +87,7 @@ def get_dataset_from_dataroot(data_root, image_column="image", caption_column="c for image_path in tqdm(image_paths): caption_path = image_path.with_suffix(".txt") if os.path.exists(str(caption_path)): - captions = caption_path.read_text(encoding="utf-8").split('\n') + captions = caption_path.read_text(encoding="utf-8").split("\n") captions = list(filter(lambda t: len(t) > 0, captions)) else: captions = [] @@ -86,24 +99,27 @@ def get_dataset_from_dataroot(data_root, image_column="image", caption_column="c dataset.save_to_disk(save_path) return dataset + def split_dataset_into_dataloaders(dataset, valid_frac=0.05, seed=42, batch_size=1): if valid_frac > 0: train_size = int((1 - valid_frac) * len(dataset)) valid_size = len(dataset) - train_size - dataset, validation_dataset = random_split(dataset, [train_size, valid_size], generator = torch.Generator().manual_seed(seed)) - print(f'training with dataset of {len(dataset)} samples and validating with randomly splitted {len(validation_dataset)} samples') + dataset, validation_dataset = random_split( + dataset, + [train_size, valid_size], + generator=torch.Generator().manual_seed(seed), + ) + print( + f"training with dataset of {len(dataset)} samples and validating with randomly splitted {len(validation_dataset)} samples" + ) else: validation_dataset = dataset - print(f'training with shared training and valid dataset of {len(dataset)} samples') - dataloader = DataLoader( - dataset, - batch_size = batch_size, - shuffle = True - ) + print( + f"training with shared training and valid dataset of {len(dataset)} samples" + ) + dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True) validation_dataloader = DataLoader( - validation_dataset, - batch_size = batch_size, - shuffle = True + validation_dataset, batch_size=batch_size, shuffle=True ) return dataloader, validation_dataloader diff --git a/muse_maskgit_pytorch/muse_maskgit_pytorch.py b/muse_maskgit_pytorch/muse_maskgit_pytorch.py index 4a24f1f..f4bf0d1 100644 --- a/muse_maskgit_pytorch/muse_maskgit_pytorch.py +++ b/muse_maskgit_pytorch/muse_maskgit_pytorch.py @@ -15,17 +15,26 @@ from beartype import beartype from muse_maskgit_pytorch.vqgan_vae import VQGanVAE -from muse_maskgit_pytorch.t5 import t5_encode_text, get_encoded_dim, DEFAULT_T5_NAME, get_model_and_tokenizer +from muse_maskgit_pytorch.t5 import ( + t5_encode_text, + get_encoded_dim, + DEFAULT_T5_NAME, + get_model_and_tokenizer, +) from pathlib import Path from tqdm.auto import tqdm + # helpers + def exists(val): return val is not None + def default(val, d): return val if exists(val) else d + def eval_decorator(fn): def inner(model, *args, **kwargs): was_training = model.training @@ -33,70 +42,72 @@ def inner(model, *args, **kwargs): out = fn(model, *args, **kwargs) model.train(was_training) return out + return inner + def l2norm(t): - return F.normalize(t, dim = -1) + return F.normalize(t, dim=-1) + # tensor helpers -def get_mask_subset_prob(mask, prob, min_mask = 0): + +def get_mask_subset_prob(mask, prob, min_mask=0): batch, seq, device = *mask.shape, mask.device - num_to_mask = (mask.sum(dim = -1, keepdim = True) * prob).clamp(min = min_mask) - logits = torch.rand((batch, seq), device = device) + num_to_mask = (mask.sum(dim=-1, keepdim=True) * prob).clamp(min=min_mask) + logits = torch.rand((batch, seq), device=device) logits = logits.masked_fill(~mask, -1) - randperm = logits.argsort(dim = -1).float() + randperm = logits.argsort(dim=-1).float() - num_padding = (~mask).sum(dim = -1, keepdim = True) + num_padding = (~mask).sum(dim=-1, keepdim=True) randperm -= num_padding subset_mask = randperm < num_to_mask subset_mask.masked_fill_(~mask, False) return subset_mask + # classes + class LayerNorm(nn.Module): def __init__(self, dim): super().__init__() self.gamma = nn.Parameter(torch.ones(dim)) - self.register_buffer('beta', torch.zeros(dim)) + self.register_buffer("beta", torch.zeros(dim)) def forward(self, x): return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta) + class GEGLU(nn.Module): - """ https://arxiv.org/abs/2002.05202 """ + """https://arxiv.org/abs/2002.05202""" def forward(self, x): - x, gate = x.chunk(2, dim = -1) + x, gate = x.chunk(2, dim=-1) return gate * F.gelu(x) -def FeedForward(dim, mult = 4): - """ https://arxiv.org/abs/2110.09456 """ + +def FeedForward(dim, mult=4): + """https://arxiv.org/abs/2110.09456""" inner_dim = int(dim * mult * 2 / 3) return nn.Sequential( LayerNorm(dim), - nn.Linear(dim, inner_dim * 2, bias = False), + nn.Linear(dim, inner_dim * 2, bias=False), GEGLU(), LayerNorm(inner_dim), - nn.Linear(inner_dim, dim, bias = False) + nn.Linear(inner_dim, dim, bias=False), ) + class Attention(nn.Module): - def __init__( - self, - dim, - dim_head = 64, - heads = 8, - cross_attend = False, - scale = 8 - ): + def __init__(self, dim, dim_head=64, heads=8, cross_attend=False, scale=8): super().__init__() self.scale = scale - self.heads = heads + self.heads = heads inner_dim = dim_head * heads self.cross_attend = cross_attend @@ -104,20 +115,15 @@ def __init__( self.null_kv = nn.Parameter(torch.randn(2, heads, 1, dim_head)) - self.to_q = nn.Linear(dim, inner_dim, bias = False) - self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False) + self.to_q = nn.Linear(dim, inner_dim, bias=False) + self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) self.q_scale = nn.Parameter(torch.ones(dim_head)) self.k_scale = nn.Parameter(torch.ones(dim_head)) - self.to_out = nn.Linear(inner_dim, dim, bias = False) + self.to_out = nn.Linear(inner_dim, dim, bias=False) - def forward( - self, - x, - context = None, - context_mask = None - ): + def forward(self, x, context=None, context_mask=None): assert not (exists(context) ^ self.cross_attend) h, is_cross_attn = self.heads, exists(context) @@ -126,69 +132,70 @@ def forward( kv_input = context if self.cross_attend else x - q, k, v = (self.to_q(x), *self.to_kv(kv_input).chunk(2, dim = -1)) + q, k, v = (self.to_q(x), *self.to_kv(kv_input).chunk(2, dim=-1)) - q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v)) + q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v)) nk, nv = self.null_kv - nk, nv = map(lambda t: repeat(t, 'h 1 d -> b h 1 d', b = x.shape[0]), (nk, nv)) + nk, nv = map(lambda t: repeat(t, "h 1 d -> b h 1 d", b=x.shape[0]), (nk, nv)) - k = torch.cat((nk, k), dim = -2) - v = torch.cat((nv, v), dim = -2) + k = torch.cat((nk, k), dim=-2) + v = torch.cat((nv, v), dim=-2) q, k = map(l2norm, (q, k)) q = q * self.q_scale k = k * self.k_scale - sim = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale + sim = einsum("b h i d, b h j d -> b h i j", q, k) * self.scale if exists(context_mask): - context_mask = rearrange(context_mask, 'b j -> b 1 1 j') - context_mask = F.pad(context_mask, (1, 0), value = True) + context_mask = rearrange(context_mask, "b j -> b 1 1 j") + context_mask = F.pad(context_mask, (1, 0), value=True) mask_value = -torch.finfo(sim.dtype).max sim = sim.masked_fill(~context_mask, mask_value) - attn = sim.softmax(dim = -1) - out = einsum('b h i j, b h j d -> b h i d', attn, v) + attn = sim.softmax(dim=-1) + out = einsum("b h i j, b h j d -> b h i d", attn, v) - out = rearrange(out, 'b h n d -> b n (h d)') + out = rearrange(out, "b h n d -> b n (h d)") return self.to_out(out) + class TransformerBlocks(nn.Module): - def __init__( - self, - *, - dim, - depth, - dim_head = 64, - heads = 8, - ff_mult = 4 - ): + def __init__(self, *, dim, depth, dim_head=64, heads=8, ff_mult=4): super().__init__() self.layers = nn.ModuleList([]) for _ in range(depth): - self.layers.append(nn.ModuleList([ - Attention(dim = dim, dim_head = dim_head, heads = heads), - Attention(dim = dim, dim_head = dim_head, heads = heads, cross_attend = True), - FeedForward(dim = dim, mult = ff_mult) - ])) + self.layers.append( + nn.ModuleList( + [ + Attention(dim=dim, dim_head=dim_head, heads=heads), + Attention( + dim=dim, dim_head=dim_head, heads=heads, cross_attend=True + ), + FeedForward(dim=dim, mult=ff_mult), + ] + ) + ) self.norm = LayerNorm(dim) - def forward(self, x, context = None, context_mask = None): + def forward(self, x, context=None, context_mask=None): for attn, cross_attn, ff in self.layers: x = attn(x) + x - x = cross_attn(x, context = context, context_mask = context_mask) + x + x = cross_attn(x, context=context, context_mask=context_mask) + x x = ff(x) + x return self.norm(x) + # transformer - it's all we need + class Transformer(nn.Module): def __init__( self, @@ -196,10 +203,10 @@ def __init__( num_tokens, dim, seq_len, - dim_out = None, - t5_name = DEFAULT_T5_NAME, - self_cond = False, - add_mask_id = False, + dim_out=None, + t5_name=DEFAULT_T5_NAME, + self_cond=False, + add_mask_id=False, **kwargs ): super().__init__() @@ -211,20 +218,24 @@ def __init__( self.pos_emb = nn.Embedding(seq_len, dim) self.seq_len = seq_len - self.transformer_blocks = TransformerBlocks(dim = dim, **kwargs) + self.transformer_blocks = TransformerBlocks(dim=dim, **kwargs) self.norm = LayerNorm(dim) self.dim_out = default(dim_out, num_tokens) - self.to_logits = nn.Linear(dim, self.dim_out, bias = False) + self.to_logits = nn.Linear(dim, self.dim_out, bias=False) # text conditioning - self.t5, self.tokenizer= get_model_and_tokenizer(t5_name) + self.t5, self.tokenizer = get_model_and_tokenizer(t5_name) self.t5.eval() - self.encode_text = partial(t5_encode_text, tokenizer = self.tokenizer, t5=self.t5) + self.encode_text = partial(t5_encode_text, tokenizer=self.tokenizer, t5=self.t5) text_embed_dim = get_encoded_dim(t5_name) - self.text_embed_proj = nn.Linear(text_embed_dim, dim, bias = False) if text_embed_dim != dim else nn.Identity() + self.text_embed_proj = ( + nn.Linear(text_embed_dim, dim, bias=False) + if text_embed_dim != dim + else nn.Identity() + ) # optional self conditioning @@ -232,18 +243,18 @@ def __init__( self.self_cond_to_init_embed = FeedForward(dim) def forward_with_cond_scale( - self, - *args, - cond_scale = 3., - return_embed = False, - **kwargs + self, *args, cond_scale=3.0, return_embed=False, **kwargs ): if cond_scale == 1: - return self.forward(*args, return_embed = return_embed, cond_drop_prob = 0., **kwargs) + return self.forward( + *args, return_embed=return_embed, cond_drop_prob=0.0, **kwargs + ) - logits, embed = self.forward(*args, return_embed = True, cond_drop_prob = 0., **kwargs) + logits, embed = self.forward( + *args, return_embed=True, cond_drop_prob=0.0, **kwargs + ) - null_logits = self.forward(*args, cond_drop_prob = 1., **kwargs) + null_logits = self.forward(*args, cond_drop_prob=1.0, **kwargs) scaled_logits = null_logits + (logits - null_logits) * cond_scale @@ -257,12 +268,20 @@ def forward_with_neg_prompt( *args, text_embed: torch.Tensor, neg_text_embed: torch.Tensor, - cond_scale = 3., - return_embed = False, + cond_scale=3.0, + return_embed=False, **kwargs ): - neg_logits = self.forward(*args, neg_text_embed = neg_text_embed, cond_drop_prob = 0., **kwargs) - pos_logits, embed = self.forward(*args, return_embed = True, text_embed = text_embed, cond_drop_prob = 0., **kwargs) + neg_logits = self.forward( + *args, neg_text_embed=neg_text_embed, cond_drop_prob=0.0, **kwargs + ) + pos_logits, embed = self.forward( + *args, + return_embed=True, + text_embed=text_embed, + cond_drop_prob=0.0, + **kwargs + ) scaled_logits = neg_logits + (pos_logits - neg_logits) * cond_scale @@ -274,15 +293,15 @@ def forward_with_neg_prompt( def forward( self, x, - return_embed = False, - return_logits = False, - labels = None, - ignore_index = 0, - self_cond_embed = None, - cond_drop_prob = 0., + return_embed=False, + return_logits=False, + labels=None, + ignore_index=0, + self_cond_embed=None, + cond_drop_prob=0.0, conditioning_token_ids: Optional[torch.Tensor] = None, texts: Optional[List[str]] = None, - text_embeds: Optional[torch.Tensor] = None + text_embeds: Optional[torch.Tensor] = None, ): device, b, n = x.device, *x.shape assert n <= self.seq_len @@ -296,33 +315,37 @@ def forward( context = self.text_embed_proj(text_embeds) - context_mask = (text_embeds != 0).any(dim = -1) + context_mask = (text_embeds != 0).any(dim=-1) # classifier free guidance - if self.training and cond_drop_prob > 0.: - mask = prob_mask_like((b, 1), 1. - cond_drop_prob, device) + if self.training and cond_drop_prob > 0.0: + mask = prob_mask_like((b, 1), 1.0 - cond_drop_prob, device) context_mask = context_mask & mask # concat conditioning image token ids if needed if exists(conditioning_token_ids): - conditioning_token_ids = rearrange(conditioning_token_ids, 'b ... -> b (...)') + conditioning_token_ids = rearrange( + conditioning_token_ids, "b ... -> b (...)" + ) cond_token_emb = self.token_emb(conditioning_token_ids) - context = torch.cat((context, cond_token_emb), dim = -2) - context_mask = F.pad(context_mask, (0, conditioning_token_ids.shape[-1]), value = True) + context = torch.cat((context, cond_token_emb), dim=-2) + context_mask = F.pad( + context_mask, (0, conditioning_token_ids.shape[-1]), value=True + ) # embed tokens x = self.token_emb(x) - x = x + self.pos_emb(torch.arange(n, device = device)) + x = x + self.pos_emb(torch.arange(n, device=device)) if self.self_cond: if not exists(self_cond_embed): self_cond_embed = torch.zeros_like(x) x = x + self.self_cond_to_init_embed(self_cond_embed) - embed = self.transformer_blocks(x, context = context, context_mask = context_mask) + embed = self.transformer_blocks(x, context=context, context_mask=context_mask) logits = self.to_logits(embed) @@ -333,17 +356,23 @@ def forward( return logits if self.dim_out == 1: - loss = F.binary_cross_entropy_with_logits(rearrange(logits, '... 1 -> ...'), labels) + loss = F.binary_cross_entropy_with_logits( + rearrange(logits, "... 1 -> ..."), labels + ) else: - loss = F.cross_entropy(rearrange(logits, 'b n c -> b c n'), labels, ignore_index = ignore_index) + loss = F.cross_entropy( + rearrange(logits, "b n c -> b c n"), labels, ignore_index=ignore_index + ) if not return_logits: return loss return loss, logits + # self critic wrapper + class SelfCritic(nn.Module): def __init__(self, net): super().__init__() @@ -351,74 +380,93 @@ def __init__(self, net): self.to_pred = nn.Linear(net.dim, 1) def forward_with_cond_scale(self, x, *args, **kwargs): - _, embeds = self.net.forward_with_cond_scale(x, *args, return_embed = True, **kwargs) + _, embeds = self.net.forward_with_cond_scale( + x, *args, return_embed=True, **kwargs + ) return self.to_pred(embeds) def forward_with_neg_prompt(self, x, *args, **kwargs): - _, embeds = self.net.forward_with_neg_prompt(x, *args, return_embed = True, **kwargs) + _, embeds = self.net.forward_with_neg_prompt( + x, *args, return_embed=True, **kwargs + ) return self.to_pred(embeds) - def forward(self, x, *args, labels = None, **kwargs): - _, embeds = self.net(x, *args, return_embed = True, **kwargs) + def forward(self, x, *args, labels=None, **kwargs): + _, embeds = self.net(x, *args, return_embed=True, **kwargs) logits = self.to_pred(embeds) if not exists(labels): return logits - logits = rearrange(logits, '... 1 -> ...') + logits = rearrange(logits, "... 1 -> ...") return F.binary_cross_entropy_with_logits(logits, labels) + # specialized transformers + class MaskGitTransformer(Transformer): def __init__(self, *args, **kwargs): - assert 'add_mask_id' not in kwargs - super().__init__(*args, add_mask_id = True, **kwargs) + assert "add_mask_id" not in kwargs + super().__init__(*args, add_mask_id=True, **kwargs) + class TokenCritic(Transformer): def __init__(self, *args, **kwargs): - assert 'dim_out' not in kwargs - super().__init__(*args, dim_out = 1, **kwargs) + assert "dim_out" not in kwargs + super().__init__(*args, dim_out=1, **kwargs) + # classifier free guidance functions -def uniform(shape, min = 0, max = 1, device = None): - return torch.zeros(shape, device = device).float().uniform_(0, 1) -def prob_mask_like(shape, prob, device = None): +def uniform(shape, min=0, max=1, device=None): + return torch.zeros(shape, device=device).float().uniform_(0, 1) + + +def prob_mask_like(shape, prob, device=None): if prob == 1: - return torch.ones(shape, device = device, dtype = torch.bool) + return torch.ones(shape, device=device, dtype=torch.bool) elif prob == 0: - return torch.zeros(shape, device = device, dtype = torch.bool) + return torch.zeros(shape, device=device, dtype=torch.bool) else: - return uniform(shape, device = device) < prob + return uniform(shape, device=device) < prob + # sampling helpers -def log(t, eps = 1e-20): - return torch.log(t.clamp(min = eps)) + +def log(t, eps=1e-20): + return torch.log(t.clamp(min=eps)) + def gumbel_noise(t): noise = torch.zeros_like(t).uniform_(0, 1) return -log(-log(noise)) -def gumbel_sample(t, temperature = 1., dim = -1): - return ((t / max(temperature, 1e-10)) + gumbel_noise(t)).argmax(dim = dim) -def top_k(logits, thres = 0.9): +def gumbel_sample(t, temperature=1.0, dim=-1): + return ((t / max(temperature, 1e-10)) + gumbel_noise(t)).argmax(dim=dim) + + +def top_k(logits, thres=0.9): k = math.ceil((1 - thres) * logits.shape[-1]) - val, ind = logits.topk(k, dim = -1) - probs = torch.full_like(logits, float('-inf')) + val, ind = logits.topk(k, dim=-1) + probs = torch.full_like(logits, float("-inf")) probs.scatter_(2, ind, val) return probs + # noise schedules + def cosine_schedule(t): return torch.cos(t * math.pi * 0.5) + # main maskgit classes + @beartype class MaskGit(nn.Module): def __init__( @@ -427,14 +475,14 @@ def __init__( transformer: MaskGitTransformer, noise_schedule: Callable = cosine_schedule, token_critic: Optional[TokenCritic] = None, - self_token_critic = False, + self_token_critic=False, vae: Optional[VQGanVAE] = None, cond_vae: Optional[VQGanVAE] = None, - cond_image_size = None, - cond_drop_prob = 0.5, - self_cond_prob = 0.9, - no_mask_token_prob = 0., - critic_loss_weight = 1. + cond_image_size=None, + cond_drop_prob=0.5, + self_cond_prob=0.9, + no_mask_token_prob=0.0, + critic_loss_weight=1.0, ): super().__init__() self.vae = vae.copy_for_eval() if exists(vae) else None @@ -444,7 +492,9 @@ def __init__( else: self.cond_vae = self.vae - assert not (exists(cond_vae) and not exists(cond_image_size)), 'cond_image_size must be specified if conditioning' + assert not ( + exists(cond_vae) and not exists(cond_image_size) + ), "cond_image_size must be specified if conditioning" self.image_size = image_size self.cond_image_size = cond_image_size @@ -454,7 +504,11 @@ def __init__( self.transformer = transformer self.self_cond = transformer.self_cond - assert self.vae.codebook_size == self.cond_vae.codebook_size == transformer.num_tokens, 'transformer num_tokens must be set to be equal to the vae codebook size' + assert ( + self.vae.codebook_size + == self.cond_vae.codebook_size + == transformer.num_tokens + ), "transformer num_tokens must be set to be equal to the vae codebook size" self.mask_id = transformer.mask_id self.noise_schedule = noise_schedule @@ -490,14 +544,14 @@ def generate( texts: List[str], negative_texts: Optional[List[str]] = None, cond_images: Optional[torch.Tensor] = None, - fmap_size = None, - temperature = 1., - topk_filter_thres = 0.9, - can_remask_prev_masked = False, - force_not_use_token_critic = False, - timesteps = 18, # ideal number of steps is 18 in maskgit paper - cond_scale = 3, - critic_noise_scale = 1 + fmap_size=None, + temperature=1.0, + topk_filter_thres=0.9, + can_remask_prev_masked=False, + force_not_use_token_critic=False, + timesteps=18, # ideal number of steps is 18 in maskgit paper + cond_scale=3, + critic_noise_scale=1, ): fmap_size = default(fmap_size, self.vae.get_encoded_fmap_size(self.image_size)) @@ -505,14 +559,14 @@ def generate( device = next(self.parameters()).device - seq_len = fmap_size ** 2 + seq_len = fmap_size**2 batch_size = len(texts) shape = (batch_size, seq_len) - ids = torch.full(shape, self.mask_id, dtype = torch.long, device = device) - scores = torch.zeros(shape, dtype = torch.float32, device = device) + ids = torch.full(shape, self.mask_id, dtype=torch.long, device=device) + scores = torch.zeros(shape, dtype=torch.float32, device=device) starting_temperature = temperature @@ -536,78 +590,93 @@ def generate( assert len(texts) == len(negative_texts) neg_text_embeds = self.transformer.encode_text(negative_texts) - demask_fn = partial(self.transformer.forward_with_neg_prompt, neg_text_embeds = neg_text_embeds) + demask_fn = partial( + self.transformer.forward_with_neg_prompt, + neg_text_embeds=neg_text_embeds, + ) if use_token_critic: - token_critic_fn = partial(self.token_critic.forward_with_neg_prompt, neg_text_embeds = neg_text_embeds) + token_critic_fn = partial( + self.token_critic.forward_with_neg_prompt, + neg_text_embeds=neg_text_embeds, + ) if self.resize_image_for_cond_image: - assert exists(cond_images), 'conditioning image must be passed in to generate for super res maskgit' + assert exists( + cond_images + ), "conditioning image must be passed in to generate for super res maskgit" with torch.no_grad(): _, cond_ids, _ = self.cond_vae.encode(cond_images) self_cond_embed = None - for timestep, steps_until_x0 in tqdm(zip(torch.linspace(0, 1, timesteps, device = device), reversed(range(timesteps))), total = timesteps): - + for timestep, steps_until_x0 in tqdm( + zip( + torch.linspace(0, 1, timesteps, device=device), + reversed(range(timesteps)), + ), + total=timesteps, + ): rand_mask_prob = self.noise_schedule(timestep) num_token_masked = max(int((rand_mask_prob * seq_len).item()), 1) - masked_indices = scores.topk(num_token_masked, dim = -1).indices + masked_indices = scores.topk(num_token_masked, dim=-1).indices ids = ids.scatter(1, masked_indices, self.mask_id) logits, embed = demask_fn( ids, - text_embeds = text_embeds, - self_cond_embed = self_cond_embed, - conditioning_token_ids = cond_ids, - cond_scale = cond_scale, - return_embed = True + text_embeds=text_embeds, + self_cond_embed=self_cond_embed, + conditioning_token_ids=cond_ids, + cond_scale=cond_scale, + return_embed=True, ) self_cond_embed = embed if self.self_cond else None filtered_logits = top_k(logits, topk_filter_thres) - temperature = starting_temperature * (steps_until_x0 / timesteps) # temperature is annealed + temperature = starting_temperature * ( + steps_until_x0 / timesteps + ) # temperature is annealed - pred_ids = gumbel_sample(filtered_logits, temperature = temperature, dim = -1) + pred_ids = gumbel_sample(filtered_logits, temperature=temperature, dim=-1) is_mask = ids == self.mask_id - ids = torch.where( - is_mask, - pred_ids, - ids - ) + ids = torch.where(is_mask, pred_ids, ids) if use_token_critic: scores = token_critic_fn( ids, - text_embeds = text_embeds, - conditioning_token_ids = cond_ids, - cond_scale = cond_scale + text_embeds=text_embeds, + conditioning_token_ids=cond_ids, + cond_scale=cond_scale, ) - scores = rearrange(scores, '... 1 -> ...') + scores = rearrange(scores, "... 1 -> ...") - scores = scores + (uniform(scores.shape, device = device) - 0.5) * critic_noise_scale * (steps_until_x0 / timesteps) + scores = scores + ( + uniform(scores.shape, device=device) - 0.5 + ) * critic_noise_scale * (steps_until_x0 / timesteps) else: - probs_without_temperature = logits.softmax(dim = -1) + probs_without_temperature = logits.softmax(dim=-1) scores = 1 - probs_without_temperature.gather(2, pred_ids[..., None]) - scores = rearrange(scores, '... 1 -> ...') + scores = rearrange(scores, "... 1 -> ...") if not can_remask_prev_masked: scores = scores.masked_fill(~is_mask, -1e5) else: - assert self.no_mask_token_prob > 0., 'without training with some of the non-masked tokens forced to predict, not sure if the logits will be meaningful for these token' + assert ( + self.no_mask_token_prob > 0.0 + ), "without training with some of the non-masked tokens forced to predict, not sure if the logits will be meaningful for these token" # get ids - ids = rearrange(ids, 'b (i j) -> b i j', i = fmap_size, j = fmap_size) + ids = rearrange(ids, "b (i j) -> b i j", i=fmap_size, j=fmap_size) if not exists(self.vae): return ids @@ -618,63 +687,85 @@ def generate( def forward( self, images_or_ids: torch.Tensor, - ignore_index = -1, + ignore_index=-1, cond_images: Optional[torch.Tensor] = None, cond_token_ids: Optional[torch.Tensor] = None, texts: Optional[List[str]] = None, text_embeds: Optional[torch.Tensor] = None, - cond_drop_prob = None, - train_only_generator = False, - sample_temperature = None + cond_drop_prob=None, + train_only_generator=False, + sample_temperature=None, ): # tokenize if needed if images_or_ids.dtype == torch.float: - assert exists(self.vae), 'vqgan vae must be passed in if training from raw images' - assert all([height_or_width == self.image_size for height_or_width in images_or_ids.shape[-2:]]), 'the image you passed in is not of the correct dimensions' + assert exists( + self.vae + ), "vqgan vae must be passed in if training from raw images" + assert all( + [ + height_or_width == self.image_size + for height_or_width in images_or_ids.shape[-2:] + ] + ), "the image you passed in is not of the correct dimensions" with torch.no_grad(): _, ids, _ = self.vae.encode(images_or_ids) else: - assert not self.resize_image_for_cond_image, 'you cannot pass in raw image token ids if you want the framework to autoresize image for conditioning super res transformer' + assert ( + not self.resize_image_for_cond_image + ), "you cannot pass in raw image token ids if you want the framework to autoresize image for conditioning super res transformer" ids = images_or_ids # take care of conditioning image if specified if self.resize_image_for_cond_image: - cond_images_or_ids = F.interpolate(images_or_ids, self.cond_image_size, mode = 'nearest') + cond_images_or_ids = F.interpolate( + images_or_ids, self.cond_image_size, mode="nearest" + ) # get some basic variables - ids = rearrange(ids, 'b ... -> b (...)') + ids = rearrange(ids, "b ... -> b (...)") - batch, seq_len, device, cond_drop_prob = *ids.shape, ids.device, default(cond_drop_prob, self.cond_drop_prob) + batch, seq_len, device, cond_drop_prob = ( + *ids.shape, + ids.device, + default(cond_drop_prob, self.cond_drop_prob), + ) # tokenize conditional images if needed - assert not (exists(cond_images) and exists(cond_token_ids)), 'if conditioning on low resolution, cannot pass in both images and token ids' + assert not ( + exists(cond_images) and exists(cond_token_ids) + ), "if conditioning on low resolution, cannot pass in both images and token ids" if exists(cond_images): - assert exists(self.cond_vae), 'cond vqgan vae must be passed in' - assert all([height_or_width == self.cond_image_size for height_or_width in cond_images.shape[-2:]]) + assert exists(self.cond_vae), "cond vqgan vae must be passed in" + assert all( + [ + height_or_width == self.cond_image_size + for height_or_width in cond_images.shape[-2:] + ] + ) with torch.no_grad(): _, cond_token_ids, _ = self.cond_vae.encode(cond_images) # prepare mask - rand_time = uniform((batch,), device = device) + rand_time = uniform((batch,), device=device) rand_mask_probs = self.noise_schedule(rand_time) - num_token_masked = (seq_len * rand_mask_probs).round().clamp(min = 1) + num_token_masked = (seq_len * rand_mask_probs).round().clamp(min=1) mask_id = self.mask_id - batch_randperm = torch.rand((batch, seq_len), device = device).argsort(dim = -1) - mask = batch_randperm < rearrange(num_token_masked, 'b -> b 1') + batch_randperm = torch.rand((batch, seq_len), device=device).argsort(dim=-1) + mask = batch_randperm < rearrange(num_token_masked, "b -> b 1") mask_id = self.transformer.mask_id labels = torch.where(mask, ids, ignore_index) - if self.no_mask_token_prob > 0.: + if self.no_mask_token_prob > 0.0: no_mask_mask = get_mask_subset_prob(mask, self.no_mask_token_prob) mask &= ~no_mask_mask @@ -694,10 +785,10 @@ def forward( with torch.no_grad(): _, self_cond_embed = self.transformer( x, - text_embeds = text_embeds, - conditioning_token_ids = cond_token_ids, - cond_drop_prob = 0., - return_embed = True + text_embeds=text_embeds, + conditioning_token_ids=cond_token_ids, + cond_drop_prob=0.0, + return_embed=True, ) self_cond_embed.detach_() @@ -706,13 +797,13 @@ def forward( ce_loss, logits = self.transformer( x, - text_embeds = text_embeds, - self_cond_embed = self_cond_embed, - conditioning_token_ids = cond_token_ids, - labels = labels, - cond_drop_prob = cond_drop_prob, - ignore_index = ignore_index, - return_logits = True + text_embeds=text_embeds, + self_cond_embed=self_cond_embed, + conditioning_token_ids=cond_token_ids, + labels=labels, + cond_drop_prob=cond_drop_prob, + ignore_index=ignore_index, + return_logits=True, ) if not exists(self.token_critic) or train_only_generator: @@ -720,30 +811,30 @@ def forward( # token critic loss - sampled_ids = gumbel_sample(logits, temperature = default(sample_temperature, random())) + sampled_ids = gumbel_sample( + logits, temperature=default(sample_temperature, random()) + ) critic_input = torch.where(mask, sampled_ids, x) critic_labels = (ids != critic_input).float() bce_loss = self.token_critic( critic_input, - text_embeds = text_embeds, - conditioning_token_ids = cond_token_ids, - labels = critic_labels, - cond_drop_prob = cond_drop_prob + text_embeds=text_embeds, + conditioning_token_ids=cond_token_ids, + labels=critic_labels, + cond_drop_prob=cond_drop_prob, ) return ce_loss + self.critic_loss_weight * bce_loss + # final Muse class + @beartype class Muse(nn.Module): - def __init__( - self, - base: MaskGit, - superres: MaskGit - ): + def __init__(self, base: MaskGit, superres: MaskGit): super().__init__() self.base_maskgit = base.eval() @@ -754,31 +845,31 @@ def __init__( def forward( self, texts: List[str], - cond_scale = 3., - temperature = 1., - timesteps = 18, - superres_timesteps = None, - return_lowres = False, - return_pil_images = True + cond_scale=3.0, + temperature=1.0, + timesteps=18, + superres_timesteps=None, + return_lowres=False, + return_pil_images=True, ): lowres_image = self.base_maskgit.generate( - texts = texts, - cond_scale = cond_scale, - temperature = temperature, - timesteps = timesteps + texts=texts, + cond_scale=cond_scale, + temperature=temperature, + timesteps=timesteps, ) superres_image = self.superres_maskgit.generate( - texts = texts, - cond_scale = cond_scale, - cond_images = lowres_image, - temperature = temperature, - timesteps = default(superres_timesteps, timesteps) + texts=texts, + cond_scale=cond_scale, + cond_images=lowres_image, + temperature=temperature, + timesteps=default(superres_timesteps, timesteps), ) - + if return_pil_images: lowres_image = list(map(T.ToPILImage(), lowres_image)) - superres_image = list(map(T.ToPILImage(), superres_image)) + superres_image = list(map(T.ToPILImage(), superres_image)) if not return_lowres: return superres_image diff --git a/muse_maskgit_pytorch/t5.py b/muse_maskgit_pytorch/t5.py index 036e8a9..cdb68a1 100644 --- a/muse_maskgit_pytorch/t5.py +++ b/muse_maskgit_pytorch/t5.py @@ -8,27 +8,32 @@ transformers.logging.set_verbosity_error() + def exists(val): return val is not None + # config MAX_LENGTH = 256 -DEFAULT_T5_NAME = 'google/t5-v1_1-base' +DEFAULT_T5_NAME = "google/t5-v1_1-base" T5_CONFIGS = {} # singleton globals + def get_tokenizer(name): tokenizer = T5Tokenizer.from_pretrained(name) return tokenizer + def get_model(name): model = T5EncoderModel.from_pretrained(name) return model + def get_model_and_tokenizer(name): global T5_CONFIGS @@ -39,7 +44,8 @@ def get_model_and_tokenizer(name): if "tokenizer" not in T5_CONFIGS[name]: T5_CONFIGS[name]["tokenizer"] = get_tokenizer(name) - return T5_CONFIGS[name]['model'], T5_CONFIGS[name]['tokenizer'] + return T5_CONFIGS[name]["model"], T5_CONFIGS[name]["tokenizer"] + def get_encoded_dim(name): if name not in T5_CONFIGS: @@ -54,41 +60,39 @@ def get_encoded_dim(name): assert False return config.d_model + # encoding text -def t5_encode_text_from_encoded(input_ids, - attn_mask, - t5, - output_device): + +def t5_encode_text_from_encoded(input_ids, attn_mask, t5, output_device): device = t5.device input_ids, attn_mask = input_ids.to(device), attn_mask.to(device) with torch.no_grad(): - output = t5(input_ids = input_ids, attention_mask = attn_mask) + output = t5(input_ids=input_ids, attention_mask=attn_mask) encoded_text = output.last_hidden_state.detach() attn_mask = attn_mask.bool() - encoded_text = encoded_text.masked_fill(attn_mask[..., None], 0.) + encoded_text = encoded_text.masked_fill(attn_mask[..., None], 0.0) if not exists(output_device): return encoded_text encoded_text.to(output_device) return encoded_text + + @beartype -def t5_encode_text( - texts: Union[str, List[str]], - tokenizer, - t5, - output_device = None -): +def t5_encode_text(texts: Union[str, List[str]], tokenizer, t5, output_device=None): if isinstance(texts, str): texts = [texts] encoded = tokenizer.batch_encode_plus( texts, - return_tensors = "pt", - padding = 'max_length', - max_length = MAX_LENGTH, - truncation = True + return_tensors="pt", + padding="max_length", + max_length=MAX_LENGTH, + truncation=True, + ) + return t5_encode_text_from_encoded( + encoded["input_ids"], encoded["attention_mask"], t5, output_device ) - return t5_encode_text_from_encoded(encoded["input_ids"], encoded["attention_mask"], t5, output_device) diff --git a/muse_maskgit_pytorch/trainers/__init__.py b/muse_maskgit_pytorch/trainers/__init__.py index ee7e8e0..4ba398f 100644 --- a/muse_maskgit_pytorch/trainers/__init__.py +++ b/muse_maskgit_pytorch/trainers/__init__.py @@ -1,3 +1,3 @@ from muse_maskgit_pytorch.trainers.vqvae_trainers import VQGanVAETrainer from muse_maskgit_pytorch.trainers.maskgit_trainer import MaskGitTrainer -from muse_maskgit_pytorch.trainers.base_accelerated_trainer import get_accelerator \ No newline at end of file +from muse_maskgit_pytorch.trainers.base_accelerated_trainer import get_accelerator diff --git a/muse_maskgit_pytorch/trainers/base_accelerated_trainer.py b/muse_maskgit_pytorch/trainers/base_accelerated_trainer.py index fbde16a..0c4f10f 100644 --- a/muse_maskgit_pytorch/trainers/base_accelerated_trainer.py +++ b/muse_maskgit_pytorch/trainers/base_accelerated_trainer.py @@ -1,4 +1,3 @@ - from pathlib import Path from shutil import rmtree @@ -18,17 +17,24 @@ import numpy as np + try: import wandb except: None + + def noop(*args, **kwargs): pass + + # helper functions + def identity(t, *args, **kwargs): return t + def cycle(dl): while True: for data in dl: @@ -66,6 +72,8 @@ def get_accelerator(**accelerate_kwargs): accelerator = Accelerator(**accelerate_kwargs) return accelerator + + def split_dataset(dataset, valid_frac, accelerator, seed=42): if valid_frac > 0: train_size = int((1 - valid_frac) * len(dataset)) @@ -85,8 +93,10 @@ def split_dataset(dataset, valid_frac, accelerator, seed=42): ) return ds, valid_ds + # main trainer class + @beartype class BaseAcceleratedTrainer(nn.Module): def __init__( @@ -107,7 +117,7 @@ def __init__( clear_previous_experiments=False, ): super().__init__() - self.model=None + self.model = None # instantiate accelerator self.gradient_accumulation_steps = gradient_accumulation_steps self.accelerator = accelerator @@ -144,6 +154,7 @@ def save(self, path): optim=self.optim.state_dict(), ) torch.save(pkg, path) + def load(self, path): path = Path(path) assert path.exists() @@ -154,30 +165,41 @@ def load(self, path): self.optim.load_state_dict(pkg["optim"]) return pkg + def log_validation_images(self, images, step, prompt=None): for tracker in self.accelerator.trackers: if tracker.name == "tensorboard": np_images = np.stack([np.asarray(img) for img in images]) - tracker.writer.add_images("validation", np_images, step, dataformats="NHWC") + tracker.writer.add_images( + "validation", np_images, step, dataformats="NHWC" + ) if tracker.name == "wandb": tracker.log( { "validation": [ - wandb.Image(image, caption=f"{i}"+"" if prompt else f": {prompt}") + wandb.Image( + image, caption=f"{i}" + "" if prompt else f": {prompt}" + ) for i, image in enumerate(images) ] } ) + def print(self, msg): self.accelerator.print(msg) + def log(self, log_dict): self.accelerator.log(log_dict) + def prepare(self, *args): return self.accelerator.prepare(*args) + def get_state_dict(self, model): return self.accelerator.get_state_dict(model) + def unwrap_model(self, model): return self.accelerator.unwrap_model(model) + @property def device(self): return self.accelerator.device @@ -198,8 +220,9 @@ def is_local_main(self): return self.accelerator.is_local_main_process def train_step(self): - raise NotImplementedError("You are calling train_step on the base trainer with no models") - + raise NotImplementedError( + "You are calling train_step on the base trainer with no models" + ) def train(self, log_fn=noop): self.model.train() @@ -209,4 +232,3 @@ def train(self, log_fn=noop): log_fn(logs) self.writer.close() self.print("training complete") - diff --git a/muse_maskgit_pytorch/trainers/maskgit_trainer.py b/muse_maskgit_pytorch/trainers/maskgit_trainer.py index e69b152..55c1f1e 100644 --- a/muse_maskgit_pytorch/trainers/maskgit_trainer.py +++ b/muse_maskgit_pytorch/trainers/maskgit_trainer.py @@ -1,4 +1,3 @@ - from pathlib import Path from shutil import rmtree @@ -20,15 +19,21 @@ from ema_pytorch import EMA from muse_maskgit_pytorch.muse_maskgit_pytorch import MaskGit -from muse_maskgit_pytorch.trainers.base_accelerated_trainer import BaseAcceleratedTrainer +from muse_maskgit_pytorch.trainers.base_accelerated_trainer import ( + BaseAcceleratedTrainer, +) from muse_maskgit_pytorch.t5 import t5_encode_text_from_encoded import torch.nn.functional as F + + def noop(*args, **kwargs): pass + def exists(val): return val is not None + class MaskGitTrainer(BaseAcceleratedTrainer): def __init__( self, @@ -48,26 +53,38 @@ def __init__( logging_dir="./results/logs", apply_grad_penalty_every=4, lr=3e-4, + lr_scheduler_type="constant", + lr_warmup_steps=500, use_ema=True, ema_beta=0.995, ema_update_after_step=0, ema_update_every=1, log_model_every=100, validation_prompt="a photo of a dog", - clear_previous_experiments=False + clear_previous_experiments=False, ): - super().__init__(dataloader, valid_dataloader, accelerator, current_step=current_step, num_train_steps=num_train_steps,\ - gradient_accumulation_steps=gradient_accumulation_steps, max_grad_norm=max_grad_norm, save_results_every=save_results_every, \ - save_model_every=save_model_every, results_dir=results_dir, logging_dir=logging_dir, apply_grad_penalty_every=apply_grad_penalty_every,\ - clear_previous_experiments=clear_previous_experiments) - self.log_model_every=log_model_every - self.batch_size=batch_size + super().__init__( + dataloader, + valid_dataloader, + accelerator, + current_step=current_step, + num_train_steps=num_train_steps, + gradient_accumulation_steps=gradient_accumulation_steps, + max_grad_norm=max_grad_norm, + save_results_every=save_results_every, + save_model_every=save_model_every, + results_dir=results_dir, + logging_dir=logging_dir, + apply_grad_penalty_every=apply_grad_penalty_every, + clear_previous_experiments=clear_previous_experiments, + ) + self.log_model_every = log_model_every + self.batch_size = batch_size # maskgit self.model = maskgit self.model.vae.requires_grad_(False) self.model.transformer.t5.requires_grad_(False) - all_parameters = set(maskgit.parameters()) # don't train the vae @@ -79,6 +96,13 @@ def __init__( self.optim = Adam(transformer_parameters, lr=lr) + self.lr_scheduler = get_scheduler( + lr_scheduler_type, + optimizer=self.optim, + num_warmup_steps=lr_warmup_steps * self.gradient_accumulation_steps, + num_training_steps=self.num_train_steps * self.gradient_accumulation_steps, + ) + # prepare with accelerator ( @@ -86,8 +110,9 @@ def __init__( self.optim, self.dl, self.valid_dl, + self.lr_scheduler, ) = self.prepare( - self.model, self.optim, self.dl, self.valid_dl + self.model, self.optim, self.dl, self.valid_dl, self.lr_scheduler ) self.use_ema = use_ema @@ -99,15 +124,23 @@ def __init__( update_every=ema_update_every, ) self.ema_model = self.prepare(self.ema_model) - def log_validation_images(self, validation_prompt, step, cond_image=None, cond_scale=3, temperature=1): - image = self.model.generate([validation_prompt], cond_images=cond_image, cond_scale=cond_scale, temperature=temperature) + + def log_validation_images( + self, validation_prompt, step, cond_image=None, cond_scale=3, temperature=1 + ): + image = self.model.generate( + [validation_prompt], + cond_images=cond_image, + cond_scale=cond_scale, + temperature=temperature, + ) super().log_validation_images([image], step, validation_prompt) + def train_step(self): device = self.device steps = int(self.steps.item()) apply_grad_penalty = not (steps % self.apply_grad_penalty_every) - if self.use_ema: ema_model = self.ema_model.module if self.is_distributed else self.ema_model self.model.train() @@ -115,41 +148,54 @@ def train_step(self): train_loss = 0 with self.accelerator.accumulate(self.model): imgs, input_ids, attn_mask = next(self.dl_iter) - imgs, input_ids, attn_mask = imgs.to(device), input_ids.to(device), attn_mask.to(device) - text_embeds = t5_encode_text_from_encoded(input_ids, attn_mask, self.model.transformer.t5, device) - loss = self.model( - imgs, - text_embeds=text_embeds + imgs, input_ids, attn_mask = ( + imgs.to(device), + input_ids.to(device), + attn_mask.to(device), ) + text_embeds = t5_encode_text_from_encoded( + input_ids, attn_mask, self.model.transformer.t5, device + ) + loss = self.model(imgs, text_embeds=text_embeds) avg_loss = self.accelerator.gather(loss.repeat(self.batch_size)).mean() train_loss += avg_loss.item() / self.gradient_accumulation_steps self.accelerator.backward(loss) if exists(self.max_grad_norm): - self.accelerator.clip_grad_norm_(self.model.parameters(), self.max_grad_norm) + self.accelerator.clip_grad_norm_( + self.model.parameters(), self.max_grad_norm + ) + self.lr_scheduler.step() self.optim.step() self.optim.zero_grad() if self.accelerator.sync_gradients: self.steps += 1 if self.use_ema: ema_model.update() - logs = {"loss": train_loss} - + logs = {"loss": train_loss, "lr": self.lr_scheduler.get_last_lr()[0]} self.accelerator.log(logs, steps) if steps % self.save_model_every == 0: state_dict = self.accelerator.unwrap_model(self.model).state_dict() - maskgit_save_name = 'maskgit_superres' if self.model.cond_image_size else 'maskgit' - model_path = str(self.results_dir / f'{maskgit_save_name}.{steps}.pt') + maskgit_save_name = ( + "maskgit_superres" if self.model.cond_image_size else "maskgit" + ) + model_path = str(self.results_dir / f"{maskgit_save_name}.{steps}.pt") self.accelerator.save(state_dict, model_path) if self.use_ema: - ema_state_dict = self.accelerator.unwrap_model(self.ema_model).state_dict() - model_path = str(self.results_dir / f'{maskgit_save_name}.{steps}.ema.pt') + ema_state_dict = self.accelerator.unwrap_model( + self.ema_model + ).state_dict() + model_path = str( + self.results_dir / f"{maskgit_save_name}.{steps}.ema.pt" + ) self.accelerator.save(ema_state_dict, model_path) - self.print(f'{steps}: saving model to {str(self.results_dir)}') + self.print(f"{steps}: saving model to {str(self.results_dir)}") if steps % self.log_model_every == 0: cond_image = None if self.model.cond_image_size: - cond_image =F.interpolate(imgs[0], 256) - self.log_validation_images(self.validation_prompt, self.steps, cond_image=cond_image) - return logs \ No newline at end of file + cond_image = F.interpolate(imgs[0], 256) + self.log_validation_images( + self.validation_prompt, self.steps, cond_image=cond_image + ) + return logs diff --git a/muse_maskgit_pytorch/trainers/vqvae_trainers.py b/muse_maskgit_pytorch/trainers/vqvae_trainers.py index 163ffc8..cfb8019 100644 --- a/muse_maskgit_pytorch/trainers/vqvae_trainers.py +++ b/muse_maskgit_pytorch/trainers/vqvae_trainers.py @@ -1,4 +1,3 @@ - from pathlib import Path from shutil import rmtree from datetime import datetime @@ -21,20 +20,27 @@ from ema_pytorch import EMA import numpy as np -from muse_maskgit_pytorch.trainers.base_accelerated_trainer import BaseAcceleratedTrainer +from muse_maskgit_pytorch.trainers.base_accelerated_trainer import ( + BaseAcceleratedTrainer, +) from diffusers.optimization import get_scheduler + def noop(*args, **kwargs): pass + def accum_log(log, new_logs): for key, new_value in new_logs.items(): old_value = log.get(key, 0.0) log[key] = old_value + new_value return log + + def exists(val): return val is not None + class VQGanVAETrainer(BaseAcceleratedTrainer): def __init__( self, @@ -53,19 +59,30 @@ def __init__( logging_dir="./results/logs", apply_grad_penalty_every=4, lr=3e-4, - lr_scheduler_type='constant', - lr_warmup_steps= 500, + lr_scheduler_type="constant", + lr_warmup_steps=500, discr_max_grad_norm=None, use_ema=True, ema_beta=0.995, ema_update_after_step=0, ema_update_every=1, - clear_previous_experiments=False + clear_previous_experiments=False, ): - super().__init__(dataloader, valid_dataloader, accelerator, current_step=current_step, num_train_steps=num_train_steps,\ - gradient_accumulation_steps=gradient_accumulation_steps, max_grad_norm=max_grad_norm, save_results_every=save_results_every, \ - save_model_every=save_model_every, results_dir=results_dir, logging_dir=logging_dir, apply_grad_penalty_every=apply_grad_penalty_every,\ - clear_previous_experiments=clear_previous_experiments) + super().__init__( + dataloader, + valid_dataloader, + accelerator, + current_step=current_step, + num_train_steps=num_train_steps, + gradient_accumulation_steps=gradient_accumulation_steps, + max_grad_norm=max_grad_norm, + save_results_every=save_results_every, + save_model_every=save_model_every, + results_dir=results_dir, + logging_dir=logging_dir, + apply_grad_penalty_every=apply_grad_penalty_every, + clear_previous_experiments=clear_previous_experiments, + ) # vae self.model = vae @@ -77,20 +94,20 @@ def __init__( # optimizers self.optim = Adam(vae_parameters, lr=lr) self.discr_optim = Adam(discr_parameters, lr=lr) - + self.lr_scheduler = get_scheduler( - lr_scheduler_type, - optimizer=self.optim, - num_warmup_steps=lr_warmup_steps * self.gradient_accumulation_steps, - num_training_steps=self.num_train_steps * self.gradient_accumulation_steps, + lr_scheduler_type, + optimizer=self.optim, + num_warmup_steps=lr_warmup_steps * self.gradient_accumulation_steps, + num_training_steps=self.num_train_steps * self.gradient_accumulation_steps, ) - + self.lr_scheduler_discr = get_scheduler( lr_scheduler_type, optimizer=self.discr_optim, num_warmup_steps=lr_warmup_steps * self.gradient_accumulation_steps, num_training_steps=self.num_train_steps * self.gradient_accumulation_steps, - ) + ) self.discr_max_grad_norm = discr_max_grad_norm @@ -102,8 +119,16 @@ def __init__( self.discr_optim, self.dl, self.valid_dl, + self.lr_scheduler, + self.lr_scheduler_discr, ) = self.prepare( - self.model, self.optim, self.discr_optim, self.dl, self.valid_dl + self.model, + self.optim, + self.discr_optim, + self.dl, + self.valid_dl, + self.lr_scheduler, + self.lr_scheduler_discr, ) self.model.train() @@ -120,6 +145,7 @@ def __init__( def load(self, path): pkg = super().load(path) self.discr_optim.load_state_dict(pkg["discr_optim"]) + def save(self, path): if not self.is_local_main_process: return @@ -130,6 +156,7 @@ def save(self, path): discr_optim=self.discr_optim.state_dict(), ) torch.save(pkg, path) + def log_validation_images(self, models_to_evaluate, logs, steps): log_imgs = [] for model, filename in models_to_evaluate: @@ -138,23 +165,24 @@ def log_validation_images(self, models_to_evaluate, logs, steps): valid_data = next(self.valid_dl_iter) valid_data = valid_data.to(self.device) - recons = model(valid_data, return_recons = True) + recons = model(valid_data, return_recons=True) # else save a grid of images - imgs_and_recons = torch.stack((valid_data, recons), dim = 0) - imgs_and_recons = rearrange(imgs_and_recons, 'r b ... -> (b r) ...') + imgs_and_recons = torch.stack((valid_data, recons), dim=0) + imgs_and_recons = rearrange(imgs_and_recons, "r b ... -> (b r) ...") - imgs_and_recons = imgs_and_recons.detach().cpu().float().clamp(0., 1.) - grid = make_grid(imgs_and_recons, nrow = 2, normalize = True, value_range = (0, 1)) + imgs_and_recons = imgs_and_recons.detach().cpu().float().clamp(0.0, 1.0) + grid = make_grid( + imgs_and_recons, nrow=2, normalize=True, value_range=(0, 1) + ) - logs['reconstructions'] = grid - save_file = str(self.results_dir / f'{filename}.png') + logs["reconstructions"] = grid + save_file = str(self.results_dir / f"{filename}.png") save_image(grid, save_file) log_imgs.append(np.asarray(Image.open(save_file))) super().log_validation_images(log_imgs, steps) - def train_step(self): device = self.device @@ -178,18 +206,20 @@ def train_step(self): with self.accelerator.autocast(): loss = self.model( - img, - add_gradient_penalty = apply_grad_penalty, - return_loss = True + img, add_gradient_penalty=apply_grad_penalty, return_loss=True ) self.accelerator.backward(loss / self.gradient_accumulation_steps) - accum_log(logs, {'Train/vae_loss': loss.item() / self.gradient_accumulation_steps}) + accum_log( + logs, {"Train/vae_loss": loss.item() / self.gradient_accumulation_steps} + ) if exists(self.max_grad_norm): - self.accelerator.clip_grad_norm_(self.model.parameters(), self.max_grad_norm) - + self.accelerator.clip_grad_norm_( + self.model.parameters(), self.max_grad_norm + ) + self.lr_scheduler.step() self.lr_scheduler_discr.step() self.optim.step() @@ -204,21 +234,31 @@ def train_step(self): img = next(self.dl_iter) img = img.to(device) - loss = self.model(img, return_discr_loss = True) + loss = self.model(img, return_discr_loss=True) self.accelerator.backward(loss / self.gradient_accumulation_steps) - accum_log(logs, {'Train/discr_loss': loss.item() / self.gradient_accumulation_steps}) + accum_log( + logs, + { + "Train/discr_loss": loss.item() + / self.gradient_accumulation_steps + }, + ) if exists(self.discr_max_grad_norm): - self.accelerator.clip_grad_norm_(discr.parameters(), self.discr_max_grad_norm) + self.accelerator.clip_grad_norm_( + discr.parameters(), self.discr_max_grad_norm + ) self.discr_optim.step() # log - - self.print(f"{steps}: vae loss: {logs['Train/vae_loss']} - discr loss: {logs['Train/discr_loss']} - lr: {self.lr_scheduler.get_last_lr()[0]}") - logs['lr'] = self.lr_scheduler.get_last_lr()[0] + + self.print( + f"{steps}: vae loss: {logs['Train/vae_loss']} - discr loss: {logs['Train/discr_loss']} - lr: {self.lr_scheduler.get_last_lr()[0]}" + ) + logs["lr"] = self.lr_scheduler.get_last_lr()[0] self.accelerator.log(logs, step=steps) # update exponential moving averaged generator @@ -232,25 +272,28 @@ def train_step(self): vaes_to_evaluate = ((self.model, str(steps)),) if self.use_ema: - vaes_to_evaluate = ((ema_model.ema_model, f'{steps}.ema'),) + vaes_to_evaluate + vaes_to_evaluate = ( + (ema_model.ema_model, f"{steps}.ema"), + ) + vaes_to_evaluate self.log_validation_images(vaes_to_evaluate, logs, steps) - self.print(f'{steps}: saving to {str(self.results_dir)}') - + self.print(f"{steps}: saving to {str(self.results_dir)}") # save model every so often self.accelerator.wait_for_everyone() if self.is_main and (steps % self.save_model_every) == 0: state_dict = self.accelerator.unwrap_model(self.model).state_dict() - model_path = str(self.results_dir / f'vae.{steps}.pt') + model_path = str(self.results_dir / f"vae.{steps}.pt") self.accelerator.save(state_dict, model_path) if self.use_ema: - ema_state_dict = self.accelerator.unwrap_model(self.ema_model).state_dict() - model_path = str(self.results_dir / f'vae.{steps}.ema.pt') + ema_state_dict = self.accelerator.unwrap_model( + self.ema_model + ).state_dict() + model_path = str(self.results_dir / f"vae.{steps}.ema.pt") self.accelerator.save(ema_state_dict, model_path) - self.print(f'{steps}: saving model to {str(self.results_dir)}') + self.print(f"{steps}: saving model to {str(self.results_dir)}") self.steps += 1 return logs diff --git a/muse_maskgit_pytorch/vqgan_vae.py b/muse_maskgit_pytorch/vqgan_vae.py index 7089402..070e2bf 100644 --- a/muse_maskgit_pytorch/vqgan_vae.py +++ b/muse_maskgit_pytorch/vqgan_vae.py @@ -22,14 +22,18 @@ # helper functions + def exists(val): return val is not None + def default(val, d): return val if exists(val) else d + # decorators + def eval_decorator(fn): def inner(model, *args, **kwargs): was_training = model.training @@ -37,15 +41,17 @@ def inner(model, *args, **kwargs): out = fn(model, *args, **kwargs) model.train(was_training) return out + return inner + def remove_vgg(fn): @wraps(fn) def inner(self, *args, **kwargs): - has_vgg = hasattr(self, '_vgg') + has_vgg = hasattr(self, "_vgg") if has_vgg: vgg = self._vgg - delattr(self, '_vgg') + delattr(self, "_vgg") out = fn(self, *args, **kwargs) @@ -53,125 +59,153 @@ def inner(self, *args, **kwargs): self._vgg = vgg return out + return inner + # keyword argument helpers + def pick_and_pop(keys, d): values = list(map(lambda key: d.pop(key), keys)) return dict(zip(keys, values)) + def group_dict_by_key(cond, d): - return_val = [dict(),dict()] + return_val = [dict(), dict()] for key in d.keys(): match = bool(cond(key)) ind = int(not match) return_val[ind][key] = d[key] return (*return_val,) + def string_begins_with(prefix, string_input): return string_input.startswith(prefix) + def group_by_key_prefix(prefix, d): return group_dict_by_key(partial(string_begins_with, prefix), d) + def groupby_prefix_and_trim(prefix, d): - kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d) - kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items()))) + kwargs_with_prefix, kwargs = group_dict_by_key( + partial(string_begins_with, prefix), d + ) + kwargs_without_prefix = dict( + map(lambda x: (x[0][len(prefix) :], x[1]), tuple(kwargs_with_prefix.items())) + ) return kwargs_without_prefix, kwargs + # tensor helper functions -def log(t, eps = 1e-10): + +def log(t, eps=1e-10): return torch.log(t + eps) -def gradient_penalty(images, output, weight = 10): + +def gradient_penalty(images, output, weight=10): batch_size = images.shape[0] gradients = torch_grad( - outputs = output, - inputs = images, - grad_outputs = torch.ones(output.size(), device = images.device), - create_graph = True, - retain_graph = True, - only_inputs = True + outputs=output, + inputs=images, + grad_outputs=torch.ones(output.size(), device=images.device), + create_graph=True, + retain_graph=True, + only_inputs=True, )[0] - gradients = rearrange(gradients, 'b ... -> b (...)') - return weight * ((gradients.norm(2, dim = 1) - 1) ** 2).mean() + gradients = rearrange(gradients, "b ... -> b (...)") + return weight * ((gradients.norm(2, dim=1) - 1) ** 2).mean() + -def leaky_relu(p = 0.1): +def leaky_relu(p=0.1): return nn.LeakyReLU(0.1) -def safe_div(numer, denom, eps = 1e-8): - return numer / denom.clamp(min = eps) + +def safe_div(numer, denom, eps=1e-8): + return numer / denom.clamp(min=eps) + # gan losses + def hinge_discr_loss(fake, real): return (F.relu(1 + fake) + F.relu(1 - real)).mean() + def hinge_gen_loss(fake): return -fake.mean() + def bce_discr_loss(fake, real): return (-log(1 - torch.sigmoid(fake)) - log(torch.sigmoid(real))).mean() + def bce_gen_loss(fake): return -log(torch.sigmoid(fake)).mean() + def grad_layer_wrt_loss(loss, layer): return torch_grad( - outputs = loss, - inputs = layer, - grad_outputs = torch.ones_like(loss), - retain_graph = True + outputs=loss, + inputs=layer, + grad_outputs=torch.ones_like(loss), + retain_graph=True, )[0].detach() + # vqgan vae + class LayerNormChan(nn.Module): - def __init__( - self, - dim, - eps = 1e-5 - ): + def __init__(self, dim, eps=1e-5): super().__init__() self.eps = eps self.gamma = nn.Parameter(torch.ones(1, dim, 1, 1)) def forward(self, x): - var = torch.var(x, dim = 1, unbiased = False, keepdim = True) - mean = torch.mean(x, dim = 1, keepdim = True) - return (x - mean) * var.clamp(min = self.eps).rsqrt() * self.gamma + var = torch.var(x, dim=1, unbiased=False, keepdim=True) + mean = torch.mean(x, dim=1, keepdim=True) + return (x - mean) * var.clamp(min=self.eps).rsqrt() * self.gamma + # discriminator + class Discriminator(nn.Module): - def __init__( - self, - dims, - channels = 3, - groups = 16, - init_kernel_size = 5 - ): + def __init__(self, dims, channels=3, groups=16, init_kernel_size=5): super().__init__() dim_pairs = zip(dims[:-1], dims[1:]) - self.layers = MList([nn.Sequential(nn.Conv2d(channels, dims[0], init_kernel_size, padding = init_kernel_size // 2), leaky_relu())]) + self.layers = MList( + [ + nn.Sequential( + nn.Conv2d( + channels, + dims[0], + init_kernel_size, + padding=init_kernel_size // 2, + ), + leaky_relu(), + ) + ] + ) for dim_in, dim_out in dim_pairs: - self.layers.append(nn.Sequential( - nn.Conv2d(dim_in, dim_out, 4, stride = 2, padding = 1), - nn.GroupNorm(groups, dim_out), - leaky_relu() - )) + self.layers.append( + nn.Sequential( + nn.Conv2d(dim_in, dim_out, 4, stride=2, padding=1), + nn.GroupNorm(groups, dim_out), + leaky_relu(), + ) + ) dim = dims[-1] - self.to_logits = nn.Sequential( # return 5 x 5, for PatchGAN-esque training - nn.Conv2d(dim, dim, 1), - leaky_relu(), - nn.Conv2d(dim, 1, 4) + self.to_logits = nn.Sequential( # return 5 x 5, for PatchGAN-esque training + nn.Conv2d(dim, dim, 1), leaky_relu(), nn.Conv2d(dim, 1, 4) ) def forward(self, x): @@ -180,30 +214,36 @@ def forward(self, x): return self.to_logits(x) + # resnet encoder / decoder + class ResnetEncDec(nn.Module): def __init__( self, dim, *, - channels = 3, - layers = 4, - layer_mults = None, - num_resnet_blocks = 1, - resnet_groups = 16, - first_conv_kernel_size = 5 + channels=3, + layers=4, + layer_mults=None, + num_resnet_blocks=1, + resnet_groups=16, + first_conv_kernel_size=5, ): super().__init__() - assert dim % resnet_groups == 0, f'dimension {dim} must be divisible by {resnet_groups} (groups for the groupnorm)' + assert ( + dim % resnet_groups == 0 + ), f"dimension {dim} must be divisible by {resnet_groups} (groups for the groupnorm)" self.layers = layers self.encoders = MList([]) self.decoders = MList([]) - layer_mults = default(layer_mults, list(map(lambda t: 2 ** t, range(layers)))) - assert len(layer_mults) == layers, 'layer multipliers must be equal to designated number of layers' + layer_mults = default(layer_mults, list(map(lambda t: 2**t, range(layers)))) + assert ( + len(layer_mults) == layers + ), "layer multipliers must be equal to designated number of layers" layer_dims = [dim * mult for mult in layer_mults] dims = (dim, *layer_dims) @@ -218,21 +258,43 @@ def __init__( if not isinstance(num_resnet_blocks, tuple): num_resnet_blocks = (*((0,) * (layers - 1)), num_resnet_blocks) - assert len(num_resnet_blocks) == layers, 'number of resnet blocks config must be equal to number of layers' - - for layer_index, (dim_in, dim_out), layer_num_resnet_blocks in zip(range(layers), dim_pairs, num_resnet_blocks): - append(self.encoders, nn.Sequential(nn.Conv2d(dim_in, dim_out, 4, stride = 2, padding = 1), leaky_relu())) - prepend(self.decoders, nn.Sequential(nn.ConvTranspose2d(dim_out, dim_in, 4, 2, 1), leaky_relu())) + assert ( + len(num_resnet_blocks) == layers + ), "number of resnet blocks config must be equal to number of layers" + + for layer_index, (dim_in, dim_out), layer_num_resnet_blocks in zip( + range(layers), dim_pairs, num_resnet_blocks + ): + append( + self.encoders, + nn.Sequential( + nn.Conv2d(dim_in, dim_out, 4, stride=2, padding=1), leaky_relu() + ), + ) + prepend( + self.decoders, + nn.Sequential( + nn.ConvTranspose2d(dim_out, dim_in, 4, 2, 1), leaky_relu() + ), + ) for _ in range(layer_num_resnet_blocks): - append(self.encoders, ResBlock(dim_out, groups = resnet_groups)) - prepend(self.decoders, GLUResBlock(dim_out, groups = resnet_groups)) - - prepend(self.encoders, nn.Conv2d(channels, dim, first_conv_kernel_size, padding = first_conv_kernel_size // 2)) + append(self.encoders, ResBlock(dim_out, groups=resnet_groups)) + prepend(self.decoders, GLUResBlock(dim_out, groups=resnet_groups)) + + prepend( + self.encoders, + nn.Conv2d( + channels, + dim, + first_conv_kernel_size, + padding=first_conv_kernel_size // 2, + ), + ) append(self.decoders, nn.Conv2d(dim, channels, 1)) def get_encoded_fmap_size(self, image_size): - return image_size // (2 ** self.layers) + return image_size // (2**self.layers) @property def last_dec_layer(self): @@ -248,87 +310,88 @@ def decode(self, x): x = dec(x) return x + class GLUResBlock(nn.Module): - def __init__(self, chan, groups = 16): + def __init__(self, chan, groups=16): super().__init__() self.net = nn.Sequential( - nn.Conv2d(chan, chan * 2, 3, padding = 1), - nn.GLU(dim = 1), + nn.Conv2d(chan, chan * 2, 3, padding=1), + nn.GLU(dim=1), nn.GroupNorm(groups, chan), - nn.Conv2d(chan, chan * 2, 3, padding = 1), - nn.GLU(dim = 1), + nn.Conv2d(chan, chan * 2, 3, padding=1), + nn.GLU(dim=1), nn.GroupNorm(groups, chan), - nn.Conv2d(chan, chan, 1) + nn.Conv2d(chan, chan, 1), ) def forward(self, x): return self.net(x) + x + class ResBlock(nn.Module): - def __init__(self, chan, groups = 16): + def __init__(self, chan, groups=16): super().__init__() self.net = nn.Sequential( - nn.Conv2d(chan, chan, 3, padding = 1), + nn.Conv2d(chan, chan, 3, padding=1), nn.GroupNorm(groups, chan), leaky_relu(), - nn.Conv2d(chan, chan, 3, padding = 1), + nn.Conv2d(chan, chan, 3, padding=1), nn.GroupNorm(groups, chan), leaky_relu(), - nn.Conv2d(chan, chan, 1) + nn.Conv2d(chan, chan, 1), ) def forward(self, x): return self.net(x) + x + # main vqgan-vae classes + class VQGanVAE(nn.Module): def __init__( self, *, dim, - channels = 3, - layers = 4, - l2_recon_loss = False, - use_hinge_loss = True, - vgg = None, - vq_codebook_dim = 256, - vq_codebook_size = 512, - vq_decay = 0.8, - vq_commitment_weight = 1., - vq_kmeans_init = True, - vq_use_cosine_sim = True, - use_vgg_and_gan = True, - discr_layers = 4, - **kwargs + channels=3, + layers=4, + l2_recon_loss=False, + use_hinge_loss=True, + vgg=None, + vq_codebook_dim=256, + vq_codebook_size=512, + vq_decay=0.8, + vq_commitment_weight=1.0, + vq_kmeans_init=True, + vq_use_cosine_sim=True, + use_vgg_and_gan=True, + discr_layers=4, + **kwargs, ): super().__init__() - vq_kwargs, kwargs = groupby_prefix_and_trim('vq_', kwargs) - encdec_kwargs, kwargs = groupby_prefix_and_trim('encdec_', kwargs) + vq_kwargs, kwargs = groupby_prefix_and_trim("vq_", kwargs) + encdec_kwargs, kwargs = groupby_prefix_and_trim("encdec_", kwargs) self.channels = channels self.codebook_size = vq_codebook_size - self.dim_divisor = 2 ** layers + self.dim_divisor = 2**layers enc_dec_klass = ResnetEncDec self.enc_dec = enc_dec_klass( - dim = dim, - channels = channels, - layers = layers, - **encdec_kwargs + dim=dim, channels=channels, layers=layers, **encdec_kwargs ) self.vq = VQ( - dim = self.enc_dec.encoded_dim, - codebook_dim = vq_codebook_dim, - codebook_size = vq_codebook_size, - decay = vq_decay, - commitment_weight = vq_commitment_weight, - accept_image_fmap = True, - kmeans_init = vq_kmeans_init, - use_cosine_sim = vq_use_cosine_sim, - **vq_kwargs + dim=self.enc_dec.encoded_dim, + codebook_dim=vq_codebook_dim, + codebook_size=vq_codebook_size, + decay=vq_decay, + commitment_weight=vq_commitment_weight, + accept_image_fmap=True, + kmeans_init=vq_kmeans_init, + use_cosine_sim=vq_use_cosine_sim, + **vq_kwargs, ) # reconstruction loss @@ -351,11 +414,11 @@ def __init__( # gan related losses - layer_mults = list(map(lambda t: 2 ** t, range(discr_layers))) + layer_mults = list(map(lambda t: 2**t, range(discr_layers))) layer_dims = [dim * mult for mult in layer_mults] dims = (dim, *layer_dims) - self.discr = Discriminator(dims = dims, channels = channels) + self.discr = Discriminator(dims=dims, channels=channels) self.discr_loss = hinge_discr_loss if use_hinge_loss else bce_discr_loss self.gen_loss = hinge_gen_loss if use_hinge_loss else bce_gen_loss @@ -369,7 +432,7 @@ def vgg(self): if exists(self._vgg): return self._vgg - vgg = torchvision.models.vgg16(pretrained = True) + vgg = torchvision.models.vgg16(pretrained=True) vgg.classifier = nn.Sequential(*vgg.classifier[:-2]) self._vgg = vgg.to(self.device) return self._vgg @@ -421,7 +484,7 @@ def encode(self, fmap): def decode_from_ids(self, ids): codes = self.codebook[ids] fmap = self.vq.project_out(codes) - fmap = rearrange(fmap, 'b h w c -> b c h w') + fmap = rearrange(fmap, "b h w c -> b c h w") return self.decode(fmap) def decode(self, fmap): @@ -430,17 +493,21 @@ def decode(self, fmap): def forward( self, img, - return_loss = False, - return_discr_loss = False, - return_recons = False, - add_gradient_penalty = True + return_loss=False, + return_discr_loss=False, + return_recons=False, + add_gradient_penalty=True, ): batch, channels, height, width, device = *img.shape, img.device - for dim_name, size in (('height', height), ('width', width)): - assert (size % self.dim_divisor) == 0, f'{dim_name} must be divisible by {self.dim_divisor}' + for dim_name, size in (("height", height), ("width", width)): + assert ( + size % self.dim_divisor + ) == 0, f"{dim_name} must be divisible by {self.dim_divisor}" - assert channels == self.channels, 'number of channels on image or sketch is not equal to the channels set on this VQGanVAE' + assert ( + channels == self.channels + ), "number of channels on image or sketch is not equal to the channels set on this VQGanVAE" fmap, indices, commit_loss = self.encode(img) @@ -449,12 +516,14 @@ def forward( if not return_loss and not return_discr_loss: return fmap - assert return_loss ^ return_discr_loss, 'you should either return autoencoder loss or discriminator loss, but not both' + assert ( + return_loss ^ return_discr_loss + ), "you should either return autoencoder loss or discriminator loss, but not both" # whether to return discriminator loss if return_discr_loss: - assert exists(self.discr), 'discriminator must exist to train it' + assert exists(self.discr), "discriminator must exist to train it" fmap.detach_() img.requires_grad_() @@ -491,7 +560,10 @@ def forward( if img.shape[1] == 1: # handle grayscale for vgg - img_vgg_input, fmap_vgg_input = map(lambda t: repeat(t, 'b 1 ... -> b c ...', c = 3), (img_vgg_input, fmap_vgg_input)) + img_vgg_input, fmap_vgg_input = map( + lambda t: repeat(t, "b 1 ... -> b c ...", c=3), + (img_vgg_input, fmap_vgg_input), + ) img_vgg_feats = self.vgg(img_vgg_input) recon_vgg_feats = self.vgg(fmap_vgg_input) @@ -505,11 +577,15 @@ def forward( last_dec_layer = self.enc_dec.last_dec_layer - norm_grad_wrt_gen_loss = grad_layer_wrt_loss(gen_loss, last_dec_layer).norm(p = 2) - norm_grad_wrt_perceptual_loss = grad_layer_wrt_loss(perceptual_loss, last_dec_layer).norm(p = 2) + norm_grad_wrt_gen_loss = grad_layer_wrt_loss(gen_loss, last_dec_layer).norm(p=2) + norm_grad_wrt_perceptual_loss = grad_layer_wrt_loss( + perceptual_loss, last_dec_layer + ).norm(p=2) - adaptive_weight = safe_div(norm_grad_wrt_perceptual_loss, norm_grad_wrt_gen_loss) - adaptive_weight.clamp_(max = 1e4) + adaptive_weight = safe_div( + norm_grad_wrt_perceptual_loss, norm_grad_wrt_gen_loss + ) + adaptive_weight.clamp_(max=1e4) # combine losses diff --git a/setup.py b/setup.py index 21ed312..e8c1b14 100644 --- a/setup.py +++ b/setup.py @@ -1,43 +1,43 @@ from setuptools import setup, find_packages setup( - name = 'muse-maskgit-pytorch', - packages = find_packages(exclude=[]), - version = '0.1.0', - license='MIT', - description = 'MUSE - Text-to-Image Generation via Masked Generative Transformers, in Pytorch', - author = 'Phil Wang', - author_email = 'lucidrains@gmail.com', - long_description_content_type = 'text/markdown', - url = 'https://github.com/lucidrains/muse-maskgit-pytorch', - keywords = [ - 'artificial intelligence', - 'deep learning', - 'transformers', - 'attention mechanism', - 'text-to-image' - ], - install_requires=[ - 'accelerate', - 'diffusers', - 'datasets', - 'beartype', - 'einops>=0.6', - 'ema-pytorch', - 'pillow', - 'sentencepiece', - 'torch>=1.6', - 'transformers', - 'torch>=1.6', - 'torchvision', - 'tqdm', - 'vector-quantize-pytorch>=0.10.14' - ], - classifiers=[ - 'Development Status :: 4 - Beta', - 'Intended Audience :: Developers', - 'Topic :: Scientific/Engineering :: Artificial Intelligence', - 'License :: OSI Approved :: MIT License', - 'Programming Language :: Python :: 3.6', - ], + name="muse-maskgit-pytorch", + packages=find_packages(exclude=[]), + version="0.1.0", + license="MIT", + description="MUSE - Text-to-Image Generation via Masked Generative Transformers, in Pytorch", + author="Phil Wang", + author_email="lucidrains@gmail.com", + long_description_content_type="text/markdown", + url="https://github.com/lucidrains/muse-maskgit-pytorch", + keywords=[ + "artificial intelligence", + "deep learning", + "transformers", + "attention mechanism", + "text-to-image", + ], + install_requires=[ + "accelerate", + "diffusers", + "datasets", + "beartype", + "einops>=0.6", + "ema-pytorch", + "pillow", + "sentencepiece", + "torch>=1.6", + "transformers", + "torch>=1.6", + "torchvision", + "tqdm", + "vector-quantize-pytorch>=0.10.14", + ], + classifiers=[ + "Development Status :: 4 - Beta", + "Intended Audience :: Developers", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "License :: OSI Approved :: MIT License", + "Programming Language :: Python :: 3.6", + ], ) diff --git a/train_muse_maskgit.py b/train_muse_maskgit.py index f253645..dc4c54c 100644 --- a/train_muse_maskgit.py +++ b/train_muse_maskgit.py @@ -9,36 +9,48 @@ MaskGitTrainer, MaskGit, MaskGitTransformer, - get_accelerator + get_accelerator, +) +from muse_maskgit_pytorch.dataset import ( + get_dataset_from_dataroot, + ImageTextDataset, + split_dataset_into_dataloaders, ) -from muse_maskgit_pytorch.dataset import get_dataset_from_dataroot, ImageTextDataset, split_dataset_into_dataloaders import argparse + def parse_args(): # Create the parser parser = argparse.ArgumentParser() parser.add_argument( - '--dataset_save_path', type=str, default="dataset", help="Path to save the dataset if you are making one from a directory" - ) - parser.add_argument( - "--clear_previous_experiments", action="store_true", help="Whether to clear previous experiments." + "--dataset_save_path", + type=str, + default="dataset", + help="Path to save the dataset if you are making one from a directory", ) parser.add_argument( - "--num_tokens", type=int, default=256, help="Number of tokens. Must be same as codebook size above" + "--clear_previous_experiments", + action="store_true", + help="Whether to clear previous experiments.", ) parser.add_argument( - "--seq_len", type=int, default=1024, help="The sequence length. Must be equivalent to fmap_size ** 2 in vae" + "--num_tokens", + type=int, + default=256, + help="Number of tokens. Must be same as codebook size above", ) parser.add_argument( - "--depth", type=int, default=2, help="The depth of model" + "--seq_len", + type=int, + default=1024, + help="The sequence length. Must be equivalent to fmap_size ** 2 in vae", ) + parser.add_argument("--depth", type=int, default=2, help="The depth of model") parser.add_argument( "--dim_head", type=int, default=64, help="Attention head dimension" ) - parser.add_argument( - "--heads", type=int, default=8, help="Attention heads" - ) + parser.add_argument("--heads", type=int, default=8, help="Attention heads") parser.add_argument( "--ff_mult", type=int, default=4, help="Feed forward expansion factor" ) @@ -49,34 +61,40 @@ def parse_args(): "--cond_image_size", type=int, default=None, help="Conditional image size." ) parser.add_argument( - "--validation_prompt", type=str, default="A photo of a dog", help="Validation prompt." + "--validation_prompt", + type=str, + default="A photo of a dog", + help="Validation prompt.", ) parser.add_argument( "--max_grad_norm", type=float, default=None, help="Max gradient norm." ) - parser.add_argument( - "--seed", type=int, default=42, help="Seed." - ) + parser.add_argument("--seed", type=int, default=42, help="Seed.") parser.add_argument( "--valid_frac", type=float, default=0.05, help="validation fraction." ) - parser.add_argument( - "--use_ema", action="store_true", help="Whether to use ema." - ) - parser.add_argument( - "--ema_beta", type=float, default=0.995, help="Ema beta." - ) + parser.add_argument("--use_ema", action="store_true", help="Whether to use ema.") + parser.add_argument("--ema_beta", type=float, default=0.995, help="Ema beta.") parser.add_argument( "--ema_update_after_step", type=int, default=1, help="Ema update after step." ) parser.add_argument( - "--ema_update_every", type=int, default=1, help="Ema update every this number of steps." + "--ema_update_every", + type=int, + default=1, + help="Ema update every this number of steps.", ) parser.add_argument( - "--apply_grad_penalty_every", type=int, default=4, help="Apply gradient penalty every this number of steps." + "--apply_grad_penalty_every", + type=int, + default=4, + help="Apply gradient penalty every this number of steps.", ) parser.add_argument( - "--image_column", type=str, default="image", help="The column of the dataset containing an image." + "--image_column", + type=str, + default="image", + help="The column of the dataset containing an image.", ) parser.add_argument( "--caption_column", @@ -98,7 +116,7 @@ def parse_args(): type=str, default="no", choices=["no", "fp16", "bf16"], - help="Precision to train on." + help="Precision to train on.", ) parser.add_argument( "--results_dir", @@ -124,7 +142,7 @@ def parse_args(): "--dataset_name", type=str, default=None, - help="Name of the huggingface dataset used." + help="Name of the huggingface dataset used.", ) parser.add_argument( "--train_data_dir", @@ -142,7 +160,10 @@ def parse_args(): parser.add_argument("--batch_size", type=int, default=1, help="Batch Size.") parser.add_argument("--lr", type=float, default=3e-4, help="Learning Rate.") parser.add_argument( - "--gradient_accumulation_steps", type=int, default=1, help="Gradient Accumulation." + "--gradient_accumulation_steps", + type=int, + default=1, + help="Gradient Accumulation.", ) parser.add_argument( "--log_model_every", @@ -163,13 +184,30 @@ def parse_args(): help="Save the model every this number of steps.", ) parser.add_argument("--vq_codebook_size", type=int, default=256, help="Image Size.") - parser.add_argument("--cond_drop_prob", type=float, default=0.5, help="Conditional dropout, for classifier free guidance.") + parser.add_argument( + "--cond_drop_prob", + type=float, + default=0.5, + help="Conditional dropout, for classifier free guidance.", + ) parser.add_argument( "--image_size", type=int, default=256, help="Image size. You may want to start with small images, and then curriculum learn to larger ones, but because the vae is all convolution, it should generalize to 512 (as in paper) without training on it", ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help='The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"]', + ) + parser.add_argument( + "--lr_warmup_steps", + type=int, + default=0, + help="Number of steps for the warmup in the lr scheduler.", + ) parser.add_argument( "--resume_path", type=str, @@ -179,58 +217,69 @@ def parse_args(): # Parse the argument return parser.parse_args() + def main(): args = parse_args() - accelerator = get_accelerator(log_with=args.log_with, gradient_accumulation_steps=args.gradient_accumulation_steps,mixed_precision=args.mixed_precision, logging_dir=args.logging_dir) + accelerator = get_accelerator( + log_with=args.log_with, + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + logging_dir=args.logging_dir, + ) if accelerator.is_main_process: accelerator.init_trackers("muse_maskgit", config=vars(args)) if args.train_data_dir: - dataset = get_dataset_from_dataroot(args.train_data_dir, image_column=args.image_column, caption_column=args.caption_column, save_path=args.dataset_save_path) + dataset = get_dataset_from_dataroot( + args.train_data_dir, + image_column=args.image_column, + caption_column=args.caption_column, + save_path=args.dataset_save_path, + ) elif args.dataset_name: dataset = load_dataset(args.dataset_name)["train"] + vae = VQGanVAE(dim=args.dim, vq_codebook_size=args.vq_codebook_size).to( + accelerator.device + ) - vae = VQGanVAE( - dim = args.dim, - vq_codebook_size = args.vq_codebook_size - ).to(accelerator.device) - - print ('Resuming VAE from: ', args.vae_path) - vae.load(args.vae_path) # you will want to load the exponentially moving averaged VAE + print("Resuming VAE from: ", args.vae_path) + vae.load( + args.vae_path + ) # you will want to load the exponentially moving averaged VAE # then you plug the vae and transformer into your MaskGit as so # (1) create your transformer / attention network transformer = MaskGitTransformer( - num_tokens = args.num_tokens, # must be same as codebook size above - seq_len = args.seq_len, # must be equivalent to fmap_size ** 2 in vae - dim = args.dim, # model dimension - depth = args.depth, # depth - dim_head = args.dim_head, # attention head dimension - heads = args.heads, # attention heads, - ff_mult = args.ff_mult, # feedforward expansion factor - t5_name = args.t5_name, # name of your T5 + num_tokens=args.num_tokens, # must be same as codebook size above + seq_len=args.seq_len, # must be equivalent to fmap_size ** 2 in vae + dim=args.dim, # model dimension + depth=args.depth, # depth + dim_head=args.dim_head, # attention head dimension + heads=args.heads, # attention heads, + ff_mult=args.ff_mult, # feedforward expansion factor + t5_name=args.t5_name, # name of your T5 ).to(accelerator.device) transformer.t5.to(accelerator.device) # (2) pass your trained VAE and the base transformer to MaskGit maskgit = MaskGit( - vae = vae, # vqgan vae - transformer = transformer, # transformer - image_size = args.image_size, # image size - cond_drop_prob = args.cond_drop_prob, # conditional dropout, for classifier free guidance - cond_image_size = args.cond_image_size + vae=vae, # vqgan vae + transformer=transformer, # transformer + image_size=args.image_size, # image size + cond_drop_prob=args.cond_drop_prob, # conditional dropout, for classifier free guidance + cond_image_size=args.cond_image_size, ).to(accelerator.device) # load the maskgit transformer from disk if we have previously trained one if args.resume_path: - print (f'Resuming MaskGit from: {args.resume_path}') + print(f"Resuming MaskGit from: {args.resume_path}") maskgit.load(args.resume_path) - resume_from_parts = args.resume_path.split('.') - for i in range(len(resume_from_parts)-1, -1, -1): + resume_from_parts = args.resume_path.split(".") + for i in range(len(resume_from_parts) - 1, -1, -1): if resume_from_parts[i].isdigit(): current_step = int(resume_from_parts[i]) print(f"Found step {current_step} for the MaskGit model.") @@ -238,11 +287,19 @@ def main(): if current_step == 0: print("No step found for the MaskGit model.") - dataset = ImageTextDataset(dataset, args.image_size, transformer.tokenizer, image_column=args.image_column, caption_column=args.caption_column) - dataloader, validation_dataloader = split_dataset_into_dataloaders(dataset, args.valid_frac, args.seed, args.batch_size) + dataset = ImageTextDataset( + dataset, + args.image_size, + transformer.tokenizer, + image_column=args.image_column, + caption_column=args.caption_column, + ) + dataloader, validation_dataloader = split_dataset_into_dataloaders( + dataset, args.valid_frac, args.seed, args.batch_size + ) trainer = MaskGitTrainer( - maskgit,\ + maskgit, dataloader, validation_dataloader, accelerator, @@ -250,6 +307,8 @@ def main(): num_train_steps=args.num_train_steps, batch_size=args.batch_size, lr=args.lr, + lr_scheduler=args.lr_scheduler, + lr_warmup_steps=args.lr_warmup_steps, max_grad_norm=args.max_grad_norm, save_results_every=args.save_results_every, save_model_every=args.save_model_every, @@ -263,12 +322,11 @@ def main(): gradient_accumulation_steps=args.gradient_accumulation_steps, validation_prompt=args.validation_prompt, log_model_every=args.log_model_every, - clear_previous_experiments=args.clear_previous_experiments + clear_previous_experiments=args.clear_previous_experiments, ) trainer.train() - if __name__ == "__main__": main() diff --git a/train_muse_vae.py b/train_muse_vae.py index d0ba35d..01242e6 100644 --- a/train_muse_vae.py +++ b/train_muse_vae.py @@ -4,53 +4,66 @@ from pathlib import Path from datasets import load_dataset import os -from muse_maskgit_pytorch import ( - VQGanVAE, - VQGanVAETrainer, - get_accelerator +from muse_maskgit_pytorch import VQGanVAE, VQGanVAETrainer, get_accelerator +from muse_maskgit_pytorch.dataset import ( + get_dataset_from_dataroot, + ImageDataset, + split_dataset_into_dataloaders, ) -from muse_maskgit_pytorch.dataset import get_dataset_from_dataroot, ImageDataset, split_dataset_into_dataloaders import argparse + + def parse_args(): # Create the parser parser = argparse.ArgumentParser() parser.add_argument( - '--dataset_save_path', type=str, default="dataset", help="Path to save the dataset if you are making one from a directory" + "--dataset_save_path", + type=str, + default="dataset", + help="Path to save the dataset if you are making one from a directory", ) parser.add_argument( - "--clear_previous_experiments", action="store_true", help="Whether to clear previous experiments." + "--clear_previous_experiments", + action="store_true", + help="Whether to clear previous experiments.", ) parser.add_argument( "--max_grad_norm", type=float, default=None, help="Max gradient norm." ) parser.add_argument( - "--discr_max_grad_norm", type=float, default=None, help="Max gradient norm for discriminator." - ) - parser.add_argument( - "--seed", type=int, default=42, help="Seed." + "--discr_max_grad_norm", + type=float, + default=None, + help="Max gradient norm for discriminator.", ) + parser.add_argument("--seed", type=int, default=42, help="Seed.") parser.add_argument( "--valid_frac", type=float, default=0.05, help="validation fraction." ) - parser.add_argument( - "--use_ema", action="store_true", help="Whether to use ema." - ) - parser.add_argument( - "--ema_beta", type=float, default=0.995, help="Ema beta." - ) + parser.add_argument("--use_ema", action="store_true", help="Whether to use ema.") + parser.add_argument("--ema_beta", type=float, default=0.995, help="Ema beta.") parser.add_argument( "--ema_update_after_step", type=int, default=1, help="Ema update after step." ) parser.add_argument( - "--ema_update_every", type=int, default=1, help="Ema update every this number of steps." + "--ema_update_every", + type=int, + default=1, + help="Ema update every this number of steps.", ) parser.add_argument( - "--apply_grad_penalty_every", type=int, default=4, help="Apply gradient penalty every this number of steps." + "--apply_grad_penalty_every", + type=int, + default=4, + help="Apply gradient penalty every this number of steps.", ) parser.add_argument( - "--image_column", type=str, default="image", help="The column of the dataset containing an image." + "--image_column", + type=str, + default="image", + help="The column of the dataset containing an image.", ) parser.add_argument( "--caption_column", @@ -72,7 +85,7 @@ def parse_args(): type=str, default="no", choices=["no", "fp16", "bf16"], - help="Precision to train on." + help="Precision to train on.", ) parser.add_argument( "--results_dir", @@ -92,7 +105,7 @@ def parse_args(): "--dataset_name", type=str, default=None, - help="Name of the huggingface dataset used." + help="Name of the huggingface dataset used.", ) parser.add_argument( "--train_data_dir", @@ -110,7 +123,10 @@ def parse_args(): parser.add_argument("--batch_size", type=int, default=1, help="Batch Size.") parser.add_argument("--lr", type=float, default=3e-4, help="Learning Rate.") parser.add_argument( - "--gradient_accumulation_steps", type=int, default=1, help="Gradient Accumulation." + "--gradient_accumulation_steps", + type=int, + default=1, + help="Gradient Accumulation.", ) parser.add_argument( "--save_results_every", @@ -131,8 +147,18 @@ def parse_args(): default=256, help="Image size. You may want to start with small images, and then curriculum learn to larger ones, but because the vae is all convolution, it should generalize to 512 (as in paper) without training on it", ) - parser.add_argument("--lr_scheduler", type=str, default="constant", help='The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"]') - parser.add_argument("--lr_warmup_steps", type=int, default=0, help='Number of steps for the warmup in the lr scheduler.') + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help='The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"]', + ) + parser.add_argument( + "--lr_warmup_steps", + type=int, + default=0, + help="Number of steps for the warmup in the lr scheduler.", + ) parser.add_argument( "--resume_path", type=str, @@ -142,23 +168,34 @@ def parse_args(): # Parse the argument return parser.parse_args() + def main(): args = parse_args() - accelerator = get_accelerator(log_with=args.log_with, gradient_accumulation_steps=args.gradient_accumulation_steps,mixed_precision=args.mixed_precision, logging_dir=args.logging_dir) + accelerator = get_accelerator( + log_with=args.log_with, + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + logging_dir=args.logging_dir, + ) if accelerator.is_main_process: accelerator.init_trackers("muse_vae", config=vars(args)) if args.train_data_dir: - dataset = get_dataset_from_dataroot(args.train_data_dir, image_column=args.image_column, caption_column=args.caption_column, save_path=args.dataset_save_path) + dataset = get_dataset_from_dataroot( + args.train_data_dir, + image_column=args.image_column, + caption_column=args.caption_column, + save_path=args.dataset_save_path, + ) elif args.dataset_name: dataset = load_dataset(args.dataset_name)["train"] vae = VQGanVAE(dim=args.dim, vq_codebook_size=args.vq_codebook_size) - + if args.resume_path: - print (f'Resuming VAE from: {args.resume_path}') + print(f"Resuming VAE from: {args.resume_path}") vae.load(args.resume_path) - resume_from_parts = args.resume_path.split('.') - for i in range(len(resume_from_parts)-1, -1, -1): + resume_from_parts = args.resume_path.split(".") + for i in range(len(resume_from_parts) - 1, -1, -1): if resume_from_parts[i].isdigit(): current_step = int(resume_from_parts[i]) print(f"Found step {current_step} for the VAE model.") @@ -168,7 +205,9 @@ def main(): dataset = ImageDataset(dataset, args.image_size, image_column=args.image_column) # dataloader - dataloader, validation_dataloader = split_dataset_into_dataloaders(dataset, args.valid_frac, args.seed, args.batch_size) + dataloader, validation_dataloader = split_dataset_into_dataloaders( + dataset, args.valid_frac, args.seed, args.batch_size + ) trainer = VQGanVAETrainer( vae, dataloader, @@ -177,8 +216,8 @@ def main(): current_step=current_step, num_train_steps=args.num_train_steps, lr=args.lr, - lr_scheduler = args.lr_scheduler, - lr_warmup_steps = args.lr_warmup_steps, + lr_scheduler=args.lr_scheduler, + lr_warmup_steps=args.lr_warmup_steps, max_grad_norm=args.max_grad_norm, discr_max_grad_norm=args.discr_max_grad_norm, save_results_every=args.save_results_every, @@ -191,12 +230,11 @@ def main(): ema_update_every=args.ema_update_every, apply_grad_penalty_every=args.apply_grad_penalty_every, gradient_accumulation_steps=args.gradient_accumulation_steps, - clear_previous_experiments=args.clear_previous_experiments + clear_previous_experiments=args.clear_previous_experiments, ) trainer.train() - if __name__ == "__main__": main()