Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ThunderFX: Save the reproducer script into files #1380

Open
wants to merge 21 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions thunder/benchmarks/benchmark_litgpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,7 @@ def __init__(
use_torchao_fp8_allgather: bool = False,
use_torchao_fp8_precompute_scale_for_fsdp: bool = False,
fp8_shard_intermediate_activation: bool = False,
save_dynamo_repro: str | None = None,
):
seed = 1337
torch.manual_seed(seed)
Expand Down Expand Up @@ -273,6 +274,11 @@ def __init__(
self.dump_thunder_traces = dump_thunder_traces
self.dump_memory_snapshot = dump_memory_snapshot
self.fp8_shard_intermediate_activation = fp8_shard_intermediate_activation
if save_dynamo_repro is not None:
assert (
"dynamo" in self.compile and "thunder" in self.compile
), "save_dynamo_repro can only be used if --compile=thunder+dynamo"
self.save_dynamo_repro = save_dynamo_repro

if use_torchao_fp8_linear:

Expand Down Expand Up @@ -889,6 +895,9 @@ def benchmark_main(return_metrics_as_json=False, json_path="", **kwargs) -> None
print(f"##########\n#{i}-th ThunderModule\n##########")
print(b_traces[-1])

if benchmark.save_dynamo_repro:
benchmark.backend.save_reproducer_to_folder(benchmark.save_dynamo_repro)

if global_rank in [0, None]:
if return_metrics_as_json:
benchmark.add_model_info_to_metrics()
Expand Down
3 changes: 3 additions & 0 deletions thunder/core/transform_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,9 @@ def reverse_transform_state_dict_for_submodule(
) -> dict[str, Any]:
return state_dict

def __repr__(self) -> str:
return f"{self.__class__.__module__}.{self.__class__.__name__}()"


def order_proxies(bsyms: Sequence[BoundSymbol]) -> dict[str, int]:
"""computes a canonical ordering of proxies in the bound symbols based on the order of appearance
Expand Down
34 changes: 33 additions & 1 deletion thunder/dynamo/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,13 @@
import torch

from thunder.core.baseutils import run_once
from thunder.dynamo.utils import recompile_graph
from thunder.core.utils import safe_zip
from thunder.dynamo.utils import recompile_graph, reproducer
from thunder.dynamo.splitter import _splitter

if TYPE_CHECKING:
from thunder.dynamo.utils import SubgraphInfo
from os import PathLike


@run_once
Expand Down Expand Up @@ -81,3 +83,33 @@ def __call__(self, gm: torch.fx.GraphModule, sample_args: list[torch.SymInt, tor
split_module, subgraph_info = _splitter(gm, self._thunder_jit, self._torch_compile, sample_args)
self.subgraph_infos.append(subgraph_info)
return split_module

def save_reproducer_to_folder(self, reproducer_folder: str | PathLike):
"""
Save the reproducer script for the GraphModule executed by Thunder to the specified `reproducer_folder`.
Each saved script is named as "g[graph_id]_thunder_[module_id]", where:

- `graph_id` indexes the graph generated by Dynamo, which is then passed to Thunder.
- `module_id` indexes the submodule split by the :func:`thunder.dynamo.utils._splitter`.

Both `graph_id` and `module_id` start from 1.
"""
if not self.subgraph_infos:
raise TypeError(f"{self} doesn't seem to have been called yet.")

for graph_idx, subgraph_info in enumerate(self.subgraph_infos):
thunder_module_names = []
for node in subgraph_info.split_graph_module.graph.nodes:
target = node.target
if isinstance(target, str) and target.startswith("thunder_"):
thunder_module_names.append(target)
thunder_modules = subgraph_info.thunder_compiled_fns
example_inputs = subgraph_info.thunder_compiled_fns_example_inputs
for cur_module, example_input, cur_name in safe_zip(thunder_modules, example_inputs, thunder_module_names):
reproducer(
getattr(cur_module, "_model"),
self.thunder_options,
example_input,
reproducer_folder,
f"{graph_idx+1}_{cur_name}",
)
2 changes: 1 addition & 1 deletion thunder/dynamo/compiler_graph_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def __call__(self, gm: torch.fx.GraphModule, sample_args: list[torch.SymInt, tor
cur_nodes = cur_module.graph.nodes
# Greates random input values for the current module based on the faketensor 'example_value' of the placeholder node
placeholders = list(n for n in cur_nodes if n.op == "placeholder")
args = chain(*map(_get_example_inputs_from_placeholder, placeholders))
args = list(map(_get_example_inputs_from_placeholder, placeholders))
# Runs the benchmark on the original module with the generated random inputs
self.run_bench(compiled_functions_to_submodule[cur_module], target, *args)
self.graph_idx += 1
Expand Down
11 changes: 11 additions & 0 deletions thunder/dynamo/splitter.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from itertools import chain
from functools import partial

import torch
from torch.fx.passes.split_module import split_module
Expand All @@ -15,6 +17,7 @@
update_node_and_submodule,
recompile_graph,
checkpoint_converter,
_get_example_inputs_from_placeholder,
)

if TYPE_CHECKING:
Expand Down Expand Up @@ -140,10 +143,17 @@ def is_thunder_supported_partition(node: torch.fx.Node) -> bool:

# Call compile on the split region/s.
thunder_compiled_fns = []
example_input_metadatas = []
submodule_to_compiled_fns = {}
for node in split_gm.graph.nodes:
if is_thunder_supported_partition(node):
graph_module = getattr(split_gm, node.name)
# Record the input tensor metadata of the current module based on the faketensor 'example_value' of the placeholder node
placeholders = list(n for n in graph_module.graph.nodes if n.op == "placeholder")
example_input_metadata = map(
partial(_get_example_inputs_from_placeholder, only_metadata=True), placeholders
)
example_input_metadatas.append(list(example_input_metadata))
# Replace PyTorch operators within the checkpointed function with the corresponding Thunder operators
checkpoint_converter(split_gm, graph_module)
jit_fn = thunder_jit(graph_module)
Expand All @@ -168,6 +178,7 @@ def is_thunder_supported_partition(node: torch.fx.Node) -> bool:
gm,
split_gm,
thunder_compiled_fns,
example_input_metadatas,
submodule_to_compiled_fns,
split_reasons,
)
Loading
Loading