Skip to content

Commit

Permalink
added vmap test
Browse files Browse the repository at this point in the history
  • Loading branch information
dkweiss31 committed Nov 27, 2024
1 parent c994a60 commit 78289c9
Showing 1 changed file with 38 additions and 0 deletions.
38 changes: 38 additions & 0 deletions test/test_saveat_solution.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,44 @@ def test_t0_eq_t1(subs):
assert tree_allclose(sol.ys, compare)


@pytest.mark.parametrize("subs", [True, False])
def test_vmap_t0_eq_t1(subs):
ntsave = 4
y0 = jnp.array([2.0])
term = diffrax.ODETerm(lambda t, y, args: y)

def _solve(tf):
ts = jnp.linspace(0.0, tf, ntsave)
get0 = diffrax.SubSaveAt(
ts=ts,
t1=True,
)
get1 = diffrax.SubSaveAt(
t0=True,
ts=ts,
)
subs = (get0, get1)
saveat = diffrax.SaveAt(subs=subs)
return diffrax.diffeqsolve(
term,
t0=ts[0],
t1=ts[-1],
y0=y0,
dt0=0.1,
solver=diffrax.Dopri5(),
saveat=saveat,
)

compare = jnp.full((ntsave + 1, *y0.shape), y0)
sol = jax.vmap(_solve)(jnp.array([0.0, 1.0]))
assert tree_allclose(sol.ys[0][0], compare) # pyright: ignore
assert tree_allclose(sol.ys[1][0], compare) # pyright: ignore

regular_solve = _solve(1.0)
assert tree_allclose(sol.ys[0][1], regular_solve.ys[0]) # pyright: ignore
assert tree_allclose(sol.ys[1][1], regular_solve.ys[1]) # pyright: ignore


def test_trivial_dense():
term = diffrax.ODETerm(lambda t, y, args: -0.5 * y)
y0 = jnp.array([2.1])
Expand Down

0 comments on commit 78289c9

Please sign in to comment.