diff --git a/csrc/CMakeLists.txt b/csrc/CMakeLists.txt index 5b305fa98..a19b92744 100644 --- a/csrc/CMakeLists.txt +++ b/csrc/CMakeLists.txt @@ -1,5 +1,7 @@ cmake_minimum_required(VERSION 3.9 FATAL_ERROR) project(flash-attention LANGUAGES CXX CUDA) +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED ON) find_package(Git QUIET REQUIRED) @@ -7,6 +9,11 @@ execute_process(COMMAND ${GIT_EXECUTABLE} submodule update --init --recursive WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} RESULT_VARIABLE GIT_SUBMOD_RESULT) +#cmake -DWITH_ADVANCED=ON +if (WITH_ADVANCED) + add_compile_definitions(PADDLE_WITH_ADVANCED) +endif() + add_definitions("-DFLASH_ATTN_WITH_TORCH=0") set(CUTLASS_3_DIR ${CMAKE_CURRENT_SOURCE_DIR}/cutlass) @@ -55,6 +62,7 @@ target_include_directories(flashattn PRIVATE flash_attn ${CUTLASS_3_DIR}/include) +if (WITH_ADVANCED) set(FA1_SOURCES_CU flash_attn_with_bias_and_mask/flash_attn_with_bias_mask.cu flash_attn_with_bias_and_mask/src/cuda_utils.cu @@ -65,6 +73,12 @@ set(FA1_SOURCES_CU flash_attn_with_bias_and_mask/src/fmha_bwd_with_mask_bias_hdim64.cu flash_attn_with_bias_and_mask/src/fmha_bwd_with_mask_bias_hdim128.cu flash_attn_with_bias_and_mask/src/utils.cu) +else() +set(FA1_SOURCES_CU + flash_attn_with_bias_and_mask/flash_attn_with_bias_mask.cu + flash_attn_with_bias_and_mask/src/cuda_utils.cu + flash_attn_with_bias_and_mask/src/utils.cu) +endif() add_library(flashattn_with_bias_mask STATIC flash_attn_with_bias_and_mask/ @@ -83,18 +97,14 @@ target_link_libraries(flashattn flashattn_with_bias_mask) add_dependencies(flashattn flashattn_with_bias_mask) +set(NVCC_ARCH_BIN 80 CACHE STRING "CUDA architectures") -if (NOT DEFINED NVCC_ARCH_BIN) - message(FATAL_ERROR "NVCC_ARCH_BIN is not defined.") -endif() - -if (NVCC_ARCH_BIN STREQUAL "") - message(FATAL_ERROR "NVCC_ARCH_BIN is not set.") -endif() +message("NVCC_ARCH_BIN is set to: ${NVCC_ARCH_BIN}") STRING(REPLACE "-" ";" FA_NVCC_ARCH_BIN ${NVCC_ARCH_BIN}) set(FA_GENCODE_OPTION "SHELL:") + foreach(arch ${FA_NVCC_ARCH_BIN}) if(${arch} GREATER_EQUAL 80) set(FA_GENCODE_OPTION "${FA_GENCODE_OPTION} -gencode arch=compute_${arch},code=sm_${arch}") @@ -131,7 +141,25 @@ target_compile_options(flashattn_with_bias_mask PRIVATE $<$) + INSTALL(TARGETS flashattn LIBRARY DESTINATION "lib") INSTALL(FILES capi/flash_attn.h DESTINATION "include") + +if (WITH_ADVANCED) + set_target_properties(flashattn PROPERTIES + OUTPUT_NAME libflashattn_advanced + PREFIX "" + ) + add_custom_target(build_whl + COMMAND ${CMAKE_COMMAND} -E env python ${CMAKE_SOURCE_DIR}/setup.py bdist_wheel + WORKING_DIRECTORY ${CMAKE_BINARY_DIR} + DEPENDS flashattn + COMMENT "Running build wheel" + ) + + add_custom_target(default_target DEPENDS build_whl) + + set_property(DIRECTORY PROPERTY DEFAULT_TARGET default_target) +endif() diff --git a/csrc/flash_attn/src/flash_bwd_launch_template.h b/csrc/flash_attn/src/flash_bwd_launch_template.h index 2c62e6c57..d611b5dea 100644 --- a/csrc/flash_attn/src/flash_bwd_launch_template.h +++ b/csrc/flash_attn/src/flash_bwd_launch_template.h @@ -64,6 +64,7 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params ¶ms, cudaStream_t stream, const bool is_attn_mask = params.attn_mask_ptr != nullptr; const bool is_deterministic = params.num_splits == 1; // printf("smem_size_dq_dk_dv = %d\n", smem_size_dq_dk_dv); +#ifdef PADDLE_WITH_ADVANCED BOOL_SWITCH(params.is_causal, IsCausalConst, [&] { BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] { BOOL_SWITCH(is_even_K, IsEvenKConst, [&] { @@ -82,6 +83,21 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params ¶ms, cudaStream_t stream, }); }); }); +#else + BOOL_SWITCH(params.is_causal, IsCausalConst, [&] { + BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] { + BOOL_SWITCH(is_even_K, IsEvenKConst, [&] { + auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel; + if (smem_size_dq_dk_dv >= 48 * 1024) { + C10_CUDA_CHECK(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv)); + } + kernel<<>>(params); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); + }); + }); +#endif auto kernel_dq = &flash_bwd_convert_dq_kernel; if (Kernel_traits::kSmemdQSize >= 48 * 1024) { diff --git a/csrc/flash_attn/src/flash_fwd_launch_template.h b/csrc/flash_attn/src/flash_fwd_launch_template.h index 6c6382617..ae707d0e9 100644 --- a/csrc/flash_attn/src/flash_fwd_launch_template.h +++ b/csrc/flash_attn/src/flash_fwd_launch_template.h @@ -36,6 +36,7 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { const bool return_softmax = params.p_ptr != nullptr; const bool is_attn_mask = params.attn_mask_ptr != nullptr; const bool is_equal_qk = (params.seqlen_q == params.seqlen_k) && (Is_causal) && (!is_attn_mask); +#ifdef PADDLE_WITH_ADVANCED BOOL_SWITCH(is_even_N, IsEvenNConst, [&] { BOOL_SWITCH(is_even_K, IsEvenKConst, [&] { BOOL_SWITCH(return_softmax, ReturnSoftmaxConst, [&] { @@ -59,6 +60,29 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { }); }); }); +#else + BOOL_SWITCH(is_even_N, IsEvenNConst, [&] { + BOOL_SWITCH(is_even_K, IsEvenKConst, [&] { + BOOL_SWITCH(return_softmax, ReturnSoftmaxConst, [&] { + BOOL_SWITCH(is_equal_qk, Is_equal_seq_qk, [&] { + // Will only return softmax if dropout, to reduce compilation time. + auto kernel = &flash_fwd_kernel; + // auto kernel = &flash_fwd_kernel; + if (smem_size >= 48 * 1024) { + C10_CUDA_CHECK(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + } + int ctas_per_sm; + cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size); + // printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm); + kernel<<>>(params); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); + }); + }); + }); +#endif } template diff --git a/csrc/setup.py b/csrc/setup.py new file mode 100644 index 000000000..060268cca --- /dev/null +++ b/csrc/setup.py @@ -0,0 +1,221 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Adapted from https://github.com/NVIDIA/apex/blob/master/setup.py +import ast +import logging +import os +import platform +import re +import shutil +import subprocess +import sys +import warnings +from pathlib import Path + +from packaging.version import parse +from setuptools import find_packages, setup +from setuptools.command.install import install as _install +from wheel.bdist_wheel import bdist_wheel as _bdist_wheel + +import paddle +from paddle.utils.cpp_extension.extension_utils import find_cuda_home + +version_detail = sys.version_info +python_version = platform.python_version() +version = version_detail[0] + version_detail[1] / 10 +env_version = os.getenv("PY_VERSION") + +if version < 3.7: + raise RuntimeError( + f"Paddle only supports Python version >= 3.7 now," + f"you are using Python {python_version}" + ) +elif env_version is None: + print(f"export PY_VERSION = { python_version }") + os.environ["PY_VERSION"] = python_version + +elif env_version != version: + warnings.warn( + f"You set PY_VERSION={env_version}, but" + f"your current python environment is {version}" + f"we will use your current python version to execute" + ) + os.environ["PY_VERSION"] = python_version + +paddle_include_path = paddle.sysconfig.get_include() +paddle_lib_path = paddle.sysconfig.get_lib() + +print("Paddle Include Path:", paddle_include_path) +print("Paddle Lib Path:", paddle_lib_path) + +# preparing parameters for setup() +paddle_version = paddle.version.full_version +cuda_version = paddle.version.cuda_version + + +with open("../../README.md", "r", encoding="utf-8") as fh: + long_description = fh.read() + + +# ninja build does not work unless include_dirs are abs path +this_dir = os.path.dirname(os.path.abspath(__file__)) +CUDA_HOME = find_cuda_home() +PACKAGE_NAME = "paddle_flash_attn" + + +def get_platform(): + """ + Returns the platform name as used in wheel filenames. + """ + if sys.platform.startswith('linux'): + return 'linux_x86_64' + elif sys.platform == 'darwin': + mac_version = '.'.join(platform.mac_ver()[0].split('.')[:2]) + return f'macosx_{mac_version}_x86_64' + elif sys.platform == 'win32': + return 'win_amd64' + else: + raise ValueError(f'Unsupported platform: {sys.platform}') + + +def get_cuda_bare_metal_version(cuda_dir): + raw_output = subprocess.check_output( + [cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True + ) + output = raw_output.split() + release_idx = output.index("release") + 1 + bare_metal_version = parse(output[release_idx].split(",")[0]) + + return raw_output, bare_metal_version + + +def _is_cuda_available(): + """ + Check whether CUDA is available. + """ + try: + assert len(paddle.static.cuda_places()) > 0 + return True + except Exception as e: + logging.warning( + "You are using GPU version PaddlePaddle, but there is no GPU " + "detected on your machine. Maybe CUDA devices is not set properly." + f"\n Original Error is {e}" + ) + return False + + +check = _is_cuda_available() +cmdclass = {} + + +def get_package_version(): + with open(Path(this_dir) / "../flash_attn" / "__init__.py", "r") as f: + version_match = re.search( + r"^__version__\s*=\s*(.*)$", f.read(), re.MULTILINE + ) + public_version = ast.literal_eval(version_match.group(1)) + local_version = os.environ.get("FLASH_ATTN_LOCAL_VERSION") + if local_version: + return f"{public_version}+{local_version}" + else: + return str(public_version) + + +def get_data_files(): + data_files = [] + #source_lib_path = 'libflashattn.so' + #data_files.append((".", [source_lib_path])) + data_files.append((".", ['libflashattn_advanced.so'])) + return data_files + + +class CustomWheelsCommand(_bdist_wheel): + """ + The CachedWheelsCommand plugs into the default bdist wheel, which is ran by pip when it cannot + find an existing wheel (which is currently the case for all flash attention installs). We use + the environment parameters to detect whether there is already a pre-built version of a compatible + wheel available and short-circuits the standard full build pipeline. + """ + + def run(self): + self.run_command('build_ext') + super().run() + # Determine the version numbers that will be used to determine the correct wheel + # We're using the CUDA version used to build paddle, not the one currently installed + # _, cuda_version_raw = get_cuda_bare_metal_version(CUDA_HOME) + python_version = f"cp{sys.version_info.major}{sys.version_info.minor}" + platform_name = get_platform() + flash_version = get_package_version() + cxx11_abi = "" # str(paddle._C.-D_GLIBCXX_USE_CXX11_ABI).upper() + + # Determine wheel URL based on CUDA version, paddle version, python version and OS + wheel_filename = f'{PACKAGE_NAME}-{flash_version}-cu{cuda_version}-paddle{paddle_version}-{python_version}-{python_version}-{platform_name}.whl' + impl_tag, abi_tag, plat_tag = self.get_tag() + original_wheel_name = ( + f"{self.wheel_dist_name}-{impl_tag}-{abi_tag}-{plat_tag}" + ) + + # new_wheel_name = wheel_filename + new_wheel_name = ( + f"{self.wheel_dist_name}-{python_version}-{abi_tag}-{plat_tag}" + ) + shutil.move( + f"{self.dist_dir}/{original_wheel_name}.whl", + f"{self.dist_dir}/{new_wheel_name}.whl", + ) + + +class CustomInstallCommand(_install): + def run(self): + _install.run(self) + install_path = self.install_lib + source_lib_path = os.path.abspath('libflashattn_advanced.so') + destination_lib_path = os.path.join(paddle_lib_path, 'libflashattn_advanced.so') + shutil.copy(f"{source_lib_path}", f"{destination_lib_path}") + + +setup( + name=PACKAGE_NAME, + version=get_package_version(), + packages=find_packages(), + data_files=get_data_files(), + package_data={PACKAGE_NAME: ['build/libflashattn.so']}, + author_email="Paddle-better@baidu.com", + description="Flash Attention: Fast and Memory-Efficient Exact Attention", + long_description=long_description, + long_description_content_type="text/markdown", + url="https://github.com/PaddlePaddle/flash-attention", + classifiers=[ + "Programming Language :: Python :: 37", + "License :: OSI Approved :: BSD License", + "Operating System :: Unix", + ], + cmdclass={ + 'bdist_wheel': CustomWheelsCommand, + 'install': CustomInstallCommand, + }, + python_requires=">=3.7", + install_requires=[ + "common", + "dual", + "tight>=0.1.0", + "data", + "prox", + "ninja", # Put ninja before paddle if paddle depends on it + "einops", + "packaging", +], +)