diff --git a/test/test_integrate.py b/test/test_integrate.py index 555d6ade..db5794c0 100644 --- a/test/test_integrate.py +++ b/test/test_integrate.py @@ -792,3 +792,51 @@ def func(self, terms, t0, y0, args): ValueError, match=r"Terms are not compatible with solver!" ): diffrax.diffeqsolve(term, solver, 0.0, 1.0, 0.1, y0) + + +def test_vmap_backprop(): + + def dynamics(t, y, args): + param = args + return param - y + + + def event_fn(t, y, args, **kwargs): + return y - 1.5 + + def single_loss_fn(param): + solver = diffrax.Euler() + root_finder = optx.Newton(1e-2, 1e-2, optx.rms_norm) + event = diffrax.Event(event_fn, root_finder) + term = diffrax.ODETerm(dynamics) + + sol = diffrax.diffeqsolve( + term, + solver=solver, + t0=0.0, + t1=2.0, + dt0=0.1, + y0=0.0, + args=param, + event=event, + max_steps=1000, + ) + + final_y = sol.ys[-1] + return param**2 + final_y**2 + + def batched_loss_fn(params: jnp.ndarray) -> jnp.ndarray: + return jax.vmap(single_loss_fn)(params) + + def grad_fn(params: jnp.ndarray) -> jnp.ndarray: + return jax.grad(lambda p: jnp.sum(batched_loss_fn(p)))(params) + + try: + grad = grad_fn(y0) + except NotImplementedError as e: + pytest.fail(f"NotImplementedError was raised: {e}") + except Exception as e: + pytest.fail(f"An unexpected exception was raised: {e}") + + assert not jnp.isnan(grad), "Gradient should not be NaN." + assert not jnp.isinf(grad), "Gradient should not be infinite."