diff --git a/diffrax/global_interpolation.py b/diffrax/global_interpolation.py index 8cd67914..5900bb45 100644 --- a/diffrax/global_interpolation.py +++ b/diffrax/global_interpolation.py @@ -318,40 +318,32 @@ def evaluate( if t1 is not None: return self.evaluate(t1, left=left) - self.evaluate(t0, left=left) t = t0 * self.direction - ts_0 = self.ts[0] - ts_1 = self.ts[self.ts_size - 1] - pred = (self.ts_size > 1) & (t >= ts_0) & (t <= ts_1) - eval_fn = ft.partial(self.__class__._evaluate, t=t, left=left) - nan_fn = self.__class__._nan - # Use cond to avoid generating nans unless we have to. - out = lax.cond(pred, eval_fn, nan_fn, self) + t_bounded = self._nan_if_out_of_bounds(t) + out = self._get_local_interpolation(t_bounded, left).evaluate( + t_bounded, left=left + ) keep = ft.partial(jnp.where, (t == self.t0_if_trivial) & (self.ts_size == 1)) return jtu.tree_map(keep, self.y0_if_trivial, out) @eqx.filter_jit def derivative(self, t: Scalar, left: bool = True) -> PyTree: t = t * self.direction + t = self._nan_if_out_of_bounds(t) + out = self._get_local_interpolation(t, left).derivative(t, left=left) + return (self.direction * out**ω).ω + + def _nan_if_out_of_bounds(self, t): # Note that len(self.ts) == max_steps + 1 > 0 so the indexing is always valid, # even if we throw it away because self.ts_size == 0. ts_0 = self.ts[0] ts_1 = self.ts[self.ts_size - 1] - pred = (self.ts_size > 1) & (t >= ts_0) & (t <= ts_1) - deriv_fn = ft.partial(self.__class__._derivative, t=t, left=left) - nan_fn = self.__class__._nan - # Use cond to avoid generating nans unless we have to. - return lax.cond(pred, deriv_fn, nan_fn, self) - - def _evaluate(self, t, left): - return self._get_local_interpolation(t, left).evaluate(t, left=left) - - def _derivative(self, t, left): - out = self._get_local_interpolation(t, left).derivative(t, left=left) - return (self.direction * out**ω).ω - - def _nan(self): - return jtu.tree_map( - ft.partial(jnp.full_like, fill_value=jnp.nan), self.y0_if_trivial - ) + out_of_bounds = (self.ts_size <= 1) | (t < ts_0) | (t > ts_1) + make_nans = lambda t: jnp.where(out_of_bounds, jnp.nan, t) + identity = lambda t: t + # Avoid making NaNs unless we have to, by using a cond. + # (For the sake of JAX_DEBUG_NANS.) + t = lax.cond(eqxi.unvmap_any(out_of_bounds), make_nans, identity, t) + return t @property def t0(self):