diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 46977e56..c60591b4 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -10,7 +10,7 @@ jobs: runs-on: ubuntu-latest steps: - name: Release - uses: patrick-kidger/action_update_python_project@v2 + uses: patrick-kidger/action_update_python_project@v4 with: python-version: "3.11" test-script: | @@ -21,7 +21,3 @@ jobs: pypi-token: ${{ secrets.pypi_token }} github-user: patrick-kidger github-token: ${{ github.token }} - email-user: ${{ secrets.email_user }} - email-token: ${{ secrets.email_token }} - email-server: ${{ secrets.email_server }} - email-target: ${{ secrets.email_target }} diff --git a/equinox/_ad.py b/equinox/_ad.py index 2aaac6af..374347e4 100644 --- a/equinox/_ad.py +++ b/equinox/_ad.py @@ -878,9 +878,13 @@ def _none_to_zero(ct, x): if x is None: return None else: - # No raising-to-vspace. JAX is internally inconsistent, and expects integers - # to have integer tangents from custom_{jvp,vjp} rules - aval = jax.core.raise_to_shaped(jax.core.get_aval(x)) # .at_least_vspace() + aval = jax.core.raise_to_shaped(jax.core.get_aval(x)) + if hasattr(aval, "to_tangent_aval"): + # Earlier versions of JAX were internally inconsistent, and expected + # e.g. integer primals to have integer tangents from `custom_{jvp,vjp}` + # rules. + # That changed in JAX 0.4.34. + aval = aval.to_tangent_aval() # pyright: ignore return jax.custom_derivatives.SymbolicZero(aval) else: return ct diff --git a/equinox/_jit.py b/equinox/_jit.py index b1eea161..1d6b190a 100644 --- a/equinox/_jit.py +++ b/equinox/_jit.py @@ -110,11 +110,11 @@ def _postprocess(out): try: # Added in JAX 0.4.34. - JaxRuntimeError = jax.errors.JaxRuntimeError # pyright: ignore[reportAttributeAccessIssue] + JaxRuntimeError = jax.errors.JaxRuntimeError # pyright: ignore except AttributeError: try: # Forward compatibility in case they ever decide to fix the capitalization. - JaxRuntimeError = jax.errors.JAXRuntimeError # pyright: ignore[reportAttributeAccessIssue] + JaxRuntimeError = jax.errors.JAXRuntimeError # pyright: ignore except AttributeError: # Not public API, so wrap in a try-except for forward compatibility. try: diff --git a/equinox/internal/_primitive.py b/equinox/internal/_primitive.py index 91ffa0e5..c2d63f81 100644 --- a/equinox/internal/_primitive.py +++ b/equinox/internal/_primitive.py @@ -59,7 +59,14 @@ def _is_array_like_internal(x): def _zero_from_primal(p): assert type(p) is not ad.UndefinedPrimal - return ad.Zero(jax.core.get_aval(p).at_least_vspace()) + aval = jax.core.get_aval(p) + if hasattr(aval, "to_tangent_aval"): + # JAX >=0.4.34 + aval = aval.to_tangent_aval() # pyright: ignore + else: + # earlier JAX + aval = aval.at_least_vspace() + return ad.Zero(aval) def _combine(dynamic, static):