-
Hi, I'm pretty new here and recently started to learn about mlx by reimplementing existing PyTorch code with import torch.autograd as autograd
def jacobian(f, x):
"""Computes the Jacobian of f w.r.t x.
:param f: function R^N -> R^N
:param x: torch.tensor of shape [B, N]
:return: Jacobian matrix (torch.tensor) of shape [B, N, N]
"""
B, N = x.shape
y = f(x)
jacobian = list()
for i in range(N):
v = torch.zeros_like(y)
v[:, i] = 1.
dy_i_dx = autograd.grad(y, x, grad_outputs=v, retain_graph=True, create_graph=True, allow_unused=True)[0] # shape [B, N]
jacobian.append(dy_i_dx)
jacobian = torch.stack(jacobian, dim=2).requires_grad_()
return jacobian As far as I understand, the relevant function for this is mlx.core.grad. However, I could not figure out how to perform partial derivatives with it. I would very appreciate any feedback and help. Thanks! |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 2 replies
-
You can do this very nicely in MLX with So in MLX you could do: import mlx.core as mx
def fun(x):
return x / x.sum()
def jacobian(f, x):
B, N = x.shape
I = mx.broadcast_to(mx.eye(N), (B, N, N))
def vjpfn(y):
return mx.vjp(f, (x,), (y,))
return mx.vmap(vjpfn, in_axes=1, out_axes=1)(I)
# B = 2, N = 3
x = mx.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
print(jacobian(fun, x)) |
Beta Was this translation helpful? Give feedback.
You can do this very nicely in MLX with
mx.vjp
(ormv.jvp
) andmx.vmap
. Jax has some really nice documentation on combining vmap and autograd to get Jacobians, Hessians, etc. The ideas mostly translate to MLX with a few slight API changes.So in MLX you could do: