Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for specifying execution cluster labels in pyflyte #2422

Merged
merged 2 commits into from
Jun 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion flytekit/clis/sdk_in_container/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ class RunLevelParams(PyFlyteParams):
)
limit: int = make_click_option_field(
click.Option(
param_decls=["--limit", "limit"],
param_decls=["--limit"],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: bit of scope creep ;)

Still valid of course.

required=False,
type=int,
default=50,
Expand All @@ -256,6 +256,16 @@ class RunLevelParams(PyFlyteParams):
help="Assign newly created execution to a given cluster pool",
)
)
execution_cluster_label: str = make_click_option_field(
click.Option(
param_decls=["--execution-cluster-label", "--ecl"],
required=False,
type=str,
default="",
help="Assign newly created execution to a given execution cluster label",
)
)

computed_params: RunLevelComputedParams = field(default_factory=RunLevelComputedParams)
_remote: typing.Optional[FlyteRemote] = None

Expand Down Expand Up @@ -448,6 +458,7 @@ def run_remote(
envs=run_level_params.envvars,
tags=run_level_params.tags,
cluster_pool=run_level_params.cluster_pool,
execution_cluster_label=run_level_params.execution_cluster_label,
)

console_url = remote.generate_console_url(execution)
Expand Down
14 changes: 14 additions & 0 deletions flytekit/models/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from flytekit.models import security
from flytekit.models.core import execution as _core_execution
from flytekit.models.core import identifier as _identifier
from flytekit.models.matchable_resource import ExecutionClusterLabel
from flytekit.models.node_execution import DynamicWorkflowNodeMetadata


Expand Down Expand Up @@ -181,6 +182,7 @@ def __init__(
envs: Optional[_common_models.Envs] = None,
tags: Optional[typing.List[str]] = None,
cluster_assignment: Optional[ClusterAssignment] = None,
execution_cluster_label: Optional[ExecutionClusterLabel] = None,
):
"""
:param flytekit.models.core.identifier.Identifier launch_plan: Launch plan unique identifier to execute
Expand All @@ -198,6 +200,7 @@ def __init__(
:param overwrite_cache: Optional flag to overwrite the cache for this execution.
:param envs: flytekit.models.common.Envs environment variables to set for this execution.
:param tags: Optional list of tags to apply to the execution.
:param execution_cluster_label: Optional execution cluster label to use for this execution.
"""
self._launch_plan = launch_plan
self._metadata = metadata
Expand All @@ -213,6 +216,7 @@ def __init__(
self._envs = envs
self._tags = tags
self._cluster_assignment = cluster_assignment
self._execution_cluster_label = execution_cluster_label

@property
def launch_plan(self):
Expand Down Expand Up @@ -295,6 +299,10 @@ def tags(self) -> Optional[typing.List[str]]:
def cluster_assignment(self) -> Optional[ClusterAssignment]:
return self._cluster_assignment

@property
def execution_cluster_label(self) -> Optional[ExecutionClusterLabel]:
return self._execution_cluster_label

def to_flyte_idl(self):
"""
:rtype: flyteidl.admin.execution_pb2.ExecutionSpec
Expand All @@ -316,6 +324,9 @@ def to_flyte_idl(self):
envs=self.envs.to_flyte_idl() if self.envs else None,
tags=self.tags,
cluster_assignment=self._cluster_assignment.to_flyte_idl() if self._cluster_assignment else None,
execution_cluster_label=self._execution_cluster_label.to_flyte_idl()
if self._execution_cluster_label
else None,
)

@classmethod
Expand Down Expand Up @@ -345,6 +356,9 @@ def from_flyte_idl(cls, p):
cluster_assignment=ClusterAssignment.from_flyte_idl(p.cluster_assignment)
if p.HasField("cluster_assignment")
else None,
execution_cluster_label=ExecutionClusterLabel.from_flyte_idl(p.execution_cluster_label)
if p.HasField("execution_cluster_label")
else None,
)


Expand Down
35 changes: 35 additions & 0 deletions flytekit/remote/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@
)
from flytekit.models.launch_plan import LaunchPlanState
from flytekit.models.literals import Literal, LiteralMap
from flytekit.models.matchable_resource import ExecutionClusterLabel
from flytekit.remote.backfill import create_backfill_workflow
from flytekit.remote.data import download_literal
from flytekit.remote.entities import FlyteLaunchPlan, FlyteNode, FlyteTask, FlyteTaskNode, FlyteWorkflow
Expand Down Expand Up @@ -1107,6 +1108,7 @@ def _execute(
envs: typing.Optional[typing.Dict[str, str]] = None,
tags: typing.Optional[typing.List[str]] = None,
cluster_pool: typing.Optional[str] = None,
execution_cluster_label: typing.Optional[str] = None,
) -> FlyteWorkflowExecution:
"""Common method for execution across all entities.

Expand All @@ -1124,6 +1126,7 @@ def _execute(
:param envs: Environment variables to set for the execution.
:param tags: Tags to set for the execution.
:param cluster_pool: Specify cluster pool on which newly created execution should be placed.
:param execution_cluster_label: Specify label of cluster(s) on which newly created execution should be placed.
:returns: :class:`~flytekit.remote.workflow_execution.FlyteWorkflowExecution`
"""
if execution_name is not None and execution_name_prefix is not None:
Expand Down Expand Up @@ -1201,6 +1204,9 @@ def _execute(
envs=common_models.Envs(envs) if envs else None,
tags=tags,
cluster_assignment=ClusterAssignment(cluster_pool=cluster_pool) if cluster_pool else None,
execution_cluster_label=ExecutionClusterLabel(execution_cluster_label)
if execution_cluster_label
else None,
),
literal_inputs,
)
Expand Down Expand Up @@ -1261,6 +1267,7 @@ def execute(
envs: typing.Optional[typing.Dict[str, str]] = None,
tags: typing.Optional[typing.List[str]] = None,
cluster_pool: typing.Optional[str] = None,
execution_cluster_label: typing.Optional[str] = None,
) -> FlyteWorkflowExecution:
"""
Execute a task, workflow, or launchplan, either something that's been declared locally, or a fetched entity.
Expand Down Expand Up @@ -1300,6 +1307,7 @@ def execute(
:param envs: Environment variables to be set for the execution.
:param tags: Tags to be set for the execution.
:param cluster_pool: Specify cluster pool on which newly created execution should be placed.
:param execution_cluster_label: Specify label of cluster(s) on which newly created execution should be placed.

.. note:

Expand All @@ -1323,6 +1331,7 @@ def execute(
envs=envs,
tags=tags,
cluster_pool=cluster_pool,
execution_cluster_label=execution_cluster_label,
)
if isinstance(entity, FlyteWorkflow):
return self.execute_remote_wf(
Expand All @@ -1339,6 +1348,7 @@ def execute(
envs=envs,
tags=tags,
cluster_pool=cluster_pool,
execution_cluster_label=execution_cluster_label,
)
if isinstance(entity, ReferenceTask):
return self.execute_reference_task(
Expand All @@ -1353,6 +1363,7 @@ def execute(
envs=envs,
tags=tags,
cluster_pool=cluster_pool,
execution_cluster_label=execution_cluster_label,
)
if isinstance(entity, ReferenceWorkflow):
return self.execute_reference_workflow(
Expand All @@ -1367,6 +1378,7 @@ def execute(
envs=envs,
tags=tags,
cluster_pool=cluster_pool,
execution_cluster_label=execution_cluster_label,
)
if isinstance(entity, ReferenceLaunchPlan):
return self.execute_reference_launch_plan(
Expand All @@ -1381,6 +1393,7 @@ def execute(
envs=envs,
tags=tags,
cluster_pool=cluster_pool,
execution_cluster_label=execution_cluster_label,
)
if isinstance(entity, PythonTask):
return self.execute_local_task(
Expand All @@ -1398,6 +1411,7 @@ def execute(
envs=envs,
tags=tags,
cluster_pool=cluster_pool,
execution_cluster_label=execution_cluster_label,
)
if isinstance(entity, WorkflowBase):
return self.execute_local_workflow(
Expand All @@ -1416,6 +1430,7 @@ def execute(
envs=envs,
tags=tags,
cluster_pool=cluster_pool,
execution_cluster_label=execution_cluster_label,
)
if isinstance(entity, LaunchPlan):
return self.execute_local_launch_plan(
Expand All @@ -1433,6 +1448,7 @@ def execute(
envs=envs,
tags=tags,
cluster_pool=cluster_pool,
execution_cluster_label=execution_cluster_label,
)
raise NotImplementedError(f"entity type {type(entity)} not recognized for execution")

Expand All @@ -1454,6 +1470,7 @@ def execute_remote_task_lp(
envs: typing.Optional[typing.Dict[str, str]] = None,
tags: typing.Optional[typing.List[str]] = None,
cluster_pool: typing.Optional[str] = None,
execution_cluster_label: typing.Optional[str] = None,
) -> FlyteWorkflowExecution:
"""Execute a FlyteTask, or FlyteLaunchplan.

Expand All @@ -1473,6 +1490,7 @@ def execute_remote_task_lp(
envs=envs,
tags=tags,
cluster_pool=cluster_pool,
execution_cluster_label=execution_cluster_label,
)

def execute_remote_wf(
Expand All @@ -1490,6 +1508,7 @@ def execute_remote_wf(
envs: typing.Optional[typing.Dict[str, str]] = None,
tags: typing.Optional[typing.List[str]] = None,
cluster_pool: typing.Optional[str] = None,
execution_cluster_label: typing.Optional[str] = None,
) -> FlyteWorkflowExecution:
"""Execute a FlyteWorkflow.

Expand All @@ -1510,6 +1529,7 @@ def execute_remote_wf(
envs=envs,
tags=tags,
cluster_pool=cluster_pool,
execution_cluster_label=execution_cluster_label,
)

# Flyte Reference Entities
Expand All @@ -1527,6 +1547,7 @@ def execute_reference_task(
envs: typing.Optional[typing.Dict[str, str]] = None,
tags: typing.Optional[typing.List[str]] = None,
cluster_pool: typing.Optional[str] = None,
execution_cluster_label: typing.Optional[str] = None,
) -> FlyteWorkflowExecution:
"""Execute a ReferenceTask."""
resolved_identifiers = ResolvedIdentifiers(
Expand Down Expand Up @@ -1557,6 +1578,7 @@ def execute_reference_task(
envs=envs,
tags=tags,
cluster_pool=cluster_pool,
execution_cluster_label=execution_cluster_label,
)

def execute_reference_workflow(
Expand All @@ -1572,6 +1594,7 @@ def execute_reference_workflow(
envs: typing.Optional[typing.Dict[str, str]] = None,
tags: typing.Optional[typing.List[str]] = None,
cluster_pool: typing.Optional[str] = None,
execution_cluster_label: typing.Optional[str] = None,
) -> FlyteWorkflowExecution:
"""Execute a ReferenceWorkflow."""
resolved_identifiers = ResolvedIdentifiers(
Expand Down Expand Up @@ -1616,6 +1639,7 @@ def execute_reference_workflow(
envs=envs,
tags=tags,
cluster_pool=cluster_pool,
execution_cluster_label=execution_cluster_label,
)

def execute_reference_launch_plan(
Expand All @@ -1631,6 +1655,7 @@ def execute_reference_launch_plan(
envs: typing.Optional[typing.Dict[str, str]] = None,
tags: typing.Optional[typing.List[str]] = None,
cluster_pool: typing.Optional[str] = None,
execution_cluster_label: typing.Optional[str] = None,
) -> FlyteWorkflowExecution:
"""Execute a ReferenceLaunchPlan."""
resolved_identifiers = ResolvedIdentifiers(
Expand Down Expand Up @@ -1661,6 +1686,7 @@ def execute_reference_launch_plan(
envs=envs,
tags=tags,
cluster_pool=cluster_pool,
execution_cluster_label=execution_cluster_label,
)

# Flytekit Entities
Expand All @@ -1682,6 +1708,7 @@ def execute_local_task(
envs: typing.Optional[typing.Dict[str, str]] = None,
tags: typing.Optional[typing.List[str]] = None,
cluster_pool: typing.Optional[str] = None,
execution_cluster_label: typing.Optional[str] = None,
) -> FlyteWorkflowExecution:
"""
Execute a @task-decorated function or TaskTemplate task.
Expand All @@ -1699,6 +1726,7 @@ def execute_local_task(
:param envs: Environment variables to set for the execution.
:param tags: Tags to set for the execution.
:param cluster_pool: Specify cluster pool on which newly created execution should be placed.
:param execution_cluster_label: Specify label of cluster(s) on which newly created execution should be placed.
:return: FlyteWorkflowExecution object.
"""
resolved_identifiers = self._resolve_identifier_kwargs(entity, project, domain, name, version)
Expand Down Expand Up @@ -1730,6 +1758,7 @@ def execute_local_task(
envs=envs,
tags=tags,
cluster_pool=cluster_pool,
execution_cluster_label=execution_cluster_label,
)

def execute_local_workflow(
Expand All @@ -1749,6 +1778,7 @@ def execute_local_workflow(
envs: typing.Optional[typing.Dict[str, str]] = None,
tags: typing.Optional[typing.List[str]] = None,
cluster_pool: typing.Optional[str] = None,
execution_cluster_label: typing.Optional[str] = None,
) -> FlyteWorkflowExecution:
"""
Execute an @workflow decorated function.
Expand All @@ -1766,6 +1796,7 @@ def execute_local_workflow(
:param envs:
:param tags:
:param cluster_pool:
:param execution_cluster_label:
:return:
"""
resolved_identifiers = self._resolve_identifier_kwargs(entity, project, domain, name, version)
Expand Down Expand Up @@ -1814,6 +1845,7 @@ def execute_local_workflow(
envs=envs,
tags=tags,
cluster_pool=cluster_pool,
execution_cluster_label=execution_cluster_label,
)

def execute_local_launch_plan(
Expand All @@ -1832,6 +1864,7 @@ def execute_local_launch_plan(
envs: typing.Optional[typing.Dict[str, str]] = None,
tags: typing.Optional[typing.List[str]] = None,
cluster_pool: typing.Optional[str] = None,
execution_cluster_label: typing.Optional[str] = None,
) -> FlyteWorkflowExecution:
"""

Expand All @@ -1848,6 +1881,7 @@ def execute_local_launch_plan(
:param envs: Environment variables to be passed into the execution.
:param tags: Tags to be passed into the execution.
:param cluster_pool: Specify cluster pool on which newly created execution should be placed.
:param execution_cluster_label: Specify label of cluster(s) on which newly created execution should be placed.
:return: FlyteWorkflowExecution object
"""
resolved_identifiers = self._resolve_identifier_kwargs(entity, project, domain, name, version)
Expand Down Expand Up @@ -1877,6 +1911,7 @@ def execute_local_launch_plan(
envs=envs,
tags=tags,
cluster_pool=cluster_pool,
execution_cluster_label=execution_cluster_label,
)

###################################
Expand Down
22 changes: 22 additions & 0 deletions tests/flytekit/unit/remote/test_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,28 @@ def local_assertions(*args, **kwargs):
)


def test_execution_cluster_label_attributes(remote, mock_wf_exec):
mock_wf_exec.return_value = True
mock_client = MagicMock()
remote._client = mock_client

def local_assertions(*args, **kwargs):
execution_spec = args[3]
assert execution_spec.execution_cluster_label.value == "label"

mock_client.create_execution.side_effect = local_assertions

mock_entity = MagicMock()

remote._execute(
mock_entity,
inputs={},
project="proj",
domain="dev",
execution_cluster_label="label",
)


def test_underscore_execute_fall_back_remote_attributes(remote, mock_wf_exec):
mock_wf_exec.return_value = True
mock_client = MagicMock()
Expand Down
Loading