diff --git a/kedro-datasets/RELEASE.md b/kedro-datasets/RELEASE.md index e565d674c..9aed7b27e 100755 --- a/kedro-datasets/RELEASE.md +++ b/kedro-datasets/RELEASE.md @@ -1,6 +1,9 @@ # Upcoming Release ## Major features and improvements * Moved `PartitionedDataSet` and `IncrementalDataSet` from the core Kedro repo to `kedro-datasets` and renamed to `PartitionedDataset` and `IncrementalDataset`. +* Added `polars.LazyPolarsDataset`, a `GenericDataSet` using [polars](https://www.pola.rs/)'s Lazy API. +* Renamed `polars.GenericDataSet` to `polars.EagerPolarsDataset` to better reflect the difference between the two dataset classes. +* Added a deprecation warning when using `polars.GenericDataSet` or `polars.GenericDataset` that these have been renamed to `polars.EagerPolarsDataset` * Delayed backend connection for `pandas.SQLTableDataset`, `pandas.SQLQueryDataset`, and `snowflake.SnowparkTableDataset`. In practice, this means that a dataset's connection details aren't used (or validated) until the dataset is accessed. On the plus side, the cost of connection isn't incurred regardless of when or whether the dataset is used. ## Bug fixes and other changes diff --git a/kedro-datasets/docs/source/kedro_datasets.rst b/kedro-datasets/docs/source/kedro_datasets.rst index 67f87e0e3..3091b3c4a 100644 --- a/kedro-datasets/docs/source/kedro_datasets.rst +++ b/kedro-datasets/docs/source/kedro_datasets.rst @@ -73,6 +73,8 @@ kedro_datasets kedro_datasets.polars.CSVDataset kedro_datasets.polars.GenericDataSet kedro_datasets.polars.GenericDataset + kedro_datasets.polars.EagerPolarsDataset + kedro_datasets.polars.LazyPolarsDataset kedro_datasets.redis.PickleDataSet kedro_datasets.redis.PickleDataset kedro_datasets.snowflake.SnowparkTableDataSet diff --git a/kedro-datasets/kedro_datasets/polars/__init__.py b/kedro-datasets/kedro_datasets/polars/__init__.py index 5070de80d..236b6eb7b 100644 --- a/kedro-datasets/kedro_datasets/polars/__init__.py +++ b/kedro-datasets/kedro_datasets/polars/__init__.py @@ -8,13 +8,20 @@ # https://github.com/pylint-dev/pylint/issues/4300#issuecomment-1043601901 CSVDataSet: type[CSVDataset] CSVDataset: Any -GenericDataSet: type[GenericDataset] -GenericDataset: Any +GenericDataSet: type[EagerPolarsDataset] +GenericDataset: type[EagerPolarsDataset] +EagerPolarsDataset: Any +LazyPolarsDataset: Any __getattr__, __dir__, __all__ = lazy.attach( __name__, submod_attrs={ "csv_dataset": ["CSVDataSet", "CSVDataset"], - "generic_dataset": ["GenericDataSet", "GenericDataset"], + "eager_polars_dataset": [ + "EagerPolarsDataset", + "GenericDataSet", + "GenericDataset", + ], + "lazy_polars_dataset": ["LazyPolarsDataset"], }, ) diff --git a/kedro-datasets/kedro_datasets/polars/generic_dataset.py b/kedro-datasets/kedro_datasets/polars/eager_polars_dataset.py similarity index 92% rename from kedro-datasets/kedro_datasets/polars/generic_dataset.py rename to kedro-datasets/kedro_datasets/polars/eager_polars_dataset.py index aa6eedd48..24a5a1acc 100644 --- a/kedro-datasets/kedro_datasets/polars/generic_dataset.py +++ b/kedro-datasets/kedro_datasets/polars/eager_polars_dataset.py @@ -1,4 +1,4 @@ -"""``GenericDataset`` loads/saves data from/to a data file using an underlying +"""``EagerPolarsDataset`` loads/saves data from/to a data file using an underlying filesystem (e.g.: local, S3, GCS). It uses polars to handle the type of read/write target. """ @@ -16,8 +16,8 @@ from kedro_datasets._io import AbstractVersionedDataset, DatasetError -class GenericDataset(AbstractVersionedDataset[pl.DataFrame, pl.DataFrame]): - """``polars.GenericDataset`` loads/saves data from/to a data file using an underlying +class EagerPolarsDataset(AbstractVersionedDataset[pl.DataFrame, pl.DataFrame]): + """``polars.EagerPolarsDataset`` loads/saves data from/to a data file using an underlying filesystem (e.g.: local, S3, GCS). It uses polars to handle the dynamically select the appropriate type of read/write on a best effort basis. @@ -27,7 +27,7 @@ class GenericDataset(AbstractVersionedDataset[pl.DataFrame, pl.DataFrame]): .. code-block:: yaml cars: - type: polars.GenericDataset + type: polars.EagerPolarsDataset file_format: parquet filepath: s3://data/01_raw/company/cars.parquet load_args: @@ -39,12 +39,12 @@ class GenericDataset(AbstractVersionedDataset[pl.DataFrame, pl.DataFrame]): .. code-block:: pycon - >>> from kedro_datasets.polars import GenericDataset + >>> from kedro_datasets.polars import EagerPolarsDataset >>> import polars as pl >>> >>> data = pl.DataFrame({"col1": [1, 2], "col2": [4, 5], "col3": [5, 6]}) >>> - >>> dataset = GenericDataset(filepath="test.parquet", file_format="parquet") + >>> dataset = EagerPolarsDataset(filepath="test.parquet", file_format="parquet") >>> dataset.save(data) >>> reloaded = dataset.load() >>> assert data.frame_equal(reloaded) @@ -64,7 +64,7 @@ def __init__( # noqa: PLR0913 credentials: Dict[str, Any] = None, fs_args: Dict[str, Any] = None, ): - """Creates a new instance of ``GenericDataset`` pointing to a concrete data file + """Creates a new instance of ``EagerPolarsDataset`` pointing to a concrete data file on a specific filesystem. The appropriate polars load/save methods are dynamically identified by string matching on a best effort basis. @@ -200,7 +200,8 @@ def _invalidate_cache(self) -> None: _DEPRECATED_CLASSES = { - "GenericDataSet": GenericDataset, + "GenericDataSet": EagerPolarsDataset, + "GenericDataset": EagerPolarsDataset, } diff --git a/kedro-datasets/kedro_datasets/polars/lazy_polars_dataset.py b/kedro-datasets/kedro_datasets/polars/lazy_polars_dataset.py new file mode 100644 index 000000000..698fa2392 --- /dev/null +++ b/kedro-datasets/kedro_datasets/polars/lazy_polars_dataset.py @@ -0,0 +1,248 @@ +"""``LazyPolarsDataset`` loads/saves data from/to a data file using an underlying +filesystem (e.g.: local, S3, GCS). It uses polars to handle the +type of read/write target. +""" +import logging +from copy import deepcopy +from io import BytesIO +from pathlib import PurePosixPath +from typing import Any, ClassVar, Dict, Optional, Union + +import fsspec +import polars as pl +import pyarrow.dataset as ds +from kedro.io.core import ( + AbstractVersionedDataSet, + DatasetError, + Version, + get_filepath_str, + get_protocol_and_path, +) + +ACCEPTED_FILE_FORMATS = ["csv", "parquet"] + +PolarsFrame = Union[pl.LazyFrame, pl.DataFrame] + +logger = logging.getLogger(__name__) + + +class LazyPolarsDataset(AbstractVersionedDataSet[pl.LazyFrame, PolarsFrame]): + """``LazyPolarsDataset`` loads/saves data from/to a data file using an + underlying filesystem (e.g.: local, S3, GCS). It uses Polars to handle + the type of read/write target. It uses lazy loading with Polars Lazy API, but it can + save both Lazy and Eager Polars DataFrames. + + Example usage for the `YAML API `_: + + .. code-block:: yaml + + >>> cars: + >>> type: polars.LazyPolarsDataset + >>> filepath: data/01_raw/company/cars.csv + >>> load_args: + >>> sep: "," + >>> parse_dates: False + >>> save_args: + >>> has_header: False + null_value: "somenullstring" + >>> + >>> motorbikes: + >>> type: polars.LazyPolarsDataset + >>> filepath: s3://your_bucket/data/02_intermediate/company/motorbikes.csv + >>> credentials: dev_s3 + + Example using Python API: + :: + + >>> from kedro_datasets.polars import LazyPolarsDataset + >>> import polars as pl + >>> + >>> data = pl.DataFrame({'col1': [1, 2], 'col2': [4, 5], + >>> 'col3': [5, 6]}) + >>> + >>> data_set = LazyPolarsDataset(filepath="test.csv") + >>> data_set.save(data) + >>> reloaded = data_set.load() + >>> assert data.frame_equal(reloaded) + + """ + + DEFAULT_LOAD_ARGS: ClassVar[Dict[str, Any]] = {} + DEFAULT_SAVE_ARGS: ClassVar[Dict[str, Any]] = {} + + def __init__( # noqa: PLR0913 + self, + filepath: str, + file_format: str, + load_args: Optional[Dict[str, Any]] = None, + save_args: Optional[Dict[str, Any]] = None, + version: Version = None, + credentials: Optional[Dict[str, Any]] = None, + fs_args: Optional[Dict[str, Any]] = None, + metadata: Optional[Dict[str, Any]] = None, + ) -> None: + """Creates a new instance of ``LazyPolarsDataset`` pointing to a concrete + data file on a specific filesystem. + + Args: + filepath: Filepath in POSIX format to a file prefixed with a protocol like + `s3://`. + If prefix is not provided, `file` protocol (local filesystem) + will be used. + The prefix should be any protocol supported by ``fsspec``. + Key assumption: The first argument of either load/save method points to + a filepath/buffer/io type location. There are some read/write targets + such as 'clipboard' or 'records' that will fail since they do not take a + filepath like argument. + file_format: String which is used to match the appropriate load/save method + on a best effort basis. For example if 'csv' is passed the + `polars.read_csv` and + `polars.DataFrame.write_csv` methods will be identified. An error will + be raised unless + at least one matching `read_{file_format}` or `write_{file_format}`. + load_args: polars options for loading files. + Here you can find all available arguments: + https://pola-rs.github.io/polars/py-polars/html/reference/io.html + All defaults are preserved. + save_args: Polars options for saving files. + Here you can find all available arguments: + https://pola-rs.github.io/polars/py-polars/html/reference/io.html + All defaults are preserved. + version: If specified, should be an instance of + ``kedro.io.core.Version``. If its ``load`` attribute is + None, the latest version will be loaded. If its ``save`` + attribute is None, save version will be autogenerated. + credentials: Credentials required to get access to the underlying filesystem. + E.g. for ``GCSFileSystem`` it should look like `{"token": None}`. + fs_args: Extra arguments to pass into underlying filesystem class constructor + (e.g. `{"project": "my-project"}` for ``GCSFileSystem``), as well as + to pass to the filesystem's `open` method through nested keys + `open_args_load` and `open_args_save`. + Here you can find all available arguments for `open`: + https://filesystem-spec.readthedocs.io/en/latest/api.html#fsspec.spec.AbstractFileSystem.open + All defaults are preserved, except `mode`, which is set to `r` when loading + and to `w` when saving. + metadata: Any arbitrary metadata. + This is ignored by Kedro, but may be consumed by users or external plugins. + Raises: + DatasetError: Will be raised if at least less than one appropriate + read or write methods are identified. + """ + self._file_format = file_format.lower() + + if self._file_format not in ACCEPTED_FILE_FORMATS: + raise DatasetError( + f"'{self._file_format}' is not an accepted format " + f"({ACCEPTED_FILE_FORMATS}) ensure that your 'file_format' parameter " + "has been defined correctly as per the Polars API " + "https://pola-rs.github.io/polars/py-polars/html/reference/io.html" + ) + + _fs_args = deepcopy(fs_args) or {} + _credentials = deepcopy(credentials) or {} + + protocol, path = get_protocol_and_path(filepath, version) + if protocol == "file": + _fs_args.setdefault("auto_mkdir", True) + + self._protocol = protocol + self._storage_options = {**_credentials, **_fs_args} + self._fs = fsspec.filesystem(self._protocol, **self._storage_options) + + self.metadata = metadata + + super().__init__( + filepath=PurePosixPath(path), + version=version, + exists_function=self._fs.exists, + glob_function=self._fs.glob, + ) + + # Handle default load and save arguments + self._load_args = deepcopy(self.DEFAULT_LOAD_ARGS) + if load_args is not None: + self._load_args.update(load_args) + self._save_args = deepcopy(self.DEFAULT_SAVE_ARGS) + if save_args is not None: + self._save_args.update(save_args) + + if "storage_options" in self._save_args or "storage_options" in self._load_args: + logger.warning( + "Dropping 'storage_options' for %s, " + "please specify them under 'fs_args' or 'credentials'.", + self._filepath, + ) + self._save_args.pop("storage_options", None) + self._load_args.pop("storage_options", None) + + def _describe(self) -> Dict[str, Any]: + return { + "filepath": self._filepath, + "protocol": self._protocol, + "load_args": self._load_args, + "save_args": self._save_args, + "version": self._version, + } + + def _load(self) -> pl.LazyFrame: + load_path = str(self._get_load_path()) + + if self._protocol == "file": + # With local filesystems, we can use Polar's build-in I/O method: + load_method = getattr(pl, f"scan_{self._file_format}", None) + return load_method(load_path, **self._load_args) + + # For object storage, we use pyarrow for I/O: + dataset = ds.dataset( + load_path, filesystem=self._fs, format=self._file_format, **self._load_args + ) + return pl.scan_pyarrow_dataset(dataset) + + def _save(self, data: Union[pl.DataFrame, pl.LazyFrame]) -> None: + save_path = get_filepath_str(self._get_save_path(), self._protocol) + + collected_data = None + if isinstance(data, pl.LazyFrame): + collected_data = data.collect() + else: + collected_data = data + + # Note: polars does support writing partitioned parquet file + # it is leveraging Arrow to do so, see e.g. + # https://pola-rs.github.io/polars/py-polars/html/reference/api/polars.DataFrame.write_parquet.html + save_method = getattr(collected_data, f"write_{self._file_format}", None) + if save_method: + buf = BytesIO() + save_method(file=buf, **self._save_args) + with self._fs.open(save_path, mode="wb") as fs_file: + fs_file.write(buf.getvalue()) + self._invalidate_cache() + # How the LazyPolarsDataset logic is currently written with + # ACCEPTED_FILE_FORMATS and a check in the `__init__` method, + # this else loop is never reached, hence we exclude it from coverage report + # but leave it in for consistency between the Eager and Lazy classes + else: # pragma: no cover + raise DatasetError( + f"Unable to retrieve 'polars.DataFrame.write_{self._file_format}' " + "method, please ensure that your 'file_format' parameter has been " + "defined correctly as per the Polars API" + "https://pola-rs.github.io/polars/py-polars/html/reference/dataframe/index.html" + ) + + def _exists(self) -> bool: + try: + load_path = get_filepath_str(self._get_load_path(), self._protocol) + except DatasetError: # pragma: no cover + return False + + return self._fs.exists(load_path) + + def _release(self) -> None: + super()._release() + self._invalidate_cache() + + def _invalidate_cache(self) -> None: + """Invalidate underlying filesystem caches.""" + filepath = get_filepath_str(self._filepath, self._protocol) + self._fs.invalidate_cache(filepath) diff --git a/kedro-datasets/pyproject.toml b/kedro-datasets/pyproject.toml index 51416af42..00c03ac34 100644 --- a/kedro-datasets/pyproject.toml +++ b/kedro-datasets/pyproject.toml @@ -31,6 +31,7 @@ version = {attr = "kedro_datasets.__version__"} [tool.coverage.report] fail_under = 100 show_missing = true +# temporarily ignore kedro_datasets/__init__.py in coverage report omit = ["tests/*", "kedro_datasets/holoviews/*", "kedro_datasets/snowflake/*", "kedro_datasets/tensorflow/*", "kedro_datasets/__init__.py", "kedro_datasets/databricks/*"] exclude_lines = ["pragma: no cover", "raise NotImplementedError", "if TYPE_CHECKING:"] diff --git a/kedro-datasets/setup.py b/kedro-datasets/setup.py index a22e83f81..c0c5e8c43 100644 --- a/kedro-datasets/setup.py +++ b/kedro-datasets/setup.py @@ -59,6 +59,15 @@ def _collect_requirements(requires): [ POLARS, "pyarrow>=4.0", "xlsx2csv>=0.8.0", "deltalake >= 0.6.2" ], + "polars.EagerPolarsDataset": + [ + POLARS, "pyarrow>=4.0", "xlsx2csv>=0.8.0", "deltalake >= 0.6.2" + ], + "polars.LazyPolarsDataset": + [ + # Note: there is no Lazy read Excel option, so we exclude xlsx2csv here. + POLARS, "pyarrow>=4.0", "deltalake >= 0.6.2" + ], } redis_require = {"redis.PickleDataSet": ["redis~=4.1"]} snowflake_require = { diff --git a/kedro-datasets/tests/polars/test_generic_dataset.py b/kedro-datasets/tests/polars/test_eager_polars_dataset.py similarity index 89% rename from kedro-datasets/tests/polars/test_generic_dataset.py rename to kedro-datasets/tests/polars/test_eager_polars_dataset.py index b300cfd78..5687972db 100644 --- a/kedro-datasets/tests/polars/test_generic_dataset.py +++ b/kedro-datasets/tests/polars/test_eager_polars_dataset.py @@ -16,8 +16,8 @@ from kedro_datasets import KedroDeprecationWarning from kedro_datasets._io import DatasetError -from kedro_datasets.polars import GenericDataset -from kedro_datasets.polars.generic_dataset import _DEPRECATED_CLASSES +from kedro_datasets.polars import EagerPolarsDataset +from kedro_datasets.polars.eager_polars_dataset import _DEPRECATED_CLASSES @pytest.fixture @@ -37,7 +37,7 @@ def filepath_parquet(tmp_path): @pytest.fixture def versioned_csv_dataset(filepath_csv, load_version, save_version): - return GenericDataset( + return EagerPolarsDataset( filepath=filepath_csv.as_posix(), file_format="csv", version=Version(load_version, save_version), @@ -47,7 +47,7 @@ def versioned_csv_dataset(filepath_csv, load_version, save_version): @pytest.fixture def versioned_ipc_dataset(filepath_ipc, load_version, save_version): - return GenericDataset( + return EagerPolarsDataset( filepath=filepath_ipc.as_posix(), file_format="ipc", version=Version(load_version, save_version), @@ -57,7 +57,7 @@ def versioned_ipc_dataset(filepath_ipc, load_version, save_version): @pytest.fixture def versioned_parquet_dataset(filepath_parquet, load_version, save_version): - return GenericDataset( + return EagerPolarsDataset( filepath=filepath_parquet.as_posix(), file_format="parquet", version=Version(load_version, save_version), @@ -67,7 +67,7 @@ def versioned_parquet_dataset(filepath_parquet, load_version, save_version): @pytest.fixture def csv_dataset(filepath_csv): - return GenericDataset( + return EagerPolarsDataset( filepath=filepath_csv.as_posix(), file_format="csv", ) @@ -87,7 +87,7 @@ def filepath_excel(tmp_path): def parquet_dataset_ignore(dummy_dataframe: pl.DataFrame, filepath_parquet): dummy_dataframe.write_parquet(filepath_parquet) - return GenericDataset( + return EagerPolarsDataset( filepath=filepath_parquet.as_posix(), file_format="parquet", load_args={"low_memory": True}, @@ -99,14 +99,15 @@ def excel_dataset(dummy_dataframe: pl.DataFrame, filepath_excel): pd_df = dummy_dataframe.to_pandas() pd_df.to_excel(filepath_excel, index=False) - return GenericDataset( + return EagerPolarsDataset( filepath=filepath_excel.as_posix(), file_format="excel", ) @pytest.mark.parametrize( - "module_name", ["kedro_datasets.polars", "kedro_datasets.polars.generic_dataset"] + "module_name", + ["kedro_datasets.polars", "kedro_datasets.polars.eager_polars_dataset"], ) @pytest.mark.parametrize("class_name", _DEPRECATED_CLASSES) def test_deprecation(module_name, class_name): @@ -116,7 +117,7 @@ def test_deprecation(module_name, class_name): getattr(importlib.import_module(module_name), class_name) -class TestGenericExcelDataset: +class TestEagerExcelDataset: def test_load(self, excel_dataset): df = excel_dataset.load() assert df.shape == (2, 3) @@ -142,7 +143,7 @@ def test_save_and_load(self, excel_dataset, dummy_dataframe): ], ) def test_protocol_usage(self, filepath, instance_type, credentials): - dataset = GenericDataset( + dataset = EagerPolarsDataset( filepath=filepath, file_format="excel", credentials=credentials, @@ -157,14 +158,14 @@ def test_protocol_usage(self, filepath, instance_type, credentials): def test_catalog_release(self, mocker): fs_mock = mocker.patch("fsspec.filesystem").return_value filepath = "test.csv" - dataset = GenericDataset(filepath=filepath, file_format="excel") + dataset = EagerPolarsDataset(filepath=filepath, file_format="excel") assert dataset._version_cache.currsize == 0 # no cache if unversioned dataset.release() fs_mock.invalidate_cache.assert_called_once_with(filepath) assert dataset._version_cache.currsize == 0 -class TestGenericParquetDatasetVersioned: +class TestEagerParquetDatasetVersioned: def test_load_args(self, parquet_dataset_ignore): df = parquet_dataset_ignore.load() assert df.shape == (2, 3) @@ -179,8 +180,8 @@ def test_version_str_repr(self, filepath_parquet, load_version, save_version): """Test that version is in string representation of the class instance when applicable.""" filepath = filepath_parquet.as_posix() - ds = GenericDataset(filepath=filepath, file_format="parquet") - ds_versioned = GenericDataset( + ds = EagerPolarsDataset(filepath=filepath, file_format="parquet") + ds_versioned = EagerPolarsDataset( filepath=filepath, file_format="parquet", version=Version(load_version, save_version), @@ -189,8 +190,8 @@ def test_version_str_repr(self, filepath_parquet, load_version, save_version): assert filepath in str(ds_versioned) ver_str = f"version=Version(load={load_version}, save='{save_version}')" assert ver_str in str(ds_versioned) - assert "GenericDataset" in str(ds_versioned) - assert "GenericDataset" in str(ds) + assert "EagerPolarsDataset" in str(ds_versioned) + assert "EagerPolarsDataset" in str(ds) def test_multiple_loads( self, versioned_parquet_dataset, dummy_dataframe, filepath_parquet @@ -204,7 +205,7 @@ def test_multiple_loads( sleep(0.5) # force-drop a newer version into the same location v_new = generate_timestamp() - GenericDataset( + EagerPolarsDataset( filepath=filepath_parquet.as_posix(), file_format="parquet", version=Version(v_new, v_new), @@ -214,7 +215,7 @@ def test_multiple_loads( v2 = versioned_parquet_dataset.resolve_load_version() assert v2 == v1 # v2 should not be v_new! - ds_new = GenericDataset( + ds_new = EagerPolarsDataset( filepath=filepath_parquet.as_posix(), file_format="parquet", version=Version(None, None), @@ -225,7 +226,7 @@ def test_multiple_loads( def test_multiple_saves(self, dummy_dataframe, filepath_parquet): """Test multiple cycles of save followed by load for the same dataset""" - ds_versioned = GenericDataset( + ds_versioned = EagerPolarsDataset( filepath=filepath_parquet.as_posix(), file_format="parquet", version=Version(None, None), @@ -246,7 +247,7 @@ def test_multiple_saves(self, dummy_dataframe, filepath_parquet): assert second_load_version > first_load_version # another dataset - ds_new = GenericDataset( + ds_new = EagerPolarsDataset( filepath=filepath_parquet.as_posix(), file_format="parquet", version=Version(None, None), @@ -254,7 +255,7 @@ def test_multiple_saves(self, dummy_dataframe, filepath_parquet): assert ds_new.resolve_load_version() == second_load_version -class TestGenericIPCDatasetVersioned: +class TestEagerIPCDatasetVersioned: def test_save_and_load(self, versioned_ipc_dataset, dummy_dataframe): """Test saving and reloading the data set.""" versioned_ipc_dataset.save(dummy_dataframe) @@ -265,8 +266,8 @@ def test_version_str_repr(self, filepath_ipc, load_version, save_version): """Test that version is in string representation of the class instance when applicable.""" filepath = filepath_ipc.as_posix() - ds = GenericDataset(filepath=filepath, file_format="ipc") - ds_versioned = GenericDataset( + ds = EagerPolarsDataset(filepath=filepath, file_format="ipc") + ds_versioned = EagerPolarsDataset( filepath=filepath, file_format="ipc", version=Version(load_version, save_version), @@ -275,8 +276,8 @@ def test_version_str_repr(self, filepath_ipc, load_version, save_version): assert filepath in str(ds_versioned) ver_str = f"version=Version(load={load_version}, save='{save_version}')" assert ver_str in str(ds_versioned) - assert "GenericDataset" in str(ds_versioned) - assert "GenericDataset" in str(ds) + assert "EagerPolarsDataset" in str(ds_versioned) + assert "EagerPolarsDataset" in str(ds) def test_multiple_loads(self, versioned_ipc_dataset, dummy_dataframe, filepath_ipc): """Test that if a new version is created mid-run, by an @@ -288,7 +289,7 @@ def test_multiple_loads(self, versioned_ipc_dataset, dummy_dataframe, filepath_i sleep(0.5) # force-drop a newer version into the same location v_new = generate_timestamp() - GenericDataset( + EagerPolarsDataset( filepath=filepath_ipc.as_posix(), file_format="ipc", version=Version(v_new, v_new), @@ -298,7 +299,7 @@ def test_multiple_loads(self, versioned_ipc_dataset, dummy_dataframe, filepath_i v2 = versioned_ipc_dataset.resolve_load_version() assert v2 == v1 # v2 should not be v_new! - ds_new = GenericDataset( + ds_new = EagerPolarsDataset( filepath=filepath_ipc.as_posix(), file_format="ipc", version=Version(None, None), @@ -309,7 +310,7 @@ def test_multiple_loads(self, versioned_ipc_dataset, dummy_dataframe, filepath_i def test_multiple_saves(self, dummy_dataframe, filepath_ipc): """Test multiple cycles of save followed by load for the same dataset""" - ds_versioned = GenericDataset( + ds_versioned = EagerPolarsDataset( filepath=filepath_ipc.as_posix(), file_format="ipc", version=Version(None, None), @@ -330,7 +331,7 @@ def test_multiple_saves(self, dummy_dataframe, filepath_ipc): assert second_load_version > first_load_version # another dataset - ds_new = GenericDataset( + ds_new = EagerPolarsDataset( filepath=filepath_ipc.as_posix(), file_format="ipc", version=Version(None, None), @@ -338,13 +339,13 @@ def test_multiple_saves(self, dummy_dataframe, filepath_ipc): assert ds_new.resolve_load_version() == second_load_version -class TestGenericCSVDatasetVersioned: +class TestEagerCSVDatasetVersioned: def test_version_str_repr(self, filepath_csv, load_version, save_version): """Test that version is in string representation of the class instance when applicable.""" filepath = filepath_csv.as_posix() - ds = GenericDataset(filepath=filepath, file_format="csv") - ds_versioned = GenericDataset( + ds = EagerPolarsDataset(filepath=filepath, file_format="csv") + ds_versioned = EagerPolarsDataset( filepath=filepath, file_format="csv", version=Version(load_version, save_version), @@ -353,8 +354,8 @@ def test_version_str_repr(self, filepath_csv, load_version, save_version): assert filepath in str(ds_versioned) ver_str = f"version=Version(load={load_version}, save='{save_version}')" assert ver_str in str(ds_versioned) - assert "GenericDataset" in str(ds_versioned) - assert "GenericDataset" in str(ds) + assert "EagerPolarsDataset" in str(ds_versioned) + assert "EagerPolarsDataset" in str(ds) assert "protocol" in str(ds_versioned) assert "protocol" in str(ds) @@ -375,7 +376,7 @@ def test_multiple_loads(self, versioned_csv_dataset, dummy_dataframe, filepath_c sleep(0.5) # force-drop a newer version into the same location v_new = generate_timestamp() - GenericDataset( + EagerPolarsDataset( filepath=filepath_csv.as_posix(), file_format="csv", version=Version(v_new, v_new), @@ -385,7 +386,7 @@ def test_multiple_loads(self, versioned_csv_dataset, dummy_dataframe, filepath_c v2 = versioned_csv_dataset.resolve_load_version() assert v2 == v1 # v2 should not be v_new! - ds_new = GenericDataset( + ds_new = EagerPolarsDataset( filepath=filepath_csv.as_posix(), file_format="csv", version=Version(None, None), @@ -396,7 +397,7 @@ def test_multiple_loads(self, versioned_csv_dataset, dummy_dataframe, filepath_c def test_multiple_saves(self, dummy_dataframe, filepath_csv): """Test multiple cycles of save followed by load for the same dataset""" - ds_versioned = GenericDataset( + ds_versioned = EagerPolarsDataset( filepath=filepath_csv.as_posix(), file_format="csv", version=Version(None, None), @@ -417,7 +418,7 @@ def test_multiple_saves(self, dummy_dataframe, filepath_csv): assert second_load_version > first_load_version # another dataset - ds_new = GenericDataset( + ds_new = EagerPolarsDataset( filepath=filepath_csv.as_posix(), file_format="csv", version=Version(None, None), @@ -426,7 +427,7 @@ def test_multiple_saves(self, dummy_dataframe, filepath_csv): def test_release_instance_cache(self, dummy_dataframe, filepath_csv): """Test that cache invalidation does not affect other instances""" - ds_a = GenericDataset( + ds_a = EagerPolarsDataset( filepath=filepath_csv.as_posix(), file_format="csv", version=Version(None, None), @@ -435,7 +436,7 @@ def test_release_instance_cache(self, dummy_dataframe, filepath_csv): ds_a.save(dummy_dataframe) # create a version assert ds_a._version_cache.currsize == 2 - ds_b = GenericDataset( + ds_b = EagerPolarsDataset( filepath=filepath_csv.as_posix(), file_format="csv", version=Version(None, None), @@ -456,7 +457,7 @@ def test_release_instance_cache(self, dummy_dataframe, filepath_csv): def test_no_versions(self, versioned_csv_dataset): """Check the error if no versions are available for load.""" - pattern = r"Did not find any versions for GenericDataset\(.+\)" + pattern = r"Did not find any versions for EagerPolarsDataset\(.+\)" with pytest.raises(DatasetError, match=pattern): versioned_csv_dataset.load() @@ -471,7 +472,7 @@ def test_prevent_overwrite(self, versioned_csv_dataset, dummy_dataframe): corresponding Generic (csv) file for a given save version already exists.""" versioned_csv_dataset.save(dummy_dataframe) pattern = ( - r"Save path \'.+\' for GenericDataset\(.+\) must " + r"Save path \'.+\' for EagerPolarsDataset\(.+\) must " r"not exist if versioning is enabled\." ) with pytest.raises(DatasetError, match=pattern): @@ -490,7 +491,7 @@ def test_save_version_warning( the subsequent load path.""" pattern = ( rf"Save version '{save_version}' did not match load version " - rf"'{load_version}' for GenericDataset\(.+\)" + rf"'{load_version}' for EagerPolarsDataset\(.+\)" ) with pytest.warns(UserWarning, match=pattern): versioned_csv_dataset.save(dummy_dataframe) @@ -516,9 +517,9 @@ def test_versioning_existing_dataset( assert versioned_csv_dataset.exists() -class TestBadGenericDataset: +class TestBadEagerPolarsDataset: def test_bad_file_format_argument(self): - ds = GenericDataset(filepath="test.kedro", file_format="kedro") + ds = EagerPolarsDataset(filepath="test.kedro", file_format="kedro") pattern = ( "Unable to retrieve 'polars.DataFrame.write_kedro' method, please " diff --git a/kedro-datasets/tests/polars/test_lazy_polars_dataset.py b/kedro-datasets/tests/polars/test_lazy_polars_dataset.py new file mode 100644 index 000000000..4eac0accd --- /dev/null +++ b/kedro-datasets/tests/polars/test_lazy_polars_dataset.py @@ -0,0 +1,412 @@ +import re +from pathlib import Path, PurePosixPath +from time import sleep + +import boto3 +import polars as pl +import pytest +from adlfs import AzureBlobFileSystem +from fsspec.implementations.http import HTTPFileSystem +from fsspec.implementations.local import LocalFileSystem +from gcsfs import GCSFileSystem +from kedro.io import DatasetError, Version +from kedro.io.core import PROTOCOL_DELIMITER, generate_timestamp +from moto import mock_s3 +from polars.testing import assert_frame_equal +from s3fs.core import S3FileSystem + +from kedro_datasets.polars import LazyPolarsDataset +from kedro_datasets.polars.lazy_polars_dataset import ACCEPTED_FILE_FORMATS + +BUCKET_NAME = "test_bucket" +FILE_NAME = "test.csv" + + +@pytest.fixture +def filepath_csv(tmp_path): + return (tmp_path / "test.csv").as_posix() + + +@pytest.fixture +def filepath_pq(tmp_path): + return (tmp_path / "test.pq").as_posix() + + +@pytest.fixture +def dummy_dataframe(): + return pl.DataFrame({"col1": [1, 2], "col2": [4, 5], "col3": [5, 6]}) + + +@pytest.fixture +def csv_data_set(filepath_csv, load_args, save_args, fs_args): + return LazyPolarsDataset( + filepath=filepath_csv, + file_format="csv", + load_args=load_args, + save_args=save_args, + fs_args=fs_args, + ) + + +@pytest.fixture +def parquet_dataset(filepath_pq, load_args, save_args, fs_args): + return LazyPolarsDataset( + filepath=filepath_pq, + file_format="parquet", + load_args=load_args, + save_args=save_args, + fs_args=fs_args, + ) + + +@pytest.fixture +def parquet_dataset_ignore(filepath_pq): + return LazyPolarsDataset( + filepath=filepath_pq, + file_format="parquet", + load_args={"low_memory": True}, + ) + + +@pytest.fixture +def versioned_parquet_dataset( + filepath_pq, + save_args, + load_version, + save_version, +): + return LazyPolarsDataset( + filepath=filepath_pq, + file_format="parquet", + version=Version(load_version, save_version), + save_args={}, + ) + + +@pytest.fixture +def mocked_s3_bucket(): + """Create a bucket for testing using moto.""" + with mock_s3(): + conn = boto3.client( + "s3", + aws_access_key_id="fake_access_key", + aws_secret_access_key="fake_secret_key", + ) + conn.create_bucket(Bucket=BUCKET_NAME) + yield conn + + +@pytest.fixture +def mocked_csv_in_s3(mocked_s3_bucket, dummy_dataframe): + mocked_s3_bucket.put_object( + Bucket=BUCKET_NAME, + Key=FILE_NAME, + Body=dummy_dataframe.write_csv(), + ) + return f"s3://{BUCKET_NAME}/{FILE_NAME}" + + +class TestLazyCSVDataset: + """Test class for LazyPolarsDataset csv functionality""" + + def test_exists(self, csv_data_set, dummy_dataframe): + """Test `exists` method invocation for both existing and + nonexistent data set. + """ + assert not csv_data_set.exists() + csv_data_set.save(dummy_dataframe) + assert csv_data_set.exists() + + def test_load(self, dummy_dataframe, csv_data_set, filepath_csv): + dummy_dataframe.write_csv(filepath_csv) + df = csv_data_set.load() + assert df.collect().shape == (2, 3) + + def test_load_s3(self, dummy_dataframe, mocked_csv_in_s3): + ds = LazyPolarsDataset(mocked_csv_in_s3, file_format="csv") + + assert ds._protocol == "s3" + + loaded_df = ds.load().collect() + assert_frame_equal(loaded_df, dummy_dataframe) + + def test_save_and_load(self, csv_data_set, dummy_dataframe): + csv_data_set.save(dummy_dataframe) + reloaded_df = csv_data_set.load().collect() + assert_frame_equal(dummy_dataframe, reloaded_df) + + def test_load_missing_file(self, csv_data_set): + """Check the error when trying to load missing file.""" + pattern = r"Failed while loading data from data set LazyPolarsDataset\(.*\)" + with pytest.raises(DatasetError, match=pattern): + csv_data_set.load() + + @pytest.mark.parametrize( + "load_args", [{"k1": "v1", "index": "value"}], indirect=True + ) + def test_load_extra_params(self, csv_data_set, load_args): + """Test overriding the default load arguments.""" + for key, value in load_args.items(): + assert csv_data_set._load_args[key] == value + + @pytest.mark.parametrize( + "save_args", [{"k1": "v1", "index": "value"}], indirect=True + ) + def test_save_extra_params(self, csv_data_set, save_args): + """Test overriding the default save arguments.""" + for key, value in save_args.items(): + assert csv_data_set._save_args[key] == value + + @pytest.mark.parametrize( + "load_args,save_args", + [ + ({"storage_options": {"a": "b"}}, {}), + ({}, {"storage_options": {"a": "b"}}), + ({"storage_options": {"a": "b"}}, {"storage_options": {"x": "y"}}), + ], + ) + def test_storage_options_dropped(self, load_args, save_args, caplog, tmp_path): + filepath = str(tmp_path / "test.csv") + + ds = LazyPolarsDataset( + filepath=filepath, + file_format="csv", + load_args=load_args, + save_args=save_args, + ) + + records = [r for r in caplog.records if r.levelname == "WARNING"] + expected_log_message = ( + f"Dropping 'storage_options' for {filepath}, " + f"please specify them under 'fs_args' or 'credentials'." + ) + assert records[0].getMessage() == expected_log_message + assert "storage_options" not in ds._save_args + assert "storage_options" not in ds._load_args + + @pytest.mark.parametrize( + "filepath,instance_type,credentials", + [ + ("s3://bucket/file.csv", S3FileSystem, {}), + ("file:///tmp/test.csv", LocalFileSystem, {}), + ("/tmp/test.csv", LocalFileSystem, {}), + ("gcs://bucket/file.csv", GCSFileSystem, {}), + ("https://example.com/file.csv", HTTPFileSystem, {}), + ( + "abfs://bucket/file.csv", + AzureBlobFileSystem, + {"account_name": "test", "account_key": "test"}, + ), + ], + ) + def test_protocol_usage(self, filepath, instance_type, credentials): + dataset = LazyPolarsDataset( + filepath=filepath, + file_format="csv", + credentials=credentials, + ) + assert isinstance(dataset._fs, instance_type) + + path = filepath.split(PROTOCOL_DELIMITER, 1)[-1] + + assert str(dataset._filepath) == path + assert isinstance(dataset._filepath, PurePosixPath) + + def test_catalog_release(self, mocker): + fs_mock = mocker.patch("fsspec.filesystem").return_value + filepath = "test.csv" + dataset = LazyPolarsDataset(filepath=filepath, file_format="csv") + assert dataset._version_cache.currsize == 0 # no cache if unversioned + dataset.release() + fs_mock.invalidate_cache.assert_called_once_with(filepath) + assert dataset._version_cache.currsize == 0 + + +class TestLazyParquetDatasetVersioned: + def test_load_args(self, parquet_dataset_ignore, dummy_dataframe, filepath_pq): + dummy_dataframe.write_parquet(filepath_pq) + df = parquet_dataset_ignore.load().collect() + assert df.shape == (2, 3) + + def test_save_and_load(self, versioned_parquet_dataset, dummy_dataframe): + """Test saving and reloading the data set.""" + versioned_parquet_dataset.save(dummy_dataframe.lazy()) + reloaded_df = versioned_parquet_dataset.load().collect() + assert_frame_equal(dummy_dataframe, reloaded_df) + + def test_version_str_repr(self, filepath_pq, load_version, save_version): + """Test that version is in string representation of the class instance + when applicable.""" + ds = LazyPolarsDataset(filepath=filepath_pq, file_format="parquet") + ds_versioned = LazyPolarsDataset( + filepath=filepath_pq, + file_format="parquet", + version=Version(load_version, save_version), + ) + assert filepath_pq in str(ds) + assert filepath_pq in str(ds_versioned) + ver_str = f"version=Version(load={load_version}, save='{save_version}')" + assert ver_str in str(ds_versioned) + assert "LazyPolarsDataset" in str(ds_versioned) + assert "LazyPolarsDataset" in str(ds) + assert "protocol" in str(ds_versioned) + assert "protocol" in str(ds) + + def test_multiple_loads( + self, versioned_parquet_dataset, dummy_dataframe, filepath_pq + ): + """Test that if a new version is created mid-run, by an + external system, it won't be loaded in the current run.""" + versioned_parquet_dataset.save(dummy_dataframe) + versioned_parquet_dataset.load() + v1 = versioned_parquet_dataset.resolve_load_version() + + sleep(0.5) + # force-drop a newer version into the same location + v_new = generate_timestamp() + LazyPolarsDataset( + filepath=filepath_pq, + file_format="parquet", + version=Version(v_new, v_new), + ).save(dummy_dataframe) + + versioned_parquet_dataset.load() + v2 = versioned_parquet_dataset.resolve_load_version() + + assert v2 == v1 # v2 should not be v_new! + ds_new = LazyPolarsDataset( + filepath=filepath_pq, + file_format="parquet", + version=Version(None, None), + ) + assert ( + ds_new.resolve_load_version() == v_new + ) # new version is discoverable by a new instance + + def test_multiple_saves(self, dummy_dataframe, filepath_pq): + """Test multiple cycles of save followed by load for the same dataset""" + ds_versioned = LazyPolarsDataset( + filepath=filepath_pq, + file_format="parquet", + version=Version(None, None), + ) + + # first save + ds_versioned.save(dummy_dataframe) + first_save_version = ds_versioned.resolve_save_version() + first_load_version = ds_versioned.resolve_load_version() + assert first_load_version == first_save_version + + # second save + sleep(0.5) + ds_versioned.save(dummy_dataframe) + second_save_version = ds_versioned.resolve_save_version() + second_load_version = ds_versioned.resolve_load_version() + assert second_load_version == second_save_version + assert second_load_version > first_load_version + + # another dataset + ds_new = LazyPolarsDataset( + filepath=filepath_pq, + file_format="parquet", + version=Version(None, None), + ) + assert ds_new.resolve_load_version() == second_load_version + + def test_release_instance_cache(self, dummy_dataframe, filepath_pq): + """Test that cache invalidation does not affect other instances""" + ds_a = LazyPolarsDataset( + filepath=filepath_pq, + file_format="parquet", + version=Version(None, None), + ) + assert ds_a._version_cache.currsize == 0 + ds_a.save(dummy_dataframe) # create a version + assert ds_a._version_cache.currsize == 2 + + ds_b = LazyPolarsDataset( + filepath=filepath_pq, + file_format="parquet", + version=Version(None, None), + ) + assert ds_b._version_cache.currsize == 0 + ds_b.resolve_save_version() + assert ds_b._version_cache.currsize == 1 + ds_b.resolve_load_version() + assert ds_b._version_cache.currsize == 2 + + ds_a.release() + + # dataset A cache is cleared + assert ds_a._version_cache.currsize == 0 + + # dataset B cache is unaffected + assert ds_b._version_cache.currsize == 2 + + def test_no_versions(self, versioned_parquet_dataset): + """Check the error if no versions are available for load.""" + pattern = r"Did not find any versions for LazyPolarsDataset\(.+\)" + with pytest.raises(DatasetError, match=pattern): + versioned_parquet_dataset.load() + + def test_prevent_overwrite(self, versioned_parquet_dataset, dummy_dataframe): + """Check the error when attempting to override the data set if the + corresponding Generic (parquet) file for a given save version already exists.""" + versioned_parquet_dataset.save(dummy_dataframe) + pattern = ( + r"Save path \'.+\' for LazyPolarsDataset\(.+\) must " + r"not exist if versioning is enabled\." + ) + with pytest.raises(DatasetError, match=pattern): + versioned_parquet_dataset.save(dummy_dataframe) + + @pytest.mark.parametrize( + "load_version", ["2019-01-01T23.59.59.999Z"], indirect=True + ) + @pytest.mark.parametrize( + "save_version", ["2019-01-02T00.00.00.000Z"], indirect=True + ) + def test_save_version_warning( + self, versioned_parquet_dataset, load_version, save_version, dummy_dataframe + ): + """Check the warning when saving to the path that differs from + the subsequent load path.""" + pattern = ( + rf"Save version '{save_version}' did not match load version " + rf"'{load_version}' for LazyPolarsDataset\(.+\)" + ) + with pytest.warns(UserWarning, match=pattern): + versioned_parquet_dataset.save(dummy_dataframe) + + def test_versioning_existing_dataset( + self, parquet_dataset, versioned_parquet_dataset, dummy_dataframe + ): + """Check the error when attempting to save a versioned dataset on top of an + already existing (non-versioned) dataset.""" + parquet_dataset.save(dummy_dataframe) + assert parquet_dataset.exists() + assert parquet_dataset._filepath == versioned_parquet_dataset._filepath + pattern = ( + f"(?=.*file with the same name already exists in the directory)" + f"(?=.*{versioned_parquet_dataset._filepath.parent.as_posix()})" + ) + with pytest.raises(DatasetError, match=pattern): + versioned_parquet_dataset.save(dummy_dataframe) + + # Remove non-versioned dataset and try again + Path(parquet_dataset._filepath.as_posix()).unlink() + versioned_parquet_dataset.save(dummy_dataframe) + assert versioned_parquet_dataset.exists() + + +class TestBadLazyPolarsDataset: + def test_bad_file_format_argument(self): + + pattern = ( + "'kedro' is not an accepted format " + f"({ACCEPTED_FILE_FORMATS}) ensure that your 'file_format' parameter " + "has been defined correctly as per the Polars API " + "https://pola-rs.github.io/polars/py-polars/html/reference/io.html" + ) + with pytest.raises(DatasetError, match=re.escape(pattern)): + LazyPolarsDataset(filepath="test.kedro", file_format="kedro")