Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Paged Attention support for FA3 #1268

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
391 changes: 391 additions & 0 deletions hopper/copy_paged_sm90_tma.hpp

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions hopper/flash.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ struct Flash_fwd_params : public Qkv_params {
int * __restrict__ block_table;
index_t block_table_batch_stride;
int page_block_size;
int page_num_blocks;

// The dropout probability (probability of keeping an activation).
float p_dropout;
Expand Down
45 changes: 39 additions & 6 deletions hopper/flash_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ void set_params_fprop(Flash_fwd_params &params,
params.is_bf16 = q.dtype() == torch::kBFloat16;
params.is_e4m3 = q.dtype() == torch::kFloat8_e4m3fn;
params.is_kv_cache = false;

params.page_num_blocks = 0;
// Set the pointers and strides.
params.q_ptr = q.data_ptr();
params.k_ptr = k.data_ptr();
Expand Down Expand Up @@ -214,6 +214,7 @@ void set_params_dgrad(Flash_bwd_params &params,
params.dq_ptr = dq.data_ptr();
params.dk_ptr = dk.data_ptr();
params.dv_ptr = dv.data_ptr();
params.page_num_blocks = 0;
params.dq_row_stride = dq.stride(-3);
params.dk_row_stride = dk.stride(-3);
params.dv_row_stride = dv.stride(-3);
Expand Down Expand Up @@ -636,6 +637,7 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
const at::Tensor &cu_seqlens_k, // b+1
c10::optional<at::Tensor> &seqused_q, // b. If given, only this many elements of each batch element's queries and outputs are used.
c10::optional<at::Tensor> &seqused_k, // b. If given, only this many elements of each batch element's keys are used.
std::optional<at::Tensor> &block_table_, // batch_size x max_num_blocks_per_seq
int max_seqlen_q,
const int max_seqlen_k,
const float softmax_scale,
Expand Down Expand Up @@ -665,25 +667,46 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
CHECK_CONTIGUOUS(cu_seqlens_q);
CHECK_CONTIGUOUS(cu_seqlens_k);

at::Tensor block_table;
const bool paged_KV = block_table_.has_value();
if (paged_KV) {
block_table = block_table_.value();
CHECK_DEVICE(block_table);
TORCH_CHECK(block_table.dtype() == torch::kInt32, "block_table must have dtype torch.int32");
TORCH_CHECK(block_table.stride(-1) == 1, "block_table must have contiguous last dimension");
}

const auto sizes = q.sizes();

const int batch_size = cu_seqlens_q.numel() - 1;
int num_heads = sizes[1];
const int head_size_og = sizes[2];
const int num_heads_k = k.size(1);
const int num_heads_k = paged_KV ? k.size(2) : k.size(1);

void *cu_seqlens_q_d = cu_seqlens_q.data_ptr();

const int total_q = q.sizes()[0];

const int max_num_blocks_per_seq = !paged_KV ? 0 : block_table.size(1);
const int num_blocks = !paged_KV ? 0 : k.size(0);
const int page_block_size = !paged_KV ? -1 : k.size(1);
TORCH_CHECK(!paged_KV || page_block_size % 256 == 0, "Paged KV cache block size must be divisible by 256");

TORCH_CHECK(batch_size > 0, "batch size must be positive");
TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256");
TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");

CHECK_SHAPE(q, total_q, num_heads, head_size_og);
const int total_k = k.size(0);
CHECK_SHAPE(k, total_k, num_heads_k, head_size_og);
CHECK_SHAPE(v, total_k, num_heads_k, head_size_og);

if (!paged_KV) {
CHECK_SHAPE(k, total_k, num_heads_k, head_size_og);
CHECK_SHAPE(v, total_k, num_heads_k, head_size_og);
} else {
CHECK_SHAPE(k, num_blocks, page_block_size, num_heads_k, head_size_og);
CHECK_SHAPE(v, num_blocks, page_block_size, num_heads_k, head_size_og);
CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq);
}

CHECK_SHAPE(cu_seqlens_q, batch_size + 1);
if (seqused_q.has_value()){
Expand Down Expand Up @@ -764,6 +787,17 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
params.total_q = total_q;
params.total_k = total_k;

if (paged_KV) {
params.block_table = block_table.data_ptr<int>();
params.block_table_batch_stride = block_table.stride(0);
params.k_batch_stride = k.stride(0);
params.v_batch_stride = v.stride(0);
params.page_num_blocks = k.size(0);
}
params.page_block_size = page_block_size;
params.page_num_blocks = num_blocks;

//printf("mha_varlen_fwd: params.seqlen_k=%d, max_seqlen_k=%d, params.page_num_blocks=%d\n", (int)params.seqlen_k, (int)max_seqlen_k, (int)params.page_num_blocks);
if (max_seqlen_k > 0) {
auto stream = at::cuda::getCurrentCUDAStream().stream();
run_mha_fwd(params, stream);
Expand All @@ -778,7 +812,6 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
out = out.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
if (out_.has_value()) { out_.value().copy_(out); }
}

return {out, q_padded, k_padded, v_padded, out_padded, softmax_lse};
}

Expand Down
94 changes: 71 additions & 23 deletions hopper/flash_attn_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,32 @@

from typing import Optional, Union

import torch
import torch.nn as nn

# isort: off
# We need to import the CUDA kernels after importing torch
import flashattn_hopper_cuda

import torch
import torch.nn as nn

# isort: on


def maybe_contiguous(x):
return x.contiguous() if x is not None and x.stride(-1) != 1 else x

def _flash_attn_forward(q, k, v, softmax_scale, causal, window_size, descale_q = None, descale_k = None, descale_v = None, gqa_parallel=False):

def _flash_attn_forward(
q,
k,
v,
softmax_scale,
causal,
window_size,
descale_q=None,
descale_k=None,
descale_v=None,
gqa_parallel=False,
):
q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
out, q, k, v, out_padded, softmax_lse, S_dmask = flashattn_hopper_cuda.fwd(
q,
Expand All @@ -28,7 +41,7 @@ def _flash_attn_forward(q, k, v, softmax_scale, causal, window_size, descale_q =
causal,
window_size[0],
window_size[1],
gqa_parallel
gqa_parallel,
)
return out, q, k, v, out_padded, softmax_lse, S_dmask

Expand All @@ -46,7 +59,7 @@ def _flash_attn_backward(
softmax_scale,
causal,
window_size,
deterministic=False
deterministic=False,
):
# dq, dk, dv are allocated by us so they should already be contiguous
dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
Expand All @@ -68,6 +81,7 @@ def _flash_attn_backward(
)
return dq, dk, dv, softmax_d


def _flash_attn_varlen_forward(
q,
k,
Expand All @@ -78,10 +92,14 @@ def _flash_attn_varlen_forward(
max_seqlen_k,
softmax_scale,
causal,
block_table,
window_size=(-1, -1),
seqused_q=None,
seqused_k=None,
):
assert (
block_table is None or k.dtype != torch.float8_e4m3fn
), "Paged Attention / block_table is not supported for fp8 just yet"
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
out, q, k, v, out_padded, softmax_lse = flashattn_hopper_cuda.varlen_fwd(
Expand All @@ -93,6 +111,7 @@ def _flash_attn_varlen_forward(
cu_seqlens_k,
seqused_q,
seqused_k,
block_table,
max_seqlen_q,
max_seqlen_k,
softmax_scale,
Expand Down Expand Up @@ -189,7 +208,7 @@ def forward(
window_size,
descale_q=descale_q,
descale_k=descale_k,
descale_v=descale_v,
descale_v=descale_v,
gqa_parallel=gqa_parallel,
)
ctx.save_for_backward(q, k, v, out_padded, softmax_lse)
Expand Down Expand Up @@ -242,6 +261,7 @@ def forward(
deterministic=False,
seqused_q=None,
seqused_k=None,
block_table=None,
):
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
Expand All @@ -258,10 +278,18 @@ def forward(
window_size=window_size,
seqused_q=seqused_q,
seqused_k=seqused_k,
block_table=block_table,
)
ctx.save_for_backward(
q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k,
seqused_q, seqused_k
q,
k,
v,
out_padded,
softmax_lse,
cu_seqlens_q,
cu_seqlens_k,
seqused_q,
seqused_k,
)
ctx.max_seqlen_q = max_seqlen_q
ctx.max_seqlen_k = max_seqlen_k
Expand All @@ -273,7 +301,9 @@ def forward(

@staticmethod
def backward(ctx, dout, *args):
q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k = ctx.saved_tensors
q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k = (
ctx.saved_tensors
)
dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
_flash_attn_varlen_backward(
dout,
Expand All @@ -299,7 +329,22 @@ def backward(ctx, dout, *args):
dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
dk = dk[..., : dout.shape[-1]]
dv = dv[..., : dout.shape[-1]]
return dq, dk, dv, None, None, None, None, None, None, None, None, None, None
return (
dq,
dk,
dv,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
)


def flash_attn_func(
Expand Down Expand Up @@ -377,7 +422,7 @@ def flash_attn_func(
descale_q,
descale_k,
descale_v,
gqa_parallel
gqa_parallel,
)


Expand All @@ -395,6 +440,7 @@ def flash_attn_varlen_func(
deterministic=False,
seqused_q=None,
seqused_k=None,
block_table=None,
):
"""
Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads
Expand Down Expand Up @@ -430,6 +476,7 @@ def flash_attn_varlen_func(
query and output tokens in each sequence.
seqused_k: (batch_size,), dtype torch.int32. If not None, it defines the actual number of
key and value tokens in each sequence.
block_table: Optional block_table of dtype int32 and shape [batch_size, num_blocks_per_seq] used for paged attention.
Return:
out: (total, nheads, headdim).
softmax_lse [optional, if return_attn_probs=True]: (nheads, total_q_seqlen). The
Expand All @@ -450,6 +497,7 @@ def flash_attn_varlen_func(
deterministic,
seqused_q,
seqused_k,
block_table,
)


Expand Down Expand Up @@ -571,15 +619,15 @@ def flash_attn_with_kvcache(
"""

# unimplemented kwargs
k=None
v=None
rotary_cos=None
rotary_sin=None
cache_leftpad=None
block_table=None
softcap=0.0
rotary_interleaved=True
alibi_slopes=None
k = None
v = None
rotary_cos = None
rotary_sin = None
cache_leftpad = None
block_table = None
softcap = 0.0
rotary_interleaved = True
alibi_slopes = None

assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension"
assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension"
Expand All @@ -599,7 +647,7 @@ def flash_attn_with_kvcache(
if q.shape[2] == k_cache.shape[2]:
gqa_parallel = False
if max_seqlen_k_hint is None:
max_seqlen_k_hint = k_cache.shape[1]
max_seqlen_k_hint = k_cache.shape[1]
out, softmax_lse = flashattn_hopper_cuda.fwd_kvcache(
q,
k_cache,
Expand All @@ -625,6 +673,6 @@ def flash_attn_with_kvcache(
rotary_interleaved,
num_splits,
max_seqlen_k_hint,
gqa_parallel
gqa_parallel,
)
return (out, softmax_lse) if return_softmax_lse else out
Loading