diff --git a/fx2ait/fx2ait/fx2ait.py b/fx2ait/fx2ait/fx2ait.py index d9f054932..33e0f79c2 100644 --- a/fx2ait/fx2ait/fx2ait.py +++ b/fx2ait/fx2ait/fx2ait.py @@ -74,6 +74,8 @@ def __init__( use_tanh_for_sigmoid: bool = False, profile_timeout: int = 500, optimize_for_compilation_time: bool = False, + allow_cutlass_sm90: bool = False, + force_cutlass_sm90: bool = False, ): """ Args: @@ -93,6 +95,8 @@ def __init__( save_remote_cache: whether to save the updated cache use_fast_math: whether to use fast math in CUDA kernels use_tanh_for_sigmoid: whether to use tanh to approximate sigmoid in CUDA kernels + allow_cutlass_sm90: generate cutlass sm90 kernels alongside sm80 kernels on sm90 arch + force_cutlass_sm90: only generate cutlass sm90 kernels on sm90 arch profile_timeout: timeout in seconds for AIT profilers to complete optimize_for_compilation_time: we use O1 and disable the ProfileImpl function to reduce compilation time. """ @@ -119,6 +123,8 @@ def __init__( self.use_fp16_acc = use_fp16_acc self.use_fast_math = use_fast_math self.use_tanh_for_sigmoid = use_tanh_for_sigmoid + self.allow_cutlass_sm90 = allow_cutlass_sm90 + self.force_cutlass_sm90 = force_cutlass_sm90 self.optimize_for_compilation_time = optimize_for_compilation_time self.hardware_target = self._create_target() self.input_specs = input_specs @@ -149,6 +155,8 @@ def _create_target(self): remote_cache_bytes=self.remote_cache_bytes, use_fast_math=self.use_fast_math, use_tanh_for_sigmoid=self.use_tanh_for_sigmoid, + allow_cutlass_sm90=self.allow_cutlass_sm90, + force_cutlass_sm90=self.force_cutlass_sm90, optimize_for_compilation_time=self.optimize_for_compilation_time, ) diff --git a/fx2ait/fx2ait/lower/lower_settings.py b/fx2ait/fx2ait/lower/lower_settings.py index 465c0bb9f..a63d19b7b 100644 --- a/fx2ait/fx2ait/lower/lower_settings.py +++ b/fx2ait/fx2ait/lower/lower_settings.py @@ -87,3 +87,7 @@ class LowerSettings: optimize_for_compilation_time: bool = False # If True, use tanh to approximate sigmoid in CUDA kernels use_tanh_for_sigmoid: bool = False + # generate cutlass sm90 kernels alongside sm80 kernels on sm90 arch + allow_cutlass_sm90: bool = False + # only generate cutlass sm90 kernels on sm90 arch + force_cutlass_sm90: bool = False diff --git a/python/aitemplate/backend/cuda/target_def.py b/python/aitemplate/backend/cuda/target_def.py index 2036ffa68..81d883e62 100644 --- a/python/aitemplate/backend/cuda/target_def.py +++ b/python/aitemplate/backend/cuda/target_def.py @@ -235,7 +235,17 @@ def __enter__(self): super().__enter__() self._gen_cutlass_lib_pkg() f_gen_ops = registry.get("cuda.gen_cutlass_ops") - self._operators = f_gen_ops(self._arch, self._cuda_version) + allow_cutlass_sm90 = ( + self._kwargs.get("allow_cutlass_sm90", False) + or environ.allow_cutlass_sm90_kernels() + ) + force_cutlass_sm90 = ( + self._kwargs.get("force_cutlass_sm90", False) + or environ.force_cutlass_sm90_kernels() + ) + self._operators = f_gen_ops( + self._arch, self._cuda_version, allow_cutlass_sm90, force_cutlass_sm90 + ) def __exit__(self, ptype, value, trace): super().__exit__(ptype, value, trace) diff --git a/python/aitemplate/backend/cuda/utils.py b/python/aitemplate/backend/cuda/utils.py index e87b2a4bb..f2a1c4900 100644 --- a/python/aitemplate/backend/cuda/utils.py +++ b/python/aitemplate/backend/cuda/utils.py @@ -18,10 +18,6 @@ import logging from aitemplate.backend import registry -from aitemplate.utils.environ import ( - allow_cutlass_sm90_kernels, - force_cutlass_sm90_kernels, -) from aitemplate.utils.mk_cutlass_lib.mk_cutlass_lib import mk_cutlass_lib # pylint: disable=C0103,C0415,W0707 @@ -51,7 +47,12 @@ def __init__(self, arch): @registry.reg("cuda.gen_cutlass_ops") -def gen_ops(arch, cuda_version): +def gen_ops( + arch, + cuda_version, + allow_cutlass_sm90, + force_cutlass_sm90, +): import cutlass_lib args = Args(arch) @@ -60,9 +61,9 @@ def gen_ops(arch, cuda_version): manifest = cutlass_lib.manifest.Manifest(args) if arch == "90": - if force_cutlass_sm90_kernels(): + if force_cutlass_sm90: cutlass_lib.generator.GenerateSM90(manifest, args.cuda_version) - elif allow_cutlass_sm90_kernels(): + elif allow_cutlass_sm90: cutlass_lib.generator.GenerateSM90(manifest, args.cuda_version) cutlass_lib.generator.GenerateSM80(manifest, args.cuda_version) cutlass_lib.extra_operation.GenerateSM80(manifest, args) diff --git a/tests/unittest/ops/test_gemm_bias.py b/tests/unittest/ops/test_gemm_bias.py index 330a030f2..60ac66b80 100644 --- a/tests/unittest/ops/test_gemm_bias.py +++ b/tests/unittest/ops/test_gemm_bias.py @@ -41,8 +41,12 @@ def __init__(self, *args, **kwargs): super(GEMMBiasTestCase, self).__init__(*args, **kwargs) self._test_id = 0 - def _test_rcr(self, Ms, N, K, test_name, dtype="float16"): - target = detect_target() + def _test_rcr( + self, Ms, N, K, test_name, dtype="float16", allow_sm90=False, force_sm90=False + ): + target = detect_target( + allow_cutlass_sm90=allow_sm90, force_cutlass_sm90=force_sm90 + ) tolerance_limits = _TOLERANCE_LIMITS[dtype] MDim = shape_utils.gen_int_var_min_max(Ms, name="m") X = Tensor(shape=[MDim, IntImm(K)], dtype=dtype, name="input_0", is_input=True) @@ -108,6 +112,26 @@ def test_rcr_bfloat16_bf16(self): ) def test_rcr_sm90(self) -> None: + with env_variables( + INSIDE_RE_WORKER="1", + FORCE_PROFILE="1", + ): + self._test_rcr( + Ms=[128], + N=32, + K=32, + test_name="target_fp16_allow_sm90", + dtype="float16", + allow_sm90=True, + ) + self._test_rcr( + Ms=[128], + N=32, + K=32, + test_name="target_fp16_force_sm90", + dtype="float16", + force_sm90=True, + ) with env_variables( AIT_FORCE_CUTLASS_SM90_KERNELS="1", INSIDE_RE_WORKER="1",