Skip to content

Commit

Permalink
Check that the mapping is valid after setting it. (#747)
Browse files Browse the repository at this point in the history
  • Loading branch information
marcenacp authored Sep 25, 2024
1 parent 7574ded commit df13167
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 27 deletions.
11 changes: 11 additions & 0 deletions python/mlcroissant/mlcroissant/_src/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,16 @@ def get_operations(ctx: Context, metadata: Metadata) -> OperationGraph:
return operations


def _check_mapping(metadata: Metadata, mapping: Mapping[str, epath.Path]):
"""Checks that the mapping is valid, i.e. keys are actual UUIDs and paths exist."""
uuids = set([node.uuid for node in metadata.nodes()])
for uuid, path in mapping.items():
if uuid not in uuids:
raise ValueError(f"{uuid=} in the mapping doesn't exist in the JSON-LD")
if not path.exists():
raise ValueError(f"{path=} doesn't exist on disk")


def _expand_mapping(
mapping: Mapping[str, epath.PathLike] | None,
) -> Mapping[str, epath.Path]:
Expand Down Expand Up @@ -83,6 +93,7 @@ def __post_init__(self):
self.metadata = Metadata.from_file(ctx=ctx, file=self.jsonld)
else:
return
_check_mapping(self.metadata, ctx.mapping)
# Draw the structure graph for debugging purposes.
if self.debug:
graphs_utils.pretty_print_graph(ctx.graph)
Expand Down
87 changes: 60 additions & 27 deletions python/mlcroissant/mlcroissant/_src/datasets_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
from mlcroissant._src.tests.records import record_to_python
from mlcroissant._src.tests.versions import parametrize_version

_REPOSITORY_FOLDER = epath.Path(__file__).parent.parent.parent.parent.parent


# End-to-end tests on real data. The data is in `tests/graphs/*/metadata.json`.
def get_error_msg(folder: epath.Path):
Expand Down Expand Up @@ -88,6 +90,7 @@ def load_records_and_test_equality(
record_set_name: str,
num_records: int,
filters: dict[str, Any] | None = None,
mapping: dict[str, epath.PathLike] | None = None,
):
filters_command = ""
if filters:
Expand All @@ -99,17 +102,12 @@ def load_records_and_test_equality(
f" {record_set_name} --num_records {num_records} --debug --update_output"
f" {filters_command}`"
)
config = (
epath.Path(__file__).parent.parent.parent.parent.parent
/ "datasets"
/ version
/ dataset_name
)
config = _REPOSITORY_FOLDER / "datasets" / version / dataset_name
output_file = config.parent / "output" / f"{record_set_name}.jsonl"
with output_file.open("rb") as f:
lines = f.readlines()
expected_records = [json.loads(line) for line in lines]
dataset = datasets.Dataset(config)
dataset = datasets.Dataset(config, mapping=mapping)
records = dataset.records(record_set_name, filters=filters)
records = iter(records)
length = 0
Expand Down Expand Up @@ -144,12 +142,7 @@ def load_records_with_beam_and_test_equality(
dataset_name: str,
record_set_name: str,
):
jsonld = (
epath.Path(__file__).parent.parent.parent.parent.parent
/ "datasets"
/ version
/ dataset_name
)
jsonld = _REPOSITORY_FOLDER / "datasets" / version / dataset_name
output_file = jsonld.parent / "output" / f"{record_set_name}.jsonl"
with output_file.open("rb") as f:
lines = f.readlines()
Expand Down Expand Up @@ -281,12 +274,7 @@ def test_load_from_huggingface():

@parametrize_version()
def test_raises_when_the_record_set_does_not_exist(version):
dataset_folder = (
epath.Path(__file__).parent.parent.parent.parent.parent
/ "datasets"
/ version
/ "titanic"
)
dataset_folder = _REPOSITORY_FOLDER / "datasets" / version / "titanic"
dataset = datasets.Dataset(dataset_folder / "metadata.json")
with pytest.raises(ValueError, match="did not find"):
dataset.records("this_record_set_does_not_exist")
Expand All @@ -297,15 +285,9 @@ def test_cypress_fixtures(version):
# Cypress cannot read files outside of its direct scope, so we have to copy them
# as fixtures. This test tests that the copies are equal to the original.
fixture_folder: epath.Path = (
epath.Path(__file__).parent.parent.parent.parent.parent
/ "editor"
/ "cypress"
/ "fixtures"
/ version
)
datasets_folder: epath.Path = (
epath.Path(__file__).parent.parent.parent.parent.parent / "datasets" / version
_REPOSITORY_FOLDER / "editor" / "cypress" / "fixtures" / version
)
datasets_folder: epath.Path = _REPOSITORY_FOLDER / "datasets" / version
for fixture in fixture_folder.glob("*.json"):
dataset = datasets_folder / f"{fixture.stem}" / "metadata.json"
assert json.load(fixture.open()) == json.load(dataset.open()), (
Expand All @@ -330,3 +312,54 @@ def test_validate_filters(filters, raises):
datasets._validate_filters(filters)
else:
datasets._validate_filters(filters)


@parametrize_version()
def test_check_mapping_when_the_key_does_not_exist(version):
dataset_name = "simple-parquet/metadata.json"
record_set_name = "persons"
with pytest.raises(ValueError, match="doesn't exist in the JSON-LD"):
load_records_and_test_equality(
version,
dataset_name,
record_set_name,
-1,
mapping={"this_UUID_does_not_exist": "/this/path/does/not/exist"},
)


@parametrize_version()
def test_check_mapping_when_the_path_does_not_exist(version):
dataset_name = "simple-parquet/metadata.json"
record_set_name = "persons"
with pytest.raises(ValueError, match="doesn't exist on disk"):
load_records_and_test_equality(
version,
dataset_name,
record_set_name,
-1,
mapping={"dataframe": "/this/path/does/not/exist"},
)


@parametrize_version()
def test_check_mapping_when_the_mapping_is_correct(version, tmp_path):
dataset_name = "simple-parquet/metadata.json"
record_set_name = "persons"
old_path = (
_REPOSITORY_FOLDER
/ "datasets"
/ version
/ "simple-parquet/data/dataframe.parquet"
)
assert old_path.exists()
# Copy the dataframe to a temporary file:
new_path = tmp_path / "dataframe.parquet"
old_path.copy(new_path)
load_records_and_test_equality(
version,
dataset_name,
record_set_name,
-1,
mapping={"dataframe": new_path},
)

0 comments on commit df13167

Please sign in to comment.