diff --git a/thunder/executors/nvfuserex_impl.py b/thunder/executors/nvfuserex_impl.py index 019697f71..f64ed05ac 100644 --- a/thunder/executors/nvfuserex_impl.py +++ b/thunder/executors/nvfuserex_impl.py @@ -441,8 +441,8 @@ class FusionDefinitionWrapper: cache_clear: None | Callable = None last_used: None | FusionDefinition = None last_inputs: None | Sequence[tuple] = None - store_inputs: bool = False, - _enable_options: None | list[str] = None, + store_inputs: bool = (False,) + _enable_options: None | list[str] = (None,) _disable_options: None | list[str] = None def __call__(self, *args): @@ -559,8 +559,12 @@ def create_fusion_definition_wrapper( store_inputs: None | bool = get_compile_option( "nv_store_fusion_inputs", "Allow nvFuser to store fusion inputs for repro." ) - _enable_options: None | list[str] = get_compile_option("nv_enable_options", "List of NVFUSER_ENABLE options to set.") - _disable_options: None | list[str] = get_compile_option("nv_disable_options", "List of NVFUSER_DISABLE options to set.") + _enable_options: None | list[str] = get_compile_option( + "nv_enable_options", "List of NVFUSER_ENABLE options to set." + ) + _disable_options: None | list[str] = get_compile_option( + "nv_disable_options", "List of NVFUSER_DISABLE options to set." + ) tensor_indices = [] for idx, x in enumerate(sorted_unique_inputs):