Skip to content

Commit

Permalink
Add tests for resume scenario suggestion to parallel runner (#4417)
Browse files Browse the repository at this point in the history
* Add tests for resume scenario suggestion to parallel runner

Signed-off-by: Merel Theisen <[email protected]>
  • Loading branch information
merelcht authored Jan 15, 2025
1 parent 9610e78 commit a565d66
Showing 1 changed file with 91 additions and 0 deletions.
91 changes: 91 additions & 0 deletions tests/runner/test_parallel_runner.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import re
from concurrent.futures.process import ProcessPoolExecutor
from typing import Any

Expand Down Expand Up @@ -253,6 +254,23 @@ def _describe(self) -> dict[str, Any]:
ParallelRunnerManager.register("LoggingDataset", LoggingDataset)


@pytest.fixture
def logging_dataset_catalog():
log = []
persistent_dataset = LoggingDataset(log, "in", "stuff")
return DataCatalog(
{
"ds0_A": persistent_dataset,
"ds0_B": persistent_dataset,
"ds2_A": persistent_dataset,
"ds2_B": persistent_dataset,
"dsX": persistent_dataset,
"dsY": persistent_dataset,
"params:p": MemoryDataset(1),
}
)


@pytest.mark.parametrize("is_async", [False, True])
class TestParallelRunnerRelease:
def test_dont_release_inputs_and_outputs(self, is_async):
Expand Down Expand Up @@ -342,3 +360,76 @@ def test_release_transcoded(self, is_async):

# we want to see both datasets being released
assert list(log) == [("release", "save"), ("load", "load"), ("release", "load")]


class TestSuggestResumeScenario:
@pytest.mark.parametrize(
"failing_node_names,expected_pattern",
[
(["node1_A", "node1_B"], r"No nodes ran."),
(["node2"], r"(node1_A,node1_B|node1_B,node1_A)"),
(["node3_A"], r"(node3_A,node3_B|node3_B,node3_A|node3_A)"),
(["node4_A"], r"(node3_A,node3_B|node3_B,node3_A|node3_A)"),
(["node3_A", "node4_A"], r"(node3_A,node3_B|node3_B,node3_A|node3_A)"),
(["node2", "node4_A"], r"(node1_A,node1_B|node1_B,node1_A)"),
],
)
def test_suggest_resume_scenario(
self,
caplog,
two_branches_crossed_pipeline,
logging_dataset_catalog,
failing_node_names,
expected_pattern,
):
nodes = {n.name: n for n in two_branches_crossed_pipeline.nodes}
for name in failing_node_names:
two_branches_crossed_pipeline -= modular_pipeline([nodes[name]])
two_branches_crossed_pipeline += modular_pipeline(
[nodes[name]._copy(func=exception_fn)]
)
with pytest.raises(Exception):
ParallelRunner().run(
two_branches_crossed_pipeline,
logging_dataset_catalog,
hook_manager=_create_hook_manager(),
)
assert re.search(expected_pattern, caplog.text)

@pytest.mark.parametrize(
"failing_node_names,expected_pattern",
[
(["node1_A", "node1_B"], r"No nodes ran."),
(["node2"], r'"node1_A,node1_B"'),
(["node3_A"], r"(node3_A,node3_B|node3_A)"),
(["node4_A"], r"(node3_A,node3_B|node3_A)"),
(["node3_A", "node4_A"], r"(node3_A,node3_B|node3_A)"),
(["node2", "node4_A"], r'"node1_A,node1_B"'),
],
)
def test_stricter_suggest_resume_scenario(
self,
caplog,
two_branches_crossed_pipeline_variable_inputs,
logging_dataset_catalog,
failing_node_names,
expected_pattern,
):
"""
Stricter version of previous test.
Covers pipelines where inputs are shared across nodes.
"""
test_pipeline = two_branches_crossed_pipeline_variable_inputs

nodes = {n.name: n for n in test_pipeline.nodes}
for name in failing_node_names:
test_pipeline -= modular_pipeline([nodes[name]])
test_pipeline += modular_pipeline([nodes[name]._copy(func=exception_fn)])

with pytest.raises(Exception, match="test exception"):
ParallelRunner().run(
test_pipeline,
logging_dataset_catalog,
hook_manager=_create_hook_manager(),
)
assert re.search(expected_pattern, caplog.text)

0 comments on commit a565d66

Please sign in to comment.