From 3553ae1360acbaf9f01143cab300dd0864e85a9e Mon Sep 17 00:00:00 2001 From: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com> Date: Tue, 17 Dec 2024 21:10:39 +0100 Subject: [PATCH] Fix pyright failures from numpy 2. --- diffrax/_misc.py | 2 +- diffrax/_progress_meter.py | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/diffrax/_misc.py b/diffrax/_misc.py index ac61b813..7c6fa53b 100644 --- a/diffrax/_misc.py +++ b/diffrax/_misc.py @@ -148,7 +148,7 @@ def static_select(pred: BoolScalarLike, a: ArrayLike, b: ArrayLike) -> ArrayLike # This in turn allows us to perform some trace-time optimisations that XLA isn't # smart enough to do on its own. if isinstance(pred, (np.ndarray, np.generic)) and pred.shape == (): - pred = pred.item() + pred = cast(BoolScalarLike, pred.item()) if pred is True: return a elif pred is False: diff --git a/diffrax/_progress_meter.py b/diffrax/_progress_meter.py index 8a813be6..8ad0acaf 100644 --- a/diffrax/_progress_meter.py +++ b/diffrax/_progress_meter.py @@ -123,8 +123,9 @@ def _step_bar(bar: list[float], progress: FloatScalarLike) -> None: if eqx.is_array(progress): # May not be an array when called with `JAX_DISABLE_JIT=1` progress = cast(Union[Array, np.ndarray], progress) - progress = progress.item() - progress = cast(float, progress) + progress = cast(float, progress.item()) + else: + progress = cast(float, progress) bar[0] = progress print(f"{100 * progress:.2f}%")