From e33c17eeb7ca5ec081f763f6f3d146c2e8df4231 Mon Sep 17 00:00:00 2001 From: Stefano Cortinovis Date: Thu, 1 Feb 2024 16:04:36 +0000 Subject: [PATCH] Replace jnp.where with static_select --- diffrax/_step_size_controller/adaptive.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/diffrax/_step_size_controller/adaptive.py b/diffrax/_step_size_controller/adaptive.py index 4b3d432a..927eaea3 100644 --- a/diffrax/_step_size_controller/adaptive.py +++ b/diffrax/_step_size_controller/adaptive.py @@ -27,7 +27,7 @@ VF, Y, ) -from .._misc import upcast_or_raise +from .._misc import static_select, upcast_or_raise from .._solution import RESULTS from .._term import AbstractTerm, ODETerm from .base import AbstractStepSizeController @@ -604,7 +604,7 @@ def _scale(_y0, _y1_candidate, _y_error): # This is important because we don't know whether or not the jump is as a # result of a left- or right-discontinuity, so we have to skip the jump # location altogether. - _t1 = jnp.where(made_jump, eqxi.nextafter(eqxi.nextafter(t1)), t1) + _t1 = static_select(made_jump, eqxi.nextafter(eqxi.nextafter(t1)), t1) else: _t1 = t1 next_t0 = jnp.where(keep_step, _t1, t0)