From 8ba696acd424f1edbee944f92823aa512b8eaa1f Mon Sep 17 00:00:00 2001 From: kadykov <62546709+kadykov@users.noreply.github.com> Date: Wed, 4 Jan 2023 12:17:25 +0100 Subject: [PATCH 01/23] Fix broken dask_array_type import --- cupy_xarray/accessors.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/cupy_xarray/accessors.py b/cupy_xarray/accessors.py index 325f2f3..06a5631 100644 --- a/cupy_xarray/accessors.py +++ b/cupy_xarray/accessors.py @@ -5,7 +5,10 @@ register_dataarray_accessor, register_dataset_accessor, ) -from xarray.core.pycompat import dask_array_type +from xarray.core.pycompat import DuckArrayModule + +dsk = DuckArrayModule("dask") +dask_array_type = dsk.type @register_dataarray_accessor("cupy") From 485d9c1a53052df29cd270e90d9b3b770760c0c0 Mon Sep 17 00:00:00 2001 From: kadykov <62546709+kadykov@users.noreply.github.com> Date: Fri, 13 Jan 2023 13:49:46 +0100 Subject: [PATCH 02/23] Fix is_cupy for pint arrays This commit fix the issue where `is_cupy` shows wrong status for pint arrays on GPU --- cupy_xarray/accessors.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/cupy_xarray/accessors.py b/cupy_xarray/accessors.py index 06a5631..f88d946 100644 --- a/cupy_xarray/accessors.py +++ b/cupy_xarray/accessors.py @@ -7,8 +7,8 @@ ) from xarray.core.pycompat import DuckArrayModule -dsk = DuckArrayModule("dask") -dask_array_type = dsk.type +dask_array_type = DuckArrayModule("dask").type +pint_array_type = DuckArrayModule("pint").type @register_dataarray_accessor("cupy") @@ -26,6 +26,8 @@ def is_cupy(self): """bool: The underlying data is a cupy array.""" if isinstance(self.da.data, dask_array_type): return isinstance(self.da.data._meta, cp.ndarray) + if isinstance(self.da.data, pint_array_type): + return isinstance(self.da.data.magnitude, cp.ndarray) return isinstance(self.da.data, cp.ndarray) def as_cupy(self): From 45aac6cd546b6fedd0d4d5875b14caf67569f7f3 Mon Sep 17 00:00:00 2001 From: kadykov <62546709+kadykov@users.noreply.github.com> Date: Fri, 13 Jan 2023 17:49:16 +0100 Subject: [PATCH 03/23] Fix units drop during as_cupy conversion from pint xarray --- cupy_xarray/accessors.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/cupy_xarray/accessors.py b/cupy_xarray/accessors.py index f88d946..ca9ba26 100644 --- a/cupy_xarray/accessors.py +++ b/cupy_xarray/accessors.py @@ -61,6 +61,8 @@ def as_cupy(self): name=self.da.name, attrs=self.da.attrs, ) + if isinstance(self.da.data, pint_array_type): + return self.da.pint.dequantify().cupy.as_cupy().pint.quantify() return DataArray( data=cp.asarray(self.da.data), coords=self.da.coords, From 528b25854091238a8fa334e9b72eff82c4922bef Mon Sep 17 00:00:00 2001 From: kadykov <62546709+kadykov@users.noreply.github.com> Date: Mon, 16 Jan 2023 12:03:20 +0100 Subject: [PATCH 04/23] Construct pint array without pint xarray accessor --- cupy_xarray/accessors.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/cupy_xarray/accessors.py b/cupy_xarray/accessors.py index ca9ba26..9cc572b 100644 --- a/cupy_xarray/accessors.py +++ b/cupy_xarray/accessors.py @@ -62,7 +62,13 @@ def as_cupy(self): attrs=self.da.attrs, ) if isinstance(self.da.data, pint_array_type): - return self.da.pint.dequantify().cupy.as_cupy().pint.quantify() + return DataArray( + data=(self.da.data.units * cp.asarray(self.da.data.magnitude)), + coords=self.da.coords, + dims=self.da.dims, + name=self.da.name, + attrs=self.da.attrs, + ) return DataArray( data=cp.asarray(self.da.data), coords=self.da.coords, From a864e6ae6a4129acd187290231f28ae57606ef1a Mon Sep 17 00:00:00 2001 From: kadykov <62546709+kadykov@users.noreply.github.com> Date: Mon, 16 Jan 2023 12:21:43 +0100 Subject: [PATCH 05/23] Refactor with _as_dataarray wrapper --- cupy_xarray/accessors.py | 43 +++++++++++++--------------------------- 1 file changed, 14 insertions(+), 29 deletions(-) diff --git a/cupy_xarray/accessors.py b/cupy_xarray/accessors.py index 9cc572b..4cb3fc6 100644 --- a/cupy_xarray/accessors.py +++ b/cupy_xarray/accessors.py @@ -54,28 +54,14 @@ def as_cupy(self): """ if isinstance(self.da.data, dask_array_type): - return DataArray( + return self._as_dataarray( data=self.da.data.map_blocks(cp.asarray), - coords=self.da.coords, - dims=self.da.dims, - name=self.da.name, - attrs=self.da.attrs, ) if isinstance(self.da.data, pint_array_type): - return DataArray( + return self._as_dataarray( data=(self.da.data.units * cp.asarray(self.da.data.magnitude)), - coords=self.da.coords, - dims=self.da.dims, - name=self.da.name, - attrs=self.da.attrs, ) - return DataArray( - data=cp.asarray(self.da.data), - coords=self.da.coords, - dims=self.da.dims, - name=self.da.name, - attrs=self.da.attrs, - ) + return self._as_dataarray(data=cp.asarray(self.da.data)) def as_numpy(self): """ @@ -89,27 +75,26 @@ def as_numpy(self): """ if self.is_cupy: if isinstance(self.da.data, dask_array_type): - return DataArray( + return self._as_dataarray( data=self.da.data.map_blocks( lambda block: block.get(), dtype=self.da.data._meta.dtype ), - coords=self.da.coords, - dims=self.da.dims, - name=self.da.name, - attrs=self.da.attrs, ) - return DataArray( - data=self.da.data.get(), - coords=self.da.coords, - dims=self.da.dims, - name=self.da.name, - attrs=self.da.attrs, - ) + return self._as_dataarray(data=self.da.data.get()) return self.da.as_numpy() def get(self): return self.da.data.get() + def _as_dataarray(self, data): + return DataArray( + data=data, + coords=self.da.coords, + dims=self.da.dims, + name=self.da.name, + attrs=self.da.attrs, + ) + @register_dataset_accessor("cupy") class CupyDatasetAccessor: From 1c4b1a7596f5d013aafde6f12bfe2650def0ab32 Mon Sep 17 00:00:00 2001 From: kadykov <62546709+kadykov@users.noreply.github.com> Date: Mon, 16 Jan 2023 18:08:53 +0100 Subject: [PATCH 06/23] Add basic tests for pint arrays --- cupy_xarray/tests/test_accessors.py | 37 ++++++++++++++++++++++++++++- 1 file changed, 36 insertions(+), 1 deletion(-) diff --git a/cupy_xarray/tests/test_accessors.py b/cupy_xarray/tests/test_accessors.py index cad6955..3bc4d8a 100644 --- a/cupy_xarray/tests/test_accessors.py +++ b/cupy_xarray/tests/test_accessors.py @@ -1,9 +1,15 @@ import numpy as np import pytest import xarray as xr -from xarray.core.pycompat import dask_array_type +from xarray.tests import requires_pint +from xarray.core.pycompat import DuckArrayModule + +dask_array_type = DuckArrayModule("dask").type +pint_array_type = DuckArrayModule("pint").type import cupy_xarray # noqa: F401 +from pint import UnitRegistry +ureg = UnitRegistry() @pytest.fixture @@ -26,6 +32,21 @@ def tutorial_da_air_dask(tutorial_ds_air_dask): return tutorial_ds_air_dask.air +@pytest.fixture +def tutorial_ds_air_pint(): + return ( + xr.tutorial.load_dataset("air_temperature") + * ureg.Quantity("degree_Kelvin") + ) + + +@pytest.fixture +def tutorial_da_air_pint(tutorial_ds_air_pint): + return tutorial_ds_air_pint.air + + + + def test_data_set_accessor(tutorial_ds_air): ds = tutorial_ds_air assert hasattr(ds, "cupy") @@ -64,3 +85,17 @@ def test_data_array_accessor_dask(tutorial_da_air_dask): da = da.cupy.as_numpy() assert not da.cupy.is_cupy + + +@requires_pint +def test_data_array_accessor_pint(tutorial_da_air_pint): + da = tutorial_da_air_pint + assert hasattr(da, "cupy") + assert not da.cupy.is_cupy + + da = da.as_cupy() + assert da.cupy.is_cupy + assert isinstance(da.data, pint_array_type) + + da = da.cupy.as_numpy() + assert not da.cupy.is_cupy From fa059bf9924c5490a688f20271cc7e6e2bf7a17c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 16 Jan 2023 17:10:42 +0000 Subject: [PATCH 07/23] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- cupy_xarray/tests/test_accessors.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/cupy_xarray/tests/test_accessors.py b/cupy_xarray/tests/test_accessors.py index 3bc4d8a..7d82c5b 100644 --- a/cupy_xarray/tests/test_accessors.py +++ b/cupy_xarray/tests/test_accessors.py @@ -1,14 +1,16 @@ import numpy as np import pytest import xarray as xr -from xarray.tests import requires_pint from xarray.core.pycompat import DuckArrayModule +from xarray.tests import requires_pint dask_array_type = DuckArrayModule("dask").type pint_array_type = DuckArrayModule("pint").type -import cupy_xarray # noqa: F401 from pint import UnitRegistry + +import cupy_xarray # noqa: F401 + ureg = UnitRegistry() @@ -34,10 +36,7 @@ def tutorial_da_air_dask(tutorial_ds_air_dask): @pytest.fixture def tutorial_ds_air_pint(): - return ( - xr.tutorial.load_dataset("air_temperature") - * ureg.Quantity("degree_Kelvin") - ) + return xr.tutorial.load_dataset("air_temperature") * ureg.Quantity("degree_Kelvin") @pytest.fixture @@ -45,8 +44,6 @@ def tutorial_da_air_pint(tutorial_ds_air_pint): return tutorial_ds_air_pint.air - - def test_data_set_accessor(tutorial_ds_air): ds = tutorial_ds_air assert hasattr(ds, "cupy") From c7580b3a67125e1a53f3dcd328ec70dc2c80c95a Mon Sep 17 00:00:00 2001 From: kadykov <62546709+kadykov@users.noreply.github.com> Date: Mon, 23 Jan 2023 15:21:32 +0100 Subject: [PATCH 08/23] Fix as_cupy() for arrays created by pint_xarray --- cupy_xarray/accessors.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/cupy_xarray/accessors.py b/cupy_xarray/accessors.py index 4cb3fc6..5565f09 100644 --- a/cupy_xarray/accessors.py +++ b/cupy_xarray/accessors.py @@ -58,8 +58,9 @@ def as_cupy(self): data=self.da.data.map_blocks(cp.asarray), ) if isinstance(self.da.data, pint_array_type): + from pint import Quantity return self._as_dataarray( - data=(self.da.data.units * cp.asarray(self.da.data.magnitude)), + data=Quantity(cp.asarray(self.da.data.magnitude), units=self.da.data.units), ) return self._as_dataarray(data=cp.asarray(self.da.data)) From 9ad2357216e0ccfcff4cafb410cf2f16180a63dc Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 23 Jan 2023 14:21:41 +0000 Subject: [PATCH 09/23] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- cupy_xarray/accessors.py | 1 + 1 file changed, 1 insertion(+) diff --git a/cupy_xarray/accessors.py b/cupy_xarray/accessors.py index 5565f09..b1acdf4 100644 --- a/cupy_xarray/accessors.py +++ b/cupy_xarray/accessors.py @@ -59,6 +59,7 @@ def as_cupy(self): ) if isinstance(self.da.data, pint_array_type): from pint import Quantity + return self._as_dataarray( data=Quantity(cp.asarray(self.da.data.magnitude), units=self.da.data.units), ) From ca9b51ce8c4a77ec397ee3278a250d5cacd25bd2 Mon Sep 17 00:00:00 2001 From: kadykov <62546709+kadykov@users.noreply.github.com> Date: Tue, 24 Jan 2023 16:32:43 +0100 Subject: [PATCH 10/23] Add pint nested array support --- cupy_xarray/accessors.py | 60 +++++++++++++++++++++++++--------------- 1 file changed, 38 insertions(+), 22 deletions(-) diff --git a/cupy_xarray/accessors.py b/cupy_xarray/accessors.py index b1acdf4..11e0d52 100644 --- a/cupy_xarray/accessors.py +++ b/cupy_xarray/accessors.py @@ -24,11 +24,15 @@ def __init__(self, da): @property def is_cupy(self): """bool: The underlying data is a cupy array.""" - if isinstance(self.da.data, dask_array_type): - return isinstance(self.da.data._meta, cp.ndarray) - if isinstance(self.da.data, pint_array_type): - return isinstance(self.da.data.magnitude, cp.ndarray) - return isinstance(self.da.data, cp.ndarray) + return self._get_datatype(self.da.data) + + @classmethod + def _get_datatype(cls, data): + if isinstance(data, dask_array_type): + return isinstance(data._meta, cp.ndarray) + elif isinstance(data, pint_array_type): + return cls._get_datatype(data.magnitude) + return isinstance(data, cp.ndarray) def as_cupy(self): """ @@ -53,17 +57,20 @@ def as_cupy(self): """ - if isinstance(self.da.data, dask_array_type): - return self._as_dataarray( - data=self.da.data.map_blocks(cp.asarray), - ) - if isinstance(self.da.data, pint_array_type): + return self._as_dataarray(self._as_cupy_data(self.da.data)) + + @classmethod + def _as_cupy_data(cls, data): + if isinstance(data, dask_array_type): + return data.map_blocks(cp.asarray) + if isinstance(data, pint_array_type): from pint import Quantity - return self._as_dataarray( - data=Quantity(cp.asarray(self.da.data.magnitude), units=self.da.data.units), + return Quantity( + cls._as_cupy_data(data.magnitude), + units=data.units, ) - return self._as_dataarray(data=cp.asarray(self.da.data)) + return cp.asarray(data) def as_numpy(self): """ @@ -75,15 +82,24 @@ def as_numpy(self): DataArray with underlying data cast to numpy. """ - if self.is_cupy: - if isinstance(self.da.data, dask_array_type): - return self._as_dataarray( - data=self.da.data.map_blocks( - lambda block: block.get(), dtype=self.da.data._meta.dtype - ), - ) - return self._as_dataarray(data=self.da.data.get()) - return self.da.as_numpy() + return self._as_dataarray(self._as_numpy_data(self.da.data)) + + @classmethod + def _as_numpy_data(cls, data): + if isinstance(data, dask_array_type): + return data.map_blocks( + lambda block: block.get(), dtype=data._meta.dtype + ) + if isinstance(data, pint_array_type): + from pint import Quantity + + return Quantity( + cls._as_numpy_data(data.magnitude), + units=data.units, + ) + if isinstance(data, cp.ndarray): + return data.get() + return data def get(self): return self.da.data.get() From efac1451f0e1cb1f7c9c51119195b06f5cce6af2 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 24 Jan 2023 15:33:53 +0000 Subject: [PATCH 11/23] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- cupy_xarray/accessors.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/cupy_xarray/accessors.py b/cupy_xarray/accessors.py index 11e0d52..7976a04 100644 --- a/cupy_xarray/accessors.py +++ b/cupy_xarray/accessors.py @@ -87,9 +87,7 @@ def as_numpy(self): @classmethod def _as_numpy_data(cls, data): if isinstance(data, dask_array_type): - return data.map_blocks( - lambda block: block.get(), dtype=data._meta.dtype - ) + return data.map_blocks(lambda block: block.get(), dtype=data._meta.dtype) if isinstance(data, pint_array_type): from pint import Quantity From 12e476a3c156342880218d8835a6f824568e80c8 Mon Sep 17 00:00:00 2001 From: kadykov <62546709+kadykov@users.noreply.github.com> Date: Tue, 24 Jan 2023 16:34:37 +0100 Subject: [PATCH 12/23] Add basic pint(dask) array tests --- cupy_xarray/tests/test_accessors.py | 31 +++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/cupy_xarray/tests/test_accessors.py b/cupy_xarray/tests/test_accessors.py index 7d82c5b..a69e0f4 100644 --- a/cupy_xarray/tests/test_accessors.py +++ b/cupy_xarray/tests/test_accessors.py @@ -44,6 +44,19 @@ def tutorial_da_air_pint(tutorial_ds_air_pint): return tutorial_ds_air_pint.air +@pytest.fixture +def tutorial_ds_air_pint_dask(): + return ( + xr.tutorial.open_dataset("air_temperature", chunks={"lat": 25, "lon": 25, "time": -1}) + * ureg.Quantity("degree_Kelvin") + ) + + +@pytest.fixture +def tutorial_da_air_pint_dask(tutorial_ds_air_pint_dask): + return tutorial_ds_air_pint_dask.air + + def test_data_set_accessor(tutorial_ds_air): ds = tutorial_ds_air assert hasattr(ds, "cupy") @@ -96,3 +109,21 @@ def test_data_array_accessor_pint(tutorial_da_air_pint): da = da.cupy.as_numpy() assert not da.cupy.is_cupy + assert isinstance(da.data, pint_array_type) + + +@requires_pint +def test_data_array_accessor_pint_dask(tutorial_da_air_pint_dask): + da = tutorial_da_air_pint_dask + assert hasattr(da, "cupy") + assert not da.cupy.is_cupy + + da = da.as_cupy() + assert da.cupy.is_cupy + assert isinstance(da.data, pint_array_type) + assert isinstance(da.data.magnitude, dask_array_type) + + da = da.cupy.as_numpy() + assert not da.cupy.is_cupy + assert isinstance(da.data, pint_array_type) + assert isinstance(da.data.magnitude, dask_array_type) From 44bd5ad4b096e86b2fae12df5c1d0909a396573a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 24 Jan 2023 15:34:48 +0000 Subject: [PATCH 13/23] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- cupy_xarray/tests/test_accessors.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/cupy_xarray/tests/test_accessors.py b/cupy_xarray/tests/test_accessors.py index a69e0f4..1b63c2f 100644 --- a/cupy_xarray/tests/test_accessors.py +++ b/cupy_xarray/tests/test_accessors.py @@ -46,10 +46,9 @@ def tutorial_da_air_pint(tutorial_ds_air_pint): @pytest.fixture def tutorial_ds_air_pint_dask(): - return ( - xr.tutorial.open_dataset("air_temperature", chunks={"lat": 25, "lon": 25, "time": -1}) - * ureg.Quantity("degree_Kelvin") - ) + return xr.tutorial.open_dataset( + "air_temperature", chunks={"lat": 25, "lon": 25, "time": -1} + ) * ureg.Quantity("degree_Kelvin") @pytest.fixture From 0d841d197b24ed0b100cb3e87922119342023da8 Mon Sep 17 00:00:00 2001 From: Aleksandr Kadykov Date: Thu, 9 Mar 2023 18:10:38 +0100 Subject: [PATCH 14/23] Move helper functions outside of the class --- cupy_xarray/accessors.py | 59 ++++++++++++++++++++-------------------- 1 file changed, 29 insertions(+), 30 deletions(-) diff --git a/cupy_xarray/accessors.py b/cupy_xarray/accessors.py index 7976a04..463bcbe 100644 --- a/cupy_xarray/accessors.py +++ b/cupy_xarray/accessors.py @@ -57,20 +57,7 @@ def as_cupy(self): """ - return self._as_dataarray(self._as_cupy_data(self.da.data)) - - @classmethod - def _as_cupy_data(cls, data): - if isinstance(data, dask_array_type): - return data.map_blocks(cp.asarray) - if isinstance(data, pint_array_type): - from pint import Quantity - - return Quantity( - cls._as_cupy_data(data.magnitude), - units=data.units, - ) - return cp.asarray(data) + return self._as_dataarray(_as_cupy_data(self.da.data)) def as_numpy(self): """ @@ -82,22 +69,7 @@ def as_numpy(self): DataArray with underlying data cast to numpy. """ - return self._as_dataarray(self._as_numpy_data(self.da.data)) - - @classmethod - def _as_numpy_data(cls, data): - if isinstance(data, dask_array_type): - return data.map_blocks(lambda block: block.get(), dtype=data._meta.dtype) - if isinstance(data, pint_array_type): - from pint import Quantity - - return Quantity( - cls._as_numpy_data(data.magnitude), - units=data.units, - ) - if isinstance(data, cp.ndarray): - return data.get() - return data + return self._as_dataarray(_as_numpy_data(self.da.data)) def get(self): return self.da.data.get() @@ -112,6 +84,33 @@ def _as_dataarray(self, data): ) +def _as_cupy_data(data): + if isinstance(data, dask_array_type): + return data.map_blocks(cp.asarray) + if isinstance(data, pint_array_type): + from pint import Quantity # pylint: disable=import-outside-toplevel + + return Quantity( + _as_cupy_data(data.magnitude), + units=data.units, + ) + return cp.asarray(data) + +def _as_numpy_data(data): + if isinstance(data, dask_array_type): + return data.map_blocks( + lambda block: block.get(), dtype=data._meta.dtype + ) + if isinstance(data, pint_array_type): + from pint import Quantity # pylint: disable=import-outside-toplevel + + return Quantity( + _as_numpy_data(data.magnitude), + units=data.units, + ) + return data.get() if isinstance(data, cp.ndarray) else data + + @register_dataset_accessor("cupy") class CupyDatasetAccessor: """ From 5f1079909394a1e52b317277038fecc9a21c1460 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 9 Mar 2023 17:10:58 +0000 Subject: [PATCH 15/23] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- cupy_xarray/accessors.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/cupy_xarray/accessors.py b/cupy_xarray/accessors.py index 463bcbe..db31b1a 100644 --- a/cupy_xarray/accessors.py +++ b/cupy_xarray/accessors.py @@ -88,7 +88,7 @@ def _as_cupy_data(data): if isinstance(data, dask_array_type): return data.map_blocks(cp.asarray) if isinstance(data, pint_array_type): - from pint import Quantity # pylint: disable=import-outside-toplevel + from pint import Quantity # pylint: disable=import-outside-toplevel return Quantity( _as_cupy_data(data.magnitude), @@ -96,13 +96,12 @@ def _as_cupy_data(data): ) return cp.asarray(data) + def _as_numpy_data(data): if isinstance(data, dask_array_type): - return data.map_blocks( - lambda block: block.get(), dtype=data._meta.dtype - ) + return data.map_blocks(lambda block: block.get(), dtype=data._meta.dtype) if isinstance(data, pint_array_type): - from pint import Quantity # pylint: disable=import-outside-toplevel + from pint import Quantity # pylint: disable=import-outside-toplevel return Quantity( _as_numpy_data(data.magnitude), From 34136e38eed8cc62b1a1f24bbff6cb18a1946ae8 Mon Sep 17 00:00:00 2001 From: Aleksandr Kadykov Date: Wed, 5 Apr 2023 16:15:55 +0200 Subject: [PATCH 16/23] Fix Dataset.cupy.as_numpy() error --- cupy_xarray/accessors.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/cupy_xarray/accessors.py b/cupy_xarray/accessors.py index db31b1a..996833f 100644 --- a/cupy_xarray/accessors.py +++ b/cupy_xarray/accessors.py @@ -125,8 +125,14 @@ def is_cupy(self): return all([da.cupy.is_cupy for da in self.ds.data_vars.values()]) def as_cupy(self): - data_vars = {var: da.as_cupy() for var, da in self.ds.data_vars.items()} - return Dataset(data_vars=data_vars, coords=self.ds.coords, attrs=self.ds.attrs) + if not self.is_cupy: + data_vars = {var: da.as_cupy() for var, da in self.ds.data_vars.items()} + return Dataset( + data_vars=data_vars, + coords=self.ds.coords, + attrs=self.ds.attrs, + ) + return self.ds def as_numpy(self): if self.is_cupy: @@ -136,8 +142,7 @@ def as_numpy(self): coords=self.ds.coords, attrs=self.ds.attrs, ) - else: - return self.ds.as_numpy() + return self.ds # Attach the `as_cupy` methods to the top level `Dataset` and `Dataarray` objects. From 5eeaa4ce285f8c20b602c55b9fa6aa895135455b Mon Sep 17 00:00:00 2001 From: Aleksandr Kadykov Date: Wed, 5 Apr 2023 16:17:54 +0200 Subject: [PATCH 17/23] Refactor is_cupy, as_cupy() and as_numpy() tests --- cupy_xarray/tests/test_accessors.py | 354 ++++++++++++++++++++++------ 1 file changed, 280 insertions(+), 74 deletions(-) diff --git a/cupy_xarray/tests/test_accessors.py b/cupy_xarray/tests/test_accessors.py index 1b63c2f..eea4a6f 100644 --- a/cupy_xarray/tests/test_accessors.py +++ b/cupy_xarray/tests/test_accessors.py @@ -1,128 +1,334 @@ +"""Tests for cupy-xarray accessors""" +import cupy as cp +import dask.array as da import numpy as np +import pint_xarray # pylint:disable=unused-import import pytest import xarray as xr from xarray.core.pycompat import DuckArrayModule -from xarray.tests import requires_pint +from xarray.tests import requires_cupy, requires_dask, requires_pint + +import cupy_xarray # noqa: F401 pylint:disable=unused-import dask_array_type = DuckArrayModule("dask").type pint_array_type = DuckArrayModule("pint").type +cupy_array_type = DuckArrayModule("cupy").type -from pint import UnitRegistry - -import cupy_xarray # noqa: F401 -ureg = UnitRegistry() +@pytest.fixture +def dataarray_numpy(): + """Prepare numpy DataArray""" + return xr.DataArray( + np.random.rand(2, 3), + attrs={"units": "candle"}, + ) @pytest.fixture -def tutorial_ds_air(): - return xr.tutorial.load_dataset("air_temperature") +def dataarray_cupy(): + """Prepare cupy DataArray""" + return xr.DataArray( + cp.random.rand(2, 3), + attrs={"units": "kelvin"}, + ) @pytest.fixture -def tutorial_da_air(tutorial_ds_air): - return tutorial_ds_air.air +def dataarray_dask(): + """Prepare dask DataArray""" + return xr.DataArray( + da.asarray(np.random.rand(2, 3)), + attrs={"units": "mole"}, + ) @pytest.fixture -def tutorial_ds_air_dask(): - return xr.tutorial.open_dataset("air_temperature", chunks={"lat": 25, "lon": 25, "time": -1}) +def dataarray_dask_cupy(): + """Prepare dask(cupy) DataArray""" + return xr.DataArray( + da.asarray(cp.random.rand(2, 3)), + attrs={"units": "mole"}, + ) + + +def to_pint(dataarray): + """Convert DataArray data to pint""" + return dataarray.pint.quantify() @pytest.fixture -def tutorial_da_air_dask(tutorial_ds_air_dask): - return tutorial_ds_air_dask.air +def dataarray_pint_numpy(dataarray_numpy): # pylint:disable=redefined-outer-name + """Prepare pint(numpy) DataArray""" + return to_pint(dataarray_numpy) @pytest.fixture -def tutorial_ds_air_pint(): - return xr.tutorial.load_dataset("air_temperature") * ureg.Quantity("degree_Kelvin") +def dataarray_pint_cupy(dataarray_cupy): # pylint:disable=redefined-outer-name + """Prepare pint(cupy) DataArray""" + return to_pint(dataarray_cupy) @pytest.fixture -def tutorial_da_air_pint(tutorial_ds_air_pint): - return tutorial_ds_air_pint.air +def dataarray_pint_dask(dataarray_dask): # pylint:disable=redefined-outer-name + """Prepare pint(dask) DataArray""" + return to_pint(dataarray_dask) @pytest.fixture -def tutorial_ds_air_pint_dask(): - return xr.tutorial.open_dataset( - "air_temperature", chunks={"lat": 25, "lon": 25, "time": -1} - ) * ureg.Quantity("degree_Kelvin") +def dataarray_pint_dask_cupy(dataarray_dask_cupy): # pylint:disable=redefined-outer-name + """Prepare pint(dask(cupy)) DataArray""" + return to_pint(dataarray_dask_cupy) + + +def to_dataset(dataarray): + """Convert DataArray to Dataset""" + return xr.Dataset(data_vars={"foo": dataarray}) @pytest.fixture -def tutorial_da_air_pint_dask(tutorial_ds_air_pint_dask): - return tutorial_ds_air_pint_dask.air +def dataset_numpy(dataarray_numpy): # pylint:disable=redefined-outer-name + """Prepare numpy Dataset""" + return to_dataset(dataarray_numpy) + +@pytest.fixture +def dataset_cupy(dataarray_cupy): # pylint:disable=redefined-outer-name + """Prepare cupy Dataset""" + return to_dataset(dataarray_cupy) -def test_data_set_accessor(tutorial_ds_air): - ds = tutorial_ds_air - assert hasattr(ds, "cupy") - assert not ds.cupy.is_cupy - ds = ds.as_cupy() - assert ds.cupy.is_cupy +@pytest.fixture +def dataset_dask(dataarray_dask): # pylint:disable=redefined-outer-name + """Prepare dask Dataset""" + return to_dataset(dataarray_dask) - ds = ds.cupy.as_numpy() - assert not ds.cupy.is_cupy +@pytest.fixture +def dataset_dask_cupy(dataarray_dask_cupy): # pylint:disable=redefined-outer-name + """Prepare dask(cupy) Dataset""" + return to_dataset(dataarray_dask_cupy) -def test_data_array_accessor(tutorial_da_air): - da = tutorial_da_air - assert hasattr(da, "cupy") - assert not da.cupy.is_cupy - da = da.as_cupy() - assert da.cupy.is_cupy +@pytest.fixture +def dataset_pint_numpy(dataarray_pint_numpy): # pylint:disable=redefined-outer-name + """Prepare pint(numpy) Dataset""" + return to_dataset(dataarray_pint_numpy) - garr = da.cupy.get() - assert isinstance(garr, np.ndarray) - da = da.cupy.as_numpy() - assert not da.cupy.is_cupy +@pytest.fixture +def dataset_pint_cupy(dataarray_pint_cupy): # pylint:disable=redefined-outer-name + """Prepare pint(cupy) Dataset""" + return to_dataset(dataarray_pint_cupy) -def test_data_array_accessor_dask(tutorial_da_air_dask): - da = tutorial_da_air_dask - assert hasattr(da, "cupy") - assert not da.cupy.is_cupy +@pytest.fixture +def dataset_pint_dask(dataarray_pint_dask): # pylint:disable=redefined-outer-name + """Prepare pint(dask) Dataset""" + return to_dataset(dataarray_pint_dask) - da = da.as_cupy() - assert da.cupy.is_cupy - assert isinstance(da.data, dask_array_type) - da = da.cupy.as_numpy() - assert not da.cupy.is_cupy +@pytest.fixture +def dataset_pint_dask_cupy(dataarray_pint_dask_cupy): # pylint:disable=redefined-outer-name + """Prepare pint(dask(cupy)) Dataset""" + return to_dataset(dataarray_pint_dask_cupy) @requires_pint -def test_data_array_accessor_pint(tutorial_da_air_pint): - da = tutorial_da_air_pint - assert hasattr(da, "cupy") - assert not da.cupy.is_cupy +@requires_dask +@requires_cupy +def test_is_cupy( + dataarray_numpy, # pylint:disable=redefined-outer-name + dataarray_cupy, # pylint:disable=redefined-outer-name + dataarray_dask, # pylint:disable=redefined-outer-name + dataarray_dask_cupy, # pylint:disable=redefined-outer-name + dataarray_pint_numpy, # pylint:disable=redefined-outer-name + dataarray_pint_cupy, # pylint:disable=redefined-outer-name + dataarray_pint_dask, # pylint:disable=redefined-outer-name + dataarray_pint_dask_cupy, # pylint:disable=redefined-outer-name + dataset_numpy, # pylint:disable=redefined-outer-name + dataset_cupy, # pylint:disable=redefined-outer-name + dataset_dask, # pylint:disable=redefined-outer-name + dataset_dask_cupy, # pylint:disable=redefined-outer-name + dataset_pint_numpy, # pylint:disable=redefined-outer-name + dataset_pint_cupy, # pylint:disable=redefined-outer-name + dataset_pint_dask, # pylint:disable=redefined-outer-name + dataset_pint_dask_cupy, # pylint:disable=redefined-outer-name +): + """Test is_cupy property in cupy xarray accessor""" + # Test all dataarray types + assert not dataarray_numpy.cupy.is_cupy + assert dataarray_cupy.cupy.is_cupy + assert not dataarray_dask.cupy.is_cupy + assert dataarray_dask_cupy.cupy.is_cupy + + # Test all pinted dataarray types + assert not dataarray_pint_numpy.cupy.is_cupy + assert dataarray_pint_cupy.cupy.is_cupy + assert not dataarray_pint_dask.cupy.is_cupy + assert dataarray_pint_dask_cupy.cupy.is_cupy + + # Test all dataset types + assert not dataset_numpy.cupy.is_cupy + assert dataset_cupy.cupy.is_cupy + assert not dataset_dask.cupy.is_cupy + assert dataset_dask_cupy.cupy.is_cupy + + # Test all pinted dataset types + assert not dataset_pint_numpy.cupy.is_cupy + assert dataset_pint_cupy.cupy.is_cupy + assert not dataset_pint_dask.cupy.is_cupy + assert dataset_pint_dask_cupy.cupy.is_cupy - da = da.as_cupy() - assert da.cupy.is_cupy - assert isinstance(da.data, pint_array_type) - da = da.cupy.as_numpy() - assert not da.cupy.is_cupy - assert isinstance(da.data, pint_array_type) +@requires_pint +@requires_dask +@requires_cupy +def test_as_cupy( + dataarray_numpy, # pylint:disable=redefined-outer-name + dataarray_cupy, # pylint:disable=redefined-outer-name + dataarray_dask, # pylint:disable=redefined-outer-name + dataarray_dask_cupy, # pylint:disable=redefined-outer-name + dataarray_pint_numpy, # pylint:disable=redefined-outer-name + dataarray_pint_cupy, # pylint:disable=redefined-outer-name + dataarray_pint_dask, # pylint:disable=redefined-outer-name + dataarray_pint_dask_cupy, # pylint:disable=redefined-outer-name + dataset_numpy, # pylint:disable=redefined-outer-name + dataset_cupy, # pylint:disable=redefined-outer-name + dataset_dask, # pylint:disable=redefined-outer-name + dataset_dask_cupy, # pylint:disable=redefined-outer-name + dataset_pint_numpy, # pylint:disable=redefined-outer-name + dataset_pint_cupy, # pylint:disable=redefined-outer-name + dataset_pint_dask, # pylint:disable=redefined-outer-name + dataset_pint_dask_cupy, # pylint:disable=redefined-outer-name +): + """Test as_cupy() method in cupy xarray accessor""" + # Apply cupy.as_cupy() to all dataarray types + dataarray_numpy_as_cupy = dataarray_numpy.cupy.as_cupy() + dataarray_cupy_as_cupy = dataarray_cupy.cupy.as_cupy() + dataarray_dask_as_cupy = dataarray_dask.cupy.as_cupy() + dataarray_dask_cupy_as_cupy = dataarray_dask_cupy.cupy.as_cupy() + dataarray_pint_numpy_as_cupy = dataarray_pint_numpy.cupy.as_cupy() + dataarray_pint_cupy_as_cupy = dataarray_pint_cupy.cupy.as_cupy() + dataarray_pint_dask_as_cupy = dataarray_pint_dask.cupy.as_cupy() + dataarray_pint_dask_cupy_as_cupy = dataarray_pint_dask_cupy.cupy.as_cupy() + dataset_numpy_as_cupy = dataset_numpy.cupy.as_cupy() + dataset_cupy_as_cupy = dataset_cupy.cupy.as_cupy() + dataset_dask_as_cupy = dataset_dask.cupy.as_cupy() + dataset_dask_cupy_as_cupy = dataset_dask_cupy.cupy.as_cupy() + dataset_pint_numpy_as_cupy = dataset_pint_numpy.cupy.as_cupy() + dataset_pint_cupy_as_cupy = dataset_pint_cupy.cupy.as_cupy() + dataset_pint_dask_as_cupy = dataset_pint_dask.cupy.as_cupy() + dataset_pint_dask_cupy_as_cupy = dataset_pint_dask_cupy.cupy.as_cupy() + + # Test that all types are cupy-based + assert dataarray_numpy_as_cupy.cupy.is_cupy + assert dataarray_cupy_as_cupy.cupy.is_cupy + assert dataarray_dask_as_cupy.cupy.is_cupy + assert dataarray_dask_cupy_as_cupy.cupy.is_cupy + assert dataarray_pint_numpy_as_cupy.cupy.is_cupy + assert dataarray_pint_cupy_as_cupy.cupy.is_cupy + assert dataarray_pint_dask_as_cupy.cupy.is_cupy + assert dataarray_pint_dask_cupy_as_cupy.cupy.is_cupy + assert dataset_numpy_as_cupy.cupy.is_cupy + assert dataset_cupy_as_cupy.cupy.is_cupy + assert dataset_dask_as_cupy.cupy.is_cupy + assert dataset_dask_cupy_as_cupy.cupy.is_cupy + assert dataset_pint_numpy_as_cupy.cupy.is_cupy + assert dataset_pint_cupy_as_cupy.cupy.is_cupy + assert dataset_pint_dask_as_cupy.cupy.is_cupy + assert dataset_pint_dask_cupy_as_cupy.cupy.is_cupy + + # Check that we keep the original data type (except pure numpy) + assert isinstance(dataarray_numpy_as_cupy.data, cupy_array_type) + assert isinstance(dataarray_cupy_as_cupy.data, cupy_array_type) + assert isinstance(dataarray_dask_as_cupy.data, dask_array_type) + assert isinstance(dataarray_dask_cupy_as_cupy.data, dask_array_type) + assert isinstance(dataarray_pint_numpy_as_cupy.data, pint_array_type) + assert isinstance(dataarray_pint_cupy_as_cupy.data, pint_array_type) + assert isinstance(dataarray_pint_dask_as_cupy.data, pint_array_type) + assert isinstance(dataarray_pint_dask_cupy_as_cupy.data, pint_array_type) + assert isinstance(dataset_numpy_as_cupy["foo"].data, cupy_array_type) + assert isinstance(dataset_cupy_as_cupy["foo"].data, cupy_array_type) + assert isinstance(dataset_dask_as_cupy["foo"].data, dask_array_type) + assert isinstance(dataset_dask_cupy_as_cupy["foo"].data, dask_array_type) + assert isinstance(dataset_pint_numpy_as_cupy["foo"].data, pint_array_type) + assert isinstance(dataset_pint_cupy_as_cupy["foo"].data, pint_array_type) + assert isinstance(dataset_pint_dask_as_cupy["foo"].data, pint_array_type) + assert isinstance(dataset_pint_dask_cupy_as_cupy["foo"].data, pint_array_type) @requires_pint -def test_data_array_accessor_pint_dask(tutorial_da_air_pint_dask): - da = tutorial_da_air_pint_dask - assert hasattr(da, "cupy") - assert not da.cupy.is_cupy - - da = da.as_cupy() - assert da.cupy.is_cupy - assert isinstance(da.data, pint_array_type) - assert isinstance(da.data.magnitude, dask_array_type) - - da = da.cupy.as_numpy() - assert not da.cupy.is_cupy - assert isinstance(da.data, pint_array_type) - assert isinstance(da.data.magnitude, dask_array_type) +@requires_dask +@requires_cupy +def test_as_numpy( + dataarray_numpy, # pylint:disable=redefined-outer-name + dataarray_cupy, # pylint:disable=redefined-outer-name + dataarray_dask, # pylint:disable=redefined-outer-name + dataarray_dask_cupy, # pylint:disable=redefined-outer-name + dataarray_pint_numpy, # pylint:disable=redefined-outer-name + dataarray_pint_cupy, # pylint:disable=redefined-outer-name + dataarray_pint_dask, # pylint:disable=redefined-outer-name + dataarray_pint_dask_cupy, # pylint:disable=redefined-outer-name + dataset_numpy, # pylint:disable=redefined-outer-name + dataset_cupy, # pylint:disable=redefined-outer-name + dataset_dask, # pylint:disable=redefined-outer-name + dataset_dask_cupy, # pylint:disable=redefined-outer-name + dataset_pint_numpy, # pylint:disable=redefined-outer-name + dataset_pint_cupy, # pylint:disable=redefined-outer-name + dataset_pint_dask, # pylint:disable=redefined-outer-name + dataset_pint_dask_cupy, # pylint:disable=redefined-outer-name +): + """Test as_numpy() method in cupy xarray accessor""" + # Apply cupy.as_numpy() to all dataarray types + dataarray_numpy_as_numpy = dataarray_numpy.cupy.as_numpy() + dataarray_cupy_as_numpy = dataarray_cupy.cupy.as_numpy() + dataarray_dask_as_numpy = dataarray_dask.cupy.as_numpy() + dataarray_dask_cupy_as_numpy = dataarray_dask_cupy.cupy.as_numpy() + dataarray_pint_numpy_as_numpy = dataarray_pint_numpy.cupy.as_numpy() + dataarray_pint_cupy_as_numpy = dataarray_pint_cupy.cupy.as_numpy() + dataarray_pint_dask_as_numpy = dataarray_pint_dask.cupy.as_numpy() + dataarray_pint_dask_cupy_as_numpy = dataarray_pint_dask_cupy.cupy.as_numpy() + dataset_numpy_as_numpy = dataset_numpy.cupy.as_numpy() + dataset_cupy_as_numpy = dataset_cupy.cupy.as_numpy() + dataset_dask_as_numpy = dataset_dask.cupy.as_numpy() + dataset_dask_cupy_as_numpy = dataset_dask_cupy.cupy.as_numpy() + dataset_pint_numpy_as_numpy = dataset_pint_numpy.cupy.as_numpy() + dataset_pint_cupy_as_numpy = dataset_pint_cupy.cupy.as_numpy() + dataset_pint_dask_as_numpy = dataset_pint_dask.cupy.as_numpy() + dataset_pint_dask_cupy_as_numpy = dataset_pint_dask_cupy.cupy.as_numpy() + + # Test that all types are not cupy-based + assert not dataarray_numpy_as_numpy.cupy.is_cupy + assert not dataarray_cupy_as_numpy.cupy.is_cupy + assert not dataarray_dask_as_numpy.cupy.is_cupy + assert not dataarray_dask_cupy_as_numpy.cupy.is_cupy + assert not dataarray_pint_numpy_as_numpy.cupy.is_cupy + assert not dataarray_pint_cupy_as_numpy.cupy.is_cupy + assert not dataarray_pint_dask_as_numpy.cupy.is_cupy + assert not dataarray_pint_dask_cupy_as_numpy.cupy.is_cupy + assert not dataset_numpy_as_numpy.cupy.is_cupy + assert not dataset_cupy_as_numpy.cupy.is_cupy + assert not dataset_dask_as_numpy.cupy.is_cupy + assert not dataset_dask_cupy_as_numpy.cupy.is_cupy + assert not dataset_pint_numpy_as_numpy.cupy.is_cupy + assert not dataset_pint_cupy_as_numpy.cupy.is_cupy + assert not dataset_pint_dask_as_numpy.cupy.is_cupy + assert not dataset_pint_dask_cupy_as_numpy.cupy.is_cupy + + # Check that we keep the original data type (except pure numpy) + assert isinstance(dataarray_numpy_as_numpy.data, np.ndarray) + assert isinstance(dataarray_cupy_as_numpy.data, np.ndarray) + assert isinstance(dataarray_dask_as_numpy.data, dask_array_type) + assert isinstance(dataarray_dask_cupy_as_numpy.data, dask_array_type) + assert isinstance(dataarray_pint_numpy_as_numpy.data, pint_array_type) + assert isinstance(dataarray_pint_cupy_as_numpy.data, pint_array_type) + assert isinstance(dataarray_pint_dask_as_numpy.data, pint_array_type) + assert isinstance(dataarray_pint_dask_cupy_as_numpy.data, pint_array_type) + assert isinstance(dataset_numpy_as_numpy["foo"].data, np.ndarray) + assert isinstance(dataset_cupy_as_numpy["foo"].data, np.ndarray) + assert isinstance(dataset_dask_as_numpy["foo"].data, dask_array_type) + assert isinstance(dataset_dask_cupy_as_numpy["foo"].data, dask_array_type) + assert isinstance(dataset_pint_numpy_as_numpy["foo"].data, pint_array_type) + assert isinstance(dataset_pint_cupy_as_numpy["foo"].data, pint_array_type) + assert isinstance(dataset_pint_dask_as_numpy["foo"].data, pint_array_type) + assert isinstance(dataset_pint_dask_cupy_as_numpy["foo"].data, pint_array_type) From 04a830da81dd297318040f872b1bdbc3b58461a3 Mon Sep 17 00:00:00 2001 From: Aleksandr Kadykov Date: Thu, 25 May 2023 11:35:04 +0200 Subject: [PATCH 18/23] Ignore flake8 unused import warning --- cupy_xarray/tests/test_accessors.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cupy_xarray/tests/test_accessors.py b/cupy_xarray/tests/test_accessors.py index eea4a6f..7f4f869 100644 --- a/cupy_xarray/tests/test_accessors.py +++ b/cupy_xarray/tests/test_accessors.py @@ -2,7 +2,7 @@ import cupy as cp import dask.array as da import numpy as np -import pint_xarray # pylint:disable=unused-import +import pint_xarray # noqa: F401 pylint:disable=unused-import import pytest import xarray as xr from xarray.core.pycompat import DuckArrayModule From 41c0d8d42261007265dc13f6699557a64e9695f0 Mon Sep 17 00:00:00 2001 From: dcherian Date: Fri, 27 Oct 2023 15:34:29 -0600 Subject: [PATCH 19/23] cleanup --- cupy_xarray/accessors.py | 50 +++++++++++++++++----------------------- 1 file changed, 21 insertions(+), 29 deletions(-) diff --git a/cupy_xarray/accessors.py b/cupy_xarray/accessors.py index 996833f..b6dd1cc 100644 --- a/cupy_xarray/accessors.py +++ b/cupy_xarray/accessors.py @@ -1,16 +1,19 @@ import cupy as cp -from xarray import ( - DataArray, - Dataset, - register_dataarray_accessor, - register_dataset_accessor, -) +from xarray import Dataset, register_dataarray_accessor, register_dataset_accessor from xarray.core.pycompat import DuckArrayModule dask_array_type = DuckArrayModule("dask").type pint_array_type = DuckArrayModule("pint").type +def _get_datatype(cls, data): + if isinstance(data, dask_array_type): + return isinstance(data._meta, cp.ndarray) + elif isinstance(data, pint_array_type): + return _get_datatype(data.magnitude) + return isinstance(data, cp.ndarray) + + @register_dataarray_accessor("cupy") class CupyDataArrayAccessor: """ @@ -22,17 +25,9 @@ def __init__(self, da): self.da = da @property - def is_cupy(self): - """bool: The underlying data is a cupy array.""" - return self._get_datatype(self.da.data) - - @classmethod - def _get_datatype(cls, data): - if isinstance(data, dask_array_type): - return isinstance(data._meta, cp.ndarray) - elif isinstance(data, pint_array_type): - return cls._get_datatype(data.magnitude) - return isinstance(data, cp.ndarray) + def is_cupy(self) -> bool: + """True if the underlying data is a cupy array.""" + return _get_datatype(self.da.data) def as_cupy(self): """ @@ -57,7 +52,7 @@ def as_cupy(self): """ - return self._as_dataarray(_as_cupy_data(self.da.data)) + return self.da.copy(data=_as_cupy_data(self.da.data)) def as_numpy(self): """ @@ -69,20 +64,11 @@ def as_numpy(self): DataArray with underlying data cast to numpy. """ - return self._as_dataarray(_as_numpy_data(self.da.data)) + raise NotImplementedError("Please use .as_numpy DataArray method directly.") def get(self): return self.da.data.get() - def _as_dataarray(self, data): - return DataArray( - data=data, - coords=self.da.coords, - dims=self.da.dims, - name=self.da.name, - attrs=self.da.attrs, - ) - def _as_cupy_data(data): if isinstance(data, dask_array_type): @@ -121,7 +107,13 @@ def __init__(self, ds): self.ds = ds @property - def is_cupy(self): + def has_cupy(self) -> bool: + """True if any data variable contains a cupy array.""" + return any([da.cupy.is_cupy for da in self.ds.data_vars.values()]) + + @property + def is_cupy(self) -> bool: + """True if all data variables contain cupy arrays.""" return all([da.cupy.is_cupy for da in self.ds.data_vars.values()]) def as_cupy(self): From c076239eb35a6bd95e5bc0cd1d0cc495e3330cf7 Mon Sep 17 00:00:00 2001 From: dcherian Date: Fri, 27 Oct 2023 15:52:30 -0600 Subject: [PATCH 20/23] Clean up --- cupy_xarray/tests/test_accessors.py | 351 +++------------------------- 1 file changed, 37 insertions(+), 314 deletions(-) diff --git a/cupy_xarray/tests/test_accessors.py b/cupy_xarray/tests/test_accessors.py index 7f4f869..a349170 100644 --- a/cupy_xarray/tests/test_accessors.py +++ b/cupy_xarray/tests/test_accessors.py @@ -1,334 +1,57 @@ """Tests for cupy-xarray accessors""" import cupy as cp -import dask.array as da import numpy as np -import pint_xarray # noqa: F401 pylint:disable=unused-import import pytest import xarray as xr -from xarray.core.pycompat import DuckArrayModule -from xarray.tests import requires_cupy, requires_dask, requires_pint +from xarray.tests import requires_dask, requires_pint import cupy_xarray # noqa: F401 pylint:disable=unused-import -dask_array_type = DuckArrayModule("dask").type -pint_array_type = DuckArrayModule("pint").type -cupy_array_type = DuckArrayModule("cupy").type +da = xr.DataArray(np.random.rand(2, 3), attrs={"units": "candle"}) +ds = xr.Dataset({"a": da}) -@pytest.fixture -def dataarray_numpy(): - """Prepare numpy DataArray""" - return xr.DataArray( - np.random.rand(2, 3), - attrs={"units": "candle"}, - ) - - -@pytest.fixture -def dataarray_cupy(): - """Prepare cupy DataArray""" - return xr.DataArray( - cp.random.rand(2, 3), - attrs={"units": "kelvin"}, - ) - - -@pytest.fixture -def dataarray_dask(): - """Prepare dask DataArray""" - return xr.DataArray( - da.asarray(np.random.rand(2, 3)), - attrs={"units": "mole"}, - ) - - -@pytest.fixture -def dataarray_dask_cupy(): - """Prepare dask(cupy) DataArray""" - return xr.DataArray( - da.asarray(cp.random.rand(2, 3)), - attrs={"units": "mole"}, - ) - - -def to_pint(dataarray): - """Convert DataArray data to pint""" - return dataarray.pint.quantify() - - -@pytest.fixture -def dataarray_pint_numpy(dataarray_numpy): # pylint:disable=redefined-outer-name - """Prepare pint(numpy) DataArray""" - return to_pint(dataarray_numpy) - - -@pytest.fixture -def dataarray_pint_cupy(dataarray_cupy): # pylint:disable=redefined-outer-name - """Prepare pint(cupy) DataArray""" - return to_pint(dataarray_cupy) - - -@pytest.fixture -def dataarray_pint_dask(dataarray_dask): # pylint:disable=redefined-outer-name - """Prepare pint(dask) DataArray""" - return to_pint(dataarray_dask) - - -@pytest.fixture -def dataarray_pint_dask_cupy(dataarray_dask_cupy): # pylint:disable=redefined-outer-name - """Prepare pint(dask(cupy)) DataArray""" - return to_pint(dataarray_dask_cupy) - - -def to_dataset(dataarray): - """Convert DataArray to Dataset""" - return xr.Dataset(data_vars={"foo": dataarray}) - - -@pytest.fixture -def dataset_numpy(dataarray_numpy): # pylint:disable=redefined-outer-name - """Prepare numpy Dataset""" - return to_dataset(dataarray_numpy) - - -@pytest.fixture -def dataset_cupy(dataarray_cupy): # pylint:disable=redefined-outer-name - """Prepare cupy Dataset""" - return to_dataset(dataarray_cupy) - - -@pytest.fixture -def dataset_dask(dataarray_dask): # pylint:disable=redefined-outer-name - """Prepare dask Dataset""" - return to_dataset(dataarray_dask) - - -@pytest.fixture -def dataset_dask_cupy(dataarray_dask_cupy): # pylint:disable=redefined-outer-name - """Prepare dask(cupy) Dataset""" - return to_dataset(dataarray_dask_cupy) - - -@pytest.fixture -def dataset_pint_numpy(dataarray_pint_numpy): # pylint:disable=redefined-outer-name - """Prepare pint(numpy) Dataset""" - return to_dataset(dataarray_pint_numpy) - - -@pytest.fixture -def dataset_pint_cupy(dataarray_pint_cupy): # pylint:disable=redefined-outer-name - """Prepare pint(cupy) Dataset""" - return to_dataset(dataarray_pint_cupy) - - -@pytest.fixture -def dataset_pint_dask(dataarray_pint_dask): # pylint:disable=redefined-outer-name - """Prepare pint(dask) Dataset""" - return to_dataset(dataarray_pint_dask) - - -@pytest.fixture -def dataset_pint_dask_cupy(dataarray_pint_dask_cupy): # pylint:disable=redefined-outer-name - """Prepare pint(dask(cupy)) Dataset""" - return to_dataset(dataarray_pint_dask_cupy) - - -@requires_pint -@requires_dask -@requires_cupy -def test_is_cupy( - dataarray_numpy, # pylint:disable=redefined-outer-name - dataarray_cupy, # pylint:disable=redefined-outer-name - dataarray_dask, # pylint:disable=redefined-outer-name - dataarray_dask_cupy, # pylint:disable=redefined-outer-name - dataarray_pint_numpy, # pylint:disable=redefined-outer-name - dataarray_pint_cupy, # pylint:disable=redefined-outer-name - dataarray_pint_dask, # pylint:disable=redefined-outer-name - dataarray_pint_dask_cupy, # pylint:disable=redefined-outer-name - dataset_numpy, # pylint:disable=redefined-outer-name - dataset_cupy, # pylint:disable=redefined-outer-name - dataset_dask, # pylint:disable=redefined-outer-name - dataset_dask_cupy, # pylint:disable=redefined-outer-name - dataset_pint_numpy, # pylint:disable=redefined-outer-name - dataset_pint_cupy, # pylint:disable=redefined-outer-name - dataset_pint_dask, # pylint:disable=redefined-outer-name - dataset_pint_dask_cupy, # pylint:disable=redefined-outer-name -): +@pytest.mark.parametrize("obj", [da, ds]) +def test_numpy(obj): """Test is_cupy property in cupy xarray accessor""" - # Test all dataarray types - assert not dataarray_numpy.cupy.is_cupy - assert dataarray_cupy.cupy.is_cupy - assert not dataarray_dask.cupy.is_cupy - assert dataarray_dask_cupy.cupy.is_cupy - - # Test all pinted dataarray types - assert not dataarray_pint_numpy.cupy.is_cupy - assert dataarray_pint_cupy.cupy.is_cupy - assert not dataarray_pint_dask.cupy.is_cupy - assert dataarray_pint_dask_cupy.cupy.is_cupy - # Test all dataset types - assert not dataset_numpy.cupy.is_cupy - assert dataset_cupy.cupy.is_cupy - assert not dataset_dask.cupy.is_cupy - assert dataset_dask_cupy.cupy.is_cupy + assert not da.cupy.is_cupy + cpda = da.cupy.as_cupy() + assert cpda.is_cupy - # Test all pinted dataset types - assert not dataset_pint_numpy.cupy.is_cupy - assert dataset_pint_cupy.cupy.is_cupy - assert not dataset_pint_dask.cupy.is_cupy - assert dataset_pint_dask_cupy.cupy.is_cupy + as_numpy = cpda.as_numpy() + assert not cpda.cupy.is_cupy + if isinstance(as_numpy, xr.DataArray): + assert isinstance(as_numpy.data, np.ndarray) -@requires_pint @requires_dask -@requires_cupy -def test_as_cupy( - dataarray_numpy, # pylint:disable=redefined-outer-name - dataarray_cupy, # pylint:disable=redefined-outer-name - dataarray_dask, # pylint:disable=redefined-outer-name - dataarray_dask_cupy, # pylint:disable=redefined-outer-name - dataarray_pint_numpy, # pylint:disable=redefined-outer-name - dataarray_pint_cupy, # pylint:disable=redefined-outer-name - dataarray_pint_dask, # pylint:disable=redefined-outer-name - dataarray_pint_dask_cupy, # pylint:disable=redefined-outer-name - dataset_numpy, # pylint:disable=redefined-outer-name - dataset_cupy, # pylint:disable=redefined-outer-name - dataset_dask, # pylint:disable=redefined-outer-name - dataset_dask_cupy, # pylint:disable=redefined-outer-name - dataset_pint_numpy, # pylint:disable=redefined-outer-name - dataset_pint_cupy, # pylint:disable=redefined-outer-name - dataset_pint_dask, # pylint:disable=redefined-outer-name - dataset_pint_dask_cupy, # pylint:disable=redefined-outer-name -): - """Test as_cupy() method in cupy xarray accessor""" - # Apply cupy.as_cupy() to all dataarray types - dataarray_numpy_as_cupy = dataarray_numpy.cupy.as_cupy() - dataarray_cupy_as_cupy = dataarray_cupy.cupy.as_cupy() - dataarray_dask_as_cupy = dataarray_dask.cupy.as_cupy() - dataarray_dask_cupy_as_cupy = dataarray_dask_cupy.cupy.as_cupy() - dataarray_pint_numpy_as_cupy = dataarray_pint_numpy.cupy.as_cupy() - dataarray_pint_cupy_as_cupy = dataarray_pint_cupy.cupy.as_cupy() - dataarray_pint_dask_as_cupy = dataarray_pint_dask.cupy.as_cupy() - dataarray_pint_dask_cupy_as_cupy = dataarray_pint_dask_cupy.cupy.as_cupy() - dataset_numpy_as_cupy = dataset_numpy.cupy.as_cupy() - dataset_cupy_as_cupy = dataset_cupy.cupy.as_cupy() - dataset_dask_as_cupy = dataset_dask.cupy.as_cupy() - dataset_dask_cupy_as_cupy = dataset_dask_cupy.cupy.as_cupy() - dataset_pint_numpy_as_cupy = dataset_pint_numpy.cupy.as_cupy() - dataset_pint_cupy_as_cupy = dataset_pint_cupy.cupy.as_cupy() - dataset_pint_dask_as_cupy = dataset_pint_dask.cupy.as_cupy() - dataset_pint_dask_cupy_as_cupy = dataset_pint_dask_cupy.cupy.as_cupy() - - # Test that all types are cupy-based - assert dataarray_numpy_as_cupy.cupy.is_cupy - assert dataarray_cupy_as_cupy.cupy.is_cupy - assert dataarray_dask_as_cupy.cupy.is_cupy - assert dataarray_dask_cupy_as_cupy.cupy.is_cupy - assert dataarray_pint_numpy_as_cupy.cupy.is_cupy - assert dataarray_pint_cupy_as_cupy.cupy.is_cupy - assert dataarray_pint_dask_as_cupy.cupy.is_cupy - assert dataarray_pint_dask_cupy_as_cupy.cupy.is_cupy - assert dataset_numpy_as_cupy.cupy.is_cupy - assert dataset_cupy_as_cupy.cupy.is_cupy - assert dataset_dask_as_cupy.cupy.is_cupy - assert dataset_dask_cupy_as_cupy.cupy.is_cupy - assert dataset_pint_numpy_as_cupy.cupy.is_cupy - assert dataset_pint_cupy_as_cupy.cupy.is_cupy - assert dataset_pint_dask_as_cupy.cupy.is_cupy - assert dataset_pint_dask_cupy_as_cupy.cupy.is_cupy +@pytest.mark.parametrize("obj", [da, ds]) +def test_dask(obj): + """Test is_cupy property in cupy xarray accessor""" + as_dask = obj.chunk() + assert not as_dask.cupy.is_cupy + cpda = as_dask.cupy.as_cupy() + assert cpda.cupy.is_cupy - # Check that we keep the original data type (except pure numpy) - assert isinstance(dataarray_numpy_as_cupy.data, cupy_array_type) - assert isinstance(dataarray_cupy_as_cupy.data, cupy_array_type) - assert isinstance(dataarray_dask_as_cupy.data, dask_array_type) - assert isinstance(dataarray_dask_cupy_as_cupy.data, dask_array_type) - assert isinstance(dataarray_pint_numpy_as_cupy.data, pint_array_type) - assert isinstance(dataarray_pint_cupy_as_cupy.data, pint_array_type) - assert isinstance(dataarray_pint_dask_as_cupy.data, pint_array_type) - assert isinstance(dataarray_pint_dask_cupy_as_cupy.data, pint_array_type) - assert isinstance(dataset_numpy_as_cupy["foo"].data, cupy_array_type) - assert isinstance(dataset_cupy_as_cupy["foo"].data, cupy_array_type) - assert isinstance(dataset_dask_as_cupy["foo"].data, dask_array_type) - assert isinstance(dataset_dask_cupy_as_cupy["foo"].data, dask_array_type) - assert isinstance(dataset_pint_numpy_as_cupy["foo"].data, pint_array_type) - assert isinstance(dataset_pint_cupy_as_cupy["foo"].data, pint_array_type) - assert isinstance(dataset_pint_dask_as_cupy["foo"].data, pint_array_type) - assert isinstance(dataset_pint_dask_cupy_as_cupy["foo"].data, pint_array_type) + if isinstance(cpda, xr.DataArray): + assert isinstance(cpda.data._meta, cp.ndarray) @requires_pint -@requires_dask -@requires_cupy -def test_as_numpy( - dataarray_numpy, # pylint:disable=redefined-outer-name - dataarray_cupy, # pylint:disable=redefined-outer-name - dataarray_dask, # pylint:disable=redefined-outer-name - dataarray_dask_cupy, # pylint:disable=redefined-outer-name - dataarray_pint_numpy, # pylint:disable=redefined-outer-name - dataarray_pint_cupy, # pylint:disable=redefined-outer-name - dataarray_pint_dask, # pylint:disable=redefined-outer-name - dataarray_pint_dask_cupy, # pylint:disable=redefined-outer-name - dataset_numpy, # pylint:disable=redefined-outer-name - dataset_cupy, # pylint:disable=redefined-outer-name - dataset_dask, # pylint:disable=redefined-outer-name - dataset_dask_cupy, # pylint:disable=redefined-outer-name - dataset_pint_numpy, # pylint:disable=redefined-outer-name - dataset_pint_cupy, # pylint:disable=redefined-outer-name - dataset_pint_dask, # pylint:disable=redefined-outer-name - dataset_pint_dask_cupy, # pylint:disable=redefined-outer-name -): - """Test as_numpy() method in cupy xarray accessor""" - # Apply cupy.as_numpy() to all dataarray types - dataarray_numpy_as_numpy = dataarray_numpy.cupy.as_numpy() - dataarray_cupy_as_numpy = dataarray_cupy.cupy.as_numpy() - dataarray_dask_as_numpy = dataarray_dask.cupy.as_numpy() - dataarray_dask_cupy_as_numpy = dataarray_dask_cupy.cupy.as_numpy() - dataarray_pint_numpy_as_numpy = dataarray_pint_numpy.cupy.as_numpy() - dataarray_pint_cupy_as_numpy = dataarray_pint_cupy.cupy.as_numpy() - dataarray_pint_dask_as_numpy = dataarray_pint_dask.cupy.as_numpy() - dataarray_pint_dask_cupy_as_numpy = dataarray_pint_dask_cupy.cupy.as_numpy() - dataset_numpy_as_numpy = dataset_numpy.cupy.as_numpy() - dataset_cupy_as_numpy = dataset_cupy.cupy.as_numpy() - dataset_dask_as_numpy = dataset_dask.cupy.as_numpy() - dataset_dask_cupy_as_numpy = dataset_dask_cupy.cupy.as_numpy() - dataset_pint_numpy_as_numpy = dataset_pint_numpy.cupy.as_numpy() - dataset_pint_cupy_as_numpy = dataset_pint_cupy.cupy.as_numpy() - dataset_pint_dask_as_numpy = dataset_pint_dask.cupy.as_numpy() - dataset_pint_dask_cupy_as_numpy = dataset_pint_dask_cupy.cupy.as_numpy() - - # Test that all types are not cupy-based - assert not dataarray_numpy_as_numpy.cupy.is_cupy - assert not dataarray_cupy_as_numpy.cupy.is_cupy - assert not dataarray_dask_as_numpy.cupy.is_cupy - assert not dataarray_dask_cupy_as_numpy.cupy.is_cupy - assert not dataarray_pint_numpy_as_numpy.cupy.is_cupy - assert not dataarray_pint_cupy_as_numpy.cupy.is_cupy - assert not dataarray_pint_dask_as_numpy.cupy.is_cupy - assert not dataarray_pint_dask_cupy_as_numpy.cupy.is_cupy - assert not dataset_numpy_as_numpy.cupy.is_cupy - assert not dataset_cupy_as_numpy.cupy.is_cupy - assert not dataset_dask_as_numpy.cupy.is_cupy - assert not dataset_dask_cupy_as_numpy.cupy.is_cupy - assert not dataset_pint_numpy_as_numpy.cupy.is_cupy - assert not dataset_pint_cupy_as_numpy.cupy.is_cupy - assert not dataset_pint_dask_as_numpy.cupy.is_cupy - assert not dataset_pint_dask_cupy_as_numpy.cupy.is_cupy - - # Check that we keep the original data type (except pure numpy) - assert isinstance(dataarray_numpy_as_numpy.data, np.ndarray) - assert isinstance(dataarray_cupy_as_numpy.data, np.ndarray) - assert isinstance(dataarray_dask_as_numpy.data, dask_array_type) - assert isinstance(dataarray_dask_cupy_as_numpy.data, dask_array_type) - assert isinstance(dataarray_pint_numpy_as_numpy.data, pint_array_type) - assert isinstance(dataarray_pint_cupy_as_numpy.data, pint_array_type) - assert isinstance(dataarray_pint_dask_as_numpy.data, pint_array_type) - assert isinstance(dataarray_pint_dask_cupy_as_numpy.data, pint_array_type) - assert isinstance(dataset_numpy_as_numpy["foo"].data, np.ndarray) - assert isinstance(dataset_cupy_as_numpy["foo"].data, np.ndarray) - assert isinstance(dataset_dask_as_numpy["foo"].data, dask_array_type) - assert isinstance(dataset_dask_cupy_as_numpy["foo"].data, dask_array_type) - assert isinstance(dataset_pint_numpy_as_numpy["foo"].data, pint_array_type) - assert isinstance(dataset_pint_cupy_as_numpy["foo"].data, pint_array_type) - assert isinstance(dataset_pint_dask_as_numpy["foo"].data, pint_array_type) - assert isinstance(dataset_pint_dask_cupy_as_numpy["foo"].data, pint_array_type) +@pytest.mark.parametrize("obj", [da, ds]) +def test_pint(obj): + import pint_xarray # noqa + + as_pint = obj.pint.quantify() + + assert not as_pint.cupy.is_cupy + cpda = as_pint.cupy.as_cup() + assert cpda.cupy.is_cupy + + as_dask = as_pint.chunk() + assert not as_dask.cupy.is_cupy + cpda = as_dask.cupy.as_cupy() + assert cpda.cupy.is_cupy + if isinstance(cpda, xr.DataArray): + assert isinstance(cpda.data._meta, cp.ndarray) From 53e5a86d9befb036d5ced47469a6fbcec29598c4 Mon Sep 17 00:00:00 2001 From: dcherian Date: Fri, 27 Oct 2023 16:20:21 -0600 Subject: [PATCH 21/23] Fixes. --- cupy_xarray/accessors.py | 2 +- cupy_xarray/tests/test_accessors.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/cupy_xarray/accessors.py b/cupy_xarray/accessors.py index b6dd1cc..1bcb7f8 100644 --- a/cupy_xarray/accessors.py +++ b/cupy_xarray/accessors.py @@ -6,7 +6,7 @@ pint_array_type = DuckArrayModule("pint").type -def _get_datatype(cls, data): +def _get_datatype(data): if isinstance(data, dask_array_type): return isinstance(data._meta, cp.ndarray) elif isinstance(data, pint_array_type): diff --git a/cupy_xarray/tests/test_accessors.py b/cupy_xarray/tests/test_accessors.py index a349170..10e89fe 100644 --- a/cupy_xarray/tests/test_accessors.py +++ b/cupy_xarray/tests/test_accessors.py @@ -17,10 +17,10 @@ def test_numpy(obj): assert not da.cupy.is_cupy cpda = da.cupy.as_cupy() - assert cpda.is_cupy + assert cpda.cupy.is_cupy as_numpy = cpda.as_numpy() - assert not cpda.cupy.is_cupy + assert not as_numpy.cupy.is_cupy if isinstance(as_numpy, xr.DataArray): assert isinstance(as_numpy.data, np.ndarray) From cff0456b9cedef022f38694d90ee1124bbdfebf2 Mon Sep 17 00:00:00 2001 From: dcherian Date: Fri, 27 Oct 2023 16:23:17 -0600 Subject: [PATCH 22/23] more fix --- cupy_xarray/tests/test_accessors.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cupy_xarray/tests/test_accessors.py b/cupy_xarray/tests/test_accessors.py index 10e89fe..2048717 100644 --- a/cupy_xarray/tests/test_accessors.py +++ b/cupy_xarray/tests/test_accessors.py @@ -46,7 +46,7 @@ def test_pint(obj): as_pint = obj.pint.quantify() assert not as_pint.cupy.is_cupy - cpda = as_pint.cupy.as_cup() + cpda = as_pint.cupy.as_cupy() assert cpda.cupy.is_cupy as_dask = as_pint.chunk() From 0c36ef0a0fdfa6ce58030d1aa0a2f95982ace7d4 Mon Sep 17 00:00:00 2001 From: dcherian Date: Fri, 27 Oct 2023 16:33:24 -0600 Subject: [PATCH 23/23] Failing test --- cupy_xarray/tests/test_accessors.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/cupy_xarray/tests/test_accessors.py b/cupy_xarray/tests/test_accessors.py index 2048717..18c8f02 100644 --- a/cupy_xarray/tests/test_accessors.py +++ b/cupy_xarray/tests/test_accessors.py @@ -41,15 +41,21 @@ def test_dask(obj): @requires_pint @pytest.mark.parametrize("obj", [da, ds]) def test_pint(obj): + import pint import pint_xarray # noqa as_pint = obj.pint.quantify() assert not as_pint.cupy.is_cupy cpda = as_pint.cupy.as_cupy() + if isinstance(cpda, xr.DataArray): + assert isinstance(cpda.data, pint.Quantity) assert cpda.cupy.is_cupy as_dask = as_pint.chunk() + if isinstance(as_dask, xr.DataArray): + assert isinstance(as_dask.data, pint.Quantity) + assert isinstance(as_dask.data.magnitude._meta, np.ndarray) assert not as_dask.cupy.is_cupy cpda = as_dask.cupy.as_cupy() assert cpda.cupy.is_cupy