Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fa cmake #29

Open
wants to merge 36 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
b5b20b8
update
AnnaTrainingG Dec 6, 2023
199b9d6
has data
AnnaTrainingG Dec 6, 2023
a582b3a
update
AnnaTrainingG Dec 6, 2023
78080dd
update
AnnaTrainingG Dec 6, 2023
0d6766e
update
AnnaTrainingG Dec 6, 2023
17b89c9
updat
AnnaTrainingG Dec 6, 2023
1355060
update
AnnaTrainingG Dec 6, 2023
d536119
update
AnnaTrainingG Dec 6, 2023
66fc8a7
update
AnnaTrainingG Dec 6, 2023
41ebd07
all
AnnaTrainingG Dec 6, 2023
ad614e0
update
AnnaTrainingG Dec 6, 2023
c8d003a
update
AnnaTrainingG Dec 6, 2023
559a479
update
AnnaTrainingG Dec 6, 2023
bd670ae
80 90
AnnaTrainingG Dec 6, 2023
e4b5006
error
AnnaTrainingG Dec 6, 2023
4f7f1f0
update build ok
AnnaTrainingG Dec 6, 2023
8c12f72
update
AnnaTrainingG Dec 6, 2023
7b257e8
update
AnnaTrainingG Dec 6, 2023
4fd33ea
updaet
AnnaTrainingG Dec 6, 2023
f03a1df
updaet
AnnaTrainingG Dec 6, 2023
d810108
upate
AnnaTrainingG Dec 6, 2023
48eb647
update
AnnaTrainingG Dec 6, 2023
7bb6f31
update
AnnaTrainingG Dec 6, 2023
58563ba
udpate
AnnaTrainingG Dec 6, 2023
e856a05
update
AnnaTrainingG Dec 6, 2023
256a3c6
update
AnnaTrainingG Dec 7, 2023
af386bf
update
AnnaTrainingG Dec 7, 2023
3aca223
Update
AnnaTrainingG Dec 7, 2023
06edc27
update
AnnaTrainingG Dec 8, 2023
45fcc53
update
AnnaTrainingG Dec 8, 2023
940a8ae
default
AnnaTrainingG Dec 8, 2023
6b6c7a8
update
AnnaTrainingG Dec 8, 2023
18ae756
update equal
AnnaTrainingG Dec 8, 2023
d926c09
for so
AnnaTrainingG Dec 10, 2023
a2714eb
Update CMakeLists.txt
AnnaTrainingG Dec 11, 2023
a61e35b
update fa1 mask
AnnaTrainingG Dec 11, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 35 additions & 7 deletions csrc/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,12 +1,19 @@
cmake_minimum_required(VERSION 3.9 FATAL_ERROR)
project(flash-attention LANGUAGES CXX CUDA)
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)

find_package(Git QUIET REQUIRED)

execute_process(COMMAND ${GIT_EXECUTABLE} submodule update --init --recursive
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
RESULT_VARIABLE GIT_SUBMOD_RESULT)

#cmake -DWITH_ADVANCED=ON
if (WITH_ADVANCED)
add_compile_definitions(PADDLE_WITH_ADVANCED)
endif()

add_definitions("-DFLASH_ATTN_WITH_TORCH=0")

set(CUTLASS_3_DIR ${CMAKE_CURRENT_SOURCE_DIR}/cutlass)
Expand Down Expand Up @@ -55,6 +62,7 @@ target_include_directories(flashattn PRIVATE
flash_attn
${CUTLASS_3_DIR}/include)

if (WITH_ADVANCED)
set(FA1_SOURCES_CU
flash_attn_with_bias_and_mask/flash_attn_with_bias_mask.cu
flash_attn_with_bias_and_mask/src/cuda_utils.cu
Expand All @@ -65,6 +73,12 @@ set(FA1_SOURCES_CU
flash_attn_with_bias_and_mask/src/fmha_bwd_with_mask_bias_hdim64.cu
flash_attn_with_bias_and_mask/src/fmha_bwd_with_mask_bias_hdim128.cu
flash_attn_with_bias_and_mask/src/utils.cu)
else()
set(FA1_SOURCES_CU
flash_attn_with_bias_and_mask/flash_attn_with_bias_mask.cu
flash_attn_with_bias_and_mask/src/cuda_utils.cu
flash_attn_with_bias_and_mask/src/utils.cu)
endif()

add_library(flashattn_with_bias_mask STATIC
flash_attn_with_bias_and_mask/
Expand All @@ -83,18 +97,14 @@ target_link_libraries(flashattn flashattn_with_bias_mask)

add_dependencies(flashattn flashattn_with_bias_mask)

set(NVCC_ARCH_BIN 80 CACHE STRING "CUDA architectures")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个写法在外部设置了-DNVCC_ARCH_BIN=...的情况下,取值会是多少,是80还是外部设置的值?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

外部设置的值, 已经做过实验会拿到外部设置的


if (NOT DEFINED NVCC_ARCH_BIN)
message(FATAL_ERROR "NVCC_ARCH_BIN is not defined.")
endif()

if (NVCC_ARCH_BIN STREQUAL "")
message(FATAL_ERROR "NVCC_ARCH_BIN is not set.")
endif()
message("NVCC_ARCH_BIN is set to: ${NVCC_ARCH_BIN}")

STRING(REPLACE "-" ";" FA_NVCC_ARCH_BIN ${NVCC_ARCH_BIN})

set(FA_GENCODE_OPTION "SHELL:")

foreach(arch ${FA_NVCC_ARCH_BIN})
if(${arch} GREATER_EQUAL 80)
set(FA_GENCODE_OPTION "${FA_GENCODE_OPTION} -gencode arch=compute_${arch},code=sm_${arch}")
Expand Down Expand Up @@ -131,7 +141,25 @@ target_compile_options(flashattn_with_bias_mask PRIVATE $<$<COMPILE_LANGUAGE:CUD
"${FA_GENCODE_OPTION}"
>)


INSTALL(TARGETS flashattn
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

最终生成的动态库名称,在关闭、开启advance功能时,最好有所区分,这样Paddle框架中在加载动态库时容易区分些。

Copy link
Author

@AnnaTrainingG AnnaTrainingG Dec 8, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

关闭是就是libflashattn.so 开启的时候是libflashattn_advanced.so

LIBRARY DESTINATION "lib")

INSTALL(FILES capi/flash_attn.h DESTINATION "include")

if (WITH_ADVANCED)
set_target_properties(flashattn PROPERTIES
OUTPUT_NAME libflashattn_advanced
PREFIX ""
)
add_custom_target(build_whl
COMMAND ${CMAKE_COMMAND} -E env python ${CMAKE_SOURCE_DIR}/setup.py bdist_wheel
WORKING_DIRECTORY ${CMAKE_BINARY_DIR}
DEPENDS flashattn
COMMENT "Running build wheel"
)

add_custom_target(default_target DEPENDS build_whl)

set_property(DIRECTORY PROPERTY DEFAULT_TARGET default_target)
endif()
16 changes: 16 additions & 0 deletions csrc/flash_attn/src/flash_bwd_launch_template.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params &params, cudaStream_t stream,
const bool is_attn_mask = params.attn_mask_ptr != nullptr;
const bool is_deterministic = params.num_splits == 1;
// printf("smem_size_dq_dk_dv = %d\n", smem_size_dq_dk_dv);
#ifdef PADDLE_WITH_ADVANCED
BOOL_SWITCH(params.is_causal, IsCausalConst, [&] {
BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] {
BOOL_SWITCH(is_even_K, IsEvenKConst, [&] {
Expand All @@ -82,6 +83,21 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params &params, cudaStream_t stream,
});
});
});
#else
BOOL_SWITCH(params.is_causal, IsCausalConst, [&] {
BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] {
BOOL_SWITCH(is_even_K, IsEvenKConst, [&] {
auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, Is_dropout, IsCausalConst, IsEvenMNConst, IsEvenKConst, false,false>;
if (smem_size_dq_dk_dv >= 48 * 1024) {
C10_CUDA_CHECK(cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv));
}
kernel<<<grid_n, Kernel_traits::kNThreads, smem_size_dq_dk_dv, stream>>>(params);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
});
});
#endif
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的代码是否可以简化下?以免出现2个分支?比如对BOOL_SWITCH做个改进,WITH_ADVANCED前后走不同的定义,以免以后维护更加困难?类似于:

#ifdef PADDLE_WITH_ADVANCED
#define BOOL_SWITCH(...) ...
#else   
#define BOOL_SWITCH(...) ...
#endif


auto kernel_dq = &flash_bwd_convert_dq_kernel<Kernel_traits>;
if (Kernel_traits::kSmemdQSize >= 48 * 1024) {
Expand Down
24 changes: 24 additions & 0 deletions csrc/flash_attn/src/flash_fwd_launch_template.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) {
const bool return_softmax = params.p_ptr != nullptr;
const bool is_attn_mask = params.attn_mask_ptr != nullptr;
const bool is_equal_qk = (params.seqlen_q == params.seqlen_k) && (Is_causal) && (!is_attn_mask);
#ifdef PADDLE_WITH_ADVANCED
BOOL_SWITCH(is_even_N, IsEvenNConst, [&] {
BOOL_SWITCH(is_even_K, IsEvenKConst, [&] {
BOOL_SWITCH(return_softmax, ReturnSoftmaxConst, [&] {
Expand All @@ -59,6 +60,29 @@ void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) {
});
});
});
#else
BOOL_SWITCH(is_even_N, IsEvenNConst, [&] {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上。最好对BOOL_SWITCH进行改进。在WITH_ADVANCED开启前后走不同的定义,避免重复代码。

BOOL_SWITCH(is_even_K, IsEvenKConst, [&] {
BOOL_SWITCH(return_softmax, ReturnSoftmaxConst, [&] {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个分支是不是应该保留is_equal_qk模板?我理解非advance分支,需要是causal最优的性能版本

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已经修改

BOOL_SWITCH(is_equal_qk, Is_equal_seq_qk, [&] {
// Will only return softmax if dropout, to reduce compilation time.
auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout, Is_causal, IsEvenNConst, IsEvenKConst, ReturnSoftmaxConst && Is_dropout, false, Is_equal_seq_qk>;
// auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout, Is_causal, IsEvenNConst, true, ReturnSoftmaxConst && Is_dropout>;
if (smem_size >= 48 * 1024) {
C10_CUDA_CHECK(cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
}
int ctas_per_sm;
cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size);
// printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm);
kernel<<<grid, Kernel_traits::kNThreads, smem_size, stream>>>(params);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
});
});
});
#endif
}

template<typename T>
Expand Down
221 changes: 221 additions & 0 deletions csrc/setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,221 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Adapted from https://github.com/NVIDIA/apex/blob/master/setup.py
import ast
import logging
import os
import platform
import re
import shutil
import subprocess
import sys
import warnings
from pathlib import Path

from packaging.version import parse
from setuptools import find_packages, setup
from setuptools.command.install import install as _install
from wheel.bdist_wheel import bdist_wheel as _bdist_wheel

import paddle
from paddle.utils.cpp_extension.extension_utils import find_cuda_home

version_detail = sys.version_info
python_version = platform.python_version()
version = version_detail[0] + version_detail[1] / 10
env_version = os.getenv("PY_VERSION")

if version < 3.7:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个判断比较粗糙。version是浮点数而不是整数,建议改成使用version_detail整数的判断。

raise RuntimeError(
f"Paddle only supports Python version >= 3.7 now,"
f"you are using Python {python_version}"
)
elif env_version is None:
print(f"export PY_VERSION = { python_version }")
os.environ["PY_VERSION"] = python_version

elif env_version != version:
warnings.warn(
f"You set PY_VERSION={env_version}, but"
f"your current python environment is {version}"
f"we will use your current python version to execute"
)
os.environ["PY_VERSION"] = python_version

paddle_include_path = paddle.sysconfig.get_include()
paddle_lib_path = paddle.sysconfig.get_lib()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为啥要依赖paddle的路径呢

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为了将安装的libflash_attn_advanced.so拷贝到paddle路径下

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这种方式我不太确定,请@sneaxiy 也看下。我理解:

  1. FA动态图即使是安装在自己的目录下,应该也是能找到的
  2. FA打包成.whl后,对FA的依赖应该需要写到Paddle的requirements.txt里面,那安装FA的时候很可能是还没有安装Paddle的

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FA应该作为Paddle侧的一个外部算子?类似于xformers跟PyTorch的关系,而不是新的so直接去替换Paddle里的FA的so?


print("Paddle Include Path:", paddle_include_path)
print("Paddle Lib Path:", paddle_lib_path)

# preparing parameters for setup()
paddle_version = paddle.version.full_version
cuda_version = paddle.version.cuda_version


with open("../../README.md", "r", encoding="utf-8") as fh:
long_description = fh.read()


# ninja build does not work unless include_dirs are abs path
this_dir = os.path.dirname(os.path.abspath(__file__))
CUDA_HOME = find_cuda_home()
PACKAGE_NAME = "paddle_flash_attn"


def get_platform():
"""
Returns the platform name as used in wheel filenames.
"""
if sys.platform.startswith('linux'):
return 'linux_x86_64'
elif sys.platform == 'darwin':
mac_version = '.'.join(platform.mac_ver()[0].split('.')[:2])
return f'macosx_{mac_version}_x86_64'
elif sys.platform == 'win32':
return 'win_amd64'
else:
raise ValueError(f'Unsupported platform: {sys.platform}')


def get_cuda_bare_metal_version(cuda_dir):
raw_output = subprocess.check_output(
[cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True
)
output = raw_output.split()
release_idx = output.index("release") + 1
bare_metal_version = parse(output[release_idx].split(",")[0])

return raw_output, bare_metal_version


def _is_cuda_available():
"""
Check whether CUDA is available.
"""
try:
assert len(paddle.static.cuda_places()) > 0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是否有cuda设备,最好不要用paddle接口来判断

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

暂未找到其他方法

return True
except Exception as e:
logging.warning(
"You are using GPU version PaddlePaddle, but there is no GPU "
"detected on your machine. Maybe CUDA devices is not set properly."
f"\n Original Error is {e}"
)
return False


check = _is_cuda_available()
cmdclass = {}


def get_package_version():
with open(Path(this_dir) / "../flash_attn" / "__init__.py", "r") as f:
version_match = re.search(
r"^__version__\s*=\s*(.*)$", f.read(), re.MULTILINE
)
public_version = ast.literal_eval(version_match.group(1))
local_version = os.environ.get("FLASH_ATTN_LOCAL_VERSION")
if local_version:
return f"{public_version}+{local_version}"
else:
return str(public_version)


def get_data_files():
data_files = []
#source_lib_path = 'libflashattn.so'
#data_files.append((".", [source_lib_path]))
data_files.append((".", ['libflashattn_advanced.so']))
return data_files


class CustomWheelsCommand(_bdist_wheel):
"""
The CachedWheelsCommand plugs into the default bdist wheel, which is ran by pip when it cannot
find an existing wheel (which is currently the case for all flash attention installs). We use
the environment parameters to detect whether there is already a pre-built version of a compatible
wheel available and short-circuits the standard full build pipeline.
"""

def run(self):
self.run_command('build_ext')
super().run()
# Determine the version numbers that will be used to determine the correct wheel
# We're using the CUDA version used to build paddle, not the one currently installed
# _, cuda_version_raw = get_cuda_bare_metal_version(CUDA_HOME)
python_version = f"cp{sys.version_info.major}{sys.version_info.minor}"
platform_name = get_platform()
flash_version = get_package_version()
cxx11_abi = "" # str(paddle._C.-D_GLIBCXX_USE_CXX11_ABI).upper()

# Determine wheel URL based on CUDA version, paddle version, python version and OS
wheel_filename = f'{PACKAGE_NAME}-{flash_version}-cu{cuda_version}-paddle{paddle_version}-{python_version}-{python_version}-{platform_name}.whl'
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

whl包里面不需要加paddle版本吧?本身flashattn对paddle版本并没有依赖,是paddle对flashattn版本存在依赖

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的这个后续会去掉,现在的是默认版本:paddle_flash_attn-2.0.8-cp37-none-any.whl

impl_tag, abi_tag, plat_tag = self.get_tag()
original_wheel_name = (
f"{self.wheel_dist_name}-{impl_tag}-{abi_tag}-{plat_tag}"
)

# new_wheel_name = wheel_filename
new_wheel_name = (
f"{self.wheel_dist_name}-{python_version}-{abi_tag}-{plat_tag}"
)
shutil.move(
f"{self.dist_dir}/{original_wheel_name}.whl",
f"{self.dist_dir}/{new_wheel_name}.whl",
)


class CustomInstallCommand(_install):
def run(self):
_install.run(self)
install_path = self.install_lib
source_lib_path = os.path.abspath('libflashattn_advanced.so')
destination_lib_path = os.path.join(paddle_lib_path, 'libflashattn_advanced.so')
shutil.copy(f"{source_lib_path}", f"{destination_lib_path}")


setup(
name=PACKAGE_NAME,
version=get_package_version(),
packages=find_packages(),
data_files=get_data_files(),
package_data={PACKAGE_NAME: ['build/libflashattn.so']},
author_email="[email protected]",
description="Flash Attention: Fast and Memory-Efficient Exact Attention",
long_description=long_description,
long_description_content_type="text/markdown",
url="https://github.com/PaddlePaddle/flash-attention",
classifiers=[
"Programming Language :: Python :: 37",
"License :: OSI Approved :: BSD License",
"Operating System :: Unix",
],
cmdclass={
'bdist_wheel': CustomWheelsCommand,
'install': CustomInstallCommand,
},
python_requires=">=3.7",
install_requires=[
"common",
"dual",
"tight>=0.1.0",
"data",
"prox",
"ninja", # Put ninja before paddle if paddle depends on it
"einops",
"packaging",
],
)