Skip to content

Commit

Permalink
Replace jnp.where with static_select
Browse files Browse the repository at this point in the history
  • Loading branch information
stefanocortinovis authored and patrick-kidger committed Feb 8, 2024
1 parent 8aafefc commit e33c17e
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions diffrax/_step_size_controller/adaptive.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
VF,
Y,
)
from .._misc import upcast_or_raise
from .._misc import static_select, upcast_or_raise
from .._solution import RESULTS
from .._term import AbstractTerm, ODETerm
from .base import AbstractStepSizeController
Expand Down Expand Up @@ -604,7 +604,7 @@ def _scale(_y0, _y1_candidate, _y_error):
# This is important because we don't know whether or not the jump is as a
# result of a left- or right-discontinuity, so we have to skip the jump
# location altogether.
_t1 = jnp.where(made_jump, eqxi.nextafter(eqxi.nextafter(t1)), t1)
_t1 = static_select(made_jump, eqxi.nextafter(eqxi.nextafter(t1)), t1)
else:
_t1 = t1
next_t0 = jnp.where(keep_step, _t1, t0)
Expand Down

0 comments on commit e33c17e

Please sign in to comment.