Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor run methods more into abstract method #4353

Merged
merged 25 commits into from
Dec 10, 2024
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
fc9c5c1
First refactor step to make _run() implementations more similar
merelcht Nov 4, 2024
62569f7
Move common logic to abstract _run method, use executor for sequentia…
merelcht Nov 5, 2024
dcd6991
Refactor max workers logic into shared helper method
merelcht Nov 5, 2024
58abcf7
Add resume scenario logic
merelcht Nov 5, 2024
9e2e278
Small cleanup
merelcht Nov 5, 2024
3a9dde6
Clean up
merelcht Nov 26, 2024
af955b8
Fix mypy checks
merelcht Nov 26, 2024
39ac276
Merge branch 'main' into refactor/abstract-more-into-run-method
merelcht Nov 26, 2024
a2db9cf
Fix sequential runner test
merelcht Nov 26, 2024
41154a2
Fix thread runner
merelcht Nov 26, 2024
186f845
Ignore coverage for abstract method
merelcht Nov 26, 2024
123f678
Merge branch 'main' into refactor/abstract-more-into-run-method
merelcht Nov 26, 2024
b589fa2
Try fix thread runner test on 3.13
merelcht Nov 26, 2024
00ca950
Fix thread runner test
merelcht Nov 26, 2024
33b3ec6
Fix sequential runner test on windows
merelcht Nov 26, 2024
351e0ef
More flexible options for resume suggestion in thread runner tests
merelcht Nov 26, 2024
6659132
Clean up + make resume tests the same
merelcht Nov 27, 2024
f4b3610
Merge branch 'main' into refactor/abstract-more-into-run-method
merelcht Nov 27, 2024
2fa4fa8
Update tests/runner/test_sequential_runner.py
merelcht Nov 27, 2024
b3268ce
Clean up
merelcht Nov 27, 2024
7a42cd5
Merge branch 'main' into refactor/abstract-more-into-run-method
merelcht Dec 2, 2024
a990155
Address review comments
merelcht Dec 3, 2024
1c76953
Apply suggestions from code review
merelcht Dec 10, 2024
3b36f8a
Merge branch 'main' into refactor/abstract-more-into-run-method
merelcht Dec 10, 2024
be2e285
Fix lint
merelcht Dec 10, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 13 additions & 69 deletions kedro/runner/parallel_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,7 @@

from __future__ import annotations

import os
import sys
from collections import Counter
from concurrent.futures import FIRST_COMPLETED, ProcessPoolExecutor, wait
from itertools import chain
from concurrent.futures import ProcessPoolExecutor
from multiprocessing.managers import BaseProxy, SyncManager
from multiprocessing.reduction import ForkingPickler
from pickle import PicklingError
Expand All @@ -20,8 +16,7 @@
MemoryDataset,
SharedMemoryDataset,
)
from kedro.runner.runner import AbstractRunner
from kedro.runner.task import Task
from kedro.runner.runner import AbstractRunner, validate_max_workers

if TYPE_CHECKING:
from collections.abc import Iterable
Expand All @@ -31,9 +26,6 @@
from kedro.pipeline import Pipeline
from kedro.pipeline.node import Node

# see https://github.com/python/cpython/blob/master/Lib/concurrent/futures/process.py#L114
_MAX_WINDOWS_WORKERS = 61


class ParallelRunnerManager(SyncManager):
"""``ParallelRunnerManager`` is used to create shared ``MemoryDataset``
Expand Down Expand Up @@ -83,16 +75,7 @@ def __init__(
self._manager = ParallelRunnerManager()
self._manager.start()

# This code comes from the concurrent.futures library
# https://github.com/python/cpython/blob/master/Lib/concurrent/futures/process.py#L588
if max_workers is None:
# NOTE: `os.cpu_count` might return None in some weird cases.
# https://github.com/python/cpython/blob/3.7/Modules/posixmodule.c#L11431
max_workers = os.cpu_count() or 1
if sys.platform == "win32":
max_workers = min(_MAX_WINDOWS_WORKERS, max_workers)

self._max_workers = max_workers
self._max_workers = validate_max_workers(max_workers)

def __del__(self) -> None:
self._manager.shutdown()
Expand Down Expand Up @@ -189,14 +172,17 @@ def _get_required_workers_count(self, pipeline: Pipeline) -> int:

return min(required_processes, self._max_workers)

def _get_executor(self, max_workers: int) -> ProcessPoolExecutor:
merelcht marked this conversation as resolved.
Show resolved Hide resolved
return ProcessPoolExecutor(max_workers=max_workers)

def _run(
self,
pipeline: Pipeline,
catalog: CatalogProtocol,
hook_manager: PluginManager,
hook_manager: PluginManager | None = None,
session_id: str | None = None,
) -> None:
"""The abstract interface for running pipelines.
"""The method implementing parallel pipeline running.

Args:
pipeline: The ``Pipeline`` to run.
Expand All @@ -218,50 +204,8 @@ def _run(
"for potential performance gains. https://docs.kedro.org/en/stable/nodes_and_pipelines/run_a_pipeline.html#load-and-save-asynchronously"
)

nodes = pipeline.nodes
self._validate_catalog(catalog, pipeline)
self._validate_nodes(nodes)
self._set_manager_datasets(catalog, pipeline)
load_counts = Counter(chain.from_iterable(n.inputs for n in nodes))
node_dependencies = pipeline.node_dependencies
todo_nodes = set(node_dependencies.keys())
done_nodes: set[Node] = set()
futures = set()
done = None
max_workers = self._get_required_workers_count(pipeline)

with ProcessPoolExecutor(max_workers=max_workers) as pool:
while True:
ready = {n for n in todo_nodes if node_dependencies[n] <= done_nodes}
todo_nodes -= ready
for node in ready:
task = Task(
node=node,
catalog=catalog,
is_async=self._is_async,
session_id=session_id,
parallel=True,
)
futures.add(pool.submit(task))
if not futures:
if todo_nodes:
debug_data = {
"todo_nodes": todo_nodes,
"done_nodes": done_nodes,
"ready_nodes": ready,
"done_futures": done,
}
debug_data_str = "\n".join(
f"{k} = {v}" for k, v in debug_data.items()
)
raise RuntimeError(
f"Unable to schedule new tasks although some nodes "
f"have not been run:\n{debug_data_str}"
)
break # pragma: no cover
done, futures = wait(futures, return_when=FIRST_COMPLETED)
for future in done:
node = future.result()
done_nodes.add(node)

self._release_datasets(node, catalog, load_counts, pipeline)
super()._run(
pipeline=pipeline,
catalog=catalog,
session_id=session_id,
)
137 changes: 132 additions & 5 deletions kedro/runner/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,31 @@

import inspect
import logging
import os
import sys
import warnings
from abc import ABC, abstractmethod
from collections import deque
from collections import Counter, deque
from concurrent.futures import (
FIRST_COMPLETED,
ProcessPoolExecutor,
ThreadPoolExecutor,
wait,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would this import multiprocessing as a dependencies? I recalled in the past we have issues with ShelveStore because even importing the library cause issues on restricted environment like AWS Lambda.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm yeah I think it does.. Do you remember what the problem was with that?

)
from itertools import chain
from typing import TYPE_CHECKING, Any

from pluggy import PluginManager

from kedro import KedroDeprecationWarning
from kedro.framework.hooks.manager import _NullPluginManager
from kedro.io import CatalogProtocol, MemoryDataset, SharedMemoryDataset
from kedro.pipeline import Pipeline
from kedro.runner.task import Task

# see https://github.com/python/cpython/blob/master/Lib/concurrent/futures/process.py#L114
_MAX_WINDOWS_WORKERS = 61

if TYPE_CHECKING:
from collections.abc import Collection, Iterable

Expand Down Expand Up @@ -166,25 +180,97 @@ def run_only_missing(

return self.run(to_rerun, catalog, hook_manager)

@abstractmethod # pragma: no cover
def _get_executor(
merelcht marked this conversation as resolved.
Show resolved Hide resolved
self, max_workers: int
) -> ThreadPoolExecutor | ProcessPoolExecutor:
"""Abstract method to provide the correct executor (e.g., ThreadPoolExecutor or ProcessPoolExecutor)."""
pass

@abstractmethod # pragma: no cover
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this still an abstractmethod?

Copy link
Member Author

@merelcht merelcht Dec 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, because it's still necessary to have _run() for any Runner class and you can still overwrite this when creating a custom Runner. But now it's also easier to create a custom runner, because the _run() method has more in place to get started from.

def _run(
self,
pipeline: Pipeline,
catalog: CatalogProtocol,
hook_manager: PluginManager,
hook_manager: PluginManager | None = None,
session_id: str | None = None,
) -> None:
"""The abstract interface for running pipelines, assuming that the
inputs have already been checked and normalized by run().
inputs have already been checked and normalized by run().
This contains the Common pipeline execution logic using an executor.

Args:
pipeline: The ``Pipeline`` to run.
catalog: An implemented instance of ``CatalogProtocol`` from which to fetch data.
hook_manager: The ``PluginManager`` to activate hooks.
session_id: The id of the session.

"""
pass

nodes = pipeline.nodes

self._validate_catalog(catalog, pipeline)
self._validate_nodes(nodes)
self._set_manager_datasets(catalog, pipeline)

load_counts = Counter(chain.from_iterable(n.inputs for n in pipeline.nodes))
node_dependencies = pipeline.node_dependencies
todo_nodes = set(node_dependencies.keys())
done_nodes: set[Node] = set()
futures = set()
done = None
max_workers = self._get_required_workers_count(pipeline)

with self._get_executor(max_workers) as pool:
while True:
ready = {n for n in todo_nodes if node_dependencies[n] <= done_nodes}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I propose reducing the complexity of that line in a separate PR because this becomes more critical with the unification of runners. As I understand, the previous implementation of the SequentialRunner (most popular runner) simply executed an ordered list of nodes sequentially. However, the current approach performs a full loop over all rest nodes after each executed node, this results in an overall complexity greater than quadratic.

I believe we can achieve linear complexity by maintaining a count of unmet dependencies for each node. As nodes are completed, we decrement the counter for their dependents and mark a node as ready when its counter reaches zero.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good point @DimedS. I'll create a new issue for this, so it's not forgotten.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

todo_nodes -= ready
for node in ready:
task = Task(
node=node,
catalog=catalog,
hook_manager=hook_manager,
is_async=self._is_async,
session_id=session_id,
)
if isinstance(pool, ProcessPoolExecutor):
task.parallel = True
futures.add(pool.submit(task))
if not futures:
if todo_nodes:
self._raise_runtime_error(todo_nodes, done_nodes, ready, done)
break
done, futures = wait(futures, return_when=FIRST_COMPLETED)
for future in done:
try:
node = future.result()
except Exception:
self._suggest_resume_scenario(pipeline, done_nodes, catalog)
raise
done_nodes.add(node)
self._logger.info("Completed node: %s", node.name)
self._logger.info(
"Completed %d out of %d tasks", len(done_nodes), len(nodes)
)
self._release_datasets(node, catalog, load_counts, pipeline)

@staticmethod
def _raise_runtime_error(
todo_nodes: set[Node],
done_nodes: set[Node],
ready: set[Node],
done: set[Node] | None,
) -> None:
debug_data = {
"todo_nodes": todo_nodes,
"done_nodes": done_nodes,
"ready_nodes": ready,
"done_futures": done,
}
debug_data_str = "\n".join(f"{k} = {v}" for k, v in debug_data.items())
raise RuntimeError(
f"Unable to schedule new tasks although some nodes "
f"have not been run:\n{debug_data_str}"
)

def _suggest_resume_scenario(
self,
Expand Down Expand Up @@ -240,6 +326,23 @@ def _release_datasets(
if load_counts[dataset] < 1 and dataset not in pipeline.outputs():
catalog.release(dataset)

def _validate_catalog(self, catalog: CatalogProtocol, pipeline: Pipeline) -> None:
# Add catalog validation logic here if needed
pass

def _validate_nodes(self, node: Iterable[Node]) -> None:
# Add node validation logic here if needed
pass

def _set_manager_datasets(
self, catalog: CatalogProtocol, pipeline: Pipeline
) -> None:
# Set up any necessary manager datasets here
pass

def _get_required_workers_count(self, pipeline: Pipeline) -> int:
return 1


def _find_nodes_to_resume_from(
pipeline: Pipeline, unfinished_nodes: Collection[Node], catalog: CatalogProtocol
Expand Down Expand Up @@ -443,3 +546,27 @@ def run_node(
)
node = task.execute()
return node


def validate_max_workers(max_workers: int | None) -> int:
merelcht marked this conversation as resolved.
Show resolved Hide resolved
"""
Validates and returns the number of workers. Sets to os.cpu_count() or 1 if max_workers is None,
and limits max_workers to 61 on Windows.

Args:
max_workers: Desired number of workers. If None, defaults to os.cpu_count() or 1.

Returns:
A valid number of workers to use.

Raises:
ValueError: If max_workers is set and is not positive.
"""
if max_workers is None:
max_workers = os.cpu_count() or 1
if sys.platform == "win32":
max_workers = min(_MAX_WINDOWS_WORKERS, max_workers)
elif max_workers <= 0:
raise ValueError("max_workers should be positive")

return max_workers
43 changes: 15 additions & 28 deletions kedro/runner/sequential_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@

from __future__ import annotations

from collections import Counter
from itertools import chain
from concurrent.futures import (
ThreadPoolExecutor,
)
from typing import TYPE_CHECKING, Any

from kedro.runner.runner import AbstractRunner
from kedro.runner.task import Task

if TYPE_CHECKING:
from pluggy import PluginManager
Expand Down Expand Up @@ -46,11 +46,16 @@ def __init__(
is_async=is_async, extra_dataset_patterns=self._extra_dataset_patterns
)

def _get_executor(self, max_workers: int) -> ThreadPoolExecutor:
merelcht marked this conversation as resolved.
Show resolved Hide resolved
merelcht marked this conversation as resolved.
Show resolved Hide resolved
return ThreadPoolExecutor(
max_workers=1
) # Single-threaded for sequential execution

def _run(
self,
pipeline: Pipeline,
catalog: CatalogProtocol,
hook_manager: PluginManager,
hook_manager: PluginManager | None = None,
session_id: str | None = None,
) -> None:
"""The method implementing sequential pipeline running.
Expand All @@ -69,27 +74,9 @@ def _run(
"Using synchronous mode for loading and saving data. Use the --async flag "
"for potential performance gains. https://docs.kedro.org/en/stable/nodes_and_pipelines/run_a_pipeline.html#load-and-save-asynchronously"
)
nodes = pipeline.nodes
done_nodes = set()

load_counts = Counter(chain.from_iterable(n.inputs for n in nodes))

for exec_index, node in enumerate(nodes):
try:
Task(
node=node,
catalog=catalog,
hook_manager=hook_manager,
is_async=self._is_async,
session_id=session_id,
).execute()
done_nodes.add(node)
except Exception:
self._suggest_resume_scenario(pipeline, done_nodes, catalog)
raise

self._release_datasets(node, catalog, load_counts, pipeline)

self._logger.info(
"Completed %d out of %d tasks", len(done_nodes), len(nodes)
)
super()._run(
pipeline=pipeline,
catalog=catalog,
hook_manager=hook_manager,
session_id=session_id,
)
Loading