Skip to content

Commit

Permalink
Config for a big Reformer2 model, attention chunking and weight donat…
Browse files Browse the repository at this point in the history
…ion to improve memory use.

PiperOrigin-RevId: 338114984
  • Loading branch information
Lukasz Kaiser authored and copybara-github committed Oct 20, 2020
1 parent eb9de81 commit 88ed0f4
Show file tree
Hide file tree
Showing 4 changed files with 246 additions and 19 deletions.
23 changes: 19 additions & 4 deletions trax/models/reformer/reformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@

def DecoderBlock(d_model, d_ff, d_attention_key, d_attention_value,
n_heads, attention_type, dropout, ff_activation,
ff_dropout, ff_use_sru, ff_chunk_size, ff_sparsity, mode):
ff_dropout, ff_use_sru, ff_chunk_size, ff_sparsity,
attention_chunk_size, mode):
"""Reversible transformer decoder layer.
Args:
Expand All @@ -44,14 +45,16 @@ def DecoderBlock(d_model, d_ff, d_attention_key, d_attention_value,
ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward
ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks
ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity
attention_chunk_size: int, if > 0 run attention chunked at this size
mode: str: 'train' or 'eval'
Returns:
the layer.
"""
attention = configurable_transformer.ApplyAttentionLayer(
attention_type, d_model, n_heads, d_attention_key, d_attention_value,
True, False, dropout, dropout, mode)
True, False, dropout, dropout, attention_chunk_size, mode)
attention_half_residual = tl.ReversibleHalfResidual(
tl.LayerNorm(),
attention_layer=attention,
Expand Down Expand Up @@ -110,6 +113,7 @@ def ReformerLM(vocab_size,
ff_use_sru=0,
ff_chunk_size=0,
ff_sparsity=0,
attention_chunk_size=0,
mode='train'):
"""Reversible transformer language model (only uses a decoder, no encoder).
Expand All @@ -132,6 +136,7 @@ def ReformerLM(vocab_size,
ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward
ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks
ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity
attention_chunk_size: int, if > 0 run attention chunked at this size
mode: str: 'train', 'eval', or 'predict'
Returns:
Expand Down Expand Up @@ -163,6 +168,7 @@ def ReformerLM(vocab_size,
ff_use_sru=ff_use_sru,
ff_chunk_size=ff_chunk_size,
ff_sparsity=ff_sparsity,
attention_chunk_size=attention_chunk_size,
mode=mode)
decoder_blocks.append(decoder_block)

Expand Down Expand Up @@ -199,6 +205,7 @@ def ReformerShortenLM(vocab_size,
ff_use_sru=0,
ff_chunk_size=0,
ff_sparsity=0,
attention_chunk_size=0,
mode='train'):
"""Reversible transformer language model with shortening.
Expand Down Expand Up @@ -233,6 +240,7 @@ def ReformerShortenLM(vocab_size,
ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward
ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks
ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity
attention_chunk_size: int, if > 0 run attention chunked at this size
mode: str: 'train' or 'eval'
Returns:
Expand Down Expand Up @@ -273,6 +281,7 @@ def ReformerShortenLM(vocab_size,
ff_use_sru=ff_use_sru,
ff_chunk_size=ff_chunk_size,
ff_sparsity=ff_sparsity,
attention_chunk_size=attention_chunk_size,
mode=mode)
decoder_blocks.append(decoder_block)

Expand Down Expand Up @@ -313,7 +322,7 @@ def ReformerShortenLM(vocab_size,

def EncoderBlock(d_model, d_ff, n_heads, attention_type, dropout, ff_activation,
ff_dropout, ff_use_sru=0, ff_chunk_size=0, ff_sparsity=0,
mode='train'):
attention_chunk_size=0, mode='train'):
"""Returns a list of layers that implements a Reformer encoder block.
The input to the layer is a pair, (activations, mask), where the mask was
Expand All @@ -331,6 +340,7 @@ def EncoderBlock(d_model, d_ff, n_heads, attention_type, dropout, ff_activation,
ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward
ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks
ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity
attention_chunk_size: int, if > 0 run attention chunked at this size
mode: str: 'train' or 'eval'
Returns:
Expand All @@ -345,7 +355,8 @@ def EncoderBlock(d_model, d_ff, n_heads, attention_type, dropout, ff_activation,
attention = configurable_transformer.ApplyAttentionLayer(
attention_type=attention_type, d_model=d_model, n_heads=n_heads,
d_qk=d_model//n_heads, d_v=d_model//n_heads, masked=True, causal=False,
attention_dropout=dropout, output_dropout=dropout, mode=mode)
attention_dropout=dropout, output_dropout=dropout,
attention_chunk_size=attention_chunk_size, mode=mode)
attention_half_residual = tl.ReversibleHalfResidual(
tl.LayerNorm(),
attention_layer=attention,
Expand Down Expand Up @@ -542,6 +553,7 @@ def Reformer2(input_vocab_size,
ff_chunk_size=0,
ff_dropout=None,
ff_sparsity=0,
attention_chunk_size=0,
n_layers_forget=0,
mode='train'):
"""Reversible transformer encoder-decoder model.
Expand Down Expand Up @@ -577,6 +589,7 @@ def Reformer2(input_vocab_size,
ff_dropout: float: (optional) separate dropout rate at feed-forward
nonlinearity. This is called relu_dropout in T2T.
ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity
attention_chunk_size: int, if > 0 run attention chunked at this size
n_layers_forget: how often to have a forgetting block between layers
mode: str: 'train' or 'eval'
Expand Down Expand Up @@ -628,6 +641,7 @@ def PositionalEnc(mode):
ff_use_sru=ff_use_sru,
ff_chunk_size=ff_chunk_size,
ff_sparsity=ff_sparsity,
attention_chunk_size=attention_chunk_size,
mode=mode)
for _ in range(n_encoder_layers)]
# pylint: enable=g-complex-comprehension
Expand Down Expand Up @@ -660,6 +674,7 @@ def PositionalEnc(mode):
ff_use_sru=ff_use_sru,
ff_chunk_size=ff_chunk_size,
ff_sparsity=ff_sparsity,
attention_chunk_size=attention_chunk_size,
mode=mode)
decoder_blocks.append(decoder_block)

Expand Down
36 changes: 27 additions & 9 deletions trax/models/research/configurable_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,8 @@ def FeedForwardWithOptions(d_model,

# TODO(lukaszkaiser): unify attention layers API and remove this branch
def ApplyAttentionLayer(attention_type, d_model, n_heads, d_qk, d_v, causal,
masked, attention_dropout, output_dropout, mode):
masked, attention_dropout, output_dropout,
attention_chunk_size, mode):
"""Runs the supplied attention layer."""
try:
attention = attention_type(
Expand All @@ -131,7 +132,7 @@ def ApplyAttentionLayer(attention_type, d_model, n_heads, d_qk, d_v, causal,
except TypeError: # No d_qk arguments in less advanced layers.
attention = attention_type(
d_model, n_heads=n_heads, dropout=attention_dropout, mode=mode)
return attention
return tl.Chunk(attention, attention_chunk_size)


def ConfigurableTransformerEncoder(vocab_size,
Expand All @@ -150,6 +151,7 @@ def ConfigurableTransformerEncoder(vocab_size,
ff_use_sru=0,
ff_sparsity=0,
ff_sparsity_type='1inN',
attention_chunk_size=0,
attention_type=tl.Attention):
"""Returns a Transformer encoder merged with an N-way categorization head.
Expand Down Expand Up @@ -195,6 +197,7 @@ def ConfigurableTransformerEncoder(vocab_size,
ff_sparsity_type: string, if ff_sparsity >0,
use SparseFF if ff_sparsity_type=`'1inN'` and
use BlockSparseFF if ff_sparsity_type=`'Block'`
attention_chunk_size: int, if > 0 run attention chunked at this size
attention_type: The attention layer to use for the encoder part.
Returns:
Expand All @@ -211,7 +214,8 @@ def ConfigurableTransformerEncoder(vocab_size,
encoder_blocks = [
_EncoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes, mode,
ff_activation, ff_dropout, ff_chunk_size, ff_use_sru,
ff_sparsity, ff_sparsity_type, attention_type)
ff_sparsity, ff_sparsity_type,
attention_chunk_size, attention_type)
for i in range(n_layers)
]
# pylint: enable=g-complex-comprehension
Expand Down Expand Up @@ -247,6 +251,7 @@ def ConfigurableTransformerLM(vocab_size,
ff_use_sru=0,
ff_sparsity=0,
ff_sparsity_type='1inN',
attention_chunk_size=0,
attention_type=tl.CausalAttention):
"""Returns a Transformer language model.
Expand Down Expand Up @@ -293,6 +298,7 @@ def ConfigurableTransformerLM(vocab_size,
ff_sparsity_type: string, if ff_sparsity >0,
use SparseFF if ff_sparsity_type=`'1inN'` and
use BlockSparseFF if ff_sparsity_type=`'Block'`
attention_chunk_size: int, if > 0 run attention chunked at this size
attention_type: The attention layer to use for the decoder part.
Returns:
Expand All @@ -309,7 +315,8 @@ def ConfigurableTransformerLM(vocab_size,
decoder_blocks = [
_DecoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes, mode,
ff_activation, ff_dropout, ff_chunk_size, ff_use_sru,
ff_sparsity, ff_sparsity_type, attention_type)
ff_sparsity, ff_sparsity_type,
attention_chunk_size, attention_type)
for i in range(n_layers)
]
# pylint: enable=g-complex-comprehension
Expand Down Expand Up @@ -342,6 +349,7 @@ def ConfigurableTransformer(input_vocab_size,
ff_use_sru=0,
ff_sparsity=0,
ff_sparsity_type='1inN',
attention_chunk_size=0,
encoder_attention_type=tl.Attention,
encoder_decoder_attention_type=tl.CausalAttention):
"""Returns a full Transformer model.
Expand Down Expand Up @@ -402,6 +410,7 @@ def ConfigurableTransformer(input_vocab_size,
ff_sparsity_type: string, if ff_sparsity >0,
use SparseFF if ff_sparsity_type=`'1inN'` and
use BlockSparseFF if ff_sparsity_type=`'Block'`
attention_chunk_size: int, if > 0 run attention chunked at this size
encoder_attention_type: The attention layer to use for the encoder part.
encoder_decoder_attention_type: The attention layer to use for the
encoder-decoder attention.
Expand Down Expand Up @@ -438,7 +447,8 @@ def Embedder(vocab_size): # tokens --> vectors
encoder_blocks = [
_EncoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes, mode,
ff_activation, ff_dropout, ff_chunk_size, ff_use_sru,
ff_sparsity, ff_sparsity_type, encoder_attention_type)
ff_sparsity, ff_sparsity_type,
attention_chunk_size, encoder_attention_type)
for i in range(n_encoder_layers)
]
# pylint: enable=g-complex-comprehension
Expand All @@ -452,7 +462,7 @@ def Embedder(vocab_size): # tokens --> vectors
_EncoderDecoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes,
mode, ff_activation, ff_dropout, ff_chunk_size,
ff_use_sru, ff_sparsity, ff_sparsity_type,
encoder_decoder_attention_type)
attention_chunk_size, encoder_decoder_attention_type)
for i in range(n_decoder_layers)
]
# pylint: enable=g-complex-comprehension
Expand Down Expand Up @@ -485,7 +495,8 @@ def Embedder(vocab_size): # tokens --> vectors

def _EncoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes, mode,
ff_activation, ff_dropout, ff_chunk_size, ff_use_sru,
ff_sparsity, ff_sparsity_type, attention_type):
ff_sparsity, ff_sparsity_type,
attention_chunk_size, attention_type):
"""Returns a list of layers that implements a Transformer encoder block.
The input to the block is a pair, (activations, mask), where the mask was
Expand Down Expand Up @@ -515,6 +526,7 @@ def _EncoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes, mode,
ff_sparsity_type: string, if ff_sparsity >0,
use SparseFF if ff_sparsity_type=`'1inN'` and
use BlockSparseFF if ff_sparsity_type=`'Block'`
attention_chunk_size: int, if > 0 run attention chunked at this size
attention_type: The attention layer to use.
Returns:
Expand All @@ -530,6 +542,7 @@ def _EncoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes, mode,
masked=True,
attention_dropout=dropout,
output_dropout=dropout,
attention_chunk_size=attention_chunk_size,
mode=mode)

feed_forward = FeedForwardWithOptions(d_model, d_ff, dropout,
Expand All @@ -552,7 +565,8 @@ def _EncoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes, mode,

def _DecoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes, mode,
ff_activation, ff_dropout, ff_chunk_size, ff_use_sru,
ff_sparsity, ff_sparsity_type, attention_type):
ff_sparsity, ff_sparsity_type,
attention_chunk_size, attention_type):
"""Returns a list of layers that implements a Transformer decoder block.
The input is an activation tensor.
Expand Down Expand Up @@ -580,6 +594,7 @@ def _DecoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes, mode,
ff_sparsity_type: string, if ff_sparsity >0,
use SparseFF if ff_sparsity_type=`'1inN'` and
use BlockSparseFF if ff_sparsity_type=`'Block'`
attention_chunk_size: int, if > 0 run attention chunked at this size
attention_type: The attention layer to use.
Returns:
Expand All @@ -595,6 +610,7 @@ def _DecoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes, mode,
masked=False,
attention_dropout=dropout,
output_dropout=dropout,
attention_chunk_size=attention_chunk_size,
mode=mode)

feed_forward = FeedForwardWithOptions(d_model, d_ff, dropout,
Expand All @@ -618,7 +634,7 @@ def _DecoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes, mode,
def _EncoderDecoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes,
mode, ff_activation, ff_dropout, ff_chunk_size,
ff_use_sru, ff_sparsity, ff_sparsity_type,
attention_type):
attention_chunk_size, attention_type):
"""Returns a list of layers implementing a Transformer encoder-decoder block.
The input is a triple (decoder_activations, mask, encoder_activiations) where
Expand Down Expand Up @@ -648,6 +664,7 @@ def _EncoderDecoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes,
ff_sparsity_type: string, if ff_sparsity >0,
use SparseFF if ff_sparsity_type=`'1inN'` and
use BlockSparseFF if ff_sparsity_type=`'Block'`
attention_chunk_size: int, if > 0 run attention chunked at this size
attention_type: The attention layer to use.
Returns:
Expand Down Expand Up @@ -677,6 +694,7 @@ def _Dropout():
masked=True,
attention_dropout=dropout,
output_dropout=dropout,
attention_chunk_size=attention_chunk_size,
mode=mode)

feed_forward = FeedForwardWithOptions(d_model, d_ff, dropout,
Expand Down
13 changes: 7 additions & 6 deletions trax/optimizers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,12 +284,13 @@ def _make_optimizer(layer):
rev_and_fbos = []
for layer, opt in zip(rev_layers, rev_opts):
rev_and_fbos.append(self._pjit(_reverse_and_fbo_with_layer_and_opt(
layer, opt, self._n_devices)))
self._fbos.append((self._pjit(std_fbo), rev_and_fbos))
layer, opt, self._n_devices), donate_argnums=(1,)))
self._fbos.append(
(self._pjit(std_fbo, donate_argnums=(1,)), rev_and_fbos))

loss_fbo = _fbo_with_layer_and_opt(
self._loss_layer, self._loss_opt, self._n_devices, 'loss')
self._loss_fbo = self._pjit(loss_fbo)
self._loss_fbo = self._pjit(loss_fbo, donate_argnums=(1,))

@property
def loss_layer(self):
Expand All @@ -313,12 +314,12 @@ def slots(self, slots):
for (opt, slot) in zip([s_opt] + r_opts, [s_slots] + r_slots):
opt.slots = slot

def _pjit(self, f):
def _pjit(self, f, donate_argnums=()):
"""JIT f if 1 device is available and pmap if more are available."""
if self._n_devices == 1:
return fastmath.jit(f)
return fastmath.jit(f, donate_argnums=donate_argnums)
else:
return fastmath.pmap(f, axis_name='batch')
return fastmath.pmap(f, axis_name='batch', donate_argnums=donate_argnums)

def _replicate(self, x):
if self._n_devices > 1:
Expand Down
Loading

0 comments on commit 88ed0f4

Please sign in to comment.