Skip to content

Commit

Permalink
Support ArrayNode mapping over Launch Plans (#2480)
Browse files Browse the repository at this point in the history
* set up array node

Signed-off-by: Paul Dittamo <[email protected]>

* wip array node task wrapper

Signed-off-by: Paul Dittamo <[email protected]>

* support function like callability

Signed-off-by: Paul Dittamo <[email protected]>

* temp check in some progress on python func wrapper

Signed-off-by: Paul Dittamo <[email protected]>

* only support launch plans in new array node class for now

Signed-off-by: Paul Dittamo <[email protected]>

* add map task array node implementation wrapper

Signed-off-by: Paul Dittamo <[email protected]>

* ArrayNode only supports LPs for now

Signed-off-by: Paul Dittamo <[email protected]>

* support local execute for new array node implementation

Signed-off-by: Paul Dittamo <[email protected]>

* add local execute unit tests for array node

Signed-off-by: Paul Dittamo <[email protected]>

* set exeucution version in array node spec

Signed-off-by: Paul Dittamo <[email protected]>

* check input types for local execute

Signed-off-by: Paul Dittamo <[email protected]>

* remove code that is un-needed for now

Signed-off-by: Paul Dittamo <[email protected]>

* clean up array node class

Signed-off-by: Paul Dittamo <[email protected]>

* improve naming

Signed-off-by: Paul Dittamo <[email protected]>

* clean up

Signed-off-by: Paul Dittamo <[email protected]>

* utilize enum execution mode to set array node execution path

Signed-off-by: Paul Dittamo <[email protected]>

* default execution mode to FULL_STATE for new array node class

Signed-off-by: Paul Dittamo <[email protected]>

* support min_successes for new array node

Signed-off-by: Paul Dittamo <[email protected]>

* add map task wrapper unit test

Signed-off-by: Paul Dittamo <[email protected]>

* set min successes for array node map task wrapper

Signed-off-by: Paul Dittamo <[email protected]>

* update docstrings

Signed-off-by: Paul Dittamo <[email protected]>

* Install flyteidl from master in plugins tests

Signed-off-by: Eduardo Apolinario <[email protected]>

* lint

Signed-off-by: Paul Dittamo <[email protected]>

* clean up min success/ratio setting

Signed-off-by: Paul Dittamo <[email protected]>

* lint

Signed-off-by: Paul Dittamo <[email protected]>

* make array node class callable

Signed-off-by: Paul Dittamo <[email protected]>

---------

Signed-off-by: Paul Dittamo <[email protected]>
Signed-off-by: Eduardo Apolinario <[email protected]>
Co-authored-by: Eduardo Apolinario <[email protected]>
  • Loading branch information
pvditt and eapolinario authored Jul 31, 2024
1 parent 2b49bb3 commit 676914b
Show file tree
Hide file tree
Showing 6 changed files with 405 additions and 3 deletions.
226 changes: 226 additions & 0 deletions flytekit/core/array_node.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,226 @@
import math
from typing import Any, List, Optional, Set, Tuple, Union

from flyteidl.core import workflow_pb2 as _core_workflow

from flytekit.core import interface as flyte_interface
from flytekit.core.context_manager import ExecutionState, FlyteContext
from flytekit.core.interface import transform_interface_to_list_interface, transform_interface_to_typed_interface
from flytekit.core.launch_plan import LaunchPlan
from flytekit.core.node import Node
from flytekit.core.promise import (
Promise,
VoidPromise,
flyte_entity_call_handler,
translate_inputs_to_literals,
)
from flytekit.core.task import TaskMetadata
from flytekit.loggers import logger
from flytekit.models import literals as _literal_models
from flytekit.models.core import workflow as _workflow_model
from flytekit.models.literals import Literal, LiteralCollection, Scalar


class ArrayNode:
def __init__(
self,
target: LaunchPlan,
execution_mode: _core_workflow.ArrayNode.ExecutionMode = _core_workflow.ArrayNode.FULL_STATE,
concurrency: Optional[int] = None,
min_successes: Optional[int] = None,
min_success_ratio: Optional[float] = None,
bound_inputs: Optional[Set[str]] = None,
metadata: Optional[Union[_workflow_model.NodeMetadata, TaskMetadata]] = None,
):
"""
:param target: The target Flyte entity to map over
:param concurrency: If specified, this limits the number of mapped tasks than can run in parallel to the given batch
size. If the size of the input exceeds the concurrency value, then multiple batches will be run serially until
all inputs are processed. If set to 0, this means unbounded concurrency. If left unspecified, this means the
array node will inherit parallelism from the workflow
:param min_successes: The minimum number of successful executions. If set, this takes precedence over
min_success_ratio
:param min_success_ratio: The minimum ratio of successful executions.
:param bound_inputs: The set of inputs that should be bound to the map task
:param execution_mode: The execution mode for propeller to use when handling ArrayNode
:param metadata: The metadata for the underlying entity
"""
self.target = target
self._concurrency = concurrency
self._execution_mode = execution_mode
self.id = target.name

if min_successes is not None:
self._min_successes = min_successes
self._min_success_ratio = None
else:
self._min_success_ratio = min_success_ratio if min_success_ratio is not None else 1.0
self._min_successes = 0

n_outputs = len(self.target.python_interface.outputs)
if n_outputs > 1:
raise ValueError("Only tasks with a single output are supported in map tasks.")

self._bound_inputs: Set[str] = bound_inputs or set(bound_inputs) if bound_inputs else set()

output_as_list_of_optionals = min_success_ratio is not None and min_success_ratio != 1 and n_outputs == 1
collection_interface = transform_interface_to_list_interface(
self.target.python_interface, self._bound_inputs, output_as_list_of_optionals
)
self._collection_interface = collection_interface

self.metadata = None
if isinstance(target, LaunchPlan):
if self._execution_mode != _core_workflow.ArrayNode.FULL_STATE:
raise ValueError("Only execution version 1 is supported for LaunchPlans.")
if metadata:
if isinstance(metadata, _workflow_model.NodeMetadata):
self.metadata = metadata
else:
raise Exception("Invalid metadata for LaunchPlan. Should be NodeMetadata.")
else:
raise Exception("Only LaunchPlans are supported for now.")

def construct_node_metadata(self) -> _workflow_model.NodeMetadata:
# Part of SupportsNodeCreation interface
# TODO - include passed in metadata
return _workflow_model.NodeMetadata(name=self.target.name)

@property
def name(self) -> str:
# Part of SupportsNodeCreation interface
return self.target.name

@property
def python_interface(self) -> flyte_interface.Interface:
# Part of SupportsNodeCreation interface
return self._collection_interface

@property
def bindings(self) -> List[_literal_models.Binding]:
# Required in get_serializable_node
return []

@property
def upstream_nodes(self) -> List[Node]:
# Required in get_serializable_node
return []

@property
def flyte_entity(self) -> Any:
return self.target

def local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], Promise, VoidPromise]:
outputs_expected = True
if not self.python_interface.outputs:
outputs_expected = False

mapped_entity_count = 0
for k in self.python_interface.inputs.keys():
if k not in self._bound_inputs:
v = kwargs[k]
if isinstance(v, list) and len(v) > 0 and isinstance(v[0], self.target.python_interface.inputs[k]):
mapped_entity_count = len(v)
break
else:
raise ValueError(
f"Expected a list of {self.target.python_interface.inputs[k]} but got {type(v)} instead."
)

failed_count = 0
min_successes = mapped_entity_count
if self._min_successes:
min_successes = self._min_successes
elif self._min_success_ratio:
min_successes = math.ceil(min_successes * self._min_success_ratio)

literals = []
for i in range(mapped_entity_count):
single_instance_inputs = {}
for k in self.python_interface.inputs.keys():
if k not in self._bound_inputs:
single_instance_inputs[k] = kwargs[k][i]
else:
single_instance_inputs[k] = kwargs[k]

# translate Python native inputs to Flyte literals
typed_interface = transform_interface_to_typed_interface(self.target.python_interface)
literal_map = translate_inputs_to_literals(
ctx,
incoming_values=single_instance_inputs,
flyte_interface_types={} if typed_interface is None else typed_interface.inputs,
native_types=self.target.python_interface.inputs,
)
kwargs_literals = {k1: Promise(var=k1, val=v1) for k1, v1 in literal_map.items()}

try:
output = self.target.__call__(**kwargs_literals)
if outputs_expected:
literals.append(output.val)
except Exception as exc:
if outputs_expected:
literal_with_none = Literal(scalar=Scalar(none_type=_literal_models.Void()))
literals.append(literal_with_none)
failed_count += 1
if mapped_entity_count - failed_count < min_successes:
logger.error("The number of successful tasks is lower than the minimum")
raise exc

if outputs_expected:
return Promise(var="o0", val=Literal(collection=LiteralCollection(literals=literals)))
return VoidPromise(self.name)

def local_execution_mode(self):
return ExecutionState.Mode.LOCAL_TASK_EXECUTION

@property
def min_success_ratio(self) -> Optional[float]:
return self._min_success_ratio

@property
def min_successes(self) -> Optional[int]:
return self._min_successes

@property
def concurrency(self) -> Optional[int]:
return self._concurrency

@property
def execution_mode(self) -> _core_workflow.ArrayNode.ExecutionMode:
return self._execution_mode

def __call__(self, *args, **kwargs):
return flyte_entity_call_handler(self, *args, **kwargs)


def array_node(
target: Union[LaunchPlan],
concurrency: Optional[int] = None,
min_success_ratio: Optional[float] = None,
min_successes: Optional[int] = None,
):
"""
ArrayNode implementation that maps over tasks and other Flyte entities
:param target: The target Flyte entity to map over
:param concurrency: If specified, this limits the number of mapped tasks than can run in parallel to the given batch
size. If the size of the input exceeds the concurrency value, then multiple batches will be run serially until
all inputs are processed. If set to 0, this means unbounded concurrency. If left unspecified, this means the
array node will inherit parallelism from the workflow
:param min_successes: The minimum number of successful executions. If set, this takes precedence over
min_success_ratio
:param min_success_ratio: The minimum ratio of successful executions
:return: A callable function that takes in keyword arguments and returns a Promise created by
flyte_entity_call_handler
"""
if not isinstance(target, LaunchPlan):
raise ValueError("Only LaunchPlans are supported for now.")

node = ArrayNode(
target=target,
concurrency=concurrency,
min_successes=min_successes,
min_success_ratio=min_success_ratio,
)

return node
37 changes: 37 additions & 0 deletions flytekit/core/array_node_map_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@

from flytekit.configuration import SerializationSettings
from flytekit.core import tracker
from flytekit.core.array_node import array_node
from flytekit.core.base_task import PythonTask, TaskResolverMixin
from flytekit.core.context_manager import ExecutionState, FlyteContext, FlyteContextManager
from flytekit.core.interface import transform_interface_to_list_interface
from flytekit.core.launch_plan import LaunchPlan
from flytekit.core.python_function_task import PythonFunctionTask, PythonInstanceTask
from flytekit.core.type_engine import TypeEngine, is_annotated
from flytekit.core.utils import timeit
Expand Down Expand Up @@ -347,6 +349,41 @@ def _raw_execute(self, **kwargs) -> Any:


def map_task(
target: Union[LaunchPlan, PythonFunctionTask],
concurrency: Optional[int] = None,
min_successes: Optional[int] = None,
min_success_ratio: float = 1.0,
**kwargs,
):
"""
Wrapper that creates a map task utilizing either the existing ArrayNodeMapTask
or the drop in replacement ArrayNode implementation
:param target: The Flyte entity of which will be mapped over
:param concurrency: If specified, this limits the number of mapped tasks than can run in parallel to the given batch
size. If the size of the input exceeds the concurrency value, then multiple batches will be run serially until
all inputs are processed. If set to 0, this means unbounded concurrency. If left unspecified, this means the
array node will inherit parallelism from the workflow
:param min_successes: The minimum number of successful executions
:param min_success_ratio: The minimum ratio of successful executions
"""
if isinstance(target, LaunchPlan):
return array_node(
target=target,
concurrency=concurrency,
min_successes=min_successes,
min_success_ratio=min_success_ratio,
)
return array_node_map_task(
task_function=target,
concurrency=concurrency,
min_successes=min_successes,
min_success_ratio=min_success_ratio,
**kwargs,
)


def array_node_map_task(
task_function: PythonFunctionTask,
concurrency: Optional[int] = None,
# TODO why no min_successes?
Expand Down
6 changes: 5 additions & 1 deletion flytekit/models/core/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,7 +381,9 @@ def from_flyte_idl(cls, pb2_object: _core_workflow.GateNode) -> "GateNode":


class ArrayNode(_common.FlyteIdlEntity):
def __init__(self, node: "Node", parallelism=None, min_successes=None, min_success_ratio=None) -> None:
def __init__(
self, node: "Node", parallelism=None, min_successes=None, min_success_ratio=None, execution_mode=None
) -> None:
"""
TODO: docstring
"""
Expand All @@ -390,6 +392,7 @@ def __init__(self, node: "Node", parallelism=None, min_successes=None, min_succe
# TODO either min_successes or min_success_ratio should be set
self._min_successes = min_successes
self._min_success_ratio = min_success_ratio
self._execution_mode = execution_mode

@property
def node(self) -> "Node":
Expand All @@ -401,6 +404,7 @@ def to_flyte_idl(self) -> _core_workflow.ArrayNode:
parallelism=self._parallelism,
min_successes=self._min_successes,
min_success_ratio=self._min_success_ratio,
execution_mode=self._execution_mode,
)

@classmethod
Expand Down
1 change: 1 addition & 0 deletions flytekit/remote/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -667,6 +667,7 @@ def raw_register(
workflow_model.WorkflowNode,
workflow_model.BranchNode,
workflow_model.TaskNode,
workflow_model.ArrayNode,
),
):
return None
Expand Down
34 changes: 32 additions & 2 deletions flytekit/tools/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from flytekit.configuration import Image, ImageConfig, SerializationSettings
from flytekit.core import constants as _common_constants
from flytekit.core import context_manager
from flytekit.core.array_node import ArrayNode
from flytekit.core.array_node_map_task import ArrayNodeMapTask
from flytekit.core.base_task import PythonTask
from flytekit.core.condition import BranchNode
Expand Down Expand Up @@ -49,6 +50,7 @@
ReferenceTask,
ReferenceLaunchPlan,
ReferenceEntity,
ArrayNode,
]
FlyteControlPlaneEntity = Union[
TaskSpec,
Expand Down Expand Up @@ -471,15 +473,24 @@ def get_serializable_node(

from flytekit.remote import FlyteLaunchPlan, FlyteTask, FlyteWorkflow

if isinstance(entity.flyte_entity, ArrayNodeMapTask):
if isinstance(entity.flyte_entity, ArrayNode):
node_model = workflow_model.Node(
id=_dnsify(entity.id),
metadata=entity.metadata,
metadata=entity.flyte_entity.construct_node_metadata(),
inputs=entity.bindings,
upstream_node_ids=[n.id for n in upstream_nodes],
output_aliases=[],
array_node=get_serializable_array_node(entity_mapping, settings, entity, options=options),
)
elif isinstance(entity.flyte_entity, ArrayNodeMapTask):
node_model = workflow_model.Node(
id=_dnsify(entity.id),
metadata=entity.metadata,
inputs=entity.bindings,
upstream_node_ids=[n.id for n in upstream_nodes],
output_aliases=[],
array_node=get_serializable_array_node_map_task(entity_mapping, settings, entity, options=options),
)
# TODO: do I need this?
# if entity._aliases:
# node_model._output_aliases = entity._aliases
Expand Down Expand Up @@ -617,6 +628,22 @@ def get_serializable_node(


def get_serializable_array_node(
entity_mapping: OrderedDict,
settings: SerializationSettings,
node: FlyteLocalEntity,
options: Optional[Options] = None,
) -> ArrayNodeModel:
array_node = node.flyte_entity
return ArrayNodeModel(
node=get_serializable_node(entity_mapping, settings, array_node, options=options),
parallelism=array_node.concurrency,
min_successes=array_node.min_successes,
min_success_ratio=array_node.min_success_ratio,
execution_mode=array_node.execution_mode,
)


def get_serializable_array_node_map_task(
entity_mapping: OrderedDict,
settings: SerializationSettings,
node: Node,
Expand Down Expand Up @@ -790,6 +817,9 @@ def get_serializable(
elif isinstance(entity, FlyteLaunchPlan):
cp_entity = entity

elif isinstance(entity, ArrayNode):
cp_entity = get_serializable_array_node(entity_mapping, settings, entity, options)

else:
raise Exception(f"Non serializable type found {type(entity)} Entity {entity}")

Expand Down
Loading

0 comments on commit 676914b

Please sign in to comment.