Skip to content

Commit

Permalink
check-and-clip to [0,1] in linear_rescale to protect against floating…
Browse files Browse the repository at this point in the history
… point errors that sometimes occur on GPUs.
  • Loading branch information
mattlevine22 committed Nov 13, 2024
1 parent 8b34a1c commit af2bf17
Showing 1 changed file with 7 additions and 2 deletions.
9 changes: 7 additions & 2 deletions diffrax/_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,17 +73,22 @@ def fill_forward(


def linear_rescale(t0, t, t1) -> Array:
"""Calculates (t - t0) / (t1 - t0), assuming t0 <= t <= t1.
"""Calculates (t - t0) / (t1 - t0).
Specially handles the edge case t0 == t1:
- zero is returned;
- gradients through all three arguments are zero.
- output conditionally clipped to be in [0,1] to protect
from floating point errors.
"""

cond = t0 == t1
numerator = cast(Array, jnp.where(cond, 0, t - t0))
denominator = cast(Array, jnp.where(cond, 1, t1 - t0))
return numerator / denominator
out = numerator / denominator
positive_between = (t0 < t1) & (t0 <= t) & (t <= t1)
negative_between = (t1 < t0) & (t <= t0) & (t1 <= t)
return jnp.where(positive_between | negative_between, jnp.clip(out, 0, 1), out)


def adjoint_rms_seminorm(x: tuple[PyTree, PyTree, PyTree, PyTree]) -> RealScalarLike:
Expand Down

0 comments on commit af2bf17

Please sign in to comment.