Skip to content

Commit

Permalink
Silence warning from pure_callback(vectorized=True)
Browse files Browse the repository at this point in the history
  • Loading branch information
patrick-kidger committed Dec 24, 2024
1 parent 10bff5a commit 2589876
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 3 deletions.
7 changes: 5 additions & 2 deletions diffrax/_progress_meter.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,11 +291,14 @@ def _step(_progress, _idx):
except KeyError:
pass # E.g. the backward pass after a forward pass.
else:
step_bar(bar, _progress)
# As above, `_idx` may have a spurious batch tracer. Correspondingly
# `_progress` may pick up spurious length-1 batch dimensions from
# `vmap_method="expand_dims"` below. Remove them now.
step_bar(bar, np.array(_progress).reshape(()))
# Return the idx to thread the callbacks in the correct order.
return _idx

return jax.pure_callback(_step, idx, progress, idx, vectorized=True)
return jax.pure_callback(_step, idx, progress, idx, vmap_method="expand_dims")

def close(self, close_bar: Callable[[Any], None], idx: IntScalarLike):
def _close(_idx):
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ classifiers = [
"Topic :: Scientific/Engineering :: Mathematics",
]
urls = {repository = "https://github.com/patrick-kidger/diffrax" }
dependencies = ["jax>=0.4.28", "jaxtyping>=0.2.24", "typing_extensions>=4.5.0", "typeguard==2.13.3", "equinox>=0.11.10", "lineax>=0.0.5", "optimistix>=0.0.7"]
dependencies = ["jax>=0.4.38", "jaxtyping>=0.2.24", "typing_extensions>=4.5.0", "typeguard==2.13.3", "equinox>=0.11.10", "lineax>=0.0.5", "optimistix>=0.0.7"]

[build-system]
requires = ["hatchling"]
Expand Down

0 comments on commit 2589876

Please sign in to comment.