Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP]Sparse seqparallel #9

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 2 additions & 17 deletions csrc/flash_attn/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,23 +13,8 @@ include_directories(
${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}
)

#file(GLOB SOURCES_CU "src/*.cu")
#file(GLOB SOURCES_CPP "src/*.cpp")
set(SOURCES_CU
src/fmha_fwd_hdim32.cu
src/fmha_fwd_hdim64.cu
src/fmha_fwd_hdim128.cu
src/fmha_bwd_hdim32.cu
src/fmha_bwd_hdim64.cu
src/fmha_bwd_hdim128.cu
src/fmha_fwd_with_mask_bias_hdim32.cu
src/fmha_fwd_with_mask_bias_hdim64.cu
src/fmha_fwd_with_mask_bias_hdim128.cu
src/fmha_bwd_with_mask_bias_hdim32.cu
src/fmha_bwd_with_mask_bias_hdim64.cu
src/fmha_bwd_with_mask_bias_hdim128.cu
src/utils.cu)
set(SOURCES_CPP src/cuda_utils.cpp)
file(GLOB SOURCES_CU "src/*.cu")
file(GLOB SOURCES_CPP "src/*.cpp")

#add_library(flashattn OBJECT
add_library(flashattn SHARED
Expand Down
208 changes: 208 additions & 0 deletions csrc/flash_attn/flash_attn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -497,6 +497,214 @@ bool flash_attn_bwd(
FLASHATTNLIB_END_FUNC
}

bool flash_attn_fwd_block(
const void *q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
const void *k, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
const void *v, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
void *out, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
const void *cu_seqlens_q, // int32, batch_size+1, starting offset of each sequence
const void *cu_seqlens_k, // int32, batch_size+1, starting offset of each sequence
const void *blockmask, // int32, (seqlen_k / 256, seqlen_q / 16)
const int total_q,
const int total_k,
const int batch_size,
const int num_heads,
const int head_size,
const int max_seqlen_q_,
const int max_seqlen_k_,
const float p_dropout,
const float softmax_scale,
const bool is_causal,
const bool is_bf16,
void *softmax_lse_ptr, // softmax log_sum_exp
void *softmax_ptr,
void *workspace_ptr,
uint64_t *workspace_size,
cudaStream_t stream,
uint64_t seed,
uint64_t offset
) {
// printf("forward seed %jd offset %jd\b", seed, offset);
FLASHATTNLIB_BEGIN_FUNC

auto dprops = GetDeviceProperties(-1);
ASSERT_CHECK(dprops->major == 8 && dprops->minor >= 0);
bool is_dropout = p_dropout > 0.0;

const bool return_softmax = (softmax_ptr != nullptr);
Launch_params<FMHA_fprop_params> launch_params(dprops, stream, is_dropout, return_softmax);

ASSERT_CHECK(batch_size > 0);
ASSERT_CHECK(head_size == 16 || head_size == 32 || head_size == 64 || head_size == 128);

int max_seqlen_k = ((max_seqlen_k_ + 256 - 1) / 256) * 256;
if( max_seqlen_k <= 256 ) {
max_seqlen_k = 256;
}
int max_seqlen_q = ((max_seqlen_q_ + 16 - 1) / 16) * 16;
bool loop = max_seqlen_k > 256;

void* o_tmp_ptr = workspace_ptr;
// nullptr out to calculate workspace size
if (out == nullptr) {
if (loop) {
*workspace_size = uint64_t(total_q) * num_heads * head_size * sizeof(float);
} else {
*workspace_size = 0;
}
return true;
}

if (return_softmax) {
SetZero(softmax_ptr, 2, {batch_size, num_heads, max_seqlen_q, max_seqlen_k}, stream); // float16
}

set_params_fprop(launch_params.params,
batch_size,
max_seqlen_q,
max_seqlen_k,
num_heads,
head_size,
const_cast<void*>(q),
const_cast<void*>(k),
const_cast<void*>(v),
const_cast<void*>(out),
const_cast<void*>(cu_seqlens_q),
const_cast<void*>(cu_seqlens_k),
loop ? o_tmp_ptr : nullptr,
return_softmax ? softmax_ptr : nullptr,
softmax_lse_ptr,
p_dropout,
softmax_scale,
is_causal,
is_bf16,
/*num_splits=*/1);
launch_params.params.blockmask = static_cast<int*>(const_cast<void*>(blockmask));

run_fmha_block_sm80(launch_params, /*configure=*/ true);
// number of times random will be generated per thread, to offset philox counter in thc random
// state
int64_t counter_offset = launch_params.elts_per_thread + offset;

if( is_dropout ) {
launch_params.params.philox_args = PhiloxCudaState(seed, counter_offset);
}

run_fmha_block_sm80(launch_params, /*configure=*/false);

return true;

FLASHATTNLIB_END_FUNC
}

bool flash_attn_bwd_block(
const void *q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
const void *k, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
const void *v, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
void *dq, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
void *dk, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
void *dv, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
const void *out, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
const void *dout, // total_q x num_heads, x head_size
const void *cu_seqlens_q, // int32, batch_size+1
const void *cu_seqlens_k, // int32, batch_size+1
const void *blockmask, // int32, (seqlen_k / 256, seqlen_q / 16)
const int total_q,
const int total_k,
const int batch_size,
const int num_heads,
const int head_size,
const int max_seqlen_q_,
const int max_seqlen_k_,
const float p_dropout,
const float softmax_scale,
const bool is_causal,
const bool is_bf16,
void *softmax_lse_ptr,
void *dsoftmax_ptr,
void *workspace_ptr,
uint64_t *workspace_size,
cudaStream_t stream,
uint64_t seed,
uint64_t offset
) {
// printf("backward seed %jd offset %jd\b", seed, offset);

FLASHATTNLIB_BEGIN_FUNC

auto dprops = GetDeviceProperties(-1);
bool is_sm80 = dprops->major == 8 && dprops->minor == 0;
bool is_sm8x = dprops->major == 8 && dprops->minor >= 0;
ASSERT_CHECK(dprops->major == 8 && dprops->minor >= 0);
auto launch = &run_fmha_block_dgrad_sm80;

bool is_dropout = p_dropout > 0.0;

ASSERT_CHECK(batch_size > 0);
ASSERT_CHECK(head_size == 16 || head_size == 32 || head_size == 64 || head_size == 128);
if (head_size == 128) { // TODO: eventually we should support SM86 and SM70 with d=128 as well
ASSERT_CHECK(is_sm80);
}

int max_seqlen_k = ((max_seqlen_k_ + 256 - 1) / 256) * 256;
if( max_seqlen_k <= 256 ) {
max_seqlen_k = 256;
}
int max_seqlen_q = ((max_seqlen_q_ + 16 - 1) / 16) * 16;
bool loop = max_seqlen_k > 256;

void *dq_tmp_ptr = workspace_ptr;
// nullptr out to calculate workspace size
if (out == nullptr) {
if (loop) {
*workspace_size = uint64_t(total_q) * num_heads * head_size * sizeof(float);
} else {
*workspace_size = 0;
}
return true;
}

FMHA_dgrad_params params;

set_params_dgrad(params,
batch_size,
max_seqlen_q,
max_seqlen_k,
num_heads,
head_size,
const_cast<void*>(q),
const_cast<void*>(k),
const_cast<void*>(v),
const_cast<void*>(out),
dq, dk, dv,
const_cast<void*>(cu_seqlens_q),
const_cast<void*>(cu_seqlens_k),
loop ? dq_tmp_ptr : nullptr,
const_cast<void*>(dout),
softmax_lse_ptr,
dsoftmax_ptr,
p_dropout,
softmax_scale,
is_causal,
is_bf16,
/*num_splits=*/1);
params.blockmask = static_cast<int*>(const_cast<void*>(blockmask));

// We're gonna reset the rng state in Python after this kernel, so the counter offset
// here doesn't matter at all. We just choose an arbitrary number;
int64_t counter_offset = 4 + offset;

if( is_dropout ) {
params.philox_args = PhiloxCudaState(seed, counter_offset);
}

launch(params, stream);

return true;

FLASHATTNLIB_END_FUNC
}

#ifdef __cplusplus
}
#endif
Expand Down
60 changes: 60 additions & 0 deletions csrc/flash_attn/flash_attn.h
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,66 @@ bool flash_attn_bwd_with_bias_and_mask(
const int64_t* bias_dims
);

bool flash_attn_fwd_block(
const void *q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
const void *k, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
const void *v, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
void *out, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
const void *cu_seqlens_q, // int32, batch_size+1, starting offset of each sequence
const void *cu_seqlens_k, // int32, batch_size+1, starting offset of each sequence
const void *blockmask, // int32, (seqlen / 256, seqlen / 16)
const int total_q,
const int total_k,
const int batch_size,
const int num_heads,
const int head_size,
const int max_seqlen_q_,
const int max_seqlen_k_,
const float p_dropout,
const float softmax_scale,
const bool is_causal,
const bool is_bf16,
void *softmax_lse_ptr, // softmax log_sum_exp
void *softmax_ptr,
void *workspace_ptr,
uint64_t *workspace_size,
cudaStream_t stream,
uint64_t seed,
uint64_t offset
);

bool flash_attn_bwd_block(
const void *q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
const void *k, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
const void *v, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
void *dq, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
void *dk, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
void *dv, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
const void *out, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
const void *dout, // total_q x num_heads, x head_size
const void *cu_seqlens_q, // int32, batch_size+1
const void *cu_seqlens_k, // int32, batch_size+1
const void *blockmask, // int32, (seqlen / 256, seqlen / 16)
const int total_q,
const int total_k,
const int batch_size,
const int num_heads,
const int head_size,
const int max_seqlen_q_,
const int max_seqlen_k_,
const float p_dropout,
const float softmax_scale,
const bool is_causal,
const bool is_bf16,
void *softmax_lse_ptr,
void *dsoftmax_ptr,
void *workspace_ptr,
uint64_t *workspace_size,
cudaStream_t stream,
uint64_t seed,
uint64_t offset
);

void flash_attn_set_error(const char *msg);

const char *flash_attn_error();
Expand Down
4 changes: 2 additions & 2 deletions csrc/flash_attn/src/fmha.h
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,6 @@ bool run_fmha_bwd_with_mask_bias_hdim32(FMHA_dgrad_params &params, cudaStream_t
bool run_fmha_bwd_with_mask_bias_hdim64(FMHA_dgrad_params &params, cudaStream_t stream);
bool run_fmha_bwd_with_mask_bias_hdim128(FMHA_dgrad_params &params, cudaStream_t stream);

void run_fmha_block_fp16_sm80(Launch_params<FMHA_fprop_params> &launch_params, const bool configure);
void run_fmha_block_sm80(Launch_params<FMHA_fprop_params> &launch_params, const bool configure);

void run_fmha_block_dgrad_fp16_sm80(const FMHA_dgrad_params &params, cudaStream_t stream);
void run_fmha_block_dgrad_sm80(FMHA_dgrad_params &params, cudaStream_t stream);
3 changes: 2 additions & 1 deletion csrc/flash_attn/src/fmha/gemm.h
Original file line number Diff line number Diff line change
Expand Up @@ -135,10 +135,11 @@ struct alignas(static_cast<int>(Base_::ALIGNMENT)) Fragment : public Base_ {
}

// Multiply by another fragment.
template <typename elem_type>
inline __device__ void hmul(const Fragment &other) {
#pragma unroll
for( int ii = 0; ii < Base_::NUM_REGS; ++ii ) {
this->reg(ii) = fmha::hmul2(this->reg(ii), other.reg(ii));
this->reg(ii) = fmha::hmul2<elem_type>(this->reg(ii), other.reg(ii));
}
}

Expand Down
36 changes: 27 additions & 9 deletions csrc/flash_attn/src/fmha/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,11 @@ static inline __device__ uint32_t hmin2(uint32_t a, uint32_t b) {

////////////////////////////////////////////////////////////////////////////////////////////////////

static inline __device__ uint32_t hmul2(const uint32_t a, const uint32_t b) {
template<typename T=__half >
static inline __device__ uint32_t hmul2(const uint32_t a, const uint32_t b);

template<>
inline __device__ uint32_t hmul2<__half>(const uint32_t a, const uint32_t b) {
// uint32_t c;
// asm volatile("mul.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b));
// return c;
Expand All @@ -281,6 +285,18 @@ static inline __device__ uint32_t hmul2(const uint32_t a, const uint32_t b) {
return reinterpret_cast<uint32_t(&)>(result);
}

#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
template<>
inline __device__ uint32_t hmul2<__nv_bfloat16>(const uint32_t a, const uint32_t b) {
// uint32_t c;
// asm volatile("mul.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b));
// return c;
__nv_bfloat162 result = __hmul2(reinterpret_cast<const __nv_bfloat162 (&)>(a),
reinterpret_cast<const __nv_bfloat162 (&)>(b));
return reinterpret_cast<uint32_t(&)>(result);
}
#endif

////////////////////////////////////////////////////////////////////////////////////////////////////

static inline __device__ uint2 hmul4(uint2 a, uint2 b) {
Expand All @@ -292,23 +308,25 @@ static inline __device__ uint2 hmul4(uint2 a, uint2 b) {

////////////////////////////////////////////////////////////////////////////////////////////////////

template<typename elem_type=__half >
static inline __device__ uint4 hmul8(uint4 a, uint4 b) {
uint4 c;
c.x = hmul2(a.x, b.x);
c.y = hmul2(a.y, b.y);
c.z = hmul2(a.z, b.z);
c.w = hmul2(a.w, b.w);
c.x = hmul2<elem_type>(a.x, b.x);
c.y = hmul2<elem_type>(a.y, b.y);
c.z = hmul2<elem_type>(a.z, b.z);
c.w = hmul2<elem_type>(a.w, b.w);
return c;
}

////////////////////////////////////////////////////////////////////////////////////////////////////

template<typename elem_type=__half >
static inline __device__ uint4 hmul8(uint32_t a, uint4 b) {
uint4 c;
c.x = hmul2(a, b.x);
c.y = hmul2(a, b.y);
c.z = hmul2(a, b.z);
c.w = hmul2(a, b.w);
c.x = hmul2<elem_type>(a, b.x);
c.y = hmul2<elem_type>(a, b.y);
c.z = hmul2<elem_type>(a, b.z);
c.w = hmul2<elem_type>(a, b.w);
return c;
}

Expand Down
Loading