From af2bf177982fd84fe6dcd587258dd98a342e87e5 Mon Sep 17 00:00:00 2001 From: mattlevine22 Date: Wed, 13 Nov 2024 15:15:54 -0500 Subject: [PATCH] check-and-clip to [0,1] in linear_rescale to protect against floating point errors that sometimes occur on GPUs. --- diffrax/_misc.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/diffrax/_misc.py b/diffrax/_misc.py index 5a5a3b67..c10b37eb 100644 --- a/diffrax/_misc.py +++ b/diffrax/_misc.py @@ -73,17 +73,22 @@ def fill_forward( def linear_rescale(t0, t, t1) -> Array: - """Calculates (t - t0) / (t1 - t0), assuming t0 <= t <= t1. + """Calculates (t - t0) / (t1 - t0). Specially handles the edge case t0 == t1: - zero is returned; - gradients through all three arguments are zero. + - output conditionally clipped to be in [0,1] to protect + from floating point errors. """ cond = t0 == t1 numerator = cast(Array, jnp.where(cond, 0, t - t0)) denominator = cast(Array, jnp.where(cond, 1, t1 - t0)) - return numerator / denominator + out = numerator / denominator + positive_between = (t0 < t1) & (t0 <= t) & (t <= t1) + negative_between = (t1 < t0) & (t <= t0) & (t1 <= t) + return jnp.where(positive_between | negative_between, jnp.clip(out, 0, 1), out) def adjoint_rms_seminorm(x: tuple[PyTree, PyTree, PyTree, PyTree]) -> RealScalarLike: