From 39f843bf46a45cf363d88f3232c7e6e34bffc94c Mon Sep 17 00:00:00 2001 From: Andrew Powers-Holmes Date: Tue, 2 May 2023 04:26:41 +0000 Subject: [PATCH] muse-maskgit: more cleanup --- muse_maskgit_pytorch/muse_maskgit_pytorch.py | 99 +++++++++----------- 1 file changed, 44 insertions(+), 55 deletions(-) diff --git a/muse_maskgit_pytorch/muse_maskgit_pytorch.py b/muse_maskgit_pytorch/muse_maskgit_pytorch.py index 27c7216..73d0904 100644 --- a/muse_maskgit_pytorch/muse_maskgit_pytorch.py +++ b/muse_maskgit_pytorch/muse_maskgit_pytorch.py @@ -1,5 +1,6 @@ import math from functools import partial +from os import PathLike from pathlib import Path from random import random from typing import Callable, List, Optional, Union @@ -7,6 +8,7 @@ import torch import torch.nn.functional as F import torchvision.transforms as T +from accelerate import Accelerator from beartype import beartype from einops import rearrange, repeat from torch import einsum, nn @@ -17,15 +19,14 @@ from .vqgan_vae import VQGanVAE from .vqgan_vae_taming import VQGanVAETaming -# helpers - +# helpers def exists(val): return val is not None def default(val, d): - return val if exists(val) else d + return val if val is not None else d def eval_decorator(fn): @@ -44,8 +45,6 @@ def l2norm(t): # tensor helpers - - def get_mask_subset_prob(mask, prob, min_mask=0): batch, seq, device = *mask.shape, mask.device num_to_mask = (mask.sum(dim=-1, keepdim=True) * prob).clamp(min=min_mask) @@ -63,8 +62,6 @@ def get_mask_subset_prob(mask, prob, min_mask=0): # classes - - class LayerNorm(nn.Module): def __init__(self, dim): super().__init__() @@ -119,8 +116,7 @@ def __init__(self, dim, dim_head=64, heads=8, cross_attend=False, scale=8): def forward(self, x, context=None, context_mask=None): assert not (exists(context) ^ self.cross_attend) - h, is_cross_attn = self.heads, exists(context) - + h = self.heads x = self.norm(x) kv_input = context if self.cross_attend else x @@ -185,21 +181,19 @@ def forward(self, x, context=None, context_mask=None): # transformer - it's all we need - - class Transformer(nn.Module): def __init__( self, *, - num_tokens, - dim, - seq_len, - dim_out=None, - t5_name=DEFAULT_T5_NAME, - self_cond=False, - add_mask_id=False, - cache_path=None, - **kwargs + num_tokens: int, + dim: int, + seq_len: int, + dim_out: Optional[int] = None, + t5_name: str = DEFAULT_T5_NAME, + self_cond: bool = False, + add_mask_id: bool = False, + cache_path: PathLike = None, + **kwargs, ): super().__init__() self.dim = dim @@ -230,7 +224,6 @@ def __init__( ) # optional self conditioning - self.self_cond = self_cond self.self_cond_to_init_embed = FeedForward(dim) @@ -256,7 +249,7 @@ def forward_with_neg_prompt( neg_text_embed: torch.Tensor, cond_scale=3.0, return_embed=False, - **kwargs + **kwargs, ): neg_logits = self.forward(*args, neg_text_embed=neg_text_embed, cond_drop_prob=0.0, **kwargs) pos_logits, embed = self.forward( @@ -343,8 +336,6 @@ def forward( # self critic wrapper - - class SelfCritic(nn.Module): def __init__(self, net): super().__init__() @@ -371,23 +362,21 @@ def forward(self, x, *args, labels=None, **kwargs): # specialized transformers - - class MaskGitTransformer(Transformer): def __init__(self, *args, **kwargs): - assert "add_mask_id" not in kwargs + if kwargs.pop("add_mask_id", True) is not True: + raise ValueError("MaskGitTransformer does not accept add_mask_id argument") super().__init__(*args, add_mask_id=True, **kwargs) class TokenCritic(Transformer): def __init__(self, *args, **kwargs): - assert "dim_out" not in kwargs + if kwargs.pop("dim_out", 1) != 1: + raise ValueError("TokenCritic does not accept dim_out argument") super().__init__(*args, dim_out=1, **kwargs) # classifier free guidance functions - - def uniform(shape, min=0, max=1, device=None): return torch.zeros(shape, device=device).float().uniform_(0, 1) @@ -441,50 +430,46 @@ def __init__( self, image_size, transformer: MaskGitTransformer, + accelerator: Optional[Accelerator] = None, noise_schedule: Callable = cosine_schedule, token_critic: Optional[TokenCritic] = None, - self_token_critic=False, + self_token_critic: bool = False, vae: Optional[Union[VQGanVAE, VQGanVAETaming]] = None, cond_vae: Optional[Union[VQGanVAE, VQGanVAETaming]] = None, - cond_image_size=None, - cond_drop_prob=0.5, - self_cond_prob=0.9, - no_mask_token_prob=0.0, - critic_loss_weight=1.0, + cond_image_size: Optional[int] = None, + cond_drop_prob: float = 0.5, + self_cond_prob: float = 0.9, + no_mask_token_prob: float = 0.0, + critic_loss_weight: float = 1.0, ): super().__init__() - self.vae = vae.copy_for_eval() if exists(vae) else None + self.accelerator = accelerator + + self.vae = vae.copy_for_eval() if vae is not None else None - if exists(cond_vae): + if cond_vae is not None: + if cond_image_size is None: + raise ValueError("cond_image_size must be specified if conditioning") self.cond_vae = cond_vae.eval() else: self.cond_vae = self.vae - assert not ( - exists(cond_vae) and not exists(cond_image_size) - ), "cond_image_size must be specified if conditioning" - self.image_size = image_size self.cond_image_size = cond_image_size self.resize_image_for_cond_image = exists(cond_image_size) - self.cond_drop_prob = cond_drop_prob self.transformer = transformer self.self_cond = transformer.self_cond - assert ( - self.vae.codebook_size == self.cond_vae.codebook_size == transformer.num_tokens - ), "transformer num_tokens must be set to be equal to the vae codebook size" + if not self.vae.codebook_size == self.cond_vae.codebook_size == transformer.num_tokens: + raise ValueError("transformer num_tokens must be set to be equal to the vae codebook size") self.mask_id = transformer.mask_id self.noise_schedule = noise_schedule - assert not (self_token_critic and exists(token_critic)) - self.token_critic = token_critic - - if self_token_critic: - self.token_critic = SelfCritic(transformer) - + if token_critic and self_token_critic: + raise ValueError("cannot have both self_token_critic and token_critic") + self.token_critic = SelfCritic(transformer) if self_token_critic else token_critic self.critic_loss_weight = critic_loss_weight # self conditioning @@ -495,12 +480,16 @@ def __init__( self.no_mask_token_prob = no_mask_token_prob def save(self, path): - self.accelerator.save(self.state_dict(), path) + if self.accelerator: + self.accelerator.save(self.state_dict(), path) + else: + torch.save(self.state_dict(), path) def load(self, path): path = Path(path) - assert path.exists() - state_dict = torch.load(str(path)) + if not path.exists() and path.is_file(): + raise ValueError(f"cannot find file {path} (does not exist or is not a file)") + state_dict = torch.load(str(path), map_location="cpu") self.load_state_dict(state_dict) @torch.no_grad()