Skip to content

Commit

Permalink
Upgrade jax and fix deprecation warnings
Browse files Browse the repository at this point in the history
  • Loading branch information
DrJessop committed Dec 19, 2024
1 parent d7d2cb9 commit 3f32594
Show file tree
Hide file tree
Showing 14 changed files with 67 additions and 51 deletions.
3 changes: 2 additions & 1 deletion equinox/_ad.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import jax
import jax._src.traceback_util as traceback_util
import jax.core
import jax.extend.core
import jax.interpreters.ad as ad
import jax.numpy as jnp
import jax.tree_util as jtu
Expand Down Expand Up @@ -598,7 +599,7 @@ class _ClosureConvert(Module):
# Important that `jaxpr` be a leaf (and not static), so that it is a tuple element
# when passing through `filter_primitive_bind` and thus visible to
# `jax.core.subjaxprs`
jaxpr: jax.core.Jaxpr
jaxpr: jax.extend.core.Jaxpr
consts: PyTree[ArrayLike] # Captured in the PyTree structure of _ClosureConvert
in_dynamic_struct: _FlatPyTree[jax.ShapeDtypeStruct] = field(static=True)
out_dynamic_struct: _FlatPyTree[jax.ShapeDtypeStruct] = field(static=True)
Expand Down
4 changes: 2 additions & 2 deletions equinox/_make_jaxpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import jax
import jax._src.traceback_util as traceback_util
import jax.core
import jax.extend.core
import jax.tree_util as jtu
from jaxtyping import PyTree

Expand Down Expand Up @@ -49,7 +49,7 @@ def _fn(*_dynamic_flat):
def filter_make_jaxpr(
fun: Callable[_P, Any],
) -> Callable[
_P, tuple[jax.core.ClosedJaxpr, PyTree[jax.ShapeDtypeStruct], PyTree[Any]]
_P, tuple[jax.extend.core.ClosedJaxpr, PyTree[jax.ShapeDtypeStruct], PyTree[Any]]
]:
"""As `jax.make_jaxpr`, but accepts arbitrary PyTrees as input and output.
Expand Down
7 changes: 4 additions & 3 deletions equinox/_unvmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import jax
import jax.core
import jax.extend.core
import jax.interpreters.batching as batching
import jax.interpreters.mlir as mlir
import jax.numpy as jnp
Expand All @@ -10,7 +11,7 @@

# unvmap_all

unvmap_all_p = jax.core.Primitive("unvmap_all")
unvmap_all_p = jax.extend.core.Primitive("unvmap_all")


def unvmap_all(x: Bool[ArrayLike, "..."]) -> Bool[Array, ""]:
Expand Down Expand Up @@ -41,7 +42,7 @@ def _unvmap_all_batch(x, batch_axes):

# unvmap_any

unvmap_any_p = jax.core.Primitive("unvmap_any")
unvmap_any_p = jax.extend.core.Primitive("unvmap_any")


def unvmap_any(x: Bool[ArrayLike, "..."]) -> Bool[Array, ""]:
Expand Down Expand Up @@ -72,7 +73,7 @@ def _unvmap_any_batch(x, batch_axes):

# unvmap_max

unvmap_max_p = jax.core.Primitive("unvmap_max")
unvmap_max_p = jax.extend.core.Primitive("unvmap_max")


def unvmap_max(x: Int[ArrayLike, "..."]) -> Int[Array, ""]:
Expand Down
4 changes: 2 additions & 2 deletions equinox/debug/_announce_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Any

import jax
import jax.core
import jax.extend.core
import jax.interpreters.ad as ad
import jax.interpreters.batching as batching
import jax.interpreters.mlir as mlir
Expand Down Expand Up @@ -124,7 +124,7 @@ def _mlir(*x, stack, name, intermediates, announce):
return x


announce_jaxpr_p = jax.core.Primitive("announce_jaxpr")
announce_jaxpr_p = jax.extend.core.Primitive("announce_jaxpr")
announce_jaxpr_p.multiple_results = True
announce_jaxpr_p.def_impl(_impl)
announce_jaxpr_p.def_abstract_eval(_abstract)
Expand Down
40 changes: 22 additions & 18 deletions equinox/internal/_finalise_jaxpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import jax
import jax.core
import jax.custom_derivatives
import jax.extend.core
import jax.tree_util as jtu
from jaxtyping import PyTree

Expand All @@ -36,13 +37,13 @@ def _safe_map(f, *args):

def _maybe_finalise_jaxpr(val: Any):
is_open_jaxpr = False
if isinstance(val, jax.core.Jaxpr):
if isinstance(val, jax.extend.core.Jaxpr):
if len(val.constvars) == 0:
is_open_jaxpr = True
val = jax.core.ClosedJaxpr(val, [])
val = jax.extend.core.ClosedJaxpr(val, [])
else:
return val
if isinstance(val, jax.core.ClosedJaxpr):
if isinstance(val, jax.extend.core.ClosedJaxpr):
val = finalise_jaxpr(val)
if is_open_jaxpr:
val = val.jaxpr
Expand All @@ -60,33 +61,33 @@ def _finalise_jaxprs_in_params(params):
return new_params


def _default_finalisation(prim: jax.core.Primitive, *args, **kwargs):
def _default_finalisation(prim: jax.extend.core.Primitive, *args, **kwargs):
return prim.bind(*args, **kwargs)


def _impl_finalisation(prim: jax.core.Primitive, *args, **kwargs):
def _impl_finalisation(prim: jax.extend.core.Primitive, *args, **kwargs):
return prim.impl(*args, **kwargs)


primitive_finalisations = {}


def register_impl_finalisation(prim: jax.core.Primitive):
def register_impl_finalisation(prim: jax.extend.core.Primitive):
primitive_finalisations[prim] = ft.partial(_impl_finalisation, prim)


def finalise_eval_jaxpr(jaxpr: jax.core.Jaxpr, consts, *args):
def finalise_eval_jaxpr(jaxpr: jax.extend.core.Jaxpr, consts, *args):
"""As jax.core.eval_jaxpr, but finalises (typically by calling `impl` rather than
`bind` for custom primitives).
"""

def read(v: jax.core.Atom) -> Any:
return v.val if isinstance(v, jax.core.Literal) else env[v]
return v.val if isinstance(v, jax.extend.core.Literal) else env[v]

def write(v: jax.core.Var, val: Any) -> None:
def write(v: jax.extend.core.Var, val: Any) -> None:
env[v] = val

env: dict[jax.core.Var, Any] = {}
env: dict[jax.extend.core.Var, Any] = {}
_safe_map(write, jaxpr.constvars, consts)
_safe_map(write, jaxpr.invars, args)
for eqn in jaxpr.eqns:
Expand All @@ -104,18 +105,18 @@ def write(v: jax.core.Var, val: Any) -> None:
return _safe_map(read, jaxpr.outvars)


def finalise_jaxpr_as_fn(jaxpr: jax.core.ClosedJaxpr):
def finalise_jaxpr_as_fn(jaxpr: jax.extend.core.ClosedJaxpr):
"""As `jax.core.jaxpr_as_fn`, but the result is finalised."""
return ft.partial(finalise_eval_jaxpr, jaxpr.jaxpr, jaxpr.consts)


def finalise_jaxpr(jaxpr: jax.core.ClosedJaxpr) -> jax.core.ClosedJaxpr:
def finalise_jaxpr(jaxpr: jax.extend.core.ClosedJaxpr) -> jax.extend.core.ClosedJaxpr:
"""A jaxpr-to-jaxpr transformation that performs finalisation."""
fn = finalise_jaxpr_as_fn(jaxpr)
args = [
jax.ShapeDtypeStruct(x.aval.shape, x.aval.dtype) for x in jaxpr.jaxpr.invars
]
return cast(jax.core.ClosedJaxpr, jax.make_jaxpr(fn)(*args))
return cast(jax.extend.core.ClosedJaxpr, jax.make_jaxpr(fn)(*args))


def finalise_fn(fn):
Expand All @@ -136,13 +137,15 @@ def _finalise_fn(*args):
@overload
def finalise_make_jaxpr(
fn, *, return_shape: Literal[False] = False
) -> Callable[..., jax.core.ClosedJaxpr]: ...
) -> Callable[..., jax.extend.core.ClosedJaxpr]: ...


@overload
def finalise_make_jaxpr(
fn, *, return_shape: Literal[True] = True
) -> Callable[..., tuple[jax.core.ClosedJaxpr, PyTree[jax.ShapeDtypeStruct]]]: ...
) -> Callable[
..., tuple[jax.extend.core.ClosedJaxpr, PyTree[jax.ShapeDtypeStruct]]
]: ...


@overload
Expand All @@ -151,7 +154,8 @@ def finalise_make_jaxpr(
) -> Callable[
...,
Union[
jax.core.ClosedJaxpr, tuple[jax.core.ClosedJaxpr, PyTree[jax.ShapeDtypeStruct]]
jax.extend.core.ClosedJaxpr,
tuple[jax.extend.core.ClosedJaxpr, PyTree[jax.ShapeDtypeStruct]],
],
]: ...

Expand All @@ -164,12 +168,12 @@ def _finalise_make_jaxpr(*args):
*args
)
if return_shape:
jaxpr_struct = cast(tuple[jax.core.ClosedJaxpr, Any], jaxpr_struct)
jaxpr_struct = cast(tuple[jax.extend.core.ClosedJaxpr, Any], jaxpr_struct)
jaxpr, struct = jaxpr_struct
jaxpr = finalise_jaxpr(jaxpr)
return jaxpr, struct
else:
jaxpr_struct = cast(jax.core.ClosedJaxpr, jaxpr_struct)
jaxpr_struct = cast(jax.extend.core.ClosedJaxpr, jaxpr_struct)
jaxpr = finalise_jaxpr(jaxpr_struct)
return jaxpr

Expand Down
4 changes: 2 additions & 2 deletions equinox/internal/_loop/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Any, TYPE_CHECKING, Union

import jax
import jax.core
import jax.extend.core
import jax.interpreters.ad as ad
import jax.interpreters.batching as batching
import jax.interpreters.mlir as mlir
Expand Down Expand Up @@ -105,7 +105,7 @@ def _select_if_vmap_batch(axis_size, axis_name, trace, inputs, batch_axes):
return out, out_axis


select_if_vmap_p = jax.core.Primitive("select_if_vmap")
select_if_vmap_p = jax.extend.core.Primitive("select_if_vmap")
select_if_vmap_p.def_impl(_select_if_vmap_impl)
select_if_vmap_p.def_abstract_eval(_select_if_vmap_abstract)
ad.primitive_jvps[select_if_vmap_p] = _select_if_vmap_jvp
Expand Down
3 changes: 2 additions & 1 deletion equinox/internal/_noinline.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import jax
import jax.core
import jax.extend.core
import jax.interpreters.ad as ad
import jax.interpreters.batching as batching
import jax.interpreters.mlir as mlir
Expand Down Expand Up @@ -330,7 +331,7 @@ def _noinline_mlir(ctx, *dynamic, treedef, static, flatten, **kwargs):
return result


noinline_p = jax.core.Primitive("noinline")
noinline_p = jax.extend.core.Primitive("noinline")
noinline_p.multiple_results = True
noinline_p.def_impl(_noinline_impl)
noinline_p.def_abstract_eval(_noinline_abstract)
Expand Down
8 changes: 4 additions & 4 deletions equinox/internal/_nontraceable.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing import Optional

import jax
import jax.core
import jax.extend.core
import jax.interpreters.ad as ad
import jax.interpreters.batching as batching
import jax.interpreters.mlir as mlir
Expand All @@ -29,7 +29,7 @@ def _error(*args, name):
return _error


nontraceable_p = jax.core.Primitive("nontraceable")
nontraceable_p = jax.extend.core.Primitive("nontraceable")
nontraceable_p.def_impl(_nontraceable_impl)
nontraceable_p.def_abstract_eval(_nontraceable_impl)
ad.primitive_jvps[nontraceable_p] = _make_error("differentiation")
Expand All @@ -53,7 +53,7 @@ def nontraceable(x, *, name="nontraceable operation"):
return combine(dynamic, static)


nondifferentiable_backward_p = jax.core.Primitive("nondifferentiable_backward")
nondifferentiable_backward_p = jax.extend.core.Primitive("nondifferentiable_backward")


def _nondifferentiable_backward_batch(x, batch_axes, *, msg, symbolic):
Expand Down Expand Up @@ -137,7 +137,7 @@ def _cannot_batch(x, b, *, msg, allow_constant_across_batch):
raise ValueError(msg)


nonbatchable_p = jax.core.Primitive("nonbatchable")
nonbatchable_p = jax.extend.core.Primitive("nonbatchable")
nonbatchable_p.def_impl(lambda x, *, msg, allow_constant_across_batch: x)
nonbatchable_p.def_abstract_eval(lambda x, *, msg, allow_constant_across_batch: x)
batching.primitive_batchers[nonbatchable_p] = _cannot_batch
Expand Down
5 changes: 3 additions & 2 deletions equinox/internal/_primitive.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import jax
import jax.core
import jax.extend.core
import jax.interpreters.ad as ad
import jax.interpreters.batching as batching
import jax.interpreters.mlir as mlir
Expand Down Expand Up @@ -255,7 +256,7 @@ def _wrapper(dynamic, batch_axes, *, treedef, static, flatten):
return _wrapper


def filter_primitive_bind(prim: jax.core.Primitive, *args) -> PyTree:
def filter_primitive_bind(prim: jax.extend.core.Primitive, *args) -> PyTree:
"""Calls a primitive that has had its rules defined using the filter
functions above.
"""
Expand Down Expand Up @@ -301,7 +302,7 @@ def materialise_zeros(primal, tangent, allow_struct=False):


def create_vprim(name: str, impl, abstract_eval, jvp, transpose):
prim = jax.core.Primitive(name)
prim = jax.extend.core.Primitive(name)
prim.multiple_results = True

def batch_rule(axis_size, axis_name, trace_type, inputs, batch_axes, **params):
Expand Down
3 changes: 2 additions & 1 deletion mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,9 @@ plugins:
- import jaxtyping
- jaxtyping.set_array_name_format("array")
- import jax
- import jax.extend.core
- jax.ShapeDtypeStruct.__module__ = "jax"
- jax.core.ClosedJaxpr.__module__ = "jax.core"
- jax.extend.core.ClosedJaxpr.__module__ = "jax.core"

selection:
inherited_members: true # Allow looking up inherited methods
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ name = "equinox"
version = "0.11.10"
description = "Elegant easy-to-use neural networks in JAX."
readme = "README.md"
requires-python =">=3.9"
requires-python =">=3.10"
license = {file = "LICENSE"}
authors = [
{name = "Patrick Kidger", email = "[email protected]"},
Expand All @@ -23,7 +23,7 @@ classifiers = [
"Topic :: Scientific/Engineering :: Mathematics",
]
urls = {repository = "https://github.com/patrick-kidger/equinox" }
dependencies = ["jax>=0.4.13,!=0.4.27", "jaxtyping>=0.2.20", "typing_extensions>=4.5.0"]
dependencies = ["jax>=0.4.38", "jaxtyping>=0.2.20", "typing_extensions>=4.5.0"]

[build-system]
requires = ["hatchling"]
Expand Down
Loading

0 comments on commit 3f32594

Please sign in to comment.