diff --git a/thunder/core/jit_ext.py b/thunder/core/jit_ext.py index 41d4ea9fd..46bf11a3b 100644 --- a/thunder/core/jit_ext.py +++ b/thunder/core/jit_ext.py @@ -792,9 +792,7 @@ def _generate_random_str_id() -> str: # non_differentiable_idx = fwd_kwargs.get("non_differentiable_idx") length_of_tensor_args = sum(args_tensor_mask) new_fwd_args = (wrap_const(None),) + fwd_args[:length_of_tensor_args] - old_scope = jit_ctx.computation_trace.scopes - fwd_bsyms = [] - jit_ctx.computation_trace.scopes = [fwd_bsyms] + jit_ctx.computation_trace.push_scope([]) fwd_result = _interpret_call(fwd, *new_fwd_args) if fwd_result is INTERPRETER_SIGNALS.EXCEPTION_RAISED: @@ -804,6 +802,7 @@ def _generate_random_str_id() -> str: unwrapped_fwd_args = tree_map(lambda t: unwrap(t), new_fwd_args)[1:] + fwd_bsyms: list[BoundSymbol] = jit_ctx.computation_trace.pop_scope() producer_map = utils.producers(fwd_bsyms) tensor_to_prod_bsym: dict[Variable, BoundSymbol] = {} for p in tree_flatten((output, saved_values))[0]: @@ -838,7 +837,7 @@ def _generate_random_str_id() -> str: _object_ctx=fwd_bsyms[0]._object_ctx, _executor=fwd_bsyms[0]._executor, ) - old_scope[-1].append(bsym_of_custom_autograd_func) + jit_ctx.computation_trace.scopes[-1].append(bsym_of_custom_autograd_func) # Define augmented fwd rule and backward rule on the fly. augmented_fwd_trace = TraceCtx() @@ -863,10 +862,8 @@ def augmented_fwd_rule(*args): augmented_forward_impls[sym.id] = augmented_fwd_rule - bwd_bsyms = [] - jit_ctx.computation_trace.scopes = [bwd_bsyms] + jit_ctx.computation_trace.push_scope([]) bwd_trace = TraceCtx() - bwd_trace.bound_symbols = bwd_bsyms grads = sequencify(tree_map(lambda t: TensorProxy(like=t), output)) bwd_args = (wrap_const(None),) @@ -876,6 +873,7 @@ def augmented_fwd_rule(*args): if bwd_result is INTERPRETER_SIGNALS.EXCEPTION_RAISED: return bwd_result unwrapped_bwd_result = unwrap(bwd_result) + bwd_trace.bound_symbols = jit_ctx.computation_trace.pop_scope() bwd_trace.bound_symbols.append(prims.python_return.bind(unwrapped_bwd_result, output=())) bwd_si = SigInfo(f"bwd_{si.name}") @@ -884,7 +882,6 @@ def augmented_fwd_rule(*args): bwd_trace._siginfo = bwd_si backward_impls[sym.id] = bwd_trace.python_callable(include_decorators=False) - jit_ctx.computation_trace.scopes = old_scope return wrapped_output