diff --git a/thunder/executors/nvfuserex_impl.py b/thunder/executors/nvfuserex_impl.py index 561f838a7..4e5a204f8 100644 --- a/thunder/executors/nvfuserex_impl.py +++ b/thunder/executors/nvfuserex_impl.py @@ -438,6 +438,8 @@ class FusionDefinitionWrapper: last_used: None | FusionDefinition = None last_inputs: None | Sequence[tuple] = None store_inputs: bool = False + enable_options: None | list[str] = None + disable_options: None | list[str] = None def __call__(self, *args): fd = self.get_fd(self.to_descriptors(args)) @@ -446,8 +448,21 @@ def __call__(self, *args): if self.store_inputs: self.last_inputs = args + kwargs = {} # Set device if set in one of the "factory" methods like full, iota, or uniform - kwargs = {"device": fd._selected_device} if hasattr(fd, "_selected_device") else {} + if hasattr(fd, "_selected_device"): + kwargs["device"] = fd._selected_device + + if nvfuser_version() >= LooseVersion("0.2.23"): + # nvFuser expects empty list instead of None values. + kwargs["_enable_options"] = self.enable_options if self.enable_options is not None else [] + kwargs["_disable_options"] = self.disable_options if self.disable_options is not None else [] + + elif self.enable_options or self.disable_options: + warnings.warn( + f"nv_enable_options/nv_disable_options require nvFuser version 0.2.23 and above, found version {nvfuser_version()}. These options will be ignored." + ) + with annotate_for_profile(self.name): return fd.execute(args, **kwargs) @@ -540,6 +555,10 @@ def create_fusion_definition_wrapper( store_inputs: None | bool = get_compile_option( "nv_store_fusion_inputs", "Allow nvFuser to store fusion inputs for repro." ) + enable_options: None | list[str] = get_compile_option("nv_enable_options", "List of NVFUSER_ENABLE options to set.") + disable_options: None | list[str] = get_compile_option( + "nv_disable_options", "List of NVFUSER_DISABLE options to set." + ) tensor_indices = [] for idx, x in enumerate(sorted_unique_inputs): @@ -561,6 +580,8 @@ def get_fd(input_descriptors) -> FusionDefinition: get_fd.cache_info, get_fd.cache_clear, store_inputs=store_inputs, + enable_options=enable_options, + disable_options=disable_options, ) return fdw diff --git a/thunder/tests/test_nvfuser.py b/thunder/tests/test_nvfuser.py index 32a3eee7c..f190a59fb 100644 --- a/thunder/tests/test_nvfuser.py +++ b/thunder/tests/test_nvfuser.py @@ -1077,3 +1077,44 @@ def sdpa_fn(q, k, v, dropout_p, is_causal, scale): ref_outputs = (ref_attn_out, *(inp.grad for inp in ref_tensor_inputs)) for nv_out, ref_out in zip(nv_outputs, ref_outputs): torch.testing.assert_close(nv_out, ref_out) + + +@instantiate( + dtypes=(thunder.float32,), + devicetypes=(devices.DeviceType.CUDA,), + executors=(nvFuserExecutor,), + decorators=( + pytest.mark.skipif( + nvfuser_version() is None or nvfuser_version() < LooseVersion("0.2.23"), + reason="Requires nvFuser version 0.2.23 or later", + ), + ), +) +def test_enable_disable_options(executor, device: str, thunder_dtype: dtypes.dtype): + + def fn(a, b): + return torch.matmul(a, b) + + m, n, k = 24, 16, 16 + + dtype = ltorch.to_torch_dtype(thunder_dtype) + inps = [ + torch.randn(m, k, device="cuda", dtype=dtype), + torch.randn(k, n, device="cuda", dtype=dtype), + ] + + compiled_func = thunder.jit( + fn, + executors_list=executor.executors_list(), + nv_enable_matmul=True, + nv_enable_options=["fuse_matmul"], + nv_disable_options=["matmul_expr_eval", "kernel_reuse"], + ) + # The above combination of options enables matmul codegen and disables expr evaluation for matmul. + # Since matmul scheduler does not support float32 inputs, the execution should raise an error. + # By default, without using these options, the given fusion will run through expr eval scheduler correctly. + # NOTE: This test relies on `float32` being unsupported by nvFuser matmul scheduler. + # If this support is added, the test will need to be updated since it will no longer + # verify the functionality of the above flags. + with pytest.raises(RuntimeError, match="Can not find a scheduler to schedule fusion segment"): + out = compiled_func(*inps)