From 5ee8abbf79acd815073ed5f8877ea450677660d0 Mon Sep 17 00:00:00 2001 From: Shaltiel Shmidman Date: Fri, 9 Feb 2024 03:23:26 +0200 Subject: [PATCH 1/3] Added cmakelists + conditional directive for flash_api.cpp --- CMakeLists.txt | 92 +++++++++++++++++++++++++++++++++++ csrc/flash_attn/flash_api.cpp | 7 ++- 2 files changed, 98 insertions(+), 1 deletion(-) create mode 100644 CMakeLists.txt diff --git a/CMakeLists.txt b/CMakeLists.txt new file mode 100644 index 000000000..bf70b412a --- /dev/null +++ b/CMakeLists.txt @@ -0,0 +1,92 @@ +cmake_minimum_required(VERSION 3.18) + +project(FlashAttention) + +find_package(CUDA REQUIRED) +find_package(Torch REQUIRED PATHS ${LIBTORCH_PATH}) + +if(NOT CUDA_VERSION VERSION_GREATER_EQUAL "11.6") + message(FATAL_ERROR "CUDA version must be at least 11.6") +endif() + +# Set CMAKE_CXX_FLAGS to make sure -DNDEBUG is not set +set(CMAKE_CXX_FLAGS_RELEASE "/MD /O2 /Ob2 /DCXX_BUILD " CACHE STRING "Release flags" FORCE) + +# require c++17 +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED ON) + +set(CUDA_NVCC_FLAGS ${CUDA_NVCC_FLAGS};-std=c++17;-O3;-U__CUDA_NO_HALF_OPERATORS__;-U__CUDA_NO_HALF_CONVERSIONS__;-U__CUDA_NO_HALF2_OPERATORS__;-U__CUDA_NO_BFLOAT16_CONVERSIONS__;--expt-relaxed-constexpr;--expt-extended-lambda;--use_fast_math;--threads;4;-gencode;arch=compute_80,code=sm_80;) + +if(CUDA_VERSION VERSION_GREATER "11.8") + set(CUDA_NVCC_FLAGS ${CUDA_NVCC_FLAGS};-gencode;arch=compute_90,code=sm_90) +endif() + +if (EXISTS ${LIBTORCH_PATH}/include/ATen/CudaGeneratorImpl.h) + set(CMAKE_CXX_FLAGS ${CMAKE_CXX_FLAGS} /DOLD_GENERATOR_PATH) + set(CUDA_NVCC_FLAGS ${CUDA_NVCC_FLAGS};-DOLD_GENERATOR_PATH) +endif() + +include_directories( + ${CMAKE_CURRENT_SOURCE_DIR}/csrc/flash_attn + ${CMAKE_CURRENT_SOURCE_DIR}/csrc/flash_attn/src + ${CMAKE_CURRENT_SOURCE_DIR}/csrc/cutlass/include + ${CUDA_INCLUDE_DIRS} + ${TORCH_INCLUDE_DIRS} +) + +cuda_add_library(flash_attn SHARED + csrc/flash_attn/flash_api.cpp + csrc/flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu + csrc/flash_attn/src/flash_fwd_hdim32_bf16_sm80.cu + csrc/flash_attn/src/flash_fwd_hdim64_fp16_sm80.cu + csrc/flash_attn/src/flash_fwd_hdim64_bf16_sm80.cu + csrc/flash_attn/src/flash_fwd_hdim96_fp16_sm80.cu + csrc/flash_attn/src/flash_fwd_hdim96_bf16_sm80.cu + csrc/flash_attn/src/flash_fwd_hdim128_fp16_sm80.cu + csrc/flash_attn/src/flash_fwd_hdim128_bf16_sm80.cu + csrc/flash_attn/src/flash_fwd_hdim160_fp16_sm80.cu + csrc/flash_attn/src/flash_fwd_hdim160_bf16_sm80.cu + csrc/flash_attn/src/flash_fwd_hdim192_fp16_sm80.cu + csrc/flash_attn/src/flash_fwd_hdim192_bf16_sm80.cu + csrc/flash_attn/src/flash_fwd_hdim224_fp16_sm80.cu + csrc/flash_attn/src/flash_fwd_hdim224_bf16_sm80.cu + csrc/flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu + csrc/flash_attn/src/flash_fwd_hdim256_bf16_sm80.cu + csrc/flash_attn/src/flash_bwd_hdim32_fp16_sm80.cu + csrc/flash_attn/src/flash_bwd_hdim32_bf16_sm80.cu + csrc/flash_attn/src/flash_bwd_hdim64_fp16_sm80.cu + csrc/flash_attn/src/flash_bwd_hdim64_bf16_sm80.cu + csrc/flash_attn/src/flash_bwd_hdim96_fp16_sm80.cu + csrc/flash_attn/src/flash_bwd_hdim96_bf16_sm80.cu + csrc/flash_attn/src/flash_bwd_hdim128_fp16_sm80.cu + csrc/flash_attn/src/flash_bwd_hdim128_bf16_sm80.cu + csrc/flash_attn/src/flash_bwd_hdim160_fp16_sm80.cu + csrc/flash_attn/src/flash_bwd_hdim160_bf16_sm80.cu + csrc/flash_attn/src/flash_bwd_hdim192_fp16_sm80.cu + csrc/flash_attn/src/flash_bwd_hdim192_bf16_sm80.cu + csrc/flash_attn/src/flash_bwd_hdim224_fp16_sm80.cu + csrc/flash_attn/src/flash_bwd_hdim224_bf16_sm80.cu + csrc/flash_attn/src/flash_bwd_hdim256_fp16_sm80.cu + csrc/flash_attn/src/flash_bwd_hdim256_bf16_sm80.cu + csrc/flash_attn/src/flash_fwd_split_hdim32_fp16_sm80.cu + csrc/flash_attn/src/flash_fwd_split_hdim32_bf16_sm80.cu + csrc/flash_attn/src/flash_fwd_split_hdim64_fp16_sm80.cu + csrc/flash_attn/src/flash_fwd_split_hdim64_bf16_sm80.cu + csrc/flash_attn/src/flash_fwd_split_hdim96_fp16_sm80.cu + csrc/flash_attn/src/flash_fwd_split_hdim96_bf16_sm80.cu + csrc/flash_attn/src/flash_fwd_split_hdim128_fp16_sm80.cu + csrc/flash_attn/src/flash_fwd_split_hdim128_bf16_sm80.cu + csrc/flash_attn/src/flash_fwd_split_hdim160_fp16_sm80.cu + csrc/flash_attn/src/flash_fwd_split_hdim160_bf16_sm80.cu + csrc/flash_attn/src/flash_fwd_split_hdim192_fp16_sm80.cu + csrc/flash_attn/src/flash_fwd_split_hdim192_bf16_sm80.cu + csrc/flash_attn/src/flash_fwd_split_hdim224_fp16_sm80.cu + csrc/flash_attn/src/flash_fwd_split_hdim224_bf16_sm80.cu + csrc/flash_attn/src/flash_fwd_split_hdim256_fp16_sm80.cu + csrc/flash_attn/src/flash_fwd_split_hdim256_bf16_sm80.cu +) + +target_compile_definitions(flash_attn PRIVATE CXX_BUILD) +target_link_libraries(flash_attn "${TORCH_LIBRARIES}") + diff --git a/csrc/flash_attn/flash_api.cpp b/csrc/flash_attn/flash_api.cpp index 79284dc34..57129d7d9 100644 --- a/csrc/flash_attn/flash_api.cpp +++ b/csrc/flash_attn/flash_api.cpp @@ -3,7 +3,11 @@ ******************************************************************************/ // Include these 2 headers instead of torch/extension.h since we don't need all of the torch headers. +#ifdef CXX_BUILD +#include +#else #include +#endif #include #include #include @@ -17,7 +21,6 @@ #define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") - void set_params_fprop(Flash_fwd_params ¶ms, // sizes const size_t b, @@ -1459,6 +1462,7 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he return {out, softmax_lse}; } +#ifndef CXX_BUILD PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.doc() = "FlashAttention"; m.def("fwd", &mha_fwd, "Forward pass"); @@ -1467,3 +1471,4 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("varlen_bwd", &mha_varlen_bwd, "Backward pass (variable length)"); m.def("fwd_kvcache", &mha_fwd_kvcache, "Forward pass, with KV-cache"); } +#endif \ No newline at end of file From d1f9ad239ac892099561eac63a22a955e1ec6c86 Mon Sep 17 00:00:00 2001 From: Shaltiel Shmidman Date: Fri, 9 Feb 2024 12:04:25 +0200 Subject: [PATCH 2/3] Swapped cxx_build with py_build --- CMakeLists.txt | 1 - csrc/flash_attn/flash_api.cpp | 6 +++--- setup.py | 2 +- 3 files changed, 4 insertions(+), 5 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index bf70b412a..50d951151 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -87,6 +87,5 @@ cuda_add_library(flash_attn SHARED csrc/flash_attn/src/flash_fwd_split_hdim256_bf16_sm80.cu ) -target_compile_definitions(flash_attn PRIVATE CXX_BUILD) target_link_libraries(flash_attn "${TORCH_LIBRARIES}") diff --git a/csrc/flash_attn/flash_api.cpp b/csrc/flash_attn/flash_api.cpp index 57129d7d9..abe5adedd 100644 --- a/csrc/flash_attn/flash_api.cpp +++ b/csrc/flash_attn/flash_api.cpp @@ -2,8 +2,8 @@ * Copyright (c) 2024, Tri Dao. ******************************************************************************/ -// Include these 2 headers instead of torch/extension.h since we don't need all of the torch headers. -#ifdef CXX_BUILD +// Include ( or ) and headers instead of torch/extension.h since we don't need all of the torch headers. +#ifndef PY_BUILD #include #else #include @@ -1462,7 +1462,7 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he return {out, softmax_lse}; } -#ifndef CXX_BUILD +#ifdef PY_BUILD PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.doc() = "FlashAttention"; m.def("fwd", &mha_fwd, "Forward pass"); diff --git a/setup.py b/setup.py index de1503fa0..c7359395f 100644 --- a/setup.py +++ b/setup.py @@ -184,7 +184,7 @@ def append_nvcc_threads(nvcc_extra_args): "csrc/flash_attn/src/flash_fwd_split_hdim256_bf16_sm80.cu", ], extra_compile_args={ - "cxx": ["-O3", "-std=c++17"] + generator_flag, + "cxx": ["-O3", "-std=c++17", "-DPY_BUILD"] + generator_flag, "nvcc": append_nvcc_threads( [ "-O3", From b254f9c1c8a52771d4a404d3819d93a109a008b8 Mon Sep 17 00:00:00 2001 From: Shaltiel Shmidman Date: Fri, 9 Feb 2024 15:32:00 +0200 Subject: [PATCH 3/3] Added header file --- csrc/flash_attn/flash_api.cpp | 26 +------ csrc/flash_attn/flash_api.h | 139 ++++++++++++++++++++++++++++++++++ 2 files changed, 140 insertions(+), 25 deletions(-) create mode 100644 csrc/flash_attn/flash_api.h diff --git a/csrc/flash_attn/flash_api.cpp b/csrc/flash_attn/flash_api.cpp index abe5adedd..f05ac7841 100644 --- a/csrc/flash_attn/flash_api.cpp +++ b/csrc/flash_attn/flash_api.cpp @@ -2,20 +2,7 @@ * Copyright (c) 2024, Tri Dao. ******************************************************************************/ -// Include ( or ) and headers instead of torch/extension.h since we don't need all of the torch headers. -#ifndef PY_BUILD -#include -#else -#include -#endif -#include -#include -#include - -#include - -#include "flash.h" -#include "static_switch.h" +#include "flash_api.h" #define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA") #define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") @@ -1461,14 +1448,3 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he } return {out, softmax_lse}; } - -#ifdef PY_BUILD -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.doc() = "FlashAttention"; - m.def("fwd", &mha_fwd, "Forward pass"); - m.def("varlen_fwd", &mha_varlen_fwd, "Forward pass (variable length)"); - m.def("bwd", &mha_bwd, "Backward pass"); - m.def("varlen_bwd", &mha_varlen_bwd, "Backward pass (variable length)"); - m.def("fwd_kvcache", &mha_fwd_kvcache, "Forward pass, with KV-cache"); -} -#endif \ No newline at end of file diff --git a/csrc/flash_attn/flash_api.h b/csrc/flash_attn/flash_api.h new file mode 100644 index 000000000..286ce5a01 --- /dev/null +++ b/csrc/flash_attn/flash_api.h @@ -0,0 +1,139 @@ +/****************************************************************************** + * Copyright (c) 2024, Tri Dao. + ******************************************************************************/ + +// Include ( or ) and headers instead of torch/extension.h since we don't need all of the torch headers. +#ifndef PY_BUILD +#include +#else +#include +#endif +#include +#include +#include + +#include + +#include "flash.h" +#include "static_switch.h" + +#if !defined(PY_BUILD) && defined(_WIN32) +#ifdef CXX_BUILD +#define EXPORT __declspec(dllexport) +#else +#define EXPORT __declspec(dllimport) +#endif +#else +#define EXPORT +#endif + +EXPORT std::vector +mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size + const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size + const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size + c10::optional &out_, // batch_size x seqlen_q x num_heads x head_size + c10::optional &alibi_slopes_, // num_heads or batch_size x num_heads + const float p_dropout, + const float softmax_scale, + bool is_causal, + int window_size_left, + int window_size_right, + const bool return_softmax, + c10::optional gen_); + +EXPORT std::vector +mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + c10::optional &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor &cu_seqlens_q, // b+1 + const at::Tensor &cu_seqlens_k, // b+1 + c10::optional &seqused_k, // b. If given, only this many elements of each batch element's keys are used. + c10::optional &alibi_slopes_, // num_heads or b x num_heads + int max_seqlen_q, + const int max_seqlen_k, + const float p_dropout, + const float softmax_scale, + const bool zero_tensors, + bool is_causal, + int window_size_left, + int window_size_right, + const bool return_softmax, + c10::optional gen_); + +EXPORT std::vector +mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_size_og + const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size + const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size + const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size + const at::Tensor &out, // batch_size x seqlen_q x num_heads x head_size + const at::Tensor &softmax_lse, // b x h x seqlen_q + c10::optional &dq_, // batch_size x seqlen_q x num_heads x head_size + c10::optional &dk_, // batch_size x seqlen_k x num_heads_k x head_size + c10::optional &dv_, // batch_size x seqlen_k x num_heads_k x head_size + c10::optional &alibi_slopes_, // num_heads or batch_size x num_heads + const float p_dropout, // probability to drop + const float softmax_scale, + const bool is_causal, + int window_size_left, + int window_size_right, + const bool deterministic, + c10::optional gen_, + c10::optional &rng_state); + +EXPORT std::vector +mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size + const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor &out, // total_q x num_heads x head_size + const at::Tensor &softmax_lse, // b x h x s softmax logsumexp + c10::optional &dq_, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + c10::optional &dk_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + c10::optional &dv_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor &cu_seqlens_q, // b+1 + const at::Tensor &cu_seqlens_k, // b+1 + c10::optional &alibi_slopes_, // num_heads or b x num_heads + const int max_seqlen_q, + const int max_seqlen_k, // max sequence length to choose the kernel + const float p_dropout, // probability to drop + const float softmax_scale, + const bool zero_tensors, + const bool is_causal, + int window_size_left, + int window_size_right, + const bool deterministic, + c10::optional gen_, + c10::optional &rng_state); + +EXPORT std::vector +mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size + const at::Tensor &kcache, // batch_size_c x seqlen_k x num_heads_k x head_size or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table. + const at::Tensor &vcache, // batch_size_c x seqlen_k x num_heads_k x head_size or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table. + c10::optional &k_, // batch_size x seqlen_knew x num_heads_k x head_size + c10::optional &v_, // batch_size x seqlen_knew x num_heads_k x head_size + c10::optional &seqlens_k_, // batch_size + c10::optional &rotary_cos_, // seqlen_ro x (rotary_dim / 2) + c10::optional &rotary_sin_, // seqlen_ro x (rotary_dim / 2) + c10::optional &cache_batch_idx_, // indices to index into the KV cache + c10::optional &block_table_, // batch_size x max_num_blocks_per_seq + c10::optional &alibi_slopes_, // num_heads or batch_size x num_heads + c10::optional &out_, // batch_size x seqlen_q x num_heads x head_size + const float softmax_scale, + bool is_causal, + int window_size_left, + int window_size_right, + bool is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2 + int num_splits + ); + +#ifdef PY_BUILD +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.doc() = "FlashAttention"; + m.def("fwd", &mha_fwd, "Forward pass"); + m.def("varlen_fwd", &mha_varlen_fwd, "Forward pass (variable length)"); + m.def("bwd", &mha_bwd, "Backward pass"); + m.def("varlen_bwd", &mha_varlen_bwd, "Backward pass (variable length)"); + m.def("fwd_kvcache", &mha_fwd_kvcache, "Forward pass, with KV-cache"); +} +#endif \ No newline at end of file