diff --git a/thunder/benchmarks/benchmark_litgpt.py b/thunder/benchmarks/benchmark_litgpt.py index 2aa78cbe13..a237ed4354 100644 --- a/thunder/benchmarks/benchmark_litgpt.py +++ b/thunder/benchmarks/benchmark_litgpt.py @@ -239,6 +239,7 @@ def __init__( use_torchao_fp8_allgather: bool = False, use_torchao_fp8_precompute_scale_for_fsdp: bool = False, fp8_shard_intermediate_activation: bool = False, + save_dynamo_repro: str | None = None, ): seed = 1337 torch.manual_seed(seed) @@ -273,6 +274,11 @@ def __init__( self.dump_thunder_traces = dump_thunder_traces self.dump_memory_snapshot = dump_memory_snapshot self.fp8_shard_intermediate_activation = fp8_shard_intermediate_activation + if save_dynamo_repro is not None: + assert ( + "dynamo" in self.compile and "thunder" in self.compile + ), "save_dynamo_repro can only be used if --compile=thunder+dynamo" + self.save_dynamo_repro = save_dynamo_repro if use_torchao_fp8_linear: @@ -889,6 +895,9 @@ def benchmark_main(return_metrics_as_json=False, json_path="", **kwargs) -> None print(f"##########\n#{i}-th ThunderModule\n##########") print(b_traces[-1]) + if benchmark.save_dynamo_repro: + benchmark.backend.save_reproducer_to_folder(benchmark.save_dynamo_repro) + if global_rank in [0, None]: if return_metrics_as_json: benchmark.add_model_info_to_metrics() diff --git a/thunder/core/transform_common.py b/thunder/core/transform_common.py index a21d1b207c..05e41c8788 100644 --- a/thunder/core/transform_common.py +++ b/thunder/core/transform_common.py @@ -400,6 +400,9 @@ def reverse_transform_state_dict_for_submodule( ) -> dict[str, Any]: return state_dict + def __repr__(self) -> str: + return f"{self.__class__.__module__}.{self.__class__.__name__}()" + def order_proxies(bsyms: Sequence[BoundSymbol]) -> dict[str, int]: """computes a canonical ordering of proxies in the bound symbols based on the order of appearance diff --git a/thunder/dynamo/compiler.py b/thunder/dynamo/compiler.py index f55ee29657..5852fee827 100644 --- a/thunder/dynamo/compiler.py +++ b/thunder/dynamo/compiler.py @@ -7,11 +7,13 @@ import torch from thunder.core.baseutils import run_once -from thunder.dynamo.utils import recompile_graph +from thunder.core.utils import safe_zip +from thunder.dynamo.utils import recompile_graph, reproducer from thunder.dynamo.splitter import _splitter if TYPE_CHECKING: from thunder.dynamo.utils import SubgraphInfo + from os import PathLike @run_once @@ -81,3 +83,33 @@ def __call__(self, gm: torch.fx.GraphModule, sample_args: list[torch.SymInt, tor split_module, subgraph_info = _splitter(gm, self._thunder_jit, self._torch_compile, sample_args) self.subgraph_infos.append(subgraph_info) return split_module + + def save_reproducer_to_folder(self, reproducer_folder: str | PathLike): + """ + Save the reproducer script for the GraphModule executed by Thunder to the specified `reproducer_folder`. + Each saved script is named as "g[graph_id]_thunder_[module_id]", where: + + - `graph_id` indexes the graph generated by Dynamo, which is then passed to Thunder. + - `module_id` indexes the submodule split by the :func:`thunder.dynamo.utils._splitter`. + + Both `graph_id` and `module_id` start from 1. + """ + if not self.subgraph_infos: + raise TypeError(f"{self} doesn't seem to have been called yet.") + + for graph_idx, subgraph_info in enumerate(self.subgraph_infos): + thunder_module_names = [] + for node in subgraph_info.split_graph_module.graph.nodes: + target = node.target + if isinstance(target, str) and target.startswith("thunder_"): + thunder_module_names.append(target) + thunder_modules = subgraph_info.thunder_compiled_fns + example_inputs = subgraph_info.thunder_compiled_fns_example_inputs + for cur_module, example_input, cur_name in safe_zip(thunder_modules, example_inputs, thunder_module_names): + reproducer( + getattr(cur_module, "_model"), + self.thunder_options, + example_input, + reproducer_folder, + f"{graph_idx+1}_{cur_name}", + ) diff --git a/thunder/dynamo/compiler_graph_benchmark.py b/thunder/dynamo/compiler_graph_benchmark.py index eafd30ce0e..4585a1cae7 100644 --- a/thunder/dynamo/compiler_graph_benchmark.py +++ b/thunder/dynamo/compiler_graph_benchmark.py @@ -140,7 +140,7 @@ def __call__(self, gm: torch.fx.GraphModule, sample_args: list[torch.SymInt, tor cur_nodes = cur_module.graph.nodes # Greates random input values for the current module based on the faketensor 'example_value' of the placeholder node placeholders = list(n for n in cur_nodes if n.op == "placeholder") - args = chain(*map(_get_example_inputs_from_placeholder, placeholders)) + args = list(map(_get_example_inputs_from_placeholder, placeholders)) # Runs the benchmark on the original module with the generated random inputs self.run_bench(compiled_functions_to_submodule[cur_module], target, *args) self.graph_idx += 1 diff --git a/thunder/dynamo/splitter.py b/thunder/dynamo/splitter.py index b123400ec7..7440684e7a 100644 --- a/thunder/dynamo/splitter.py +++ b/thunder/dynamo/splitter.py @@ -1,5 +1,7 @@ from __future__ import annotations from typing import TYPE_CHECKING +from itertools import chain +from functools import partial import torch from torch.fx.passes.split_module import split_module @@ -15,6 +17,7 @@ update_node_and_submodule, recompile_graph, checkpoint_converter, + _get_example_inputs_from_placeholder, ) if TYPE_CHECKING: @@ -140,10 +143,17 @@ def is_thunder_supported_partition(node: torch.fx.Node) -> bool: # Call compile on the split region/s. thunder_compiled_fns = [] + example_input_metadatas = [] submodule_to_compiled_fns = {} for node in split_gm.graph.nodes: if is_thunder_supported_partition(node): graph_module = getattr(split_gm, node.name) + # Record the input tensor metadata of the current module based on the faketensor 'example_value' of the placeholder node + placeholders = list(n for n in graph_module.graph.nodes if n.op == "placeholder") + example_input_metadata = map( + partial(_get_example_inputs_from_placeholder, only_metadata=True), placeholders + ) + example_input_metadatas.append(list(example_input_metadata)) # Replace PyTorch operators within the checkpointed function with the corresponding Thunder operators checkpoint_converter(split_gm, graph_module) jit_fn = thunder_jit(graph_module) @@ -168,6 +178,7 @@ def is_thunder_supported_partition(node: torch.fx.Node) -> bool: gm, split_gm, thunder_compiled_fns, + example_input_metadatas, submodule_to_compiled_fns, split_reasons, ) diff --git a/thunder/dynamo/utils.py b/thunder/dynamo/utils.py index 8b4c690c0a..709936d8d3 100644 --- a/thunder/dynamo/utils.py +++ b/thunder/dynamo/utils.py @@ -6,8 +6,12 @@ import inspect import itertools import copy +import warnings +from pathlib import Path import torch +from torch.nn.modules.module import _addindent +from torch._subclasses.fake_tensor import FakeTensor from thunder.torch.default_torch_ops import torch_auto_registered_ops from thunder.torch import _torch_to_thunder_function_map @@ -16,6 +20,9 @@ if TYPE_CHECKING: from thunder.core.symbol import Symbol + import os + from typing import Any, TextIO + from collections.abc import Sequence auto_register_ops = set(itertools.chain(*torch_auto_registered_ops.values())) @@ -74,6 +81,26 @@ class SplitReason: exception: Exception | None = None +@dataclasses.dataclass +class ExampleInputMetaData: + """ + Describes the metadata of a tensor, used to generate a random tensor with matching properties + """ + + requires_grad: bool + layout: torch.layout + device: str | torch.device + dtype: torch.dtype + shape: list[int] + storage_shape: list[int] + strides: list[int] + min_val: int | None = None + max_val: int | None = None + + def stride(self) -> list[int]: + return self.strides + + @dataclasses.dataclass(frozen=True) class SubgraphInfo: """A dataclass containing information about a subgraph. @@ -84,6 +111,8 @@ class SubgraphInfo: thunder_compiled_fns: List of thunder optimized callables. This could be :obj:`None` if there the graph module was not supported by thunder. Look at the :attr:`split_reasons` for further information. + thunder_compiled_fns_example_inputs: List containing metadata of sample inputs for `thunder_compiled_fns`. + These inputs are used to generate random test inputs in the reproducer script. submodule_to_compiled_functions: Dict from subgraph to compiled function. This will be a dict with one pair in case the graph was not split. split_reasons: List of reasons explaining why the subgraph was split. @@ -93,13 +122,14 @@ class SubgraphInfo: original_graph_module: torch.fx.GraphModule split_graph_module: torch.fx.GraphModule | None thunder_compiled_fns: list[Callable] | None + thunder_compiled_fns_example_inputs: list[list[ExampleInputMetaData]] | None submodule_to_compiled_functions: dict[torch.fx.GraphModule, CompiledFunction] split_reasons: list | None = None -def _concrete_shape(x): +def _concrete_value(vals: torch.Size | Sequence): """ - Get the concrete shape for a FakeTensor if it has `torch.SymInt` in its shape. + Get the concrete value from the input `vals` if it contains `torch.SymInt`. """ def get_backed_value(s): @@ -108,7 +138,7 @@ def get_backed_value(s): # Value is already concrete. return s - return tuple(map(get_backed_value, x.shape)) + return tuple(map(get_backed_value, vals)) def get_proxy_inputs_from_node(node: torch.fx.Node) -> tuple[tuple, dict]: @@ -143,11 +173,12 @@ def make_tensor_proxy(arg_node): # Here, we only want to verify that thunder can run an operation. # So, it is ok to verify with concrete value. example_value = example_value.new_ones( - _concrete_shape(example_value), device=example_value.device, dtype=example_value.dtype + _concrete_value(example_value.shape), device=example_value.device, dtype=example_value.dtype ) elif isinstance(example_value, tuple): example_value = tuple( - e_v.new_ones(_concrete_shape(e_v), device=e_v.device, dtype=e_v.dtype) for e_v in example_value + e_v.new_ones(_concrete_value(e_v.shape), device=e_v.device, dtype=e_v.dtype) + for e_v in example_value ) else: # NOTE - This will be caught will be caught and be part of the SplitReason. @@ -420,43 +451,74 @@ def recompile_graph(gm: torch.fx.GraphModule): return gm.recompile() -def _get_example_inputs_from_placeholder(node) -> tuple[torch.Tensor]: +def _get_storage_shape(t: torch.Tensor): + shape = _concrete_value(t.shape) + if t.is_contiguous(): + return shape + strides = _concrete_value(t.stride()) + storage_size = sum(strides[i] * (shape[i] - 1) for i in range(len(shape))) + 1 + return (storage_size,) + + +def _get_example_input_tensor_metadata(t: torch.Tensor) -> ExampleInputMetaData: + meta_ev = ExampleInputMetaData( + t.requires_grad, + t.layout, + t.device, + t.dtype, + _concrete_value(t.shape), + _get_storage_shape(t), + _concrete_value(t.stride()), + ) + if not isinstance(t, FakeTensor): + minmax: tuple[torch.Tensor, torch.Tensor] = torch.aminmax(t) + meta_ev.min_val = minmax[0].cpu().item() + meta_ev.max_val = minmax[1].cpu().item() + return meta_ev + + +def _create_random_tensor_from_tensor_metadata(t: ExampleInputMetaData) -> torch.Tensor: from thunder.tests.make_tensor import make_tensor + return make_tensor(t.storage_shape, dtype=t.dtype, device=t.device, requires_grad=t.requires_grad).as_strided( + t.shape, t.stride() + ) + + +def _get_example_inputs_from_placeholder( + node: torch.fx.Node, only_metadata=False +) -> tuple[torch.Tensor | ExampleInputMetaData] | torch.Tensor | ExampleInputMetaData: + """Retrieves example input data for a given placeholder `torch.fx.Node`. + - When `only_metadata` is `False`: Generates and returns a random example tensor based on the node's expected shape and data type, etc. + - When `only_metadata` is `True`: Returns only the tensor's metadata (e.g., shape, data type) without generating an actual tensor. + """ check(node.op == "placeholder", lambda: f"The node must be placeholder type", ValueError) # Prefers to use actual example value in GraphArg if available if "grapharg" in node.meta: - example_value = node.meta["grapharg"].example - if isinstance(example_value, torch.Tensor): - return (example_value.detach().clone().requires_grad_(example_value.requires_grad),) - - check("example_value" in node.meta, lambda: "example_value does not exist in the meta of {node}", ValueError) + ev = node.meta["grapharg"].example + if isinstance(ev, torch.Tensor): + if only_metadata: + return _get_example_input_tensor_metadata(ev) + return ev.detach().clone().requires_grad_(ev.requires_grad) + + if "example_value" not in node.meta: + return None example_value = node.meta["example_value"] if isinstance(example_value, torch.Tensor): - sz = _concrete_shape(example_value) - return ( - make_tensor( - sz, - dtype=example_value.dtype, - device=example_value.device, - requires_grad=example_value.requires_grad, - ).as_strided(sz, example_value.stride()), - ) + ev_metadata = _get_example_input_tensor_metadata(example_value) + if only_metadata: + return ev_metadata + return _create_random_tensor_from_tensor_metadata(ev_metadata) elif isinstance(example_value, tuple): - return tuple( - make_tensor( - _concrete_shape(e_v), - dtype=e_v.dtype, - device=e_v.device, - requires_grad=e_v.requires_grad, - ).as_strided(_concrete_shape(e_v), e_v.stride()) - for e_v in example_value - ) + ev_metadatas = tuple(_get_example_input_tensor_metadata(e_v) for e_v in example_value) + if only_metadata: + return ev_metadatas + return tuple(_create_random_tensor_from_tensor_metadata(ev_metadata) for ev_metadata in ev_metadatas) + elif isinstance(example_value, torch.types.py_sym_types): + return example_value.node.hint else: - raise TypeError( - "The 'example_value' in the placeholder node is expected to be either a Tensor or a Tuple of Tensors." - ) + raise TypeError(f"Unsupported example_value type: {type(example_value)}") def _checkpoint_function_converter(gm: torch.fx.GraphModule): @@ -512,3 +574,193 @@ def checkpoint_converter(gm: torch.fx.GraphModule, sub_gm: torch.fx.GraphModule) else: function_module = getattr(gm, n.args[0].name) _checkpoint_function_converter(function_module) + + +def arg_like_tensor(arg: torch.Tensor | ExampleInputMetaData, f: TextIO): + """Creates a new argument like the given tensor or tensor metadata""" + if isinstance(arg, torch.Tensor): + min_val, max_val = torch.aminmax(arg) + min_val = min_val.cpu().item() + max_val = max_val.cpu().item() + else: + min_val, max_val = arg.min_val, arg.max_val + storage_shape = _get_storage_shape(arg) if isinstance(arg, torch.Tensor) else arg.storage_shape + if min_val is not None and min_val == max_val: + meta = f"{storage_shape}, {min_val}, dtype={arg.dtype}, device='{arg.device}', requires_grad={arg.requires_grad}, layout={arg.layout}" + print(f" torch.full({meta}).as_strided({arg.shape}, {arg.stride()}),", file=f) + return + meta = f"{storage_shape}, dtype={arg.dtype}, device='{arg.device}', requires_grad={arg.requires_grad}," + meta = f"{meta} low={min_val}, high={max_val}," + print(f" torch.testing.make_tensor({meta}).as_strided({arg.shape}, {arg.stride()}),", file=f) + + +def arg_like(arg: Any, f: TextIO): + """Creates a new argument that is similar to the given arg.""" + if isinstance(arg, (torch.Tensor, ExampleInputMetaData)): + arg_like_tensor(arg, f) + else: + # Assume it's a literal that we can just print directly. + print(f" {arg},", file=f) + + +def _readable( + module: torch.fx.GraphModule, + module_name: str, + print_output: bool = False, + include_stride: bool = True, + include_device: bool = True, + colored: bool = False, +): + """Modified from `torch.fx.graph_module._print_readable` (https://github.com/pytorch/pytorch/blob/3192bdeea428f2bf3a95274ee59ea41c4f8e31e9/torch/fx/graph_module.py#L297). + This is basically print_readable but it sets verbose=False (torch hardcodes it to True).""" + graph = module.graph + assert graph is not None and isinstance( + graph, torch.fx.Graph + ), "print_readable must be used on a module with a graph" + + verbose_python_code = graph.python_code( + root_module="self", + verbose=False, + include_stride=include_stride, + include_device=include_device, + colored=colored, + ) + module_code = verbose_python_code.src + module_code = module_code.lstrip("\n") + module_code = f"class {module_name}(torch.nn.Module):\n" + module_code + module_code = _addindent(module_code, 2) + + submodule_code_list = [""] + for submodule_name, submodule in module.named_children(): + if hasattr(submodule, "graph"): + submodule_code_list.append( + _readable( + submodule, + submodule_name, + print_output=False, + include_stride=include_stride, + include_device=include_device, + colored=colored, + ) + ) + submodule_code = "\n".join(submodule_code_list) + submodule_code = _addindent(submodule_code, 2) + + output = module_code + submodule_code + if print_output: + print(module_code + submodule_code) + return output + + +def get_env() -> tuple[str, str]: + """Retrieve detailed environment information using `torch.utils.collect_env.get_pretty_env_info()`. + Additionally, include the installed versions of Thunder and NvFuser (if available via pip). + """ + + from torch.utils.collect_env import run, get_pretty_env_info, get_pip_packages + + torch_env = get_pretty_env_info() + _, thunder_packages = get_pip_packages(run, {"lightning-thunder", "nvfuser"}) + return torch_env, thunder_packages + + +def thunder_options_to_str(thunder_options: dict) -> str: + from thunder import resolve_executors + + option_str = "" + for key, value in thunder_options.items(): + if key == "executors": + executors = resolve_executors(value) + option_str += f"{key}=[" + ",".join(f"thunder.extend.get_executor('{ex.name}')" for ex in executors) + "]" + else: + option_str += f"{key}={repr(value)}" + option_str += "," + return option_str + + +def reproducer( + gm: torch.fx.GraphModule, + thunder_options: dict, + args: tuple[torch.Tensor | ExampleInputMetaData], + folder: str | os.PathLike, + graph_idx: int, +): + # Ideally we'd use print_readable, but we want verbose=False and there's no + # way to set that with print_readable. + folder = Path(folder) + folder.mkdir(exist_ok=True) + torch_env, thunder_pkgs = get_env() + readable = _readable(gm, "DynamoModule", print_output=False) + has_cuda_args = any(hasattr(arg, "device") and arg.device.type == "cuda" for arg in args) + thunder_options_str = thunder_options_to_str(thunder_options) + with open(folder / f"g{graph_idx}.py", "w") as f: + print('"""', file=f) + print("Environment information get from `torch.utils.collect_env.get_pretty_env_info()`:\n", file=f) + print(torch_env, file=f) + print("\nVersions of Thunder related libraries:", file=f) + print(thunder_pkgs, file=f) + print("\nThe torch.fx.Graph:", file=f) + print(gm.graph, file=f) + print('"""', file=f) + print("import os\n", file=f) + print("import torch", file=f) + print("import thunder", file=f) + if has_cuda_args: + print("import thunder.transforms.cudagraph", file=f) + print("from thunder.dev_utils.nvtx_profile_transform ", end="", file=f) + print("import NvtxProfileTransform\n", file=f) + print("_execs = [", file=f) + print(' thunder.extend.get_executor("nvfuser"),', file=f) + print(' thunder.extend.get_executor("sdpa"),', file=f) + print(' thunder.extend.get_executor("cudnn"),', file=f) + print("]\n", file=f) + print(f"def test_g{graph_idx}():", file=f) + print(" ", _addindent(readable, 2), file=f) + if any(arg is None for arg in args): + print( + " # Warning: The inputs that cannot be inferred are set to None, requiring the user to manually give inputs according to the code" + ) + print(" inputs = [", file=f) + for a in args: + print(" ", end="", file=f) + arg_like(a, f) + print(" ]", file=f) + print( + """ # NOTE the `BACKEND` environment variable is intended to provide some common ways to debug/benchmark thunder.jit + # with different backend and compilation options. By default, it uses the original Thunder options that are executed""", + file=f, + ) + print(' backend = os.getenv("BACKEND")', file=f) + # thunder_options_str = ", ".join(f"{key}={repr(value)}" for key, value in thunder_options.items()) + print(' if backend == None or backend == "thunder":', file=f) + print(f" fqn = thunder.jit(DynamoModule(), {thunder_options_str})", file=f) + print(' elif backend == "torch.compile":', file=f) + print(" fqn = torch.compile(DynamoModule())", file=f) + print(' elif backend == "dynamo-eager":', file=f) + print(' fqn = torch.compile(DynamoModule(), backend="eager")', file=f) + if has_cuda_args: + print(' elif backend == "thunder-nvtxprofile":', file=f) + print(" fqn = thunder.jit(DynamoModule(), transforms=[NvtxProfileTransform()])", file=f) + print(' elif backend == "thunder-no-torch.compile":', file=f) + print(" fqn = thunder.jit(DynamoModule(), executors=_execs)", file=f) + print(' elif backend == "thunder-cudagraph":', file=f) + print(" xform = thunder.transforms.cudagraph.CUDAGraphTransform()", file=f) + print(" fqn = thunder.jit(DynamoModule(), transform=[xform])", file=f) + print(f' post_graph = os.getenv("POST_GRAPH", "0")', file=f) + print(f" if int(post_graph) > 0:", file=f) + print(f" fqn = torch.cuda.make_graphed_callables(", file=f) + print(f" fqn, inputs,", file=f) + print(f" num_warmup_iters=1, allow_unused_input=True", file=f) + print(f" )", file=f) + print(f' torch.cuda.nvtx.range_push("g{graph_idx} warmups")', file=f) + print(f" for i in range(3): # warmup runs", file=f) + print(f" fqn(*inputs)", file=f) + if has_cuda_args: + print(f" torch.cuda.synchronize()", file=f) + print(f" torch.cuda.nvtx.range_pop()", file=f) + print(f' torch.cuda.nvtx.range_push("g{graph_idx}")', file=f) + print(f" fqn(*inputs)", file=f) + if has_cuda_args: + print(f" torch.cuda.synchronize()", file=f) + print(f" torch.cuda.nvtx.range_pop()", file=f) + print(f"\ntest_g{graph_idx}()", file=f) diff --git a/thunder/tests/test_dynamo.py b/thunder/tests/test_dynamo.py index 2f9bb0d124..f81166eb41 100644 --- a/thunder/tests/test_dynamo.py +++ b/thunder/tests/test_dynamo.py @@ -1,5 +1,7 @@ import pytest import warnings +import os +from subprocess import run import torch import torch.fx import torch.nn as nn @@ -688,3 +690,76 @@ def find_target_module(model, target_module_name): for n in submodule.graph.nodes: if n.op == "call_function": assert isinstance(n.target, Symbol) + + +@instantiate(dtypes=NOTHING, executors=[DynamoThunderExecutor]) +def test_dynamo_reproducer_2graph(executor, device: str, dtype: dtypes.dtype, tmp_path): + from thunder.dev_utils.nvtx_profile_transform import NvtxProfileTransform + from thunder import nvfuser_executor + from thunder.transforms.cudagraph import CUDAGraphTransform + + if device.startswith("cuda"): + backend = ThunderCompiler( + transforms=[ + NvtxProfileTransform(), + CUDAGraphTransform(), + ], + executors=[nvfuser_executor], + cache="no caching", + langctx=None, + record_history=False, + ) + else: + backend = ThunderCompiler(executors=None) + # Test non-contiguous input tensor + x = make_tensor((4, 4), low=3, high=10, dtype=torch.int64, device=device, noncontiguous=True) + + @torch.compile(backend=backend) + def func(x): + x = torch.sin(x) + if x.sum() > 0: + return x + 1 + else: + return x - 1 + + out = func(x) + backend.save_reproducer_to_folder(tmp_path) + + s1 = f"{tmp_path}/g1_thunder_1.py" + s2 = f"{tmp_path}/g2_thunder_1.py" + assert os.path.exists(s1) + assert os.path.exists(s2) + result1 = run(["python", s1], capture_output=True, text=True) + result2 = run(["python", s2], capture_output=True, text=True) + + assert result1.returncode == 0, f"Reproducer {s1} failed with return code {result1.returncode}" + assert result2.returncode == 0, f"Reproducer {s2} failed with return code {result2.returncode}" + + +@requiresCUDA +def test_dynamo_reproducer_submodules(tmp_path): + from thunder.tests.distributed.helper import ToyModel + import torch.nn as nn + + class SimpleModel(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.sub_mod = ToyModel() + self.seq = nn.Sequential(self.sub_mod, nn.ReLU()) + + def forward(self, x): + x = torch.sin(x) + x = self.seq(x) + return x + + x = torch.randn(1, ToyModel.N_IN, device="cuda", requires_grad=True) + model = SimpleModel().cuda() + backend = ThunderCompiler() + jf = torch.compile(backend=backend)(model) + out = jf(x) + backend.save_reproducer_to_folder(tmp_path) + + s1 = f"{tmp_path}/g1_thunder_1.py" + assert os.path.exists(s1) + result1 = run(["python", s1], capture_output=True, text=True) + assert result1.returncode == 0, f"Reproducer {s1} failed with return code {result1.returncode}"