From 8c0d6f0e1f721ac5a6490bb6203657de062018d9 Mon Sep 17 00:00:00 2001 From: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com> Date: Fri, 28 Jul 2023 15:53:59 -0700 Subject: [PATCH] Added test for grad-of-discontinuous-forcing --- test/test_adaptive_stepsize_controller.py | 45 +++++++++++++++++++++++ 1 file changed, 45 insertions(+) diff --git a/test/test_adaptive_stepsize_controller.py b/test/test_adaptive_stepsize_controller.py index 6f4dd494..997dc4ae 100644 --- a/test/test_adaptive_stepsize_controller.py +++ b/test/test_adaptive_stepsize_controller.py @@ -1,8 +1,11 @@ import diffrax import equinox as eqx +import jax import jax.numpy as jnp import jax.tree_util as jtu +from .helpers import shaped_allclose + def test_step_ts(): term = diffrax.ODETerm(lambda t, y, args: -0.2 * y) @@ -90,3 +93,45 @@ def run(ys, controller, state): ys = (y0, y1_candidate, y_error) grads = run(ys, stepsize_controller, state) assert not any(jnp.isnan(grad).any() for grad in grads) + + +def test_grad_of_discontinuous_forcing(): + def vector_field(t, y, forcing): + y, _ = y + dy = -y + forcing(t) + dsum = y + return dy, dsum + + def run(t): + term = diffrax.ODETerm(vector_field) + solver = diffrax.Tsit5() + t0 = 0 + t1 = 1 + dt0 = None + y0 = 1.0 + stepsize_controller = diffrax.PIDController( + rtol=1e-8, atol=1e-8, step_ts=t[None] + ) + + def forcing(s): + return jnp.where(s < t, 0, 1) + + sol = diffrax.diffeqsolve( + term, + solver, + t0, + t1, + dt0, + (y0, 0), + args=forcing, + stepsize_controller=stepsize_controller, + ) + _, sum = sol.ys + (sum,) = sum + return sum + + r = jax.jit(run) + eps = 1e-5 + finite_diff = (r(0.5) - r(0.5 - eps)) / eps + autodiff = jax.jit(jax.grad(run))(0.5) + assert shaped_allclose(finite_diff, autodiff)