Skip to content

Commit

Permalink
Revert "memory efficient attention"
Browse files Browse the repository at this point in the history
This reverts commit c9aa0af.
  • Loading branch information
korakoe committed May 28, 2023
1 parent c9aa0af commit 01367f6
Show file tree
Hide file tree
Showing 3 changed files with 2 additions and 48 deletions.
42 changes: 2 additions & 40 deletions muse_maskgit_pytorch/muse_maskgit_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from torch import einsum, nn, isnan
from tqdm.auto import tqdm
from transformers import T5EncoderModel, T5Tokenizer
from memory_efficient_attention_pytorch import Attention as MemAttention

from .t5 import DEFAULT_T5_NAME, get_encoded_dim, get_model_and_tokenizer, t5_encode_text
from .vqgan_vae import VQGanVAE
Expand Down Expand Up @@ -182,35 +181,6 @@ def forward(self, x, context=None, context_mask=None):
return self.norm(x)


class MemoryEfficientTransformerBlocks(nn.Module):
def __init__(self, *, dim, depth, dim_head=64, heads=8, ff_mult=4):
super().__init__()
self.layers = nn.ModuleList([])

for _ in range(depth):
self.layers.append(
nn.ModuleList(
[
MemAttention(dim=dim, dim_head=dim_head, heads=heads),
MemAttention(dim=dim, dim_head=dim_head, heads=heads),
FeedForward(dim=dim, mult=ff_mult),
]
)
)

self.norm = LayerNorm(dim)

def forward(self, x, context=None, mask=None):
for attn, cross_attn, ff in self.layers:
x = attn(x) + x

x = cross_attn(x, context=context, mask=mask) + x

x = ff(x) + x

return self.norm(x)


# transformer - it's all we need
class Transformer(nn.Module):
def __init__(
Expand All @@ -224,7 +194,6 @@ def __init__(
self_cond: bool = False,
add_mask_id: bool = False,
cache_path: PathLike = None,
memory_efficient: bool = False,
**kwargs,
):
super().__init__()
Expand All @@ -235,12 +204,8 @@ def __init__(
self.token_emb = nn.Embedding(num_tokens + int(add_mask_id), dim)
self.pos_emb = nn.Embedding(seq_len, dim)
self.seq_len = seq_len
self.memory_efficient = memory_efficient

if memory_efficient:
self.transformer_blocks = MemoryEfficientTransformerBlocks(dim=dim, **kwargs)
else:
self.transformer_blocks = TransformerBlocks(dim=dim, **kwargs)
self.transformer_blocks = TransformerBlocks(dim=dim, **kwargs)
self.norm = LayerNorm(dim)

self.dim_out = default(dim_out, num_tokens)
Expand Down Expand Up @@ -350,10 +315,7 @@ def forward(
self_cond_embed = torch.zeros_like(x)
x = x + self.self_cond_to_init_embed(self_cond_embed)

if self.memory_efficient:
embed = self.transformer_blocks(x, context=context, mask=context_mask)
else:
embed = self.transformer_blocks(x, context=context, context_mask=context_mask)
embed = self.transformer_blocks(x, context=context, context_mask=context_mask)

logits = self.to_logits(embed)

Expand Down
1 change: 0 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
"tqdm",
"vector-quantize-pytorch>=0.10.14",
"lion-pytorch",
"memory-efficient-attention-pytorch=0.1.2"
],
classifiers=[
"Development Status :: 4 - Beta",
Expand Down
7 changes: 0 additions & 7 deletions train_muse_maskgit.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,11 +288,6 @@
default="Adafactor",
help="Optimizer to use. Choose between: ['Adam', 'AdamW', 'Lion', 'Adafactor']. Default: Adafactor (paper recommended)",
)
parser.add_argument(
"--memory_efficient",
action="store_true",
help="whether to use memory efficient attention instead of standard attention",
)
parser.add_argument(
"--weight_decay",
type=float,
Expand Down Expand Up @@ -369,7 +364,6 @@ class Arguments:
taming_model_path: Optional[str] = None
taming_config_path: Optional[str] = None
optimizer: str = "Lion"
memory_efficient: bool = False
weight_decay: float = 0.0
cache_path: Optional[str] = None
skip_arrow: bool = False
Expand Down Expand Up @@ -478,7 +472,6 @@ def main():
# name of your T5 model configuration
t5_name=args.t5_name,
cache_path=args.cache_path,
memory_efficient=args.memory_efficient
)
# (2) pass your trained VAE and the base transformer to MaskGit
maskgit = MaskGit(
Expand Down

0 comments on commit 01367f6

Please sign in to comment.