Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add FlashAttention #357

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions megatron/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,6 +549,9 @@ def _add_training_args(parser):
group.add_argument('--no-bias-dropout-fusion', action='store_false',
help='Disable bias and dropout fusion.',
dest='bias_dropout_fusion')
group.add_argument('--use-flash-attn', action='store_true',
help='use FlashAttention implementation of attention. '
'https://arxiv.org/abs/2205.14135')
group.add_argument('--optimizer', type=str, default='adam',
choices=['adam', 'sgd'],
help='Optimizer function')
Expand Down
336 changes: 211 additions & 125 deletions megatron/model/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,17 @@
from .glu_activations import GLU_ACTIVATIONS
from .positional_embeddings import RotaryEmbedding, apply_rotary_pos_emb_torch, apply_rotary_pos_emb

try:
from einops import rearrange
except ImportError:
rearrange = None

try:
from flash_attn.flash_attn_interface import flash_attn_unpadded_func
except ImportError:
flash_attn_unpadded_func = None


# flags required to enable jit fusion kernels
torch._C._jit_set_profiling_mode(False)
torch._C._jit_set_profiling_executor(False)
Expand Down Expand Up @@ -113,6 +124,49 @@ def forward(self, hidden_states):
return output, output_bias


class FlashSelfAttention(torch.nn.Module):
"""Implement the scaled dot product attention with softmax.
Arguments
---------
softmax_scale: The temperature to use for the softmax attention.
(default: 1/sqrt(d_keys) where d_keys is computed at
runtime)
attention_dropout: The dropout rate to apply to the attention
(default: 0.0)
"""
def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0,
device=None, dtype=None):
super().__init__()
assert flash_attn_unpadded_func is not None, ('Please install FlashAttention first, '
'e.g., with pip install flash-attn')
assert rearrange is not None, 'Please install einops first, e.g., with pip install einops'
self.causal = causal
self.softmax_scale = softmax_scale
self.dropout_p = attention_dropout

def forward(self, q, k, v):
"""Implements the multihead softmax attention.
Arguments
---------
q, k, v: The tensor containing the query, key, and value. (B, S, H, D)
"""
assert q.dtype in [torch.float16, torch.bfloat16]
assert q.is_cuda
batch_size, seqlen = q.shape[0], q.shape[1]
q, k, v = [rearrange(x, 'b s ... -> (b s) ...') for x in [q, k, v]]
max_s = seqlen
cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32,
device=q.device)
output = flash_attn_unpadded_func(
q, k, v, cu_seqlens, cu_seqlens, max_s, max_s,
self.dropout_p if self.training else 0.0,
softmax_scale=self.softmax_scale, causal=self.causal
)
output = rearrange(output, '(b s) ... -> b s ...', b=batch_size)
return output



class ParallelAttention(MegatronModule):
"""Parallel self-attention layer abstract class.

Expand All @@ -138,6 +192,20 @@ def __init__(self, init_method,
self.attention_type = attention_type
self.attn_mask_type = attn_mask_type

self.use_flash_attn = args.use_flash_attn
if self.use_flash_attn:
if flash_attn_unpadded_func is None:
raise ImportError('FlashAttention is not installed, please install with '
'pip install flash-attn')
assert attention_type == AttnType.self_attn, ('FlashAttention code path only supports '
'self-attention for now')
assert self.attn_mask_type == AttnMaskType.causal, ('FlashAttention code path only '
'supports causal mask for now')
headdim = args.hidden_size / args.num_attention_heads
assert headdim <= 128, 'FlashAttention only supports head dimension at most 128'
if rearrange is None:
raise ImportError('einops is not installed, please install with pip install einops')

projection_size = args.kv_channels * args.num_attention_heads

# Per attention head and per partition values.
Expand Down Expand Up @@ -189,6 +257,12 @@ def __init__(self, init_method,
# on average it should not be partition dependent.
self.attention_dropout = torch.nn.Dropout(args.attention_dropout)

if self.use_flash_attn:
self.core_attention_flash = FlashSelfAttention(
causal=True, attention_dropout=args.attention_dropout
)


# Output.
self.dense = mpu.RowParallelLinear(
projection_size,
Expand Down Expand Up @@ -262,136 +336,148 @@ def forward(self, hidden_states, attention_mask, layer_past=None,
if get_key_value:
present = (key_layer, value_layer)

# ===================================
# Raw attention scores. [b, np, s, s]
# ===================================

# [b, np, sq, sk]
output_size = (query_layer.size(1),
query_layer.size(2),
query_layer.size(0),
key_layer.size(0))

# [sq, b, np, hn] -> [sq, b * np, hn]
query_layer = query_layer.view(output_size[2],
output_size[0] * output_size[1], -1)
# [sk, b, np, hn] -> [sk, b * np, hn]
key_layer = key_layer.view(output_size[3],
output_size[0] * output_size[1], -1)

# preallocting result tensor: [b * np, sq, sk]
if alibi is None:
matmul_result = torch.empty(
output_size[0]*output_size[1],
output_size[2],
output_size[3],
dtype=query_layer.dtype,
device=torch.cuda.current_device())
else:
matmul_result = alibi[:output_size[0]*output_size[1], :, :output_size[3]]
if not self.use_flash_attn:
# ===================================
# Raw attention scores. [b, np, s, s]
# ===================================

# [b, np, sq, sk]
output_size = (query_layer.size(1),
query_layer.size(2),
query_layer.size(0),
key_layer.size(0))

# [sq, b, np, hn] -> [sq, b * np, hn]
query_layer = query_layer.view(output_size[2],
output_size[0] * output_size[1], -1)
# [sk, b, np, hn] -> [sk, b * np, hn]
key_layer = key_layer.view(output_size[3],
output_size[0] * output_size[1], -1)

# preallocting result tensor: [b * np, sq, sk]
if alibi is None:
matmul_result = torch.empty(
output_size[0]*output_size[1],
output_size[2],
output_size[3],
dtype=query_layer.dtype,
device=torch.cuda.current_device())
else:
matmul_result = alibi[:output_size[0]*output_size[1], :, :output_size[3]]

# Rotary embeddings
if self.position_embedding_type == PositionEmbeddingType.rotary:
apply_rotary_fn = apply_rotary_pos_emb_torch if self.bf16 else apply_rotary_pos_emb

seq_len = key_layer.shape[0]
offset = 0
if layer_past is not None and layer_past.numel() > 0:
offset = layer_past[0].shape[0]
seq_len += offset
cos, sin = self.rotary_emb(value_layer, seq_len=seq_len)
query_layer, key_layer = apply_rotary_fn(query_layer, key_layer, cos, sin, offset=offset)

# Raw attention scores. [b * np, sq, sk]
if alibi is None:
matmul_result = torch.baddbmm(
matmul_result,
query_layer.transpose(0, 1), # [b * np, sq, hn]
key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk]
beta=0.0, alpha=(1.0/self.norm_factor))
else:
if not hasattr(self, "logged_alibi"):
logger.debug("Using Alibi.")
self.logged_alibi = True

# Rotary embeddings
if self.position_embedding_type == PositionEmbeddingType.rotary:
apply_rotary_fn = apply_rotary_pos_emb_torch if self.bf16 else apply_rotary_pos_emb

seq_len = key_layer.shape[0]
offset = 0
if layer_past is not None and layer_past.numel() > 0:
offset = layer_past[0].shape[0]
seq_len += offset
cos, sin = self.rotary_emb(value_layer, seq_len=seq_len)
query_layer, key_layer = apply_rotary_fn(query_layer, key_layer, cos, sin, offset=offset)

# Raw attention scores. [b * np, sq, sk]
if alibi is None:
matmul_result = torch.baddbmm(
matmul_result,
query_layer.transpose(0, 1), # [b * np, sq, hn]
key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk]
beta=0.0, alpha=(1.0/self.norm_factor))
else:
if not hasattr(self, "logged_alibi"):
logger.debug("Using Alibi.")
self.logged_alibi = True
if self.apply_query_key_layer_scaling:
beta = 1.0 / self.layer_number
else:
beta = 1.0

if self.apply_query_key_layer_scaling:
beta = 1.0 / self.layer_number
else:
beta = 1.0
matmul_result = torch.baddbmm(
matmul_result,
query_layer.transpose(0, 1), # [b * np, sq, hn]
key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk]
beta=beta, alpha=(1.0 / self.norm_factor))

matmul_result = torch.baddbmm(
matmul_result,
query_layer.transpose(0, 1), # [b * np, sq, hn]
key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk]
beta=beta, alpha=(1.0 / self.norm_factor))
# change view to [b, np, sq, sk]
attention_scores = matmul_result.view(*output_size)
# ==================================================
# Update attention mask for inference. [b, np, sq, sk]
# ==================================================

# change view to [b, np, sq, sk]
attention_scores = matmul_result.view(*output_size)
# ==================================================
# Update attention mask for inference. [b, np, sq, sk]
# ==================================================
if get_key_value:
with torch.no_grad():
# TODO @thomasw21 Handle case where `attention_mask` is None
if layer_past is not None:
attention_mask = attention_mask[
...,
attention_scores.size(3) - 1,
:attention_scores.size(3)].unsqueeze(2)
else:
attention_mask = attention_mask[
...,
:attention_scores.size(3),
:attention_scores.size(3)]

# ===========================
# Attention probs and dropout
# ===========================

# attention scores and attention mask [b, np, sq, sk]
attention_probs = self.scale_mask_softmax(attention_scores,
attention_mask)

# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
with mpu.get_cuda_rng_tracker().fork():
attention_probs = self.attention_dropout(attention_probs)

# =========================
# Context layer. [sq, b, hp]
# =========================

# value_layer -> context layer.
# [sk, b, np, hn] --> [b, np, sq, hn]

# context layer shape: [b, np, sq, hn]
output_size = (value_layer.size(1),
value_layer.size(2),
query_layer.size(0),
value_layer.size(3))

# change view [sk, b * np, hn]
value_layer = value_layer.view(value_layer.size(0),
output_size[0] * output_size[1], -1)

# change view [b * np, sq, sk]
attention_probs = attention_probs.view(output_size[0] * output_size[1],
output_size[2], -1)

# matmul: [b * np, sq, hn]
context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))

# change view [b, np, sq, hn]
context_layer = context_layer.view(*output_size)

# [b, np, sq, hn] --> [sq, b, np, hn]
context_layer = context_layer.permute(2, 0, 1, 3).contiguous()

# [sq, b, np, hn] --> [sq, b, hp]
new_context_layer_shape = context_layer.size()[:-2] + \
(self.hidden_size_per_partition,)
context_layer = context_layer.view(*new_context_layer_shape)

if get_key_value:
with torch.no_grad():
# TODO @thomasw21 Handle case where `attention_mask` is None
if layer_past is not None:
attention_mask = attention_mask[
...,
attention_scores.size(3) - 1,
:attention_scores.size(3)].unsqueeze(2)
else:
attention_mask = attention_mask[
...,
:attention_scores.size(3),
:attention_scores.size(3)]

# ===========================
# Attention probs and dropout
# ===========================

# attention scores and attention mask [b, np, sq, sk]
attention_probs = self.scale_mask_softmax(attention_scores,
attention_mask)

# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
with mpu.get_cuda_rng_tracker().fork():
attention_probs = self.attention_dropout(attention_probs)

# =========================
# Context layer. [sq, b, hp]
# =========================

# value_layer -> context layer.
# [sk, b, np, hn] --> [b, np, sq, hn]

# context layer shape: [b, np, sq, hn]
output_size = (value_layer.size(1),
value_layer.size(2),
query_layer.size(0),
value_layer.size(3))

# change view [sk, b * np, hn]
value_layer = value_layer.view(value_layer.size(0),
output_size[0] * output_size[1], -1)

# change view [b * np, sq, sk]
attention_probs = attention_probs.view(output_size[0] * output_size[1],
output_size[2], -1)

# matmul: [b * np, sq, hn]
context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))

# change view [b, np, sq, hn]
context_layer = context_layer.view(*output_size)

# [b, np, sq, hn] --> [sq, b, np, hn]
context_layer = context_layer.permute(2, 0, 1, 3).contiguous()

# [sq, b, np, hn] --> [sq, b, hp]
new_context_layer_shape = context_layer.size()[:-2] + \
(self.hidden_size_per_partition,)
context_layer = context_layer.view(*new_context_layer_shape)
else:
# =================
# Flash Attention
# =================

q, k, v = [rearrange(x, 's b ... -> b s ...').contiguous()
for x in (query_layer, key_layer, value_layer)]
with mpu.get_cuda_rng_tracker().fork():
context_layer = self.core_attention_flash(q, k, v)
context_layer = rearrange(context_layer, 'b s h d -> s b (h d)').contiguous()

# =================
# Output. [sq, b, h]
Expand Down