diff --git a/diffrax/brownian/path.py b/diffrax/brownian/path.py index e9c0136c..84019f01 100644 --- a/diffrax/brownian/path.py +++ b/diffrax/brownian/path.py @@ -43,7 +43,9 @@ def __init__( key: "jax.random.PRNGKey", ): self.shape = ( - jax.ShapeDtypeStruct(shape, None) if is_tuple_of_ints(shape) else shape + jax.ShapeDtypeStruct(shape, jax.dtypes.canonicalize_dtype(None)) + if is_tuple_of_ints(shape) + else shape ) self.key = key if any( diff --git a/diffrax/brownian/tree.py b/diffrax/brownian/tree.py index 38cb96e2..0941d544 100644 --- a/diffrax/brownian/tree.py +++ b/diffrax/brownian/tree.py @@ -75,7 +75,9 @@ def __init__( self.t1 = t1 self.tol = tol self.shape = ( - jax.ShapeDtypeStruct(shape, None) if is_tuple_of_ints(shape) else shape + jax.ShapeDtypeStruct(shape, jax.dtypes.canonicalize_dtype(None)) + if is_tuple_of_ints(shape) + else shape ) if any( not jnp.issubdtype(x.dtype, jnp.inexact)