-
Notifications
You must be signed in to change notification settings - Fork 915
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor run methods more into abstract method #4353
Changes from all commits
fc9c5c1
62569f7
dcd6991
58abcf7
9e2e278
3a9dde6
af955b8
39ac276
a2db9cf
41154a2
186f845
123f678
b589fa2
00ca950
33b3ec6
351e0ef
6659132
f4b3610
2fa4fa8
b3268ce
7a42cd5
a990155
1c76953
3b36f8a
be2e285
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,17 +6,26 @@ | |
|
||
import inspect | ||
import logging | ||
import os | ||
import sys | ||
import warnings | ||
from abc import ABC, abstractmethod | ||
from collections import deque | ||
from collections import Counter, deque | ||
from concurrent.futures import FIRST_COMPLETED, Executor, ProcessPoolExecutor, wait | ||
from itertools import chain | ||
from typing import TYPE_CHECKING, Any | ||
|
||
from pluggy import PluginManager | ||
|
||
from kedro import KedroDeprecationWarning | ||
from kedro.framework.hooks.manager import _NullPluginManager | ||
from kedro.io import CatalogProtocol, MemoryDataset, SharedMemoryDataset | ||
from kedro.pipeline import Pipeline | ||
from kedro.runner.task import Task | ||
|
||
# see https://github.com/python/cpython/blob/master/Lib/concurrent/futures/process.py#L114 | ||
_MAX_WINDOWS_WORKERS = 61 | ||
|
||
if TYPE_CHECKING: | ||
from collections.abc import Collection, Iterable | ||
|
||
|
@@ -166,25 +175,95 @@ def run_only_missing( | |
|
||
return self.run(to_rerun, catalog, hook_manager) | ||
|
||
@abstractmethod # pragma: no cover | ||
def _get_executor(self, max_workers: int) -> Executor: | ||
"""Abstract method to provide the correct executor (e.g., ThreadPoolExecutor or ProcessPoolExecutor).""" | ||
pass | ||
|
||
@abstractmethod # pragma: no cover | ||
def _run( | ||
self, | ||
pipeline: Pipeline, | ||
catalog: CatalogProtocol, | ||
hook_manager: PluginManager, | ||
hook_manager: PluginManager | None = None, | ||
session_id: str | None = None, | ||
) -> None: | ||
"""The abstract interface for running pipelines, assuming that the | ||
inputs have already been checked and normalized by run(). | ||
inputs have already been checked and normalized by run(). | ||
This contains the Common pipeline execution logic using an executor. | ||
|
||
Args: | ||
pipeline: The ``Pipeline`` to run. | ||
catalog: An implemented instance of ``CatalogProtocol`` from which to fetch data. | ||
hook_manager: The ``PluginManager`` to activate hooks. | ||
session_id: The id of the session. | ||
|
||
""" | ||
pass | ||
|
||
nodes = pipeline.nodes | ||
|
||
self._validate_catalog(catalog, pipeline) | ||
self._validate_nodes(nodes) | ||
self._set_manager_datasets(catalog, pipeline) | ||
|
||
load_counts = Counter(chain.from_iterable(n.inputs for n in pipeline.nodes)) | ||
node_dependencies = pipeline.node_dependencies | ||
todo_nodes = set(node_dependencies.keys()) | ||
done_nodes: set[Node] = set() | ||
futures = set() | ||
done = None | ||
max_workers = self._get_required_workers_count(pipeline) | ||
|
||
with self._get_executor(max_workers) as pool: | ||
while True: | ||
ready = {n for n in todo_nodes if node_dependencies[n] <= done_nodes} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I propose reducing the complexity of that line in a separate PR because this becomes more critical with the unification of runners. As I understand, the previous implementation of the SequentialRunner (most popular runner) simply executed an ordered list of nodes sequentially. However, the current approach performs a full loop over all rest nodes after each executed node, this results in an overall complexity greater than quadratic. I believe we can achieve linear complexity by maintaining a count of unmet dependencies for each node. As nodes are completed, we decrement the counter for their dependents and mark a node as ready when its counter reaches zero. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's a good point @DimedS. I'll create a new issue for this, so it's not forgotten. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
todo_nodes -= ready | ||
for node in ready: | ||
task = Task( | ||
node=node, | ||
catalog=catalog, | ||
hook_manager=hook_manager, | ||
is_async=self._is_async, | ||
session_id=session_id, | ||
) | ||
if isinstance(pool, ProcessPoolExecutor): | ||
task.parallel = True | ||
futures.add(pool.submit(task)) | ||
if not futures: | ||
if todo_nodes: | ||
self._raise_runtime_error(todo_nodes, done_nodes, ready, done) | ||
break | ||
done, futures = wait(futures, return_when=FIRST_COMPLETED) | ||
for future in done: | ||
try: | ||
node = future.result() | ||
except Exception: | ||
self._suggest_resume_scenario(pipeline, done_nodes, catalog) | ||
raise | ||
done_nodes.add(node) | ||
self._logger.info("Completed node: %s", node.name) | ||
self._logger.info( | ||
"Completed %d out of %d tasks", len(done_nodes), len(nodes) | ||
) | ||
self._release_datasets(node, catalog, load_counts, pipeline) | ||
|
||
@staticmethod | ||
def _raise_runtime_error( | ||
todo_nodes: set[Node], | ||
done_nodes: set[Node], | ||
ready: set[Node], | ||
done: set[Node] | None, | ||
) -> None: | ||
debug_data = { | ||
"todo_nodes": todo_nodes, | ||
"done_nodes": done_nodes, | ||
"ready_nodes": ready, | ||
"done_futures": done, | ||
} | ||
debug_data_str = "\n".join(f"{k} = {v}" for k, v in debug_data.items()) | ||
raise RuntimeError( | ||
f"Unable to schedule new tasks although some nodes " | ||
f"have not been run:\n{debug_data_str}" | ||
) | ||
|
||
def _suggest_resume_scenario( | ||
self, | ||
|
@@ -240,6 +319,47 @@ def _release_datasets( | |
if load_counts[dataset] < 1 and dataset not in pipeline.outputs(): | ||
catalog.release(dataset) | ||
|
||
def _validate_catalog(self, catalog: CatalogProtocol, pipeline: Pipeline) -> None: | ||
# Add catalog validation logic here if needed | ||
pass | ||
|
||
def _validate_nodes(self, node: Iterable[Node]) -> None: | ||
# Add node validation logic here if needed | ||
pass | ||
|
||
def _set_manager_datasets( | ||
self, catalog: CatalogProtocol, pipeline: Pipeline | ||
) -> None: | ||
# Set up any necessary manager datasets here | ||
pass | ||
|
||
def _get_required_workers_count(self, pipeline: Pipeline) -> int: | ||
return 1 | ||
|
||
@classmethod | ||
def _validate_max_workers(cls, max_workers: int | None) -> int: | ||
""" | ||
Validates and returns the number of workers. Sets to os.cpu_count() or 1 if max_workers is None, | ||
and limits max_workers to 61 on Windows. | ||
|
||
Args: | ||
max_workers: Desired number of workers. If None, defaults to os.cpu_count() or 1. | ||
|
||
Returns: | ||
A valid number of workers to use. | ||
|
||
Raises: | ||
ValueError: If max_workers is set and is not positive. | ||
""" | ||
if max_workers is None: | ||
max_workers = os.cpu_count() or 1 | ||
if sys.platform == "win32": | ||
max_workers = min(_MAX_WINDOWS_WORKERS, max_workers) | ||
elif max_workers <= 0: | ||
raise ValueError("max_workers should be positive") | ||
|
||
return max_workers | ||
|
||
|
||
def _find_nodes_to_resume_from( | ||
pipeline: Pipeline, unfinished_nodes: Collection[Node], catalog: CatalogProtocol | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this still an abstractmethod?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, because it's still necessary to have
_run()
for any Runner class and you can still overwrite this when creating a custom Runner. But now it's also easier to create a custom runner, because the _run() method has more in place to get started from.