Skip to content

Commit

Permalink
Added test for grad-of-discontinuous-forcing
Browse files Browse the repository at this point in the history
  • Loading branch information
patrick-kidger committed Jul 28, 2023
1 parent b71f8fa commit 8c0d6f0
Showing 1 changed file with 45 additions and 0 deletions.
45 changes: 45 additions & 0 deletions test/test_adaptive_stepsize_controller.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import diffrax
import equinox as eqx
import jax
import jax.numpy as jnp
import jax.tree_util as jtu

from .helpers import shaped_allclose


def test_step_ts():
term = diffrax.ODETerm(lambda t, y, args: -0.2 * y)
Expand Down Expand Up @@ -90,3 +93,45 @@ def run(ys, controller, state):
ys = (y0, y1_candidate, y_error)
grads = run(ys, stepsize_controller, state)
assert not any(jnp.isnan(grad).any() for grad in grads)


def test_grad_of_discontinuous_forcing():
def vector_field(t, y, forcing):
y, _ = y
dy = -y + forcing(t)
dsum = y
return dy, dsum

def run(t):
term = diffrax.ODETerm(vector_field)
solver = diffrax.Tsit5()
t0 = 0
t1 = 1
dt0 = None
y0 = 1.0
stepsize_controller = diffrax.PIDController(
rtol=1e-8, atol=1e-8, step_ts=t[None]
)

def forcing(s):
return jnp.where(s < t, 0, 1)

sol = diffrax.diffeqsolve(
term,
solver,
t0,
t1,
dt0,
(y0, 0),
args=forcing,
stepsize_controller=stepsize_controller,
)
_, sum = sol.ys
(sum,) = sum
return sum

r = jax.jit(run)
eps = 1e-5
finite_diff = (r(0.5) - r(0.5 - eps)) / eps
autodiff = jax.jit(jax.grad(run))(0.5)
assert shaped_allclose(finite_diff, autodiff)

0 comments on commit 8c0d6f0

Please sign in to comment.