Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More fixes for JAX 0.4.34 #871

Merged
merged 2 commits into from
Oct 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions .github/workflows/release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Release
uses: patrick-kidger/action_update_python_project@v2
uses: patrick-kidger/action_update_python_project@v4
with:
python-version: "3.11"
test-script: |
Expand All @@ -21,7 +21,3 @@ jobs:
pypi-token: ${{ secrets.pypi_token }}
github-user: patrick-kidger
github-token: ${{ github.token }}
email-user: ${{ secrets.email_user }}
email-token: ${{ secrets.email_token }}
email-server: ${{ secrets.email_server }}
email-target: ${{ secrets.email_target }}
10 changes: 7 additions & 3 deletions equinox/_ad.py
Original file line number Diff line number Diff line change
Expand Up @@ -878,9 +878,13 @@ def _none_to_zero(ct, x):
if x is None:
return None
else:
# No raising-to-vspace. JAX is internally inconsistent, and expects integers
# to have integer tangents from custom_{jvp,vjp} rules
aval = jax.core.raise_to_shaped(jax.core.get_aval(x)) # .at_least_vspace()
aval = jax.core.raise_to_shaped(jax.core.get_aval(x))
if hasattr(aval, "to_tangent_aval"):
# Earlier versions of JAX were internally inconsistent, and expected
# e.g. integer primals to have integer tangents from `custom_{jvp,vjp}`
# rules.
# That changed in JAX 0.4.34.
aval = aval.to_tangent_aval() # pyright: ignore
return jax.custom_derivatives.SymbolicZero(aval)
else:
return ct
Expand Down
4 changes: 2 additions & 2 deletions equinox/_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,11 +110,11 @@ def _postprocess(out):

try:
# Added in JAX 0.4.34.
JaxRuntimeError = jax.errors.JaxRuntimeError # pyright: ignore[reportAttributeAccessIssue]
JaxRuntimeError = jax.errors.JaxRuntimeError # pyright: ignore
except AttributeError:
try:
# Forward compatibility in case they ever decide to fix the capitalization.
JaxRuntimeError = jax.errors.JAXRuntimeError # pyright: ignore[reportAttributeAccessIssue]
JaxRuntimeError = jax.errors.JAXRuntimeError # pyright: ignore
except AttributeError:
# Not public API, so wrap in a try-except for forward compatibility.
try:
Expand Down
9 changes: 8 additions & 1 deletion equinox/internal/_primitive.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,14 @@ def _is_array_like_internal(x):

def _zero_from_primal(p):
assert type(p) is not ad.UndefinedPrimal
return ad.Zero(jax.core.get_aval(p).at_least_vspace())
aval = jax.core.get_aval(p)
if hasattr(aval, "to_tangent_aval"):
# JAX >=0.4.34
aval = aval.to_tangent_aval() # pyright: ignore
else:
# earlier JAX
aval = aval.at_least_vspace()
return ad.Zero(aval)


def _combine(dynamic, static):
Expand Down
Loading