diff --git a/flytekit/clis/sdk_in_container/run.py b/flytekit/clis/sdk_in_container/run.py index 47e16510cb..323988b340 100644 --- a/flytekit/clis/sdk_in_container/run.py +++ b/flytekit/clis/sdk_in_container/run.py @@ -239,7 +239,7 @@ class RunLevelParams(PyFlyteParams): ) limit: int = make_click_option_field( click.Option( - param_decls=["--limit", "limit"], + param_decls=["--limit"], required=False, type=int, default=50, @@ -256,6 +256,16 @@ class RunLevelParams(PyFlyteParams): help="Assign newly created execution to a given cluster pool", ) ) + execution_cluster_label: str = make_click_option_field( + click.Option( + param_decls=["--execution-cluster-label", "--ecl"], + required=False, + type=str, + default="", + help="Assign newly created execution to a given execution cluster label", + ) + ) + computed_params: RunLevelComputedParams = field(default_factory=RunLevelComputedParams) _remote: typing.Optional[FlyteRemote] = None @@ -448,6 +458,7 @@ def run_remote( envs=run_level_params.envvars, tags=run_level_params.tags, cluster_pool=run_level_params.cluster_pool, + execution_cluster_label=run_level_params.execution_cluster_label, ) console_url = remote.generate_console_url(execution) diff --git a/flytekit/models/execution.py b/flytekit/models/execution.py index 11c0f547d7..7e4ff02645 100644 --- a/flytekit/models/execution.py +++ b/flytekit/models/execution.py @@ -17,6 +17,7 @@ from flytekit.models import security from flytekit.models.core import execution as _core_execution from flytekit.models.core import identifier as _identifier +from flytekit.models.matchable_resource import ExecutionClusterLabel from flytekit.models.node_execution import DynamicWorkflowNodeMetadata @@ -181,6 +182,7 @@ def __init__( envs: Optional[_common_models.Envs] = None, tags: Optional[typing.List[str]] = None, cluster_assignment: Optional[ClusterAssignment] = None, + execution_cluster_label: Optional[ExecutionClusterLabel] = None, ): """ :param flytekit.models.core.identifier.Identifier launch_plan: Launch plan unique identifier to execute @@ -198,6 +200,7 @@ def __init__( :param overwrite_cache: Optional flag to overwrite the cache for this execution. :param envs: flytekit.models.common.Envs environment variables to set for this execution. :param tags: Optional list of tags to apply to the execution. + :param execution_cluster_label: Optional execution cluster label to use for this execution. """ self._launch_plan = launch_plan self._metadata = metadata @@ -213,6 +216,7 @@ def __init__( self._envs = envs self._tags = tags self._cluster_assignment = cluster_assignment + self._execution_cluster_label = execution_cluster_label @property def launch_plan(self): @@ -295,6 +299,10 @@ def tags(self) -> Optional[typing.List[str]]: def cluster_assignment(self) -> Optional[ClusterAssignment]: return self._cluster_assignment + @property + def execution_cluster_label(self) -> Optional[ExecutionClusterLabel]: + return self._execution_cluster_label + def to_flyte_idl(self): """ :rtype: flyteidl.admin.execution_pb2.ExecutionSpec @@ -316,6 +324,9 @@ def to_flyte_idl(self): envs=self.envs.to_flyte_idl() if self.envs else None, tags=self.tags, cluster_assignment=self._cluster_assignment.to_flyte_idl() if self._cluster_assignment else None, + execution_cluster_label=self._execution_cluster_label.to_flyte_idl() + if self._execution_cluster_label + else None, ) @classmethod @@ -345,6 +356,9 @@ def from_flyte_idl(cls, p): cluster_assignment=ClusterAssignment.from_flyte_idl(p.cluster_assignment) if p.HasField("cluster_assignment") else None, + execution_cluster_label=ExecutionClusterLabel.from_flyte_idl(p.execution_cluster_label) + if p.HasField("execution_cluster_label") + else None, ) diff --git a/flytekit/remote/remote.py b/flytekit/remote/remote.py index ac449dd786..70549c5b30 100644 --- a/flytekit/remote/remote.py +++ b/flytekit/remote/remote.py @@ -76,6 +76,7 @@ ) from flytekit.models.launch_plan import LaunchPlanState from flytekit.models.literals import Literal, LiteralMap +from flytekit.models.matchable_resource import ExecutionClusterLabel from flytekit.remote.backfill import create_backfill_workflow from flytekit.remote.data import download_literal from flytekit.remote.entities import FlyteLaunchPlan, FlyteNode, FlyteTask, FlyteTaskNode, FlyteWorkflow @@ -1107,6 +1108,7 @@ def _execute( envs: typing.Optional[typing.Dict[str, str]] = None, tags: typing.Optional[typing.List[str]] = None, cluster_pool: typing.Optional[str] = None, + execution_cluster_label: typing.Optional[str] = None, ) -> FlyteWorkflowExecution: """Common method for execution across all entities. @@ -1124,6 +1126,7 @@ def _execute( :param envs: Environment variables to set for the execution. :param tags: Tags to set for the execution. :param cluster_pool: Specify cluster pool on which newly created execution should be placed. + :param execution_cluster_label: Specify label of cluster(s) on which newly created execution should be placed. :returns: :class:`~flytekit.remote.workflow_execution.FlyteWorkflowExecution` """ if execution_name is not None and execution_name_prefix is not None: @@ -1201,6 +1204,9 @@ def _execute( envs=common_models.Envs(envs) if envs else None, tags=tags, cluster_assignment=ClusterAssignment(cluster_pool=cluster_pool) if cluster_pool else None, + execution_cluster_label=ExecutionClusterLabel(execution_cluster_label) + if execution_cluster_label + else None, ), literal_inputs, ) @@ -1261,6 +1267,7 @@ def execute( envs: typing.Optional[typing.Dict[str, str]] = None, tags: typing.Optional[typing.List[str]] = None, cluster_pool: typing.Optional[str] = None, + execution_cluster_label: typing.Optional[str] = None, ) -> FlyteWorkflowExecution: """ Execute a task, workflow, or launchplan, either something that's been declared locally, or a fetched entity. @@ -1300,6 +1307,7 @@ def execute( :param envs: Environment variables to be set for the execution. :param tags: Tags to be set for the execution. :param cluster_pool: Specify cluster pool on which newly created execution should be placed. + :param execution_cluster_label: Specify label of cluster(s) on which newly created execution should be placed. .. note: @@ -1323,6 +1331,7 @@ def execute( envs=envs, tags=tags, cluster_pool=cluster_pool, + execution_cluster_label=execution_cluster_label, ) if isinstance(entity, FlyteWorkflow): return self.execute_remote_wf( @@ -1339,6 +1348,7 @@ def execute( envs=envs, tags=tags, cluster_pool=cluster_pool, + execution_cluster_label=execution_cluster_label, ) if isinstance(entity, ReferenceTask): return self.execute_reference_task( @@ -1353,6 +1363,7 @@ def execute( envs=envs, tags=tags, cluster_pool=cluster_pool, + execution_cluster_label=execution_cluster_label, ) if isinstance(entity, ReferenceWorkflow): return self.execute_reference_workflow( @@ -1367,6 +1378,7 @@ def execute( envs=envs, tags=tags, cluster_pool=cluster_pool, + execution_cluster_label=execution_cluster_label, ) if isinstance(entity, ReferenceLaunchPlan): return self.execute_reference_launch_plan( @@ -1381,6 +1393,7 @@ def execute( envs=envs, tags=tags, cluster_pool=cluster_pool, + execution_cluster_label=execution_cluster_label, ) if isinstance(entity, PythonTask): return self.execute_local_task( @@ -1398,6 +1411,7 @@ def execute( envs=envs, tags=tags, cluster_pool=cluster_pool, + execution_cluster_label=execution_cluster_label, ) if isinstance(entity, WorkflowBase): return self.execute_local_workflow( @@ -1416,6 +1430,7 @@ def execute( envs=envs, tags=tags, cluster_pool=cluster_pool, + execution_cluster_label=execution_cluster_label, ) if isinstance(entity, LaunchPlan): return self.execute_local_launch_plan( @@ -1433,6 +1448,7 @@ def execute( envs=envs, tags=tags, cluster_pool=cluster_pool, + execution_cluster_label=execution_cluster_label, ) raise NotImplementedError(f"entity type {type(entity)} not recognized for execution") @@ -1454,6 +1470,7 @@ def execute_remote_task_lp( envs: typing.Optional[typing.Dict[str, str]] = None, tags: typing.Optional[typing.List[str]] = None, cluster_pool: typing.Optional[str] = None, + execution_cluster_label: typing.Optional[str] = None, ) -> FlyteWorkflowExecution: """Execute a FlyteTask, or FlyteLaunchplan. @@ -1473,6 +1490,7 @@ def execute_remote_task_lp( envs=envs, tags=tags, cluster_pool=cluster_pool, + execution_cluster_label=execution_cluster_label, ) def execute_remote_wf( @@ -1490,6 +1508,7 @@ def execute_remote_wf( envs: typing.Optional[typing.Dict[str, str]] = None, tags: typing.Optional[typing.List[str]] = None, cluster_pool: typing.Optional[str] = None, + execution_cluster_label: typing.Optional[str] = None, ) -> FlyteWorkflowExecution: """Execute a FlyteWorkflow. @@ -1510,6 +1529,7 @@ def execute_remote_wf( envs=envs, tags=tags, cluster_pool=cluster_pool, + execution_cluster_label=execution_cluster_label, ) # Flyte Reference Entities @@ -1527,6 +1547,7 @@ def execute_reference_task( envs: typing.Optional[typing.Dict[str, str]] = None, tags: typing.Optional[typing.List[str]] = None, cluster_pool: typing.Optional[str] = None, + execution_cluster_label: typing.Optional[str] = None, ) -> FlyteWorkflowExecution: """Execute a ReferenceTask.""" resolved_identifiers = ResolvedIdentifiers( @@ -1557,6 +1578,7 @@ def execute_reference_task( envs=envs, tags=tags, cluster_pool=cluster_pool, + execution_cluster_label=execution_cluster_label, ) def execute_reference_workflow( @@ -1572,6 +1594,7 @@ def execute_reference_workflow( envs: typing.Optional[typing.Dict[str, str]] = None, tags: typing.Optional[typing.List[str]] = None, cluster_pool: typing.Optional[str] = None, + execution_cluster_label: typing.Optional[str] = None, ) -> FlyteWorkflowExecution: """Execute a ReferenceWorkflow.""" resolved_identifiers = ResolvedIdentifiers( @@ -1616,6 +1639,7 @@ def execute_reference_workflow( envs=envs, tags=tags, cluster_pool=cluster_pool, + execution_cluster_label=execution_cluster_label, ) def execute_reference_launch_plan( @@ -1631,6 +1655,7 @@ def execute_reference_launch_plan( envs: typing.Optional[typing.Dict[str, str]] = None, tags: typing.Optional[typing.List[str]] = None, cluster_pool: typing.Optional[str] = None, + execution_cluster_label: typing.Optional[str] = None, ) -> FlyteWorkflowExecution: """Execute a ReferenceLaunchPlan.""" resolved_identifiers = ResolvedIdentifiers( @@ -1661,6 +1686,7 @@ def execute_reference_launch_plan( envs=envs, tags=tags, cluster_pool=cluster_pool, + execution_cluster_label=execution_cluster_label, ) # Flytekit Entities @@ -1682,6 +1708,7 @@ def execute_local_task( envs: typing.Optional[typing.Dict[str, str]] = None, tags: typing.Optional[typing.List[str]] = None, cluster_pool: typing.Optional[str] = None, + execution_cluster_label: typing.Optional[str] = None, ) -> FlyteWorkflowExecution: """ Execute a @task-decorated function or TaskTemplate task. @@ -1699,6 +1726,7 @@ def execute_local_task( :param envs: Environment variables to set for the execution. :param tags: Tags to set for the execution. :param cluster_pool: Specify cluster pool on which newly created execution should be placed. + :param execution_cluster_label: Specify label of cluster(s) on which newly created execution should be placed. :return: FlyteWorkflowExecution object. """ resolved_identifiers = self._resolve_identifier_kwargs(entity, project, domain, name, version) @@ -1730,6 +1758,7 @@ def execute_local_task( envs=envs, tags=tags, cluster_pool=cluster_pool, + execution_cluster_label=execution_cluster_label, ) def execute_local_workflow( @@ -1749,6 +1778,7 @@ def execute_local_workflow( envs: typing.Optional[typing.Dict[str, str]] = None, tags: typing.Optional[typing.List[str]] = None, cluster_pool: typing.Optional[str] = None, + execution_cluster_label: typing.Optional[str] = None, ) -> FlyteWorkflowExecution: """ Execute an @workflow decorated function. @@ -1766,6 +1796,7 @@ def execute_local_workflow( :param envs: :param tags: :param cluster_pool: + :param execution_cluster_label: :return: """ resolved_identifiers = self._resolve_identifier_kwargs(entity, project, domain, name, version) @@ -1814,6 +1845,7 @@ def execute_local_workflow( envs=envs, tags=tags, cluster_pool=cluster_pool, + execution_cluster_label=execution_cluster_label, ) def execute_local_launch_plan( @@ -1832,6 +1864,7 @@ def execute_local_launch_plan( envs: typing.Optional[typing.Dict[str, str]] = None, tags: typing.Optional[typing.List[str]] = None, cluster_pool: typing.Optional[str] = None, + execution_cluster_label: typing.Optional[str] = None, ) -> FlyteWorkflowExecution: """ @@ -1848,6 +1881,7 @@ def execute_local_launch_plan( :param envs: Environment variables to be passed into the execution. :param tags: Tags to be passed into the execution. :param cluster_pool: Specify cluster pool on which newly created execution should be placed. + :param execution_cluster_label: Specify label of cluster(s) on which newly created execution should be placed. :return: FlyteWorkflowExecution object """ resolved_identifiers = self._resolve_identifier_kwargs(entity, project, domain, name, version) @@ -1877,6 +1911,7 @@ def execute_local_launch_plan( envs=envs, tags=tags, cluster_pool=cluster_pool, + execution_cluster_label=execution_cluster_label, ) ################################### diff --git a/tests/flytekit/unit/remote/test_remote.py b/tests/flytekit/unit/remote/test_remote.py index d6b0cc711c..f4be6c33c1 100644 --- a/tests/flytekit/unit/remote/test_remote.py +++ b/tests/flytekit/unit/remote/test_remote.py @@ -157,6 +157,28 @@ def local_assertions(*args, **kwargs): ) +def test_execution_cluster_label_attributes(remote, mock_wf_exec): + mock_wf_exec.return_value = True + mock_client = MagicMock() + remote._client = mock_client + + def local_assertions(*args, **kwargs): + execution_spec = args[3] + assert execution_spec.execution_cluster_label.value == "label" + + mock_client.create_execution.side_effect = local_assertions + + mock_entity = MagicMock() + + remote._execute( + mock_entity, + inputs={}, + project="proj", + domain="dev", + execution_cluster_label="label", + ) + + def test_underscore_execute_fall_back_remote_attributes(remote, mock_wf_exec): mock_wf_exec.return_value = True mock_client = MagicMock()