From 01367f642e56c769ad47ad352dfaf725637a0b5e Mon Sep 17 00:00:00 2001 From: Korakoe <56580073+korakoe@users.noreply.github.com> Date: Sun, 28 May 2023 12:12:18 +0800 Subject: [PATCH] Revert "memory efficient attention" This reverts commit c9aa0afd6dcbaf31c6a5e266b6994b8171ca2126. --- muse_maskgit_pytorch/muse_maskgit_pytorch.py | 42 +------------------- setup.py | 1 - train_muse_maskgit.py | 7 ---- 3 files changed, 2 insertions(+), 48 deletions(-) diff --git a/muse_maskgit_pytorch/muse_maskgit_pytorch.py b/muse_maskgit_pytorch/muse_maskgit_pytorch.py index ebf2176..4c7aa02 100644 --- a/muse_maskgit_pytorch/muse_maskgit_pytorch.py +++ b/muse_maskgit_pytorch/muse_maskgit_pytorch.py @@ -15,7 +15,6 @@ from torch import einsum, nn, isnan from tqdm.auto import tqdm from transformers import T5EncoderModel, T5Tokenizer -from memory_efficient_attention_pytorch import Attention as MemAttention from .t5 import DEFAULT_T5_NAME, get_encoded_dim, get_model_and_tokenizer, t5_encode_text from .vqgan_vae import VQGanVAE @@ -182,35 +181,6 @@ def forward(self, x, context=None, context_mask=None): return self.norm(x) -class MemoryEfficientTransformerBlocks(nn.Module): - def __init__(self, *, dim, depth, dim_head=64, heads=8, ff_mult=4): - super().__init__() - self.layers = nn.ModuleList([]) - - for _ in range(depth): - self.layers.append( - nn.ModuleList( - [ - MemAttention(dim=dim, dim_head=dim_head, heads=heads), - MemAttention(dim=dim, dim_head=dim_head, heads=heads), - FeedForward(dim=dim, mult=ff_mult), - ] - ) - ) - - self.norm = LayerNorm(dim) - - def forward(self, x, context=None, mask=None): - for attn, cross_attn, ff in self.layers: - x = attn(x) + x - - x = cross_attn(x, context=context, mask=mask) + x - - x = ff(x) + x - - return self.norm(x) - - # transformer - it's all we need class Transformer(nn.Module): def __init__( @@ -224,7 +194,6 @@ def __init__( self_cond: bool = False, add_mask_id: bool = False, cache_path: PathLike = None, - memory_efficient: bool = False, **kwargs, ): super().__init__() @@ -235,12 +204,8 @@ def __init__( self.token_emb = nn.Embedding(num_tokens + int(add_mask_id), dim) self.pos_emb = nn.Embedding(seq_len, dim) self.seq_len = seq_len - self.memory_efficient = memory_efficient - if memory_efficient: - self.transformer_blocks = MemoryEfficientTransformerBlocks(dim=dim, **kwargs) - else: - self.transformer_blocks = TransformerBlocks(dim=dim, **kwargs) + self.transformer_blocks = TransformerBlocks(dim=dim, **kwargs) self.norm = LayerNorm(dim) self.dim_out = default(dim_out, num_tokens) @@ -350,10 +315,7 @@ def forward( self_cond_embed = torch.zeros_like(x) x = x + self.self_cond_to_init_embed(self_cond_embed) - if self.memory_efficient: - embed = self.transformer_blocks(x, context=context, mask=context_mask) - else: - embed = self.transformer_blocks(x, context=context, context_mask=context_mask) + embed = self.transformer_blocks(x, context=context, context_mask=context_mask) logits = self.to_logits(embed) diff --git a/setup.py b/setup.py index 610a64b..815f7f7 100644 --- a/setup.py +++ b/setup.py @@ -36,7 +36,6 @@ "tqdm", "vector-quantize-pytorch>=0.10.14", "lion-pytorch", - "memory-efficient-attention-pytorch=0.1.2" ], classifiers=[ "Development Status :: 4 - Beta", diff --git a/train_muse_maskgit.py b/train_muse_maskgit.py index fd524b7..c5223ef 100644 --- a/train_muse_maskgit.py +++ b/train_muse_maskgit.py @@ -288,11 +288,6 @@ default="Adafactor", help="Optimizer to use. Choose between: ['Adam', 'AdamW', 'Lion', 'Adafactor']. Default: Adafactor (paper recommended)", ) -parser.add_argument( - "--memory_efficient", - action="store_true", - help="whether to use memory efficient attention instead of standard attention", -) parser.add_argument( "--weight_decay", type=float, @@ -369,7 +364,6 @@ class Arguments: taming_model_path: Optional[str] = None taming_config_path: Optional[str] = None optimizer: str = "Lion" - memory_efficient: bool = False weight_decay: float = 0.0 cache_path: Optional[str] = None skip_arrow: bool = False @@ -478,7 +472,6 @@ def main(): # name of your T5 model configuration t5_name=args.t5_name, cache_path=args.cache_path, - memory_efficient=args.memory_efficient ) # (2) pass your trained VAE and the base transformer to MaskGit maskgit = MaskGit(