Skip to content

Compute a Jacobian matrix of a function w.r.t. a tensor #671

Answered by awni
Maverobot asked this question in Q&A
Discussion options

You must be logged in to vote

You can do this very nicely in MLX with mx.vjp (or mv.jvp) and mx.vmap. Jax has some really nice documentation on combining vmap and autograd to get Jacobians, Hessians, etc. The ideas mostly translate to MLX with a few slight API changes.

So in MLX you could do:

import mlx.core as mx

def fun(x):
    return x / x.sum()

def jacobian(f, x):
    B, N = x.shape
    I = mx.broadcast_to(mx.eye(N), (B, N, N))
    def vjpfn(y):
        return mx.vjp(f, (x,), (y,))
    return mx.vmap(vjpfn, in_axes=1, out_axes=1)(I)

# B = 2, N = 3
x = mx.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
print(jacobian(fun, x))

Replies: 1 comment 2 replies

Comment options

You must be logged in to vote
2 replies
@awni
Comment options

@Maverobot
Comment options

Answer selected by Maverobot
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants