diff --git a/kedro-datasets/kedro_datasets/geopandas/__init__.py b/kedro-datasets/kedro_datasets/geopandas/__init__.py index d4843aa68..a257e4121 100644 --- a/kedro-datasets/kedro_datasets/geopandas/__init__.py +++ b/kedro-datasets/kedro_datasets/geopandas/__init__.py @@ -8,5 +8,9 @@ GeoJSONDataset: Any __getattr__, __dir__, __all__ = lazy.attach( - __name__, submod_attrs={"geojson_dataset": ["GeoJSONDataset"]} + __name__, + submod_attrs={ + "geojson_dataset": ["GeoJSONDataset"], + "parquet_dataset": ["ParquetDataset"], + }, ) diff --git a/kedro-datasets/kedro_datasets/geopandas/parquet_dataset.py b/kedro-datasets/kedro_datasets/geopandas/parquet_dataset.py new file mode 100644 index 000000000..831bc4b3b --- /dev/null +++ b/kedro-datasets/kedro_datasets/geopandas/parquet_dataset.py @@ -0,0 +1,162 @@ +"""ParquetDataset loads and saves data to a local parquet file. The +underlying functionality is supported by geopandas, so it supports all +allowed geopandas (pandas) options for loading and saving geosjon files. +""" + +from __future__ import annotations + +import copy +from pathlib import PurePosixPath +from typing import Any, Union + +import fsspec +import geopandas as gpd +from kedro.io.core import ( + AbstractVersionedDataset, + DatasetError, + Version, + get_filepath_str, + get_protocol_and_path, +) + + +class ParquetDataset( + AbstractVersionedDataset[ + gpd.GeoDataFrame, Union[gpd.GeoDataFrame, dict[str, gpd.GeoDataFrame]] + ] +): + """``ParquetDataset`` loads/saves data to a parquet file using an underlying filesystem + (eg: local, S3, GCS). + The underlying functionality is supported by geopandas, so it supports all + allowed geopandas (pandas) options for loading and saving parquet files. + + Example: + + .. code-block:: pycon + + >>> import geopandas as gpd + >>> from kedro_datasets.geopandas import ParquetDataset + >>> from shapely.geometry import Point + >>> + >>> data = gpd.GeoDataFrame( + ... {"col1": [1, 2], "col2": [4, 5], "col3": [5, 6]}, + ... geometry=[Point(1, 1), Point(2, 4)], + ... ) + >>> dataset = ParquetDataset(filepath=tmp_path / "test.parquet", save_args=None) + >>> dataset.save(data) + >>> reloaded = dataset.load() + >>> + >>> assert data.equals(reloaded) + + """ + + DEFAULT_LOAD_ARGS: dict[str, Any] = {} + DEFAULT_SAVE_ARGS: dict[str, Any] = {} + + def __init__( # noqa: PLR0913 + self, + *, + filepath: str, + load_args: dict[str, Any] | None = None, + save_args: dict[str, Any] | None = None, + version: Version | None = None, + credentials: dict[str, Any] | None = None, + fs_args: dict[str, Any] | None = None, + metadata: dict[str, Any] | None = None, + ) -> None: + """Creates a new instance of ``ParquetDataset`` pointing to a concrete parquet file + on a specific filesystem fsspec. + + Args: + + filepath: Filepath in POSIX format to a parquet file prefixed with a protocol like + `s3://`. If prefix is not provided `file` protocol (local filesystem) will be used. + The prefix should be any protocol supported by ``fsspec``. + Note: `http(s)` doesn't support versioning. + load_args: GeoPandas options for loading parquet files. + Here you can find all available arguments: + https://geopandas.org/en/stable/docs/reference/api/geopandas.read_parquet.html + save_args: GeoPandas options for saving parquet files. + Here you can find all available arguments: + https://geopandas.org/en/stable/docs/reference/api/geopandas.GeoDataFrame.to_parquet.html + version: If specified, should be an instance of + ``kedro.io.core.Version``. If its ``load`` attribute is + None, the latest version will be loaded. If its ``save`` + credentials: credentials required to access the underlying filesystem. + Eg. for ``GCFileSystem`` it would look like `{'token': None}`. + fs_args: Extra arguments to pass into underlying filesystem class constructor + (e.g. `{"project": "my-project"}` for ``GCSFileSystem``), as well as + to pass to the filesystem's `open` method through nested keys + `open_args_load` and `open_args_save`. + Here you can find all available arguments for `open`: + https://filesystem-spec.readthedocs.io/en/latest/api.html#fsspec.spec.AbstractFileSystem.open + All defaults are preserved, except `mode`, which is set to `wb` when saving. + metadata: Any arbitrary metadata. + This is ignored by Kedro, but may be consumed by users or external plugins. + """ + _fs_args = copy.deepcopy(fs_args) or {} + _fs_open_args_load = _fs_args.pop("open_args_load", {}) + _fs_open_args_save = _fs_args.pop("open_args_save", {}) + _credentials = copy.deepcopy(credentials) or {} + protocol, path = get_protocol_and_path(filepath, version) + self._protocol = protocol + if protocol == "file": + _fs_args.setdefault("auto_mkdir", True) + + self._fs = fsspec.filesystem(self._protocol, **_credentials, **_fs_args) + + self.metadata = metadata + + super().__init__( + filepath=PurePosixPath(path), + version=version, + exists_function=self._fs.exists, + glob_function=self._fs.glob, + ) + + self._load_args = copy.deepcopy(self.DEFAULT_LOAD_ARGS) + if load_args is not None: + self._load_args.update(load_args) + + self._save_args = copy.deepcopy(self.DEFAULT_SAVE_ARGS) + if save_args is not None: + self._save_args.update(save_args) + + _fs_open_args_save.setdefault("mode", "wb") + self._fs_open_args_load = _fs_open_args_load + self._fs_open_args_save = _fs_open_args_save + + def _load(self) -> gpd.GeoDataFrame | dict[str, gpd.GeoDataFrame]: + load_path = get_filepath_str(self._get_load_path(), self._protocol) + with self._fs.open(load_path, **self._fs_open_args_load) as fs_file: + return gpd.read_parquet(fs_file, **self._load_args) + + def _save(self, data: gpd.GeoDataFrame) -> None: + save_path = get_filepath_str(self._get_save_path(), self._protocol) + with self._fs.open(save_path, **self._fs_open_args_save) as fs_file: + data.to_parquet(fs_file, **self._save_args) + self.invalidate_cache() + + def _exists(self) -> bool: + try: + load_path = get_filepath_str(self._get_load_path(), self._protocol) + except DatasetError: + return False + return self._fs.exists(load_path) + + def _describe(self) -> dict[str, Any]: + return { + "filepath": self._filepath, + "protocol": self._protocol, + "load_args": self._load_args, + "save_args": self._save_args, + "version": self._version, + } + + def _release(self) -> None: + self.invalidate_cache() + + def invalidate_cache(self) -> None: + """Invalidate underlying filesystem cache.""" + filepath = get_filepath_str(self._filepath, self._protocol) + self._fs.invalidate_cache(filepath) diff --git a/kedro-datasets/pyproject.toml b/kedro-datasets/pyproject.toml index 270a2673c..f0388acbb 100644 --- a/kedro-datasets/pyproject.toml +++ b/kedro-datasets/pyproject.toml @@ -25,6 +25,7 @@ polars-base = ["polars>=0.18.0",] plotly-base = ["plotly>=4.8.0, <6.0"] delta-base = ["delta-spark~=1.2.1",] networkx-base = ["networkx~=2.4"] +geopandas-base = ["geopandas>=0.8.0, <1.0"] # Individual Datasets api-apidataset = ["requests~=2.20"] @@ -39,8 +40,9 @@ dask = ["kedro-datasets[dask-parquetdataset]"] databricks-managedtabledataset = ["kedro-datasets[spark-base,pandas-base,delta-base,hdfs-base,s3fs-base]"] databricks = ["kedro-datasets[databricks-managedtabledataset]"] -geopandas-geojsondataset = ["geopandas>=0.6.0, <1.0", "pyproj~=3.0"] -geopandas = ["kedro-datasets[geopandas-geojsondataset]"] +geopandas-geojsondataset = ["kedro-datasets[geopandas-base]", "pyproj~=3.0"] +geopandas-parquetdataset = ["kedro-datasets[geopandas-base]"] +geopandas = ["kedro-datasets[geopandas-geojsondataset,geopandas-parquetdataset]"] holoviews-holoviewswriter = ["holoviews~=1.13.0"] holoviews = ["kedro-datasets[holoviews-holoviewswriter]"] @@ -199,7 +201,7 @@ test = [ "dill~=0.3.1", "filelock>=3.4.0, <4.0", "gcsfs>=2023.1, <2023.3", - "geopandas>=0.6.0, <1.0", + "geopandas>=0.8.0, <1.0", "hdfs>=2.5.8, <3.0", "holoviews>=1.13.0", "ibis-framework[duckdb,examples]", diff --git a/kedro-datasets/tests/geopandas/test_parquet_dataset.py b/kedro-datasets/tests/geopandas/test_parquet_dataset.py new file mode 100644 index 000000000..f0e305cc8 --- /dev/null +++ b/kedro-datasets/tests/geopandas/test_parquet_dataset.py @@ -0,0 +1,229 @@ +from pathlib import Path, PurePosixPath + +import geopandas as gpd +import pytest +from fsspec.implementations.http import HTTPFileSystem +from fsspec.implementations.local import LocalFileSystem +from gcsfs import GCSFileSystem +from kedro.io.core import PROTOCOL_DELIMITER, DatasetError, Version, generate_timestamp +from pandas.testing import assert_frame_equal +from s3fs import S3FileSystem +from shapely.geometry import Point + +from kedro_datasets.geopandas import ParquetDataset + + +@pytest.fixture(params=[None]) +def load_version(request): + return request.param + + +@pytest.fixture(params=[None]) +def save_version(request): + return request.param or generate_timestamp() + + +@pytest.fixture +def filepath(tmp_path): + return (tmp_path / "test.parquet").as_posix() + + +@pytest.fixture(params=[None]) +def load_args(request): + return request.param + + +@pytest.fixture(params=[None]) +def save_args(request): + return request.param + + +@pytest.fixture +def dummy_dataframe(): + return gpd.GeoDataFrame( + {"col1": [1, 2], "col2": [4, 5], "col3": [5, 6]}, + geometry=[Point(1, 1), Point(2, 2)], + ) + + +@pytest.fixture +def parquet_dataset(filepath, load_args, save_args, fs_args): + return ParquetDataset( + filepath=filepath, load_args=load_args, save_args=save_args, fs_args=fs_args + ) + + +@pytest.fixture +def versioned_parquet_dataset(filepath, load_version, save_version): + return ParquetDataset( + filepath=filepath, version=Version(load_version, save_version) + ) + + +class TestParquetDataset: + def test_save_and_load(self, parquet_dataset, dummy_dataframe): + """Test that saved and reloaded data matches the original one.""" + parquet_dataset.save(dummy_dataframe) + reloaded_df = parquet_dataset.load() + assert_frame_equal(reloaded_df, dummy_dataframe) + assert parquet_dataset._fs_open_args_load == {} + assert parquet_dataset._fs_open_args_save == {"mode": "wb"} + + @pytest.mark.parametrize("parquet_dataset", [{"index": False}], indirect=True) + def test_load_missing_file(self, parquet_dataset): + """Check the error while trying to load from missing source.""" + pattern = r"Failed while loading data from data set ParquetDataSet" + with pytest.raises(DatasetError, match=pattern): + parquet_dataset.load() + + def test_exists(self, parquet_dataset, dummy_dataframe): + """Test `exists` method invocation for both cases.""" + assert not parquet_dataset.exists() + parquet_dataset.save(dummy_dataframe) + assert parquet_dataset.exists() + + @pytest.mark.parametrize("load_args", [{"crs": "init:4326"}, {"crs": "init:2154"}]) + def test_load_extra_params(self, parquet_dataset, load_args): + """Test overriding default save args""" + for k, v in load_args.items(): + assert parquet_dataset._load_args[k] == v + + @pytest.mark.parametrize( + "save_args", [{"driver": "ESRI Shapefile"}, {"driver": "GPKG"}] + ) + def test_save_extra_params(self, parquet_dataset, save_args): + """Test overriding default save args""" + for k, v in save_args.items(): + assert parquet_dataset._save_args[k] == v + + @pytest.mark.parametrize( + "fs_args", + [{"open_args_load": {"mode": "rb", "compression": "gzip"}}], + indirect=True, + ) + def test_open_extra_args(self, parquet_dataset, fs_args): + assert parquet_dataset._fs_open_args_load == fs_args["open_args_load"] + assert parquet_dataset._fs_open_args_save == {"mode": "wb"} + + @pytest.mark.parametrize( + "path,instance_type", + [ + ("s3://bucket/file.parquet", S3FileSystem), + ("/tmp/test.parquet", LocalFileSystem), + ("gcs://bucket/file.parquet", GCSFileSystem), + ("file:///tmp/file.parquet", LocalFileSystem), + ("https://example.com/file.parquet", HTTPFileSystem), + ], + ) + def test_protocol_usage(self, path, instance_type): + parquet_dataset = ParquetDataset(filepath=path) + assert isinstance(parquet_dataset._fs, instance_type) + + path = path.split(PROTOCOL_DELIMITER, 1)[-1] + + assert str(parquet_dataset._filepath) == path + assert isinstance(parquet_dataset._filepath, PurePosixPath) + + def test_catalog_release(self, mocker): + fs_mock = mocker.patch("fsspec.filesystem").return_value + filepath = "test.parquet" + parquet_dataset = ParquetDataset(filepath=filepath) + parquet_dataset.release() + fs_mock.invalidate_cache.assert_called_once_with(filepath) + + +class TestParquetDatasetVersioned: + def test_version_str_repr(self, load_version, save_version): + """Test that version is in string representation of the class instance + when applicable.""" + filepath = "test.parquet" + ds = ParquetDataset(filepath=filepath) + ds_versioned = ParquetDataset( + filepath=filepath, version=Version(load_version, save_version) + ) + assert filepath in str(ds) + assert "version" not in str(ds) + + assert filepath in str(ds_versioned) + ver_str = f"version=Version(load={load_version}, save='{save_version}')" + assert ver_str in str(ds_versioned) + assert "ParquetDataset" in str(ds_versioned) + assert "ParquetDataset" in str(ds) + assert "protocol" in str(ds_versioned) + assert "protocol" in str(ds) + + def test_save_and_load(self, versioned_parquet_dataset, dummy_dataframe): + """Test that saved and reloaded data matches the original one for + the versioned data set.""" + versioned_parquet_dataset.save(dummy_dataframe) + reloaded_df = versioned_parquet_dataset.load() + assert_frame_equal(reloaded_df, dummy_dataframe) + + def test_no_versions(self, versioned_parquet_dataset): + """Check the error if no versions are available for load.""" + pattern = r"Did not find any versions for ParquetDataset\(.+\)" + with pytest.raises(DatasetError, match=pattern): + versioned_parquet_dataset.load() + + def test_exists(self, versioned_parquet_dataset, dummy_dataframe): + """Test `exists` method invocation for versioned data set.""" + assert not versioned_parquet_dataset.exists() + versioned_parquet_dataset.save(dummy_dataframe) + assert versioned_parquet_dataset.exists() + + def test_prevent_override(self, versioned_parquet_dataset, dummy_dataframe): + """Check the error when attempt to override the same data set + version.""" + versioned_parquet_dataset.save(dummy_dataframe) + pattern = ( + r"Save path \'.+\' for ParquetDataset\(.+\) must not " + r"exist if versioning is enabled" + ) + with pytest.raises(DatasetError, match=pattern): + versioned_parquet_dataset.save(dummy_dataframe) + + @pytest.mark.parametrize( + "load_version", ["2019-01-01T23.59.59.999Z"], indirect=True + ) + @pytest.mark.parametrize( + "save_version", ["2019-01-02T00.00.00.000Z"], indirect=True + ) + def test_save_version_warning( + self, versioned_parquet_dataset, load_version, save_version, dummy_dataframe + ): + """Check the warning when saving to the path that differs from + the subsequent load path.""" + pattern = ( + rf"Save version '{save_version}' did not match load version " + rf"'{load_version}' for ParquetDataset\(.+\)" + ) + with pytest.warns(UserWarning, match=pattern): + versioned_parquet_dataset.save(dummy_dataframe) + + def test_http_filesystem_no_versioning(self): + pattern = "Versioning is not supported for HTTP protocols." + + with pytest.raises(DatasetError, match=pattern): + ParquetDataset( + filepath="https://example/file.parquet", version=Version(None, None) + ) + + def test_versioning_existing_dataset( + self, parquet_dataset, versioned_parquet_dataset, dummy_dataframe + ): + """Check the error when attempting to save a versioned dataset on top of an + already existing (non-versioned) dataset.""" + parquet_dataset.save(dummy_dataframe) + assert parquet_dataset.exists() + assert parquet_dataset._filepath == versioned_parquet_dataset._filepath + pattern = ( + f"(?=.*file with the same name already exists in the directory)" + f"(?=.*{versioned_parquet_dataset._filepath.parent.as_posix()})" + ) + with pytest.raises(DatasetError, match=pattern): + versioned_parquet_dataset.save(dummy_dataframe) + + # Remove non-versioned dataset and try again + Path(parquet_dataset._filepath.as_posix()).unlink() + versioned_parquet_dataset.save(dummy_dataframe) + assert versioned_parquet_dataset.exists()