diff --git a/test/test_saveat_solution.py b/test/test_saveat_solution.py index a3f103a8..6b85125b 100644 --- a/test/test_saveat_solution.py +++ b/test/test_saveat_solution.py @@ -183,6 +183,44 @@ def test_t0_eq_t1(subs): assert tree_allclose(sol.ys, compare) +@pytest.mark.parametrize("subs", [True, False]) +def test_vmap_t0_eq_t1(subs): + ntsave = 4 + y0 = jnp.array([2.0]) + term = diffrax.ODETerm(lambda t, y, args: y) + + def _solve(tf): + ts = jnp.linspace(0.0, tf, ntsave) + get0 = diffrax.SubSaveAt( + ts=ts, + t1=True, + ) + get1 = diffrax.SubSaveAt( + t0=True, + ts=ts, + ) + subs = (get0, get1) + saveat = diffrax.SaveAt(subs=subs) + return diffrax.diffeqsolve( + term, + t0=ts[0], + t1=ts[-1], + y0=y0, + dt0=0.1, + solver=diffrax.Dopri5(), + saveat=saveat, + ) + + compare = jnp.full((ntsave + 1, *y0.shape), y0) + sol = jax.vmap(_solve)(jnp.array([0.0, 1.0])) + assert tree_allclose(sol.ys[0][0], compare) # pyright: ignore + assert tree_allclose(sol.ys[1][0], compare) # pyright: ignore + + regular_solve = _solve(1.0) + assert tree_allclose(sol.ys[0][1], regular_solve.ys[0]) # pyright: ignore + assert tree_allclose(sol.ys[1][1], regular_solve.ys[1]) # pyright: ignore + + def test_trivial_dense(): term = diffrax.ODETerm(lambda t, y, args: -0.5 * y) y0 = jnp.array([2.1])