-
Notifications
You must be signed in to change notification settings - Fork 97
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(datasets): Add geopandas ParquetDataset
- Loading branch information
1 parent
994f86c
commit 711093b
Showing
4 changed files
with
401 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
162 changes: 162 additions & 0 deletions
162
kedro-datasets/kedro_datasets/geopandas/parquet_dataset.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,162 @@ | ||
"""ParquetDataset loads and saves data to a local parquet file. The | ||
underlying functionality is supported by geopandas, so it supports all | ||
allowed geopandas (pandas) options for loading and saving geosjon files. | ||
""" | ||
|
||
from __future__ import annotations | ||
|
||
import copy | ||
from pathlib import PurePosixPath | ||
from typing import Any, Union | ||
|
||
import fsspec | ||
import geopandas as gpd | ||
from kedro.io.core import ( | ||
AbstractVersionedDataset, | ||
DatasetError, | ||
Version, | ||
get_filepath_str, | ||
get_protocol_and_path, | ||
) | ||
|
||
|
||
class ParquetDataset( | ||
AbstractVersionedDataset[ | ||
gpd.GeoDataFrame, Union[gpd.GeoDataFrame, dict[str, gpd.GeoDataFrame]] | ||
] | ||
): | ||
"""``ParquetDataset`` loads/saves data to a parquet file using an underlying filesystem | ||
(eg: local, S3, GCS). | ||
The underlying functionality is supported by geopandas, so it supports all | ||
allowed geopandas (pandas) options for loading and saving parquet files. | ||
Example: | ||
.. code-block:: pycon | ||
>>> import geopandas as gpd | ||
>>> from kedro_datasets.geopandas import ParquetDataset | ||
>>> from shapely.geometry import Point | ||
>>> | ||
>>> data = gpd.GeoDataFrame( | ||
... {"col1": [1, 2], "col2": [4, 5], "col3": [5, 6]}, | ||
... geometry=[Point(1, 1), Point(2, 4)], | ||
... ) | ||
>>> dataset = ParquetDataset(filepath=tmp_path / "test.parquet", save_args=None) | ||
>>> dataset.save(data) | ||
>>> reloaded = dataset.load() | ||
>>> | ||
>>> assert data.equals(reloaded) | ||
""" | ||
|
||
DEFAULT_LOAD_ARGS: dict[str, Any] = {} | ||
DEFAULT_SAVE_ARGS: dict[str, Any] = {} | ||
|
||
def __init__( # noqa: PLR0913 | ||
self, | ||
*, | ||
filepath: str, | ||
load_args: dict[str, Any] | None = None, | ||
save_args: dict[str, Any] | None = None, | ||
version: Version | None = None, | ||
credentials: dict[str, Any] | None = None, | ||
fs_args: dict[str, Any] | None = None, | ||
metadata: dict[str, Any] | None = None, | ||
) -> None: | ||
"""Creates a new instance of ``ParquetDataset`` pointing to a concrete parquet file | ||
on a specific filesystem fsspec. | ||
Args: | ||
filepath: Filepath in POSIX format to a parquet file prefixed with a protocol like | ||
`s3://`. If prefix is not provided `file` protocol (local filesystem) will be used. | ||
The prefix should be any protocol supported by ``fsspec``. | ||
Note: `http(s)` doesn't support versioning. | ||
load_args: GeoPandas options for loading parquet files. | ||
Here you can find all available arguments: | ||
https://geopandas.org/en/stable/docs/reference/api/geopandas.read_parquet.html | ||
save_args: GeoPandas options for saving parquet files. | ||
Here you can find all available arguments: | ||
https://geopandas.org/en/stable/docs/reference/api/geopandas.GeoDataFrame.to_parquet.html | ||
version: If specified, should be an instance of | ||
``kedro.io.core.Version``. If its ``load`` attribute is | ||
None, the latest version will be loaded. If its ``save`` | ||
credentials: credentials required to access the underlying filesystem. | ||
Eg. for ``GCFileSystem`` it would look like `{'token': None}`. | ||
fs_args: Extra arguments to pass into underlying filesystem class constructor | ||
(e.g. `{"project": "my-project"}` for ``GCSFileSystem``), as well as | ||
to pass to the filesystem's `open` method through nested keys | ||
`open_args_load` and `open_args_save`. | ||
Here you can find all available arguments for `open`: | ||
https://filesystem-spec.readthedocs.io/en/latest/api.html#fsspec.spec.AbstractFileSystem.open | ||
All defaults are preserved, except `mode`, which is set to `wb` when saving. | ||
metadata: Any arbitrary metadata. | ||
This is ignored by Kedro, but may be consumed by users or external plugins. | ||
""" | ||
_fs_args = copy.deepcopy(fs_args) or {} | ||
_fs_open_args_load = _fs_args.pop("open_args_load", {}) | ||
_fs_open_args_save = _fs_args.pop("open_args_save", {}) | ||
_credentials = copy.deepcopy(credentials) or {} | ||
protocol, path = get_protocol_and_path(filepath, version) | ||
self._protocol = protocol | ||
if protocol == "file": | ||
_fs_args.setdefault("auto_mkdir", True) | ||
|
||
self._fs = fsspec.filesystem(self._protocol, **_credentials, **_fs_args) | ||
|
||
self.metadata = metadata | ||
|
||
super().__init__( | ||
filepath=PurePosixPath(path), | ||
version=version, | ||
exists_function=self._fs.exists, | ||
glob_function=self._fs.glob, | ||
) | ||
|
||
self._load_args = copy.deepcopy(self.DEFAULT_LOAD_ARGS) | ||
if load_args is not None: | ||
self._load_args.update(load_args) | ||
|
||
self._save_args = copy.deepcopy(self.DEFAULT_SAVE_ARGS) | ||
if save_args is not None: | ||
self._save_args.update(save_args) | ||
|
||
_fs_open_args_save.setdefault("mode", "wb") | ||
self._fs_open_args_load = _fs_open_args_load | ||
self._fs_open_args_save = _fs_open_args_save | ||
|
||
def _load(self) -> gpd.GeoDataFrame | dict[str, gpd.GeoDataFrame]: | ||
load_path = get_filepath_str(self._get_load_path(), self._protocol) | ||
with self._fs.open(load_path, **self._fs_open_args_load) as fs_file: | ||
return gpd.read_parquet(fs_file, **self._load_args) | ||
|
||
def _save(self, data: gpd.GeoDataFrame) -> None: | ||
save_path = get_filepath_str(self._get_save_path(), self._protocol) | ||
with self._fs.open(save_path, **self._fs_open_args_save) as fs_file: | ||
data.to_parquet(fs_file, **self._save_args) | ||
self.invalidate_cache() | ||
|
||
def _exists(self) -> bool: | ||
try: | ||
load_path = get_filepath_str(self._get_load_path(), self._protocol) | ||
except DatasetError: | ||
return False | ||
return self._fs.exists(load_path) | ||
|
||
def _describe(self) -> dict[str, Any]: | ||
return { | ||
"filepath": self._filepath, | ||
"protocol": self._protocol, | ||
"load_args": self._load_args, | ||
"save_args": self._save_args, | ||
"version": self._version, | ||
} | ||
|
||
def _release(self) -> None: | ||
self.invalidate_cache() | ||
|
||
def invalidate_cache(self) -> None: | ||
"""Invalidate underlying filesystem cache.""" | ||
filepath = get_filepath_str(self._filepath, self._protocol) | ||
self._fs.invalidate_cache(filepath) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.