Skip to content

Commit

Permalink
Add jax.Array public type which replaces jnp.DeviceArray and `jax…
Browse files Browse the repository at this point in the history
….core.Tracer`.

PiperOrigin-RevId: 482008250
  • Loading branch information
yashk2810 authored and JAX-CFD authors committed Oct 18, 2022
1 parent 3c4663b commit 3680144
Showing 1 changed file with 1 addition and 2 deletions.
3 changes: 1 addition & 2 deletions jax_cfd/base/grids.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,7 @@ def dtype(self):
def shape(self) -> Tuple[int, ...]:
return self.data.shape

_HANDLED_TYPES = (numbers.Number, np.ndarray, jnp.DeviceArray,
jax.ShapedArray, jax.core.Tracer)
_HANDLED_TYPES = (numbers.Number, np.ndarray, jax.Array, jax.ShapedArray)

def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
"""Define arithmetic on GridArrays using NumPy's mixin."""
Expand Down

0 comments on commit 3680144

Please sign in to comment.