Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ThunderFX: Save the reproducer script into files #1380

Open
wants to merge 21 commits into
base: main
Choose a base branch
from
Open

Conversation

kiya00
Copy link
Collaborator

@kiya00 kiya00 commented Oct 31, 2024

Before submitting
  • Was this discussed/approved via a Github issue? (no need for typos and docs improvements)
  • Did you read the contributor guideline, Pull Request section?
  • Did you make sure to update the docs?
  • Did you write any new necessary tests?

What does this PR do?

Fixes #1082.

Based on the code provided by @tfogal
(https://github.com/tfogal/NeMo/blob/a0c711deae6c6f7342662795425684a342f95b8d/examples/multimodal/multimodal_llm/neva/neva_pretrain.py#L164), I added the ThunderCompiler.save_reproducer_to_folder interface to save the reproducer script in an "offline" way. SubgraphInfo.thunder_compiled_fns_example_inputs is added to record the input tensor metadata and after execution we retrieve the information in SubgraphInfo and write the reproducer to file.

The save_dynamo_repro option is added to the benchmark_litgpt.py, an example of its use: python thunder/benchmarks/benchmark_litgpt.py --model_name Llama-2-7b-hf --compile dynamo+thunder --n_layers=2 --save_dynamo_repro='tmp/bench'

TODO: support for saving the repro of module with checkpointing needs #1437

An example of the saved reproducer script using CPU inputs
"""
Environment information get from `torch.utils.collect_env.get_pretty_env_info()`:

PyTorch version: 2.6.0a0+gite2e425b
Is debug build: False
CUDA used to build PyTorch: 12.6
ROCM used to build PyTorch: N/A

OS: Ubuntu 24.04.1 LTS (x86_64)
GCC version: (Ubuntu 13.2.0-23ubuntu4) 13.2.0
Clang version: Could not collect
CMake version: version 3.28.3
Libc version: glibc-2.39

Python version: 3.12.3 (main, Sep 11 2024, 14:17:37) [GCC 13.2.0] (64-bit runtime)
Python platform: Linux-6.5.0-44-generic-x86_64-with-glibc2.39
Is CUDA available: True
CUDA runtime version: 12.6.85
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration:
GPU 0: NVIDIA RTX 6000 Ada Generation
GPU 1: NVIDIA RTX 6000 Ada Generation

Nvidia driver version: 545.29.02
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.9.5.1
/usr/lib/x86_64-linux-gnu/libcudnn_adv.so.9.5.1
/usr/lib/x86_64-linux-gnu/libcudnn_cnn.so.9.5.1
/usr/lib/x86_64-linux-gnu/libcudnn_engines_precompiled.so.9.5.1
/usr/lib/x86_64-linux-gnu/libcudnn_engines_runtime_compiled.so.9.5.1
/usr/lib/x86_64-linux-gnu/libcudnn_graph.so.9.5.1
/usr/lib/x86_64-linux-gnu/libcudnn_heuristic.so.9.5.1
/usr/lib/x86_64-linux-gnu/libcudnn_ops.so.9.5.1
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                       x86_64
CPU op-mode(s):                     32-bit, 64-bit
Address sizes:                      48 bits physical, 48 bits virtual
Byte Order:                         Little Endian
CPU(s):                             32
On-line CPU(s) list:                0-31
Vendor ID:                          AuthenticAMD
BIOS Vendor ID:                     Advanced Micro Devices, Inc.
Model name:                         AMD Ryzen 9 7950X 16-Core Processor
BIOS Model name:                    AMD Ryzen 9 7950X 16-Core Processor             Unknown CPU @ 4.5GHz
BIOS CPU family:                    107
CPU family:                         25
Model:                              97
Thread(s) per core:                 2
Core(s) per socket:                 16
Socket(s):                          1
Stepping:                           2
CPU(s) scaling MHz:                 66%
CPU max MHz:                        5881.0000
CPU min MHz:                        400.0000
BogoMIPS:                           8999.83
Flags:                              fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good amd_lbr_v2 nopl nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 hw_pstate ssbd mba perfmon_v2 ibrs ibpb stibp ibrs_enhanced vmmcall fsgsbase bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local avx512_bf16 clzero irperf xsaveerptr rdpru wbnoinvd cppc arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif x2avic v_spec_ctrl vnmi avx512vbmi umip pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq rdpid overflow_recov succor smca fsrm flush_l1d
Virtualization:                     AMD-V
L1d cache:                          512 KiB (16 instances)
L1i cache:                          512 KiB (16 instances)
L2 cache:                           16 MiB (16 instances)
L3 cache:                           64 MiB (2 instances)
NUMA node(s):                       1
NUMA node0 CPU(s):                  0-31
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit:        Not affected
Vulnerability L1tf:                 Not affected
Vulnerability Mds:                  Not affected
Vulnerability Meltdown:             Not affected
Vulnerability Mmio stale data:      Not affected
Vulnerability Retbleed:             Not affected
Vulnerability Spec rstack overflow: Vulnerable: Safe RET, no microcode
Vulnerability Spec store bypass:    Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1:           Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:           Mitigation; Enhanced / Automatic IBRS; IBPB conditional; STIBP always-on; RSB filling; PBRSB-eIBRS Not affected; BHI Not affected
Vulnerability Srbds:                Not affected
Vulnerability Tsx async abort:      Not affected

Versions of relevant libraries:
[pip3] numpy==1.26.4
[pip3] nvidia-cudnn-frontend==1.8.0
[pip3] optree==0.13.0
[pip3] optree==0.13.0
[pip3] pytorch-lightning==2.4.0
[pip3] torch==2.6.0a0+gite2e425b
[pip3] torchmetrics==1.5.2
[pip3] torchvision==0.19.0a0+d23a6e1
[pip3] triton==3.1.0
[conda] Could not collect

Versions of Thunder related libraries:
lightning-thunder==0.2.0.dev0
nvfuser==0.2.22+git5b9fb77

The torch.fx.Graph:
graph():
    %l_x_ : torch.Tensor [num_users=1] = placeholder[target=l_x_]
    %x : [num_users=2] = call_function[target=torch.sin](args = (%l_x_,), kwargs = {})
    %sum_1 : [num_users=1] = call_method[target=sum](args = (%x,), kwargs = {})
    %gt : [num_users=1] = call_function[target=operator.gt](args = (%sum_1, 0), kwargs = {})
    return (x, gt)
"""
import os

import torch
import thunder
def test_g1_thunder_1():
  class DynamoModule(torch.nn.Module):
    def forward(self, l_x_ : torch.Tensor):
        x = torch.sin(l_x_);  l_x_ = None
        sum_1 = x.sum()
        gt = sum_1 > 0;  sum_1 = None
        return (x, gt)

  inputs = [
    torch.testing.make_tensor((31,), dtype=torch.int64,  device='cpu', requires_grad=False, low=3, high=9,).as_strided((4, 4), (8, 2)),
  ]
  # NOTE the `BACKEND` environment variable is intended to provide some common ways to debug/benchmark thunder.jit
  # with different backend and compilation options. By default, it uses the original Thunder options that are executed
  backend = os.getenv("BACKEND")
  if backend == None or backend == "thunder":
    fqn = thunder.jit(DynamoModule(), executors=[thunder.extend.get_executor('apex'),thunder.extend.get_executor('cudnn'),thunder.extend.get_executor('sdpa'),thunder.extend.get_executor('torchcompile_cat'),thunder.extend.get_executor('nvfuser')],)
  elif backend == "torch.compile":
    fqn = torch.compile(DynamoModule())
  elif backend == "dynamo-eager":
    fqn = torch.compile(DynamoModule(), backend="eager")
  for i in range(3): # warmup runs
    fqn(*inputs)
  fqn(*inputs)

test_g1_thunder_1()
An example of the saved reproducer script using CUDA inputs
"""
Environment information get from `torch.utils.collect_env.get_pretty_env_info()`:

PyTorch version: 2.6.0a0+gite2e425b
Is debug build: False
CUDA used to build PyTorch: 12.6
ROCM used to build PyTorch: N/A

OS: Ubuntu 24.04.1 LTS (x86_64)
GCC version: (Ubuntu 13.2.0-23ubuntu4) 13.2.0
Clang version: Could not collect
CMake version: version 3.28.3
Libc version: glibc-2.39

Python version: 3.12.3 (main, Sep 11 2024, 14:17:37) [GCC 13.2.0] (64-bit runtime)
Python platform: Linux-6.5.0-44-generic-x86_64-with-glibc2.39
Is CUDA available: True
CUDA runtime version: 12.6.85
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration:
GPU 0: NVIDIA RTX 6000 Ada Generation
GPU 1: NVIDIA RTX 6000 Ada Generation

Nvidia driver version: 545.29.02
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.9.5.1
/usr/lib/x86_64-linux-gnu/libcudnn_adv.so.9.5.1
/usr/lib/x86_64-linux-gnu/libcudnn_cnn.so.9.5.1
/usr/lib/x86_64-linux-gnu/libcudnn_engines_precompiled.so.9.5.1
/usr/lib/x86_64-linux-gnu/libcudnn_engines_runtime_compiled.so.9.5.1
/usr/lib/x86_64-linux-gnu/libcudnn_graph.so.9.5.1
/usr/lib/x86_64-linux-gnu/libcudnn_heuristic.so.9.5.1
/usr/lib/x86_64-linux-gnu/libcudnn_ops.so.9.5.1
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                       x86_64
CPU op-mode(s):                     32-bit, 64-bit
Address sizes:                      48 bits physical, 48 bits virtual
Byte Order:                         Little Endian
CPU(s):                             32
On-line CPU(s) list:                0-31
Vendor ID:                          AuthenticAMD
BIOS Vendor ID:                     Advanced Micro Devices, Inc.
Model name:                         AMD Ryzen 9 7950X 16-Core Processor
BIOS Model name:                    AMD Ryzen 9 7950X 16-Core Processor             Unknown CPU @ 4.5GHz
BIOS CPU family:                    107
CPU family:                         25
Model:                              97
Thread(s) per core:                 2
Core(s) per socket:                 16
Socket(s):                          1
Stepping:                           2
CPU(s) scaling MHz:                 69%
CPU max MHz:                        5881.0000
CPU min MHz:                        400.0000
BogoMIPS:                           8999.83
Flags:                              fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good amd_lbr_v2 nopl nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 hw_pstate ssbd mba perfmon_v2 ibrs ibpb stibp ibrs_enhanced vmmcall fsgsbase bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local avx512_bf16 clzero irperf xsaveerptr rdpru wbnoinvd cppc arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif x2avic v_spec_ctrl vnmi avx512vbmi umip pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq rdpid overflow_recov succor smca fsrm flush_l1d
Virtualization:                     AMD-V
L1d cache:                          512 KiB (16 instances)
L1i cache:                          512 KiB (16 instances)
L2 cache:                           16 MiB (16 instances)
L3 cache:                           64 MiB (2 instances)
NUMA node(s):                       1
NUMA node0 CPU(s):                  0-31
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit:        Not affected
Vulnerability L1tf:                 Not affected
Vulnerability Mds:                  Not affected
Vulnerability Meltdown:             Not affected
Vulnerability Mmio stale data:      Not affected
Vulnerability Retbleed:             Not affected
Vulnerability Spec rstack overflow: Vulnerable: Safe RET, no microcode
Vulnerability Spec store bypass:    Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1:           Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:           Mitigation; Enhanced / Automatic IBRS; IBPB conditional; STIBP always-on; RSB filling; PBRSB-eIBRS Not affected; BHI Not affected
Vulnerability Srbds:                Not affected
Vulnerability Tsx async abort:      Not affected

Versions of relevant libraries:
[pip3] numpy==1.26.4
[pip3] nvidia-cudnn-frontend==1.8.0
[pip3] optree==0.13.0
[pip3] optree==0.13.0
[pip3] pytorch-lightning==2.4.0
[pip3] torch==2.6.0a0+gite2e425b
[pip3] torchmetrics==1.5.2
[pip3] torchvision==0.19.0a0+d23a6e1
[pip3] triton==3.1.0
[conda] Could not collect

Versions of Thunder related libraries:
lightning-thunder==0.2.0.dev0
nvfuser==0.2.22+git5b9fb77

The torch.fx.Graph:
graph():
    %l_x_ : torch.Tensor [num_users=1] = placeholder[target=l_x_]
    %x : [num_users=2] = call_function[target=torch.sin](args = (%l_x_,), kwargs = {})
    %sum_1 : [num_users=1] = call_method[target=sum](args = (%x,), kwargs = {})
    %gt : [num_users=1] = call_function[target=operator.gt](args = (%sum_1, 0), kwargs = {})
    return (x, gt)
"""
import os

import torch
import thunder
import thunder.transforms.cudagraph
from thunder.dev_utils.nvtx_profile_transform import NvtxProfileTransform

_execs = [
  thunder.extend.get_executor("nvfuser"),
  thunder.extend.get_executor("sdpa"),
  thunder.extend.get_executor("cudnn"),
]

def test_g1_thunder_1():
  class DynamoModule(torch.nn.Module):
    def forward(self, l_x_ : torch.Tensor):
        x = torch.sin(l_x_);  l_x_ = None
        sum_1 = x.sum()
        gt = sum_1 > 0;  sum_1 = None
        return (x, gt)

  inputs = [
    torch.testing.make_tensor((31,), dtype=torch.int64,  device='cuda:0', requires_grad=False, low=4, high=9,).as_strided((4, 4), (8, 2)),
  ]
  # NOTE the `BACKEND` environment variable is intended to provide some common ways to debug/benchmark thunder.jit
  # with different backend and compilation options. By default, it uses the original Thunder options that are executed
  backend = os.getenv("BACKEND")
  if backend == None or backend == "thunder":
    fqn = thunder.jit(DynamoModule(), transforms=[thunder.dev_utils.nvtx_profile_transform.NvtxProfileTransform(), thunder.transforms.cudagraph.CUDAGraphTransform()],executors=[thunder.extend.get_executor('nvfuser')],cache='no caching',langctx=None,record_history=False,)
  elif backend == "torch.compile":
    fqn = torch.compile(DynamoModule())
  elif backend == "dynamo-eager":
    fqn = torch.compile(DynamoModule(), backend="eager")
  elif backend == "thunder-nvtxprofile":
    fqn = thunder.jit(DynamoModule(), transforms=[NvtxProfileTransform()])
  elif backend == "thunder-no-torch.compile":
    fqn = thunder.jit(DynamoModule(), executors=_execs)
  elif backend == "thunder-cudagraph":
    xform = thunder.transforms.cudagraph.CUDAGraphTransform()
    fqn = thunder.jit(DynamoModule(), transform=[xform])
  post_graph = os.getenv("POST_GRAPH", "0")
  if int(post_graph) > 0:
    fqn = torch.cuda.make_graphed_callables(
      fqn, inputs,
      num_warmup_iters=1, allow_unused_input=True
    )
  torch.cuda.nvtx.range_push("g1_thunder_1 warmups")
  for i in range(3): # warmup runs
    fqn(*inputs)
  torch.cuda.synchronize()
  torch.cuda.nvtx.range_pop()
  torch.cuda.nvtx.range_push("g1_thunder_1")
  fqn(*inputs)
  torch.cuda.synchronize()
  torch.cuda.nvtx.range_pop()

test_g1_thunder_1()

@mruberry
Copy link
Collaborator

This is really cool! A couple questions, @kiya00:

backend = os.getenv("BACKEND")
  if backend == None or backend == "thunder":
    ...

For this part, can the script instead know how to produce the same compilation as in the original reproduction? Like if the original was compiled like torch.compile(DynamoModule(), backend="eager") then that's what appears in this reproduction without having to query a "BACKEND" environment variable. Same for the query to the "POST_GRAPH" environment variable. Maybe this information can be queried from the jitted function's compile statistics?

torch.cuda.nvtx.range_push("g1_thunder_1 compilation")

Do this CUDA-specific calls only appear if one or more of the input tensors is generated on a CUDA device?

inputs = [
    torch.randint(low=3, high=9, size=(31,), dtype=torch.int64, layout=torch.strided, device="cuda:0", requires_grad=False).as_strided((4, 4), (8, 2)),
  ]

Where do the low and high values for this call to randint come from? Should we think about using make_tensor instead? make_tensor can be an easier interface for creating tensors of different datatypes.

Can we add a comment with the original FX graph, too?

Versions of Thunder related libraries:
lightning-thunder==0.2.0.dev0
nvfuser==0.2.21+gitf6975f3

Not a question, I just thought this was really cool and helpful.

torch.cuda.nvtx.range_push("g1_thunder_1")
fqn(*inputs)
torch.cuda.synchronize()

Why this final call the the function?

@kiya00
Copy link
Collaborator Author

kiya00 commented Nov 12, 2024

Hi @mruberry , thanks for the advice I'll change it accordingly

Like if the original was compiled like torch.compile(DynamoModule(), backend="eager") then that's what appears in this reproduction without having to query a "BACKEND" environment variable. Same for the query to the "POST_GRAPH" environment variable. Maybe this information can be queried from the jitted function's compile statistics?

I mostly kept the repro script as the one written by Tom, I think it's more like a debug+benchmark script on different backend depending on the environment variable. If we just want to produce the same compilation as in the original reproduction, I can change it to the thunder options actually used.

inputs = [
torch.randint(low=3, high=9, size=(31,), dtype=torch.int64, layout=torch.strided, device="cuda:0", requires_grad=False).as_strided((4, 4), (8, 2)),
]
Where do the low and high values for this call to randint come from? Should we think about using make_tensor instead? make_tensor can be an easier interface for creating tensors of different datatypes.

When we can get the real tensor instead of FakeTensor in the Dynamo graph, the min/max value can be obtained from it. And sometimes it's needed to ensure correctness (e.g. nanogpt input must be in range 0-255). Currently the original torch interface is used to create the inputs, like the nvFuser repro does, maybe it's more user friendly to use the native torch API, but it's easy to change to make_tensor.

torch.cuda.nvtx.range_push("g1_thunder_1 compilation")
Do this CUDA-specific calls only appear if one or more of the input tensors is generated on a CUDA device?

I'll modify it to only appear when cuda is used.

@mruberry
Copy link
Collaborator

Hi @mruberry , thanks for the advice I'll change it accordingly

Like if the original was compiled like torch.compile(DynamoModule(), backend="eager") then that's what appears in this reproduction without having to query a "BACKEND" environment variable. Same for the query to the "POST_GRAPH" environment variable. Maybe this information can be queried from the jitted function's compile statistics?

I mostly kept the repro script as the one written by Tom, I think it's more like a debug+benchmark script on different backend depending on the environment variable. If we just want to produce the same compilation as in the original reproduction, I can change it to the thunder options actually used.

That makes a lot of sense! I think we can let @tfogal comment when he's back, but I think the principal desire behind these reproduction scripts is to create a standalone file that someone interested in reviewing thunder's performance or correctness can quickly run to replicate an issue. That's why I'd suggest making it so that when someone clicks "run" the script executes what thunder did the same way thunder did it. Of course it's great to add notes for how to override/compare that behavior, too!

inputs = [
torch.randint(low=3, high=9, size=(31,), dtype=torch.int64, layout=torch.strided, device="cuda:0", requires_grad=False).as_strided((4, 4), (8, 2)),
]
Where do the low and high values for this call to randint come from? Should we think about using make_tensor instead? make_tensor can be an easier interface for creating tensors of different datatypes.

When we can get the real tensor instead of FakeTensor in the Dynamo graph, the min/max value can be obtained from it. And sometimes it's needed to ensure correctness (e.g. nanogpt input must be in range 0-255). Currently the original torch interface is used to create the inputs, like the nvFuser repro does, maybe it's more user friendly to use the native torch API, but it's easy to change to make_tensor.

That's really cool. Querying for the min and max values from the real tensors sounds like a good solution. I would take a look at make_tensor and see if that's easier. It's what PyTorch uses to synthetically generate tensors in its test suite.

torch.cuda.nvtx.range_push("g1_thunder_1 compilation")
Do this CUDA-specific calls only appear if one or more of the input tensors is generated on a CUDA device?

I'll modify it to only appear when cuda is used.

Awesome!

@kiya00 kiya00 marked this pull request as ready for review November 15, 2024 16:09
@tfogal
Copy link
Collaborator

tfogal commented Nov 15, 2024

Maybe [backend and transformations] information can be queried from the jitted function's compile statistics?

I mostly kept the repro script [...] I can change it to the thunder options actually used.

That makes a lot of sense! I think we can let @tfogal comment when he's back, but I think the principal desire behind these reproduction scripts is to create a standalone file that someone interested in reviewing thunder's performance or correctness can quickly run to replicate an issue. That's why I'd suggest making it so that when someone clicks "run" the script executes what thunder did the same way thunder did it. Of course it's great to add notes for how to override/compare that behavior, too!

Thanks for the ping, sorry for the delay.

The original code conflates two things, and I think there's a good argument to be made that during this cleanup process we rethink things. That is, we've got two use cases to service:

  • "Dump each graph so that I can rerun it and reproduce a bug", and
  • "Let me fiddle with how each subgraph is run so that I can identify how Thunder should optimize"

These two are very close but not quite the same.

I think/hope that the ThunderCompilerBenchmark class that Yan added recently suffices for the latter goal.

Without thinking about how hard it is / if it's appropriate for a single PR, what I'd love to see is:

  1. We try/catch such that when Thunder catastrophically fails, we dump a simple reproducer exactly as was given to us.
  2. When the user says so (env var? kwarg during ThunderCompiler instantiation? I don't know...), we use the same code as (1) but parametrized in such a way that the generated code uses ThunderCompilerBenchmark. The Benchmark class should in turn be able provide all the answers that the "fiddle" case was after.

Does that make sense to both of you?

@mruberry
Copy link
Collaborator

Does that make sense to both of you?

Yep. I think providing a default that mimics the actual behavior, and the ability to override/tweak/investigate alternatives would be great!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add option to ThunderCompiler to save gm.code or gm.print_readable to file
3 participants