Skip to content

Commit

Permalink
simplified siginfo
Browse files Browse the repository at this point in the history
Signed-off-by: Masaki Kozuki <[email protected]>
  • Loading branch information
crcrpar committed Nov 13, 2024
1 parent 8791f81 commit 94c3409
Showing 1 changed file with 2 additions and 11 deletions.
13 changes: 2 additions & 11 deletions thunder/core/jit_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -844,13 +844,7 @@ def _generate_random_str_id() -> str:
for bsym in fwd_bsyms:
augmented_fwd_trace.add_bound_symbol(bsym)
augmented_fwd_trace.add_bound_symbol(prims.python_return.bind(output, saved_values, output=()))
si = SigInfo(f"augmented_autograd_function_apply_{sym_id}")
for a in bsym_of_custom_autograd_func.args:
if isinstance(a, Proxy):
si.args.append((a.name, None))
else:
pa = proxy(a)
si.args.append((pa.name, a))
si = SigInfo.from_name_and_args(f"augmented_autograd_function_apply_{sym_id}", bsym_of_custom_autograd_func.args)
augmented_fwd_trace._siginfo = si
augmented_fwd_callable = augmented_fwd_trace.python_callable(include_decorators=False)

Expand All @@ -876,10 +870,7 @@ def augmented_fwd_rule(*args):
bwd_trace.bound_symbols = jit_ctx.computation_trace.pop_scope()
bwd_trace.bound_symbols.append(prims.python_return.bind(unwrapped_bwd_result, output=()))

bwd_si = SigInfo(f"bwd_{si.name}")
for a in saved_values + grads:
bwd_si.args.append((a.name, None))
bwd_trace._siginfo = bwd_si
bwd_trace._siginfo = SigInfo.from_name_and_args(f"bwd_{si.name}", saved_values + grads)
backward_impls[sym.id] = bwd_trace.python_callable(include_decorators=False)

return wrapped_output
Expand Down

0 comments on commit 94c3409

Please sign in to comment.