Skip to content

Commit

Permalink
Add support for specifying execution cluster labels in pyflyte (#2422)
Browse files Browse the repository at this point in the history
Signed-off-by: va6996 <[email protected]>
  • Loading branch information
va6996 authored Jun 21, 2024
1 parent 0839ce1 commit b06d33e
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 1 deletion.
13 changes: 12 additions & 1 deletion flytekit/clis/sdk_in_container/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ class RunLevelParams(PyFlyteParams):
)
limit: int = make_click_option_field(
click.Option(
param_decls=["--limit", "limit"],
param_decls=["--limit"],
required=False,
type=int,
default=50,
Expand All @@ -256,6 +256,16 @@ class RunLevelParams(PyFlyteParams):
help="Assign newly created execution to a given cluster pool",
)
)
execution_cluster_label: str = make_click_option_field(
click.Option(
param_decls=["--execution-cluster-label", "--ecl"],
required=False,
type=str,
default="",
help="Assign newly created execution to a given execution cluster label",
)
)

computed_params: RunLevelComputedParams = field(default_factory=RunLevelComputedParams)
_remote: typing.Optional[FlyteRemote] = None

Expand Down Expand Up @@ -448,6 +458,7 @@ def run_remote(
envs=run_level_params.envvars,
tags=run_level_params.tags,
cluster_pool=run_level_params.cluster_pool,
execution_cluster_label=run_level_params.execution_cluster_label,
)

console_url = remote.generate_console_url(execution)
Expand Down
14 changes: 14 additions & 0 deletions flytekit/models/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from flytekit.models import security
from flytekit.models.core import execution as _core_execution
from flytekit.models.core import identifier as _identifier
from flytekit.models.matchable_resource import ExecutionClusterLabel
from flytekit.models.node_execution import DynamicWorkflowNodeMetadata


Expand Down Expand Up @@ -181,6 +182,7 @@ def __init__(
envs: Optional[_common_models.Envs] = None,
tags: Optional[typing.List[str]] = None,
cluster_assignment: Optional[ClusterAssignment] = None,
execution_cluster_label: Optional[ExecutionClusterLabel] = None,
):
"""
:param flytekit.models.core.identifier.Identifier launch_plan: Launch plan unique identifier to execute
Expand All @@ -198,6 +200,7 @@ def __init__(
:param overwrite_cache: Optional flag to overwrite the cache for this execution.
:param envs: flytekit.models.common.Envs environment variables to set for this execution.
:param tags: Optional list of tags to apply to the execution.
:param execution_cluster_label: Optional execution cluster label to use for this execution.
"""
self._launch_plan = launch_plan
self._metadata = metadata
Expand All @@ -213,6 +216,7 @@ def __init__(
self._envs = envs
self._tags = tags
self._cluster_assignment = cluster_assignment
self._execution_cluster_label = execution_cluster_label

@property
def launch_plan(self):
Expand Down Expand Up @@ -295,6 +299,10 @@ def tags(self) -> Optional[typing.List[str]]:
def cluster_assignment(self) -> Optional[ClusterAssignment]:
return self._cluster_assignment

@property
def execution_cluster_label(self) -> Optional[ExecutionClusterLabel]:
return self._execution_cluster_label

def to_flyte_idl(self):
"""
:rtype: flyteidl.admin.execution_pb2.ExecutionSpec
Expand All @@ -316,6 +324,9 @@ def to_flyte_idl(self):
envs=self.envs.to_flyte_idl() if self.envs else None,
tags=self.tags,
cluster_assignment=self._cluster_assignment.to_flyte_idl() if self._cluster_assignment else None,
execution_cluster_label=self._execution_cluster_label.to_flyte_idl()
if self._execution_cluster_label
else None,
)

@classmethod
Expand Down Expand Up @@ -345,6 +356,9 @@ def from_flyte_idl(cls, p):
cluster_assignment=ClusterAssignment.from_flyte_idl(p.cluster_assignment)
if p.HasField("cluster_assignment")
else None,
execution_cluster_label=ExecutionClusterLabel.from_flyte_idl(p.execution_cluster_label)
if p.HasField("execution_cluster_label")
else None,
)


Expand Down
35 changes: 35 additions & 0 deletions flytekit/remote/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@
)
from flytekit.models.launch_plan import LaunchPlanState
from flytekit.models.literals import Literal, LiteralMap
from flytekit.models.matchable_resource import ExecutionClusterLabel
from flytekit.remote.backfill import create_backfill_workflow
from flytekit.remote.data import download_literal
from flytekit.remote.entities import FlyteLaunchPlan, FlyteNode, FlyteTask, FlyteTaskNode, FlyteWorkflow
Expand Down Expand Up @@ -1107,6 +1108,7 @@ def _execute(
envs: typing.Optional[typing.Dict[str, str]] = None,
tags: typing.Optional[typing.List[str]] = None,
cluster_pool: typing.Optional[str] = None,
execution_cluster_label: typing.Optional[str] = None,
) -> FlyteWorkflowExecution:
"""Common method for execution across all entities.
Expand All @@ -1124,6 +1126,7 @@ def _execute(
:param envs: Environment variables to set for the execution.
:param tags: Tags to set for the execution.
:param cluster_pool: Specify cluster pool on which newly created execution should be placed.
:param execution_cluster_label: Specify label of cluster(s) on which newly created execution should be placed.
:returns: :class:`~flytekit.remote.workflow_execution.FlyteWorkflowExecution`
"""
if execution_name is not None and execution_name_prefix is not None:
Expand Down Expand Up @@ -1201,6 +1204,9 @@ def _execute(
envs=common_models.Envs(envs) if envs else None,
tags=tags,
cluster_assignment=ClusterAssignment(cluster_pool=cluster_pool) if cluster_pool else None,
execution_cluster_label=ExecutionClusterLabel(execution_cluster_label)
if execution_cluster_label
else None,
),
literal_inputs,
)
Expand Down Expand Up @@ -1261,6 +1267,7 @@ def execute(
envs: typing.Optional[typing.Dict[str, str]] = None,
tags: typing.Optional[typing.List[str]] = None,
cluster_pool: typing.Optional[str] = None,
execution_cluster_label: typing.Optional[str] = None,
) -> FlyteWorkflowExecution:
"""
Execute a task, workflow, or launchplan, either something that's been declared locally, or a fetched entity.
Expand Down Expand Up @@ -1300,6 +1307,7 @@ def execute(
:param envs: Environment variables to be set for the execution.
:param tags: Tags to be set for the execution.
:param cluster_pool: Specify cluster pool on which newly created execution should be placed.
:param execution_cluster_label: Specify label of cluster(s) on which newly created execution should be placed.
.. note:
Expand All @@ -1323,6 +1331,7 @@ def execute(
envs=envs,
tags=tags,
cluster_pool=cluster_pool,
execution_cluster_label=execution_cluster_label,
)
if isinstance(entity, FlyteWorkflow):
return self.execute_remote_wf(
Expand All @@ -1339,6 +1348,7 @@ def execute(
envs=envs,
tags=tags,
cluster_pool=cluster_pool,
execution_cluster_label=execution_cluster_label,
)
if isinstance(entity, ReferenceTask):
return self.execute_reference_task(
Expand All @@ -1353,6 +1363,7 @@ def execute(
envs=envs,
tags=tags,
cluster_pool=cluster_pool,
execution_cluster_label=execution_cluster_label,
)
if isinstance(entity, ReferenceWorkflow):
return self.execute_reference_workflow(
Expand All @@ -1367,6 +1378,7 @@ def execute(
envs=envs,
tags=tags,
cluster_pool=cluster_pool,
execution_cluster_label=execution_cluster_label,
)
if isinstance(entity, ReferenceLaunchPlan):
return self.execute_reference_launch_plan(
Expand All @@ -1381,6 +1393,7 @@ def execute(
envs=envs,
tags=tags,
cluster_pool=cluster_pool,
execution_cluster_label=execution_cluster_label,
)
if isinstance(entity, PythonTask):
return self.execute_local_task(
Expand All @@ -1398,6 +1411,7 @@ def execute(
envs=envs,
tags=tags,
cluster_pool=cluster_pool,
execution_cluster_label=execution_cluster_label,
)
if isinstance(entity, WorkflowBase):
return self.execute_local_workflow(
Expand All @@ -1416,6 +1430,7 @@ def execute(
envs=envs,
tags=tags,
cluster_pool=cluster_pool,
execution_cluster_label=execution_cluster_label,
)
if isinstance(entity, LaunchPlan):
return self.execute_local_launch_plan(
Expand All @@ -1433,6 +1448,7 @@ def execute(
envs=envs,
tags=tags,
cluster_pool=cluster_pool,
execution_cluster_label=execution_cluster_label,
)
raise NotImplementedError(f"entity type {type(entity)} not recognized for execution")

Expand All @@ -1454,6 +1470,7 @@ def execute_remote_task_lp(
envs: typing.Optional[typing.Dict[str, str]] = None,
tags: typing.Optional[typing.List[str]] = None,
cluster_pool: typing.Optional[str] = None,
execution_cluster_label: typing.Optional[str] = None,
) -> FlyteWorkflowExecution:
"""Execute a FlyteTask, or FlyteLaunchplan.
Expand All @@ -1473,6 +1490,7 @@ def execute_remote_task_lp(
envs=envs,
tags=tags,
cluster_pool=cluster_pool,
execution_cluster_label=execution_cluster_label,
)

def execute_remote_wf(
Expand All @@ -1490,6 +1508,7 @@ def execute_remote_wf(
envs: typing.Optional[typing.Dict[str, str]] = None,
tags: typing.Optional[typing.List[str]] = None,
cluster_pool: typing.Optional[str] = None,
execution_cluster_label: typing.Optional[str] = None,
) -> FlyteWorkflowExecution:
"""Execute a FlyteWorkflow.
Expand All @@ -1510,6 +1529,7 @@ def execute_remote_wf(
envs=envs,
tags=tags,
cluster_pool=cluster_pool,
execution_cluster_label=execution_cluster_label,
)

# Flyte Reference Entities
Expand All @@ -1527,6 +1547,7 @@ def execute_reference_task(
envs: typing.Optional[typing.Dict[str, str]] = None,
tags: typing.Optional[typing.List[str]] = None,
cluster_pool: typing.Optional[str] = None,
execution_cluster_label: typing.Optional[str] = None,
) -> FlyteWorkflowExecution:
"""Execute a ReferenceTask."""
resolved_identifiers = ResolvedIdentifiers(
Expand Down Expand Up @@ -1557,6 +1578,7 @@ def execute_reference_task(
envs=envs,
tags=tags,
cluster_pool=cluster_pool,
execution_cluster_label=execution_cluster_label,
)

def execute_reference_workflow(
Expand All @@ -1572,6 +1594,7 @@ def execute_reference_workflow(
envs: typing.Optional[typing.Dict[str, str]] = None,
tags: typing.Optional[typing.List[str]] = None,
cluster_pool: typing.Optional[str] = None,
execution_cluster_label: typing.Optional[str] = None,
) -> FlyteWorkflowExecution:
"""Execute a ReferenceWorkflow."""
resolved_identifiers = ResolvedIdentifiers(
Expand Down Expand Up @@ -1616,6 +1639,7 @@ def execute_reference_workflow(
envs=envs,
tags=tags,
cluster_pool=cluster_pool,
execution_cluster_label=execution_cluster_label,
)

def execute_reference_launch_plan(
Expand All @@ -1631,6 +1655,7 @@ def execute_reference_launch_plan(
envs: typing.Optional[typing.Dict[str, str]] = None,
tags: typing.Optional[typing.List[str]] = None,
cluster_pool: typing.Optional[str] = None,
execution_cluster_label: typing.Optional[str] = None,
) -> FlyteWorkflowExecution:
"""Execute a ReferenceLaunchPlan."""
resolved_identifiers = ResolvedIdentifiers(
Expand Down Expand Up @@ -1661,6 +1686,7 @@ def execute_reference_launch_plan(
envs=envs,
tags=tags,
cluster_pool=cluster_pool,
execution_cluster_label=execution_cluster_label,
)

# Flytekit Entities
Expand All @@ -1682,6 +1708,7 @@ def execute_local_task(
envs: typing.Optional[typing.Dict[str, str]] = None,
tags: typing.Optional[typing.List[str]] = None,
cluster_pool: typing.Optional[str] = None,
execution_cluster_label: typing.Optional[str] = None,
) -> FlyteWorkflowExecution:
"""
Execute a @task-decorated function or TaskTemplate task.
Expand All @@ -1699,6 +1726,7 @@ def execute_local_task(
:param envs: Environment variables to set for the execution.
:param tags: Tags to set for the execution.
:param cluster_pool: Specify cluster pool on which newly created execution should be placed.
:param execution_cluster_label: Specify label of cluster(s) on which newly created execution should be placed.
:return: FlyteWorkflowExecution object.
"""
resolved_identifiers = self._resolve_identifier_kwargs(entity, project, domain, name, version)
Expand Down Expand Up @@ -1730,6 +1758,7 @@ def execute_local_task(
envs=envs,
tags=tags,
cluster_pool=cluster_pool,
execution_cluster_label=execution_cluster_label,
)

def execute_local_workflow(
Expand All @@ -1749,6 +1778,7 @@ def execute_local_workflow(
envs: typing.Optional[typing.Dict[str, str]] = None,
tags: typing.Optional[typing.List[str]] = None,
cluster_pool: typing.Optional[str] = None,
execution_cluster_label: typing.Optional[str] = None,
) -> FlyteWorkflowExecution:
"""
Execute an @workflow decorated function.
Expand All @@ -1766,6 +1796,7 @@ def execute_local_workflow(
:param envs:
:param tags:
:param cluster_pool:
:param execution_cluster_label:
:return:
"""
resolved_identifiers = self._resolve_identifier_kwargs(entity, project, domain, name, version)
Expand Down Expand Up @@ -1814,6 +1845,7 @@ def execute_local_workflow(
envs=envs,
tags=tags,
cluster_pool=cluster_pool,
execution_cluster_label=execution_cluster_label,
)

def execute_local_launch_plan(
Expand All @@ -1832,6 +1864,7 @@ def execute_local_launch_plan(
envs: typing.Optional[typing.Dict[str, str]] = None,
tags: typing.Optional[typing.List[str]] = None,
cluster_pool: typing.Optional[str] = None,
execution_cluster_label: typing.Optional[str] = None,
) -> FlyteWorkflowExecution:
"""
Expand All @@ -1848,6 +1881,7 @@ def execute_local_launch_plan(
:param envs: Environment variables to be passed into the execution.
:param tags: Tags to be passed into the execution.
:param cluster_pool: Specify cluster pool on which newly created execution should be placed.
:param execution_cluster_label: Specify label of cluster(s) on which newly created execution should be placed.
:return: FlyteWorkflowExecution object
"""
resolved_identifiers = self._resolve_identifier_kwargs(entity, project, domain, name, version)
Expand Down Expand Up @@ -1877,6 +1911,7 @@ def execute_local_launch_plan(
envs=envs,
tags=tags,
cluster_pool=cluster_pool,
execution_cluster_label=execution_cluster_label,
)

###################################
Expand Down
22 changes: 22 additions & 0 deletions tests/flytekit/unit/remote/test_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,28 @@ def local_assertions(*args, **kwargs):
)


def test_execution_cluster_label_attributes(remote, mock_wf_exec):
mock_wf_exec.return_value = True
mock_client = MagicMock()
remote._client = mock_client

def local_assertions(*args, **kwargs):
execution_spec = args[3]
assert execution_spec.execution_cluster_label.value == "label"

mock_client.create_execution.side_effect = local_assertions

mock_entity = MagicMock()

remote._execute(
mock_entity,
inputs={},
project="proj",
domain="dev",
execution_cluster_label="label",
)


def test_underscore_execute_fall_back_remote_attributes(remote, mock_wf_exec):
mock_wf_exec.return_value = True
mock_client = MagicMock()
Expand Down

0 comments on commit b06d33e

Please sign in to comment.