Skip to content

Commit

Permalink
cherry-pick of #1256 as of 627845d
Browse files Browse the repository at this point in the history
Signed-off-by: Masaki Kozuki <[email protected]>
  • Loading branch information
crcrpar committed Nov 13, 2024
1 parent 2b8284d commit cc2bbbe
Show file tree
Hide file tree
Showing 3 changed files with 126 additions and 46 deletions.
119 changes: 119 additions & 0 deletions thunder/core/jit_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -769,6 +769,125 @@ def grad_transform(*args, **kwargs):
return forward_result


# ref: https://github.com/pytorch/pytorch/blob/38114ec/torch/_functorch/autograd_function.py#L715-L752
@register_general_jit_lookaside(torch.ops.higher_order.autograd_function_apply)
def _general_jit_torch_ops_higher_order_autograd_function_apply(fwd, bwd, *fwd_args, **fwd_kwargs):
from thunder.core import utils
from thunder.core.baseutils import sequencify
from thunder.core.pytree import tree_flatten, tree_map
from thunder.core.transforms import VJPDual, augmented_forward_impls, backward_impls

def _generate_random_str_id() -> str:
import secrets
import string

length = 5
return "".join(secrets.choice(string.ascii_lowercase) for _ in range(length))

jit_ctx: JitCtx = get_jit_ctx()

args_tensor_mask = unwrap(fwd_kwargs["args_tensor_mask"])
# TODO(crcrpar): Think about making use of `non_differentiable_idx`
# note that this key is quite new: https://github.com/pytorch/pytorch/pull/134087
# non_differentiable_idx = fwd_kwargs.get("non_differentiable_idx")
length_of_tensor_args = sum(args_tensor_mask)
new_fwd_args = (wrap_const(None),) + fwd_args[:length_of_tensor_args]
old_scope = jit_ctx.computation_trace.scopes
fwd_bsyms = []
jit_ctx.computation_trace.scopes = [fwd_bsyms]

fwd_result = _interpret_call(fwd, *new_fwd_args)
if fwd_result is INTERPRETER_SIGNALS.EXCEPTION_RAISED:
return fwd_result
output, saved_values = unwrap(fwd_result)
wrapped_output = wrap(output, provenance=fwd_result.provenance)

unwrapped_fwd_args = tree_map(lambda t: unwrap(t), new_fwd_args)[1:]

producer_map = utils.producers(fwd_bsyms)
tensor_to_prod_bsym: dict[Variable, BoundSymbol] = {}
for p in tree_flatten((output, saved_values))[0]:
if not isinstance(p, TensorProxy):
continue
if p in producer_map:
prod_bsym = producer_map[p]
tensor_to_prod_bsym[variableify(p)] = prod_bsym
prod_bsym_to_tensor = {v: unvariableify(k) for k, v in tensor_to_prod_bsym.items()}

# Encapsulate custom fwd into a bsym.
sym_id = f"autograd_function_apply_{_generate_random_str_id()}"
sym = Symbol(
name=sym_id,
id=sym_id,
_module=fwd_bsyms[-1].sym.module,
)
bsym_of_custom_autograd_func = BoundSymbol(
sym,
args=unwrapped_fwd_args,
kwargs={},
output=output,
subsymbols=fwd_bsyms,
header=(
f"output of fwd_body: {output}, saved_values from fwd_body: "
f"{[t.name if isinstance(t, Proxy) else t for t in saved_values]}"
),
source_filename=jit_ctx.computation_trace._current_source_filename,
source_positions=None,
_call_ctx=fwd_bsyms[0]._call_ctx,
_import_ctx=fwd_bsyms[0]._import_ctx,
_object_ctx=fwd_bsyms[0]._object_ctx,
_executor=fwd_bsyms[0]._executor,
)
old_scope[-1].append(bsym_of_custom_autograd_func)

# Define augmented fwd rule and backward rule on the fly.
augmented_fwd_trace = TraceCtx()
for bsym in fwd_bsyms:
augmented_fwd_trace.add_bound_symbol(bsym)
augmented_fwd_trace.add_bound_symbol(prims.python_return.bind(output, saved_values, output=()))
si = SigInfo(f"augmented_autograd_function_apply_{sym_id}")
for a in bsym_of_custom_autograd_func.args:
if isinstance(a, Proxy):
si.args.append((a.name, None))
else:
pa = proxy(a)
si.args.append((pa.name, a))
augmented_fwd_trace._siginfo = si
augmented_fwd_callable = augmented_fwd_trace.python_callable(include_decorators=False)

def augmented_fwd_rule(*args):
# First arg is `None` or `FunctionCtx`
updated_output, updated_saved_values = augmented_fwd_callable(*args)
residuals = tuple(sequencify(updated_saved_values))
return VJPDual(primal=updated_output, residuals=residuals)

augmented_forward_impls[sym.id] = augmented_fwd_rule

bwd_bsyms = []
jit_ctx.computation_trace.scopes = [bwd_bsyms]
bwd_trace = TraceCtx()
bwd_trace.bound_symbols = bwd_bsyms

grads = sequencify(tree_map(lambda t: TensorProxy(like=t), output))
bwd_args = (wrap_const(None),)
bwd_tensor_args = grads + tuple(saved_values)
wrapped_bwd_tensor_args = tree_map(lambda t: wrap(t, provenance=fwd_result.provenance), bwd_tensor_args)
bwd_result = _interpret_call(bwd, *(bwd_args + wrapped_bwd_tensor_args))
if bwd_result is INTERPRETER_SIGNALS.EXCEPTION_RAISED:
return bwd_result
unwrapped_bwd_result = unwrap(bwd_result)
bwd_trace.bound_symbols.append(prims.python_return.bind(unwrapped_bwd_result, output=()))

bwd_si = SigInfo(f"bwd_{si.name}")
for a in saved_values + grads:
bwd_si.args.append((a.name, None))
bwd_trace._siginfo = bwd_si
backward_impls[sym.id] = bwd_trace.python_callable(include_decorators=False)

jit_ctx.computation_trace.scopes = old_scope
return wrapped_output


@register_general_jit_lookaside(torch.autocast.__enter__)
def autocast_enter(autocast_obj):
unwrap_autocast_obj = unwrap(autocast_obj)
Expand Down
13 changes: 7 additions & 6 deletions thunder/tests/test_jit_general.py
Original file line number Diff line number Diff line change
Expand Up @@ -1205,13 +1205,15 @@ def forward(self, x):

def test_autograd_function_apply():

# see https://github.com/Lightning-AI/lightning-thunder/issues/1248#issuecomment-2388655917
# for why `torch.foo` instead of `torch.Tensor.foo`
def forward(ctx, x):
saved_for_backward = (x,)
return x.sin(), saved_for_backward
return torch.sin(x), saved_for_backward

def backward(ctx, grad_output, *saved_tensors):
(x,) = saved_tensors
return grad_output * x.cos()
return grad_output * torch.cos(x)

def my_sin(x):
return torch.ops.higher_order.autograd_function_apply(
Expand All @@ -1231,11 +1233,10 @@ def my_sin(x):
torch.testing.assert_close(y, y_ref)

initial_computation_trace = thunder.last_traces(jitted)[0]
assert any(
bsym.sym.id == "torch.ops.higher_order.autograd_function_apply"
for bsym in initial_computation_trace.bound_symbols
if isinstance(bsym.sym.id, str)
bsym_str_ids = tuple(
bsym.sym.id for bsym in initial_computation_trace.bound_symbols if isinstance(bsym.sym.id, str)
)
assert any(bsid.startswith("autograd_function_apply") for bsid in bsym_str_ids), bsym_str_ids

grad = torch.rand_like(y)
actual_grad = torch.autograd.grad(y, x, grad)
Expand Down
40 changes: 0 additions & 40 deletions thunder/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5571,46 +5571,6 @@ def wait(slf) -> None:
utils.check(False, lambda: "torch.distributed is not available")


# ref: https://github.com/pytorch/pytorch/blob/b99ef1a/torch/_functorch/autograd_function.py#L715-L752
@torchsymbol(
torch.ops.higher_order.autograd_function_apply,
id="torch.ops.higher_order.autograd_function_apply",
is_method=False,
)
def autograd_function_apply(
fwd: Callable[list[TensorProxy], TensorProxy | tuple[TensorProxy, ...]],
bwd: Callable[list[TensorProxy], TensorProxy | tuple[TensorProxy, ...]],
*args: Any,
args_tensor_mask: Sequence[bool] | None,
non_differentiable_idx: Sequence[int] | None = None,
) -> TensorProxy | tuple[TensorProxy, ...]:
result, saved_for_backward = fwd(None, *args)
return result


@register_augmented_forward("torch.ops.higher_order.autograd_function_apply")
def augmented_forward_autograd_function_apply(
fwd: Callable[list[Any | TensorProxy], TensorProxy | tuple[TensorProxy, ...]],
bwd: Callable[list[Any | TensorProxy], tuple[TensorProxy, ...]],
*args: Any,
args_tensor_mask: Sequence[bool],
non_differentiable_idx: Sequence[int] | None = None,
) -> tuple[TensorProxy | tuple[TensorProxy, ...], tuple[Any, ...]]:
result, saved_for_backward = fwd(None, *args)
return result, (saved_for_backward, bwd, args_tensor_mask, non_differentiable_idx)


@register_backward("torch.ops.higher_order.autograd_function_apply")
def backward_autograd_function_apply(
saved_for_backward: tuple[Any, ...],
bwd: Callable[list[Any | TensorProxy], tuple[TensorProxy, ...]],
args_tensor_mask: Sequence[bool],
non_differentiable_idx: Sequence[int] | None = None,
*grad_output: Sequence[TensorProxy],
) -> tuple[Any, ...]:
return bwd(None, *grad_output, *saved_for_backward)


@torchsymbol(
torch.amp.autocast_mode._enter_autocast,
id="torch.amp.autocast_mode._enter_autocast",
Expand Down

0 comments on commit cc2bbbe

Please sign in to comment.