From 94c3409bca09633354c8d4da7a015c4748227eb3 Mon Sep 17 00:00:00 2001 From: Masaki Kozuki Date: Wed, 13 Nov 2024 20:28:17 +0900 Subject: [PATCH] simplified siginfo Signed-off-by: Masaki Kozuki --- thunder/core/jit_ext.py | 13 ++----------- 1 file changed, 2 insertions(+), 11 deletions(-) diff --git a/thunder/core/jit_ext.py b/thunder/core/jit_ext.py index 46bf11a3b..be1974859 100644 --- a/thunder/core/jit_ext.py +++ b/thunder/core/jit_ext.py @@ -844,13 +844,7 @@ def _generate_random_str_id() -> str: for bsym in fwd_bsyms: augmented_fwd_trace.add_bound_symbol(bsym) augmented_fwd_trace.add_bound_symbol(prims.python_return.bind(output, saved_values, output=())) - si = SigInfo(f"augmented_autograd_function_apply_{sym_id}") - for a in bsym_of_custom_autograd_func.args: - if isinstance(a, Proxy): - si.args.append((a.name, None)) - else: - pa = proxy(a) - si.args.append((pa.name, a)) + si = SigInfo.from_name_and_args(f"augmented_autograd_function_apply_{sym_id}", bsym_of_custom_autograd_func.args) augmented_fwd_trace._siginfo = si augmented_fwd_callable = augmented_fwd_trace.python_callable(include_decorators=False) @@ -876,10 +870,7 @@ def augmented_fwd_rule(*args): bwd_trace.bound_symbols = jit_ctx.computation_trace.pop_scope() bwd_trace.bound_symbols.append(prims.python_return.bind(unwrapped_bwd_result, output=())) - bwd_si = SigInfo(f"bwd_{si.name}") - for a in saved_values + grads: - bwd_si.args.append((a.name, None)) - bwd_trace._siginfo = bwd_si + bwd_trace._siginfo = SigInfo.from_name_and_args(f"bwd_{si.name}", saved_values + grads) backward_impls[sym.id] = bwd_trace.python_callable(include_decorators=False) return wrapped_output