Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix compilation with clang on ARM64 #1285

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

sclarkson
Copy link

While compiling with clang-17 on ARM64

FAILED: /home/sclarkson/flash-attention/build/temp.linux-aarch64-3.10/csrc/flash_attn/flash_api.o 
clang++-17 -MMD -MF /home/sclarkson/flash-attention/build/temp.linux-aarch64-3.10/csrc/flash_attn/flash_api.o.d -Wno-unused-result -Wsign-compare -DNDEBUG -g -fwrapv -O2 -Wall -g -fstack-protector-strong -Wformat -Werror=format-security -g -fwrapv -O2 -g -fstack-protector-strong -Wformat -Werror=format-security -Wdate-time -D_FORTIFY_SOURCE=2 -fPIC -I/home/sclarkson/flash-attention/csrc/flash_attn -I/home/sclarkson/flash-attention/csrc/flash_attn/src -I/home/sclarkson/flash-attention/csrc/cutlass/include -I/usr/lib/python3/dist-packages/torch/include -I/usr/lib/python3/dist-packages/torch/include/torch/csrc/api/include -I/usr/lib/python3/dist-packages/torch/include/TH -I/usr/lib/python3/dist-packages/torch/include/THC -I/usr/include/python3.10 -c -c /home/sclarkson/flash-attention/csrc/flash_attn/flash_api.cpp -o /home/sclarkson/flash-attention/build/temp.linux-aarch64-3.10/csrc/flash_attn/flash_api.o -O3 -std=c++17 -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1016"' -DTORCH_EXTENSION_NAME=flash_attn_2_cuda -D_GLIBCXX_USE_CXX11_ABI=1
/home/sclarkson/flash-attention/csrc/flash_attn/flash_api.cpp:440:38: error: non-constant-expression cannot be narrowed from type 'char' to 'DeviceIndex' (aka 'signed char') in initializer list [-Wc++11-narrowing]
  440 |     at::cuda::CUDAGuard device_guard{(char)q.get_device()};
      |                                      ^~~~~~~~~~~~~~~~~~~~
/home/sclarkson/flash-attention/csrc/flash_attn/flash_api.cpp:440:38: note: insert an explicit cast to silence this issue
  440 |     at::cuda::CUDAGuard device_guard{(char)q.get_device()};
      |                                      ^~~~~~~~~~~~~~~~~~~~
      |                                      static_cast<DeviceIndex>( )

It seems that in pytorch/pytorch@10f3abc (first released in 2.3.0) get_device() was changed to return a c10::DeviceIndex. Since c10::DeviceIndex is an alias to int8_t and char is unsigned on ARM, this is a narrowing conversion.

Simply removing the cast entirely would work on pytorch >= 2.3.0, but break on older versions.

Instead, switch to the device() API for which there is already a CUDAGuard constructor.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant