-
Notifications
You must be signed in to change notification settings - Fork 90
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: option to group in-memory nodes
Signed-off-by: Simon Brugman <[email protected]>
- Loading branch information
Showing
7 changed files
with
266 additions
and
14 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
from __future__ import annotations | ||
|
||
from collections import defaultdict | ||
|
||
from kedro.io import DataCatalog, MemoryDataSet | ||
from kedro.pipeline.node import Node | ||
from kedro.pipeline.pipeline import Pipeline | ||
|
||
|
||
def _is_memory_dataset(catalog, dataset_name: str) -> bool: | ||
if dataset_name == "parameters" or dataset_name.startswith("params:"): | ||
return False | ||
|
||
dataset = catalog._data_sets.get(dataset_name, None) | ||
return dataset is not None and isinstance(dataset, MemoryDataSet) | ||
|
||
|
||
def get_memory_datasets(catalog: DataCatalog, pipeline: Pipeline) -> set[str]: | ||
"""Gather all datasets in the pipeline that are of type MemoryDataSet, excluding 'parameters'.""" | ||
return { | ||
dataset_name | ||
for dataset_name in pipeline.data_sets() | ||
if _is_memory_dataset(catalog, dataset_name) | ||
} | ||
|
||
|
||
def node_sequence_name(node_sequence: list[Node]) -> str: | ||
return "_".join([node.name for node in node_sequence]) | ||
|
||
|
||
def group_memory_nodes(catalog: DataCatalog, pipeline: Pipeline): | ||
# get all memory datasets in the pipeline | ||
ds = get_memory_datasets(catalog, pipeline) | ||
|
||
# Node sequences | ||
node_sequences = [] | ||
|
||
# Mapping from dataset name -> node sequence index | ||
sequence_map = {} | ||
for node in pipeline.nodes: | ||
if all(o not in ds for o in node.inputs + node.outputs): | ||
# standalone node | ||
node_sequences.append([node]) | ||
else: | ||
if all(i not in ds for i in node.inputs): | ||
# start of a sequence; create a new sequence and store the id | ||
node_sequences.append([node]) | ||
sequence_id = len(node_sequences) - 1 | ||
else: | ||
# continuation of a sequence; retrieve sequence_id | ||
sequence_id = None | ||
for i in node.inputs: | ||
if i in ds: | ||
assert sequence_id is None or sequence_id == sequence_map[i] | ||
sequence_id = sequence_map[i] | ||
|
||
# Append to map | ||
node_sequences[sequence_id].append(node) | ||
|
||
# map outputs to sequence_id | ||
for o in node.outputs: | ||
if o in ds: | ||
sequence_map[o] = sequence_id | ||
|
||
# Named node sequences | ||
nodes = { | ||
node_sequence_name(node_sequence): node_sequence | ||
for node_sequence in node_sequences | ||
} | ||
|
||
# Inverted mapping | ||
node_mapping = { | ||
node.name: sequence_name | ||
for sequence_name, node_sequence in nodes.items() | ||
for node in node_sequence | ||
} | ||
|
||
# Grouped dependencies | ||
dependencies = defaultdict(list) | ||
for node, parent_nodes in pipeline.node_dependencies.items(): | ||
for parent in parent_nodes: | ||
parent_name = node_mapping[parent.name] | ||
node_name = node_mapping[node.name] | ||
if parent_name != node_name and ( | ||
parent_name not in dependencies | ||
or node_name not in dependencies[parent_name] | ||
): | ||
dependencies[parent_name].append(node_name) | ||
|
||
return nodes, dependencies |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,128 @@ | ||
from __future__ import annotations | ||
|
||
from typing import Any | ||
|
||
import pytest | ||
from kedro.io import AbstractDataSet, DataCatalog, MemoryDataSet | ||
from kedro.pipeline import node | ||
from kedro.pipeline.modular_pipeline import pipeline as modular_pipeline | ||
|
||
from kedro_airflow.grouping import _is_memory_dataset, group_memory_nodes | ||
|
||
|
||
class TestDataSet(AbstractDataSet): | ||
def _save(self, data) -> None: | ||
pass | ||
|
||
def _describe(self) -> dict[str, Any]: | ||
return {} | ||
|
||
def _load(self): | ||
return [] | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"memory_nodes,expected_nodes,expected_dependencies", | ||
[ | ||
( | ||
["ds3", "ds6"], | ||
[["f1"], ["f2", "f3", "f4", "f6", "f7"], ["f5"]], | ||
{"f1": ["f2_f3_f4_f6_f7"], "f2_f3_f4_f6_f7": ["f5"]}, | ||
), | ||
( | ||
["ds3"], | ||
[["f1"], ["f2", "f3", "f4", "f7"], ["f5"], ["f6"]], | ||
{"f1": ["f2_f3_f4_f7"], "f2_f3_f4_f7": ["f5", "f6"]}, | ||
), | ||
( | ||
[], | ||
[["f1"], ["f2"], ["f3"], ["f4"], ["f5"], ["f6"], ["f7"]], | ||
{"f1": ["f2"], "f2": ["f3", "f4", "f5", "f7"], "f4": ["f6", "f7"]}, | ||
), | ||
], | ||
) | ||
def test_group_memory_nodes( | ||
memory_nodes: list[str], | ||
expected_nodes: list[list[str]], | ||
expected_dependencies: dict[str, list[str]], | ||
): | ||
"""Check the grouping of memory nodes.""" | ||
nodes = [f"ds{i}" for i in range(1, 10)] | ||
assert all(node_name in nodes for node_name in memory_nodes) | ||
|
||
mock_catalog = DataCatalog() | ||
for dataset_name in nodes: | ||
if dataset_name in memory_nodes: | ||
dataset = MemoryDataSet() | ||
else: | ||
dataset = TestDataSet() | ||
mock_catalog.add(dataset_name, dataset) | ||
|
||
def identity_one_to_one(x): | ||
return x | ||
|
||
mock_pipeline = modular_pipeline( | ||
[ | ||
node( | ||
func=identity_one_to_one, | ||
inputs="ds1", | ||
outputs="ds2", | ||
name="f1", | ||
), | ||
node( | ||
func=lambda x: (x, x), | ||
inputs="ds2", | ||
outputs=["ds3", "ds4"], | ||
name="f2", | ||
), | ||
node( | ||
func=identity_one_to_one, | ||
inputs="ds3", | ||
outputs="ds5", | ||
name="f3", | ||
), | ||
node( | ||
func=identity_one_to_one, | ||
inputs="ds3", | ||
outputs="ds6", | ||
name="f4", | ||
), | ||
node( | ||
func=identity_one_to_one, | ||
inputs="ds4", | ||
outputs="ds8", | ||
name="f5", | ||
), | ||
node( | ||
func=identity_one_to_one, | ||
inputs="ds6", | ||
outputs="ds7", | ||
name="f6", | ||
), | ||
node( | ||
func=lambda x, y: x, | ||
inputs=["ds3", "ds6"], | ||
outputs="ds9", | ||
name="f7", | ||
), | ||
], | ||
) | ||
|
||
nodes, dependencies = group_memory_nodes(mock_catalog, mock_pipeline) | ||
sequence = [ | ||
[node_.name for node_ in node_sequence] for node_sequence in nodes.values() | ||
] | ||
assert sequence == expected_nodes | ||
assert dict(dependencies) == expected_dependencies | ||
|
||
|
||
def test_is_memory_dataset(): | ||
catalog = DataCatalog() | ||
catalog.add("parameters", {"hello": "world"}) | ||
catalog.add("params:hello", "world") | ||
catalog.add("my_dataset", MemoryDataSet(True)) | ||
catalog.add("test_dataset", TestDataSet()) | ||
assert not _is_memory_dataset(catalog, "parameters") | ||
assert not _is_memory_dataset(catalog, "params:hello") | ||
assert _is_memory_dataset(catalog, "my_dataset") | ||
assert not _is_memory_dataset(catalog, "test_dataset") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters