diff --git a/python/mlcroissant/mlcroissant/_src/datasets.py b/python/mlcroissant/mlcroissant/_src/datasets.py index cd2bc858..b5709d83 100644 --- a/python/mlcroissant/mlcroissant/_src/datasets.py +++ b/python/mlcroissant/mlcroissant/_src/datasets.py @@ -46,6 +46,16 @@ def get_operations(ctx: Context, metadata: Metadata) -> OperationGraph: return operations +def _check_mapping(metadata: Metadata, mapping: Mapping[str, epath.Path]): + """Checks that the mapping is valid, i.e. keys are actual UUIDs and paths exist.""" + uuids = set([node.uuid for node in metadata.nodes()]) + for uuid, path in mapping.items(): + if uuid not in uuids: + raise ValueError(f"{uuid=} in the mapping doesn't exist in the JSON-LD") + if not path.exists(): + raise ValueError(f"{path=} doesn't exist on disk") + + def _expand_mapping( mapping: Mapping[str, epath.PathLike] | None, ) -> Mapping[str, epath.Path]: @@ -83,6 +93,7 @@ def __post_init__(self): self.metadata = Metadata.from_file(ctx=ctx, file=self.jsonld) else: return + _check_mapping(self.metadata, ctx.mapping) # Draw the structure graph for debugging purposes. if self.debug: graphs_utils.pretty_print_graph(ctx.graph) diff --git a/python/mlcroissant/mlcroissant/_src/datasets_test.py b/python/mlcroissant/mlcroissant/_src/datasets_test.py index 97cf22cb..72cceeed 100644 --- a/python/mlcroissant/mlcroissant/_src/datasets_test.py +++ b/python/mlcroissant/mlcroissant/_src/datasets_test.py @@ -15,6 +15,8 @@ from mlcroissant._src.tests.records import record_to_python from mlcroissant._src.tests.versions import parametrize_version +_REPOSITORY_FOLDER = epath.Path(__file__).parent.parent.parent.parent.parent + # End-to-end tests on real data. The data is in `tests/graphs/*/metadata.json`. def get_error_msg(folder: epath.Path): @@ -88,6 +90,7 @@ def load_records_and_test_equality( record_set_name: str, num_records: int, filters: dict[str, Any] | None = None, + mapping: dict[str, epath.PathLike] | None = None, ): filters_command = "" if filters: @@ -99,17 +102,12 @@ def load_records_and_test_equality( f" {record_set_name} --num_records {num_records} --debug --update_output" f" {filters_command}`" ) - config = ( - epath.Path(__file__).parent.parent.parent.parent.parent - / "datasets" - / version - / dataset_name - ) + config = _REPOSITORY_FOLDER / "datasets" / version / dataset_name output_file = config.parent / "output" / f"{record_set_name}.jsonl" with output_file.open("rb") as f: lines = f.readlines() expected_records = [json.loads(line) for line in lines] - dataset = datasets.Dataset(config) + dataset = datasets.Dataset(config, mapping=mapping) records = dataset.records(record_set_name, filters=filters) records = iter(records) length = 0 @@ -144,12 +142,7 @@ def load_records_with_beam_and_test_equality( dataset_name: str, record_set_name: str, ): - jsonld = ( - epath.Path(__file__).parent.parent.parent.parent.parent - / "datasets" - / version - / dataset_name - ) + jsonld = _REPOSITORY_FOLDER / "datasets" / version / dataset_name output_file = jsonld.parent / "output" / f"{record_set_name}.jsonl" with output_file.open("rb") as f: lines = f.readlines() @@ -281,12 +274,7 @@ def test_load_from_huggingface(): @parametrize_version() def test_raises_when_the_record_set_does_not_exist(version): - dataset_folder = ( - epath.Path(__file__).parent.parent.parent.parent.parent - / "datasets" - / version - / "titanic" - ) + dataset_folder = _REPOSITORY_FOLDER / "datasets" / version / "titanic" dataset = datasets.Dataset(dataset_folder / "metadata.json") with pytest.raises(ValueError, match="did not find"): dataset.records("this_record_set_does_not_exist") @@ -297,15 +285,9 @@ def test_cypress_fixtures(version): # Cypress cannot read files outside of its direct scope, so we have to copy them # as fixtures. This test tests that the copies are equal to the original. fixture_folder: epath.Path = ( - epath.Path(__file__).parent.parent.parent.parent.parent - / "editor" - / "cypress" - / "fixtures" - / version - ) - datasets_folder: epath.Path = ( - epath.Path(__file__).parent.parent.parent.parent.parent / "datasets" / version + _REPOSITORY_FOLDER / "editor" / "cypress" / "fixtures" / version ) + datasets_folder: epath.Path = _REPOSITORY_FOLDER / "datasets" / version for fixture in fixture_folder.glob("*.json"): dataset = datasets_folder / f"{fixture.stem}" / "metadata.json" assert json.load(fixture.open()) == json.load(dataset.open()), ( @@ -330,3 +312,54 @@ def test_validate_filters(filters, raises): datasets._validate_filters(filters) else: datasets._validate_filters(filters) + + +@parametrize_version() +def test_check_mapping_when_the_key_does_not_exist(version): + dataset_name = "simple-parquet/metadata.json" + record_set_name = "persons" + with pytest.raises(ValueError, match="doesn't exist in the JSON-LD"): + load_records_and_test_equality( + version, + dataset_name, + record_set_name, + -1, + mapping={"this_UUID_does_not_exist": "/this/path/does/not/exist"}, + ) + + +@parametrize_version() +def test_check_mapping_when_the_path_does_not_exist(version): + dataset_name = "simple-parquet/metadata.json" + record_set_name = "persons" + with pytest.raises(ValueError, match="doesn't exist on disk"): + load_records_and_test_equality( + version, + dataset_name, + record_set_name, + -1, + mapping={"dataframe": "/this/path/does/not/exist"}, + ) + + +@parametrize_version() +def test_check_mapping_when_the_mapping_is_correct(version, tmp_path): + dataset_name = "simple-parquet/metadata.json" + record_set_name = "persons" + old_path = ( + _REPOSITORY_FOLDER + / "datasets" + / version + / "simple-parquet/data/dataframe.parquet" + ) + assert old_path.exists() + # Copy the dataframe to a temporary file: + new_path = tmp_path / "dataframe.parquet" + old_path.copy(new_path) + load_records_and_test_equality( + version, + dataset_name, + record_set_name, + -1, + mapping={"dataframe": new_path}, + )