Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(datasets): Add parameter to enable/disable lazy saving for PartitionedDataset #978

Open
wants to merge 21 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions kedro-datasets/RELEASE.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
# Upcoming Release
## Major features and improvements


## Bug fixes and other changes

- Made `PartitionedDataset` accept only lambda functions for lazy saving and ignore other callable objects.

## Breaking Changes
## Community contributions

Expand Down
9 changes: 5 additions & 4 deletions kedro-datasets/kedro_datasets/dask/csv_dataset.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""``CSVDataset`` is a dataset used to load and save data to CSV files using Dask
dataframe"""

from __future__ import annotations

from copy import deepcopy
Expand All @@ -13,7 +14,7 @@
class CSVDataset(AbstractDataset[dd.DataFrame, dd.DataFrame]):
"""``CSVDataset`` loads and saves data to comma-separated value file(s). It uses Dask
remote data services to handle the corresponding load and save operations:
https://docs.dask.org/en/latest/how-to/connect-to-remote-data.html
https://docs.dask.org/en/stable/how-to/connect-to-remote-data.html

Example usage for the
`YAML API <https://docs.kedro.org/en/stable/data/data_catalog_yaml_examples.html>`_:
Expand Down Expand Up @@ -67,13 +68,13 @@ def __init__( # noqa: PLR0913
filepath: Filepath in POSIX format to a CSV file
CSV collection or the directory of a multipart CSV.
load_args: Additional loading options `dask.dataframe.read_csv`:
https://docs.dask.org/en/latest/generated/dask.dataframe.read_csv.html
https://docs.dask.org/en/stable/generated/dask.dataframe.read_csv.html
save_args: Additional saving options for `dask.dataframe.to_csv`:
https://docs.dask.org/en/latest/generated/dask.dataframe.to_csv.html
https://docs.dask.org/en/stable/generated/dask.dataframe.to_csv.html
credentials: Credentials required to get access to the underlying filesystem.
E.g. for ``GCSFileSystem`` it should look like `{"token": None}`.
fs_args: Optional parameters to the backend file system driver:
https://docs.dask.org/en/latest/how-to/connect-to-remote-data.html#optional-parameters
https://docs.dask.org/en/stable/how-to/connect-to-remote-data.html#optional-parameters
metadata: Any arbitrary metadata.
This is ignored by Kedro, but may be consumed by users or external plugins.
"""
Expand Down
9 changes: 5 additions & 4 deletions kedro-datasets/kedro_datasets/dask/parquet_dataset.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""``ParquetDataset`` is a dataset used to load and save data to parquet files using Dask
dataframe"""

from __future__ import annotations

from copy import deepcopy
Expand All @@ -14,7 +15,7 @@
class ParquetDataset(AbstractDataset[dd.DataFrame, dd.DataFrame]):
"""``ParquetDataset`` loads and saves data to parquet file(s). It uses Dask
remote data services to handle the corresponding load and save operations:
https://docs.dask.org/en/latest/how-to/connect-to-remote-data.html
https://docs.dask.org/en/stable/how-to/connect-to-remote-data.html

Example usage for the
`YAML API <https://docs.kedro.org/en/stable/data/data_catalog_yaml_examples.html>`_:
Expand Down Expand Up @@ -97,13 +98,13 @@ def __init__( # noqa: PLR0913
filepath: Filepath in POSIX format to a parquet file
parquet collection or the directory of a multipart parquet.
load_args: Additional loading options `dask.dataframe.read_parquet`:
https://docs.dask.org/en/latest/generated/dask.dataframe.read_parquet.html
https://docs.dask.org/en/stable/generated/dask.dataframe.read_parquet.html
save_args: Additional saving options for `dask.dataframe.to_parquet`:
https://docs.dask.org/en/latest/generated/dask.dataframe.to_parquet.html
https://docs.dask.org/en/stable/generated/dask.dataframe.to_parquet.html
credentials: Credentials required to get access to the underlying filesystem.
E.g. for ``GCSFileSystem`` it should look like `{"token": None}`.
fs_args: Optional parameters to the backend file system driver:
https://docs.dask.org/en/latest/how-to/connect-to-remote-data.html#optional-parameters
https://docs.dask.org/en/stable/how-to/connect-to-remote-data.html#optional-parameters
metadata: Any arbitrary metadata.
This is ignored by Kedro, but may be consumed by users or external plugins.
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,11 @@ def _grandparent(path: str) -> str:
return str(grandparent)


def _islambda(obj: object):
"""Check if object is a lambda function."""
return callable(obj) and hasattr(obj, "__name__") and obj.__name__ == "<lambda>"


class PartitionedDataset(AbstractDataset[dict[str, Any], dict[str, Callable[[], Any]]]):
"""``PartitionedDataset`` loads and saves partitioned file-like data using the
underlying dataset definition. For filesystem level operations it uses `fsspec`:
Expand Down Expand Up @@ -311,7 +316,7 @@ def save(self, data: dict[str, Any]) -> None:
# join the protocol back since tools like PySpark may rely on it
kwargs[self._filepath_arg] = self._join_protocol(partition)
dataset = self._dataset_type(**kwargs) # type: ignore
if callable(partition_data):
if _islambda(partition_data):
partition_data = partition_data() # noqa: PLW2901
dataset.save(partition_data)
self._invalidate_caches()
Expand Down
27 changes: 23 additions & 4 deletions kedro-datasets/tests/partitions/test_partitioned_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ def filepath_csvs(tmp_path):
]


def original_data_callable():
return pd.DataFrame({"foo": 42, "bar": ["a", "b", None]})


class FakeDataset: # pylint: disable=too-few-public-methods
pass

Expand Down Expand Up @@ -101,24 +105,39 @@ def test_save(self, dataset, local_csvs, suffix):
reloaded_data = loaded_partitions[part_id]()
assert_frame_equal(reloaded_data, original_data)

@pytest.mark.parametrize("dataset", ["kedro_datasets.pickle.PickleDataset"])
@pytest.mark.parametrize("suffix", ["", ".csv"])
def test_callable_save(self, dataset, local_csvs, suffix):
pds = PartitionedDataset(
path=str(local_csvs), dataset=dataset, filename_suffix=suffix
)

part_id = "new/data"
pds.save({part_id: original_data_callable})

assert (local_csvs / "new" / ("data" + suffix)).is_file()
loaded_partitions = pds.load()
assert part_id in loaded_partitions
reloaded_data = loaded_partitions[part_id]()
assert reloaded_data == original_data_callable

@pytest.mark.parametrize("dataset", LOCAL_DATASET_DEFINITION)
@pytest.mark.parametrize("suffix", ["", ".csv"])
def test_lazy_save(self, dataset, local_csvs, suffix):
pds = PartitionedDataset(
path=str(local_csvs), dataset=dataset, filename_suffix=suffix
)

def original_data():
return pd.DataFrame({"foo": 42, "bar": ["a", "b", None]})
original_data = pd.DataFrame({"foo": 42, "bar": ["a", "b", None]})

part_id = "new/data"
pds.save({part_id: original_data})
pds.save({part_id: lambda: original_data})

assert (local_csvs / "new" / ("data" + suffix)).is_file()
loaded_partitions = pds.load()
assert part_id in loaded_partitions
reloaded_data = loaded_partitions[part_id]()
assert_frame_equal(reloaded_data, original_data())
assert_frame_equal(reloaded_data, original_data)

def test_save_invalidates_cache(self, local_csvs, mocker):
"""Test that save calls invalidate partition cache"""
Expand Down
Loading