From 52a02361fb70b173d4381a0f1e6d90155795ae8b Mon Sep 17 00:00:00 2001 From: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com> Date: Sat, 2 Mar 2024 14:41:23 +0100 Subject: [PATCH] grad-of-vmap-of-linear_solve with symbolic zero cotangents no longer crashes --- lineax/_solve.py | 5 +++++ tests/test_solve.py | 18 ++++++++++++++++++ 2 files changed, 23 insertions(+) diff --git a/lineax/_solve.py b/lineax/_solve.py index 1cf9100..36fad0b 100644 --- a/lineax/_solve.py +++ b/lineax/_solve.py @@ -267,6 +267,11 @@ def _linear_solve_transpose(inputs, cts_out): jtu.tree_map( _assert_defined, (operator, state, options, solver), is_leaf=_is_undefined ) + cts_solution = jtu.tree_map( + ft.partial(eqxi.materialise_zeros, allow_struct=True), + operator.in_structure(), + cts_solution, + ) operator_transpose = operator.transpose() state_transpose, options_transpose = solver.transpose(state, options) cts_vector, _, _ = eqxi.filter_primitive_bind( diff --git a/tests/test_solve.py b/tests/test_solve.py index b76ba67..c87348f 100644 --- a/tests/test_solve.py +++ b/tests/test_solve.py @@ -138,3 +138,21 @@ def fn(y): grad, sol = jax.grad(f, has_aux=True)(x, z) assert tree_allclose(grad, -z / (x**2)) assert tree_allclose(sol, z / x) + + +def test_grad_vmap_symbolic_cotangent(): + def f(x): + return x[0], x[1] + + @jax.vmap + def to_vmap(x): + op = lx.FunctionLinearOperator(f, jax.eval_shape(lambda: x)) + sol = lx.linear_solve(op, x) + return sol.value[0] + + @jax.grad + def to_grad(x): + return jnp.sum(to_vmap(x)) + + x = (jnp.arange(3.0), jnp.arange(3.0)) + to_grad(x)