Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add C++ build support for use with LibTorch #819

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 91 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
cmake_minimum_required(VERSION 3.18)

project(FlashAttention)

find_package(CUDA REQUIRED)
find_package(Torch REQUIRED PATHS ${LIBTORCH_PATH})

if(NOT CUDA_VERSION VERSION_GREATER_EQUAL "11.6")
message(FATAL_ERROR "CUDA version must be at least 11.6")
endif()

# Set CMAKE_CXX_FLAGS to make sure -DNDEBUG is not set
set(CMAKE_CXX_FLAGS_RELEASE "/MD /O2 /Ob2 /DCXX_BUILD " CACHE STRING "Release flags" FORCE)

# require c++17
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)

set(CUDA_NVCC_FLAGS ${CUDA_NVCC_FLAGS};-std=c++17;-O3;-U__CUDA_NO_HALF_OPERATORS__;-U__CUDA_NO_HALF_CONVERSIONS__;-U__CUDA_NO_HALF2_OPERATORS__;-U__CUDA_NO_BFLOAT16_CONVERSIONS__;--expt-relaxed-constexpr;--expt-extended-lambda;--use_fast_math;--threads;4;-gencode;arch=compute_80,code=sm_80;)

if(CUDA_VERSION VERSION_GREATER "11.8")
set(CUDA_NVCC_FLAGS ${CUDA_NVCC_FLAGS};-gencode;arch=compute_90,code=sm_90)
endif()

if (EXISTS ${LIBTORCH_PATH}/include/ATen/CudaGeneratorImpl.h)
set(CMAKE_CXX_FLAGS ${CMAKE_CXX_FLAGS} /DOLD_GENERATOR_PATH)
set(CUDA_NVCC_FLAGS ${CUDA_NVCC_FLAGS};-DOLD_GENERATOR_PATH)
endif()

include_directories(
${CMAKE_CURRENT_SOURCE_DIR}/csrc/flash_attn
${CMAKE_CURRENT_SOURCE_DIR}/csrc/flash_attn/src
${CMAKE_CURRENT_SOURCE_DIR}/csrc/cutlass/include
${CUDA_INCLUDE_DIRS}
${TORCH_INCLUDE_DIRS}
)

cuda_add_library(flash_attn SHARED
csrc/flash_attn/flash_api.cpp
csrc/flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu
csrc/flash_attn/src/flash_fwd_hdim32_bf16_sm80.cu
csrc/flash_attn/src/flash_fwd_hdim64_fp16_sm80.cu
csrc/flash_attn/src/flash_fwd_hdim64_bf16_sm80.cu
csrc/flash_attn/src/flash_fwd_hdim96_fp16_sm80.cu
csrc/flash_attn/src/flash_fwd_hdim96_bf16_sm80.cu
csrc/flash_attn/src/flash_fwd_hdim128_fp16_sm80.cu
csrc/flash_attn/src/flash_fwd_hdim128_bf16_sm80.cu
csrc/flash_attn/src/flash_fwd_hdim160_fp16_sm80.cu
csrc/flash_attn/src/flash_fwd_hdim160_bf16_sm80.cu
csrc/flash_attn/src/flash_fwd_hdim192_fp16_sm80.cu
csrc/flash_attn/src/flash_fwd_hdim192_bf16_sm80.cu
csrc/flash_attn/src/flash_fwd_hdim224_fp16_sm80.cu
csrc/flash_attn/src/flash_fwd_hdim224_bf16_sm80.cu
csrc/flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu
csrc/flash_attn/src/flash_fwd_hdim256_bf16_sm80.cu
csrc/flash_attn/src/flash_bwd_hdim32_fp16_sm80.cu
csrc/flash_attn/src/flash_bwd_hdim32_bf16_sm80.cu
csrc/flash_attn/src/flash_bwd_hdim64_fp16_sm80.cu
csrc/flash_attn/src/flash_bwd_hdim64_bf16_sm80.cu
csrc/flash_attn/src/flash_bwd_hdim96_fp16_sm80.cu
csrc/flash_attn/src/flash_bwd_hdim96_bf16_sm80.cu
csrc/flash_attn/src/flash_bwd_hdim128_fp16_sm80.cu
csrc/flash_attn/src/flash_bwd_hdim128_bf16_sm80.cu
csrc/flash_attn/src/flash_bwd_hdim160_fp16_sm80.cu
csrc/flash_attn/src/flash_bwd_hdim160_bf16_sm80.cu
csrc/flash_attn/src/flash_bwd_hdim192_fp16_sm80.cu
csrc/flash_attn/src/flash_bwd_hdim192_bf16_sm80.cu
csrc/flash_attn/src/flash_bwd_hdim224_fp16_sm80.cu
csrc/flash_attn/src/flash_bwd_hdim224_bf16_sm80.cu
csrc/flash_attn/src/flash_bwd_hdim256_fp16_sm80.cu
csrc/flash_attn/src/flash_bwd_hdim256_bf16_sm80.cu
csrc/flash_attn/src/flash_fwd_split_hdim32_fp16_sm80.cu
csrc/flash_attn/src/flash_fwd_split_hdim32_bf16_sm80.cu
csrc/flash_attn/src/flash_fwd_split_hdim64_fp16_sm80.cu
csrc/flash_attn/src/flash_fwd_split_hdim64_bf16_sm80.cu
csrc/flash_attn/src/flash_fwd_split_hdim96_fp16_sm80.cu
csrc/flash_attn/src/flash_fwd_split_hdim96_bf16_sm80.cu
csrc/flash_attn/src/flash_fwd_split_hdim128_fp16_sm80.cu
csrc/flash_attn/src/flash_fwd_split_hdim128_bf16_sm80.cu
csrc/flash_attn/src/flash_fwd_split_hdim160_fp16_sm80.cu
csrc/flash_attn/src/flash_fwd_split_hdim160_bf16_sm80.cu
csrc/flash_attn/src/flash_fwd_split_hdim192_fp16_sm80.cu
csrc/flash_attn/src/flash_fwd_split_hdim192_bf16_sm80.cu
csrc/flash_attn/src/flash_fwd_split_hdim224_fp16_sm80.cu
csrc/flash_attn/src/flash_fwd_split_hdim224_bf16_sm80.cu
csrc/flash_attn/src/flash_fwd_split_hdim256_fp16_sm80.cu
csrc/flash_attn/src/flash_fwd_split_hdim256_bf16_sm80.cu
)

target_link_libraries(flash_attn "${TORCH_LIBRARIES}")

21 changes: 1 addition & 20 deletions csrc/flash_attn/flash_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,12 @@
* Copyright (c) 2024, Tri Dao.
******************************************************************************/

// Include these 2 headers instead of torch/extension.h since we don't need all of the torch headers.
#include <torch/python.h>
#include <torch/nn/functional.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>

#include <cutlass/numeric_types.h>

#include "flash.h"
#include "static_switch.h"
#include "flash_api.h"

#define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA")
#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")


void set_params_fprop(Flash_fwd_params &params,
// sizes
const size_t b,
Expand Down Expand Up @@ -1458,12 +1448,3 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
}
return {out, softmax_lse};
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.doc() = "FlashAttention";
m.def("fwd", &mha_fwd, "Forward pass");
m.def("varlen_fwd", &mha_varlen_fwd, "Forward pass (variable length)");
m.def("bwd", &mha_bwd, "Backward pass");
m.def("varlen_bwd", &mha_varlen_bwd, "Backward pass (variable length)");
m.def("fwd_kvcache", &mha_fwd_kvcache, "Forward pass, with KV-cache");
}
139 changes: 139 additions & 0 deletions csrc/flash_attn/flash_api.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
/******************************************************************************
* Copyright (c) 2024, Tri Dao.
******************************************************************************/

// Include (<torch/python.h> or <torch/all.h>) and <torch/nn/functional.h> headers instead of torch/extension.h since we don't need all of the torch headers.
#ifndef PY_BUILD
#include <torch/all.h>
#else
#include <torch/python.h>
#endif
#include <torch/nn/functional.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>

#include <cutlass/numeric_types.h>

#include "flash.h"
#include "static_switch.h"

#if !defined(PY_BUILD) && defined(_WIN32)
#ifdef CXX_BUILD
#define EXPORT __declspec(dllexport)
#else
#define EXPORT __declspec(dllimport)
#endif
#else
#define EXPORT
#endif

EXPORT std::vector<at::Tensor>
mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size
const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size
c10::optional<at::Tensor> &out_, // batch_size x seqlen_q x num_heads x head_size
c10::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
const float p_dropout,
const float softmax_scale,
bool is_causal,
int window_size_left,
int window_size_right,
const bool return_softmax,
c10::optional<at::Generator> gen_);

EXPORT std::vector<at::Tensor>
mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
c10::optional<at::Tensor> &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
const at::Tensor &cu_seqlens_q, // b+1
const at::Tensor &cu_seqlens_k, // b+1
c10::optional<at::Tensor> &seqused_k, // b. If given, only this many elements of each batch element's keys are used.
c10::optional<at::Tensor> &alibi_slopes_, // num_heads or b x num_heads
int max_seqlen_q,
const int max_seqlen_k,
const float p_dropout,
const float softmax_scale,
const bool zero_tensors,
bool is_causal,
int window_size_left,
int window_size_right,
const bool return_softmax,
c10::optional<at::Generator> gen_);

EXPORT std::vector<at::Tensor>
mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_size_og
const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size
const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size
const at::Tensor &out, // batch_size x seqlen_q x num_heads x head_size
const at::Tensor &softmax_lse, // b x h x seqlen_q
c10::optional<at::Tensor> &dq_, // batch_size x seqlen_q x num_heads x head_size
c10::optional<at::Tensor> &dk_, // batch_size x seqlen_k x num_heads_k x head_size
c10::optional<at::Tensor> &dv_, // batch_size x seqlen_k x num_heads_k x head_size
c10::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
const float p_dropout, // probability to drop
const float softmax_scale,
const bool is_causal,
int window_size_left,
int window_size_right,
const bool deterministic,
c10::optional<at::Generator> gen_,
c10::optional<at::Tensor> &rng_state);

EXPORT std::vector<at::Tensor>
mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
const at::Tensor &out, // total_q x num_heads x head_size
const at::Tensor &softmax_lse, // b x h x s softmax logsumexp
c10::optional<at::Tensor> &dq_, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
c10::optional<at::Tensor> &dk_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
c10::optional<at::Tensor> &dv_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
const at::Tensor &cu_seqlens_q, // b+1
const at::Tensor &cu_seqlens_k, // b+1
c10::optional<at::Tensor> &alibi_slopes_, // num_heads or b x num_heads
const int max_seqlen_q,
const int max_seqlen_k, // max sequence length to choose the kernel
const float p_dropout, // probability to drop
const float softmax_scale,
const bool zero_tensors,
const bool is_causal,
int window_size_left,
int window_size_right,
const bool deterministic,
c10::optional<at::Generator> gen_,
c10::optional<at::Tensor> &rng_state);

EXPORT std::vector<at::Tensor>
mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
const at::Tensor &kcache, // batch_size_c x seqlen_k x num_heads_k x head_size or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
const at::Tensor &vcache, // batch_size_c x seqlen_k x num_heads_k x head_size or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
c10::optional<const at::Tensor> &k_, // batch_size x seqlen_knew x num_heads_k x head_size
c10::optional<const at::Tensor> &v_, // batch_size x seqlen_knew x num_heads_k x head_size
c10::optional<const at::Tensor> &seqlens_k_, // batch_size
c10::optional<const at::Tensor> &rotary_cos_, // seqlen_ro x (rotary_dim / 2)
c10::optional<const at::Tensor> &rotary_sin_, // seqlen_ro x (rotary_dim / 2)
c10::optional<const at::Tensor> &cache_batch_idx_, // indices to index into the KV cache
c10::optional<at::Tensor> &block_table_, // batch_size x max_num_blocks_per_seq
c10::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
c10::optional<at::Tensor> &out_, // batch_size x seqlen_q x num_heads x head_size
const float softmax_scale,
bool is_causal,
int window_size_left,
int window_size_right,
bool is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2
int num_splits
);

#ifdef PY_BUILD
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.doc() = "FlashAttention";
m.def("fwd", &mha_fwd, "Forward pass");
m.def("varlen_fwd", &mha_varlen_fwd, "Forward pass (variable length)");
m.def("bwd", &mha_bwd, "Backward pass");
m.def("varlen_bwd", &mha_varlen_bwd, "Backward pass (variable length)");
m.def("fwd_kvcache", &mha_fwd_kvcache, "Forward pass, with KV-cache");
}
#endif
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def append_nvcc_threads(nvcc_extra_args):
"csrc/flash_attn/src/flash_fwd_split_hdim256_bf16_sm80.cu",
],
extra_compile_args={
"cxx": ["-O3", "-std=c++17"] + generator_flag,
"cxx": ["-O3", "-std=c++17", "-DPY_BUILD"] + generator_flag,
"nvcc": append_nvcc_threads(
[
"-O3",
Expand Down