From 0caf47a85fe5c3ab3d6b05aacfcf2556c6292ec3 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Wed, 12 Jun 2024 09:06:19 +0800 Subject: [PATCH] Make the value in the label optional (#2465) Signed-off-by: Kevin Su Signed-off-by: bugra.gedik --- flytekit/clis/sdk_in_container/run.py | 4 ++-- flytekit/interaction/click_types.py | 16 ++++++++++++++++ tests/flytekit/unit/cli/pyflyte/test_run.py | 10 ++++++++++ 3 files changed, 28 insertions(+), 2 deletions(-) diff --git a/flytekit/clis/sdk_in_container/run.py b/flytekit/clis/sdk_in_container/run.py index b5278be053..47e16510cb 100644 --- a/flytekit/clis/sdk_in_container/run.py +++ b/flytekit/clis/sdk_in_container/run.py @@ -32,7 +32,7 @@ from flytekit.core.type_engine import TypeEngine from flytekit.core.workflow import PythonFunctionWorkflow, WorkflowBase from flytekit.exceptions.system import FlyteSystemException -from flytekit.interaction.click_types import FlyteLiteralConverter, key_value_callback +from flytekit.interaction.click_types import FlyteLiteralConverter, key_value_callback, labels_callback from flytekit.interaction.string_literals import literal_string_repr from flytekit.loggers import logger from flytekit.models import security @@ -174,7 +174,7 @@ class RunLevelParams(PyFlyteParams): multiple=True, type=str, show_default=True, - callback=key_value_callback, + callback=labels_callback, help="Labels to be attached to the execution of the format `label_key=label_value`.", ) ) diff --git a/flytekit/interaction/click_types.py b/flytekit/interaction/click_types.py index c16339a236..4eb597d8df 100644 --- a/flytekit/interaction/click_types.py +++ b/flytekit/interaction/click_types.py @@ -57,6 +57,22 @@ def key_value_callback(_: typing.Any, param: str, values: typing.List[str]) -> t return result +def labels_callback(_: typing.Any, param: str, values: typing.List[str]) -> typing.Optional[typing.Dict[str, str]]: + """ + Callback for click to parse labels. + """ + if not values: + return None + result = {} + for v in values: + if "=" not in v: + result[v.strip()] = "" + else: + k, v = v.split("=", 1) + result[k.strip()] = v.strip() + return result + + class DirParamType(click.ParamType): name = "directory path" diff --git a/tests/flytekit/unit/cli/pyflyte/test_run.py b/tests/flytekit/unit/cli/pyflyte/test_run.py index 28bee1dea7..6957828743 100644 --- a/tests/flytekit/unit/cli/pyflyte/test_run.py +++ b/tests/flytekit/unit/cli/pyflyte/test_run.py @@ -76,6 +76,16 @@ def test_pyflyte_run_wf(remote, remote_flag, workflow_file): assert result.exit_code == 0 +def test_pyflyte_run_with_labels(): + workflow_file = pathlib.Path(__file__).parent / "workflow.py" + with mock.patch("flytekit.configuration.plugin.FlyteRemote"): + runner = CliRunner() + result = runner.invoke( + pyflyte.main, ["run", "--remote", str(workflow_file), "my_wf", "--help"], catch_exceptions=False + ) + assert result.exit_code == 0 + + def test_imperative_wf(): runner = CliRunner() result = runner.invoke(