Skip to content

Commit

Permalink
Compatibility with JAX 0.4.36, which removes ConcreteArray
Browse files Browse the repository at this point in the history
  • Loading branch information
patrick-kidger committed Nov 29, 2024
1 parent aae4742 commit ba09fba
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 9 deletions.
5 changes: 2 additions & 3 deletions diffrax/_integrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import jax.numpy as jnp
import jax.tree_util as jtu
import lineax.internal as lxi
import numpy as np
import optimistix as optx
from jaxtyping import Array, ArrayLike, Float, Inexact, PyTree, Real

Expand Down Expand Up @@ -258,12 +259,10 @@ def _maybe_static(static_x: Optional[ArrayLike], x: ArrayLike) -> ArrayLike:
# Some values (made_jump and result) are not used in many common use-cases. If we
# detect that they're unused then we make sure they're non-Array Python values, so
# that we can special case on them at trace time and get a performance boost.
if isinstance(static_x, (bool, int, float, complex)):
if isinstance(static_x, (bool, int, float, complex, np.ndarray)):
return static_x
elif static_x is None:
return x
elif type(jax.core.get_aval(static_x)) is jax.core.ConcreteArray:
return static_x
else:
return x

Expand Down
9 changes: 3 additions & 6 deletions diffrax/_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import jax.lax as lax
import jax.numpy as jnp
import jax.tree_util as jtu
import numpy as np
import optimistix as optx
from jaxtyping import Array, ArrayLike, PyTree, Shaped

Expand Down Expand Up @@ -146,12 +147,8 @@ def static_select(pred: BoolScalarLike, a: ArrayLike, b: ArrayLike) -> ArrayLike
# predicate is statically known.
# This in turn allows us to perform some trace-time optimisations that XLA isn't
# smart enough to do on its own.
if (
type(pred) is not bool
and type(jax.core.get_aval(pred)) is jax.core.ConcreteArray
):
with jax.ensure_compile_time_eval():
pred = pred.item()
if isinstance(pred, (np.ndarray, np.generic)) and pred.shape == ():
pred = pred.item()
if pred is True:
return a
elif pred is False:
Expand Down

0 comments on commit ba09fba

Please sign in to comment.