Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add nv_enable_options and nv_disable_options #1432

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
23 changes: 22 additions & 1 deletion thunder/executors/nvfuserex_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,8 @@ class FusionDefinitionWrapper:
last_used: None | FusionDefinition = None
last_inputs: None | Sequence[tuple] = None
store_inputs: bool = False
enable_options: None | list[str] = None
disable_options: None | list[str] = None

def __call__(self, *args):
fd = self.get_fd(self.to_descriptors(args))
Expand All @@ -446,8 +448,21 @@ def __call__(self, *args):
if self.store_inputs:
self.last_inputs = args

kwargs = {}
# Set device if set in one of the "factory" methods like full, iota, or uniform
kwargs = {"device": fd._selected_device} if hasattr(fd, "_selected_device") else {}
if hasattr(fd, "_selected_device"):
kwargs["device"] = fd._selected_device

if nvfuser_version() >= LooseVersion("0.2.23"):
# nvFuser expects empty list instead of None values.
kwargs["_enable_options"] = self.enable_options if self.enable_options is not None else []
kwargs["_disable_options"] = self.disable_options if self.disable_options is not None else []

elif self.enable_options or self.disable_options:
warnings.warn(
f"nv_enable_options/nv_disable_options require nvFuser version 0.2.23 and above, found version {nvfuser_version()}. These options will be ignored."
)

with annotate_for_profile(self.name):
return fd.execute(args, **kwargs)

Expand Down Expand Up @@ -540,6 +555,10 @@ def create_fusion_definition_wrapper(
store_inputs: None | bool = get_compile_option(
"nv_store_fusion_inputs", "Allow nvFuser to store fusion inputs for repro."
)
enable_options: None | list[str] = get_compile_option("nv_enable_options", "List of NVFUSER_ENABLE options to set.")
disable_options: None | list[str] = get_compile_option(
"nv_disable_options", "List of NVFUSER_DISABLE options to set."
)

tensor_indices = []
for idx, x in enumerate(sorted_unique_inputs):
Expand All @@ -561,6 +580,8 @@ def get_fd(input_descriptors) -> FusionDefinition:
get_fd.cache_info,
get_fd.cache_clear,
store_inputs=store_inputs,
enable_options=enable_options,
disable_options=disable_options,
)
return fdw

Expand Down
41 changes: 41 additions & 0 deletions thunder/tests/test_nvfuser.py
Original file line number Diff line number Diff line change
Expand Up @@ -1077,3 +1077,44 @@ def sdpa_fn(q, k, v, dropout_p, is_causal, scale):
ref_outputs = (ref_attn_out, *(inp.grad for inp in ref_tensor_inputs))
for nv_out, ref_out in zip(nv_outputs, ref_outputs):
torch.testing.assert_close(nv_out, ref_out)


@instantiate(
dtypes=(thunder.float32,),
devicetypes=(devices.DeviceType.CUDA,),
executors=(nvFuserExecutor,),
decorators=(
pytest.mark.skipif(
nvfuser_version() is None or nvfuser_version() < LooseVersion("0.2.23"),
reason="Requires nvFuser version 0.2.23 or later",
),
),
)
def test_enable_disable_options(executor, device: str, thunder_dtype: dtypes.dtype):

def fn(a, b):
return torch.matmul(a, b)

m, n, k = 24, 16, 16

dtype = ltorch.to_torch_dtype(thunder_dtype)
inps = [
torch.randn(m, k, device="cuda", dtype=dtype),
torch.randn(k, n, device="cuda", dtype=dtype),
]

compiled_func = thunder.jit(
fn,
executors_list=executor.executors_list(),
nv_enable_matmul=True,
nv_enable_options=["fuse_matmul"],
nv_disable_options=["matmul_expr_eval", "kernel_reuse"],
)
# The above combination of options enables matmul codegen and disables expr evaluation for matmul.
# Since matmul scheduler does not support float32 inputs, the execution should raise an error.
# By default, without using these options, the given fusion will run through expr eval scheduler correctly.
# NOTE: This test relies on `float32` being unsupported by nvFuser matmul scheduler.
# If this support is added, the test will need to be updated since it will no longer
# verify the functionality of the above flags.
with pytest.raises(RuntimeError, match="Can not find a scheduler to schedule fusion segment"):
out = compiled_func(*inps)
Priya2698 marked this conversation as resolved.
Show resolved Hide resolved
Loading