Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
  • Loading branch information
AnnaTrainingG committed Dec 7, 2023
1 parent 89f1ce7 commit 4a63f7c
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 46 deletions.
80 changes: 40 additions & 40 deletions csrc/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,38 +18,38 @@ set(CUTLASS_3_DIR ${CMAKE_CURRENT_SOURCE_DIR}/cutlass)

set(FA2_SOURCES_CU
flash_attn/src/cuda_utils.cu
#flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu
#flash_attn/src/flash_fwd_hdim32_bf16_sm80.cu
#flash_attn/src/flash_fwd_hdim64_fp16_sm80.cu
#flash_attn/src/flash_fwd_hdim64_bf16_sm80.cu
#flash_attn/src/flash_fwd_hdim96_fp16_sm80.cu
#flash_attn/src/flash_fwd_hdim96_bf16_sm80.cu
#flash_attn/src/flash_fwd_hdim128_fp16_sm80.cu
#flash_attn/src/flash_fwd_hdim128_bf16_sm80.cu
#flash_attn/src/flash_fwd_hdim160_fp16_sm80.cu
#flash_attn/src/flash_fwd_hdim160_bf16_sm80.cu
#flash_attn/src/flash_fwd_hdim192_fp16_sm80.cu
#flash_attn/src/flash_fwd_hdim192_bf16_sm80.cu
#flash_attn/src/flash_fwd_hdim224_fp16_sm80.cu
#flash_attn/src/flash_fwd_hdim224_bf16_sm80.cu
#flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu
#flash_attn/src/flash_fwd_hdim256_bf16_sm80.cu
#flash_attn/src/flash_bwd_hdim32_fp16_sm80.cu
#flash_attn/src/flash_bwd_hdim32_bf16_sm80.cu
#flash_attn/src/flash_bwd_hdim64_fp16_sm80.cu
#flash_attn/src/flash_bwd_hdim64_bf16_sm80.cu
#flash_attn/src/flash_bwd_hdim96_fp16_sm80.cu
#flash_attn/src/flash_bwd_hdim96_bf16_sm80.cu
#flash_attn/src/flash_bwd_hdim128_fp16_sm80.cu
#flash_attn/src/flash_bwd_hdim128_bf16_sm80.cu
#flash_attn/src/flash_bwd_hdim160_fp16_sm80.cu
#flash_attn/src/flash_bwd_hdim160_bf16_sm80.cu
#flash_attn/src/flash_bwd_hdim192_fp16_sm80.cu
#flash_attn/src/flash_bwd_hdim192_bf16_sm80.cu
#flash_attn/src/flash_bwd_hdim224_fp16_sm80.cu
#flash_attn/src/flash_bwd_hdim224_bf16_sm80.cu
#flash_attn/src/flash_bwd_hdim256_fp16_sm80.cu
#flash_attn/src/flash_bwd_hdim256_bf16_sm80.cu
flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu
flash_attn/src/flash_fwd_hdim32_bf16_sm80.cu
flash_attn/src/flash_fwd_hdim64_fp16_sm80.cu
flash_attn/src/flash_fwd_hdim64_bf16_sm80.cu
flash_attn/src/flash_fwd_hdim96_fp16_sm80.cu
flash_attn/src/flash_fwd_hdim96_bf16_sm80.cu
flash_attn/src/flash_fwd_hdim128_fp16_sm80.cu
flash_attn/src/flash_fwd_hdim128_bf16_sm80.cu
flash_attn/src/flash_fwd_hdim160_fp16_sm80.cu
flash_attn/src/flash_fwd_hdim160_bf16_sm80.cu
flash_attn/src/flash_fwd_hdim192_fp16_sm80.cu
flash_attn/src/flash_fwd_hdim192_bf16_sm80.cu
flash_attn/src/flash_fwd_hdim224_fp16_sm80.cu
flash_attn/src/flash_fwd_hdim224_bf16_sm80.cu
flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu
flash_attn/src/flash_fwd_hdim256_bf16_sm80.cu
flash_attn/src/flash_bwd_hdim32_fp16_sm80.cu
flash_attn/src/flash_bwd_hdim32_bf16_sm80.cu
flash_attn/src/flash_bwd_hdim64_fp16_sm80.cu
flash_attn/src/flash_bwd_hdim64_bf16_sm80.cu
flash_attn/src/flash_bwd_hdim96_fp16_sm80.cu
flash_attn/src/flash_bwd_hdim96_bf16_sm80.cu
flash_attn/src/flash_bwd_hdim128_fp16_sm80.cu
flash_attn/src/flash_bwd_hdim128_bf16_sm80.cu
flash_attn/src/flash_bwd_hdim160_fp16_sm80.cu
flash_attn/src/flash_bwd_hdim160_bf16_sm80.cu
flash_attn/src/flash_bwd_hdim192_fp16_sm80.cu
flash_attn/src/flash_bwd_hdim192_bf16_sm80.cu
flash_attn/src/flash_bwd_hdim224_fp16_sm80.cu
flash_attn/src/flash_bwd_hdim224_bf16_sm80.cu
flash_attn/src/flash_bwd_hdim256_fp16_sm80.cu
flash_attn/src/flash_bwd_hdim256_bf16_sm80.cu
)

add_library(flashattn SHARED
Expand All @@ -63,12 +63,12 @@ target_include_directories(flashattn PRIVATE
set(FA1_SOURCES_CU
flash_attn_with_bias_and_mask/flash_attn_with_bias_mask.cu
flash_attn_with_bias_and_mask/src/cuda_utils.cu
#flash_attn_with_bias_and_mask/src/fmha_fwd_with_mask_bias_hdim32.cu
#flash_attn_with_bias_and_mask/src/fmha_fwd_with_mask_bias_hdim64.cu
#flash_attn_with_bias_and_mask/src/fmha_fwd_with_mask_bias_hdim128.cu
#flash_attn_with_bias_and_mask/src/fmha_bwd_with_mask_bias_hdim32.cu
#flash_attn_with_bias_and_mask/src/fmha_bwd_with_mask_bias_hdim64.cu
#flash_attn_with_bias_and_mask/src/fmha_bwd_with_mask_bias_hdim128.cu
flash_attn_with_bias_and_mask/src/fmha_fwd_with_mask_bias_hdim32.cu
flash_attn_with_bias_and_mask/src/fmha_fwd_with_mask_bias_hdim64.cu
flash_attn_with_bias_and_mask/src/fmha_fwd_with_mask_bias_hdim128.cu
flash_attn_with_bias_and_mask/src/fmha_bwd_with_mask_bias_hdim32.cu
flash_attn_with_bias_and_mask/src/fmha_bwd_with_mask_bias_hdim64.cu
flash_attn_with_bias_and_mask/src/fmha_bwd_with_mask_bias_hdim128.cu
flash_attn_with_bias_and_mask/src/utils.cu)

add_library(flashattn_with_bias_mask STATIC
Expand Down Expand Up @@ -103,7 +103,7 @@ target_compile_options(flashattn PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:
--expt-relaxed-constexpr
--expt-extended-lambda
--use_fast_math
"SHELL:-gencode arch=compute_80,code=sm_80 -gencode arch=compute_90,code=sm_90"
"SHELL:-gencode arch=compute_80,code=sm_80"
>)

target_compile_options(flashattn_with_bias_mask PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:
Expand All @@ -118,7 +118,7 @@ target_compile_options(flashattn_with_bias_mask PRIVATE $<$<COMPILE_LANGUAGE:CUD
--expt-relaxed-constexpr
--expt-extended-lambda
--use_fast_math
"SHELL:-gencode arch=compute_80,code=sm_80 -gencode arch=compute_90,code=sm_90"
"SHELL:-gencode arch=compute_80,code=sm_80"
>)

set(CMAKE_CXX_STANDARD 17)
Expand Down
16 changes: 10 additions & 6 deletions csrc/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def get_data_files():
data_files = []
#source_lib_path = 'libflashattn.so'
#data_files.append((".", [source_lib_path]))
data_files.append((".", ['flashattn_advanced.so']))
data_files.append((".", ['libflashattn_advanced.so']))
return data_files


Expand Down Expand Up @@ -213,9 +213,13 @@ def run(self):
},
python_requires=">=3.7",
install_requires=[
"paddle",
"einops",
"packaging",
"ninja",
],
"common",
"dual",
"tight>=0.1.0",
"data",
"prox",
"ninja", # Put ninja before paddle if paddle depends on it
"einops",
"packaging",
],
)

0 comments on commit 4a63f7c

Please sign in to comment.