Skip to content

Commit

Permalink
memory efficient attention
Browse files Browse the repository at this point in the history
  • Loading branch information
korakoe committed May 15, 2023
1 parent 3a987f4 commit c9aa0af
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 2 deletions.
42 changes: 40 additions & 2 deletions muse_maskgit_pytorch/muse_maskgit_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from torch import einsum, nn, isnan
from tqdm.auto import tqdm
from transformers import T5EncoderModel, T5Tokenizer
from memory_efficient_attention_pytorch import Attention as MemAttention

from .t5 import DEFAULT_T5_NAME, get_encoded_dim, get_model_and_tokenizer, t5_encode_text
from .vqgan_vae import VQGanVAE
Expand Down Expand Up @@ -181,6 +182,35 @@ def forward(self, x, context=None, context_mask=None):
return self.norm(x)


class MemoryEfficientTransformerBlocks(nn.Module):
def __init__(self, *, dim, depth, dim_head=64, heads=8, ff_mult=4):
super().__init__()
self.layers = nn.ModuleList([])

for _ in range(depth):
self.layers.append(
nn.ModuleList(
[
MemAttention(dim=dim, dim_head=dim_head, heads=heads),
MemAttention(dim=dim, dim_head=dim_head, heads=heads),
FeedForward(dim=dim, mult=ff_mult),
]
)
)

self.norm = LayerNorm(dim)

def forward(self, x, context=None, mask=None):
for attn, cross_attn, ff in self.layers:
x = attn(x) + x

x = cross_attn(x, context=context, mask=mask) + x

x = ff(x) + x

return self.norm(x)


# transformer - it's all we need
class Transformer(nn.Module):
def __init__(
Expand All @@ -194,6 +224,7 @@ def __init__(
self_cond: bool = False,
add_mask_id: bool = False,
cache_path: PathLike = None,
memory_efficient: bool = False,
**kwargs,
):
super().__init__()
Expand All @@ -204,8 +235,12 @@ def __init__(
self.token_emb = nn.Embedding(num_tokens + int(add_mask_id), dim)
self.pos_emb = nn.Embedding(seq_len, dim)
self.seq_len = seq_len
self.memory_efficient = memory_efficient

self.transformer_blocks = TransformerBlocks(dim=dim, **kwargs)
if memory_efficient:
self.transformer_blocks = MemoryEfficientTransformerBlocks(dim=dim, **kwargs)
else:
self.transformer_blocks = TransformerBlocks(dim=dim, **kwargs)
self.norm = LayerNorm(dim)

self.dim_out = default(dim_out, num_tokens)
Expand Down Expand Up @@ -315,7 +350,10 @@ def forward(
self_cond_embed = torch.zeros_like(x)
x = x + self.self_cond_to_init_embed(self_cond_embed)

embed = self.transformer_blocks(x, context=context, context_mask=context_mask)
if self.memory_efficient:
embed = self.transformer_blocks(x, context=context, mask=context_mask)
else:
embed = self.transformer_blocks(x, context=context, context_mask=context_mask)

logits = self.to_logits(embed)

Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
"tqdm",
"vector-quantize-pytorch>=0.10.14",
"lion-pytorch",
"memory-efficient-attention-pytorch=0.1.2"
],
classifiers=[
"Development Status :: 4 - Beta",
Expand Down
7 changes: 7 additions & 0 deletions train_muse_maskgit.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,11 @@
default="Adafactor",
help="Optimizer to use. Choose between: ['Adam', 'AdamW', 'Lion', 'Adafactor']. Default: Adafactor (paper recommended)",
)
parser.add_argument(
"--memory_efficient",
action="store_true",
help="whether to use memory efficient attention instead of standard attention",
)
parser.add_argument(
"--weight_decay",
type=float,
Expand Down Expand Up @@ -364,6 +369,7 @@ class Arguments:
taming_model_path: Optional[str] = None
taming_config_path: Optional[str] = None
optimizer: str = "Lion"
memory_efficient: bool = False
weight_decay: float = 0.0
cache_path: Optional[str] = None
skip_arrow: bool = False
Expand Down Expand Up @@ -472,6 +478,7 @@ def main():
# name of your T5 model configuration
t5_name=args.t5_name,
cache_path=args.cache_path,
memory_efficient=args.memory_efficient
)
# (2) pass your trained VAE and the base transformer to MaskGit
maskgit = MaskGit(
Expand Down

0 comments on commit c9aa0af

Please sign in to comment.