From 40589fe7df170358b60a336782720f6b14fb0253 Mon Sep 17 00:00:00 2001 From: Pierre Marcenac Date: Wed, 25 Sep 2024 09:59:28 +0000 Subject: [PATCH] Check that the mapping is valid after setting it. --- .../mlcroissant/mlcroissant/_src/datasets.py | 11 +++++++ .../mlcroissant/_src/datasets_test.py | 31 ++++++++++++++++++- 2 files changed, 41 insertions(+), 1 deletion(-) diff --git a/python/mlcroissant/mlcroissant/_src/datasets.py b/python/mlcroissant/mlcroissant/_src/datasets.py index cd2bc858..b5709d83 100644 --- a/python/mlcroissant/mlcroissant/_src/datasets.py +++ b/python/mlcroissant/mlcroissant/_src/datasets.py @@ -46,6 +46,16 @@ def get_operations(ctx: Context, metadata: Metadata) -> OperationGraph: return operations +def _check_mapping(metadata: Metadata, mapping: Mapping[str, epath.Path]): + """Checks that the mapping is valid, i.e. keys are actual UUIDs and paths exist.""" + uuids = set([node.uuid for node in metadata.nodes()]) + for uuid, path in mapping.items(): + if uuid not in uuids: + raise ValueError(f"{uuid=} in the mapping doesn't exist in the JSON-LD") + if not path.exists(): + raise ValueError(f"{path=} doesn't exist on disk") + + def _expand_mapping( mapping: Mapping[str, epath.PathLike] | None, ) -> Mapping[str, epath.Path]: @@ -83,6 +93,7 @@ def __post_init__(self): self.metadata = Metadata.from_file(ctx=ctx, file=self.jsonld) else: return + _check_mapping(self.metadata, ctx.mapping) # Draw the structure graph for debugging purposes. if self.debug: graphs_utils.pretty_print_graph(ctx.graph) diff --git a/python/mlcroissant/mlcroissant/_src/datasets_test.py b/python/mlcroissant/mlcroissant/_src/datasets_test.py index 97cf22cb..7562ae2b 100644 --- a/python/mlcroissant/mlcroissant/_src/datasets_test.py +++ b/python/mlcroissant/mlcroissant/_src/datasets_test.py @@ -88,6 +88,7 @@ def load_records_and_test_equality( record_set_name: str, num_records: int, filters: dict[str, Any] | None = None, + mapping: dict[str, epath.PathLike] | None = None, ): filters_command = "" if filters: @@ -109,7 +110,7 @@ def load_records_and_test_equality( with output_file.open("rb") as f: lines = f.readlines() expected_records = [json.loads(line) for line in lines] - dataset = datasets.Dataset(config) + dataset = datasets.Dataset(config, mapping=mapping) records = dataset.records(record_set_name, filters=filters) records = iter(records) length = 0 @@ -330,3 +331,31 @@ def test_validate_filters(filters, raises): datasets._validate_filters(filters) else: datasets._validate_filters(filters) + + +@parametrize_version() +def test_check_mapping_when_the_key_does_not_exist(version): + dataset_name = "simple-parquet/metadata.json" + record_set_name = "persons" + with pytest.raises(ValueError, match="doesn't exist in the JSON-LD"): + load_records_and_test_equality( + version, + dataset_name, + record_set_name, + -1, + mapping={"this_UUID_does_not_exist": "/this/path/does/not/exist"}, + ) + + +@parametrize_version() +def test_check_mapping_when_the_path_does_not_exist(version): + dataset_name = "simple-parquet/metadata.json" + record_set_name = "persons" + with pytest.raises(ValueError, match="doesn't exist on disk"): + load_records_and_test_equality( + version, + dataset_name, + record_set_name, + -1, + mapping={"dataframe": "/this/path/does/not/exist"}, + )