diff --git a/muse_maskgit_pytorch/muse_maskgit_pytorch.py b/muse_maskgit_pytorch/muse_maskgit_pytorch.py index 4c7aa02..ebf2176 100644 --- a/muse_maskgit_pytorch/muse_maskgit_pytorch.py +++ b/muse_maskgit_pytorch/muse_maskgit_pytorch.py @@ -15,6 +15,7 @@ from torch import einsum, nn, isnan from tqdm.auto import tqdm from transformers import T5EncoderModel, T5Tokenizer +from memory_efficient_attention_pytorch import Attention as MemAttention from .t5 import DEFAULT_T5_NAME, get_encoded_dim, get_model_and_tokenizer, t5_encode_text from .vqgan_vae import VQGanVAE @@ -181,6 +182,35 @@ def forward(self, x, context=None, context_mask=None): return self.norm(x) +class MemoryEfficientTransformerBlocks(nn.Module): + def __init__(self, *, dim, depth, dim_head=64, heads=8, ff_mult=4): + super().__init__() + self.layers = nn.ModuleList([]) + + for _ in range(depth): + self.layers.append( + nn.ModuleList( + [ + MemAttention(dim=dim, dim_head=dim_head, heads=heads), + MemAttention(dim=dim, dim_head=dim_head, heads=heads), + FeedForward(dim=dim, mult=ff_mult), + ] + ) + ) + + self.norm = LayerNorm(dim) + + def forward(self, x, context=None, mask=None): + for attn, cross_attn, ff in self.layers: + x = attn(x) + x + + x = cross_attn(x, context=context, mask=mask) + x + + x = ff(x) + x + + return self.norm(x) + + # transformer - it's all we need class Transformer(nn.Module): def __init__( @@ -194,6 +224,7 @@ def __init__( self_cond: bool = False, add_mask_id: bool = False, cache_path: PathLike = None, + memory_efficient: bool = False, **kwargs, ): super().__init__() @@ -204,8 +235,12 @@ def __init__( self.token_emb = nn.Embedding(num_tokens + int(add_mask_id), dim) self.pos_emb = nn.Embedding(seq_len, dim) self.seq_len = seq_len + self.memory_efficient = memory_efficient - self.transformer_blocks = TransformerBlocks(dim=dim, **kwargs) + if memory_efficient: + self.transformer_blocks = MemoryEfficientTransformerBlocks(dim=dim, **kwargs) + else: + self.transformer_blocks = TransformerBlocks(dim=dim, **kwargs) self.norm = LayerNorm(dim) self.dim_out = default(dim_out, num_tokens) @@ -315,7 +350,10 @@ def forward( self_cond_embed = torch.zeros_like(x) x = x + self.self_cond_to_init_embed(self_cond_embed) - embed = self.transformer_blocks(x, context=context, context_mask=context_mask) + if self.memory_efficient: + embed = self.transformer_blocks(x, context=context, mask=context_mask) + else: + embed = self.transformer_blocks(x, context=context, context_mask=context_mask) logits = self.to_logits(embed) diff --git a/setup.py b/setup.py index 815f7f7..610a64b 100644 --- a/setup.py +++ b/setup.py @@ -36,6 +36,7 @@ "tqdm", "vector-quantize-pytorch>=0.10.14", "lion-pytorch", + "memory-efficient-attention-pytorch=0.1.2" ], classifiers=[ "Development Status :: 4 - Beta", diff --git a/train_muse_maskgit.py b/train_muse_maskgit.py index c5223ef..fd524b7 100644 --- a/train_muse_maskgit.py +++ b/train_muse_maskgit.py @@ -288,6 +288,11 @@ default="Adafactor", help="Optimizer to use. Choose between: ['Adam', 'AdamW', 'Lion', 'Adafactor']. Default: Adafactor (paper recommended)", ) +parser.add_argument( + "--memory_efficient", + action="store_true", + help="whether to use memory efficient attention instead of standard attention", +) parser.add_argument( "--weight_decay", type=float, @@ -364,6 +369,7 @@ class Arguments: taming_model_path: Optional[str] = None taming_config_path: Optional[str] = None optimizer: str = "Lion" + memory_efficient: bool = False weight_decay: float = 0.0 cache_path: Optional[str] = None skip_arrow: bool = False @@ -472,6 +478,7 @@ def main(): # name of your T5 model configuration t5_name=args.t5_name, cache_path=args.cache_path, + memory_efficient=args.memory_efficient ) # (2) pass your trained VAE and the base transformer to MaskGit maskgit = MaskGit(