From 48d9be3fddf8b11393d0a9211899d370fe176124 Mon Sep 17 00:00:00 2001 From: kuizhiqing Date: Fri, 26 May 2023 05:12:56 +0000 Subject: [PATCH 1/7] add sparse api --- csrc/flash_attn/flash_attn.cpp | 206 +++++++++++++++++++++++++++++++++ csrc/flash_attn/flash_attn.h | 58 ++++++++++ 2 files changed, 264 insertions(+) diff --git a/csrc/flash_attn/flash_attn.cpp b/csrc/flash_attn/flash_attn.cpp index 42f2644b4..ef815d86c 100644 --- a/csrc/flash_attn/flash_attn.cpp +++ b/csrc/flash_attn/flash_attn.cpp @@ -497,6 +497,212 @@ bool flash_attn_bwd( FLASHATTNLIB_END_FUNC } +bool flash_attn_fwd_block( + const void *q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + const void *k, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i + const void *v, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i + void *out, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i + const void *cu_seqlens_q, // int32, batch_size+1, starting offset of each sequence + const void *cu_seqlens_k, // int32, batch_size+1, starting offset of each sequence + const void *blockmask, // int32, (seqlen_k / 256, seqlen_q / 16) + const int total_q, + const int total_k, + const int batch_size, + const int num_heads, + const int head_size, + const int max_seqlen_q_, + const int max_seqlen_k_, + const float p_dropout, + const float softmax_scale, + const bool is_causal, + void *softmax_lse_ptr, // softmax log_sum_exp + void *softmax_ptr, + void *workspace_ptr, + uint64_t *workspace_size, + cudaStream_t stream, + uint64_t seed, + uint64_t offset +) { + // printf("forward seed %jd offset %jd\b", seed, offset); + FLASHATTNLIB_BEGIN_FUNC + + auto dprops = GetDeviceProperties(-1); + ASSERT_CHECK(dprops->major == 8 && dprops->minor >= 0); + bool is_dropout = p_dropout > 0.0; + + const bool return_softmax = (softmax_ptr != nullptr); + Launch_params launch_params(dprops, stream, is_dropout, return_softmax); + + ASSERT_CHECK(batch_size > 0); + ASSERT_CHECK(head_size == 16 || head_size == 32 || head_size == 64 || head_size == 128); + + int max_seqlen_k = ((max_seqlen_k_ + 256 - 1) / 256) * 256; + if( max_seqlen_k <= 256 ) { + max_seqlen_k = 256; + } + int max_seqlen_q = ((max_seqlen_q_ + 16 - 1) / 16) * 16; + bool loop = max_seqlen_k > 256; + + void* o_tmp_ptr = workspace_ptr; + // nullptr out to calculate workspace size + if (out == nullptr) { + if (loop) { + *workspace_size = uint64_t(total_q) * num_heads * head_size * sizeof(float); + } else { + *workspace_size = 0; + } + return true; + } + + if (return_softmax) { + SetZero(softmax_ptr, 2, {batch_size, num_heads, max_seqlen_q, max_seqlen_k}, stream); // float16 + } + + set_params_fprop(launch_params.params, + batch_size, + max_seqlen_q, + max_seqlen_k, + num_heads, + head_size, + const_cast(q), + const_cast(k), + const_cast(v), + const_cast(out), + const_cast(cu_seqlens_q), + const_cast(cu_seqlens_k), + loop ? o_tmp_ptr : nullptr, + return_softmax ? softmax_ptr : nullptr, + softmax_lse_ptr, + p_dropout, + softmax_scale, + is_causal, + /*is_bf16*/false, + /*num_splits=*/1); + launch_params.params.blockmask = static_cast(const_cast(blockmask)); + + run_fmha_block_fp16_sm80(launch_params, /*configure=*/ true); + // number of times random will be generated per thread, to offset philox counter in thc random + // state + int64_t counter_offset = launch_params.elts_per_thread + offset; + + if( is_dropout ) { + launch_params.params.philox_args = PhiloxCudaState(seed, counter_offset); + } + + run_fmha_block_fp16_sm80(launch_params, /*configure=*/false); + + return true; + + FLASHATTNLIB_END_FUNC +} + +bool flash_attn_bwd_block( + const void *q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + const void *k, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i + const void *v, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i + void *dq, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + void *dk, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i + void *dv, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i + const void *out, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i + const void *dout, // total_q x num_heads, x head_size + const void *cu_seqlens_q, // int32, batch_size+1 + const void *cu_seqlens_k, // int32, batch_size+1 + const void *blockmask, // int32, (seqlen_k / 256, seqlen_q / 16) + const int total_q, + const int total_k, + const int batch_size, + const int num_heads, + const int head_size, + const int max_seqlen_q_, + const int max_seqlen_k_, + const float p_dropout, + const float softmax_scale, + const bool is_causal, + void *softmax_lse_ptr, + void *dsoftmax_ptr, + void *workspace_ptr, + uint64_t *workspace_size, + cudaStream_t stream, + uint64_t seed, + uint64_t offset +) { + // printf("backward seed %jd offset %jd\b", seed, offset); + + FLASHATTNLIB_BEGIN_FUNC + + auto dprops = GetDeviceProperties(-1); + bool is_sm80 = dprops->major == 8 && dprops->minor == 0; + bool is_sm8x = dprops->major == 8 && dprops->minor >= 0; + ASSERT_CHECK(dprops->major == 8 && dprops->minor >= 0); + auto launch = &run_fmha_block_dgrad_fp16_sm80; + + bool is_dropout = p_dropout > 0.0; + + ASSERT_CHECK(batch_size > 0); + ASSERT_CHECK(head_size == 16 || head_size == 32 || head_size == 64 || head_size == 128); + if (head_size == 128) { // TODO: eventually we should support SM86 and SM70 with d=128 as well + ASSERT_CHECK(is_sm80); + } + + int max_seqlen_k = ((max_seqlen_k_ + 256 - 1) / 256) * 256; + if( max_seqlen_k <= 256 ) { + max_seqlen_k = 256; + } + int max_seqlen_q = ((max_seqlen_q_ + 16 - 1) / 16) * 16; + bool loop = max_seqlen_k > 256; + + void *dq_tmp_ptr = workspace_ptr; + // nullptr out to calculate workspace size + if (out == nullptr) { + if (loop) { + *workspace_size = uint64_t(total_q) * num_heads * head_size * sizeof(float); + } else { + *workspace_size = 0; + } + return true; + } + + FMHA_dgrad_params params; + + set_params_dgrad(params, + batch_size, + max_seqlen_q, + max_seqlen_k, + num_heads, + head_size, + const_cast(q), + const_cast(k), + const_cast(v), + const_cast(out), + dq, dk, dv, + const_cast(cu_seqlens_q), + const_cast(cu_seqlens_k), + loop ? dq_tmp_ptr : nullptr, + const_cast(dout), + softmax_lse_ptr, + dsoftmax_ptr, + p_dropout, + softmax_scale, + is_causal, + /*is_bf16*/false, + /*num_splits=*/1); + params.blockmask = static_cast(const_cast(blockmask)); + + // We're gonna reset the rng state in Python after this kernel, so the counter offset + // here doesn't matter at all. We just choose an arbitrary number; + int64_t counter_offset = 4 + offset; + + if( is_dropout ) { + params.philox_args = PhiloxCudaState(seed, counter_offset); + } + + launch(params, stream); + + return true; + + FLASHATTNLIB_END_FUNC +} + #ifdef __cplusplus } #endif diff --git a/csrc/flash_attn/flash_attn.h b/csrc/flash_attn/flash_attn.h index 48dfacd19..fecdcb8fd 100644 --- a/csrc/flash_attn/flash_attn.h +++ b/csrc/flash_attn/flash_attn.h @@ -138,6 +138,64 @@ bool flash_attn_bwd_with_bias_and_mask( const int64_t* bias_dims ); +bool flash_attn_fwd_block( + const void *q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + const void *k, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i + const void *v, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i + void *out, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i + const void *cu_seqlens_q, // int32, batch_size+1, starting offset of each sequence + const void *cu_seqlens_k, // int32, batch_size+1, starting offset of each sequence + const void *blockmask, // int32, (seqlen / 256, seqlen / 16) + const int total_q, + const int total_k, + const int batch_size, + const int num_heads, + const int head_size, + const int max_seqlen_q_, + const int max_seqlen_k_, + const float p_dropout, + const float softmax_scale, + const bool is_causal, + void *softmax_lse_ptr, // softmax log_sum_exp + void *softmax_ptr, + void *workspace_ptr, + uint64_t *workspace_size, + cudaStream_t stream, + uint64_t seed, + uint64_t offset +); + +bool flash_attn_bwd_block( + const void *q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + const void *k, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i + const void *v, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i + void *dq, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + void *dk, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i + void *dv, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i + const void *out, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i + const void *dout, // total_q x num_heads, x head_size + const void *cu_seqlens_q, // int32, batch_size+1 + const void *cu_seqlens_k, // int32, batch_size+1 + const void *blockmask, // int32, (seqlen / 256, seqlen / 16) + const int total_q, + const int total_k, + const int batch_size, + const int num_heads, + const int head_size, + const int max_seqlen_q_, + const int max_seqlen_k_, + const float p_dropout, + const float softmax_scale, + const bool is_causal, + void *softmax_lse_ptr, + void *dsoftmax_ptr, + void *workspace_ptr, + uint64_t *workspace_size, + cudaStream_t stream, + uint64_t seed, + uint64_t offset +); + void flash_attn_set_error(const char *msg); const char *flash_attn_error(); From be5cade50b0f93fa5c19c75fadec38db8270a9c7 Mon Sep 17 00:00:00 2001 From: kuizhiqing Date: Fri, 26 May 2023 12:30:08 +0000 Subject: [PATCH 2/7] include cu file --- csrc/flash_attn/CMakeLists.txt | 19 ++----------------- 1 file changed, 2 insertions(+), 17 deletions(-) diff --git a/csrc/flash_attn/CMakeLists.txt b/csrc/flash_attn/CMakeLists.txt index f6bb280cd..7267d9232 100644 --- a/csrc/flash_attn/CMakeLists.txt +++ b/csrc/flash_attn/CMakeLists.txt @@ -13,23 +13,8 @@ include_directories( ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES} ) -#file(GLOB SOURCES_CU "src/*.cu") -#file(GLOB SOURCES_CPP "src/*.cpp") -set(SOURCES_CU - src/fmha_fwd_hdim32.cu - src/fmha_fwd_hdim64.cu - src/fmha_fwd_hdim128.cu - src/fmha_bwd_hdim32.cu - src/fmha_bwd_hdim64.cu - src/fmha_bwd_hdim128.cu - src/fmha_fwd_with_mask_bias_hdim32.cu - src/fmha_fwd_with_mask_bias_hdim64.cu - src/fmha_fwd_with_mask_bias_hdim128.cu - src/fmha_bwd_with_mask_bias_hdim32.cu - src/fmha_bwd_with_mask_bias_hdim64.cu - src/fmha_bwd_with_mask_bias_hdim128.cu - src/utils.cu) -set(SOURCES_CPP src/cuda_utils.cpp) +file(GLOB SOURCES_CU "src/*.cu") +file(GLOB SOURCES_CPP "src/*.cpp") #add_library(flashattn OBJECT add_library(flashattn SHARED From c51e94408e0832c4982d6e233657ba770748c696 Mon Sep 17 00:00:00 2001 From: kuizhiqing Date: Sat, 27 May 2023 10:13:47 +0000 Subject: [PATCH 3/7] support bf16 for block --- csrc/flash_attn/flash_attn.cpp | 12 ++-- csrc/flash_attn/flash_attn.h | 2 + csrc/flash_attn/src/fmha.h | 4 +- .../fmha_block_dgrad_fp16_kernel_loop.sm80.cu | 63 ------------------ .../src/fmha_block_dgrad_kernel_1xN_loop.h | 50 ++++++++------ .../src/fmha_block_dgrad_kernel_loop.sm80.cu | 66 +++++++++++++++++++ ...m80.cu => fmha_block_fprop_kernel.sm80.cu} | 39 ++++++----- .../src/fmha_block_fprop_kernel_1xN.h | 17 +++-- 8 files changed, 139 insertions(+), 114 deletions(-) delete mode 100644 csrc/flash_attn/src/fmha_block_dgrad_fp16_kernel_loop.sm80.cu create mode 100644 csrc/flash_attn/src/fmha_block_dgrad_kernel_loop.sm80.cu rename csrc/flash_attn/src/{fmha_block_fprop_fp16_kernel.sm80.cu => fmha_block_fprop_kernel.sm80.cu} (67%) diff --git a/csrc/flash_attn/flash_attn.cpp b/csrc/flash_attn/flash_attn.cpp index ef815d86c..ebdd18900 100644 --- a/csrc/flash_attn/flash_attn.cpp +++ b/csrc/flash_attn/flash_attn.cpp @@ -515,6 +515,7 @@ bool flash_attn_fwd_block( const float p_dropout, const float softmax_scale, const bool is_causal, + const bool is_bf16, void *softmax_lse_ptr, // softmax log_sum_exp void *softmax_ptr, void *workspace_ptr, @@ -576,11 +577,11 @@ bool flash_attn_fwd_block( p_dropout, softmax_scale, is_causal, - /*is_bf16*/false, + is_bf16, /*num_splits=*/1); launch_params.params.blockmask = static_cast(const_cast(blockmask)); - run_fmha_block_fp16_sm80(launch_params, /*configure=*/ true); + run_fmha_block_sm80(launch_params, /*configure=*/ true); // number of times random will be generated per thread, to offset philox counter in thc random // state int64_t counter_offset = launch_params.elts_per_thread + offset; @@ -589,7 +590,7 @@ bool flash_attn_fwd_block( launch_params.params.philox_args = PhiloxCudaState(seed, counter_offset); } - run_fmha_block_fp16_sm80(launch_params, /*configure=*/false); + run_fmha_block_sm80(launch_params, /*configure=*/false); return true; @@ -618,6 +619,7 @@ bool flash_attn_bwd_block( const float p_dropout, const float softmax_scale, const bool is_causal, + const bool is_bf16, void *softmax_lse_ptr, void *dsoftmax_ptr, void *workspace_ptr, @@ -634,7 +636,7 @@ bool flash_attn_bwd_block( bool is_sm80 = dprops->major == 8 && dprops->minor == 0; bool is_sm8x = dprops->major == 8 && dprops->minor >= 0; ASSERT_CHECK(dprops->major == 8 && dprops->minor >= 0); - auto launch = &run_fmha_block_dgrad_fp16_sm80; + auto launch = &run_fmha_block_dgrad_sm80; bool is_dropout = p_dropout > 0.0; @@ -684,7 +686,7 @@ bool flash_attn_bwd_block( p_dropout, softmax_scale, is_causal, - /*is_bf16*/false, + is_bf16, /*num_splits=*/1); params.blockmask = static_cast(const_cast(blockmask)); diff --git a/csrc/flash_attn/flash_attn.h b/csrc/flash_attn/flash_attn.h index fecdcb8fd..6ad8014d9 100644 --- a/csrc/flash_attn/flash_attn.h +++ b/csrc/flash_attn/flash_attn.h @@ -156,6 +156,7 @@ bool flash_attn_fwd_block( const float p_dropout, const float softmax_scale, const bool is_causal, + const bool is_bf16, void *softmax_lse_ptr, // softmax log_sum_exp void *softmax_ptr, void *workspace_ptr, @@ -187,6 +188,7 @@ bool flash_attn_bwd_block( const float p_dropout, const float softmax_scale, const bool is_causal, + const bool is_bf16, void *softmax_lse_ptr, void *dsoftmax_ptr, void *workspace_ptr, diff --git a/csrc/flash_attn/src/fmha.h b/csrc/flash_attn/src/fmha.h index ae3816032..0d4c5c36d 100644 --- a/csrc/flash_attn/src/fmha.h +++ b/csrc/flash_attn/src/fmha.h @@ -214,6 +214,6 @@ bool run_fmha_bwd_with_mask_bias_hdim32(FMHA_dgrad_params ¶ms, cudaStream_t bool run_fmha_bwd_with_mask_bias_hdim64(FMHA_dgrad_params ¶ms, cudaStream_t stream); bool run_fmha_bwd_with_mask_bias_hdim128(FMHA_dgrad_params ¶ms, cudaStream_t stream); -void run_fmha_block_fp16_sm80(Launch_params &launch_params, const bool configure); +void run_fmha_block_sm80(Launch_params &launch_params, const bool configure); -void run_fmha_block_dgrad_fp16_sm80(const FMHA_dgrad_params ¶ms, cudaStream_t stream); +void run_fmha_block_dgrad_sm80(const FMHA_dgrad_params ¶ms, cudaStream_t stream); diff --git a/csrc/flash_attn/src/fmha_block_dgrad_fp16_kernel_loop.sm80.cu b/csrc/flash_attn/src/fmha_block_dgrad_fp16_kernel_loop.sm80.cu deleted file mode 100644 index c6c45177e..000000000 --- a/csrc/flash_attn/src/fmha_block_dgrad_fp16_kernel_loop.sm80.cu +++ /dev/null @@ -1,63 +0,0 @@ -/* Copyright (c) 2022, Tri Dao. - */ -#include "fmha.h" -#include "fmha_block_dgrad_kernel_1xN_loop.h" - -template -__global__ void fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel(FMHA_dgrad_params params) { - fmha::compute_block_dq_dk_dv_1xN(params); -} - -template -void run_fmha_block_dgrad_fp16_sm80_loop_(const FMHA_dgrad_params ¶ms, cudaStream_t stream) { - constexpr int smem_size_softmax = Kernel_traits::Cta_tile_p::M * Kernel_traits::Cta_tile_p::WARPS_N * sizeof(float); - constexpr int smem_size_q = Kernel_traits::Smem_tile_q::BYTES_PER_TILE; - constexpr int smem_size_v = Kernel_traits::Smem_tile_v::BYTES_PER_TILE; - constexpr int smem_size_dq = Kernel_traits::Smem_tile_o::BYTES_PER_TILE; - constexpr int smem_size_dp_sum = Kernel_traits::Smem_dp_sum::BYTES_PER_TILE; - - using Smem_tile_s = fmha::Smem_tile_mma_transposed; - constexpr int smem_size_s = Smem_tile_s::BYTES_PER_TILE; - static_assert(smem_size_s == 16 * Kernel_traits::Cta_tile_p::N * 2); - static_assert(smem_size_dq == 16 * Kernel_traits::Cta_tile_p::K * 4 * Kernel_traits::Cta_tile_p::WARPS_N); - static_assert(smem_size_dp_sum == 16 * 4 * 2); - - constexpr int smem_size_dq_dk_dv = smem_size_q * 2 + smem_size_v * (Kernel_traits::V_IN_REGS ? 1 : 2) + smem_size_dq + smem_size_s * 2 + smem_size_dp_sum; - - bool is_dropout = params.p_dropout < 1.f; // params.p_dropout is the probability of "keeping" - bool is_causal = params.is_causal; - auto kernel = is_dropout - ? (is_causal ? &fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel : &fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel) - : (is_causal ? &fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel : &fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel); - constexpr int blocksize_c = Kernel_traits::Cta_tile_p::N; - if (params.seqlen_k == blocksize_c) { - kernel = is_dropout - ? (is_causal ? &fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel : &fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel) - : (is_causal ? &fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel : &fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel); - } else if (params.seqlen_k == blocksize_c * 2) { - kernel = is_dropout - ? (is_causal ? &fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel : &fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel) - : (is_causal ? &fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel : &fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel); - } - - if( smem_size_dq_dk_dv >= 48 * 1024 ) { - FMHA_CHECK_CUDA(cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv)); - } - dim3 grid(params.b, params.h); - kernel<<>>(params); - FMHA_CHECK_CUDA(cudaPeekAtLastError()); -} - -void run_fmha_block_dgrad_fp16_sm80(const FMHA_dgrad_params ¶ms, cudaStream_t stream) { - if (params.d == 16) { - using Kernel_traits = FMHA_kernel_traits<256, 16, 16, 1, 8, 0x08u>; - run_fmha_block_dgrad_fp16_sm80_loop_(params, stream); - } else if (params.d == 32) { - using Kernel_traits = FMHA_kernel_traits<256, 32, 16, 1, 8, 0x08u>; - run_fmha_block_dgrad_fp16_sm80_loop_(params, stream); - } else if (params.d == 64) { - using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 8, 0x100u>; - run_fmha_block_dgrad_fp16_sm80_loop_(params, stream); - } -} diff --git a/csrc/flash_attn/src/fmha_block_dgrad_kernel_1xN_loop.h b/csrc/flash_attn/src/fmha_block_dgrad_kernel_1xN_loop.h index 79e0a88c8..01282691c 100644 --- a/csrc/flash_attn/src/fmha_block_dgrad_kernel_1xN_loop.h +++ b/csrc/flash_attn/src/fmha_block_dgrad_kernel_1xN_loop.h @@ -13,12 +13,12 @@ namespace fmha { //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template inline __device__ void dot_do_o(float (&sum)[M], const uint4 (&do_)[M], const uint4 (&o)[M], Smem_dp_sum smem, const int buffer_idx) { #pragma unroll for (int mi = 0; mi < M; ++mi) { - sum[mi] = smem.reduce_warp(fmha::hmulsum8<__half>(do_[mi], o[mi])); + sum[mi] = smem.reduce_warp(fmha::hmulsum8(do_[mi], o[mi])); } static_assert(M == 1); smem.store(sum[0], buffer_idx); @@ -30,6 +30,14 @@ template= 800 + using elem_type = typename Kernel_traits::elem_type; +#else + constexpr bool is_fp16_type = std::is_same::value; + assert(is_fp16_type); + using elem_type = __half; +#endif + // The description of the CTA tile for the 1st batched GEMM. using Cta_tile_p = typename Kernel_traits::Cta_tile_p; // The description of the CTA tile for the 2nd batched GEMM. @@ -103,7 +111,7 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params ¶ms, using Smem_dp_sum = typename Kernel_traits::Smem_dp_sum; // using Gemm1 = Gemm_Q_K; - using Gemm1 = Gemm_Q_K; + using Gemm1 = Gemm_Q_K; using Softmax = fmha::Softmax; @@ -242,7 +250,7 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params ¶ms, // if (Is_first) { // if (true) { if (Is_first || mask_val % 2 == 1) { - dot_do_o(dp_sum_regs, gmem_do.fetch_, gmem_o.fetch_, smem_dp_sum, 0); + dot_do_o(dp_sum_regs, gmem_do.fetch_, gmem_o.fetch_, smem_dp_sum, 0); const int dp_sum_row = tidx / Smem_dp_sum::THREADS_PER_ROW; if ((dp_sum_row < Smem_dp_sum::ROWS) && (tidx % Smem_dp_sum::THREADS_PER_ROW == 0)) { gmem_softmax_d.store_row(reinterpret_cast(dp_sum_regs), dp_sum_row); @@ -365,7 +373,7 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params ¶ms, Frag_p frag_p[Mma_tile_dq::MMAS_K][Mma_tile_dq::MMAS_M]; static_assert(Mma_tile_dq::MMAS_M == Mma_tile_p::MMAS_M); static_assert(Mma_tile_dq::MMAS_K == Mma_tile_p::MMAS_N); - softmax.template pack<__half>(frag_p); + softmax.template pack(frag_p); // Store s * dmask to smem for transpose smem_s.store(frag_p); @@ -414,9 +422,9 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params ¶ms, smem_do.load(frag_do[ki & 1], ki); if (!Kernel_traits::V_IN_REGS) { smem_v.load(frag_v[ki & 1], ki); - fmha::gemm_cl<__half>(acc_dp, frag_do[(ki - 1) & 1], frag_v[(ki - 1) & 1]); + fmha::gemm_cl(acc_dp, frag_do[(ki - 1) & 1], frag_v[(ki - 1) & 1]); } else { - fmha::gemm_cl<__half>(acc_dp, frag_do[(ki - 1) & 1], frag_v[ki - 1]); + fmha::gemm_cl(acc_dp, frag_do[(ki - 1) & 1], frag_v[ki - 1]); } // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l < 4)) { // float2 tmp = __half22float2(reinterpret_cast<__half2 &>(frag_do[(ki - 1) & 1])); @@ -430,9 +438,9 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params ¶ms, { int ki = Mma_tile_p::MMAS_K; if (!Kernel_traits::V_IN_REGS) { - fmha::gemm_cl<__half>(acc_dp, frag_do[(ki - 1) & 1], frag_v[(ki - 1) & 1]); + fmha::gemm_cl(acc_dp, frag_do[(ki - 1) & 1], frag_v[(ki - 1) & 1]); } else { - fmha::gemm_cl<__half>(acc_dp, frag_do[(ki - 1) & 1], frag_v[(ki - 1)]); + fmha::gemm_cl(acc_dp, frag_do[(ki - 1) & 1], frag_v[(ki - 1)]); } } @@ -470,7 +478,7 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params ¶ms, if (is_first_read) { softmax.subtract_dp_sum(dp_sum); } Frag_p frag_dp[Mma_tile_dq::MMAS_K][Mma_tile_dq::MMAS_M]; - softmax.template pack<__half>(frag_dp); + softmax.template pack(frag_dp); if (!Is_dropout) { #pragma unroll @@ -521,13 +529,13 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params ¶ms, // Trigger the load from shared memory for the next series of Q values. smem_kt.load(frag_kt[ki & 1], ki); // Do the math for the values already in registers. - fmha::gemm_cl<__half>(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1) & 1]); + fmha::gemm_cl(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1) & 1]); // fmha::gemm_cl<__half>(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1)]); } // Do the final stage of math. { int ki = Mma_tile_dq::MMAS_K; - fmha::gemm_cl<__half>(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1) & 1]); + fmha::gemm_cl(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1) & 1]); // fmha::gemm_cl<__half>(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1)]); } @@ -551,7 +559,7 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params ¶ms, for( int ki = 0; ki < Mma_tile_dkv::MMAS_K; ki++ ) { #pragma unroll for( int mi = 0; mi < Mma_tile_dkv::MMAS_M; mi++ ) { - frag_s[ki][mi].template hrelu_<__half>(); + frag_s[ki][mi].template hrelu_(); } } } @@ -561,13 +569,13 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params ¶ms, // Trigger the load from shared memory for the next series of Q values. smem_dot.load(frag_dot[ki & 1], ki); // Do the math for the values already in registers. - fmha::gemm_cl<__half>(acc_dv, frag_s[(ki - 1)], frag_dot[(ki - 1) & 1]); + fmha::gemm_cl(acc_dv, frag_s[(ki - 1)], frag_dot[(ki - 1) & 1]); } // Do the final stage of math. { int ki = Mma_tile_dkv::MMAS_K; - fmha::gemm_cl<__half>(acc_dv, frag_s[(ki - 1)], frag_dot[(ki - 1) & 1]); + fmha::gemm_cl(acc_dv, frag_s[(ki - 1)], frag_dot[(ki - 1) & 1]); } // __syncthreads(); @@ -590,7 +598,7 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params ¶ms, if (Is_first || mask_val_next % 2 == 1) { // dot_do_o(dp_sum_regs, gmem_do.fetch_, gmem_o.fetch_, smem_dp_sum); // smem_dp_sum.move_to_next_write_buffer(); - dot_do_o(dp_sum_regs, gmem_do.fetch_, gmem_o.fetch_, smem_dp_sum, (l + 1) % 2); + dot_do_o(dp_sum_regs, gmem_do.fetch_, gmem_o.fetch_, smem_dp_sum, (l + 1) % 2); const int dp_sum_row_1 = tidx / Smem_dp_sum::THREADS_PER_ROW; if ((dp_sum_row_1 < Smem_dp_sum::ROWS) && (tidx % Smem_dp_sum::THREADS_PER_ROW == 0)) { gmem_softmax_d.store_row(reinterpret_cast(dp_sum_regs), dp_sum_row_1); @@ -619,13 +627,13 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params ¶ms, // Trigger the load from shared memory for the next series of Q values. smem_qt.load(frag_qt[ki & 1], ki); // Do the math for the values already in registers. - fmha::gemm_cl<__half>(acc_dk, frag_dpt[(ki - 1)], frag_qt[(ki - 1) & 1]); + fmha::gemm_cl(acc_dk, frag_dpt[(ki - 1)], frag_qt[(ki - 1) & 1]); } // Do the final stage of math. { int ki = Mma_tile_dkv::MMAS_K; - fmha::gemm_cl<__half>(acc_dk, frag_dpt[(ki - 1)], frag_qt[(ki - 1) & 1]); + fmha::gemm_cl(acc_dk, frag_dpt[(ki - 1)], frag_qt[(ki - 1) & 1]); } // Make sure dQ is in shared memory. @@ -645,7 +653,7 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params ¶ms, // } dq_out[0] = fmha::fmul4(dq_out[0], params.scale_bmm1f); // Output the values. - gmem_dq.template store<__half>(dq_out, 0); + gmem_dq.template store(dq_out, 0); } else { // Output the values. gmem_dq_tmp.store(dq_out, 0); @@ -700,11 +708,11 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params ¶ms, // the total amount of shared mem? // Epilogue swizzle for dV Smem_tile_dv smem_dv(&smem_[0], tidx); - smem_dv.template store<__half>(acc_dv); + smem_dv.template store(acc_dv); // Epilogue swizzle for dK Smem_tile_dk smem_dk(&smem_[Smem_tile_dv::BYTES_PER_TILE], tidx); - smem_dk.template store<__half>(acc_dk); + smem_dk.template store(acc_dk); __syncthreads(); uint4 dv_out[Smem_tile_dv::NUM_LDS]; diff --git a/csrc/flash_attn/src/fmha_block_dgrad_kernel_loop.sm80.cu b/csrc/flash_attn/src/fmha_block_dgrad_kernel_loop.sm80.cu new file mode 100644 index 000000000..b33880f18 --- /dev/null +++ b/csrc/flash_attn/src/fmha_block_dgrad_kernel_loop.sm80.cu @@ -0,0 +1,66 @@ +/* Copyright (c) 2022, Tri Dao. + */ +#include "fmha.h" +#include "static_switch.h" +#include "fmha_block_dgrad_kernel_1xN_loop.h" + +template +__global__ void fmha_block_dgrad_sm80_dq_dk_dv_loop_kernel(FMHA_dgrad_params params) { + fmha::compute_block_dq_dk_dv_1xN(params); +} + +template +void run_fmha_block_dgrad_sm80_loop_(const FMHA_dgrad_params ¶ms, cudaStream_t stream) { + constexpr int smem_size_softmax = Kernel_traits::Cta_tile_p::M * Kernel_traits::Cta_tile_p::WARPS_N * sizeof(float); + constexpr int smem_size_q = Kernel_traits::Smem_tile_q::BYTES_PER_TILE; + constexpr int smem_size_v = Kernel_traits::Smem_tile_v::BYTES_PER_TILE; + constexpr int smem_size_dq = Kernel_traits::Smem_tile_o::BYTES_PER_TILE; + constexpr int smem_size_dp_sum = Kernel_traits::Smem_dp_sum::BYTES_PER_TILE; + + using Smem_tile_s = fmha::Smem_tile_mma_transposed; + constexpr int smem_size_s = Smem_tile_s::BYTES_PER_TILE; + static_assert(smem_size_s == 16 * Kernel_traits::Cta_tile_p::N * 2); + static_assert(smem_size_dq == 16 * Kernel_traits::Cta_tile_p::K * 4 * Kernel_traits::Cta_tile_p::WARPS_N); + static_assert(smem_size_dp_sum == 16 * 4 * 2); + + constexpr int smem_size_dq_dk_dv = smem_size_q * 2 + smem_size_v * (Kernel_traits::V_IN_REGS ? 1 : 2) + smem_size_dq + smem_size_s * 2 + smem_size_dp_sum; + + bool is_dropout = params.p_dropout < 1.f; // params.p_dropout is the probability of "keeping" + bool is_causal = params.is_causal; + auto kernel = is_dropout + ? (is_causal ? &fmha_block_dgrad_sm80_dq_dk_dv_loop_kernel : &fmha_block_dgrad_sm80_dq_dk_dv_loop_kernel) + : (is_causal ? &fmha_block_dgrad_sm80_dq_dk_dv_loop_kernel : &fmha_block_dgrad_sm80_dq_dk_dv_loop_kernel); + constexpr int blocksize_c = Kernel_traits::Cta_tile_p::N; + if (params.seqlen_k == blocksize_c) { + kernel = is_dropout + ? (is_causal ? &fmha_block_dgrad_sm80_dq_dk_dv_loop_kernel : &fmha_block_dgrad_sm80_dq_dk_dv_loop_kernel) + : (is_causal ? &fmha_block_dgrad_sm80_dq_dk_dv_loop_kernel : &fmha_block_dgrad_sm80_dq_dk_dv_loop_kernel); + } else if (params.seqlen_k == blocksize_c * 2) { + kernel = is_dropout + ? (is_causal ? &fmha_block_dgrad_sm80_dq_dk_dv_loop_kernel : &fmha_block_dgrad_sm80_dq_dk_dv_loop_kernel) + : (is_causal ? &fmha_block_dgrad_sm80_dq_dk_dv_loop_kernel : &fmha_block_dgrad_sm80_dq_dk_dv_loop_kernel); + } + + if( smem_size_dq_dk_dv >= 48 * 1024 ) { + FMHA_CHECK_CUDA(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv)); + } + dim3 grid(params.b, params.h); + kernel<<>>(params); + FMHA_CHECK_CUDA(cudaPeekAtLastError()); +} + +void run_fmha_block_dgrad_sm80(const FMHA_dgrad_params ¶ms, cudaStream_t stream) { + FP16_SWITCH(params.is_bf16, ([&] { + if (params.d == 16) { + using Kernel_traits = FMHA_kernel_traits<256, 16, 16, 1, 8, 0x08u, elem_type>; + run_fmha_block_dgrad_sm80_loop_(params, stream); + } else if (params.d == 32) { + using Kernel_traits = FMHA_kernel_traits<256, 32, 16, 1, 8, 0x08u, elem_type>; + run_fmha_block_dgrad_sm80_loop_(params, stream); + } else if (params.d == 64) { + using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 8, 0x100u, elem_type>; + run_fmha_block_dgrad_sm80_loop_(params, stream); + } + })); +} diff --git a/csrc/flash_attn/src/fmha_block_fprop_fp16_kernel.sm80.cu b/csrc/flash_attn/src/fmha_block_fprop_kernel.sm80.cu similarity index 67% rename from csrc/flash_attn/src/fmha_block_fprop_fp16_kernel.sm80.cu rename to csrc/flash_attn/src/fmha_block_fprop_kernel.sm80.cu index d1a90633e..01e1d324e 100644 --- a/csrc/flash_attn/src/fmha_block_fprop_fp16_kernel.sm80.cu +++ b/csrc/flash_attn/src/fmha_block_fprop_kernel.sm80.cu @@ -26,25 +26,26 @@ ******************************************************************************/ #include "fmha.h" +#include "static_switch.h" #include "fmha_block_fprop_kernel_1xN.h" template -__global__ void fmha_block_fprop_fp16_sm80_loop_kernel(FMHA_fprop_params params) { +__global__ void fmha_block_fprop_sm80_loop_kernel(FMHA_fprop_params params) { fmha::device_block_1xN_loop(params); } template -void run_fmha_block_fp16_sm80_loop_(Launch_params &launch_params, +void run_fmha_block_sm80_loop_(Launch_params &launch_params, const bool configure) { bool is_causal = launch_params.params.is_causal; // TD [2022-04-27]: This case work is pretty ugly, maybe there's a better way? auto kernel = launch_params.is_dropout ? (is_causal - ? (launch_params.return_softmax ? &fmha_block_fprop_fp16_sm80_loop_kernel : &fmha_block_fprop_fp16_sm80_loop_kernel) - : (launch_params.return_softmax ? &fmha_block_fprop_fp16_sm80_loop_kernel : &fmha_block_fprop_fp16_sm80_loop_kernel)) + ? (launch_params.return_softmax ? &fmha_block_fprop_sm80_loop_kernel : &fmha_block_fprop_sm80_loop_kernel) + : (launch_params.return_softmax ? &fmha_block_fprop_sm80_loop_kernel : &fmha_block_fprop_sm80_loop_kernel)) : (is_causal - ? (launch_params.return_softmax ? &fmha_block_fprop_fp16_sm80_loop_kernel : &fmha_block_fprop_fp16_sm80_loop_kernel) - : (launch_params.return_softmax ? &fmha_block_fprop_fp16_sm80_loop_kernel : &fmha_block_fprop_fp16_sm80_loop_kernel)); + ? (launch_params.return_softmax ? &fmha_block_fprop_sm80_loop_kernel : &fmha_block_fprop_sm80_loop_kernel) + : (launch_params.return_softmax ? &fmha_block_fprop_sm80_loop_kernel : &fmha_block_fprop_sm80_loop_kernel)); constexpr int blocksize_c = Kernel_traits::Cta_tile_p::N; const int loop_steps = (launch_params.params.seqlen_k + blocksize_c - 1) / blocksize_c; @@ -75,16 +76,18 @@ void run_fmha_block_fp16_sm80_loop_(Launch_params &launch_par FMHA_CHECK_CUDA(cudaPeekAtLastError()); } -void run_fmha_block_fp16_sm80(Launch_params &launch_params, +void run_fmha_block_sm80(Launch_params &launch_params, const bool configure) { - if (launch_params.params.d == 16) { - using Kernel_traits = FMHA_kernel_traits<256, 16, 16, 1, 4, 0x08u>; - run_fmha_block_fp16_sm80_loop_(launch_params, configure); - } else if (launch_params.params.d == 32) { - using Kernel_traits = FMHA_kernel_traits<256, 32, 16, 1, 4, 0x08u>; - run_fmha_block_fp16_sm80_loop_(launch_params, configure); - } else if (launch_params.params.d == 64) { - using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u>; - run_fmha_block_fp16_sm80_loop_(launch_params, configure); - } -} \ No newline at end of file + FP16_SWITCH(launch_params.params.is_bf16, ([&] { + if (launch_params.params.d == 16) { + using Kernel_traits = FMHA_kernel_traits<256, 16, 16, 1, 4, 0x08u, elem_type>; + run_fmha_block_sm80_loop_(launch_params, configure); + } else if (launch_params.params.d == 32) { + using Kernel_traits = FMHA_kernel_traits<256, 32, 16, 1, 4, 0x08u, elem_type>; + run_fmha_block_sm80_loop_(launch_params, configure); + } else if (launch_params.params.d == 64) { + using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u, elem_type>; + run_fmha_block_sm80_loop_(launch_params, configure); + } + })); +} diff --git a/csrc/flash_attn/src/fmha_block_fprop_kernel_1xN.h b/csrc/flash_attn/src/fmha_block_fprop_kernel_1xN.h index daa8c186e..50980db06 100644 --- a/csrc/flash_attn/src/fmha_block_fprop_kernel_1xN.h +++ b/csrc/flash_attn/src/fmha_block_fprop_kernel_1xN.h @@ -39,6 +39,13 @@ namespace fmha { template inline __device__ void device_block_1xN_(const Params ¶ms, const int bidb, const int bidh, int steps, Prng &ph0, Prng &ph1, const int loop_step_idx) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + using elem_type = typename Kernel_traits::elem_type; +#else + constexpr bool is_fp16_type = std::is_same::value; + assert(is_fp16_type); + using elem_type = __half; +#endif // The description of the CTA tile for the 1st batched GEMM. using Cta_tile_p = typename Kernel_traits::Cta_tile_p; @@ -73,7 +80,7 @@ inline __device__ void device_block_1xN_(const Params ¶ms, const int bidb, c using Smem_softmax_sum = typename Kernel_traits::Smem_dp_sum; - using Gemm1 = Gemm_Q_K; + using Gemm1 = Gemm_Q_K; using Softmax = fmha::Softmax; @@ -340,7 +347,7 @@ inline __device__ void device_block_1xN_(const Params ¶ms, const int bidb, c Frag_p frag_p[Mma_tile_o::MMAS_K][Mma_tile_o::MMAS_M]; static_assert(Mma_tile_o::MMAS_M == Mma_tile_p::MMAS_M); static_assert(Mma_tile_o::MMAS_K == Mma_tile_p::MMAS_N); - softmax.template pack<__half>(frag_p); + softmax.template pack(frag_p); if (Return_softmax) { gmem_s.store(frag_p, mask); if (not_last_iter) { @@ -358,7 +365,7 @@ inline __device__ void device_block_1xN_(const Params ¶ms, const int bidb, c for( int ki = 0; ki < Mma_tile_o::MMAS_K; ki++ ) { #pragma unroll for( int mi = 0; mi < Mma_tile_o::MMAS_M; mi++ ) { - frag_p[ki][mi].template hrelu_<__half>(); + frag_p[ki][mi].template hrelu_(); } } } @@ -370,7 +377,7 @@ inline __device__ void device_block_1xN_(const Params ¶ms, const int bidb, c // Do this part of O = P^T * V^T. #pragma unroll for( int ki = 0; ki < Mma_tile_o::MMAS_K; ++ki ) { - fmha::gemm_cl<__half>(acc_o, frag_p[ki], frag_v[ki]); + fmha::gemm_cl(acc_o, frag_p[ki], frag_v[ki]); } // The mapping from tidx to rows changes between the softmax and the O-reduction. @@ -471,7 +478,7 @@ inline __device__ void device_block_1xN_(const Params ¶ms, const int bidb, c // Output the values. if (is_final_write) { - gmem_o.template store<__half>(out, 0); + gmem_o.template store(out, 0); } else { gmem_o_tmp.store(out, 0); } From dc557473ec08c0f8d918bb7ebf8de27bab9dec92 Mon Sep 17 00:00:00 2001 From: kuizhiqing Date: Mon, 29 May 2023 10:39:20 +0000 Subject: [PATCH 4/7] fix bf16 bug in grad --- csrc/flash_attn/src/fmha/gemm.h | 3 +- csrc/flash_attn/src/fmha/utils.h | 36 ++++++++++++++----- .../src/fmha_block_dgrad_kernel_1xN_loop.h | 33 +++++++++++++++-- 3 files changed, 59 insertions(+), 13 deletions(-) diff --git a/csrc/flash_attn/src/fmha/gemm.h b/csrc/flash_attn/src/fmha/gemm.h index 2fff2b219..72d68b0ac 100644 --- a/csrc/flash_attn/src/fmha/gemm.h +++ b/csrc/flash_attn/src/fmha/gemm.h @@ -135,10 +135,11 @@ struct alignas(static_cast(Base_::ALIGNMENT)) Fragment : public Base_ { } // Multiply by another fragment. + template inline __device__ void hmul(const Fragment &other) { #pragma unroll for( int ii = 0; ii < Base_::NUM_REGS; ++ii ) { - this->reg(ii) = fmha::hmul2(this->reg(ii), other.reg(ii)); + this->reg(ii) = fmha::hmul2(this->reg(ii), other.reg(ii)); } } diff --git a/csrc/flash_attn/src/fmha/utils.h b/csrc/flash_attn/src/fmha/utils.h index 110dda25f..0494e4c0b 100644 --- a/csrc/flash_attn/src/fmha/utils.h +++ b/csrc/flash_attn/src/fmha/utils.h @@ -272,7 +272,11 @@ static inline __device__ uint32_t hmin2(uint32_t a, uint32_t b) { //////////////////////////////////////////////////////////////////////////////////////////////////// -static inline __device__ uint32_t hmul2(const uint32_t a, const uint32_t b) { +template +static inline __device__ uint32_t hmul2(const uint32_t a, const uint32_t b); + +template<> +inline __device__ uint32_t hmul2<__half>(const uint32_t a, const uint32_t b) { // uint32_t c; // asm volatile("mul.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b)); // return c; @@ -281,6 +285,18 @@ static inline __device__ uint32_t hmul2(const uint32_t a, const uint32_t b) { return reinterpret_cast(result); } +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +template<> +inline __device__ uint32_t hmul2<__nv_bfloat16>(const uint32_t a, const uint32_t b) { + // uint32_t c; + // asm volatile("mul.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b)); + // return c; + __nv_bfloat162 result = __hmul2(reinterpret_cast(a), + reinterpret_cast(b)); + return reinterpret_cast(result); +} +#endif + //////////////////////////////////////////////////////////////////////////////////////////////////// static inline __device__ uint2 hmul4(uint2 a, uint2 b) { @@ -292,23 +308,25 @@ static inline __device__ uint2 hmul4(uint2 a, uint2 b) { //////////////////////////////////////////////////////////////////////////////////////////////////// +template static inline __device__ uint4 hmul8(uint4 a, uint4 b) { uint4 c; - c.x = hmul2(a.x, b.x); - c.y = hmul2(a.y, b.y); - c.z = hmul2(a.z, b.z); - c.w = hmul2(a.w, b.w); + c.x = hmul2(a.x, b.x); + c.y = hmul2(a.y, b.y); + c.z = hmul2(a.z, b.z); + c.w = hmul2(a.w, b.w); return c; } //////////////////////////////////////////////////////////////////////////////////////////////////// +template static inline __device__ uint4 hmul8(uint32_t a, uint4 b) { uint4 c; - c.x = hmul2(a, b.x); - c.y = hmul2(a, b.y); - c.z = hmul2(a, b.z); - c.w = hmul2(a, b.w); + c.x = hmul2(a, b.x); + c.y = hmul2(a, b.y); + c.z = hmul2(a, b.z); + c.w = hmul2(a, b.w); return c; } diff --git a/csrc/flash_attn/src/fmha_block_dgrad_kernel_1xN_loop.h b/csrc/flash_attn/src/fmha_block_dgrad_kernel_1xN_loop.h index 01282691c..9ad4569b1 100644 --- a/csrc/flash_attn/src/fmha_block_dgrad_kernel_1xN_loop.h +++ b/csrc/flash_attn/src/fmha_block_dgrad_kernel_1xN_loop.h @@ -31,6 +31,7 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params ¶ms, const int loop_step_idx) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + constexpr bool is_fp16_type = std::is_same::value; using elem_type = typename Kernel_traits::elem_type; #else constexpr bool is_fp16_type = std::is_same::value; @@ -38,6 +39,7 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params ¶ms, using elem_type = __half; #endif + // The description of the CTA tile for the 1st batched GEMM. using Cta_tile_p = typename Kernel_traits::Cta_tile_p; // The description of the CTA tile for the 2nd batched GEMM. @@ -262,7 +264,7 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params ¶ms, const uint32_t scale_dropout = params.scale_dropout; #pragma unroll for(int it=0; it < Gmem_tile_v::LDGS; it++){ - gmem_v.fetch_[it] = fmha::hmul8(scale_dropout, gmem_v.fetch_[it]); + gmem_v.fetch_[it] = fmha::hmul8(scale_dropout, gmem_v.fetch_[it]); } } @@ -485,10 +487,10 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params ¶ms, for( int mi = 0; mi < Mma_tile_p::MMAS_M; mi++ ) { #pragma unroll for( int ni = 0; ni < Mma_tile_p::MMAS_N; ni++ ) { - frag_p[mi][ni].hmul(frag_dp[mi][ni]); + frag_p[mi][ni].template hmul(frag_dp[mi][ni]); } } - } else { + } else if (is_fp16_type) { __half2 dp_sum_half[Mma_tile_p::MMAS_M * 2]; for (int mi = 0; mi < Mma_tile_p::MMAS_M * 2; mi++) { dp_sum_half[mi] = __float2half2_rn(dp_sum[mi]); @@ -511,6 +513,31 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params ¶ms, } } } + } else { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + __nv_bfloat162 dp_sum_half[Mma_tile_p::MMAS_M * 2]; + for (int mi = 0; mi < Mma_tile_p::MMAS_M * 2; mi++) { + dp_sum_half[mi] = __float2bfloat162_rn(dp_sum[mi]); + } + const __nv_bfloat16 zero_h = __nv_bfloat16(0.f); + #pragma unroll + for( int mi = 0; mi < Mma_tile_p::MMAS_M; mi++ ) { + #pragma unroll + for( int ni = 0; ni < Mma_tile_p::MMAS_N; ni++ ) { + #pragma unroll + for (int ii = 0; ii < 4; ++ii) { + const __nv_bfloat162 p = frag_p[mi][ni].template elt_as<__nv_bfloat162>(ii); + const __nv_bfloat162 pdp = __hmul2(p, frag_dp[mi][ni].template elt_as<__nv_bfloat162>(ii)); + // If this element is dropped, then frag_p stores -p instead of p. + // So pd holds -p * dp_sum in that case. + const __nv_bfloat162 pd = __hmul2(p, dp_sum_half[mi * 2 + (ii % 2)]); + const __nv_bfloat16 low = __low2bfloat16(p) >= zero_h ? __low2bfloat16(pdp) : __low2bfloat16(pd); + const __nv_bfloat16 high = __low2bfloat16(p) >= zero_h ? __low2bfloat16(pdp) : __low2bfloat16(pd); + frag_p[mi][ni].template elt_as<__nv_bfloat162>(ii) = __halves2bfloat162(low, high); + } + } + } +#endif } // Store dp to smem for transpose From ff74bc0c2032f92b97b13aa34ffcdb34eea3692e Mon Sep 17 00:00:00 2001 From: kuizhiqing Date: Wed, 31 May 2023 12:17:17 +0000 Subject: [PATCH 5/7] add dim128 support for block sparsewq --- csrc/flash_attn/src/fmha_block_dgrad_kernel_1xN_loop.h | 2 +- csrc/flash_attn/src/fmha_block_dgrad_kernel_loop.sm80.cu | 3 +++ csrc/flash_attn/src/fmha_block_fprop_kernel.sm80.cu | 3 +++ 3 files changed, 7 insertions(+), 1 deletion(-) diff --git a/csrc/flash_attn/src/fmha_block_dgrad_kernel_1xN_loop.h b/csrc/flash_attn/src/fmha_block_dgrad_kernel_1xN_loop.h index 9ad4569b1..1ed692eab 100644 --- a/csrc/flash_attn/src/fmha_block_dgrad_kernel_1xN_loop.h +++ b/csrc/flash_attn/src/fmha_block_dgrad_kernel_1xN_loop.h @@ -49,7 +49,7 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params ¶ms, fmha::Cta_tile_extd; static_assert(Cta_tile_dkv::M == 512 || Cta_tile_dkv::M == 256 || Cta_tile_dkv::M == 128); - static_assert(Cta_tile_dkv::N == 16 || Cta_tile_dkv::N == 32 || Cta_tile_dkv::N == 64); + static_assert(Cta_tile_dkv::N == 16 || Cta_tile_dkv::N == 32 || Cta_tile_dkv::N == 64 || Cta_tile_dkv::N == 128); static_assert(Cta_tile_dkv::K == 16); // The MMA tile for the 1st GEMM. diff --git a/csrc/flash_attn/src/fmha_block_dgrad_kernel_loop.sm80.cu b/csrc/flash_attn/src/fmha_block_dgrad_kernel_loop.sm80.cu index b33880f18..410b94409 100644 --- a/csrc/flash_attn/src/fmha_block_dgrad_kernel_loop.sm80.cu +++ b/csrc/flash_attn/src/fmha_block_dgrad_kernel_loop.sm80.cu @@ -61,6 +61,9 @@ void run_fmha_block_dgrad_sm80(const FMHA_dgrad_params ¶ms, cudaStream_t str } else if (params.d == 64) { using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 8, 0x100u, elem_type>; run_fmha_block_dgrad_sm80_loop_(params, stream); + } else if (params.d == 128) { + using Kernel_traits = FMHA_kernel_traits<256, 128, 16, 1, 8, 0x08u, elem_type>; + run_fmha_block_dgrad_sm80_loop_(params, stream); } })); } diff --git a/csrc/flash_attn/src/fmha_block_fprop_kernel.sm80.cu b/csrc/flash_attn/src/fmha_block_fprop_kernel.sm80.cu index 01e1d324e..0adc617aa 100644 --- a/csrc/flash_attn/src/fmha_block_fprop_kernel.sm80.cu +++ b/csrc/flash_attn/src/fmha_block_fprop_kernel.sm80.cu @@ -88,6 +88,9 @@ void run_fmha_block_sm80(Launch_params &launch_params, } else if (launch_params.params.d == 64) { using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u, elem_type>; run_fmha_block_sm80_loop_(launch_params, configure); + } else if (launch_params.params.d == 128) { + using Kernel_traits = FMHA_kernel_traits<256, 128, 16, 1, 4, 0x08u, elem_type>; + run_fmha_block_sm80_loop_(launch_params, configure); } })); } From 98625ce9fd8023d36986635a0d4bad097177f141 Mon Sep 17 00:00:00 2001 From: umiswing Date: Mon, 5 Jun 2023 08:18:53 +0000 Subject: [PATCH 6/7] Add 128 hdim. --- csrc/flash_attn/src/fmha_block_dgrad_kernel_1xN_loop.h | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/csrc/flash_attn/src/fmha_block_dgrad_kernel_1xN_loop.h b/csrc/flash_attn/src/fmha_block_dgrad_kernel_1xN_loop.h index ce5410fc8..a396eb48c 100644 --- a/csrc/flash_attn/src/fmha_block_dgrad_kernel_1xN_loop.h +++ b/csrc/flash_attn/src/fmha_block_dgrad_kernel_1xN_loop.h @@ -643,7 +643,9 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params ¶ms, // if (Is_dropout) { // dq_out[0] = fmha::fmul4(dq_out[0], params.rp_dropout); // } - dq_out[0] = fmha::fmul4(dq_out[0], params.scale_bmm1f); + for (int jj = 0; jj < Gmem_tile_dq::STGS_PER_LOOP; ++jj) { + dq_out[jj] = fmha::fmul4(dq_out[jj], params.scale_bmm1f); + } // Output the values. gmem_dq.template store<__half>(dq_out, 0); } else { From dd68e2c60edb6ca55ecc3d42161ed2566ba1df34 Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Wed, 7 Jun 2023 08:26:46 +0000 Subject: [PATCH 7/7] support seqlen parallel --- csrc/flash_attn/src/fmha.h | 2 +- .../src/fmha_block_dgrad_kernel_1xN_loop.h | 169 +++++++++++++++--- .../src/fmha_block_dgrad_kernel_loop.sm80.cu | 103 ++++++++--- 3 files changed, 231 insertions(+), 43 deletions(-) diff --git a/csrc/flash_attn/src/fmha.h b/csrc/flash_attn/src/fmha.h index 0d4c5c36d..3f6e0f654 100644 --- a/csrc/flash_attn/src/fmha.h +++ b/csrc/flash_attn/src/fmha.h @@ -216,4 +216,4 @@ bool run_fmha_bwd_with_mask_bias_hdim128(FMHA_dgrad_params ¶ms, cudaStream_t void run_fmha_block_sm80(Launch_params &launch_params, const bool configure); -void run_fmha_block_dgrad_sm80(const FMHA_dgrad_params ¶ms, cudaStream_t stream); +void run_fmha_block_dgrad_sm80(FMHA_dgrad_params ¶ms, cudaStream_t stream); diff --git a/csrc/flash_attn/src/fmha_block_dgrad_kernel_1xN_loop.h b/csrc/flash_attn/src/fmha_block_dgrad_kernel_1xN_loop.h index 2a7452d0d..ce36dab73 100644 --- a/csrc/flash_attn/src/fmha_block_dgrad_kernel_1xN_loop.h +++ b/csrc/flash_attn/src/fmha_block_dgrad_kernel_1xN_loop.h @@ -26,7 +26,105 @@ inline __device__ void dot_do_o(float (&sum)[M], const uint4 (&do_)[M], const ui //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template +inline __device__ void dot_do_o(const uint4 (&do_)[M], const uint4 (&o)[M], const float scale, + Gmem_softmax_sum gmem_softmax_d, int tidx) { + float sum[M]; + fmha::SumOp sum_op; + #pragma unroll + for (int mi = 0; mi < M; ++mi) { + sum[mi] = fmha::Allreduce::run( + fmha::hmulsum8(do_[mi], o[mi]), sum_op + ) * scale; + } + const int dp_sum_row = tidx / THREADS_PER_ROW; + if ((dp_sum_row < ROWS) && (tidx % THREADS_PER_ROW == 0)) { + gmem_softmax_d.store_row(reinterpret_cast(sum), dp_sum_row); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Just compute dot(do, o) and write the result (softmax_d) to global memory as a separate kernel. +// This is used in the case where we want to parallelize the backward across seqlen_k. +template +inline __device__ void compute_dot_do_o(const Params ¶ms) { + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + using elem_type = typename Kernel_traits::elem_type; +#else + constexpr bool is_fp16_type = std::is_same::value; + assert(is_fp16_type); + using elem_type = __half; +#endif + + // The description of the CTA tile for the 1st batched GEMM. + using Cta_tile_p = typename Kernel_traits::Cta_tile_p; + // The description of the CTA tile for the 3rd batched GEMM. + using Cta_tile_dkv = + fmha::Cta_tile_extd; + + static_assert(Cta_tile_dkv::N == 16 || Cta_tile_dkv::N == 32 || Cta_tile_dkv::N == 64 || Cta_tile_dkv::N == 128); + static_assert(Cta_tile_dkv::K == 16); + + // The global memory tile to load dO. + using Gmem_tile_do = typename Kernel_traits::Gmem_tile_do; + + // The global memory tile to load O.Loading O here is similar to loading dO. + using Gmem_tile_o = Gmem_tile_do; + + using Gmem_softmax_sum = typename Kernel_traits::Gmem_softmax_sum; + + // The block index for the batch. + const int bidb = blockIdx.x; + // The block index for the head. + const int bidh = blockIdx.y; + // The thread index. + const int tidx = threadIdx.x; + + // How many steps to jump per iteration. + const int step_stride = gridDim.z; + + const BlockInfoPadded binfo(params, bidb, bidh, tidx); + if( binfo.stop_early() ) return; + + // Allocate the global memory tile loader for dO. + Gmem_tile_do gmem_do(params.do_ptr, params.o_row_stride_in_elts, params.o_head_stride_in_elts, + params.d, binfo, tidx, true); + + // Allocate the global memory tile loader for O. + Gmem_tile_o gmem_o(params.o_ptr, params.o_row_stride_in_elts, params.o_head_stride_in_elts, + params.d, binfo, tidx, true); + + Gmem_softmax_sum gmem_softmax_d(params.dsoftmax_sum, params, tidx); + + static_assert(Cta_tile_p::N % Cta_tile_p::M == 0); + const int steps = (params.seqlen_q + Cta_tile_p::M - 1) / Cta_tile_p::M; + // Wind gmem tiles to the correct position. + gmem_do.move(blockIdx.z); + gmem_o.move(blockIdx.z); + gmem_softmax_d.move(blockIdx.z); + + // Load over the entire sequence length. + for (int l = blockIdx.z; l < steps; l += step_stride) { + if (l * Cta_tile_p::M >= binfo.actual_seqlen_q) + break; + + gmem_do.load(); + gmem_do.move(step_stride); + gmem_o.load(); + gmem_o.move(step_stride); + + dot_do_o( + gmem_do.fetch_, gmem_o.fetch_, params.p_dropout, gmem_softmax_d, tidx + ); + gmem_softmax_d.move(step_stride); + } // Outer loop over the sequence length. +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng &ph, const int loop_step_idx) { @@ -207,7 +305,10 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params ¶ms, gmem_q.move(block_row_idx_to_move); gmem_do.move(block_row_idx_to_move); gmem_o.move(block_row_idx_to_move); - gmem_dq.move(block_row_idx_to_move); + //gmem_dq.move(block_row_idx_to_move); + if (!Seq_parallel) { + gmem_dq.move(block_row_idx_to_move); + } // If Seq_parallel, we're not using gmem_dq at all gmem_dq_tmp.move(block_row_idx_to_move); // TODO: need to move gmem_s if we want the intermediate result for debugging gmem_softmax_lse.move(block_row_idx_to_move); @@ -613,7 +714,7 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params ¶ms, uint4 dq_out[Gmem_tile_dq::STGS_PER_LOOP]; // if (!Is_first) { gmem_dq_tmp.load(dq_out, 0); } - if (!is_first_read) { gmem_dq_tmp.load(dq_out, 0); } + if (!is_first_read && !Seq_parallel) { gmem_dq_tmp.load(dq_out, 0); } // __syncthreads(); // Commit the values for Q and dO into shared memory. @@ -667,25 +768,33 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params ¶ms, __syncthreads(); // Load from shared memory. - is_first_read ? smem_dq.template load(dq_out) : smem_dq.template load(dq_out); - - const bool is_final_write = - Is_last - || ((loop_step_idx + 1) * Cta_tile_p::N >= binfo.actual_seqlen_k) - || ((mask_val & 0x2) != 0) - || ((Is_causal) && (block_row_idx * Cta_tile_p::M < (loop_step_idx + 1) * Cta_tile_p::N)); - if (is_final_write) { - // if (Is_dropout) { - // dq_out[0] = fmha::fmul4(dq_out[0], params.rp_dropout); - // } - for (int jj = 0; jj < Gmem_tile_dq::STGS_PER_LOOP; ++jj) { - dq_out[jj] = fmha::fmul4(dq_out[jj], params.scale_bmm1f); - } - // Output the values. - gmem_dq.template store(dq_out, 0); - } else { - // Output the values. - gmem_dq_tmp.store(dq_out, 0); + (is_first_read || Seq_parallel) ? smem_dq.template load(dq_out) : smem_dq.template load(dq_out); + + if(!Seq_parallel){ + const bool is_final_write = + Is_last + || ((loop_step_idx + 1) * Cta_tile_p::N >= binfo.actual_seqlen_k) + || ((mask_val & 0x2) != 0) + || ((Is_causal) && (block_row_idx * Cta_tile_p::M < (loop_step_idx + 1) * Cta_tile_p::N)); + if (is_final_write) { + // if (Is_dropout) { + // dq_out[0] = fmha::fmul4(dq_out[0], params.rp_dropout); + // } + for (int jj = 0; jj < Gmem_tile_dq::STGS_PER_LOOP; ++jj) { + dq_out[jj] = fmha::fmul4(dq_out[jj], params.scale_bmm1f); + } + // Output the values. + gmem_dq.template store(dq_out, 0); + } else { + // Output the values. + gmem_dq_tmp.store(dq_out, 0); + } + }else{ + for (int jj = 0; jj < Gmem_tile_dq::STGS_PER_LOOP; ++jj) { + // dq_out[jj] = fmha::fmul4(dq_out[jj], params.scale_bmm1f); + dq_out[jj] = fmha::fmul4(dq_out[jj], params.scale_bmm1_rp_dropout); + } + gmem_dq_tmp.atomic_add(dq_out, 0); } // Move to the next part of the output. @@ -804,6 +913,22 @@ inline __device__ void compute_block_dq_dk_dv_1xN(const Params ¶ms) { } } +template +inline __device__ void compute_block_dq_dk_dv_1xN_seqparallel(const Params ¶ms) { + // The block index for the batch. + const int bidb = blockIdx.x; + // The block index for the head. + const int bidh = blockIdx.y; + // The thread index. + const int tidx = threadIdx.x; + + const int tidx_global = (bidb * params.h + bidh) * blockDim.x + tidx; + auto seeds = philox::unpack(params.philox_args); + Philox ph(std::get<0>(seeds), tidx_global, std::get<1>(seeds)); + + int loop_step_idx = blockIdx.z; + compute_block_dq_dk_dv_1xN_one_iter(params, ph, loop_step_idx); +} //////////////////////////////////////////////////////////////////////////////////////////////////// } // namespace fmha diff --git a/csrc/flash_attn/src/fmha_block_dgrad_kernel_loop.sm80.cu b/csrc/flash_attn/src/fmha_block_dgrad_kernel_loop.sm80.cu index 410b94409..9e97c02e1 100644 --- a/csrc/flash_attn/src/fmha_block_dgrad_kernel_loop.sm80.cu +++ b/csrc/flash_attn/src/fmha_block_dgrad_kernel_loop.sm80.cu @@ -3,14 +3,42 @@ #include "fmha.h" #include "static_switch.h" #include "fmha_block_dgrad_kernel_1xN_loop.h" +#include "cuda_utils.h" + +template +__global__ void fmha_bwd_dot_do_o_kernel(FMHA_dgrad_params params) { + fmha::compute_dot_do_o(params); +} template __global__ void fmha_block_dgrad_sm80_dq_dk_dv_loop_kernel(FMHA_dgrad_params params) { fmha::compute_block_dq_dk_dv_1xN(params); } +template +__global__ void fmha_block_dgrad_sm80_dq_dk_dv_loop_seqparallel_kernel(FMHA_dgrad_params params) { + fmha::compute_block_dq_dk_dv_1xN_seqparallel(params); +} +inline int num_splits_heuristic_bwd(int batch_nheads, int num_SMs, int ctas_per_sm, int seqlen, + int blocksize, bool is_causal) { + float n_waves_1 = float(batch_nheads) / (num_SMs * ctas_per_sm); + float eff_1 = n_waves_1 / ceil(n_waves_1); + int num_splits_parallel = seqlen / blocksize; + float n_waves_parallel = float(batch_nheads * num_splits_parallel) / (num_SMs * ctas_per_sm); + float eff_parallel_raw = n_waves_parallel / ceil(n_waves_parallel); + float discount_factor; + if (!is_causal) { + discount_factor = 1.f + float(blocksize) / seqlen; + } else { // For causal, parallelizing seems to help with load-balancing as well + // For example, if headdim=128, seqlen >= 1280 always prefers parallel + if (seqlen / blocksize >= 10) return num_splits_parallel; + discount_factor = 1.f + 0.5 * float(blocksize) / seqlen; + } + float eff_parallel = eff_parallel_raw / discount_factor; + return eff_1 >= eff_parallel ? 1 : num_splits_parallel; +} template -void run_fmha_block_dgrad_sm80_loop_(const FMHA_dgrad_params ¶ms, cudaStream_t stream) { +void run_fmha_block_dgrad_sm80_loop_(FMHA_dgrad_params ¶ms, cudaStream_t stream) { constexpr int smem_size_softmax = Kernel_traits::Cta_tile_p::M * Kernel_traits::Cta_tile_p::WARPS_N * sizeof(float); constexpr int smem_size_q = Kernel_traits::Smem_tile_q::BYTES_PER_TILE; constexpr int smem_size_v = Kernel_traits::Smem_tile_v::BYTES_PER_TILE; @@ -27,30 +55,65 @@ void run_fmha_block_dgrad_sm80_loop_(const FMHA_dgrad_params ¶ms, cudaStream bool is_dropout = params.p_dropout < 1.f; // params.p_dropout is the probability of "keeping" bool is_causal = params.is_causal; - auto kernel = is_dropout - ? (is_causal ? &fmha_block_dgrad_sm80_dq_dk_dv_loop_kernel : &fmha_block_dgrad_sm80_dq_dk_dv_loop_kernel) - : (is_causal ? &fmha_block_dgrad_sm80_dq_dk_dv_loop_kernel : &fmha_block_dgrad_sm80_dq_dk_dv_loop_kernel); - constexpr int blocksize_c = Kernel_traits::Cta_tile_p::N; - if (params.seqlen_k == blocksize_c) { - kernel = is_dropout - ? (is_causal ? &fmha_block_dgrad_sm80_dq_dk_dv_loop_kernel : &fmha_block_dgrad_sm80_dq_dk_dv_loop_kernel) - : (is_causal ? &fmha_block_dgrad_sm80_dq_dk_dv_loop_kernel : &fmha_block_dgrad_sm80_dq_dk_dv_loop_kernel); - } else if (params.seqlen_k == blocksize_c * 2) { - kernel = is_dropout + BOOL_SWITCH(is_dropout, IsDropoutConst, ([&] { + auto kernel = is_dropout + ? (is_causal ? &fmha_block_dgrad_sm80_dq_dk_dv_loop_kernel : &fmha_block_dgrad_sm80_dq_dk_dv_loop_kernel) + : (is_causal ? &fmha_block_dgrad_sm80_dq_dk_dv_loop_kernel : &fmha_block_dgrad_sm80_dq_dk_dv_loop_kernel); + constexpr int blocksize_c = Kernel_traits::Cta_tile_p::N; + if (params.seqlen_k == blocksize_c) { + kernel = is_dropout + ? (is_causal ? &fmha_block_dgrad_sm80_dq_dk_dv_loop_kernel : &fmha_block_dgrad_sm80_dq_dk_dv_loop_kernel) + : (is_causal ? &fmha_block_dgrad_sm80_dq_dk_dv_loop_kernel : &fmha_block_dgrad_sm80_dq_dk_dv_loop_kernel); + } else if (params.seqlen_k == blocksize_c * 2) { + kernel = is_dropout ? (is_causal ? &fmha_block_dgrad_sm80_dq_dk_dv_loop_kernel : &fmha_block_dgrad_sm80_dq_dk_dv_loop_kernel) : (is_causal ? &fmha_block_dgrad_sm80_dq_dk_dv_loop_kernel : &fmha_block_dgrad_sm80_dq_dk_dv_loop_kernel); - } + } - if( smem_size_dq_dk_dv >= 48 * 1024 ) { - FMHA_CHECK_CUDA(cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv)); - } - dim3 grid(params.b, params.h); - kernel<<>>(params); - FMHA_CHECK_CUDA(cudaPeekAtLastError()); + auto kernel_seqparallel = params.is_causal + ? &fmha_block_dgrad_sm80_dq_dk_dv_loop_seqparallel_kernel + : &fmha_block_dgrad_sm80_dq_dk_dv_loop_seqparallel_kernel; + + if( smem_size_dq_dk_dv >= 48 * 1024 ) { + FMHA_CHECK_CUDA(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv)); + FMHA_CHECK_CUDA(cudaFuncSetAttribute( + kernel_seqparallel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv)); + } + //dim3 grid(params.b, params.h); + //kernel<<>>(params); + //FMHA_CHECK_CUDA(cudaPeekAtLastError()); + // Automatically set num_splits to maximize occupancy + if (params.num_splits <= 0) { + int ctas_per_sm; + cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &ctas_per_sm, kernel, Kernel_traits::THREADS, smem_size_dq_dk_dv); + auto dprops = GetDeviceProperties(-1); + // printf("CTAS_PER_SM = %d, nSMs = %d\n", ctas_per_sm, dprops->multiProcessorCount); + constexpr int M = Kernel_traits::Cta_tile_p::M; + // We don't want more than 10 splits due to numerical error. + // Numerical error on dk/dv scales as sqrt(num_splits). + params.num_splits = num_splits_heuristic_bwd( + params.b * params.h, dprops->multiProcessorCount, + ctas_per_sm, params.seqlen_k, blocksize_c, params.is_causal + ); + } + //if (configure) return; + if (params.num_splits == 1) { + dim3 grid(params.b, params.h, params.num_splits); + kernel<<>>(params); + } else { + dim3 grid_dot(params.b, params.h, (params.seqlen_q + 128 - 1) / 128); + fmha_bwd_dot_do_o_kernel<<>>(params); + int num_splits = params.seqlen_k / blocksize_c; // seqlen_k is divisible by blocksize_c + dim3 grid(params.b, params.h, num_splits); + kernel_seqparallel<<>>(params); + } + FMHA_CHECK_CUDA(cudaPeekAtLastError()); + })); } -void run_fmha_block_dgrad_sm80(const FMHA_dgrad_params ¶ms, cudaStream_t stream) { +void run_fmha_block_dgrad_sm80(FMHA_dgrad_params ¶ms, cudaStream_t stream) { FP16_SWITCH(params.is_bf16, ([&] { if (params.d == 16) { using Kernel_traits = FMHA_kernel_traits<256, 16, 16, 1, 8, 0x08u, elem_type>;