Skip to content

Commit

Permalink
Check that the mapping is valid after setting it.
Browse files Browse the repository at this point in the history
  • Loading branch information
marcenacp committed Sep 25, 2024
1 parent 7574ded commit 40589fe
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 1 deletion.
11 changes: 11 additions & 0 deletions python/mlcroissant/mlcroissant/_src/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,16 @@ def get_operations(ctx: Context, metadata: Metadata) -> OperationGraph:
return operations


def _check_mapping(metadata: Metadata, mapping: Mapping[str, epath.Path]):
"""Checks that the mapping is valid, i.e. keys are actual UUIDs and paths exist."""
uuids = set([node.uuid for node in metadata.nodes()])
for uuid, path in mapping.items():
if uuid not in uuids:
raise ValueError(f"{uuid=} in the mapping doesn't exist in the JSON-LD")
if not path.exists():
raise ValueError(f"{path=} doesn't exist on disk")


def _expand_mapping(
mapping: Mapping[str, epath.PathLike] | None,
) -> Mapping[str, epath.Path]:
Expand Down Expand Up @@ -83,6 +93,7 @@ def __post_init__(self):
self.metadata = Metadata.from_file(ctx=ctx, file=self.jsonld)
else:
return
_check_mapping(self.metadata, ctx.mapping)
# Draw the structure graph for debugging purposes.
if self.debug:
graphs_utils.pretty_print_graph(ctx.graph)
Expand Down
31 changes: 30 additions & 1 deletion python/mlcroissant/mlcroissant/_src/datasets_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ def load_records_and_test_equality(
record_set_name: str,
num_records: int,
filters: dict[str, Any] | None = None,
mapping: dict[str, epath.PathLike] | None = None,
):
filters_command = ""
if filters:
Expand All @@ -109,7 +110,7 @@ def load_records_and_test_equality(
with output_file.open("rb") as f:
lines = f.readlines()
expected_records = [json.loads(line) for line in lines]
dataset = datasets.Dataset(config)
dataset = datasets.Dataset(config, mapping=mapping)
records = dataset.records(record_set_name, filters=filters)
records = iter(records)
length = 0
Expand Down Expand Up @@ -330,3 +331,31 @@ def test_validate_filters(filters, raises):
datasets._validate_filters(filters)
else:
datasets._validate_filters(filters)


@parametrize_version()
def test_check_mapping_when_the_key_does_not_exist(version):
dataset_name = "simple-parquet/metadata.json"
record_set_name = "persons"
with pytest.raises(ValueError, match="doesn't exist in the JSON-LD"):
load_records_and_test_equality(
version,
dataset_name,
record_set_name,
-1,
mapping={"this_UUID_does_not_exist": "/this/path/does/not/exist"},
)


@parametrize_version()
def test_check_mapping_when_the_path_does_not_exist(version):
dataset_name = "simple-parquet/metadata.json"
record_set_name = "persons"
with pytest.raises(ValueError, match="doesn't exist on disk"):
load_records_and_test_equality(
version,
dataset_name,
record_set_name,
-1,
mapping={"dataframe": "/this/path/does/not/exist"},
)

0 comments on commit 40589fe

Please sign in to comment.