You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
In PyTorch, you can define a custom forward pass by subclassing torch.autograd.Function. This allows you to specify the forward pass, backward pass, and gradient computation of your custom function.
For example, you could implement the Jax f function as follows in PyTorch:
Here, ctx.save_for_backward is used to save the values of x and y for use in the backward pass. The backward method then computes the gradients with respect to x and y using the saved values and the chain rule. Finally, the apply method is used to apply the custom function to the inputs x and y.
Is there a way to define a custom forward pass, like in jax, where one can output a residual that may be used by the backward pass?
For example, is the following example (from the Jax docs) implementable in autograd?
The text was updated successfully, but these errors were encountered: