diff --git a/diffrax/_progress_meter.py b/diffrax/_progress_meter.py index 8ad0acaf..dc792101 100644 --- a/diffrax/_progress_meter.py +++ b/diffrax/_progress_meter.py @@ -291,11 +291,14 @@ def _step(_progress, _idx): except KeyError: pass # E.g. the backward pass after a forward pass. else: - step_bar(bar, _progress) + # As above, `_idx` may have a spurious batch tracer. Correspondingly + # `_progress` may pick up spurious length-1 batch dimensions from + # `vmap_method="expand_dims"` below. Remove them now. + step_bar(bar, np.array(_progress).reshape(())) # Return the idx to thread the callbacks in the correct order. return _idx - return jax.pure_callback(_step, idx, progress, idx, vectorized=True) + return jax.pure_callback(_step, idx, progress, idx, vmap_method="expand_dims") def close(self, close_bar: Callable[[Any], None], idx: IntScalarLike): def _close(_idx): diff --git a/pyproject.toml b/pyproject.toml index be5520e7..62055140 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,7 +23,7 @@ classifiers = [ "Topic :: Scientific/Engineering :: Mathematics", ] urls = {repository = "https://github.com/patrick-kidger/diffrax" } -dependencies = ["jax>=0.4.28", "jaxtyping>=0.2.24", "typing_extensions>=4.5.0", "typeguard==2.13.3", "equinox>=0.11.10", "lineax>=0.0.5", "optimistix>=0.0.7"] +dependencies = ["jax>=0.4.38", "jaxtyping>=0.2.24", "typing_extensions>=4.5.0", "typeguard==2.13.3", "equinox>=0.11.10", "lineax>=0.0.5", "optimistix>=0.0.7"] [build-system] requires = ["hatchling"]