diff --git a/megatron/arguments.py b/megatron/arguments.py index c18235a78..c8f6bb3be 100644 --- a/megatron/arguments.py +++ b/megatron/arguments.py @@ -549,6 +549,9 @@ def _add_training_args(parser): group.add_argument('--no-bias-dropout-fusion', action='store_false', help='Disable bias and dropout fusion.', dest='bias_dropout_fusion') + group.add_argument('--use-flash-attn', action='store_true', + help='use FlashAttention implementation of attention. ' + 'https://arxiv.org/abs/2205.14135') group.add_argument('--optimizer', type=str, default='adam', choices=['adam', 'sgd'], help='Optimizer function') diff --git a/megatron/model/transformer.py b/megatron/model/transformer.py index 03e6faaec..7ac3ef9aa 100644 --- a/megatron/model/transformer.py +++ b/megatron/model/transformer.py @@ -33,6 +33,17 @@ from .glu_activations import GLU_ACTIVATIONS from .positional_embeddings import RotaryEmbedding, apply_rotary_pos_emb_torch, apply_rotary_pos_emb +try: + from einops import rearrange +except ImportError: + rearrange = None + +try: + from flash_attn.flash_attn_interface import flash_attn_unpadded_func +except ImportError: + flash_attn_unpadded_func = None + + # flags required to enable jit fusion kernels torch._C._jit_set_profiling_mode(False) torch._C._jit_set_profiling_executor(False) @@ -113,6 +124,49 @@ def forward(self, hidden_states): return output, output_bias +class FlashSelfAttention(torch.nn.Module): + """Implement the scaled dot product attention with softmax. + Arguments + --------- + softmax_scale: The temperature to use for the softmax attention. + (default: 1/sqrt(d_keys) where d_keys is computed at + runtime) + attention_dropout: The dropout rate to apply to the attention + (default: 0.0) + """ + def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0, + device=None, dtype=None): + super().__init__() + assert flash_attn_unpadded_func is not None, ('Please install FlashAttention first, ' + 'e.g., with pip install flash-attn') + assert rearrange is not None, 'Please install einops first, e.g., with pip install einops' + self.causal = causal + self.softmax_scale = softmax_scale + self.dropout_p = attention_dropout + + def forward(self, q, k, v): + """Implements the multihead softmax attention. + Arguments + --------- + q, k, v: The tensor containing the query, key, and value. (B, S, H, D) + """ + assert q.dtype in [torch.float16, torch.bfloat16] + assert q.is_cuda + batch_size, seqlen = q.shape[0], q.shape[1] + q, k, v = [rearrange(x, 'b s ... -> (b s) ...') for x in [q, k, v]] + max_s = seqlen + cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32, + device=q.device) + output = flash_attn_unpadded_func( + q, k, v, cu_seqlens, cu_seqlens, max_s, max_s, + self.dropout_p if self.training else 0.0, + softmax_scale=self.softmax_scale, causal=self.causal + ) + output = rearrange(output, '(b s) ... -> b s ...', b=batch_size) + return output + + + class ParallelAttention(MegatronModule): """Parallel self-attention layer abstract class. @@ -138,6 +192,20 @@ def __init__(self, init_method, self.attention_type = attention_type self.attn_mask_type = attn_mask_type + self.use_flash_attn = args.use_flash_attn + if self.use_flash_attn: + if flash_attn_unpadded_func is None: + raise ImportError('FlashAttention is not installed, please install with ' + 'pip install flash-attn') + assert attention_type == AttnType.self_attn, ('FlashAttention code path only supports ' + 'self-attention for now') + assert self.attn_mask_type == AttnMaskType.causal, ('FlashAttention code path only ' + 'supports causal mask for now') + headdim = args.hidden_size / args.num_attention_heads + assert headdim <= 128, 'FlashAttention only supports head dimension at most 128' + if rearrange is None: + raise ImportError('einops is not installed, please install with pip install einops') + projection_size = args.kv_channels * args.num_attention_heads # Per attention head and per partition values. @@ -189,6 +257,12 @@ def __init__(self, init_method, # on average it should not be partition dependent. self.attention_dropout = torch.nn.Dropout(args.attention_dropout) + if self.use_flash_attn: + self.core_attention_flash = FlashSelfAttention( + causal=True, attention_dropout=args.attention_dropout + ) + + # Output. self.dense = mpu.RowParallelLinear( projection_size, @@ -262,136 +336,148 @@ def forward(self, hidden_states, attention_mask, layer_past=None, if get_key_value: present = (key_layer, value_layer) - # =================================== - # Raw attention scores. [b, np, s, s] - # =================================== - - # [b, np, sq, sk] - output_size = (query_layer.size(1), - query_layer.size(2), - query_layer.size(0), - key_layer.size(0)) - - # [sq, b, np, hn] -> [sq, b * np, hn] - query_layer = query_layer.view(output_size[2], - output_size[0] * output_size[1], -1) - # [sk, b, np, hn] -> [sk, b * np, hn] - key_layer = key_layer.view(output_size[3], - output_size[0] * output_size[1], -1) - - # preallocting result tensor: [b * np, sq, sk] - if alibi is None: - matmul_result = torch.empty( - output_size[0]*output_size[1], - output_size[2], - output_size[3], - dtype=query_layer.dtype, - device=torch.cuda.current_device()) - else: - matmul_result = alibi[:output_size[0]*output_size[1], :, :output_size[3]] + if not self.use_flash_attn: + # =================================== + # Raw attention scores. [b, np, s, s] + # =================================== + + # [b, np, sq, sk] + output_size = (query_layer.size(1), + query_layer.size(2), + query_layer.size(0), + key_layer.size(0)) + + # [sq, b, np, hn] -> [sq, b * np, hn] + query_layer = query_layer.view(output_size[2], + output_size[0] * output_size[1], -1) + # [sk, b, np, hn] -> [sk, b * np, hn] + key_layer = key_layer.view(output_size[3], + output_size[0] * output_size[1], -1) + + # preallocting result tensor: [b * np, sq, sk] + if alibi is None: + matmul_result = torch.empty( + output_size[0]*output_size[1], + output_size[2], + output_size[3], + dtype=query_layer.dtype, + device=torch.cuda.current_device()) + else: + matmul_result = alibi[:output_size[0]*output_size[1], :, :output_size[3]] + + # Rotary embeddings + if self.position_embedding_type == PositionEmbeddingType.rotary: + apply_rotary_fn = apply_rotary_pos_emb_torch if self.bf16 else apply_rotary_pos_emb + + seq_len = key_layer.shape[0] + offset = 0 + if layer_past is not None and layer_past.numel() > 0: + offset = layer_past[0].shape[0] + seq_len += offset + cos, sin = self.rotary_emb(value_layer, seq_len=seq_len) + query_layer, key_layer = apply_rotary_fn(query_layer, key_layer, cos, sin, offset=offset) + + # Raw attention scores. [b * np, sq, sk] + if alibi is None: + matmul_result = torch.baddbmm( + matmul_result, + query_layer.transpose(0, 1), # [b * np, sq, hn] + key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk] + beta=0.0, alpha=(1.0/self.norm_factor)) + else: + if not hasattr(self, "logged_alibi"): + logger.debug("Using Alibi.") + self.logged_alibi = True - # Rotary embeddings - if self.position_embedding_type == PositionEmbeddingType.rotary: - apply_rotary_fn = apply_rotary_pos_emb_torch if self.bf16 else apply_rotary_pos_emb - - seq_len = key_layer.shape[0] - offset = 0 - if layer_past is not None and layer_past.numel() > 0: - offset = layer_past[0].shape[0] - seq_len += offset - cos, sin = self.rotary_emb(value_layer, seq_len=seq_len) - query_layer, key_layer = apply_rotary_fn(query_layer, key_layer, cos, sin, offset=offset) - - # Raw attention scores. [b * np, sq, sk] - if alibi is None: - matmul_result = torch.baddbmm( - matmul_result, - query_layer.transpose(0, 1), # [b * np, sq, hn] - key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk] - beta=0.0, alpha=(1.0/self.norm_factor)) - else: - if not hasattr(self, "logged_alibi"): - logger.debug("Using Alibi.") - self.logged_alibi = True + if self.apply_query_key_layer_scaling: + beta = 1.0 / self.layer_number + else: + beta = 1.0 - if self.apply_query_key_layer_scaling: - beta = 1.0 / self.layer_number - else: - beta = 1.0 + matmul_result = torch.baddbmm( + matmul_result, + query_layer.transpose(0, 1), # [b * np, sq, hn] + key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk] + beta=beta, alpha=(1.0 / self.norm_factor)) - matmul_result = torch.baddbmm( - matmul_result, - query_layer.transpose(0, 1), # [b * np, sq, hn] - key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk] - beta=beta, alpha=(1.0 / self.norm_factor)) + # change view to [b, np, sq, sk] + attention_scores = matmul_result.view(*output_size) + # ================================================== + # Update attention mask for inference. [b, np, sq, sk] + # ================================================== - # change view to [b, np, sq, sk] - attention_scores = matmul_result.view(*output_size) - # ================================================== - # Update attention mask for inference. [b, np, sq, sk] - # ================================================== + if get_key_value: + with torch.no_grad(): + # TODO @thomasw21 Handle case where `attention_mask` is None + if layer_past is not None: + attention_mask = attention_mask[ + ..., + attention_scores.size(3) - 1, + :attention_scores.size(3)].unsqueeze(2) + else: + attention_mask = attention_mask[ + ..., + :attention_scores.size(3), + :attention_scores.size(3)] + + # =========================== + # Attention probs and dropout + # =========================== + + # attention scores and attention mask [b, np, sq, sk] + attention_probs = self.scale_mask_softmax(attention_scores, + attention_mask) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + with mpu.get_cuda_rng_tracker().fork(): + attention_probs = self.attention_dropout(attention_probs) + + # ========================= + # Context layer. [sq, b, hp] + # ========================= + + # value_layer -> context layer. + # [sk, b, np, hn] --> [b, np, sq, hn] + + # context layer shape: [b, np, sq, hn] + output_size = (value_layer.size(1), + value_layer.size(2), + query_layer.size(0), + value_layer.size(3)) + + # change view [sk, b * np, hn] + value_layer = value_layer.view(value_layer.size(0), + output_size[0] * output_size[1], -1) + + # change view [b * np, sq, sk] + attention_probs = attention_probs.view(output_size[0] * output_size[1], + output_size[2], -1) + + # matmul: [b * np, sq, hn] + context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1)) + + # change view [b, np, sq, hn] + context_layer = context_layer.view(*output_size) + + # [b, np, sq, hn] --> [sq, b, np, hn] + context_layer = context_layer.permute(2, 0, 1, 3).contiguous() + + # [sq, b, np, hn] --> [sq, b, hp] + new_context_layer_shape = context_layer.size()[:-2] + \ + (self.hidden_size_per_partition,) + context_layer = context_layer.view(*new_context_layer_shape) - if get_key_value: - with torch.no_grad(): - # TODO @thomasw21 Handle case where `attention_mask` is None - if layer_past is not None: - attention_mask = attention_mask[ - ..., - attention_scores.size(3) - 1, - :attention_scores.size(3)].unsqueeze(2) - else: - attention_mask = attention_mask[ - ..., - :attention_scores.size(3), - :attention_scores.size(3)] - - # =========================== - # Attention probs and dropout - # =========================== - - # attention scores and attention mask [b, np, sq, sk] - attention_probs = self.scale_mask_softmax(attention_scores, - attention_mask) - - # This is actually dropping out entire tokens to attend to, which might - # seem a bit unusual, but is taken from the original Transformer paper. - with mpu.get_cuda_rng_tracker().fork(): - attention_probs = self.attention_dropout(attention_probs) - - # ========================= - # Context layer. [sq, b, hp] - # ========================= - - # value_layer -> context layer. - # [sk, b, np, hn] --> [b, np, sq, hn] - - # context layer shape: [b, np, sq, hn] - output_size = (value_layer.size(1), - value_layer.size(2), - query_layer.size(0), - value_layer.size(3)) - - # change view [sk, b * np, hn] - value_layer = value_layer.view(value_layer.size(0), - output_size[0] * output_size[1], -1) - - # change view [b * np, sq, sk] - attention_probs = attention_probs.view(output_size[0] * output_size[1], - output_size[2], -1) - - # matmul: [b * np, sq, hn] - context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1)) - - # change view [b, np, sq, hn] - context_layer = context_layer.view(*output_size) - - # [b, np, sq, hn] --> [sq, b, np, hn] - context_layer = context_layer.permute(2, 0, 1, 3).contiguous() - - # [sq, b, np, hn] --> [sq, b, hp] - new_context_layer_shape = context_layer.size()[:-2] + \ - (self.hidden_size_per_partition,) - context_layer = context_layer.view(*new_context_layer_shape) + else: + # ================= + # Flash Attention + # ================= + + q, k, v = [rearrange(x, 's b ... -> b s ...').contiguous() + for x in (query_layer, key_layer, value_layer)] + with mpu.get_cuda_rng_tracker().fork(): + context_layer = self.core_attention_flash(q, k, v) + context_layer = rearrange(context_layer, 'b s h d -> s b (h d)').contiguous() # ================= # Output. [sq, b, h]