From 368014411412f303e342598b1f3cd8062552d0fe Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Tue, 18 Oct 2022 13:37:02 -0700 Subject: [PATCH] Add `jax.Array` public type which replaces `jnp.DeviceArray` and `jax.core.Tracer`. PiperOrigin-RevId: 482008250 --- jax_cfd/base/grids.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/jax_cfd/base/grids.py b/jax_cfd/base/grids.py index 468052b..dc00c9d 100644 --- a/jax_cfd/base/grids.py +++ b/jax_cfd/base/grids.py @@ -83,8 +83,7 @@ def dtype(self): def shape(self) -> Tuple[int, ...]: return self.data.shape - _HANDLED_TYPES = (numbers.Number, np.ndarray, jnp.DeviceArray, - jax.ShapedArray, jax.core.Tracer) + _HANDLED_TYPES = (numbers.Number, np.ndarray, jax.Array, jax.ShapedArray) def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): """Define arithmetic on GridArrays using NumPy's mixin."""