diff --git a/tests/runner/test_parallel_runner.py b/tests/runner/test_parallel_runner.py index 049b27f200..7173ee047a 100644 --- a/tests/runner/test_parallel_runner.py +++ b/tests/runner/test_parallel_runner.py @@ -1,5 +1,6 @@ from __future__ import annotations +import re from concurrent.futures.process import ProcessPoolExecutor from typing import Any @@ -253,6 +254,23 @@ def _describe(self) -> dict[str, Any]: ParallelRunnerManager.register("LoggingDataset", LoggingDataset) +@pytest.fixture +def logging_dataset_catalog(): + log = [] + persistent_dataset = LoggingDataset(log, "in", "stuff") + return DataCatalog( + { + "ds0_A": persistent_dataset, + "ds0_B": persistent_dataset, + "ds2_A": persistent_dataset, + "ds2_B": persistent_dataset, + "dsX": persistent_dataset, + "dsY": persistent_dataset, + "params:p": MemoryDataset(1), + } + ) + + @pytest.mark.parametrize("is_async", [False, True]) class TestParallelRunnerRelease: def test_dont_release_inputs_and_outputs(self, is_async): @@ -342,3 +360,76 @@ def test_release_transcoded(self, is_async): # we want to see both datasets being released assert list(log) == [("release", "save"), ("load", "load"), ("release", "load")] + + +class TestSuggestResumeScenario: + @pytest.mark.parametrize( + "failing_node_names,expected_pattern", + [ + (["node1_A", "node1_B"], r"No nodes ran."), + (["node2"], r"(node1_A,node1_B|node1_B,node1_A)"), + (["node3_A"], r"(node3_A,node3_B|node3_B,node3_A|node3_A)"), + (["node4_A"], r"(node3_A,node3_B|node3_B,node3_A|node3_A)"), + (["node3_A", "node4_A"], r"(node3_A,node3_B|node3_B,node3_A|node3_A)"), + (["node2", "node4_A"], r"(node1_A,node1_B|node1_B,node1_A)"), + ], + ) + def test_suggest_resume_scenario( + self, + caplog, + two_branches_crossed_pipeline, + logging_dataset_catalog, + failing_node_names, + expected_pattern, + ): + nodes = {n.name: n for n in two_branches_crossed_pipeline.nodes} + for name in failing_node_names: + two_branches_crossed_pipeline -= modular_pipeline([nodes[name]]) + two_branches_crossed_pipeline += modular_pipeline( + [nodes[name]._copy(func=exception_fn)] + ) + with pytest.raises(Exception): + ParallelRunner().run( + two_branches_crossed_pipeline, + logging_dataset_catalog, + hook_manager=_create_hook_manager(), + ) + assert re.search(expected_pattern, caplog.text) + + @pytest.mark.parametrize( + "failing_node_names,expected_pattern", + [ + (["node1_A", "node1_B"], r"No nodes ran."), + (["node2"], r'"node1_A,node1_B"'), + (["node3_A"], r"(node3_A,node3_B|node3_A)"), + (["node4_A"], r"(node3_A,node3_B|node3_A)"), + (["node3_A", "node4_A"], r"(node3_A,node3_B|node3_A)"), + (["node2", "node4_A"], r'"node1_A,node1_B"'), + ], + ) + def test_stricter_suggest_resume_scenario( + self, + caplog, + two_branches_crossed_pipeline_variable_inputs, + logging_dataset_catalog, + failing_node_names, + expected_pattern, + ): + """ + Stricter version of previous test. + Covers pipelines where inputs are shared across nodes. + """ + test_pipeline = two_branches_crossed_pipeline_variable_inputs + + nodes = {n.name: n for n in test_pipeline.nodes} + for name in failing_node_names: + test_pipeline -= modular_pipeline([nodes[name]]) + test_pipeline += modular_pipeline([nodes[name]._copy(func=exception_fn)]) + + with pytest.raises(Exception, match="test exception"): + ParallelRunner().run( + test_pipeline, + logging_dataset_catalog, + hook_manager=_create_hook_manager(), + ) + assert re.search(expected_pattern, caplog.text)