diff --git a/diffrax/_misc.py b/diffrax/_misc.py index 5fe777b4..5a5a3b67 100644 --- a/diffrax/_misc.py +++ b/diffrax/_misc.py @@ -177,7 +177,7 @@ def upcast_or_raise( target_dtype = jnp.result_type(array_for_dtype) with jax.numpy_dtype_promotion("standard"): promote_dtype = jnp.result_type(x_dtype, target_dtype) - config_value = jax.config.jax_numpy_dtype_promotion + config_value = jax.config.jax_numpy_dtype_promotion # pyright: ignore if config_value == "strict": if target_dtype != promote_dtype: raise ValueError(