diff --git a/test/test_integrate.py b/test/test_integrate.py index 8706a597..efa5cbf6 100644 --- a/test/test_integrate.py +++ b/test/test_integrate.py @@ -801,7 +801,6 @@ def dynamics(t, y, args): param = args return param - y - def event_fn(t, y, args, **kwargs): return y - 1.5 @@ -809,8 +808,8 @@ def single_loss_fn(param): solver = diffrax.Euler() root_finder = optx.Newton(1e-2, 1e-2, optx.rms_norm) event = diffrax.Event(event_fn, root_finder) - term = diffrax.ODETerm(dynamics) - + term = diffrax.ODETerm(dynamics) + sol = diffrax.diffeqsolve( term, solver=solver, @@ -831,9 +830,11 @@ def batched_loss_fn(params: jnp.ndarray) -> jnp.ndarray: def grad_fn(params: jnp.ndarray) -> jnp.ndarray: return jax.grad(lambda p: jnp.sum(batched_loss_fn(p)))(params) + + batch = jnp.array([1.0, 2.0, 3.0]) try: - grad = grad_fn(y0) + grad = grad_fn(batch) except NotImplementedError as e: pytest.fail(f"NotImplementedError was raised: {e}") except Exception as e: