Skip to content

Commit

Permalink
Added new test checking gradient of vmapped diffeqsolve
Browse files Browse the repository at this point in the history
  • Loading branch information
LuggiStruggi authored Jan 14, 2025
1 parent d96572e commit dab9a1f
Showing 1 changed file with 48 additions and 0 deletions.
48 changes: 48 additions & 0 deletions test/test_integrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -792,3 +792,51 @@ def func(self, terms, t0, y0, args):
ValueError, match=r"Terms are not compatible with solver!"
):
diffrax.diffeqsolve(term, solver, 0.0, 1.0, 0.1, y0)


def test_vmap_backprop():

def dynamics(t, y, args):
param = args
return param - y


def event_fn(t, y, args, **kwargs):
return y - 1.5

def single_loss_fn(param):
solver = diffrax.Euler()
root_finder = optx.Newton(1e-2, 1e-2, optx.rms_norm)
event = diffrax.Event(event_fn, root_finder)
term = diffrax.ODETerm(dynamics)

sol = diffrax.diffeqsolve(
term,
solver=solver,
t0=0.0,
t1=2.0,
dt0=0.1,
y0=0.0,
args=param,
event=event,
max_steps=1000,
)

final_y = sol.ys[-1]
return param**2 + final_y**2

def batched_loss_fn(params: jnp.ndarray) -> jnp.ndarray:
return jax.vmap(single_loss_fn)(params)

def grad_fn(params: jnp.ndarray) -> jnp.ndarray:
return jax.grad(lambda p: jnp.sum(batched_loss_fn(p)))(params)

try:
grad = grad_fn(y0)
except NotImplementedError as e:
pytest.fail(f"NotImplementedError was raised: {e}")
except Exception as e:
pytest.fail(f"An unexpected exception was raised: {e}")

assert not jnp.isnan(grad), "Gradient should not be NaN."
assert not jnp.isinf(grad), "Gradient should not be infinite."

0 comments on commit dab9a1f

Please sign in to comment.