Skip to content

Commit

Permalink
feat(cuda): compile according to capabilities
Browse files Browse the repository at this point in the history
The __CUDA_ARCH__ preprocessor variable is not exported when compiling
CUDA code for the host, but only for the device. This means that we need
another preprocessor variable to decide whether we compile the AWQ kernels.
  • Loading branch information
dacorvo committed Jun 12, 2024
1 parent fef7b60 commit 0ca8021
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 7 deletions.
35 changes: 29 additions & 6 deletions optimum/quanto/library/ext/cuda/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,24 @@
_ext = None


def get_min_cuda_arch():
capability_list = []
supported_sm = [int(arch.split("_")[1]) for arch in torch.cuda.get_arch_list() if "sm_" in arch]
if supported_sm:
max_supported_sm = max((sm // 10, sm % 10) for sm in supported_sm)
for i in range(torch.cuda.device_count()):
capability = torch.cuda.get_device_capability(i)
# Capability of the device may be higher than what's supported by the user's
# NVCC, causing compilation error. User's NVCC is expected to match the one
# used to build pytorch, so we use the maximum supported capability of pytorch
# to clamp the capability.
capability = min(max_supported_sm, capability)
if capability not in capability_list:
capability_list.append(capability)
min_capability = min(sorted(capability_list)) if len(capability_list) > 0 else (0, 0)
return f"{min_capability[0]}{min_capability[1]}0"


def ext():
"""Helper to load the CUDA ext only when it is required"""
global _ext
Expand All @@ -44,15 +62,20 @@ def ext():
"--use_fast_math",
"--threads=8",
]
# We need to know the minimum CUDA Arch to select only the relevant kernels
# but we cannot rely on __CUDA_ARCH__ as it is not set in host code (only on device code)
quanto_cuda_arch = get_min_cuda_arch()
extra_cuda_cflags += [f"-DQUANTO_CUDA_ARCH={quanto_cuda_arch}"]
module_path = os.path.dirname(__file__)
sources = [
f"{module_path}/unpack.cu",
f"{module_path}/awq/v2/gemm_cuda.cu",
f"{module_path}/awq/v2/gemv_cuda.cu",
f"{module_path}/pybind_module.cpp",
]
_ext = load(
name="quanto_cuda",
sources=[
f"{module_path}/unpack.cu",
f"{module_path}/awq/v2/gemm_cuda.cu",
f"{module_path}/awq/v2/gemv_cuda.cu",
f"{module_path}/pybind_module.cpp",
],
sources=sources,
extra_cflags=extra_cflags,
extra_cuda_cflags=extra_cuda_cflags,
)
Expand Down
2 changes: 1 addition & 1 deletion optimum/quanto/library/ext/cuda/awq/v2/gemm_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
#include <torch/extension.h>
#include <cuda_pipeline_primitives.h>

#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
#if defined(QUANTO_CUDA_ARCH) and QUANTO_CUDA_ARCH >= 800
// The following GEMMs requires m16n8k16 which is only supported for CUDA arch after sm80

#define kInterleave 4
Expand Down

0 comments on commit 0ca8021

Please sign in to comment.