From cc2bbbed560d33077b3172f324a4d640e3528625 Mon Sep 17 00:00:00 2001 From: Masaki Kozuki Date: Wed, 13 Nov 2024 17:45:42 +0900 Subject: [PATCH] cherry-pick of #1256 as of 627845d367bd7857f40b6547b851d1843e186f68 Signed-off-by: Masaki Kozuki --- thunder/core/jit_ext.py | 119 ++++++++++++++++++++++++++++++ thunder/tests/test_jit_general.py | 13 ++-- thunder/torch/__init__.py | 40 ---------- 3 files changed, 126 insertions(+), 46 deletions(-) diff --git a/thunder/core/jit_ext.py b/thunder/core/jit_ext.py index c7ec43ad45..41d4ea9fdd 100644 --- a/thunder/core/jit_ext.py +++ b/thunder/core/jit_ext.py @@ -769,6 +769,125 @@ def grad_transform(*args, **kwargs): return forward_result +# ref: https://github.com/pytorch/pytorch/blob/38114ec/torch/_functorch/autograd_function.py#L715-L752 +@register_general_jit_lookaside(torch.ops.higher_order.autograd_function_apply) +def _general_jit_torch_ops_higher_order_autograd_function_apply(fwd, bwd, *fwd_args, **fwd_kwargs): + from thunder.core import utils + from thunder.core.baseutils import sequencify + from thunder.core.pytree import tree_flatten, tree_map + from thunder.core.transforms import VJPDual, augmented_forward_impls, backward_impls + + def _generate_random_str_id() -> str: + import secrets + import string + + length = 5 + return "".join(secrets.choice(string.ascii_lowercase) for _ in range(length)) + + jit_ctx: JitCtx = get_jit_ctx() + + args_tensor_mask = unwrap(fwd_kwargs["args_tensor_mask"]) + # TODO(crcrpar): Think about making use of `non_differentiable_idx` + # note that this key is quite new: https://github.com/pytorch/pytorch/pull/134087 + # non_differentiable_idx = fwd_kwargs.get("non_differentiable_idx") + length_of_tensor_args = sum(args_tensor_mask) + new_fwd_args = (wrap_const(None),) + fwd_args[:length_of_tensor_args] + old_scope = jit_ctx.computation_trace.scopes + fwd_bsyms = [] + jit_ctx.computation_trace.scopes = [fwd_bsyms] + + fwd_result = _interpret_call(fwd, *new_fwd_args) + if fwd_result is INTERPRETER_SIGNALS.EXCEPTION_RAISED: + return fwd_result + output, saved_values = unwrap(fwd_result) + wrapped_output = wrap(output, provenance=fwd_result.provenance) + + unwrapped_fwd_args = tree_map(lambda t: unwrap(t), new_fwd_args)[1:] + + producer_map = utils.producers(fwd_bsyms) + tensor_to_prod_bsym: dict[Variable, BoundSymbol] = {} + for p in tree_flatten((output, saved_values))[0]: + if not isinstance(p, TensorProxy): + continue + if p in producer_map: + prod_bsym = producer_map[p] + tensor_to_prod_bsym[variableify(p)] = prod_bsym + prod_bsym_to_tensor = {v: unvariableify(k) for k, v in tensor_to_prod_bsym.items()} + + # Encapsulate custom fwd into a bsym. + sym_id = f"autograd_function_apply_{_generate_random_str_id()}" + sym = Symbol( + name=sym_id, + id=sym_id, + _module=fwd_bsyms[-1].sym.module, + ) + bsym_of_custom_autograd_func = BoundSymbol( + sym, + args=unwrapped_fwd_args, + kwargs={}, + output=output, + subsymbols=fwd_bsyms, + header=( + f"output of fwd_body: {output}, saved_values from fwd_body: " + f"{[t.name if isinstance(t, Proxy) else t for t in saved_values]}" + ), + source_filename=jit_ctx.computation_trace._current_source_filename, + source_positions=None, + _call_ctx=fwd_bsyms[0]._call_ctx, + _import_ctx=fwd_bsyms[0]._import_ctx, + _object_ctx=fwd_bsyms[0]._object_ctx, + _executor=fwd_bsyms[0]._executor, + ) + old_scope[-1].append(bsym_of_custom_autograd_func) + + # Define augmented fwd rule and backward rule on the fly. + augmented_fwd_trace = TraceCtx() + for bsym in fwd_bsyms: + augmented_fwd_trace.add_bound_symbol(bsym) + augmented_fwd_trace.add_bound_symbol(prims.python_return.bind(output, saved_values, output=())) + si = SigInfo(f"augmented_autograd_function_apply_{sym_id}") + for a in bsym_of_custom_autograd_func.args: + if isinstance(a, Proxy): + si.args.append((a.name, None)) + else: + pa = proxy(a) + si.args.append((pa.name, a)) + augmented_fwd_trace._siginfo = si + augmented_fwd_callable = augmented_fwd_trace.python_callable(include_decorators=False) + + def augmented_fwd_rule(*args): + # First arg is `None` or `FunctionCtx` + updated_output, updated_saved_values = augmented_fwd_callable(*args) + residuals = tuple(sequencify(updated_saved_values)) + return VJPDual(primal=updated_output, residuals=residuals) + + augmented_forward_impls[sym.id] = augmented_fwd_rule + + bwd_bsyms = [] + jit_ctx.computation_trace.scopes = [bwd_bsyms] + bwd_trace = TraceCtx() + bwd_trace.bound_symbols = bwd_bsyms + + grads = sequencify(tree_map(lambda t: TensorProxy(like=t), output)) + bwd_args = (wrap_const(None),) + bwd_tensor_args = grads + tuple(saved_values) + wrapped_bwd_tensor_args = tree_map(lambda t: wrap(t, provenance=fwd_result.provenance), bwd_tensor_args) + bwd_result = _interpret_call(bwd, *(bwd_args + wrapped_bwd_tensor_args)) + if bwd_result is INTERPRETER_SIGNALS.EXCEPTION_RAISED: + return bwd_result + unwrapped_bwd_result = unwrap(bwd_result) + bwd_trace.bound_symbols.append(prims.python_return.bind(unwrapped_bwd_result, output=())) + + bwd_si = SigInfo(f"bwd_{si.name}") + for a in saved_values + grads: + bwd_si.args.append((a.name, None)) + bwd_trace._siginfo = bwd_si + backward_impls[sym.id] = bwd_trace.python_callable(include_decorators=False) + + jit_ctx.computation_trace.scopes = old_scope + return wrapped_output + + @register_general_jit_lookaside(torch.autocast.__enter__) def autocast_enter(autocast_obj): unwrap_autocast_obj = unwrap(autocast_obj) diff --git a/thunder/tests/test_jit_general.py b/thunder/tests/test_jit_general.py index 3c5f4c4d0f..af6d1365cf 100644 --- a/thunder/tests/test_jit_general.py +++ b/thunder/tests/test_jit_general.py @@ -1205,13 +1205,15 @@ def forward(self, x): def test_autograd_function_apply(): + # see https://github.com/Lightning-AI/lightning-thunder/issues/1248#issuecomment-2388655917 + # for why `torch.foo` instead of `torch.Tensor.foo` def forward(ctx, x): saved_for_backward = (x,) - return x.sin(), saved_for_backward + return torch.sin(x), saved_for_backward def backward(ctx, grad_output, *saved_tensors): (x,) = saved_tensors - return grad_output * x.cos() + return grad_output * torch.cos(x) def my_sin(x): return torch.ops.higher_order.autograd_function_apply( @@ -1231,11 +1233,10 @@ def my_sin(x): torch.testing.assert_close(y, y_ref) initial_computation_trace = thunder.last_traces(jitted)[0] - assert any( - bsym.sym.id == "torch.ops.higher_order.autograd_function_apply" - for bsym in initial_computation_trace.bound_symbols - if isinstance(bsym.sym.id, str) + bsym_str_ids = tuple( + bsym.sym.id for bsym in initial_computation_trace.bound_symbols if isinstance(bsym.sym.id, str) ) + assert any(bsid.startswith("autograd_function_apply") for bsid in bsym_str_ids), bsym_str_ids grad = torch.rand_like(y) actual_grad = torch.autograd.grad(y, x, grad) diff --git a/thunder/torch/__init__.py b/thunder/torch/__init__.py index 7047485ff6..9c92ec6d84 100644 --- a/thunder/torch/__init__.py +++ b/thunder/torch/__init__.py @@ -5571,46 +5571,6 @@ def wait(slf) -> None: utils.check(False, lambda: "torch.distributed is not available") -# ref: https://github.com/pytorch/pytorch/blob/b99ef1a/torch/_functorch/autograd_function.py#L715-L752 -@torchsymbol( - torch.ops.higher_order.autograd_function_apply, - id="torch.ops.higher_order.autograd_function_apply", - is_method=False, -) -def autograd_function_apply( - fwd: Callable[list[TensorProxy], TensorProxy | tuple[TensorProxy, ...]], - bwd: Callable[list[TensorProxy], TensorProxy | tuple[TensorProxy, ...]], - *args: Any, - args_tensor_mask: Sequence[bool] | None, - non_differentiable_idx: Sequence[int] | None = None, -) -> TensorProxy | tuple[TensorProxy, ...]: - result, saved_for_backward = fwd(None, *args) - return result - - -@register_augmented_forward("torch.ops.higher_order.autograd_function_apply") -def augmented_forward_autograd_function_apply( - fwd: Callable[list[Any | TensorProxy], TensorProxy | tuple[TensorProxy, ...]], - bwd: Callable[list[Any | TensorProxy], tuple[TensorProxy, ...]], - *args: Any, - args_tensor_mask: Sequence[bool], - non_differentiable_idx: Sequence[int] | None = None, -) -> tuple[TensorProxy | tuple[TensorProxy, ...], tuple[Any, ...]]: - result, saved_for_backward = fwd(None, *args) - return result, (saved_for_backward, bwd, args_tensor_mask, non_differentiable_idx) - - -@register_backward("torch.ops.higher_order.autograd_function_apply") -def backward_autograd_function_apply( - saved_for_backward: tuple[Any, ...], - bwd: Callable[list[Any | TensorProxy], tuple[TensorProxy, ...]], - args_tensor_mask: Sequence[bool], - non_differentiable_idx: Sequence[int] | None = None, - *grad_output: Sequence[TensorProxy], -) -> tuple[Any, ...]: - return bwd(None, *grad_output, *saved_for_backward) - - @torchsymbol( torch.amp.autocast_mode._enter_autocast, id="torch.amp.autocast_mode._enter_autocast",