diff --git a/optimum/quanto/library/ext/cuda/__init__.py b/optimum/quanto/library/ext/cuda/__init__.py index b4cdb707..b7897390 100644 --- a/optimum/quanto/library/ext/cuda/__init__.py +++ b/optimum/quanto/library/ext/cuda/__init__.py @@ -24,6 +24,24 @@ _ext = None +def get_min_cuda_arch(): + capability_list = [] + supported_sm = [int(arch.split("_")[1]) for arch in torch.cuda.get_arch_list() if "sm_" in arch] + if supported_sm: + max_supported_sm = max((sm // 10, sm % 10) for sm in supported_sm) + for i in range(torch.cuda.device_count()): + capability = torch.cuda.get_device_capability(i) + # Capability of the device may be higher than what's supported by the user's + # NVCC, causing compilation error. User's NVCC is expected to match the one + # used to build pytorch, so we use the maximum supported capability of pytorch + # to clamp the capability. + capability = min(max_supported_sm, capability) + if capability not in capability_list: + capability_list.append(capability) + min_capability = min(sorted(capability_list)) if len(capability_list) > 0 else (0, 0) + return f"{min_capability[0]}{min_capability[1]}0" + + def ext(): """Helper to load the CUDA ext only when it is required""" global _ext @@ -44,15 +62,20 @@ def ext(): "--use_fast_math", "--threads=8", ] + # We need to know the minimum CUDA Arch to select only the relevant kernels + # but we cannot rely on __CUDA_ARCH__ as it is not set in host code (only on device code) + quanto_cuda_arch = get_min_cuda_arch() + extra_cuda_cflags += [f"-DQUANTO_CUDA_ARCH={quanto_cuda_arch}"] module_path = os.path.dirname(__file__) + sources = [ + f"{module_path}/unpack.cu", + f"{module_path}/awq/v2/gemm_cuda.cu", + f"{module_path}/awq/v2/gemv_cuda.cu", + f"{module_path}/pybind_module.cpp", + ] _ext = load( name="quanto_cuda", - sources=[ - f"{module_path}/unpack.cu", - f"{module_path}/awq/v2/gemm_cuda.cu", - f"{module_path}/awq/v2/gemv_cuda.cu", - f"{module_path}/pybind_module.cpp", - ], + sources=sources, extra_cflags=extra_cflags, extra_cuda_cflags=extra_cuda_cflags, ) diff --git a/optimum/quanto/library/ext/cuda/awq/v2/gemm_cuda.cu b/optimum/quanto/library/ext/cuda/awq/v2/gemm_cuda.cu index 3a9268c5..c333290e 100644 --- a/optimum/quanto/library/ext/cuda/awq/v2/gemm_cuda.cu +++ b/optimum/quanto/library/ext/cuda/awq/v2/gemm_cuda.cu @@ -5,7 +5,7 @@ #include #include -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +#if defined(QUANTO_CUDA_ARCH) and QUANTO_CUDA_ARCH >= 800 // The following GEMMs requires m16n8k16 which is only supported for CUDA arch after sm80 #define kInterleave 4