Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Best practice when taking the grad for a PCG #127

Open
Magwos opened this issue Jan 10, 2025 · 6 comments
Open

Best practice when taking the grad for a PCG #127

Magwos opened this issue Jan 10, 2025 · 6 comments
Labels
question User queries

Comments

@Magwos
Copy link

Magwos commented Jan 10, 2025

Hi, thanks a lot for your amazing work with this package!

I currently have an issue when taking the auto-differentiation of linear_solve when a preconditioner is involved in options, as it leads to
RuntimeError: Unexpected tangent. lineax.linear_solve(..., options=...)cannot be autodifferentiated.

Here is a minimal working example (thanks @ASKabalan for this!) to reproduce the error

import jax.numpy as jnp
import lineax as lx
from jaxtyping import Array, Float  # https://github.com/google/jaxtyping
import jax 

def f(y: Float[Array, "3"], args) -> Float[Array, "3"]:
    y0, y1, y2 = y
    f0 = 5 * y0 + y1**2
    f1 = y1 - y2 + 5
    f2 = y0 / (1 + 5 * y2**2)
    return jnp.stack([f0, f1, f2])


y = jnp.array([[1.0, 2.0, 3.0] , [5.0, 6.0, 7.0], [8.0, 9.0, 10.0]])


def my_linear_operator(x , matrix):
    return jnp.dot(matrix, x)



def fun(y: Float[Array, "3"], args) -> Float[Array, "3"]:
    operator = lx.JacobianLinearOperator(f, y, args=None)
    vector = f(y, args=None)
    solver = lx.NormalCG(rtol=1e-6, atol=1e-6)

    A = lx.FunctionLinearOperator(lambda x : my_linear_operator(x , y), operator.in_structure())
    preconditioner = lx.TaggedLinearOperator(A, lx.positive_semidefinite_tag)

    options =  {"preconditioner": preconditioner}

    solution = lx.linear_solve(operator, vector, solver , options=options)

    return solution.value


fun(y, args=None)


# can be jitted
jax.jit(fun)(y, args=None)

# cannot be diffed

jax.grad(fun)(y, args=None)

I am taking the differentiation of a linear_solve with preconditioner in the context of a minimization (using optaxor optimistix), so where a jax.grad is computed at each step of the minimization and the preconditioner should depend on the updated parameters.

What would be the best practice in this context?

@johannahaffner
Copy link

Differentiating with respect to options is currently not supported #104 (comment). (And options could, in principle, be all kinds of things.)

Do you need the derivative of the solution with respect to the initial values? Or do you need a derivative of something else with respect to y, which gets updated here?

@Magwos
Copy link
Author

Magwos commented Jan 10, 2025

Hi @johannahaffner,
I need a derivate of the solution with respect to the initial values, however the preconditioner gets updated with these values as well, as in the example.

@patrick-kidger
Copy link
Owner

I think you want to wrap your options in a lax.stop_gradient.

As Johanna notes then options could in principle be anything, so we don't do this automatically -- we don't want to silently not compute gradients where some might be expected!

But I think in your case then no gradients are probably what you want.

@patrick-kidger patrick-kidger added the question User queries label Jan 10, 2025
@johannahaffner
Copy link

johannahaffner commented Jan 10, 2025

I had tried that, but since the preconditioner is a FunctionLinearOperator, it contains a jaxpr and this is not a valid input to jax.lax.stop_gradient. Materialising the preconditioner works, but kind of defeats the purpose of using CG.

(FWIW solution.value is not a scalar, so a gradient is not possible - but eqx.filter_jacrev seems to work. We do get an error in the solve for the materialised preconditioner though, maybe a loss of precision? The rows of y are linearly dependent though, so this could be it as well.)

@patrick-kidger
Copy link
Owner

Hmm probably stop_gradient should just be applied to the arrays then, via equinox.partition and equinox.combine.

Maybe we should just always apply such a stop-gradient to preconditioners? Mathematically the output should not depend on the preconditioner anyway, and it would save users from having to do this.

@ASKabalan
Copy link

I think so yes.
Because in our case the array that goes into the preconditionner is the initial values ... so we do want the gradient with respect to this array .. just not propagate during a section of the code

Example

from jax import lax

def fn(a):
    a_bis = a * 2
    return a_bis **2


def fn_with_stop_grad(a):
    a_bis = a * 2
    a_bis = lax.stop_gradient(a_bis)
    return a_bis **2

def required_out(a):
    return a**2

a = 2.0

jax.grad(fn)(a) # 16 Correct
jax.grad(fn_with_stop_grad)(a) # 0 Gradient entirely stopped
jax.grad(required_out)(2 * a) # 8 gradient of 2*a is stopped

so stopping the gradient on an array will set the entier autodiff graph to 0

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question User queries
Projects
None yet
Development

No branches or pull requests

4 participants