From 1713765348a518c52a4af3ae1abcd71064c7b055 Mon Sep 17 00:00:00 2001 From: Samhita Alla Date: Tue, 28 May 2024 20:24:54 +0530 Subject: [PATCH 01/32] update openai batch test and workflow (#2440) Signed-off-by: Samhita Alla --- flytekit/remote/remote.py | 5 ++++- flytekit/types/file/file.py | 3 --- flytekit/types/iterator/json_iterator.py | 5 ++++- .../flytekitplugins/openai/batch/agent.py | 14 ++++++-------- .../flytekitplugins/openai/batch/workflow.py | 9 ++++++++- .../tests/openai_batch/test_agent.py | 1 - 6 files changed, 22 insertions(+), 15 deletions(-) diff --git a/flytekit/remote/remote.py b/flytekit/remote/remote.py index 5e1d3fb589..5908462865 100644 --- a/flytekit/remote/remote.py +++ b/flytekit/remote/remote.py @@ -948,7 +948,10 @@ def _version_from_hash( h.update(bytes(s, "utf-8")) if default_inputs: - h.update(cloudpickle.dumps(default_inputs)) + try: + h.update(cloudpickle.dumps(default_inputs)) + except TypeError: # cannot pickle errors + logger.info("Skip pickling default inputs.") # Omit the character '=' from the version as that's essentially padding used by the base64 encoding # and does not increase entropy of the hash while making it very inconvenient to copy-and-paste. diff --git a/flytekit/types/file/file.py b/flytekit/types/file/file.py index 2995bd82f7..9b71eb5b4f 100644 --- a/flytekit/types/file/file.py +++ b/flytekit/types/file/file.py @@ -526,13 +526,10 @@ def _downloader(): return ff def guess_python_type(self, literal_type: LiteralType) -> typing.Type[FlyteFile[typing.Any]]: - from flytekit.types.iterator.json_iterator import JSONIteratorTransformer - if ( literal_type.blob is not None and literal_type.blob.dimensionality == BlobType.BlobDimensionality.SINGLE and literal_type.blob.format != FlytePickleTransformer.PYTHON_PICKLE_FORMAT - and literal_type.blob.format != JSONIteratorTransformer.JSON_ITERATOR_FORMAT ): return FlyteFile.__class_getitem__(literal_type.blob.format) diff --git a/flytekit/types/iterator/json_iterator.py b/flytekit/types/iterator/json_iterator.py index 52ab88497b..d8ed2ce570 100644 --- a/flytekit/types/iterator/json_iterator.py +++ b/flytekit/types/iterator/json_iterator.py @@ -40,6 +40,7 @@ class JSONIteratorTransformer(TypeTransformer[Iterator[JSON]]): """ JSON_ITERATOR_FORMAT = "jsonl" + JSON_ITERATOR_METADATA = "json iterator" def __init__(self): super().__init__("JSON Iterator", Iterator[JSON]) @@ -49,7 +50,8 @@ def get_literal_type(self, t: Type[Iterator[JSON]]) -> LiteralType: blob=_core_types.BlobType( format=self.JSON_ITERATOR_FORMAT, dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE, - ) + ), + metadata={"format": self.JSON_ITERATOR_METADATA}, ) def to_literal( @@ -103,6 +105,7 @@ def guess_python_type(self, literal_type: LiteralType) -> Type[Iterator[JSON]]: literal_type.blob is not None and literal_type.blob.dimensionality == _core_types.BlobType.BlobDimensionality.SINGLE and literal_type.blob.format == self.JSON_ITERATOR_FORMAT + and literal_type.metadata == {"format": self.JSON_ITERATOR_METADATA} ): return Iterator[JSON] # type: ignore diff --git a/plugins/flytekit-openai/flytekitplugins/openai/batch/agent.py b/plugins/flytekit-openai/flytekitplugins/openai/batch/agent.py index 2c9821b204..fa01383ca0 100644 --- a/plugins/flytekit-openai/flytekitplugins/openai/batch/agent.py +++ b/plugins/flytekit-openai/flytekitplugins/openai/batch/agent.py @@ -102,14 +102,12 @@ async def get( if data and data[0].message: message = data[0].message - outputs = {"result": {"result": None}} - if current_state in State.Success.value: - result = retrieved_result.to_dict() - - ctx = FlyteContextManager.current_context() - outputs = LiteralMap( - literals={"result": TypeEngine.to_literal(ctx, result, Dict, TypeEngine.to_literal_type(Dict))} - ) + result = retrieved_result.to_dict() + + ctx = FlyteContextManager.current_context() + outputs = LiteralMap( + literals={"result": TypeEngine.to_literal(ctx, result, Dict, TypeEngine.to_literal_type(Dict))} + ) return Resource(phase=flyte_phase, outputs=outputs, message=message) diff --git a/plugins/flytekit-openai/flytekitplugins/openai/batch/workflow.py b/plugins/flytekit-openai/flytekitplugins/openai/batch/workflow.py index 027f006b59..1f0ff30b51 100644 --- a/plugins/flytekit-openai/flytekitplugins/openai/batch/workflow.py +++ b/plugins/flytekit-openai/flytekitplugins/openai/batch/workflow.py @@ -1,6 +1,6 @@ from typing import Any, Dict, Iterator -from flytekit import Workflow +from flytekit import Resources, Workflow from flytekit.models.security import Secret from flytekit.types.file import JSONLFile from flytekit.types.iterator import JSON @@ -20,6 +20,8 @@ def create_batch( secret: Secret, config: Dict[str, Any] = {}, is_json_iterator: bool = True, + file_upload_mem: str = "700Mi", + file_download_mem: str = "700Mi", ) -> Workflow: """ Uploads JSON data to a JSONL file, creates a batch, waits for it to complete, and downloads the output/error JSON files. @@ -29,6 +31,8 @@ def create_batch( :param secret: Secret comprising the OpenAI API key. :param config: Additional config for batch creation. :param is_json_iterator: Set to True if you're sending an iterator/generator; if a JSONL file, set to False. + :param file_upload_mem: Memory to allocate to the upload file task. + :param file_download_mem: Memory to allocate to the download file task. """ wf = Workflow(name=f"openai-batch-{name.replace('.', '')}") @@ -64,6 +68,9 @@ def create_batch( batch_endpoint_result=node_2.outputs["result"], ) + node_1.with_overrides(requests=Resources(mem=file_upload_mem), limits=Resources(mem=file_upload_mem)) + node_3.with_overrides(requests=Resources(mem=file_download_mem), limits=Resources(mem=file_download_mem)) + wf.add_workflow_output("batch_output", node_3.outputs["result"], BatchResult) return wf diff --git a/plugins/flytekit-openai/tests/openai_batch/test_agent.py b/plugins/flytekit-openai/tests/openai_batch/test_agent.py index 3dde953741..d9352e918b 100644 --- a/plugins/flytekit-openai/tests/openai_batch/test_agent.py +++ b/plugins/flytekit-openai/tests/openai_batch/test_agent.py @@ -165,7 +165,6 @@ async def test_openai_batch_agent(mock_retrieve, mock_create, mock_context): mock_retrieve.return_value = batch_retrieve_result_failure resource = await agent.get(metadata) assert resource.phase == TaskExecution.FAILED - assert resource.outputs == {"result": {"result": None}} assert resource.message == "This line is not parseable as valid JSON." # CREATE From 1b6e02befeb5b68ca81ac6e302c5e88faf35d4d5 Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Tue, 28 May 2024 11:38:02 -0400 Subject: [PATCH 02/32] fix(ImageSpec): Do not build image during executions (#2410) Signed-off-by: Thomas J. Fan --- flytekit/image_spec/image_spec.py | 9 ++++++++- .../unit/core/image_spec/test_image_spec.py | 14 ++++++++++++++ 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/flytekit/image_spec/image_spec.py b/flytekit/image_spec/image_spec.py index 1a055fce84..98f6c05cdc 100644 --- a/flytekit/image_spec/image_spec.py +++ b/flytekit/image_spec/image_spec.py @@ -234,7 +234,14 @@ def register(cls, builder_type: str, image_spec_builder: ImageSpecBuilder, prior @classmethod @lru_cache - def build(cls, image_spec: ImageSpec) -> str: + def build(cls, image_spec: ImageSpec): + from flytekit.core.context_manager import FlyteContextManager + + execution_mode = FlyteContextManager.current_context().execution_state.mode + # Do not build in executions + if execution_mode is not None: + return + if isinstance(image_spec.base_image, ImageSpec): cls.build(image_spec.base_image) image_spec.base_image = image_spec.base_image.image_name() diff --git a/tests/flytekit/unit/core/image_spec/test_image_spec.py b/tests/flytekit/unit/core/image_spec/test_image_spec.py index 4d204e11a0..4a596c1e1e 100644 --- a/tests/flytekit/unit/core/image_spec/test_image_spec.py +++ b/tests/flytekit/unit/core/image_spec/test_image_spec.py @@ -122,3 +122,17 @@ def test_custom_tag(): ) spec_hash = calculate_hash_from_image_spec(spec) assert spec.image_name() == f"my_image:{spec_hash}-dev" + + +def test_no_build_during_execution(): + # Check that no builds are called during executions + ImageBuildEngine._build_image = Mock() + + ctx = context_manager.FlyteContext.current_context() + with context_manager.FlyteContextManager.with_context( + ctx.with_execution_state(ctx.execution_state.with_params(mode=ExecutionState.Mode.TASK_EXECUTION)) + ): + spec = ImageSpec(name="my_image_v2", python_version="3.12") + ImageBuildEngine.build(spec) + + ImageBuildEngine._build_image.assert_not_called() From f8355b3e2fe6d71614c23b54efc687e8cd4e6d24 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 28 May 2024 08:40:28 -0700 Subject: [PATCH 03/32] Bump requests from 2.31.0 to 2.32.2 in /plugins/flytekit-airflow (#2441) Bumps [requests](https://github.com/psf/requests) from 2.31.0 to 2.32.2. - [Release notes](https://github.com/psf/requests/releases) - [Changelog](https://github.com/psf/requests/blob/main/HISTORY.md) - [Commits](https://github.com/psf/requests/compare/v2.31.0...v2.32.2) --- updated-dependencies: - dependency-name: requests dependency-type: indirect ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- plugins/flytekit-airflow/dev-requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugins/flytekit-airflow/dev-requirements.txt b/plugins/flytekit-airflow/dev-requirements.txt index 2776c1fe9d..4ff135ad1a 100644 --- a/plugins/flytekit-airflow/dev-requirements.txt +++ b/plugins/flytekit-airflow/dev-requirements.txt @@ -846,7 +846,7 @@ referencing==0.30.2 # jsonschema-specifications regex==2023.10.3 # via apache-beam -requests==2.31.0 +requests==2.32.2 # via # apache-airflow-providers-http # apache-beam From 88ee93ec44dd44d362bfccec4a802a5b8b66613b Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Tue, 28 May 2024 23:47:37 +0800 Subject: [PATCH 04/32] Use default logger if fail to initialize RichHandler (#2423) Signed-off-by: Kevin Su --- flytekit/loggers.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/flytekit/loggers.py b/flytekit/loggers.py index 1de634de9c..1a0165f007 100644 --- a/flytekit/loggers.py +++ b/flytekit/loggers.py @@ -83,11 +83,6 @@ def initialize_global_loggers(): """ Initializes the global loggers to the default configuration. """ - # Use Rich logging while running in the local execution - if os.environ.get("FLYTE_INTERNAL_EXECUTION_ID", None) is None: - upgrade_to_rich_logging() - return - handler = logging.StreamHandler() handler.setLevel(logging.DEBUG) formatter = logging.Formatter(fmt="[%(name)s] %(message)s") @@ -98,6 +93,10 @@ def initialize_global_loggers(): set_flytekit_log_properties(handler, None, _get_env_logging_level()) set_user_logger_properties(handler, None, logging.INFO) + # Use Rich logging while running in the local execution + if os.environ.get("FLYTE_INTERNAL_EXECUTION_ID", None) is None or interactive.ipython_check(): + upgrade_to_rich_logging() + def is_rich_logging_enabled() -> bool: return os.environ.get(LOGGING_RICH_FMT_ENV_VAR) != "0" @@ -146,8 +145,5 @@ def get_level_from_cli_verbosity(verbosity: int) -> int: return logging.DEBUG -if interactive.ipython_check(): - upgrade_to_rich_logging() -else: - # Default initialization - initialize_global_loggers() +# Default initialization +initialize_global_loggers() From 1a0136833f54a4f9c4fd9a1c957008383e02f14e Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Tue, 28 May 2024 23:56:47 +0800 Subject: [PATCH 05/32] Remove databricks token from the task config (#2429) Signed-off-by: Kevin Su --- .../flytekit-spark/flytekitplugins/spark/models.py | 14 -------------- .../flytekit-spark/flytekitplugins/spark/task.py | 3 --- plugins/flytekit-spark/tests/test_spark_task.py | 3 --- 3 files changed, 20 deletions(-) diff --git a/plugins/flytekit-spark/flytekitplugins/spark/models.py b/plugins/flytekit-spark/flytekitplugins/spark/models.py index 28e67ac631..df8191304a 100644 --- a/plugins/flytekit-spark/flytekitplugins/spark/models.py +++ b/plugins/flytekit-spark/flytekitplugins/spark/models.py @@ -26,7 +26,6 @@ def __init__( hadoop_conf: Dict[str, str], executor_path: str, databricks_conf: Dict[str, Dict[str, Dict]] = {}, - databricks_token: Optional[str] = None, databricks_instance: Optional[str] = None, ): """ @@ -36,7 +35,6 @@ def __init__( :param dict[Text, Text] spark_conf: A definition of key-value pairs for spark config for the job. :param dict[Text, Text] hadoop_conf: A definition of key-value pairs for hadoop config for the job. :param Optional[dict[Text, dict]] databricks_conf: A definition of key-value pairs for databricks config for the job. Refer to https://docs.databricks.com/dev-tools/api/latest/jobs.html#operation/JobsRunsSubmit. - :param Optional[str] databricks_token: databricks access token. :param Optional[str] databricks_instance: Domain name of your deployment. Use the form .cloud.databricks.com. """ self._application_file = application_file @@ -46,7 +44,6 @@ def __init__( self._spark_conf = spark_conf self._hadoop_conf = hadoop_conf self._databricks_conf = databricks_conf - self._databricks_token = databricks_token self._databricks_instance = databricks_instance def with_overrides( @@ -71,7 +68,6 @@ def with_overrides( spark_conf=new_spark_conf, hadoop_conf=new_hadoop_conf, databricks_conf=new_databricks_conf, - databricks_token=self.databricks_token, databricks_instance=self.databricks_instance, executor_path=self.executor_path, ) @@ -133,14 +129,6 @@ def databricks_conf(self) -> Dict[str, Dict]: """ return self._databricks_conf - @property - def databricks_token(self) -> str: - """ - Databricks access token - :rtype: str - """ - return self._databricks_token - @property def databricks_instance(self) -> str: """ @@ -176,7 +164,6 @@ def to_flyte_idl(self): sparkConf=self.spark_conf, hadoopConf=self.hadoop_conf, databricksConf=databricks_conf, - databricksToken=self.databricks_token, databricksInstance=self.databricks_instance, ) @@ -203,6 +190,5 @@ def from_flyte_idl(cls, pb2_object): hadoop_conf=pb2_object.hadoopConf, executor_path=pb2_object.executorPath, databricks_conf=json_format.MessageToDict(pb2_object.databricksConf), - databricks_token=pb2_object.databricksToken, databricks_instance=pb2_object.databricksInstance, ) diff --git a/plugins/flytekit-spark/flytekitplugins/spark/task.py b/plugins/flytekit-spark/flytekitplugins/spark/task.py index 079cf8815c..8a8c3b2b5b 100644 --- a/plugins/flytekit-spark/flytekitplugins/spark/task.py +++ b/plugins/flytekit-spark/flytekitplugins/spark/task.py @@ -54,12 +54,10 @@ class Databricks(Spark): databricks_conf: Databricks job configuration compliant with API version 2.1, supporting 2.0 use cases. For the configuration structure, visit here.https://docs.databricks.com/dev-tools/api/2.0/jobs.html#request-structure For updates in API 2.1, refer to: https://docs.databricks.com/en/workflows/jobs/jobs-api-updates.html - databricks_token: Databricks access token. https://docs.databricks.com/dev-tools/api/latest/authentication.html. databricks_instance: Domain name of your deployment. Use the form .cloud.databricks.com. """ databricks_conf: Optional[Dict[str, Union[str, dict]]] = None - databricks_token: Optional[str] = None databricks_instance: Optional[str] = None @@ -156,7 +154,6 @@ def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]: if isinstance(self.task_config, Databricks): cfg = cast(Databricks, self.task_config) job._databricks_conf = cfg.databricks_conf - job._databricks_token = cfg.databricks_token job._databricks_instance = cfg.databricks_instance return MessageToDict(job.to_flyte_idl()) diff --git a/plugins/flytekit-spark/tests/test_spark_task.py b/plugins/flytekit-spark/tests/test_spark_task.py index 4c4db817e2..2a541b7f11 100644 --- a/plugins/flytekit-spark/tests/test_spark_task.py +++ b/plugins/flytekit-spark/tests/test_spark_task.py @@ -78,7 +78,6 @@ def my_spark(a: str) -> int: assert ("spark", "1") in configs assert ("spark.app.name", "FlyteSpark: ex:local:local:local") in configs - databricks_token = "token" databricks_instance = "account.cloud.databricks.com" @task( @@ -86,7 +85,6 @@ def my_spark(a: str) -> int: spark_conf={"spark": "2"}, databricks_conf=databricks_conf, databricks_instance="account.cloud.databricks.com", - databricks_token="token", ) ) def my_databricks(a: int) -> int: @@ -98,7 +96,6 @@ def my_databricks(a: int) -> int: assert my_databricks.task_config.spark_conf == {"spark": "2"} assert my_databricks.task_config.databricks_conf == databricks_conf assert my_databricks.task_config.databricks_instance == databricks_instance - assert my_databricks.task_config.databricks_token == databricks_token assert my_databricks(a=3) == 3 From 5c9cb975e5419782760fc7834ceb1ad3865864e3 Mon Sep 17 00:00:00 2001 From: Yee Hing Tong Date: Tue, 28 May 2024 10:10:35 -0700 Subject: [PATCH 06/32] Remote fetch array node (#2442) Signed-off-by: Yee Hing Tong --- flytekit/models/node_execution.py | 72 +++++-------------- flytekit/remote/entities.py | 7 +- flytekit/remote/executions.py | 4 +- flytekit/remote/remote.py | 15 +++- .../integration/remote/test_remote.py | 13 ++++ .../remote/workflows/basic/array_map.py | 14 ++++ .../unit/models/admin/test_node_executions.py | 6 -- 7 files changed, 69 insertions(+), 62 deletions(-) create mode 100644 tests/flytekit/integration/remote/workflows/basic/array_map.py diff --git a/flytekit/models/node_execution.py b/flytekit/models/node_execution.py index 50e685f3e8..8aced0707c 100644 --- a/flytekit/models/node_execution.py +++ b/flytekit/models/node_execution.py @@ -2,7 +2,7 @@ import typing from datetime import timezone as _timezone -import flyteidl.admin.node_execution_pb2 as _node_execution_pb2 +import flyteidl.admin.node_execution_pb2 as admin_node_execution_pb2 from flytekit.models import common as _common_models from flytekit.models.core import catalog as catalog_models @@ -19,13 +19,13 @@ def __init__(self, execution_id: _identifier.WorkflowExecutionIdentifier): def execution_id(self) -> _identifier.WorkflowExecutionIdentifier: return self._execution_id - def to_flyte_idl(self) -> _node_execution_pb2.WorkflowNodeMetadata: - return _node_execution_pb2.WorkflowNodeMetadata( + def to_flyte_idl(self) -> admin_node_execution_pb2.WorkflowNodeMetadata: + return admin_node_execution_pb2.WorkflowNodeMetadata( executionId=self.execution_id.to_flyte_idl(), ) @classmethod - def from_flyte_idl(cls, p: _node_execution_pb2.WorkflowNodeMetadata) -> "WorkflowNodeMetadata": + def from_flyte_idl(cls, p: admin_node_execution_pb2.WorkflowNodeMetadata) -> "WorkflowNodeMetadata": return cls( execution_id=_identifier.WorkflowExecutionIdentifier.from_flyte_idl(p.executionId), ) @@ -44,14 +44,14 @@ def id(self) -> _identifier.Identifier: def compiled_workflow(self) -> core_compiler_models.CompiledWorkflowClosure: return self._compiled_workflow - def to_flyte_idl(self) -> _node_execution_pb2.DynamicWorkflowNodeMetadata: - return _node_execution_pb2.DynamicWorkflowNodeMetadata( + def to_flyte_idl(self) -> admin_node_execution_pb2.DynamicWorkflowNodeMetadata: + return admin_node_execution_pb2.DynamicWorkflowNodeMetadata( id=self.id.to_flyte_idl(), compiled_workflow=self.compiled_workflow.to_flyte_idl(), ) @classmethod - def from_flyte_idl(cls, p: _node_execution_pb2.DynamicWorkflowNodeMetadata) -> "DynamicWorkflowNodeMetadata": + def from_flyte_idl(cls, p: admin_node_execution_pb2.DynamicWorkflowNodeMetadata) -> "DynamicWorkflowNodeMetadata": yy = cls( id=_identifier.Identifier.from_flyte_idl(p.id), compiled_workflow=core_compiler_models.CompiledWorkflowClosure.from_flyte_idl(p.compiled_workflow), @@ -72,14 +72,14 @@ def cache_status(self) -> int: def catalog_key(self) -> catalog_models.CatalogMetadata: return self._catalog_key - def to_flyte_idl(self) -> _node_execution_pb2.TaskNodeMetadata: - return _node_execution_pb2.TaskNodeMetadata( + def to_flyte_idl(self) -> admin_node_execution_pb2.TaskNodeMetadata: + return admin_node_execution_pb2.TaskNodeMetadata( cache_status=self.cache_status, catalog_key=self.catalog_key.to_flyte_idl(), ) @classmethod - def from_flyte_idl(cls, p: _node_execution_pb2.TaskNodeMetadata) -> "TaskNodeMetadata": + def from_flyte_idl(cls, p: admin_node_execution_pb2.TaskNodeMetadata) -> "TaskNodeMetadata": return cls( cache_status=p.cache_status, catalog_key=catalog_models.CatalogMetadata.from_flyte_idl(p.catalog_key), @@ -185,7 +185,7 @@ def to_flyte_idl(self): """ :rtype: flyteidl.admin.node_execution_pb2.NodeExecutionClosure """ - obj = _node_execution_pb2.NodeExecutionClosure( + obj = admin_node_execution_pb2.NodeExecutionClosure( phase=self.phase, output_uri=self.output_uri, deck_uri=self.deck_uri, @@ -227,47 +227,13 @@ def from_flyte_idl(cls, p): ) -class NodeExecutionMetaData(_common_models.FlyteIdlEntity): - def __init__(self, retry_group: str, is_parent_node: bool, spec_node_id: str): - self._retry_group = retry_group - self._is_parent_node = is_parent_node - self._spec_node_id = spec_node_id - - @property - def retry_group(self) -> str: - return self._retry_group - - @property - def is_parent_node(self) -> bool: - return self._is_parent_node - - @property - def spec_node_id(self) -> str: - return self._spec_node_id - - def to_flyte_idl(self) -> _node_execution_pb2.NodeExecutionMetaData: - return _node_execution_pb2.NodeExecutionMetaData( - retry_group=self.retry_group, - is_parent_node=self.is_parent_node, - spec_node_id=self.spec_node_id, - ) - - @classmethod - def from_flyte_idl(cls, p: _node_execution_pb2.NodeExecutionMetaData) -> "NodeExecutionMetaData": - return cls( - retry_group=p.retry_group, - is_parent_node=p.is_parent_node, - spec_node_id=p.spec_node_id, - ) - - class NodeExecution(_common_models.FlyteIdlEntity): - def __init__(self, id, input_uri, closure, metadata): + def __init__(self, id, input_uri, closure, metadata: admin_node_execution_pb2.NodeExecutionMetaData): """ :param flytekit.models.core.identifier.NodeExecutionIdentifier id: :param Text input_uri: :param NodeExecutionClosure closure: - :param NodeExecutionMetaData metadata: + :param metadata: """ self._id = id self._input_uri = input_uri @@ -296,22 +262,22 @@ def closure(self): return self._closure @property - def metadata(self) -> NodeExecutionMetaData: + def metadata(self) -> admin_node_execution_pb2.NodeExecutionMetaData: return self._metadata - def to_flyte_idl(self) -> _node_execution_pb2.NodeExecution: - return _node_execution_pb2.NodeExecution( + def to_flyte_idl(self) -> admin_node_execution_pb2.NodeExecution: + return admin_node_execution_pb2.NodeExecution( id=self.id.to_flyte_idl(), input_uri=self.input_uri, closure=self.closure.to_flyte_idl(), - metadata=self.metadata.to_flyte_idl(), + metadata=self.metadata, ) @classmethod - def from_flyte_idl(cls, p: _node_execution_pb2.NodeExecution) -> "NodeExecution": + def from_flyte_idl(cls, p: admin_node_execution_pb2.NodeExecution) -> "NodeExecution": return cls( id=_identifier.NodeExecutionIdentifier.from_flyte_idl(p.id), input_uri=p.input_uri, closure=NodeExecutionClosure.from_flyte_idl(p.closure), - metadata=NodeExecutionMetaData.from_flyte_idl(p.metadata), + metadata=p.metadata, ) diff --git a/flytekit/remote/entities.py b/flytekit/remote/entities.py index 2af0db3afb..1f09ebb19d 100644 --- a/flytekit/remote/entities.py +++ b/flytekit/remote/entities.py @@ -349,7 +349,12 @@ def promote_from_model(cls, model: _workflow_model.GateNode): class FlyteArrayNode(_workflow_model.ArrayNode): @classmethod def promote_from_model(cls, model: _workflow_model.ArrayNode): - return cls(model._parallelism, model._node, model._min_success_ratio, model._min_successes) + return cls( + node=model._node, + parallelism=model._parallelism, + min_successes=model._min_successes, + min_success_ratio=model._min_success_ratio, + ) class FlyteNode(_hash_mixin.HashOnReferenceMixin, _workflow_model.Node): diff --git a/flytekit/remote/executions.py b/flytekit/remote/executions.py index bd5e182952..c06ee06739 100644 --- a/flytekit/remote/executions.py +++ b/flytekit/remote/executions.py @@ -1,5 +1,6 @@ from __future__ import annotations +import typing from abc import abstractmethod from typing import Dict, List, Optional, Union @@ -9,6 +10,7 @@ from flytekit.models import node_execution as node_execution_models from flytekit.models.admin import task_execution as admin_task_execution_models from flytekit.models.core import execution as core_execution_models +from flytekit.models.interface import TypedInterface from flytekit.remote.entities import FlyteTask, FlyteWorkflow @@ -148,7 +150,7 @@ def __init__(self, *args, **kwargs): self._task_executions = None self._workflow_executions = [] self._underlying_node_executions = None - self._interface = None + self._interface: typing.Optional[TypedInterface] = None self._flyte_node = None @property diff --git a/flytekit/remote/remote.py b/flytekit/remote/remote.py index 5908462865..584501f137 100644 --- a/flytekit/remote/remote.py +++ b/flytekit/remote/remote.py @@ -2062,7 +2062,7 @@ def sync_node_execution( return execution # If a node ran a static subworkflow or a dynamic subworkflow then the parent flag will be set. - if execution.metadata.is_parent_node: + if execution.metadata.is_parent_node or execution.metadata.is_array: # We'll need to query child node executions regardless since this is a parent node child_node_executions = iterate_node_executions( self.client, @@ -2115,6 +2115,19 @@ def sync_node_execution( "not have inputs and outputs filled in" ) return execution + elif execution._node.array_node is not None: + # if there's a task node underneath the array node, let's fetch the interface for it + if execution._node.array_node.node.task_node is not None: + tid = execution._node.array_node.node.task_node.reference_id + t = self.fetch_task(tid.project, tid.domain, tid.name, tid.version) + if t.interface: + execution._interface = t.interface + else: + logger.error(f"Fetched map task does not have an interface, skipping i/o {t}") + return execution + else: + logger.error(f"Array node not over task, skipping i/o {t}") + return execution else: logger.error(f"NE {execution} undeterminable, {type(execution._node)}, {execution._node}") raise Exception(f"Node execution undeterminable, entity has type {type(execution._node)}") diff --git a/tests/flytekit/integration/remote/test_remote.py b/tests/flytekit/integration/remote/test_remote.py index 78b30c8276..f81031361c 100644 --- a/tests/flytekit/integration/remote/test_remote.py +++ b/tests/flytekit/integration/remote/test_remote.py @@ -468,3 +468,16 @@ def my_wf(a: int, b: str) -> (int, str): assert execution.spec.envs.envs == {"foo": "bar"} assert execution.spec.tags == ["flyte"] assert execution.spec.cluster_assignment.cluster_pool == "gpu" + + +def test_execute_workflow_with_maptask(register): + remote = FlyteRemote(Config.auto(config_file=CONFIG), PROJECT, DOMAIN) + d: typing.List[int] = [1, 2, 3] + flyte_launch_plan = remote.fetch_launch_plan(name="basic.array_map.workflow_with_maptask", version=VERSION) + execution = remote.execute( + flyte_launch_plan, + inputs={"data": d, "y": 3}, + version=VERSION, + wait=True, + ) + assert execution.outputs["o0"] == [4, 5, 6] diff --git a/tests/flytekit/integration/remote/workflows/basic/array_map.py b/tests/flytekit/integration/remote/workflows/basic/array_map.py new file mode 100644 index 0000000000..24bbafd15b --- /dev/null +++ b/tests/flytekit/integration/remote/workflows/basic/array_map.py @@ -0,0 +1,14 @@ +from functools import partial + +from flytekit import map_task, task, workflow + + +@task +def fn(x: int, y: int) -> int: + return x + y + + +@workflow +def workflow_with_maptask(data: list[int], y: int) -> list[int]: + partial_fn = partial(fn, y=y) + return map_task(partial_fn)(x=data) diff --git a/tests/flytekit/unit/models/admin/test_node_executions.py b/tests/flytekit/unit/models/admin/test_node_executions.py index b4cd77e5e8..252a91bcc1 100644 --- a/tests/flytekit/unit/models/admin/test_node_executions.py +++ b/tests/flytekit/unit/models/admin/test_node_executions.py @@ -3,12 +3,6 @@ from tests.flytekit.unit.common_tests.test_workflow_promote import get_compiled_workflow_closure -def test_metadata(): - md = node_execution_models.NodeExecutionMetaData(retry_group="0", is_parent_node=True, spec_node_id="n0") - md2 = node_execution_models.NodeExecutionMetaData.from_flyte_idl(md.to_flyte_idl()) - assert md == md2 - - def test_workflow_node_metadata(): wf_exec_id = identifier.WorkflowExecutionIdentifier("project", "domain", "name") From 0e2ca78222204cfbda11ebf16813babc3590236a Mon Sep 17 00:00:00 2001 From: Samhita Alla Date: Thu, 30 May 2024 21:41:05 +0530 Subject: [PATCH 07/32] add output prefix to `do()` method in sync agent & update boto agent (#2450) Signed-off-by: Kevin Su Signed-off-by: Samhita Alla Co-authored-by: Kevin Su --- flytekit/extend/backend/agent_service.py | 5 ++- flytekit/extend/backend/base_agent.py | 15 +++++--- .../awssagemaker_inference/boto3_agent.py | 36 +++++++++++++------ .../awssagemaker_inference/boto3_mixin.py | 3 -- .../tests/test_boto3_agent.py | 6 +++- pyproject.toml | 2 +- 6 files changed, 47 insertions(+), 20 deletions(-) diff --git a/flytekit/extend/backend/agent_service.py b/flytekit/extend/backend/agent_service.py index eb2838ca41..a92cef8e36 100644 --- a/flytekit/extend/backend/agent_service.py +++ b/flytekit/extend/backend/agent_service.py @@ -164,6 +164,7 @@ async def ExecuteTaskSync( ) -> typing.AsyncIterator[ExecuteTaskSyncResponse]: request = await request_iterator.__anext__() template = TaskTemplate.from_flyte_idl(request.header.template) + output_prefix = request.header.output_prefix task_type = template.type try: with request_latency.labels(task_type=task_type, operation=do_operation).time(): @@ -173,7 +174,9 @@ async def ExecuteTaskSync( request = await request_iterator.__anext__() literal_map = LiteralMap.from_flyte_idl(request.inputs) if request.inputs else None - res = await mirror_async_methods(agent.do, task_template=template, inputs=literal_map) + res = await mirror_async_methods( + agent.do, task_template=template, inputs=literal_map, output_prefix=output_prefix + ) if res.outputs is None: outputs = None diff --git a/flytekit/extend/backend/base_agent.py b/flytekit/extend/backend/base_agent.py index 4d1d8956da..33a03e282b 100644 --- a/flytekit/extend/backend/base_agent.py +++ b/flytekit/extend/backend/base_agent.py @@ -119,7 +119,7 @@ class SyncAgentBase(AgentBase): name = "Base Sync Agent" @abstractmethod - def do(self, task_template: TaskTemplate, inputs: Optional[LiteralMap], **kwargs) -> Resource: + def do(self, task_template: TaskTemplate, inputs: Optional[LiteralMap], output_prefix: str, **kwargs) -> Resource: """ This is the method that the agent will run. """ @@ -243,10 +243,11 @@ def execute(self: PythonTask, **kwargs) -> LiteralMap: ctx = FlyteContext.current_context() ss = ctx.serialization_settings or SerializationSettings(ImageConfig()) task_template = get_serializable(OrderedDict(), ss, self).template + output_prefix = ctx.file_access.get_random_remote_directory() agent = AgentRegistry.get_agent(task_template.type, task_template.task_type_version) - resource = asyncio.run(self._do(agent, task_template, kwargs)) + resource = asyncio.run(self._do(agent, task_template, output_prefix, kwargs)) if resource.phase != TaskExecution.SUCCEEDED: raise FlyteUserException(f"Failed to run the task {self.name} with error: {resource.message}") @@ -255,12 +256,18 @@ def execute(self: PythonTask, **kwargs) -> LiteralMap: return resource.outputs async def _do( - self: PythonTask, agent: SyncAgentBase, template: TaskTemplate, inputs: Dict[str, Any] = None + self: PythonTask, + agent: SyncAgentBase, + template: TaskTemplate, + output_prefix: str, + inputs: Dict[str, Any] = None, ) -> Resource: try: ctx = FlyteContext.current_context() literal_map = TypeEngine.dict_to_literal_map(ctx, inputs or {}, self.get_input_types()) - return await mirror_async_methods(agent.do, task_template=template, inputs=literal_map) + return await mirror_async_methods( + agent.do, task_template=template, inputs=literal_map, output_prefix=output_prefix + ) except Exception as error_message: raise FlyteUserException(f"Failed to run the task {self.name} with error: {error_message}") diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/boto3_agent.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/boto3_agent.py index 314358c267..f5624127fb 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/boto3_agent.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/boto3_agent.py @@ -4,6 +4,8 @@ from typing_extensions import Annotated from flytekit import FlyteContextManager, kwtypes +from flytekit.core import context_manager +from flytekit.core.data_persistence import FileAccessProvider from flytekit.core.type_engine import TypeEngine from flytekit.extend.backend.base_agent import ( AgentRegistry, @@ -37,7 +39,13 @@ class BotoAgent(SyncAgentBase): def __init__(self): super().__init__(task_type_name="boto") - async def do(self, task_template: TaskTemplate, inputs: Optional[LiteralMap] = None, **kwargs) -> Resource: + async def do( + self, + task_template: TaskTemplate, + output_prefix: str, + inputs: Optional[LiteralMap] = None, + **kwargs, + ) -> Resource: custom = task_template.custom service = custom.get("service") @@ -60,16 +68,24 @@ async def do(self, task_template: TaskTemplate, inputs: Optional[LiteralMap] = N outputs = {"result": {"result": None}} if result: ctx = FlyteContextManager.current_context() - outputs = LiteralMap( - literals={ - "result": TypeEngine.to_literal( - ctx, - result, - Annotated[dict, kwtypes(allow_pickle=True)], - TypeEngine.to_literal_type(dict), - ) - } + builder = ctx.with_file_access( + FileAccessProvider( + local_sandbox_dir=ctx.file_access.local_sandbox_dir, + raw_output_prefix=output_prefix, + data_config=ctx.file_access.data_config, + ) ) + with context_manager.FlyteContextManager.with_context(builder) as new_ctx: + outputs = LiteralMap( + literals={ + "result": TypeEngine.to_literal( + new_ctx, + result, + Annotated[dict, kwtypes(allow_pickle=True)], + TypeEngine.to_literal_type(dict), + ) + } + ) return Resource(phase=TaskExecution.SUCCEEDED, outputs=outputs) diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/boto3_mixin.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/boto3_mixin.py index 98b1c513f1..c2596750fc 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/boto3_mixin.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/boto3_mixin.py @@ -132,9 +132,6 @@ async def _call( :param images: A dict of Docker images to use, for example, when deploying a model on SageMaker. :param inputs: The inputs for the task being created. :param region: The region for the boto3 client. If not provided, the region specified in the constructor will be used. - :param aws_access_key_id: The access key ID to use to access the AWS resources. - :param aws_secret_access_key: The secret access key to use to access the AWS resources - :param aws_session_token: An AWS session token used as part of the credentials to authenticate the user. """ args = {} input_region = None diff --git a/plugins/flytekit-aws-sagemaker/tests/test_boto3_agent.py b/plugins/flytekit-aws-sagemaker/tests/test_boto3_agent.py index ad72a0b7ac..f17e50ea6f 100644 --- a/plugins/flytekit-aws-sagemaker/tests/test_boto3_agent.py +++ b/plugins/flytekit-aws-sagemaker/tests/test_boto3_agent.py @@ -4,6 +4,7 @@ import pytest from flyteidl.core.execution_pb2 import TaskExecution +from flytekit import FlyteContext from flytekit.extend.backend.base_agent import AgentRegistry from flytekit.interaction.string_literals import literal_map_string_repr from flytekit.interfaces.cli_identifiers import Identifier @@ -115,7 +116,10 @@ async def test_agent(mock_boto_call, mock_return_value): }, ) - resource = await agent.do(task_template, task_inputs) + ctx = FlyteContext.current_context() + output_prefix = ctx.file_access.get_random_remote_directory() + resource = await agent.do(task_template=task_template, inputs=task_inputs, output_prefix=output_prefix) + assert resource.phase == TaskExecution.SUCCEEDED if mock_return_value: diff --git a/pyproject.toml b/pyproject.toml index d3bcf722ec..126f05050a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,7 +43,7 @@ dependencies = [ "python-json-logger>=2.0.0", "pytimeparse>=1.1.8,<2.0.0", "pyyaml!=6.0.0,!=5.4.0,!=5.4.1", # pyyaml is broken with cython 3: https://github.com/yaml/pyyaml/issues/601 - "requests>=2.18.4,<3.0.0,!=2.32.0,!=2.32.1,!=2.32.2", + "requests>=2.18.4,<3.0.0,!=2.32.0,!=2.32.1,!=2.32.2,!=2.32.3", "rich", "rich_click", "s3fs>=2023.3.0,!=2024.3.1", From 59413c2926311a42e691aaf7ec32416201f16614 Mon Sep 17 00:00:00 2001 From: "Fabio M. Graetz, Ph.D" Date: Thu, 30 May 2024 20:55:09 +0200 Subject: [PATCH 08/32] Make error messages in LaunchPlan.get_or_create actionable (#2451) Signed-off-by: Fabio Graetz --- flytekit/core/launch_plan.py | 35 ++++++++++++++++++++++------------- 1 file changed, 22 insertions(+), 13 deletions(-) diff --git a/flytekit/core/launch_plan.py b/flytekit/core/launch_plan.py index 0b097ad847..9018184837 100644 --- a/flytekit/core/launch_plan.py +++ b/flytekit/core/launch_plan.py @@ -269,19 +269,28 @@ def get_or_create( ), ) - if ( - workflow != cached_outputs["_workflow"] - or schedule != cached_outputs["_schedule"] - or notifications != cached_outputs["_notifications"] - or default_inputs != cached_outputs["_saved_inputs"] - or labels != cached_outputs["_labels"] - or annotations != cached_outputs["_annotations"] - or raw_output_data_config != cached_outputs["_raw_output_data_config"] - or max_parallelism != cached_outputs["_max_parallelism"] - or security_context != cached_outputs["_security_context"] - or overwrite_cache != cached_outputs["_overwrite_cache"] - ): - raise AssertionError("The cached values aren't the same as the current call arguments") + if workflow != cached_outputs["_workflow"]: + raise AssertionError( + f"Trying to create two launch plans both named '{name}' for the workflows '{workflow.name}' " + f"and '{cached_outputs['_workflow'].name}' - please ensure unique names." + ) + + for arg_name, new, cached in [ + ("schedule", schedule, cached_outputs["_schedule"]), + ("notifications", notifications, cached_outputs["_notifications"]), + ("default_inputs", default_inputs, cached_outputs["_saved_inputs"]), + ("labels", labels, cached_outputs["_labels"]), + ("annotations", annotations, cached_outputs["_annotations"]), + ("raw_output_data_config", raw_output_data_config, cached_outputs["_raw_output_data_config"]), + ("max_parallelism", max_parallelism, cached_outputs["_max_parallelism"]), + ("security_context", security_context, cached_outputs["_security_context"]), + ("overwrite_cache", overwrite_cache, cached_outputs["_overwrite_cache"]), + ]: + if new != cached: + raise AssertionError( + f"Trying to create two launch plans for workflow '{workflow.name}' both named '{name}' " + f"but with different values for '{arg_name}' - please use different launch plan names." + ) return LaunchPlan.CACHE[name] elif name is None and workflow.name in LaunchPlan.CACHE: From 6d4041358f67ea7f17ee22c29cae2c31cae2f106 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Thu, 30 May 2024 20:00:32 -0700 Subject: [PATCH 09/32] Bump jinja2 (#2392) Bumps [jinja2](https://github.com/pallets/jinja) from 3.1.3 to 3.1.4. - [Release notes](https://github.com/pallets/jinja/releases) - [Changelog](https://github.com/pallets/jinja/blob/main/CHANGES.rst) - [Commits](https://github.com/pallets/jinja/compare/3.1.3...3.1.4) --- updated-dependencies: - dependency-name: jinja2 dependency-type: direct:production ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .../remote/mock_flyte_repo/workflows/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/flytekit/integration/remote/mock_flyte_repo/workflows/requirements.txt b/tests/flytekit/integration/remote/mock_flyte_repo/workflows/requirements.txt index 6822279040..164b9db345 100644 --- a/tests/flytekit/integration/remote/mock_flyte_repo/workflows/requirements.txt +++ b/tests/flytekit/integration/remote/mock_flyte_repo/workflows/requirements.txt @@ -155,7 +155,7 @@ jeepney==0.8.0 # via # keyring # secretstorage -jinja2==3.1.3 +jinja2==3.1.4 # via cookiecutter jmespath==1.0.1 # via botocore From 69a3218abbc787936074b17bd09435cc48e62596 Mon Sep 17 00:00:00 2001 From: Future-Outlier Date: Mon, 3 Jun 2024 11:19:35 +0800 Subject: [PATCH 10/32] Fix DBT plugin test (#2454) Signed-off-by: Future-Outlier Signed-off-by: Kevin Su Co-authored-by: pingsutw --- dev-requirements.txt | 198 ++++++++++++---------- plugins/flytekit-dbt/dev-requirements.in | 4 +- plugins/flytekit-dbt/dev-requirements.txt | 124 -------------- plugins/flytekit-dbt/setup.py | 5 +- 4 files changed, 111 insertions(+), 220 deletions(-) delete mode 100644 plugins/flytekit-dbt/dev-requirements.txt diff --git a/dev-requirements.txt b/dev-requirements.txt index 28fab54d4e..d3279adcac 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -4,13 +4,13 @@ # # pip-compile dev-requirements.in # --e file:.#egg=flytekit +-e file:. # via -r dev-requirements.in -adlfs==2023.9.0 +adlfs==2024.4.1 # via flytekit -aiobotocore==2.5.4 +aiobotocore==2.13.0 # via s3fs -aiohttp==3.9.3 +aiohttp==3.9.5 # via # adlfs # aiobotocore @@ -27,22 +27,22 @@ attrs==23.2.0 # aiohttp # hypothesis # jsonlines -autoflake==2.2.1 +autoflake==2.3.1 # via -r dev-requirements.in -azure-core==1.30.0 +azure-core==1.30.1 # via # adlfs # azure-identity # azure-storage-blob azure-datalake-store==0.0.53 # via adlfs -azure-identity==1.15.0 +azure-identity==1.16.0 # via adlfs -azure-storage-blob==12.19.0 +azure-storage-blob==12.20.0 # via adlfs -botocore==1.31.17 +botocore==1.34.106 # via aiobotocore -cachetools==5.3.2 +cachetools==5.3.3 # via google-auth certifi==2024.2.2 # via @@ -62,15 +62,15 @@ click==8.1.7 # rich-click cloudpickle==3.0.0 # via flytekit -codespell==2.2.6 +codespell==2.3.0 # via -r dev-requirements.in -coverage[toml]==7.4.1 +coverage[toml]==7.5.3 # via # -r dev-requirements.in # pytest-cov -croniter==2.0.1 +croniter==2.0.5 # via flytekit -cryptography==42.0.2 +cryptography==42.0.7 # via # azure-identity # azure-storage-blob @@ -88,13 +88,13 @@ distlib==0.3.8 # via virtualenv docker==6.1.3 # via flytekit -docstring-parser==0.15 +docstring-parser==0.16 # via flytekit -execnet==2.0.2 +execnet==2.1.1 # via pytest-xdist executing==2.0.1 # via stack-data -filelock==3.13.1 +filelock==3.14.0 # via virtualenv flyteidl @ git+https://github.com/flyteorg/flyte.git@master#subdirectory=flyteidl # via @@ -104,39 +104,41 @@ frozenlist==1.4.1 # via # aiohttp # aiosignal -fsspec==2023.9.2 +fsspec==2024.5.0 # via # adlfs # flytekit # gcsfs # s3fs -gcsfs==2023.9.2 +gcsfs==2024.5.0 # via flytekit -google-api-core[grpc]==2.16.2 +google-api-core[grpc]==2.19.0 # via # google-cloud-bigquery # google-cloud-bigquery-storage # google-cloud-core # google-cloud-storage -google-auth==2.27.0 +google-auth==2.29.0 # via # gcsfs # google-api-core # google-auth-oauthlib + # google-cloud-bigquery + # google-cloud-bigquery-storage # google-cloud-core # google-cloud-storage # kubernetes google-auth-oauthlib==1.2.0 # via gcsfs -google-cloud-bigquery==3.17.1 +google-cloud-bigquery==3.23.1 # via -r dev-requirements.in -google-cloud-bigquery-storage==2.24.0 +google-cloud-bigquery-storage==2.25.0 # via -r dev-requirements.in google-cloud-core==2.4.1 # via # google-cloud-bigquery # google-cloud-storage -google-cloud-storage==2.14.0 +google-cloud-storage==2.16.0 # via gcsfs google-crc32c==1.5.0 # via @@ -146,64 +148,72 @@ google-resumable-media==2.7.0 # via # google-cloud-bigquery # google-cloud-storage -googleapis-common-protos==1.62.0 +googleapis-common-protos==1.63.0 # via # flyteidl # flytekit # google-api-core # grpcio-status # protoc-gen-openapiv2 -grpcio==1.60.1 +grpcio==1.64.0 # via # flytekit # google-api-core # grpcio-status -grpcio-status==1.60.1 +grpcio-status==1.62.2 # via # flytekit # google-api-core -hypothesis==6.98.2 +hypothesis==6.103.0 # via -r dev-requirements.in icdiff==2.0.7 # via pytest-icdiff -identify==2.5.33 +identify==2.5.36 # via pre-commit -idna==3.6 +idna==3.7 # via # requests # yarl -importlib-metadata==7.0.1 +importlib-metadata==7.1.0 # via flytekit iniconfig==2.0.0 # via pytest -ipython==8.21.0 +ipython==8.25.0 # via -r dev-requirements.in isodate==0.6.1 # via # azure-storage-blob # flytekit -jaraco-classes==3.3.0 +jaraco-classes==3.4.0 # via # keyring # keyrings-alt +jaraco-context==5.3.0 + # via + # keyring + # keyrings-alt +jaraco-functools==4.0.1 + # via keyring jedi==0.19.1 # via ipython -jinja2==3.1.3 - # via -r dev-requirements.in +jinja2==3.1.4 + # via + # -r dev-requirements.in + # flytekit jmespath==1.0.1 # via botocore -joblib==1.3.2 +joblib==1.4.2 # via # -r dev-requirements.in # flytekit # scikit-learn jsonlines==4.0.0 # via flytekit -jsonpickle==3.0.2 +jsonpickle==3.0.4 # via flytekit -keyring==24.3.0 +keyring==25.2.1 # via flytekit -keyrings-alt==5.0.0 +keyrings-alt==5.0.1 # via -r dev-requirements.in kubernetes==29.0.0 # via -r dev-requirements.in @@ -213,7 +223,7 @@ markdown-it-py==3.0.0 # rich markupsafe==2.1.5 # via jinja2 -marshmallow==3.20.2 +marshmallow==3.21.2 # via # dataclasses-json # marshmallow-enum @@ -224,17 +234,19 @@ marshmallow-enum==1.5.1 # flytekit marshmallow-jsonschema==0.13.0 # via flytekit -mashumaro==3.12 +mashumaro==3.13 # via flytekit -matplotlib-inline==0.1.6 +matplotlib-inline==0.1.7 # via ipython mdurl==0.1.2 # via markdown-it-py mock==5.1.0 # via -r dev-requirements.in more-itertools==10.2.0 - # via jaraco-classes -msal==1.26.0 + # via + # jaraco-classes + # jaraco-functools +msal==1.28.0 # via # azure-datalake-store # azure-identity @@ -251,7 +263,7 @@ mypy-extensions==1.0.0 # via # mypy # typing-inspect -nodeenv==1.8.0 +nodeenv==1.9.0 # via pre-commit numpy==1.26.4 # via @@ -264,9 +276,9 @@ oauthlib==3.2.2 # via # kubernetes # requests-oauthlib -orjson==3.9.12 +orjson==3.10.3 # via -r dev-requirements.in -packaging==23.2 +packaging==24.0 # via # docker # google-cloud-bigquery @@ -274,31 +286,33 @@ packaging==23.2 # msal-extensions # pytest # setuptools-scm -pandas==2.2.0 +pandas==2.2.2 # via -r dev-requirements.in -parso==0.8.3 +parso==0.8.4 # via jedi pexpect==4.9.0 # via ipython -pillow==10.2.0 +pillow==10.3.0 # via -r dev-requirements.in -platformdirs==4.2.0 +platformdirs==4.2.2 # via virtualenv -pluggy==1.4.0 +pluggy==1.5.0 # via pytest portalocker==2.8.2 # via msal-extensions pprintpp==0.4.0 # via pytest-icdiff -pre-commit==3.6.0 +pre-commit==3.7.1 # via -r dev-requirements.in -prometheus-client==0.19.0 +prometheus-client==0.20.0 # via -r dev-requirements.in -prompt-toolkit==3.0.43 +prompt-toolkit==3.0.45 # via ipython proto-plus==1.23.0 - # via google-cloud-bigquery-storage -protobuf==4.23.4 + # via + # google-api-core + # google-cloud-bigquery-storage +protobuf==4.25.3 # via # flyteidl # flytekit @@ -314,19 +328,19 @@ ptyprocess==0.7.0 # via pexpect pure-eval==0.2.2 # via stack-data -pyarrow==15.0.0 +pyarrow==16.1.0 # via flytekit -pyasn1==0.5.1 +pyasn1==0.6.0 # via # pyasn1-modules # rsa -pyasn1-modules==0.3.0 +pyasn1-modules==0.4.0 # via google-auth -pycparser==2.21 +pycparser==2.22 # via cffi pyflakes==3.2.0 # via autoflake -pygments==2.17.2 +pygments==2.18.0 # via # flytekit # ipython @@ -335,7 +349,7 @@ pyjwt[crypto]==2.8.0 # via # msal # pyjwt -pytest==7.4.4 +pytest==8.2.1 # via # -r dev-requirements.in # pytest-asyncio @@ -344,19 +358,19 @@ pytest==7.4.4 # pytest-mock # pytest-timeout # pytest-xdist -pytest-asyncio==0.23.4 +pytest-asyncio==0.23.7 # via -r dev-requirements.in -pytest-cov==4.1.0 +pytest-cov==5.0.0 # via -r dev-requirements.in pytest-icdiff==0.9 # via -r dev-requirements.in -pytest-mock==3.12.0 +pytest-mock==3.14.0 # via -r dev-requirements.in -pytest-timeout==2.2.0 +pytest-timeout==2.3.1 # via -r dev-requirements.in -pytest-xdist==3.5.0 +pytest-xdist==3.6.1 # via -r dev-requirements.in -python-dateutil==2.8.2 +python-dateutil==2.9.0.post0 # via # botocore # croniter @@ -378,7 +392,7 @@ pyyaml==6.0.1 # flytekit # kubernetes # pre-commit -requests==2.31.0 +requests==2.32.3 # via # azure-core # azure-datalake-store @@ -391,25 +405,25 @@ requests==2.31.0 # kubernetes # msal # requests-oauthlib -requests-oauthlib==1.3.1 +requests-oauthlib==2.0.0 # via # google-auth-oauthlib # kubernetes -rich==13.7.0 +rich==13.7.1 # via # flytekit # rich-click -rich-click==1.7.3 +rich-click==1.8.2 # via flytekit rsa==4.9 # via google-auth -s3fs==2023.9.2 +s3fs==2024.5.0 # via flytekit -scikit-learn==1.4.0 +scikit-learn==1.5.0 # via -r dev-requirements.in -scipy==1.12.0 +scipy==1.13.1 # via scikit-learn -setuptools-scm==8.0.4 +setuptools-scm==8.1.0 # via -r dev-requirements.in six==1.16.0 # via @@ -424,25 +438,23 @@ stack-data==0.6.3 # via ipython statsd==3.3.0 # via flytekit -threadpoolctl==3.2.0 +threadpoolctl==3.5.0 # via scikit-learn -traitlets==5.14.1 +traitlets==5.14.3 # via # ipython # matplotlib-inline -types-croniter==2.0.0.20240106 +types-croniter==2.0.0.20240423 # via -r dev-requirements.in types-decorator==5.1.8.20240310 # via -r dev-requirements.in -types-mock==5.1.0.20240106 +types-mock==5.1.0.20240425 # via -r dev-requirements.in -types-protobuf==4.24.0.20240129 +types-protobuf==5.26.0.20240422 # via -r dev-requirements.in -types-requests==2.31.0.6 +types-requests==2.32.0.20240523 # via -r dev-requirements.in -types-urllib3==1.26.25.14 - # via types-requests -typing-extensions==4.9.0 +typing-extensions==4.12.0 # via # azure-core # azure-storage-blob @@ -450,32 +462,32 @@ typing-extensions==4.9.0 # mashumaro # mypy # rich-click - # setuptools-scm # typing-inspect typing-inspect==0.9.0 # via dataclasses-json -tzdata==2023.4 +tzdata==2024.1 # via pandas -urllib3==1.26.18 +urllib3==2.2.1 # via # botocore # docker # flytekit # kubernetes # requests -virtualenv==20.25.0 + # types-requests +virtualenv==20.26.2 # via pre-commit wcwidth==0.2.13 # via prompt-toolkit -websocket-client==1.7.0 +websocket-client==1.8.0 # via # docker # kubernetes -wrapt==1.14.1 +wrapt==1.16.0 # via aiobotocore yarl==1.9.4 # via aiohttp -zipp==3.17.0 +zipp==3.19.1 # via importlib-metadata # The following packages are considered to be unsafe in a requirements file: diff --git a/plugins/flytekit-dbt/dev-requirements.in b/plugins/flytekit-dbt/dev-requirements.in index 6a7786f5fa..474972f0d1 100644 --- a/plugins/flytekit-dbt/dev-requirements.in +++ b/plugins/flytekit-dbt/dev-requirements.in @@ -1,2 +1,4 @@ +dbt-core==1.4.5 dbt-sqlite==1.4.0 -dbt-core>=1.0.0,<1.4.6 +dbt-semantic-interfaces<0.5.0 +numpy==1.26.4 diff --git a/plugins/flytekit-dbt/dev-requirements.txt b/plugins/flytekit-dbt/dev-requirements.txt deleted file mode 100644 index da14dc11a6..0000000000 --- a/plugins/flytekit-dbt/dev-requirements.txt +++ /dev/null @@ -1,124 +0,0 @@ -# -# This file is autogenerated by pip-compile with Python 3.10 -# by the following command: -# -# pip-compile dev-requirements.in -# -agate==1.7.0 - # via dbt-core -attrs==23.1.0 - # via jsonschema -babel==2.13.1 - # via agate -betterproto==1.2.5 - # via dbt-core -certifi==2023.7.22 - # via requests -cffi==1.16.0 - # via dbt-core -charset-normalizer==3.3.2 - # via requests -click==8.1.7 - # via dbt-core -colorama==0.4.6 - # via dbt-core -dbt-core==1.4.5 - # via - # -r dev-requirements.in - # dbt-sqlite -dbt-extractor==0.4.1 - # via dbt-core -dbt-sqlite==1.4.0 - # via -r dev-requirements.in -future==0.18.3 - # via parsedatetime -grpclib==0.4.6 - # via betterproto -h2==4.1.0 - # via grpclib -hologram==0.0.15 - # via dbt-core -hpack==4.0.0 - # via h2 -hyperframe==6.0.1 - # via h2 -idna==3.4 - # via - # dbt-core - # requests -isodate==0.6.1 - # via - # agate - # dbt-core -jinja2==3.1.2 - # via dbt-core -jsonschema==3.2.0 - # via hologram -leather==0.3.4 - # via agate -logbook==1.5.3 - # via dbt-core -markupsafe==2.1.3 - # via - # jinja2 - # werkzeug -mashumaro[msgpack]==3.3.1 - # via - # dbt-core - # mashumaro -minimal-snowplow-tracker==0.0.2 - # via dbt-core -msgpack==1.0.7 - # via mashumaro -multidict==6.0.4 - # via grpclib -networkx==2.8.8 - # via dbt-core -packaging==23.2 - # via dbt-core -parsedatetime==2.4 - # via agate -pathspec==0.10.3 - # via dbt-core -pycparser==2.21 - # via cffi -pyrsistent==0.20.0 - # via jsonschema -python-dateutil==2.8.2 - # via hologram -python-slugify==8.0.1 - # via agate -pytimeparse==1.1.8 - # via agate -pytz==2023.3.post1 - # via dbt-core -pyyaml==6.0.1 - # via dbt-core -requests==2.31.0 - # via - # dbt-core - # minimal-snowplow-tracker -six==1.16.0 - # via - # isodate - # jsonschema - # leather - # minimal-snowplow-tracker - # python-dateutil -sqlparse==0.4.4 - # via dbt-core -stringcase==1.2.0 - # via betterproto -text-unidecode==1.3 - # via python-slugify -typing-extensions==4.8.0 - # via - # dbt-core - # mashumaro -urllib3==2.0.7 - # via requests -werkzeug==2.3.8 - # via dbt-core - -# The following packages are considered to be unsafe in a requirements file: -# setuptools diff --git a/plugins/flytekit-dbt/setup.py b/plugins/flytekit-dbt/setup.py index 943386bed1..aca9ddd6a7 100644 --- a/plugins/flytekit-dbt/setup.py +++ b/plugins/flytekit-dbt/setup.py @@ -5,8 +5,8 @@ microlib_name = f"flytekitplugins-{PLUGIN_NAME}" plugin_requires = [ - "flytekit>=1.3.0b2,<2.0.0", - "dbt-core>=1.0.0", + "flytekit>=1.3.0b2", + "dbt-core<1.8.0", ] __version__ = "0.0.0+develop" @@ -33,6 +33,7 @@ "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", "Topic :: Scientific/Engineering", "Topic :: Scientific/Engineering :: Artificial Intelligence", "Topic :: Software Development", From e2847f0c4457016a1fc52cf18f284cf4c7b7a1cc Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Tue, 4 Jun 2024 03:53:48 +0800 Subject: [PATCH 11/32] Add . to the PYTHONPATH in ImageSpec (#2447) Signed-off-by: Kevin Su --- plugins/flytekit-envd/flytekitplugins/envd/image_builder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugins/flytekit-envd/flytekitplugins/envd/image_builder.py b/plugins/flytekit-envd/flytekitplugins/envd/image_builder.py index 59b8ca07eb..344003bafd 100644 --- a/plugins/flytekit-envd/flytekitplugins/envd/image_builder.py +++ b/plugins/flytekit-envd/flytekitplugins/envd/image_builder.py @@ -91,7 +91,7 @@ def create_envd_config(image_spec: ImageSpec) -> str: run_commands = _create_str_from_package_list(image_spec.commands) conda_channels = _create_str_from_package_list(image_spec.conda_channels) apt_packages = _create_str_from_package_list(image_spec.apt_packages) - env = {"PYTHONPATH": "/root", _F_IMG_ID: image_spec.image_name()} + env = {"PYTHONPATH": "/root:", _F_IMG_ID: image_spec.image_name()} if image_spec.env: env.update(image_spec.env) From 4a88d113f73bde2cdb4a6eeb284ae0939e275c4f Mon Sep 17 00:00:00 2001 From: Eduardo Apolinario <653394+eapolinario@users.noreply.github.com> Date: Mon, 3 Jun 2024 15:16:56 -0700 Subject: [PATCH 12/32] Remove pod-plugin and scikit-learn from official docker image (#2455) Signed-off-by: Eduardo Apolinario Co-authored-by: Eduardo Apolinario --- Dockerfile | 2 -- 1 file changed, 2 deletions(-) diff --git a/Dockerfile b/Dockerfile index 0a20d8aefa..d9c113c9f5 100644 --- a/Dockerfile +++ b/Dockerfile @@ -23,9 +23,7 @@ ARG DOCKER_IMAGE RUN apt-get update && apt-get install build-essential -y \ && pip install uv \ && uv pip install --system --no-cache-dir -U flytekit==$VERSION \ - flytekitplugins-pod==$VERSION \ flytekitplugins-deck-standard==$VERSION \ - scikit-learn \ && apt-get clean autoclean \ && apt-get autoremove --yes \ && rm -rf /var/lib/{apt,dpkg,cache,log}/ \ From db2ff9e4e14bdcec2c96388af33f0a3357459fe0 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Tue, 4 Jun 2024 11:06:31 +0800 Subject: [PATCH 13/32] Add ImageConfig to the serialization context for dynamic task (#2456) Signed-off-by: Kevin Su --- flytekit/core/python_auto_container.py | 10 +++++-- flytekit/image_spec/image_spec.py | 15 ++++++++++ flytekit/tools/translator.py | 19 +++++++++++-- .../flytekit-envd/tests/test_image_spec.py | 6 ++-- .../flytekit/unit/core/test_node_creation.py | 5 +++- .../unit/core/test_python_auto_container.py | 12 ++++++-- .../flytekit/unit/core/test_serialization.py | 28 ++++++++++++++++--- 7 files changed, 80 insertions(+), 15 deletions(-) diff --git a/flytekit/core/python_auto_container.py b/flytekit/core/python_auto_container.py index 29167ac031..2c4703cdd3 100644 --- a/flytekit/core/python_auto_container.py +++ b/flytekit/core/python_auto_container.py @@ -16,7 +16,7 @@ from flytekit.core.tracker import TrackedInstance, extract_task_module from flytekit.core.utils import _get_container_definition, _serialize_pod_spec, timeit from flytekit.extras.accelerators import BaseAccelerator -from flytekit.image_spec.image_spec import ImageBuildEngine, ImageSpec +from flytekit.image_spec.image_spec import ImageBuildEngine, ImageSpec, _calculate_deduped_hash_from_image_spec from flytekit.loggers import logger from flytekit.models import task as _task_model from flytekit.models.security import Secret, SecurityContext @@ -276,8 +276,12 @@ def get_registerable_container_image(img: Optional[Union[str, ImageSpec]], cfg: :return: """ if isinstance(img, ImageSpec): - ImageBuildEngine.build(img) - return img.image_name() + image = cfg.find_image(_calculate_deduped_hash_from_image_spec(img)) + image_name = image.full if image else None + if not image_name: + ImageBuildEngine.build(img) + image_name = img.image_name() + return image_name if img is not None and img != "": matches = _IMAGE_REPLACE_REGEX.findall(img) diff --git a/flytekit/image_spec/image_spec.py b/flytekit/image_spec/image_spec.py index 98f6c05cdc..37f87549d0 100644 --- a/flytekit/image_spec/image_spec.py +++ b/flytekit/image_spec/image_spec.py @@ -280,6 +280,21 @@ def _build_image(cls, builder, image_spec, img_name): cls._IMAGE_NAME_TO_REAL_NAME[img_name] = fully_qualified_image_name +@lru_cache +def _calculate_deduped_hash_from_image_spec(image_spec: ImageSpec): + """ + Calculate this special hash from the image spec, + and it used to identify the imageSpec in the ImageConfig in the serialization context. + + ImageConfig: + - deduced hash 1: flyteorg/flytekit: 123 + - deduced hash 2: flyteorg/flytekit: 456 + """ + image_spec_bytes = asdict(image_spec).__str__().encode("utf-8") + # copy the image spec to avoid modifying the original image spec. otherwise, the hash will be different. + return base64.urlsafe_b64encode(hashlib.md5(image_spec_bytes).digest()).decode("ascii").rstrip("=") + + @lru_cache def calculate_hash_from_image_spec(image_spec: ImageSpec): """ diff --git a/flytekit/tools/translator.py b/flytekit/tools/translator.py index b49639d23a..a77e0a0bf5 100644 --- a/flytekit/tools/translator.py +++ b/flytekit/tools/translator.py @@ -6,9 +6,10 @@ from flyteidl.admin import schedule_pb2 -from flytekit import PythonFunctionTask, SourceCode -from flytekit.configuration import SerializationSettings +from flytekit import ImageSpec, PythonFunctionTask, SourceCode +from flytekit.configuration import Image, ImageConfig, SerializationSettings from flytekit.core import constants as _common_constants +from flytekit.core import context_manager from flytekit.core.array_node_map_task import ArrayNodeMapTask from flytekit.core.base_task import PythonTask from flytekit.core.condition import BranchNode @@ -22,6 +23,7 @@ from flytekit.core.task import ReferenceTask from flytekit.core.utils import ClassDecorator, _dnsify from flytekit.core.workflow import ReferenceWorkflow, WorkflowBase +from flytekit.image_spec.image_spec import _calculate_deduped_hash_from_image_spec from flytekit.models import common as _common_models from flytekit.models import common as common_models from flytekit.models import interface as interface_models @@ -176,6 +178,19 @@ def get_serializable_task( ) if isinstance(entity, PythonFunctionTask) and entity.execution_mode == PythonFunctionTask.ExecutionBehavior.DYNAMIC: + for e in context_manager.FlyteEntities.entities: + if isinstance(e, PythonAutoContainerTask): + # 1. Build the ImageSpec for all the entities that are inside the current context, + # 2. Add images to the serialization context, so the dynamic task can look it up at runtime. + if isinstance(e.container_image, ImageSpec): + if settings.image_config.images is None: + settings.image_config = ImageConfig.create_from(settings.image_config.default_image) + settings.image_config.images.append( + Image.look_up_image_info( + _calculate_deduped_hash_from_image_spec(e.container_image), e.get_image(settings) + ) + ) + # In case of Dynamic tasks, we want to pass the serialization context, so that they can reconstruct the state # from the serialization context. This is passed through an environment variable, that is read from # during dynamic serialization diff --git a/plugins/flytekit-envd/tests/test_image_spec.py b/plugins/flytekit-envd/tests/test_image_spec.py index 7fd3cd1be0..5b7b73f755 100644 --- a/plugins/flytekit-envd/tests/test_image_spec.py +++ b/plugins/flytekit-envd/tests/test_image_spec.py @@ -57,7 +57,7 @@ def build(): run(commands=["echo hello"]) install.python_packages(name=["pandas"]) install.apt_packages(name=["git"]) - runtime.environ(env={{'PYTHONPATH': '/root', '_F_IMG_ID': '{image_name}'}}, extra_path=['/root']) + runtime.environ(env={{'PYTHONPATH': '/root:', '_F_IMG_ID': '{image_name}'}}, extra_path=['/root']) config.pip_index(url="https://private-pip-index/simple") install.python(version="3.8") io.copy(source="./", target="/root") @@ -88,7 +88,7 @@ def build(): run(commands=[]) install.python_packages(name=["flytekit"]) install.apt_packages(name=[]) - runtime.environ(env={{'PYTHONPATH': '/root', '_F_IMG_ID': '{image_name}'}}, extra_path=['/root']) + runtime.environ(env={{'PYTHONPATH': '/root:', '_F_IMG_ID': '{image_name}'}}, extra_path=['/root']) config.pip_index(url="https://pypi.org/simple") install.conda(use_mamba=True) install.conda_packages(name=["pytorch", "cpuonly"], channel=["pytorch"]) @@ -122,7 +122,7 @@ def build(): run(commands=[]) install.python_packages(name=["-U --pre pandas", "torch", "torchvision"]) install.apt_packages(name=[]) - runtime.environ(env={{'PYTHONPATH': '/root', '_F_IMG_ID': '{image_name}'}}, extra_path=['/root']) + runtime.environ(env={{'PYTHONPATH': '/root:', '_F_IMG_ID': '{image_name}'}}, extra_path=['/root']) config.pip_index(url="https://pypi.org/simple", extra_url="https://download.pytorch.org/whl/cpu https://pypi.anaconda.org/scientific-python-nightly-wheels/simple") """ ) diff --git a/tests/flytekit/unit/core/test_node_creation.py b/tests/flytekit/unit/core/test_node_creation.py index fc3284ca10..684f49031b 100644 --- a/tests/flytekit/unit/core/test_node_creation.py +++ b/tests/flytekit/unit/core/test_node_creation.py @@ -14,12 +14,15 @@ from flytekit.core.workflow import workflow from flytekit.exceptions.user import FlyteAssertion from flytekit.extras.accelerators import A100, T4 +from flytekit.image_spec.image_spec import ImageBuildEngine from flytekit.models import literals as _literal_models from flytekit.models.task import Resources as _resources_models from flytekit.tools.translator import get_serializable -def test_normal_task(): +def test_normal_task(mock_image_spec_builder): + ImageBuildEngine.register("test", mock_image_spec_builder) + @task def t1(a: str) -> str: return a + " world" diff --git a/tests/flytekit/unit/core/test_python_auto_container.py b/tests/flytekit/unit/core/test_python_auto_container.py index 58492fca06..5068da53de 100644 --- a/tests/flytekit/unit/core/test_python_auto_container.py +++ b/tests/flytekit/unit/core/test_python_auto_container.py @@ -9,7 +9,7 @@ from flytekit.core.pod_template import PodTemplate from flytekit.core.python_auto_container import PythonAutoContainerTask, get_registerable_container_image from flytekit.core.resources import Resources -from flytekit.image_spec.image_spec import ImageBuildEngine, ImageSpec +from flytekit.image_spec.image_spec import ImageBuildEngine, ImageSpec, _calculate_deduped_hash_from_image_spec from flytekit.tools.translator import get_serializable_task @@ -55,9 +55,17 @@ def serialization_settings(request): def test_image_name_interpolation(default_image_config): + image_spec = ImageSpec(name="image-1", registry="localhost:30000", builder="test") + + new_img_cfg = ImageConfig.create_from( + default_image_config.default_image, + other_images=[Image.look_up_image_info(_calculate_deduped_hash_from_image_spec(image_spec), "flyte/test:d1")], + ) img_to_interpolate = "{{.image.default.fqn}}:{{.image.default.version}}-special" - img = get_registerable_container_image(img=img_to_interpolate, cfg=default_image_config) + img = get_registerable_container_image(img=img_to_interpolate, cfg=new_img_cfg) assert img == "docker.io/xyz:some-git-hash-special" + img = get_registerable_container_image(img=image_spec, cfg=new_img_cfg) + assert img == "flyte/test:d1" class DummyAutoContainerTask(PythonAutoContainerTask): diff --git a/tests/flytekit/unit/core/test_serialization.py b/tests/flytekit/unit/core/test_serialization.py index 9b11a2a16a..88297f43f4 100644 --- a/tests/flytekit/unit/core/test_serialization.py +++ b/tests/flytekit/unit/core/test_serialization.py @@ -6,12 +6,13 @@ import pytest import flytekit.configuration -from flytekit import ContainerTask, kwtypes +from flytekit import ContainerTask, ImageSpec, kwtypes from flytekit.configuration import Image, ImageConfig, SerializationSettings from flytekit.core.condition import conditional from flytekit.core.python_auto_container import get_registerable_container_image from flytekit.core.task import task from flytekit.core.workflow import workflow +from flytekit.image_spec.image_spec import ImageBuildEngine, _calculate_deduped_hash_from_image_spec from flytekit.models.admin.workflow import WorkflowSpec from flytekit.models.types import SimpleType from flytekit.tools.translator import get_serializable @@ -250,7 +251,9 @@ def test_bad_configuration(): get_registerable_container_image(container_image, image_config) -def test_serialization_images(): +def test_serialization_images(mock_image_spec_builder): + ImageBuildEngine.register("test", mock_image_spec_builder) + @task(container_image="{{.image.xyz.fqn}}:{{.image.xyz.version}}") def t1(a: int) -> int: return a @@ -271,10 +274,24 @@ def t5(a: int) -> int: def t6(a: int) -> int: return a + image_spec = ImageSpec( + packages=["mypy"], + apt_packages=["git"], + registry="ghcr.io/flyteorg", + builder="test", + ) + + @task(container_image=image_spec) + def t7(a: int) -> int: + return a + with mock.patch.dict(os.environ, {"FLYTE_INTERNAL_IMAGE": "docker.io/default:version"}): imgs = ImageConfig.auto( config_file=os.path.join(os.path.dirname(os.path.realpath(__file__)), "configs/images.config") ) + imgs.images.append( + Image(name=_calculate_deduped_hash_from_image_spec(image_spec), fqn="docker.io/t7", tag="latest") + ) rs = flytekit.configuration.SerializationSettings( project="project", domain="domain", @@ -295,8 +312,11 @@ def t6(a: int) -> int: t5_spec = get_serializable(OrderedDict(), rs, t5) assert t5_spec.template.container.image == "docker.io/org/myimage:latest" - t5_spec = get_serializable(OrderedDict(), rs, t6) - assert t5_spec.template.container.image == "docker.io/xyz_123:v1" + t6_spec = get_serializable(OrderedDict(), rs, t6) + assert t6_spec.template.container.image == "docker.io/xyz_123:v1" + + t7_spec = get_serializable(OrderedDict(), rs, t7) + assert t7_spec.template.container.image == "docker.io/t7:latest" def test_serialization_command1(): From c388a437a8f75aadaf5d59d2016a5568aad13b85 Mon Sep 17 00:00:00 2001 From: Eduardo Apolinario <653394+eapolinario@users.noreply.github.com> Date: Tue, 4 Jun 2024 13:11:56 -0700 Subject: [PATCH 14/32] Only run CI checks using protobuf 4 (#2459) * Only run CI checks using protobuf 4 Signed-off-by: Eduardo Apolinario * Fix lint error Signed-off-by: Eduardo Apolinario * Install the right version of types-protobuf Signed-off-by: Eduardo Apolinario * Force protobuf<5 in plugins tests Signed-off-by: Eduardo Apolinario --------- Signed-off-by: Eduardo Apolinario Co-authored-by: Eduardo Apolinario --- .github/workflows/pythonbuild.yml | 3 ++- dev-requirements.in | 7 ++++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/.github/workflows/pythonbuild.yml b/.github/workflows/pythonbuild.yml index b18cdc229a..f56caffc0d 100644 --- a/.github/workflows/pythonbuild.yml +++ b/.github/workflows/pythonbuild.yml @@ -417,7 +417,8 @@ jobs: cd plugins/${{ matrix.plugin-names }} uv pip install --system . if [ -f dev-requirements.in ]; then uv pip install --system -r dev-requirements.in; fi - uv pip install --system -U $GITHUB_WORKSPACE + # TODO: move to protobuf>=5. Github issue: https://github.com/flyteorg/flyte/issues/5448 + uv pip install --system -U $GITHUB_WORKSPACE "protobuf<5" uv pip freeze - name: Test with coverage run: | diff --git a/dev-requirements.in b/dev-requirements.in index 7f42851061..ca37177df7 100644 --- a/dev-requirements.in +++ b/dev-requirements.in @@ -37,7 +37,12 @@ torch; python_version<'3.12' # Once a solution is found, this should be updated to support Windows as well. python-magic; (platform_system=='Darwin' or platform_system=='Linux') -types-protobuf +# Google released a new major version of the protobuf library and once that started being used in the ecosystem at large, +# including `googleapis-common-protos` we started seeing errors in CI, so let's constrain that for now. +# The issue to support protobuf 5 is being tracked in https://github.com/flyteorg/flyte/issues/5448. +protobuf<5 +types-protobuf<5 + types-croniter types-decorator types-mock From 9872e65cf7b72c803bb347b1be746acdcfc9308c Mon Sep 17 00:00:00 2001 From: Future-Outlier Date: Wed, 5 Jun 2024 06:47:24 +0800 Subject: [PATCH 15/32] Support Passing Dataclass Values via Command Line (#2446) Signed-off-by: Future-Outlier --- flytekit/interaction/click_types.py | 7 ++++++- .../unit/interaction/test_click_types.py | 20 +++++++++++++++++++ 2 files changed, 26 insertions(+), 1 deletion(-) diff --git a/flytekit/interaction/click_types.py b/flytekit/interaction/click_types.py index 2c26ed0cbc..c16339a236 100644 --- a/flytekit/interaction/click_types.py +++ b/flytekit/interaction/click_types.py @@ -10,7 +10,7 @@ import cloudpickle import rich_click as click import yaml -from dataclasses_json import DataClassJsonMixin +from dataclasses_json import DataClassJsonMixin, dataclass_json from pytimeparse import parse from flytekit import BlobType, FlyteContext, FlyteContextManager, Literal, LiteralType, StructuredDataset @@ -273,6 +273,11 @@ def convert( if is_pydantic_basemodel(self._python_type): return self._python_type.parse_raw(json.dumps(parsed_value)) # type: ignore + + # Ensure that the python type has `from_json` function + if not hasattr(self._python_type, "from_json"): + self._python_type = dataclass_json(self._python_type) + return cast(DataClassJsonMixin, self._python_type).from_json(json.dumps(parsed_value)) diff --git a/tests/flytekit/unit/interaction/test_click_types.py b/tests/flytekit/unit/interaction/test_click_types.py index 83a191c449..cb32982916 100644 --- a/tests/flytekit/unit/interaction/test_click_types.py +++ b/tests/flytekit/unit/interaction/test_click_types.py @@ -206,3 +206,23 @@ def test_query_passing(param_type: click.ParamType): query = a.query() assert param_type.convert(value=query, param=None, ctx=None) is query + + +def test_dataclass_type(): + from dataclasses import dataclass + + @dataclass + class Datum: + x: int + y: str + z: dict[int, str] + w: list[int] + + t = JsonParamType(Datum) + value = '{ "x": 1, "y": "2", "z": { "1": "one", "2": "two" }, "w": [1, 2, 3] }' + v = t.convert(value=value, param=None, ctx=None) + + assert v.x == 1 + assert v.y == "2" + assert v.z == {1: "one", 2: "two"} + assert v.w == [1, 2, 3] From 750a3830f1b4ab293dfb30bcc9aa496778de36b5 Mon Sep 17 00:00:00 2001 From: Future-Outlier Date: Wed, 5 Jun 2024 07:32:17 +0800 Subject: [PATCH 16/32] Fix output_prefix in do() method for ChatGPT Agent (#2457) Signed-off-by: Future-Outlier Co-authored-by: pingsutw --- flytekit/extend/backend/base_agent.py | 8 +++++-- .../flytekitplugins/openai/chatgpt/agent.py | 1 + .../tests/chatgpt/test_chatgpt.py | 22 +++++++++++++++++++ 3 files changed, 29 insertions(+), 2 deletions(-) diff --git a/flytekit/extend/backend/base_agent.py b/flytekit/extend/backend/base_agent.py index 33a03e282b..e8ec18806e 100644 --- a/flytekit/extend/backend/base_agent.py +++ b/flytekit/extend/backend/base_agent.py @@ -119,7 +119,9 @@ class SyncAgentBase(AgentBase): name = "Base Sync Agent" @abstractmethod - def do(self, task_template: TaskTemplate, inputs: Optional[LiteralMap], output_prefix: str, **kwargs) -> Resource: + def do( + self, task_template: TaskTemplate, output_prefix: str, inputs: Optional[LiteralMap] = None, **kwargs + ) -> Resource: """ This is the method that the agent will run. """ @@ -247,7 +249,9 @@ def execute(self: PythonTask, **kwargs) -> LiteralMap: agent = AgentRegistry.get_agent(task_template.type, task_template.task_type_version) - resource = asyncio.run(self._do(agent, task_template, output_prefix, kwargs)) + resource = asyncio.run( + self._do(agent=agent, template=task_template, output_prefix=output_prefix, inputs=kwargs) + ) if resource.phase != TaskExecution.SUCCEEDED: raise FlyteUserException(f"Failed to run the task {self.name} with error: {resource.message}") diff --git a/plugins/flytekit-openai/flytekitplugins/openai/chatgpt/agent.py b/plugins/flytekit-openai/flytekitplugins/openai/chatgpt/agent.py index afd3af1321..e4f24baa5a 100644 --- a/plugins/flytekit-openai/flytekitplugins/openai/chatgpt/agent.py +++ b/plugins/flytekit-openai/flytekitplugins/openai/chatgpt/agent.py @@ -27,6 +27,7 @@ async def do( self, task_template: TaskTemplate, inputs: Optional[LiteralMap] = None, + **kwargs, ) -> Resource: ctx = FlyteContextManager.current_context() input_python_value = TypeEngine.literal_map_to_kwargs(ctx, inputs, {"message": str}) diff --git a/plugins/flytekit-openai/tests/chatgpt/test_chatgpt.py b/plugins/flytekit-openai/tests/chatgpt/test_chatgpt.py index 6298bdf52c..12de3da23b 100644 --- a/plugins/flytekit-openai/tests/chatgpt/test_chatgpt.py +++ b/plugins/flytekit-openai/tests/chatgpt/test_chatgpt.py @@ -1,4 +1,5 @@ from collections import OrderedDict +from unittest import mock from flytekitplugins.openai import ChatGPTTask @@ -7,6 +8,14 @@ from flytekit.models.types import SimpleType +async def mock_acreate(*args, **kwargs) -> str: + mock_response = mock.MagicMock() + mock_choice = mock.MagicMock() + mock_choice.message.content = "mocked_message" + mock_response.choices = [mock_choice] + return mock_response + + def test_chatgpt_task(): chatgpt_task = ChatGPTTask( name="chatgpt", @@ -40,3 +49,16 @@ def test_chatgpt_task(): assert chatgpt_task_spec.template.interface.inputs["message"].type.simple == SimpleType.STRING assert chatgpt_task_spec.template.interface.outputs["o0"].type.simple == SimpleType.STRING + + with mock.patch("openai.resources.chat.completions.AsyncCompletions.create", new=mock_acreate): + chatgpt_task = ChatGPTTask( + name="chatgpt", + openai_organization="TEST ORGANIZATION ID", + chatgpt_config={ + "model": "gpt-3.5-turbo", + "temperature": 0.7, + }, + ) + + response = chatgpt_task(message="hi") + assert response == "mocked_message" From 57ee143ae256925324129bd66bde503b0af8a7a7 Mon Sep 17 00:00:00 2001 From: Future-Outlier Date: Wed, 5 Jun 2024 08:52:20 +0800 Subject: [PATCH 17/32] Fix CI by restricting hugging face dataset version (#2461) * init Signed-off-by: Future-Outlier * Update pingsu's advice Signed-off-by: Future-Outlier Co-authored-by: pingsutw --------- Signed-off-by: Future-Outlier Co-authored-by: pingsutw --- plugins/flytekit-huggingface/setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugins/flytekit-huggingface/setup.py b/plugins/flytekit-huggingface/setup.py index 9c1debaba0..0d5a9c7b3a 100644 --- a/plugins/flytekit-huggingface/setup.py +++ b/plugins/flytekit-huggingface/setup.py @@ -6,7 +6,7 @@ plugin_requires = [ "flytekit>=1.3.0b2,<2.0.0", - "datasets>=2.4.0", + "datasets>=2.4.0,!=2.19.2", ] __version__ = "0.0.0+develop" From 72af7c6d8982892047e7051e1af10d628f9f9b39 Mon Sep 17 00:00:00 2001 From: Yee Hing Tong Date: Wed, 5 Jun 2024 10:32:55 -0700 Subject: [PATCH 18/32] Add escape for scalars to binding during union handling (#2460) Signed-off-by: Yee Hing Tong --- flytekit/core/promise.py | 9 +++++ tests/flytekit/unit/core/test_promise.py | 17 +++++++++ tests/flytekit/unit/core/test_type_engine.py | 36 ++++++++++++++++++++ 3 files changed, 62 insertions(+) diff --git a/flytekit/core/promise.py b/flytekit/core/promise.py index 220fc3fb89..ba326ec27e 100644 --- a/flytekit/core/promise.py +++ b/flytekit/core/promise.py @@ -697,6 +697,15 @@ def binding_data_from_python_std( ) elif t_value is not None and expected_literal_type.union_type is not None: + # If the value is not a container type, then we can directly convert it to a scalar in the Union case. + # This pushes the handling of the Union types to the type engine. + if not isinstance(t_value, list) and not isinstance(t_value, dict): + scalar = TypeEngine.to_literal(ctx, t_value, t_value_type or type(t_value), expected_literal_type).scalar + return _literals_models.BindingData(scalar=scalar) + + # If it is a container type, then we need to iterate over the variants in the Union type, try each one. This is + # akin to what the Type Engine does when it finds a Union type (see the UnionTransformer), but we can't rely on + # that in this case, because of the mix and match of realized values, and Promises. for i in range(len(expected_literal_type.union_type.variants)): try: lt_type = expected_literal_type.union_type.variants[i] diff --git a/tests/flytekit/unit/core/test_promise.py b/tests/flytekit/unit/core/test_promise.py index 74f3db99e3..e022c875e0 100644 --- a/tests/flytekit/unit/core/test_promise.py +++ b/tests/flytekit/unit/core/test_promise.py @@ -13,6 +13,7 @@ from flytekit.core.promise import ( Promise, VoidPromise, + binding_data_from_python_std, create_and_link_node, create_and_link_node_from_remote, resolve_attr_path_in_promise, @@ -20,6 +21,7 @@ ) from flytekit.core.type_engine import TypeEngine from flytekit.exceptions.user import FlyteAssertion, FlytePromiseAttributeResolveException +from flytekit.models.types import LiteralType, SimpleType, TypeStructure from flytekit.types.pickle.pickle import BatchSize @@ -234,3 +236,18 @@ class Foo: # exception with pytest.raises(FlytePromiseAttributeResolveException): tgt_promise = resolve_attr_path_in_promise(src_promise["c"]) + + +def test_prom_with_union_literals(): + ctx = FlyteContextManager.current_context() + pt = typing.Union[str, int] + lt = TypeEngine.to_literal_type(pt) + assert lt.union_type.variants == [ + LiteralType(simple=SimpleType.STRING, structure=TypeStructure(tag="str")), + LiteralType(simple=SimpleType.INTEGER, structure=TypeStructure(tag="int")), + ] + + bd = binding_data_from_python_std(ctx, lt, 3, pt, []) + assert bd.scalar.union.stored_type.structure.tag == "int" + bd = binding_data_from_python_std(ctx, lt, "hello", pt, []) + assert bd.scalar.union.stored_type.structure.tag == "str" diff --git a/tests/flytekit/unit/core/test_type_engine.py b/tests/flytekit/unit/core/test_type_engine.py index d4db4f34fe..0546b9dc7a 100644 --- a/tests/flytekit/unit/core/test_type_engine.py +++ b/tests/flytekit/unit/core/test_type_engine.py @@ -1759,6 +1759,42 @@ def test_annotated_union_type(): assert v == "hello" +def test_union_type_simple(): + pt = typing.Union[str, int] + lt = TypeEngine.to_literal_type(pt) + assert lt.union_type.variants == [ + LiteralType(simple=SimpleType.STRING, structure=TypeStructure(tag="str")), + LiteralType(simple=SimpleType.INTEGER, structure=TypeStructure(tag="int")), + ] + ctx = FlyteContextManager.current_context() + lv = TypeEngine.to_literal(ctx, 3, pt, lt) + assert lv.scalar.union is not None + assert lv.scalar.union.stored_type.structure.tag == "int" + assert lv.scalar.union.stored_type.structure.dataclass_type is None + + +def test_union_containers(): + pt = typing.Union[typing.List[typing.Dict[str, typing.List[int]]], typing.Dict[str, typing.List[int]], int] + lt = TypeEngine.to_literal_type(pt) + + list_of_maps_of_list_ints = [ + {"first_map_a": [42], "first_map_b": [42, 2]}, + { + "second_map_c": [33], + "second_map_d": [9, 99], + }, + ] + map_of_list_ints = { + "ll_1": [1, 23, 3], + "ll_2": [4, 5, 6], + } + ctx = FlyteContextManager.current_context() + lv = TypeEngine.to_literal(ctx, list_of_maps_of_list_ints, pt, lt) + assert lv.scalar.union.stored_type.structure.tag == "Typed List" + lv = TypeEngine.to_literal(ctx, map_of_list_ints, pt, lt) + assert lv.scalar.union.stored_type.structure.tag == "Python Dictionary" + + @pytest.mark.skipif(sys.version_info < (3, 10), reason="PEP604 requires >=3.10.") def test_optional_type(): pt = typing.Optional[int] From c03eaadd72fbe302003baf98e3e8e2fac237eaac Mon Sep 17 00:00:00 2001 From: Eduardo Apolinario <653394+eapolinario@users.noreply.github.com> Date: Thu, 6 Jun 2024 10:03:54 -0700 Subject: [PATCH 19/32] Prevent mutable default arguments (#2443) * Prevent mutable default arguments Signed-off-by: Eduardo Apolinario * Do not lint tests Signed-off-by: Eduardo Apolinario * Bump precommit hook version Signed-off-by: Eduardo Apolinario * Remove default value for dictionary in Spark class Signed-off-by: Eduardo Apolinario * Run precommit hook locally Signed-off-by: Eduardo Apolinario * Fix lint error in openai plugin Signed-off-by: Eduardo Apolinario * Handle nil databricks_conf in spark plugin Signed-off-by: Eduardo Apolinario --------- Signed-off-by: Eduardo Apolinario Co-authored-by: Eduardo Apolinario --- .pre-commit-config.yaml | 2 +- flytekit/__init__.py | 1 + flytekit/clients/auth/authenticator.py | 6 ++--- flytekit/configuration/__init__.py | 1 + flytekit/configuration/plugin.py | 1 + flytekit/core/artifact.py | 9 +++---- flytekit/core/context_manager.py | 3 +-- flytekit/core/data_persistence.py | 1 + flytekit/core/dynamic_workflow_task.py | 1 + flytekit/core/interface.py | 3 +-- flytekit/core/legacy_map_task.py | 1 + flytekit/core/local_fsspec.py | 1 + flytekit/core/notification.py | 1 + flytekit/core/promise.py | 24 +++++++------------ flytekit/core/reference_entity.py | 3 +-- flytekit/core/schedule.py | 3 +-- flytekit/core/task.py | 6 ++--- flytekit/core/type_engine.py | 3 +-- flytekit/core/workflow.py | 6 ++--- flytekit/extras/accelerators.py | 4 ++-- flytekit/extras/pytorch/__init__.py | 1 + flytekit/extras/sklearn/__init__.py | 1 + flytekit/remote/entities.py | 1 + flytekit/remote/executions.py | 6 ++--- flytekit/remote/remote.py | 1 + flytekit/remote/remote_callable.py | 3 +-- flytekit/types/directory/types.py | 3 +-- flytekit/types/file/file.py | 3 +-- flytekit/types/schema/types.py | 9 +++---- flytekit/types/structured/__init__.py | 1 - .../types/structured/structured_dataset.py | 11 ++++----- .../flytekitplugins/async_fsspec/__init__.py | 1 + .../flytekitplugins/kfmpi/task.py | 1 + .../flytekit-kf-mpi/tests/test_mpi_task.py | 3 +-- .../kfpytorch/error_handling.py | 1 + .../flytekitplugins/kfpytorch/task.py | 1 + .../flytekitplugins/kftensorflow/task.py | 1 + .../flytekitplugins/modin/schema.py | 6 ++--- .../flytekitplugins/openai/batch/task.py | 2 +- .../flytekitplugins/openai/batch/workflow.py | 6 +++-- .../flytekitplugins/pandera/schema.py | 6 ++--- .../pydantic/basemodel_transformer.py | 2 +- .../flytekitplugins/pydantic/serialization.py | 1 + .../flytekitplugins/spark/models.py | 4 +++- pyproject.toml | 6 ++++- 45 files changed, 77 insertions(+), 84 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 703bcda938..a0fc842ba2 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,7 +1,7 @@ repos: - repo: https://github.com/astral-sh/ruff-pre-commit # Ruff version. - rev: v0.2.2 + rev: v0.4.7 hooks: # Run the linter. - id: ruff diff --git a/flytekit/__init__.py b/flytekit/__init__.py index 63ad935b47..33bdad747a 100644 --- a/flytekit/__init__.py +++ b/flytekit/__init__.py @@ -202,6 +202,7 @@ SourceCode """ + import os import sys from typing import Generator diff --git a/flytekit/clients/auth/authenticator.py b/flytekit/clients/auth/authenticator.py index 0ed780509e..95a89422be 100644 --- a/flytekit/clients/auth/authenticator.py +++ b/flytekit/clients/auth/authenticator.py @@ -35,8 +35,7 @@ class ClientConfigStore(object): """ @abstractmethod - def get_client_config(self) -> ClientConfig: - ... + def get_client_config(self) -> ClientConfig: ... class StaticClientConfigStore(ClientConfigStore): @@ -81,8 +80,7 @@ def fetch_grpc_call_auth_metadata(self) -> typing.Optional[typing.Tuple[str, str return None @abstractmethod - def refresh_credentials(self): - ... + def refresh_credentials(self): ... class PKCEAuthenticator(Authenticator): diff --git a/flytekit/configuration/__init__.py b/flytekit/configuration/__init__.py index c29bd71c88..97a9940425 100644 --- a/flytekit/configuration/__init__.py +++ b/flytekit/configuration/__init__.py @@ -126,6 +126,7 @@ ~DataConfig """ + from __future__ import annotations import base64 diff --git a/flytekit/configuration/plugin.py b/flytekit/configuration/plugin.py index cc8750deaa..3d43844d39 100644 --- a/flytekit/configuration/plugin.py +++ b/flytekit/configuration/plugin.py @@ -17,6 +17,7 @@ my_plugin = "my_module:MyCustomPlugin" ``` """ + from typing import Optional, Protocol, runtime_checkable from click import Group diff --git a/flytekit/core/artifact.py b/flytekit/core/artifact.py index f2e08042bc..954151504f 100644 --- a/flytekit/core/artifact.py +++ b/flytekit/core/artifact.py @@ -607,14 +607,11 @@ class ArtifactSerializationHandler(typing.Protocol): This protocol defines the interface for serializing artifact-related entities down to Flyte IDL. """ - def partitions_to_idl(self, p: Optional[Partitions], **kwargs) -> Optional[art_id.Partitions]: - ... + def partitions_to_idl(self, p: Optional[Partitions], **kwargs) -> Optional[art_id.Partitions]: ... - def time_partition_to_idl(self, tp: Optional[TimePartition], **kwargs) -> Optional[art_id.TimePartition]: - ... + def time_partition_to_idl(self, tp: Optional[TimePartition], **kwargs) -> Optional[art_id.TimePartition]: ... - def artifact_query_to_idl(self, aq: ArtifactQuery, **kwargs) -> art_id.ArtifactQuery: - ... + def artifact_query_to_idl(self, aq: ArtifactQuery, **kwargs) -> art_id.ArtifactQuery: ... class DefaultArtifactSerializationHandler(ArtifactSerializationHandler): diff --git a/flytekit/core/context_manager.py b/flytekit/core/context_manager.py index a31d058774..c51b60c1c9 100644 --- a/flytekit/core/context_manager.py +++ b/flytekit/core/context_manager.py @@ -560,8 +560,7 @@ class SerializableToString(typing.Protocol): and then added to a literal's metadata. """ - def serialize_to_string(self, ctx: FlyteContext, variable_name: str) -> typing.Tuple[str, str]: - ... + def serialize_to_string(self, ctx: FlyteContext, variable_name: str) -> typing.Tuple[str, str]: ... @dataclass diff --git a/flytekit/core/data_persistence.py b/flytekit/core/data_persistence.py index 2ca84ad8fd..5c8036d179 100644 --- a/flytekit/core/data_persistence.py +++ b/flytekit/core/data_persistence.py @@ -17,6 +17,7 @@ FileAccessProvider """ + import io import os import pathlib diff --git a/flytekit/core/dynamic_workflow_task.py b/flytekit/core/dynamic_workflow_task.py index a0f84927bf..a9ff5055db 100644 --- a/flytekit/core/dynamic_workflow_task.py +++ b/flytekit/core/dynamic_workflow_task.py @@ -12,6 +12,7 @@ dynamic workflows to under fifty tasks. For large-scale identical runs, we recommend the upcoming map task. """ + import functools from flytekit.core import task diff --git a/flytekit/core/interface.py b/flytekit/core/interface.py index c139641278..13b6af2d4b 100644 --- a/flytekit/core/interface.py +++ b/flytekit/core/interface.py @@ -109,8 +109,7 @@ def runs_before(self, *args, **kwargs): where runs_before is manually called. """ - def __rshift__(self, *args, **kwargs): - ... # See runs_before + def __rshift__(self, *args, **kwargs): ... # See runs_before self._output_tuple_class = Output self._docstring = docstring diff --git a/flytekit/core/legacy_map_task.py b/flytekit/core/legacy_map_task.py index fe8d353027..99c67ad12c 100644 --- a/flytekit/core/legacy_map_task.py +++ b/flytekit/core/legacy_map_task.py @@ -2,6 +2,7 @@ Flytekit map tasks specify how to run a single task across a list of inputs. Map tasks themselves are constructed with a reference task as well as run-time parameters that limit execution concurrency and failure tolerations. """ + import functools import hashlib import logging diff --git a/flytekit/core/local_fsspec.py b/flytekit/core/local_fsspec.py index b452b3006e..91fe93ad6f 100644 --- a/flytekit/core/local_fsspec.py +++ b/flytekit/core/local_fsspec.py @@ -14,6 +14,7 @@ FlyteLocalFileSystem """ + import os from fsspec.implementations.local import LocalFileSystem diff --git a/flytekit/core/notification.py b/flytekit/core/notification.py index cecfe43367..c964c67568 100644 --- a/flytekit/core/notification.py +++ b/flytekit/core/notification.py @@ -15,6 +15,7 @@ .. autoclass:: flytekit.core.notification.Notification """ + from typing import List from flytekit.models import common as _common_model diff --git a/flytekit/core/promise.py b/flytekit/core/promise.py index ba326ec27e..931d036d02 100644 --- a/flytekit/core/promise.py +++ b/flytekit/core/promise.py @@ -900,28 +900,22 @@ def with_attr(self, key) -> NodeOutput: class SupportsNodeCreation(Protocol): @property - def name(self) -> str: - ... + def name(self) -> str: ... @property - def python_interface(self) -> flyte_interface.Interface: - ... + def python_interface(self) -> flyte_interface.Interface: ... - def construct_node_metadata(self) -> _workflow_model.NodeMetadata: - ... + def construct_node_metadata(self) -> _workflow_model.NodeMetadata: ... class HasFlyteInterface(Protocol): @property - def name(self) -> str: - ... + def name(self) -> str: ... @property - def interface(self) -> _interface_models.TypedInterface: - ... + def interface(self) -> _interface_models.TypedInterface: ... - def construct_node_metadata(self) -> _workflow_model.NodeMetadata: - ... + def construct_node_metadata(self) -> _workflow_model.NodeMetadata: ... def extract_obj_name(name: str) -> str: @@ -1148,11 +1142,9 @@ def create_and_link_node( class LocallyExecutable(Protocol): - def local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], Promise, VoidPromise, None]: - ... + def local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], Promise, VoidPromise, None]: ... - def local_execution_mode(self) -> ExecutionState.Mode: - ... + def local_execution_mode(self) -> ExecutionState.Mode: ... def flyte_entity_call_handler( diff --git a/flytekit/core/reference_entity.py b/flytekit/core/reference_entity.py index 0d861db513..b54c4d67f6 100644 --- a/flytekit/core/reference_entity.py +++ b/flytekit/core/reference_entity.py @@ -37,8 +37,7 @@ def id(self) -> _identifier_model.Identifier: @property @abstractmethod - def resource_type(self) -> int: - ... + def resource_type(self) -> int: ... @dataclass diff --git a/flytekit/core/schedule.py b/flytekit/core/schedule.py index ada3c69fd1..891fb17a24 100644 --- a/flytekit/core/schedule.py +++ b/flytekit/core/schedule.py @@ -16,8 +16,7 @@ class LaunchPlanTriggerBase(Protocol): - def to_flyte_idl(self, *args, **kwargs) -> google_message.Message: - ... + def to_flyte_idl(self, *args, **kwargs) -> google_message.Message: ... # Duplicates flytekit.common.schedules.Schedule to avoid using the ExtendedSdkType metaclass. diff --git a/flytekit/core/task.py b/flytekit/core/task.py index ed15d3e0af..d30947509d 100644 --- a/flytekit/core/task.py +++ b/flytekit/core/task.py @@ -117,8 +117,7 @@ def task( pod_template: Optional["PodTemplate"] = ..., pod_template_name: Optional[str] = ..., accelerator: Optional[BaseAccelerator] = ..., -) -> Callable[[Callable[..., FuncOut]], PythonFunctionTask[T]]: - ... +) -> Callable[[Callable[..., FuncOut]], PythonFunctionTask[T]]: ... @overload @@ -155,8 +154,7 @@ def task( pod_template: Optional["PodTemplate"] = ..., pod_template_name: Optional[str] = ..., accelerator: Optional[BaseAccelerator] = ..., -) -> Union[PythonFunctionTask[T], Callable[..., FuncOut]]: - ... +) -> Union[PythonFunctionTask[T], Callable[..., FuncOut]]: ... def task( diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index 9c8ffabf66..f5beb53f52 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -121,8 +121,7 @@ def modify_literal_uris(lit: Literal): ) -class TypeTransformerFailedError(TypeError, AssertionError, ValueError): - ... +class TypeTransformerFailedError(TypeError, AssertionError, ValueError): ... class TypeTransformer(typing.Generic[T]): diff --git a/flytekit/core/workflow.py b/flytekit/core/workflow.py index 0f25374717..58f8157983 100644 --- a/flytekit/core/workflow.py +++ b/flytekit/core/workflow.py @@ -804,8 +804,7 @@ def workflow( interruptible: bool = ..., on_failure: Optional[Union[WorkflowBase, Task]] = ..., docs: Optional[Documentation] = ..., -) -> Callable[[Callable[..., FuncOut]], PythonFunctionWorkflow]: - ... +) -> Callable[[Callable[..., FuncOut]], PythonFunctionWorkflow]: ... @overload @@ -815,8 +814,7 @@ def workflow( interruptible: bool = ..., on_failure: Optional[Union[WorkflowBase, Task]] = ..., docs: Optional[Documentation] = ..., -) -> Union[PythonFunctionWorkflow, Callable[..., FuncOut]]: - ... +) -> Union[PythonFunctionWorkflow, Callable[..., FuncOut]]: ... def workflow( diff --git a/flytekit/extras/accelerators.py b/flytekit/extras/accelerators.py index 139237e1fb..8a9d3e56a5 100644 --- a/flytekit/extras/accelerators.py +++ b/flytekit/extras/accelerators.py @@ -93,6 +93,7 @@ def my_task() -> None: A100_80GB """ + import abc import copy from typing import ClassVar, Generic, Optional, Type, TypeVar @@ -109,8 +110,7 @@ class BaseAccelerator(abc.ABC, Generic[T]): """ @abc.abstractmethod - def to_flyte_idl(self) -> T: - ... + def to_flyte_idl(self) -> T: ... class GPUAccelerator(BaseAccelerator): diff --git a/flytekit/extras/pytorch/__init__.py b/flytekit/extras/pytorch/__init__.py index a29d8e89e6..12c507afb9 100644 --- a/flytekit/extras/pytorch/__init__.py +++ b/flytekit/extras/pytorch/__init__.py @@ -10,6 +10,7 @@ PyTorchModuleTransformer PyTorchTensorTransformer """ + from flytekit.loggers import logger # TODO: abstract this out so that there's an established pattern for registering plugins diff --git a/flytekit/extras/sklearn/__init__.py b/flytekit/extras/sklearn/__init__.py index 1d16f6080f..d22546dbe2 100644 --- a/flytekit/extras/sklearn/__init__.py +++ b/flytekit/extras/sklearn/__init__.py @@ -7,6 +7,7 @@ SklearnEstimatorTransformer """ + from flytekit.loggers import logger # TODO: abstract this out so that there's an established pattern for registering plugins diff --git a/flytekit/remote/entities.py b/flytekit/remote/entities.py index 1f09ebb19d..fd78d4c3c4 100644 --- a/flytekit/remote/entities.py +++ b/flytekit/remote/entities.py @@ -2,6 +2,7 @@ This module contains shadow entities for all Flyte entities as represented in Flyte Admin / Control Plane. The goal is to enable easy access, manipulation of these entities. """ + from __future__ import annotations from typing import Dict, List, Optional, Tuple, Union diff --git a/flytekit/remote/executions.py b/flytekit/remote/executions.py index c06ee06739..4aba363f3e 100644 --- a/flytekit/remote/executions.py +++ b/flytekit/remote/executions.py @@ -26,13 +26,11 @@ def inputs(self) -> Optional[LiteralsResolver]: @property @abstractmethod - def error(self) -> core_execution_models.ExecutionError: - ... + def error(self) -> core_execution_models.ExecutionError: ... @property @abstractmethod - def is_done(self) -> bool: - ... + def is_done(self) -> bool: ... @property def outputs(self) -> Optional[LiteralsResolver]: diff --git a/flytekit/remote/remote.py b/flytekit/remote/remote.py index 584501f137..b6fdffec47 100644 --- a/flytekit/remote/remote.py +++ b/flytekit/remote/remote.py @@ -3,6 +3,7 @@ with a Flyte backend in an interactive and programmatic way. This of this experience as kind of like the web UI but in Python object form. """ + from __future__ import annotations import asyncio diff --git a/flytekit/remote/remote_callable.py b/flytekit/remote/remote_callable.py index 5b177bf7c4..ccd979bcdc 100644 --- a/flytekit/remote/remote_callable.py +++ b/flytekit/remote/remote_callable.py @@ -18,8 +18,7 @@ def __init__(self, *args, **kwargs): @property @abstractmethod - def name(self) -> str: - ... + def name(self) -> str: ... def construct_node_metadata(self) -> NodeMetadata: """ diff --git a/flytekit/types/directory/types.py b/flytekit/types/directory/types.py index 5c50bab9a5..ca8228b8a8 100644 --- a/flytekit/types/directory/types.py +++ b/flytekit/types/directory/types.py @@ -28,8 +28,7 @@ PathType = typing.Union[str, os.PathLike] -def noop(): - ... +def noop(): ... @dataclass diff --git a/flytekit/types/file/file.py b/flytekit/types/file/file.py index 9b71eb5b4f..5720753e6f 100644 --- a/flytekit/types/file/file.py +++ b/flytekit/types/file/file.py @@ -21,8 +21,7 @@ from flytekit.types.pickle.pickle import FlytePickleTransformer -def noop(): - ... +def noop(): ... T = typing.TypeVar("T") diff --git a/flytekit/types/schema/types.py b/flytekit/types/schema/types.py index e7fb1d6c09..349b5aaf5a 100644 --- a/flytekit/types/schema/types.py +++ b/flytekit/types/schema/types.py @@ -68,12 +68,10 @@ def column_names(self) -> typing.Optional[typing.List[str]]: return None @abstractmethod - def iter(self, **kwargs) -> typing.Generator[T, None, None]: - ... + def iter(self, **kwargs) -> typing.Generator[T, None, None]: ... @abstractmethod - def all(self, **kwargs) -> T: - ... + def all(self, **kwargs) -> T: ... class SchemaWriter(typing.Generic[T]): @@ -95,8 +93,7 @@ def column_names(self) -> typing.Optional[typing.List[str]]: return None @abstractmethod - def write(self, *dfs, **kwargs): - ... + def write(self, *dfs, **kwargs): ... class LocalIOSchemaReader(SchemaReader[T]): diff --git a/flytekit/types/structured/__init__.py b/flytekit/types/structured/__init__.py index 7c92be78b1..7dffa49eec 100644 --- a/flytekit/types/structured/__init__.py +++ b/flytekit/types/structured/__init__.py @@ -12,7 +12,6 @@ StructuredDatasetDecoder """ - from flytekit.deck.renderer import ArrowRenderer, TopFrameRenderer from flytekit.loggers import logger diff --git a/flytekit/types/structured/structured_dataset.py b/flytekit/types/structured/structured_dataset.py index b2c8f52b9b..1d2209a2b4 100644 --- a/flytekit/types/structured/structured_dataset.py +++ b/flytekit/types/structured/structured_dataset.py @@ -1,5 +1,6 @@ from __future__ import annotations +import _datetime import collections import types import typing @@ -7,7 +8,6 @@ from dataclasses import dataclass, field, is_dataclass from typing import Dict, Generator, Optional, Type, Union -import _datetime from dataclasses_json import config from fsspec.utils import get_protocol from marshmallow import fields @@ -348,8 +348,7 @@ def get_supported_types(): return _SUPPORTED_TYPES -class DuplicateHandlerError(ValueError): - ... +class DuplicateHandlerError(ValueError): ... class StructuredDatasetTransformerEngine(TypeTransformer[StructuredDataset]): @@ -861,9 +860,9 @@ def _get_dataset_type(self, t: typing.Union[Type[StructuredDataset], typing.Any] original_python_type, column_map, storage_format, pa_schema = extract_cols_and_format(t) # type: ignore # Get the column information - converted_cols: typing.List[ - StructuredDatasetType.DatasetColumn - ] = self._convert_ordered_dict_of_columns_to_list(column_map) + converted_cols: typing.List[StructuredDatasetType.DatasetColumn] = ( + self._convert_ordered_dict_of_columns_to_list(column_map) + ) return StructuredDatasetType( columns=converted_cols, diff --git a/plugins/flytekit-async-fsspec/flytekitplugins/async_fsspec/__init__.py b/plugins/flytekit-async-fsspec/flytekitplugins/async_fsspec/__init__.py index 3cc0de14e7..09f6fd5dbd 100644 --- a/plugins/flytekit-async-fsspec/flytekitplugins/async_fsspec/__init__.py +++ b/plugins/flytekit-async-fsspec/flytekitplugins/async_fsspec/__init__.py @@ -9,6 +9,7 @@ AsyncS3FileSystem """ + import fsspec from .s3fs.s3fs import AsyncS3FileSystem diff --git a/plugins/flytekit-kf-mpi/flytekitplugins/kfmpi/task.py b/plugins/flytekit-kf-mpi/flytekitplugins/kfmpi/task.py index 665e195b4b..7c8416d007 100644 --- a/plugins/flytekit-kf-mpi/flytekitplugins/kfmpi/task.py +++ b/plugins/flytekit-kf-mpi/flytekitplugins/kfmpi/task.py @@ -2,6 +2,7 @@ This Plugin adds the capability of running distributed MPI training to Flyte using backend plugins, natively on Kubernetes. It leverages `MPI Job `_ Plugin from kubeflow. """ + from dataclasses import dataclass, field from enum import Enum from typing import Any, Callable, Dict, List, Optional, Union diff --git a/plugins/flytekit-kf-mpi/tests/test_mpi_task.py b/plugins/flytekit-kf-mpi/tests/test_mpi_task.py index f2b453fcce..deec3ff385 100644 --- a/plugins/flytekit-kf-mpi/tests/test_mpi_task.py +++ b/plugins/flytekit-kf-mpi/tests/test_mpi_task.py @@ -175,8 +175,7 @@ def test_horovod_task(serialization_settings): ), ), ) - def my_horovod_task(): - ... + def my_horovod_task(): ... cmd = my_horovod_task.get_command(serialization_settings) assert "horovodrun" in cmd diff --git a/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/error_handling.py b/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/error_handling.py index f3c509207e..f1071678c4 100644 --- a/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/error_handling.py +++ b/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/error_handling.py @@ -1,4 +1,5 @@ """Handle errors in elastic training jobs.""" + import os RECOVERABLE_ERROR_FILE_NAME = "recoverable_error" diff --git a/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py b/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py index 46eb086ad0..94b575e2a9 100644 --- a/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py +++ b/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py @@ -2,6 +2,7 @@ This Plugin adds the capability of running distributed pytorch training to Flyte using backend plugins, natively on Kubernetes. It leverages `Pytorch Job `_ Plugin from kubeflow. """ + import os from dataclasses import dataclass, field from enum import Enum diff --git a/plugins/flytekit-kf-tensorflow/flytekitplugins/kftensorflow/task.py b/plugins/flytekit-kf-tensorflow/flytekitplugins/kftensorflow/task.py index 7be1f7d030..62cd482416 100644 --- a/plugins/flytekit-kf-tensorflow/flytekitplugins/kftensorflow/task.py +++ b/plugins/flytekit-kf-tensorflow/flytekitplugins/kftensorflow/task.py @@ -2,6 +2,7 @@ This Plugin adds the capability of running distributed tensorflow training to Flyte using backend plugins, natively on Kubernetes. It leverages `TF Job `_ Plugin from kubeflow. """ + from dataclasses import dataclass, field from enum import Enum from typing import Any, Callable, Dict, Optional, Union diff --git a/plugins/flytekit-modin/flytekitplugins/modin/schema.py b/plugins/flytekit-modin/flytekitplugins/modin/schema.py index f5ab78489a..0504c38746 100644 --- a/plugins/flytekit-modin/flytekitplugins/modin/schema.py +++ b/plugins/flytekit-modin/flytekitplugins/modin/schema.py @@ -61,9 +61,9 @@ class ModinPandasDataFrameTransformer(TypeTransformer[pandas.DataFrame]): Transforms ModinPandas DataFrame's to and from a Schema (typed/untyped) """ - _SUPPORTED_TYPES: typing.Dict[ - type, SchemaType.SchemaColumn.SchemaColumnType - ] = FlyteSchemaTransformer._SUPPORTED_TYPES + _SUPPORTED_TYPES: typing.Dict[type, SchemaType.SchemaColumn.SchemaColumnType] = ( + FlyteSchemaTransformer._SUPPORTED_TYPES + ) def __init__(self): super().__init__("pandas-df-transformer", pandas.DataFrame) diff --git a/plugins/flytekit-openai/flytekitplugins/openai/batch/task.py b/plugins/flytekit-openai/flytekitplugins/openai/batch/task.py index f5054582dd..695e8882e6 100644 --- a/plugins/flytekit-openai/flytekitplugins/openai/batch/task.py +++ b/plugins/flytekit-openai/flytekitplugins/openai/batch/task.py @@ -33,7 +33,7 @@ def __init__( self, name: str, openai_organization: str, - config: Dict[str, Any] = {}, + config: Dict[str, Any], **kwargs, ): super().__init__( diff --git a/plugins/flytekit-openai/flytekitplugins/openai/batch/workflow.py b/plugins/flytekit-openai/flytekitplugins/openai/batch/workflow.py index 1f0ff30b51..209bd0d981 100644 --- a/plugins/flytekit-openai/flytekitplugins/openai/batch/workflow.py +++ b/plugins/flytekit-openai/flytekitplugins/openai/batch/workflow.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Iterator +from typing import Any, Dict, Iterator, Optional from flytekit import Resources, Workflow from flytekit.models.security import Secret @@ -18,7 +18,7 @@ def create_batch( name: str, openai_organization: str, secret: Secret, - config: Dict[str, Any] = {}, + config: Optional[Dict[str, Any]] = None, is_json_iterator: bool = True, file_upload_mem: str = "700Mi", file_download_mem: str = "700Mi", @@ -45,6 +45,8 @@ def create_batch( name=f"openai-file-upload-{name.replace('.', '')}", task_config=OpenAIFileConfig(openai_organization=openai_organization, secret=secret), ) + if config is None: + config = {} batch_endpoint_task_obj = BatchEndpointTask( name=f"openai-batch-{name.replace('.', '')}", openai_organization=openai_organization, diff --git a/plugins/flytekit-pandera/flytekitplugins/pandera/schema.py b/plugins/flytekit-pandera/flytekitplugins/pandera/schema.py index 1c589e4c0f..6fe833d836 100644 --- a/plugins/flytekit-pandera/flytekitplugins/pandera/schema.py +++ b/plugins/flytekit-pandera/flytekitplugins/pandera/schema.py @@ -16,9 +16,9 @@ class PanderaTransformer(TypeTransformer[pandera.typing.DataFrame]): - _SUPPORTED_TYPES: typing.Dict[ - type, SchemaType.SchemaColumn.SchemaColumnType - ] = FlyteSchemaTransformer._SUPPORTED_TYPES + _SUPPORTED_TYPES: typing.Dict[type, SchemaType.SchemaColumn.SchemaColumnType] = ( + FlyteSchemaTransformer._SUPPORTED_TYPES + ) def __init__(self): super().__init__("Pandera Transformer", pandera.typing.DataFrame) # type: ignore diff --git a/plugins/flytekit-pydantic/flytekitplugins/pydantic/basemodel_transformer.py b/plugins/flytekit-pydantic/flytekitplugins/pydantic/basemodel_transformer.py index 4854360a01..50552ab108 100644 --- a/plugins/flytekit-pydantic/flytekitplugins/pydantic/basemodel_transformer.py +++ b/plugins/flytekit-pydantic/flytekitplugins/pydantic/basemodel_transformer.py @@ -1,4 +1,4 @@ -"""Serializes & deserializes the pydantic basemodels """ +"""Serializes & deserializes the pydantic basemodels""" from typing import Dict, Type diff --git a/plugins/flytekit-pydantic/flytekitplugins/pydantic/serialization.py b/plugins/flytekit-pydantic/flytekitplugins/pydantic/serialization.py index dff07883bf..5951803fdc 100644 --- a/plugins/flytekit-pydantic/flytekitplugins/pydantic/serialization.py +++ b/plugins/flytekit-pydantic/flytekitplugins/pydantic/serialization.py @@ -8,6 +8,7 @@ 3. Return a literal map with the json and the flyte object store represented as a literalmap {placeholder: flyte type} """ + import uuid from typing import Any, Dict, Union, cast diff --git a/plugins/flytekit-spark/flytekitplugins/spark/models.py b/plugins/flytekit-spark/flytekitplugins/spark/models.py index df8191304a..e74a9fbe3f 100644 --- a/plugins/flytekit-spark/flytekitplugins/spark/models.py +++ b/plugins/flytekit-spark/flytekitplugins/spark/models.py @@ -25,7 +25,7 @@ def __init__( spark_conf: Dict[str, str], hadoop_conf: Dict[str, str], executor_path: str, - databricks_conf: Dict[str, Dict[str, Dict]] = {}, + databricks_conf: Optional[Dict[str, Dict[str, Dict]]] = None, databricks_instance: Optional[str] = None, ): """ @@ -43,6 +43,8 @@ def __init__( self._executor_path = executor_path self._spark_conf = spark_conf self._hadoop_conf = hadoop_conf + if databricks_conf is None: + databricks_conf = {} self._databricks_conf = databricks_conf self._databricks_instance = databricks_instance diff --git a/pyproject.toml b/pyproject.toml index 126f05050a..e0fd189bd6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -118,7 +118,7 @@ branch = true [tool.ruff] line-length = 120 -lint.select = ["E", "W", "F", "I"] +lint.select = ["E", "W", "F", "I", "B006"] lint.ignore = [ # Whitespace before '{symbol}' "E203", @@ -135,6 +135,10 @@ lint.ignore = [ # Do not assign a lambda expression, use a def "E731", ] +extend-exclude = [ + "tests/", + "**/tests/**", +] [tool.ruff.lint.extend-per-file-ignores] "*/__init__.py" = [ From 47f2a294e8d84ce4f86655356055ffd801bc006a Mon Sep 17 00:00:00 2001 From: Chi-Sheng Liu Date: Fri, 7 Jun 2024 01:12:18 +0800 Subject: [PATCH 20/32] feat(bindings): Task arguments default value binding (#2401) flyteorg/flyte#5321 if the key is not in `kwargs` but in `interface.inputs_with_defaults`, add the value in `interface.inputs_with_defaults` to `kwargs`. Signed-off-by: Chi-Sheng Liu --- flytekit/core/promise.py | 56 +-- flytekit/core/type_engine.py | 2 +- tests/flytekit/unit/core/test_composition.py | 14 - tests/flytekit/unit/core/test_dynamic.py | 37 ++ tests/flytekit/unit/core/test_promise.py | 2 +- .../flytekit/unit/core/test_serialization.py | 450 +++++++++++++++++- .../unit/types/pickle/test_flyte_pickle.py | 52 +- .../test_structured_dataset.py | 40 +- 8 files changed, 607 insertions(+), 46 deletions(-) diff --git a/flytekit/core/promise.py b/flytekit/core/promise.py index 931d036d02..afd3c069f6 100644 --- a/flytekit/core/promise.py +++ b/flytekit/core/promise.py @@ -4,10 +4,10 @@ import inspect from copy import deepcopy from enum import Enum -from typing import Any, Coroutine, Dict, List, Optional, Set, Tuple, Union, cast +from typing import Any, Coroutine, Dict, Hashable, List, Optional, Set, Tuple, Union, cast, get_args from google.protobuf import struct_pb2 as _struct -from typing_extensions import Protocol, get_args +from typing_extensions import Protocol from flytekit.core import constants as _common_constants from flytekit.core import context_manager as _flyte_context @@ -23,7 +23,13 @@ ) from flytekit.core.interface import Interface from flytekit.core.node import Node -from flytekit.core.type_engine import DictTransformer, ListTransformer, TypeEngine, TypeTransformerFailedError +from flytekit.core.type_engine import ( + DictTransformer, + ListTransformer, + TypeEngine, + TypeTransformerFailedError, + UnionTransformer, +) from flytekit.exceptions import user as _user_exceptions from flytekit.exceptions.user import FlytePromiseAttributeResolveException from flytekit.loggers import logger @@ -774,7 +780,13 @@ def binding_from_python_std( t_value_type: type, ) -> Tuple[_literals_models.Binding, List[Node]]: nodes: List[Node] = [] - binding_data = binding_data_from_python_std(ctx, expected_literal_type, t_value, t_value_type, nodes) + binding_data = binding_data_from_python_std( + ctx, + expected_literal_type, + t_value, + t_value_type, + nodes, + ) return _literals_models.Binding(var=var_name, binding=binding_data), nodes @@ -1060,32 +1072,22 @@ def create_and_link_node( for k in sorted(interface.inputs): var = typed_interface.inputs[k] + if var.type.simple == SimpleType.NONE: + raise TypeError("Arguments do not have type annotation") if k not in kwargs: - is_optional = False - if var.type.union_type: - for variant in var.type.union_type.variants: - if variant.simple == SimpleType.NONE: - val, _default = interface.inputs_with_defaults[k] - if _default is not None: - raise ValueError( - f"The default value for the optional type must be None, but got {_default}" - ) - is_optional = True - if not is_optional: - from flytekit.core.base_task import Task - + # interface.inputs_with_defaults[k][0] is the type of the default argument + # interface.inputs_with_defaults[k][1] is the value of the default argument + if k in interface.inputs_with_defaults and ( + interface.inputs_with_defaults[k][1] is not None + or UnionTransformer.is_optional_type(interface.inputs_with_defaults[k][0]) + ): + default_val = interface.inputs_with_defaults[k][1] + if not isinstance(default_val, Hashable): + raise _user_exceptions.FlyteAssertion("Cannot use non-hashable object as default argument") + kwargs[k] = default_val + else: error_msg = f"Input {k} of type {interface.inputs[k]} was not specified for function {entity.name}" - - _, _default = interface.inputs_with_defaults[k] - if isinstance(entity, Task) and _default is not None: - error_msg += ( - ". Flyte workflow syntax is a domain-specific language (DSL) for building execution graphs which " - "supports a subset of Python’s semantics. When calling tasks, all kwargs have to be provided." - ) - raise _user_exceptions.FlyteAssertion(error_msg) - else: - continue v = kwargs[k] # This check ensures that tuples are not passed into a function, as tuples are not supported by Flyte # Usually a Tuple will indicate that multiple outputs from a previous task were accidentally passed diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index f5beb53f52..55d6368e43 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -1554,7 +1554,7 @@ def __init__(self): super().__init__("Typed Union", typing.Union) @staticmethod - def is_optional_type(t: Type[T]) -> bool: + def is_optional_type(t: Type) -> bool: """Return True if `t` is a Union or Optional type.""" return _is_union_type(t) or type(None) in get_args(t) diff --git a/tests/flytekit/unit/core/test_composition.py b/tests/flytekit/unit/core/test_composition.py index 6fe2b01e61..0073baec53 100644 --- a/tests/flytekit/unit/core/test_composition.py +++ b/tests/flytekit/unit/core/test_composition.py @@ -1,7 +1,5 @@ from typing import Dict, List, NamedTuple, Optional, Union -import pytest - from flytekit.core import launch_plan from flytekit.core.task import task from flytekit.core.workflow import workflow @@ -186,15 +184,3 @@ def wf(a: Optional[int] = 1) -> Optional[int]: return t2(a=a) assert wf() is None - - with pytest.raises(ValueError, match="The default value for the optional type must be None, but got 3"): - - @task() - def t3(c: Optional[int] = 3) -> Optional[int]: - ... - - @workflow - def wf(): - return t3() - - wf() diff --git a/tests/flytekit/unit/core/test_dynamic.py b/tests/flytekit/unit/core/test_dynamic.py index 7964548674..d3a7237391 100644 --- a/tests/flytekit/unit/core/test_dynamic.py +++ b/tests/flytekit/unit/core/test_dynamic.py @@ -95,6 +95,43 @@ def ranged_int_to_str(a: int) -> typing.List[str]: assert res == ["fast-2", "fast-3", "fast-4", "fast-5", "fast-6"] +@pytest.mark.parametrize( + "input_val,output_val", + [ + (4, 0), + (5, 5), + ], +) +def test_dynamic_local_default_args_task(input_val, output_val): + @task + def t1(a: int = 0) -> int: + return a + + @dynamic + def dt(a: int) -> int: + if a % 2 == 0: + return t1() + return t1(a=a) + + assert dt(a=input_val) == output_val + + with context_manager.FlyteContextManager.with_context( + context_manager.FlyteContextManager.current_context().with_serialization_settings(settings) + ) as ctx: + with context_manager.FlyteContextManager.with_context( + ctx.with_execution_state( + ctx.execution_state.with_params( + mode=ExecutionState.Mode.TASK_EXECUTION, + ) + ) + ) as ctx: + input_literal_map = TypeEngine.dict_to_literal_map(ctx, {"a": input_val}) + dynamic_job_spec = dt.dispatch_execute(ctx, input_literal_map) + assert len(dynamic_job_spec.nodes) == 1 + assert len(dynamic_job_spec.tasks) == 1 + assert dynamic_job_spec.nodes[0].inputs[0].binding.scalar.primitive is not None + + def test_nested_dynamic_local(): @task def t1(a: int) -> str: diff --git a/tests/flytekit/unit/core/test_promise.py b/tests/flytekit/unit/core/test_promise.py index e022c875e0..a6d223f21a 100644 --- a/tests/flytekit/unit/core/test_promise.py +++ b/tests/flytekit/unit/core/test_promise.py @@ -46,7 +46,7 @@ def t2(a: typing.Optional[int] = None) -> typing.Optional[int]: p = create_and_link_node(ctx, t2) assert p.ref.var == "o0" - assert len(p.ref.node.bindings) == 0 + assert len(p.ref.node.bindings) == 1 def test_create_and_link_node_from_remote(): diff --git a/tests/flytekit/unit/core/test_serialization.py b/tests/flytekit/unit/core/test_serialization.py index 88297f43f4..2fcf8bbd94 100644 --- a/tests/flytekit/unit/core/test_serialization.py +++ b/tests/flytekit/unit/core/test_serialization.py @@ -12,9 +12,20 @@ from flytekit.core.python_auto_container import get_registerable_container_image from flytekit.core.task import task from flytekit.core.workflow import workflow +from flytekit.exceptions.user import FlyteAssertion from flytekit.image_spec.image_spec import ImageBuildEngine, _calculate_deduped_hash_from_image_spec from flytekit.models.admin.workflow import WorkflowSpec -from flytekit.models.types import SimpleType +from flytekit.models.literals import ( + BindingData, + BindingDataCollection, + BindingDataMap, + Literal, + Primitive, + Scalar, + Union, + Void, +) +from flytekit.models.types import LiteralType, SimpleType, TypeStructure, UnionType from flytekit.tools.translator import get_serializable from flytekit.types.error.error import FlyteError @@ -495,3 +506,440 @@ def z(a: int, b: str) -> typing.Tuple[int, str]: assert task_spec.template.interface.inputs["a"].description == "foo" assert task_spec.template.interface.inputs["b"].description == "bar" assert task_spec.template.interface.outputs["o0"].description == "ramen" + + +def test_default_args_task_int_type(): + default_val = 0 + input_val = 100 + + @task + def t1(a: int = default_val) -> int: + return a + + @workflow + def wf_no_input() -> int: + return t1() + + @workflow + def wf_with_input() -> int: + return t1(a=input_val) + + @workflow + def wf_with_sub_wf() -> tuple[int, int]: + return (wf_no_input(), wf_with_input()) + + wf_no_input_spec = get_serializable(OrderedDict(), serialization_settings, wf_no_input) + wf_with_input_spec = get_serializable(OrderedDict(), serialization_settings, wf_with_input) + + assert wf_no_input_spec.template.nodes[0].inputs[0].binding.value == Scalar( + primitive=Primitive(integer=default_val) + ) + assert wf_with_input_spec.template.nodes[0].inputs[0].binding.value == Scalar( + primitive=Primitive(integer=input_val) + ) + + output_type = LiteralType(simple=SimpleType.INTEGER) + assert wf_no_input_spec.template.interface.outputs["o0"].type == output_type + assert wf_with_input_spec.template.interface.outputs["o0"].type == output_type + + assert wf_no_input() == default_val + assert wf_with_input() == input_val + assert wf_with_sub_wf() == (default_val, input_val) + + +def test_default_args_task_str_type(): + default_val = "" + input_val = "foo" + + @task + def t1(a: str = default_val) -> str: + return a + + @workflow + def wf_no_input() -> str: + return t1() + + @workflow + def wf_with_input() -> str: + return t1(a=input_val) + + @workflow + def wf_with_sub_wf() -> tuple[str, str]: + return (wf_no_input(), wf_with_input()) + + wf_no_input_spec = get_serializable(OrderedDict(), serialization_settings, wf_no_input) + wf_with_input_spec = get_serializable(OrderedDict(), serialization_settings, wf_with_input) + + assert wf_no_input_spec.template.nodes[0].inputs[0].binding.value == Scalar( + primitive=Primitive(string_value=default_val) + ) + assert wf_with_input_spec.template.nodes[0].inputs[0].binding.value == Scalar( + primitive=Primitive(string_value=input_val) + ) + + output_type = LiteralType(simple=SimpleType.STRING) + assert wf_no_input_spec.template.interface.outputs["o0"].type == output_type + assert wf_with_input_spec.template.interface.outputs["o0"].type == output_type + + assert wf_no_input() == default_val + assert wf_with_input() == input_val + assert wf_with_sub_wf() == (default_val, input_val) + + +def test_default_args_task_optional_int_type_default_none(): + default_val = None + input_val = 100 + + @task + def t1(a: typing.Optional[int] = default_val) -> typing.Optional[int]: + return a + + @workflow + def wf_no_input() -> typing.Optional[int]: + return t1() + + @workflow + def wf_with_input() -> typing.Optional[int]: + return t1(a=input_val) + + @workflow + def wf_with_sub_wf() -> tuple[typing.Optional[int], typing.Optional[int]]: + return (wf_no_input(), wf_with_input()) + + wf_no_input_spec = get_serializable(OrderedDict(), serialization_settings, wf_no_input) + wf_with_input_spec = get_serializable(OrderedDict(), serialization_settings, wf_with_input) + + assert wf_no_input_spec.template.nodes[0].inputs[0].binding.value == Scalar( + union=Union( + value=Literal( + scalar=Scalar( + none_type=Void(), + ), + ), + stored_type=LiteralType( + simple=SimpleType.NONE, + structure=TypeStructure(tag="none"), + ), + ), + ) + assert wf_with_input_spec.template.nodes[0].inputs[0].binding.value == Scalar( + union=Union( + value=Literal( + scalar=Scalar(primitive=Primitive(integer=input_val)), + ), + stored_type=LiteralType( + simple=SimpleType.INTEGER, + structure=TypeStructure(tag="int"), + ), + ), + ) + + output_type = LiteralType( + union_type=UnionType( + [ + LiteralType( + simple=SimpleType.INTEGER, + structure=TypeStructure(tag="int"), + ), + LiteralType( + simple=SimpleType.NONE, + structure=TypeStructure(tag="none"), + ), + ] + ) + ) + assert wf_no_input_spec.template.interface.outputs["o0"].type == output_type + assert wf_with_input_spec.template.interface.outputs["o0"].type == output_type + + assert wf_no_input() == default_val + assert wf_with_input() == input_val + assert wf_with_sub_wf() == (default_val, input_val) + + +def test_default_args_task_optional_int_type_default_int(): + default_val = 10 + input_val = 100 + + @task + def t1(a: typing.Optional[int] = default_val) -> typing.Optional[int]: + return a + + @workflow + def wf_no_input() -> typing.Optional[int]: + return t1() + + @workflow + def wf_with_input() -> typing.Optional[int]: + return t1(a=input_val) + + @workflow + def wf_with_sub_wf() -> tuple[typing.Optional[int], typing.Optional[int]]: + return (wf_no_input(), wf_with_input()) + + wf_no_input_spec = get_serializable(OrderedDict(), serialization_settings, wf_no_input) + wf_with_input_spec = get_serializable(OrderedDict(), serialization_settings, wf_with_input) + + assert wf_no_input_spec.template.nodes[0].inputs[0].binding.value == Scalar( + union=Union( + value=Literal( + scalar=Scalar( + primitive=Primitive(integer=default_val), + ), + ), + stored_type=LiteralType( + simple=SimpleType.INTEGER, + structure=TypeStructure(tag="int"), + ), + ), + ) + assert wf_with_input_spec.template.nodes[0].inputs[0].binding.value == Scalar( + union=Union( + value=Literal( + scalar=Scalar(primitive=Primitive(integer=input_val)), + ), + stored_type=LiteralType( + simple=SimpleType.INTEGER, + structure=TypeStructure(tag="int"), + ), + ), + ) + + output_type = LiteralType( + union_type=UnionType( + [ + LiteralType( + simple=SimpleType.INTEGER, + structure=TypeStructure(tag="int"), + ), + LiteralType( + simple=SimpleType.NONE, + structure=TypeStructure(tag="none"), + ), + ] + ) + ) + assert wf_no_input_spec.template.interface.outputs["o0"].type == output_type + assert wf_with_input_spec.template.interface.outputs["o0"].type == output_type + + assert wf_no_input() == default_val + assert wf_with_input() == input_val + assert wf_with_sub_wf() == (default_val, input_val) + + +def test_default_args_task_no_type_hint(): + @task + def t1(a=0) -> int: + return a + + @workflow + def wf_no_input() -> int: + return t1() + + @workflow + def wf_with_input() -> int: + return t1(a=100) + + with pytest.raises(TypeError, match="Arguments do not have type annotation"): + get_serializable(OrderedDict(), serialization_settings, wf_no_input) + with pytest.raises(TypeError, match="Arguments do not have type annotation"): + get_serializable(OrderedDict(), serialization_settings, wf_with_input) + + +def test_default_args_task_mismatch_type(): + @task + def t1(a: int = "foo") -> int: # type: ignore + return a + + @workflow + def wf_no_input() -> int: + return t1() + + @workflow + def wf_with_input() -> int: + return t1(a="bar") + + with pytest.raises(AssertionError, match="Failed to Bind variable"): + get_serializable(OrderedDict(), serialization_settings, wf_no_input) + with pytest.raises(AssertionError, match="Failed to Bind variable"): + get_serializable(OrderedDict(), serialization_settings, wf_with_input) + + +def test_default_args_task_list_type(): + input_val = [1, 2, 3] + + @task + def t1(a: list[int] = []) -> list[int]: + return a + + @workflow + def wf_no_input() -> list[int]: + return t1() + + @workflow + def wf_with_input() -> list[int]: + return t1(a=input_val) + + with pytest.raises(FlyteAssertion, match="Cannot use non-hashable object as default argument"): + get_serializable(OrderedDict(), serialization_settings, wf_no_input) + + wf_with_input_spec = get_serializable(OrderedDict(), serialization_settings, wf_with_input) + + assert wf_with_input_spec.template.nodes[0].inputs[0].binding.value == BindingDataCollection( + bindings=[ + BindingData(scalar=Scalar(primitive=Primitive(integer=1))), + BindingData(scalar=Scalar(primitive=Primitive(integer=2))), + BindingData(scalar=Scalar(primitive=Primitive(integer=3))), + ] + ) + + assert wf_with_input_spec.template.interface.outputs["o0"].type == LiteralType( + collection_type=LiteralType(simple=SimpleType.INTEGER) + ) + + assert wf_with_input() == input_val + + +def test_default_args_task_dict_type(): + input_val = {"a": 1, "b": 2} + + @task + def t1(a: dict[str, int] = {}) -> dict[str, int]: + return a + + @workflow + def wf_no_input() -> dict[str, int]: + return t1() + + @workflow + def wf_with_input() -> dict[str, int]: + return t1(a=input_val) + + with pytest.raises(FlyteAssertion, match="Cannot use non-hashable object as default argument"): + get_serializable(OrderedDict(), serialization_settings, wf_no_input) + + wf_with_input_spec = get_serializable(OrderedDict(), serialization_settings, wf_with_input) + + assert wf_with_input_spec.template.nodes[0].inputs[0].binding.value == BindingDataMap( + bindings={ + "a": BindingData(scalar=Scalar(primitive=Primitive(integer=1))), + "b": BindingData(scalar=Scalar(primitive=Primitive(integer=2))), + } + ) + + assert wf_with_input_spec.template.interface.outputs["o0"].type == LiteralType( + map_value_type=LiteralType(simple=SimpleType.INTEGER) + ) + + assert wf_with_input() == input_val + + +def test_default_args_task_optional_list_type_default_none(): + default_val = None + input_val = [1, 2, 3] + + @task + def t1(a: typing.Optional[typing.List[int]] = default_val) -> typing.Optional[typing.List[int]]: + return a + + @workflow + def wf_no_input() -> typing.Optional[typing.List[int]]: + return t1() + + @workflow + def wf_with_input() -> typing.Optional[typing.List[int]]: + return t1(a=input_val) + + @workflow + def wf_with_sub_wf() -> tuple[typing.Optional[typing.List[int]], typing.Optional[typing.List[int]]]: + return (wf_no_input(), wf_with_input()) + + wf_no_input_spec = get_serializable(OrderedDict(), serialization_settings, wf_no_input) + wf_with_input_spec = get_serializable(OrderedDict(), serialization_settings, wf_with_input) + + assert wf_no_input_spec.template.nodes[0].inputs[0].binding.value == Scalar( + union=Union( + value=Literal( + scalar=Scalar( + none_type=Void(), + ), + ), + stored_type=LiteralType( + simple=SimpleType.NONE, + structure=TypeStructure(tag="none"), + ), + ), + ) + assert wf_with_input_spec.template.nodes[0].inputs[0].binding.value == BindingDataCollection( + bindings=[ + BindingData(scalar=Scalar(primitive=Primitive(integer=1))), + BindingData(scalar=Scalar(primitive=Primitive(integer=2))), + BindingData(scalar=Scalar(primitive=Primitive(integer=3))), + ] + ) + + output_type = LiteralType( + union_type=UnionType( + [ + LiteralType( + collection_type=LiteralType(simple=SimpleType.INTEGER), + structure=TypeStructure(tag="Typed List"), + ), + LiteralType( + simple=SimpleType.NONE, + structure=TypeStructure(tag="none"), + ), + ] + ) + ) + assert wf_no_input_spec.template.interface.outputs["o0"].type == output_type + assert wf_with_input_spec.template.interface.outputs["o0"].type == output_type + + assert wf_no_input() == default_val + assert wf_with_input() == input_val + assert wf_with_sub_wf() == (default_val, input_val) + + +def test_default_args_task_optional_list_type_default_list(): + input_val = [1, 2, 3] + + @task + def t1(a: typing.Optional[typing.List[int]] = []) -> typing.Optional[typing.List[int]]: + return a + + @workflow + def wf_no_input() -> typing.Optional[typing.List[int]]: + return t1() + + @workflow + def wf_with_input() -> typing.Optional[typing.List[int]]: + return t1(a=input_val) + + with pytest.raises(FlyteAssertion, match="Cannot use non-hashable object as default argument"): + get_serializable(OrderedDict(), serialization_settings, wf_no_input) + + wf_with_input_spec = get_serializable(OrderedDict(), serialization_settings, wf_with_input) + + assert wf_with_input_spec.template.nodes[0].inputs[0].binding.value == BindingDataCollection( + bindings=[ + BindingData(scalar=Scalar(primitive=Primitive(integer=1))), + BindingData(scalar=Scalar(primitive=Primitive(integer=2))), + BindingData(scalar=Scalar(primitive=Primitive(integer=3))), + ] + ) + + assert wf_with_input_spec.template.interface.outputs["o0"].type == LiteralType( + union_type=UnionType( + [ + LiteralType( + collection_type=LiteralType(simple=SimpleType.INTEGER), + structure=TypeStructure(tag="Typed List"), + ), + LiteralType( + simple=SimpleType.NONE, + structure=TypeStructure(tag="none"), + ), + ] + ) + ) + + assert wf_with_input() == input_val diff --git a/tests/flytekit/unit/types/pickle/test_flyte_pickle.py b/tests/flytekit/unit/types/pickle/test_flyte_pickle.py index 7c2da727c1..53cdc7dc20 100644 --- a/tests/flytekit/unit/types/pickle/test_flyte_pickle.py +++ b/tests/flytekit/unit/types/pickle/test_flyte_pickle.py @@ -1,7 +1,7 @@ import sys from collections import OrderedDict from collections.abc import Sequence -from typing import Dict, List, Union +from typing import Any, Dict, List, Union import numpy as np import pytest @@ -11,6 +11,7 @@ from flytekit.configuration import Image, ImageConfig from flytekit.core import context_manager from flytekit.core.task import task +from flytekit.core.workflow import workflow from flytekit.models.core.types import BlobType from flytekit.models.literals import BlobMetadata from flytekit.models.types import LiteralType @@ -126,3 +127,52 @@ def t1(a: int) -> Annotated[Foo, a1(a="bar")]: task_spec = get_serializable(OrderedDict(), serialization_settings, t1) md = task_spec.template.interface.outputs["o0"].type.metadata["python_class_name"] assert "0x" not in str(md) + + +def test_default_args_task(): + default_val = 123 + input_val = "foo" + + @task + def t1(a: Any = default_val) -> Any: + return a + + @workflow + def wf_no_input() -> Any: + return t1() + + @workflow + def wf_with_input() -> Any: + return t1(a=input_val) + + @workflow + def wf_with_sub_wf() -> tuple[Any, Any]: + return (wf_no_input(), wf_with_input()) + + wf_no_input_spec = get_serializable(OrderedDict(), serialization_settings, wf_no_input) + wf_with_input_spec = get_serializable(OrderedDict(), serialization_settings, wf_with_input) + + metadata = BlobMetadata( + type=BlobType( + format="PythonPickle", + dimensionality=BlobType.BlobDimensionality.SINGLE, + ), + ) + assert wf_no_input_spec.template.nodes[0].inputs[0].binding.value.blob.metadata == metadata + assert wf_with_input_spec.template.nodes[0].inputs[0].binding.value.blob.metadata == metadata + + output_type = LiteralType( + blob=BlobType( + format="PythonPickle", + dimensionality=BlobType.BlobDimensionality.SINGLE, + ), + metadata={ + "python_class_name": "typing.Any", + }, + ) + assert wf_no_input_spec.template.interface.outputs["o0"].type == output_type + assert wf_with_input_spec.template.interface.outputs["o0"].type == output_type + + assert wf_no_input() == default_val + assert wf_with_input() == input_val + assert wf_with_sub_wf() == (default_val, input_val) diff --git a/tests/flytekit/unit/types/structured_dataset/test_structured_dataset.py b/tests/flytekit/unit/types/structured_dataset/test_structured_dataset.py index 9a5628af0f..18c3ce82db 100644 --- a/tests/flytekit/unit/types/structured_dataset/test_structured_dataset.py +++ b/tests/flytekit/unit/types/structured_dataset/test_structured_dataset.py @@ -1,6 +1,7 @@ import os import tempfile import typing +from collections import OrderedDict import google.cloud.bigquery import pyarrow as pa @@ -16,10 +17,12 @@ from flytekit.core.task import task from flytekit.core.type_engine import TypeEngine from flytekit.core.workflow import workflow +from flytekit.exceptions.user import FlyteAssertion from flytekit.lazy_import.lazy_module import is_imported from flytekit.models import literals from flytekit.models.literals import StructuredDatasetMetadata -from flytekit.models.types import SchemaType, SimpleType, StructuredDatasetType +from flytekit.models.types import LiteralType, SchemaType, SimpleType, StructuredDatasetType +from flytekit.tools.translator import get_serializable from flytekit.types.structured.structured_dataset import ( PARQUET, StructuredDataset, @@ -545,3 +548,38 @@ def test_reregister_encoder(): df_literal_type = TypeEngine.to_literal_type(pd.DataFrame) TypeEngine.to_literal(ctx, sd, python_type=pd.DataFrame, expected=df_literal_type) + + +def test_default_args_task(): + input_val = generate_pandas() + + @task + def t1(a: pd.DataFrame = pd.DataFrame()) -> pd.DataFrame: + return a + + @workflow + def wf_no_input() -> pd.DataFrame: + return t1() + + @workflow + def wf_with_input() -> pd.DataFrame: + return t1(a=input_val) + + with pytest.raises(FlyteAssertion, match="Cannot use non-hashable object as default argument"): + get_serializable(OrderedDict(), serialization_settings, wf_no_input) + + wf_with_input_spec = get_serializable(OrderedDict(), serialization_settings, wf_with_input) + + assert wf_with_input_spec.template.nodes[0].inputs[ + 0 + ].binding.value.structured_dataset.metadata == StructuredDatasetMetadata( + structured_dataset_type=StructuredDatasetType( + format="parquet", + ), + ) + + assert wf_with_input_spec.template.interface.outputs["o0"].type == LiteralType( + structured_dataset_type=StructuredDatasetType() + ) + + pd.testing.assert_frame_equal(wf_with_input(), input_val) From 4309c5ea92014f67204d245108e536ed2b82d83e Mon Sep 17 00:00:00 2001 From: Yee Hing Tong Date: Thu, 6 Jun 2024 10:27:22 -0700 Subject: [PATCH 21/32] Use new fileaccess to build execution state (#2452) Signed-off-by: Yee Hing Tong --- flytekit/bin/entrypoint.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/flytekit/bin/entrypoint.py b/flytekit/bin/entrypoint.py index ed20561ac3..a7fc1ed485 100644 --- a/flytekit/bin/entrypoint.py +++ b/flytekit/bin/entrypoint.py @@ -289,13 +289,15 @@ def setup_execution( logger.error(f"No data plugin found for raw output prefix {raw_output_data_prefix}") raise + ctx = ctx.new_builder().with_file_access(file_access).build() + es = ctx.new_execution_state().with_params( mode=ExecutionState.Mode.TASK_EXECUTION, user_space_params=execution_parameters, ) # create new output metadata tracker omt = OutputMetadataTracker() - cb = ctx.new_builder().with_file_access(file_access).with_execution_state(es).with_output_metadata_tracker(omt) + cb = ctx.new_builder().with_execution_state(es).with_output_metadata_tracker(omt) if compressed_serialization_settings: ss = SerializationSettings.from_transport(compressed_serialization_settings) From 3555f2ee54ae5750bf3dec62b7d026d7ebec2312 Mon Sep 17 00:00:00 2001 From: Niels Bantilan Date: Thu, 6 Jun 2024 14:59:39 -0400 Subject: [PATCH 22/32] Add Image builder `should_build` method for more customizable build behavior (#2458) * add ImageSpecBuilder.should_build method to make building logic customizeable by builder plugin Signed-off-by: Niels Bantilan * updates Signed-off-by: Niels Bantilan * ping's update Signed-off-by: Kevin Su * fix tests Signed-off-by: Kevin Su --------- Signed-off-by: Niels Bantilan Signed-off-by: Kevin Su Co-authored-by: Kevin Su --- flytekit/image_spec/image_spec.py | 38 ++++++++++++++----- .../unit/core/image_spec/test_image_spec.py | 10 ++--- 2 files changed, 33 insertions(+), 15 deletions(-) diff --git a/flytekit/image_spec/image_spec.py b/flytekit/image_spec/image_spec.py index 37f87549d0..4e2896867e 100644 --- a/flytekit/image_spec/image_spec.py +++ b/flytekit/image_spec/image_spec.py @@ -216,6 +216,27 @@ def build_image(self, image_spec: ImageSpec) -> Optional[str]: """ raise NotImplementedError("This method is not implemented in the base class.") + def should_build(self, image_spec: ImageSpec) -> bool: + """ + Whether or not the builder should build the ImageSpec. + + Args: + image_spec: image spec of the task. + + Returns: + True if the image should be built, otherwise it returns False. + """ + img_name = image_spec.image_name() + if not image_spec.exist(): + click.secho(f"Image {img_name} not found. building...", fg="blue") + return True + if image_spec._is_force_push: + click.secho(f"Image {img_name} found but overwriting existing image.", fg="blue") + return True + + click.secho(f"Image {img_name} found. Skip building.", fg="blue") + return False + class ImageBuildEngine: """ @@ -252,18 +273,11 @@ def build(cls, image_spec: ImageSpec): builder = image_spec.builder img_name = image_spec.image_name() - if image_spec.exist(): - if image_spec._is_force_push: - click.secho(f"Image {img_name} found. but overwriting existing image.", fg="blue") - cls._build_image(builder, image_spec, img_name) - else: - click.secho(f"Image {img_name} found. Skip building.", fg="blue") - else: - click.secho(f"Image {img_name} not found. building...", fg="blue") + if cls._get_builder(builder).should_build(image_spec): cls._build_image(builder, image_spec, img_name) @classmethod - def _build_image(cls, builder, image_spec, img_name): + def _get_builder(cls, builder: str) -> ImageSpecBuilder: if builder not in cls._REGISTRY: raise Exception(f"Builder {builder} is not registered.") if builder == "envd": @@ -275,7 +289,11 @@ def _build_image(cls, builder, image_spec, img_name): f"envd version {envd_version} is not compatible with flytekit>v1.10.2." f" Please upgrade envd to v0.3.39+." ) - fully_qualified_image_name = cls._REGISTRY[builder][0].build_image(image_spec) + return cls._REGISTRY[builder][0] + + @classmethod + def _build_image(cls, builder: str, image_spec: ImageSpec, img_name: str): + fully_qualified_image_name = cls._get_builder(builder).build_image(image_spec) if fully_qualified_image_name is not None: cls._IMAGE_NAME_TO_REAL_NAME[img_name] = fully_qualified_image_name diff --git a/tests/flytekit/unit/core/image_spec/test_image_spec.py b/tests/flytekit/unit/core/image_spec/test_image_spec.py index 4a596c1e1e..011828d4ce 100644 --- a/tests/flytekit/unit/core/image_spec/test_image_spec.py +++ b/tests/flytekit/unit/core/image_spec/test_image_spec.py @@ -104,14 +104,14 @@ def test_image_spec_engine_priority(): def test_build_existing_image_with_force_push(): - image_spec = Mock() - image_spec.exist.return_value = True - image_spec._is_force_push = True + image_spec = ImageSpec(name="hello", builder="test").force_push() - ImageBuildEngine._build_image = Mock() + builder = Mock() + builder.build_image.return_value = "new_image_name" + ImageBuildEngine.register("test", builder) ImageBuildEngine.build(image_spec) - ImageBuildEngine._build_image.assert_called_once() + builder.build_image.assert_called_once() def test_custom_tag(): From 7521acbd361e4b8ec2e03e5e8742cf72ee6c0421 Mon Sep 17 00:00:00 2001 From: Future-Outlier Date: Fri, 7 Jun 2024 22:47:59 +0800 Subject: [PATCH 23/32] Async Agent Interface Refactor (#2467) --- flytekit/extend/backend/base_agent.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/flytekit/extend/backend/base_agent.py b/flytekit/extend/backend/base_agent.py index e8ec18806e..1630bf71aa 100644 --- a/flytekit/extend/backend/base_agent.py +++ b/flytekit/extend/backend/base_agent.py @@ -152,8 +152,8 @@ def metadata_type(self) -> ResourceMeta: def create( self, task_template: TaskTemplate, + output_prefix: str, inputs: Optional[LiteralMap], - output_prefix: Optional[str], task_execution_metadata: Optional[TaskExecutionMetadata], **kwargs, ) -> ResourceMeta: @@ -297,7 +297,9 @@ def execute(self: PythonTask, **kwargs) -> LiteralMap: task_template = get_serializable(OrderedDict(), ss, self).template self._agent = AgentRegistry.get_agent(task_template.type, task_template.task_type_version) - resource_mata = asyncio.run(self._create(task_template, output_prefix, kwargs)) + resource_mata = asyncio.run( + self._create(task_template=task_template, output_prefix=output_prefix, inputs=kwargs) + ) resource = asyncio.run(self._get(resource_meta=resource_mata)) if resource.phase != TaskExecution.SUCCEEDED: From 070027d8781d9b9f830723a9a4cae5daa6359a33 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Sat, 8 Jun 2024 06:01:03 +0800 Subject: [PATCH 24/32] fix(core): Handle missing metadata in DictTransformer (#2469) --- flytekit/core/type_engine.py | 12 ++++++------ tests/flytekit/unit/core/test_type_engine.py | 10 ++++++++++ 2 files changed, 16 insertions(+), 6 deletions(-) diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index 55d6368e43..fa5772a3f2 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -1805,17 +1805,17 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: # for empty generic we have to explicitly test for lv.scalar.generic is not None as empty dict # evaluates to false if lv and lv.scalar and lv.scalar.generic is not None: - if lv.metadata["format"] == "json": - try: - return json.loads(_json_format.MessageToJson(lv.scalar.generic)) - except TypeError: - raise TypeTransformerFailedError(f"Cannot convert from {lv} to {expected_python_type}") - elif lv.metadata["format"] == "pickle": + if lv.metadata and lv.metadata.get("format", None) == "pickle": from flytekit.types.pickle import FlytePickle uri = json.loads(_json_format.MessageToJson(lv.scalar.generic)).get("pickle_file") return FlytePickle.from_pickle(uri) + try: + return json.loads(_json_format.MessageToJson(lv.scalar.generic)) + except TypeError: + raise TypeTransformerFailedError(f"Cannot convert from {lv} to {expected_python_type}") + raise TypeTransformerFailedError(f"Cannot convert from {lv} to {expected_python_type}") def guess_python_type(self, literal_type: LiteralType) -> Union[Type[dict], typing.Dict[Type, Type]]: diff --git a/tests/flytekit/unit/core/test_type_engine.py b/tests/flytekit/unit/core/test_type_engine.py index 0546b9dc7a..83d7ec73c4 100644 --- a/tests/flytekit/unit/core/test_type_engine.py +++ b/tests/flytekit/unit/core/test_type_engine.py @@ -560,6 +560,16 @@ def recursive_assert( typing.Dict[str, int], ) + lv = d.to_literal( + ctx, + {"x": "hello"}, + dict, + LiteralType(simple=SimpleType.STRUCT), + ) + + lv._metadata = None + assert d.to_python_value(ctx, lv, dict) == {"x": "hello"} + def test_convert_marshmallow_json_schema_to_python_class(): @dataclass From 9dfd6f7c95a43221c3460039001db543e5aab6aa Mon Sep 17 00:00:00 2001 From: Ketan Umare <16888709+kumare3@users.noreply.github.com> Date: Sat, 8 Jun 2024 15:06:41 -0700 Subject: [PATCH 25/32] Improve the error handling for dataclass (#2466) Signed-off-by: Ketan Umare Signed-off-by: Kevin Su Co-authored-by: Ketan Umare Co-authored-by: Kevin Su --- flytekit/clis/sdk_in_container/utils.py | 8 + flytekit/core/promise.py | 56 +++- flytekit/image_spec/image_spec.py | 4 +- flytekit/models/common.py | 2 +- tests/flytekit/unit/core/test_imperative.py | 2 +- tests/flytekit/unit/core/test_promise.py | 9 +- tests/flytekit/unit/core/test_type_engine.py | 324 ++++++++++--------- tests/flytekit/unit/core/test_type_hints.py | 4 +- tests/flytekit/unit/models/test_common.py | 4 +- 9 files changed, 233 insertions(+), 180 deletions(-) diff --git a/flytekit/clis/sdk_in_container/utils.py b/flytekit/clis/sdk_in_container/utils.py index 13d07d025c..239fef016b 100644 --- a/flytekit/clis/sdk_in_container/utils.py +++ b/flytekit/clis/sdk_in_container/utils.py @@ -116,6 +116,14 @@ def pretty_print_exception(e: Exception): pretty_print_grpc_error(e) return + if isinstance(e, AssertionError): + click.secho(f"Assertion Error: {e}", fg="red") + return + + if isinstance(e, ValueError): + click.secho(f"Value Error: {e}", fg="red") + return + click.secho(f"Failed with Unknown Exception {type(e)} Reason: {e}", fg="red") # noqa pretty_print_traceback(e) diff --git a/flytekit/core/promise.py b/flytekit/core/promise.py index afd3c069f6..557d621dd4 100644 --- a/flytekit/core/promise.py +++ b/flytekit/core/promise.py @@ -2,6 +2,7 @@ import collections import inspect +import typing from copy import deepcopy from enum import Enum from typing import Any, Coroutine, Dict, Hashable, List, Optional, Set, Tuple, Union, cast, get_args @@ -80,12 +81,12 @@ def my_wf(in1: int, in2: int) -> int: :param native_types: Map to native Python type. """ if incoming_values is None: - raise ValueError("Incoming values cannot be None, must be a dict") + raise AssertionError("Incoming values cannot be None, must be a dict") result = {} # So as to not overwrite the input_kwargs for k, v in incoming_values.items(): if k not in flyte_interface_types: - raise ValueError(f"Received unexpected keyword argument {k}") + raise AssertionError(f"Received unexpected keyword argument {k}") var = flyte_interface_types[k] t = native_types[k] try: @@ -372,12 +373,18 @@ def t1() -> (int, str): ... # TODO: Currently, NodeOutput we're creating is the slimmer core package Node class, but since only the # id is used, it's okay for now. Let's clean all this up though. - def __init__(self, var: str, val: Union[NodeOutput, _literals_models.Literal]): + def __init__( + self, + var: str, + val: Union[NodeOutput, _literals_models.Literal], + type: typing.Optional[_type_models.LiteralType] = None, + ): self._var = var self._promise_ready = True self._val = val self._ref = None self._attr_path: List[Union[str, int]] = [] + self._type = type if val and isinstance(val, NodeOutput): self._ref = val self._promise_ready = False @@ -525,6 +532,23 @@ def wf(): The attribute keys are appended on the promise and a new promise is returned with the updated attribute path. We don't modify the original promise because it might be used in other places as well. """ + + if self.ref and self._type: + if self._type.simple == SimpleType.STRUCT and self._type.metadata is None: + raise ValueError(f"Trying to index into a unschematized struct type {self.var}[{key}].") + if isinstance(self.val, _literals_models.Literal): + if self.val.scalar and self.val.scalar.generic: + if self._type and self._type.metadata is None: + raise ValueError( + f"Trying to index into a generic type {self.var}[{key}]." + f" It seems the upstream type is not indexable." + f" Prefer using `typing.Dict[str, ...]` or `@dataclass`" + f" Note: {self.var} is the name of the variable in your workflow function." + ) + raise ValueError( + f"Trying to index into a struct {self.var}[{key}]. Use {self.var}.{key} instead." + f" Note: {self.var} is the name of the variable in your workflow function." + ) return self._append_attr(key) def __iter__(self): @@ -533,8 +557,8 @@ def __iter__(self): But it still doesn't make sense to """ raise ValueError( - "Promise objects are not iterable - can't range() over a promise." - " But you can use [index] or the still stabilizing @eager" + f" {self.var} is a Promise. Promise objects are not iterable - can't range() over a promise." + " But you can use [index] or the alpha version of @eager workflows" ) def __getattr__(self, key) -> Promise: @@ -551,7 +575,15 @@ def wf(): The attribute keys are appended on the promise and a new promise is returned with the updated attribute path. We don't modify the original promise because it might be used in other places as well. """ - + if isinstance(self.val, _literals_models.Literal): + if self.val.scalar and self.val.scalar.generic: + if self._type and self._type.metadata is None: + raise ValueError( + f"Trying to index into a generic type {self.var}[{key}]." + f" It seems the upstream type is not indexable." + f" Prefer using `typing.Dict[str, ...]` or `@dataclass`" + f" Note: {self.var} is the name of the variable in your workflow function." + ) return self._append_attr(key) def _append_attr(self, key) -> Promise: @@ -1037,7 +1069,9 @@ def create_and_link_node_from_remote( # Create a node output object for each output, they should all point to this node of course. node_outputs = [] for output_name, output_var_model in typed_interface.outputs.items(): - node_outputs.append(Promise(output_name, NodeOutput(node=flytekit_node, var=output_name))) + node_outputs.append( + Promise(output_name, NodeOutput(node=flytekit_node, var=output_name), type=output_var_model.type) + ) return create_task_output(node_outputs) @@ -1137,7 +1171,9 @@ def create_and_link_node( # Create a node output object for each output, they should all point to this node of course. node_outputs = [] for output_name, output_var_model in typed_interface.outputs.items(): - node_outputs.append(Promise(output_name, NodeOutput(node=flytekit_node, var=output_name))) + node_outputs.append( + Promise(output_name, NodeOutput(node=flytekit_node, var=output_name), output_var_model.type) + ) # Don't print this, it'll crash cuz sdk_node._upstream_node_ids might be None, but idl code will break return create_task_output(node_outputs, interface) @@ -1176,7 +1212,7 @@ def flyte_entity_call_handler( # Make sure arguments are part of interface for k, v in kwargs.items(): if k not in cast(SupportsNodeCreation, entity).python_interface.inputs: - raise ValueError( + raise AssertionError( f"Received unexpected keyword argument '{k}' in function '{cast(SupportsNodeCreation, entity).name}'" ) @@ -1234,7 +1270,7 @@ def flyte_entity_call_handler( ): return create_native_named_tuple(ctx, result, cast(SupportsNodeCreation, entity).python_interface) - raise ValueError( + raise AssertionError( f"Expected outputs and actual outputs do not match." f"Result {result}. " f"Python interface: {cast(SupportsNodeCreation, entity).python_interface}" diff --git a/flytekit/image_spec/image_spec.py b/flytekit/image_spec/image_spec.py index 4e2896867e..4e5f103782 100644 --- a/flytekit/image_spec/image_spec.py +++ b/flytekit/image_spec/image_spec.py @@ -278,8 +278,10 @@ def build(cls, image_spec: ImageSpec): @classmethod def _get_builder(cls, builder: str) -> ImageSpecBuilder: + if builder is None: + raise AssertionError("There is no image builder registered.") if builder not in cls._REGISTRY: - raise Exception(f"Builder {builder} is not registered.") + raise AssertionError(f"Image builder {builder} is not registered.") if builder == "envd": envd_version = metadata.version("envd") # flytekit v1.10.2+ copies the workflow code to the WorkDir specified in the Dockerfile. However, envd<0.3.39 diff --git a/flytekit/models/common.py b/flytekit/models/common.py index 8fdd0837a8..77ae72e703 100644 --- a/flytekit/models/common.py +++ b/flytekit/models/common.py @@ -62,7 +62,7 @@ def short_string(self): """ literal_str = re.sub(r"\s+", " ", str(self.to_flyte_idl())).strip() type_str = type(self).__name__ - return f"" + return f"[Flyte Serialized object: Type: <{type_str}> Value: <{literal_str}>]" def verbose_string(self): """ diff --git a/tests/flytekit/unit/core/test_imperative.py b/tests/flytekit/unit/core/test_imperative.py index 24f37ff186..aee88e19d1 100644 --- a/tests/flytekit/unit/core/test_imperative.py +++ b/tests/flytekit/unit/core/test_imperative.py @@ -309,7 +309,7 @@ def t1(a: str) -> str: with pytest.raises(AssertionError): wb(3) - with pytest.raises(ValueError): + with pytest.raises(AssertionError): wb(in2="hello") diff --git a/tests/flytekit/unit/core/test_promise.py b/tests/flytekit/unit/core/test_promise.py index a6d223f21a..4a3826220d 100644 --- a/tests/flytekit/unit/core/test_promise.py +++ b/tests/flytekit/unit/core/test_promise.py @@ -86,13 +86,16 @@ def wf(i: int, j: int): # without providing the _inputs_not_allowed or _ignorable_inputs, all inputs to lp become required, # which is incorrect - with pytest.raises(FlyteAssertion, match=r"Missing input `i` type ``"): + with pytest.raises( + FlyteAssertion, + match=r"Missing input `i` type `\[Flyte Serialized object: Type: Value: \]`", + ): create_and_link_node_from_remote(ctx, lp) - # Even if j is not provided it will default + # Even if j is not provided, it will default create_and_link_node_from_remote(ctx, lp, _inputs_not_allowed={"i"}, _ignorable_inputs={"j"}) - # Even if i,j is not provided it will default + # Even if i,j is not provided, it will default create_and_link_node_from_remote( ctx, lp_without_fixed_inpus, _inputs_not_allowed=None, _ignorable_inputs={"i", "j"} ) diff --git a/tests/flytekit/unit/core/test_type_engine.py b/tests/flytekit/unit/core/test_type_engine.py index 83d7ec73c4..abdc314f29 100644 --- a/tests/flytekit/unit/core/test_type_engine.py +++ b/tests/flytekit/unit/core/test_type_engine.py @@ -705,51 +705,44 @@ def test_zero_floats(): assert TypeEngine.to_python_value(ctx, l1, float) == 0 -@dataclass -class InnerStruct(DataClassJsonMixin): - a: int - b: typing.Optional[str] - c: typing.List[int] - - -@dataclass -class TestStruct(DataClassJsonMixin): - s: InnerStruct - m: typing.Dict[str, str] - - -@dataclass -class TestStructB(DataClassJsonMixin): - s: InnerStruct - m: typing.Dict[int, str] - n: typing.Optional[typing.List[typing.List[int]]] = None - o: typing.Optional[typing.Dict[int, typing.Dict[int, int]]] = None - - -@dataclass -class TestStructC(DataClassJsonMixin): - s: InnerStruct - m: typing.Dict[str, int] - +def test_dataclass_transformer(): + @dataclass + class InnerStruct(DataClassJsonMixin): + a: int + b: typing.Optional[str] + c: typing.List[int] -@dataclass -class TestStructD(DataClassJsonMixin): - s: InnerStruct - m: typing.Dict[str, typing.List[int]] + @dataclass + class TestStruct(DataClassJsonMixin): + s: InnerStruct + m: typing.Dict[str, str] + @dataclass + class TestStructB(DataClassJsonMixin): + s: InnerStruct + m: typing.Dict[int, str] + n: typing.Optional[typing.List[typing.List[int]]] = None + o: typing.Optional[typing.Dict[int, typing.Dict[int, int]]] = None -class UnsupportedSchemaType: - def __init__(self): - self._a = "Hello" + @dataclass + class TestStructC(DataClassJsonMixin): + s: InnerStruct + m: typing.Dict[str, int] + @dataclass + class TestStructD(DataClassJsonMixin): + s: InnerStruct + m: typing.Dict[str, typing.List[int]] -@dataclass -class UnsupportedNestedStruct(DataClassJsonMixin): - a: int - s: UnsupportedSchemaType + class UnsupportedSchemaType: + def __init__(self): + self._a = "Hello" + @dataclass + class UnsupportedNestedStruct(DataClassJsonMixin): + a: int + s: UnsupportedSchemaType -def test_dataclass_transformer(): schema = { "$ref": "#/definitions/TeststructSchema", "$schema": "http://json-schema.org/draft-07/schema#", @@ -807,51 +800,27 @@ def test_dataclass_transformer(): assert t.metadata is None -@dataclass -class InnerStruct_transformer(DataClassJSONMixin): - a: int - b: typing.Optional[str] - c: typing.List[int] - - -@dataclass -class TestStruct_transformer(DataClassJSONMixin): - s: InnerStruct_transformer - m: typing.Dict[str, str] - - -@dataclass -class TestStructB_transformer(DataClassJSONMixin): - s: InnerStruct_transformer - m: typing.Dict[int, str] - n: typing.Optional[typing.List[typing.List[int]]] = None - o: typing.Optional[typing.Dict[int, typing.Dict[int, int]]] = None - - -@dataclass -class TestStructC_transformer(DataClassJSONMixin): - s: InnerStruct_transformer - m: typing.Dict[str, int] - - -@dataclass -class TestStructD_transformer(DataClassJSONMixin): - s: InnerStruct_transformer - m: typing.Dict[str, typing.List[int]] - - -@dataclass -class UnsupportedSchemaType_transformer: - _a: str = "Hello" +def test_dataclass_transformer_with_dataclassjsonmixin(): + @dataclass + class InnerStruct_transformer(DataClassJSONMixin): + a: int + b: typing.Optional[str] + c: typing.List[int] + @dataclass + class TestStruct_transformer(DataClassJSONMixin): + s: InnerStruct_transformer + m: typing.Dict[str, str] -@dataclass -class UnsupportedNestedStruct_transformer(DataClassJSONMixin): - a: int - s: UnsupportedSchemaType_transformer + class UnsupportedSchemaType: + def __init__(self): + self._a = "Hello" + @dataclass + class UnsupportedNestedStruct(DataClassJsonMixin): + a: int + s: UnsupportedSchemaType -def test_dataclass_transformer_with_dataclassjsonmixin(): schema = { "type": "object", "title": "TestStruct_transformer", @@ -900,8 +869,30 @@ def test_dataclass_transformer_with_dataclassjsonmixin(): def test_dataclass_int_preserving(): - ctx = FlyteContext.current_context() + @dataclass + class InnerStruct(DataClassJsonMixin): + a: int + b: typing.Optional[str] + c: typing.List[int] + + @dataclass + class TestStructB(DataClassJsonMixin): + s: InnerStruct + m: typing.Dict[int, str] + n: typing.Optional[typing.List[typing.List[int]]] = None + o: typing.Optional[typing.Dict[int, typing.Dict[int, int]]] = None + + @dataclass + class TestStructC(DataClassJsonMixin): + s: InnerStruct + m: typing.Dict[str, int] + @dataclass + class TestStructD(DataClassJsonMixin): + s: InnerStruct + m: typing.Dict[str, typing.List[int]] + + ctx = FlyteContext.current_context() o = InnerStruct(a=5, b=None, c=[1, 2, 3]) tf = DataclassTransformer() lv = tf.to_literal(ctx, o, InnerStruct, tf.get_literal_type(InnerStruct)) @@ -1013,31 +1004,29 @@ class TestFileStruct(DataClassJsonMixin): assert o.i_prime == A(a=99) -@dataclass -class A_optional_flytefile(DataClassJSONMixin): - a: int - - -@dataclass -class TestFileStruct_optional_flytefile(DataClassJSONMixin): - a: FlyteFile - b: typing.Optional[FlyteFile] - b_prime: typing.Optional[FlyteFile] - c: typing.Union[FlyteFile, None] - d: typing.List[FlyteFile] - e: typing.List[typing.Optional[FlyteFile]] - e_prime: typing.List[typing.Optional[FlyteFile]] - f: typing.Dict[str, FlyteFile] - g: typing.Dict[str, typing.Optional[FlyteFile]] - g_prime: typing.Dict[str, typing.Optional[FlyteFile]] - h: typing.Optional[FlyteFile] = None - h_prime: typing.Optional[FlyteFile] = None - i: typing.Optional[A_optional_flytefile] = None - i_prime: typing.Optional[A_optional_flytefile] = field(default_factory=lambda: A_optional_flytefile(a=99)) - - @mock.patch("flytekit.core.data_persistence.FileAccessProvider.put_data") def test_optional_flytefile_in_dataclassjsonmixin(mock_upload_dir): + @dataclass + class A_optional_flytefile(DataClassJSONMixin): + a: int + + @dataclass + class TestFileStruct_optional_flytefile(DataClassJSONMixin): + a: FlyteFile + b: typing.Optional[FlyteFile] + b_prime: typing.Optional[FlyteFile] + c: typing.Union[FlyteFile, None] + d: typing.List[FlyteFile] + e: typing.List[typing.Optional[FlyteFile]] + e_prime: typing.List[typing.Optional[FlyteFile]] + f: typing.Dict[str, FlyteFile] + g: typing.Dict[str, typing.Optional[FlyteFile]] + g_prime: typing.Dict[str, typing.Optional[FlyteFile]] + h: typing.Optional[FlyteFile] = None + h_prime: typing.Optional[FlyteFile] = None + i: typing.Optional[A_optional_flytefile] = None + i_prime: typing.Optional[A_optional_flytefile] = field(default_factory=lambda: A_optional_flytefile(a=99)) + remote_path = "s3://tmp/file" mock_upload_dir.return_value = remote_path @@ -1146,22 +1135,20 @@ class TestFileStruct(DataClassJsonMixin): assert not ctx.file_access.is_remote(ot.b.e["hello"].path) -@dataclass -class TestInnerFileStruct_flyte_file(DataClassJSONMixin): - a: JPEGImageFile - b: typing.List[FlyteFile] - c: typing.Dict[str, FlyteFile] - d: typing.List[FlyteFile] - e: typing.Dict[str, FlyteFile] - - -@dataclass -class TestFileStruct_flyte_file(DataClassJSONMixin): - a: FlyteFile - b: TestInnerFileStruct_flyte_file +def test_flyte_file_in_dataclassjsonmixin(): + @dataclass + class TestInnerFileStruct_flyte_file(DataClassJSONMixin): + a: JPEGImageFile + b: typing.List[FlyteFile] + c: typing.Dict[str, FlyteFile] + d: typing.List[FlyteFile] + e: typing.Dict[str, FlyteFile] + @dataclass + class TestFileStruct_flyte_file(DataClassJSONMixin): + a: FlyteFile + b: TestInnerFileStruct_flyte_file -def test_flyte_file_in_dataclassjsonmixin(): remote_path = "s3://tmp/file" f1 = FlyteFile(remote_path) f2 = FlyteFile("/tmp/file") @@ -1249,22 +1236,20 @@ class TestFileStruct(DataClassJsonMixin): assert o.b.e["hello"].path == ot.b.e["hello"].remote_source -@dataclass -class TestInnerFileStruct_flyte_directory(DataClassJSONMixin): - a: TensorboardLogs - b: typing.List[FlyteDirectory] - c: typing.Dict[str, FlyteDirectory] - d: typing.List[FlyteDirectory] - e: typing.Dict[str, FlyteDirectory] - - -@dataclass -class TestFileStruct_flyte_directory(DataClassJSONMixin): - a: FlyteDirectory - b: TestInnerFileStruct_flyte_directory +def test_flyte_directory_in_dataclassjsonmixin(): + @dataclass + class TestInnerFileStruct_flyte_directory(DataClassJSONMixin): + a: TensorboardLogs + b: typing.List[FlyteDirectory] + c: typing.Dict[str, FlyteDirectory] + d: typing.List[FlyteDirectory] + e: typing.Dict[str, FlyteDirectory] + @dataclass + class TestFileStruct_flyte_directory(DataClassJSONMixin): + a: FlyteDirectory + b: TestInnerFileStruct_flyte_directory -def test_flyte_directory_in_dataclassjsonmixin(): remote_path = "s3://tmp/file" tempdir = tempfile.mkdtemp(prefix="flyte-") f1 = FlyteDirectory(tempdir) @@ -1341,15 +1326,14 @@ class DatasetStruct(DataClassJsonMixin): assert "parquet" == ot.b.c["hello"].file_format -@dataclass -class InnerDatasetStructDataclassJsonMixin(DataClassJSONMixin): - a: StructuredDataset - b: typing.List[Annotated[StructuredDataset, "parquet"]] - c: typing.Dict[str, Annotated[StructuredDataset, kwtypes(Name=str, Age=int)]] - - @pytest.mark.skipif("pandas" not in sys.modules, reason="Pandas is not installed.") def test_structured_dataset_in_dataclassjsonmixin(): + @dataclass + class InnerDatasetStructDataclassJsonMixin(DataClassJSONMixin): + a: StructuredDataset + b: typing.List[Annotated[StructuredDataset, "parquet"]] + c: typing.Dict[str, Annotated[StructuredDataset, kwtypes(Name=str, Age=int)]] + import pandas as pd from pandas._testing import assert_frame_equal @@ -2142,28 +2126,6 @@ class Datum(DataClassJSONMixin): } ), ), - ( - {"p1": TestStructD(s=InnerStruct(a=5, b=None, c=[1, 2, 3]), m={"a": [5]})}, - {"p1": TestStructD}, - LiteralMap( - literals={ - "p1": Literal( - scalar=Scalar( - generic=_json_format.Parse( - typing.cast( - DataClassJsonMixin, - TestStructD( - s=InnerStruct(a=5, b=None, c=[1, 2, 3]), - m={"a": [5]}, - ), - ).to_json(), - _struct.Struct(), - ) - ) - ) - } - ), - ), ( {"p1": "s3://tmp/file.jpeg"}, {"p1": JPEGImageFile}, @@ -2193,6 +2155,42 @@ def test_dict_to_literal_map(python_value, python_types, expected_literal_map): assert TypeEngine.dict_to_literal_map(ctx, python_value, python_types) == expected_literal_map +def test_dict_to_literal_map_with_dataclass(): + @dataclass + class InnerStruct(DataClassJsonMixin): + a: int + b: typing.Optional[str] + c: typing.List[int] + + @dataclass + class TestStructD(DataClassJsonMixin): + s: InnerStruct + m: typing.Dict[str, typing.List[int]] + + ctx = FlyteContext.current_context() + python_value = {"p1": TestStructD(s=InnerStruct(a=5, b=None, c=[1, 2, 3]), m={"a": [5]})} + python_types = {"p1": TestStructD} + expected_literal_map = LiteralMap( + literals={ + "p1": Literal( + scalar=Scalar( + generic=_json_format.Parse( + typing.cast( + DataClassJsonMixin, + TestStructD( + s=InnerStruct(a=5, b=None, c=[1, 2, 3]), + m={"a": [5]}, + ), + ).to_json(), + _struct.Struct(), + ) + ) + ) + } + ) + assert TypeEngine.dict_to_literal_map(ctx, python_value, python_types) == expected_literal_map + + def test_dict_to_literal_map_with_wrong_input_type(): ctx = FlyteContext.current_context() input = {"a": 1} @@ -2297,6 +2295,12 @@ def constant_hash(df: pd.DataFrame) -> str: def test_annotated_simple_types(): + @dataclass + class InnerStruct(DataClassJsonMixin): + a: int + b: typing.Optional[str] + c: typing.List[int] + def _check_annotation(t, annotation): lt = TypeEngine.to_literal_type(t) assert isinstance(lt.annotation, TypeAnnotation) diff --git a/tests/flytekit/unit/core/test_type_hints.py b/tests/flytekit/unit/core/test_type_hints.py index e3b0978565..8879119eeb 100644 --- a/tests/flytekit/unit/core/test_type_hints.py +++ b/tests/flytekit/unit/core/test_type_hints.py @@ -1701,8 +1701,8 @@ def wf2(a: typing.Union[int, str]) -> typing.Union[int, str]: match=re.escape( "Error encountered while executing 'wf2':\n" f" Failed to convert inputs of task '{prefix}tests.flytekit.unit.core.test_type_hints.t2':\n" - ' Cannot convert from to typing.Union[float, dict] (using tag str)' + ' Cannot convert from [Flyte Serialized object: Type: Value: ] to typing.Union[float, dict] (using tag str)' ), ): assert wf2(a="2") == "2" diff --git a/tests/flytekit/unit/models/test_common.py b/tests/flytekit/unit/models/test_common.py index b3754c16ff..b966053ee6 100644 --- a/tests/flytekit/unit/models/test_common.py +++ b/tests/flytekit/unit/models/test_common.py @@ -108,5 +108,5 @@ def test_auth_role_empty(): def test_short_string_raw_output_data_config(): """""" obj = _common.RawOutputDataConfig("s3://bucket") - assert "FlyteLiteral(RawOutputDataConfig)" in obj.short_string() - assert "FlyteLiteral(RawOutputDataConfig)" in repr(obj) + assert "Flyte Serialized object: Type: Value" in obj.short_string() + assert "Flyte Serialized object: Type: Value" in repr(obj) From 7ddfb9b0873f33500045677b8a10942f41363048 Mon Sep 17 00:00:00 2001 From: redartera <120470035+redartera@users.noreply.github.com> Date: Sat, 8 Jun 2024 18:07:00 -0400 Subject: [PATCH 26/32] enable string representation for scalar schema (#2468) Signed-off-by: redartera <120470035+redartera@users.noreply.github.com> --- flytekit/interaction/string_literals.py | 2 ++ tests/flytekit/unit/interaction/test_string_literals.py | 6 +++++- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/flytekit/interaction/string_literals.py b/flytekit/interaction/string_literals.py index 8bc9421334..0bfb3c866a 100644 --- a/flytekit/interaction/string_literals.py +++ b/flytekit/interaction/string_literals.py @@ -37,6 +37,8 @@ def scalar_to_string(scalar: Scalar) -> typing.Any: return scalar.error.message if scalar.structured_dataset: return scalar.structured_dataset.uri + if scalar.schema: + return scalar.schema.uri if scalar.blob: return scalar.blob.uri if scalar.binary: diff --git a/tests/flytekit/unit/interaction/test_string_literals.py b/tests/flytekit/unit/interaction/test_string_literals.py index 1ddd1bdc98..06666feae8 100644 --- a/tests/flytekit/unit/interaction/test_string_literals.py +++ b/tests/flytekit/unit/interaction/test_string_literals.py @@ -22,10 +22,11 @@ Primitive, Scalar, StructuredDataset, + Schema, Union, Void, ) -from flytekit.models.types import Error, LiteralType, SimpleType +from flytekit.models.types import Error, LiteralType, SimpleType, SchemaType def test_primitive_to_string(): @@ -64,6 +65,9 @@ def test_scalar_to_string(): scalar = Scalar(structured_dataset=StructuredDataset(uri="uri")) assert scalar_to_string(scalar) == "uri" + scalar = Scalar(schema=Schema(uri="schema_uri", type=SchemaType(columns=[]))) + assert scalar_to_string(scalar) == "schema_uri" + scalar = Scalar( blob=Blob( metadata=BlobMetadata(BlobType(format="", dimensionality=BlobType.BlobDimensionality.SINGLE)), uri="uri" From daeff3f5f0f36a1a9a1f86c5e024d1b76cdfd5cb Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Wed, 12 Jun 2024 09:06:19 +0800 Subject: [PATCH 27/32] Make the value in the label optional (#2465) Signed-off-by: Kevin Su --- flytekit/clis/sdk_in_container/run.py | 4 ++-- flytekit/interaction/click_types.py | 16 ++++++++++++++++ tests/flytekit/unit/cli/pyflyte/test_run.py | 10 ++++++++++ 3 files changed, 28 insertions(+), 2 deletions(-) diff --git a/flytekit/clis/sdk_in_container/run.py b/flytekit/clis/sdk_in_container/run.py index b5278be053..47e16510cb 100644 --- a/flytekit/clis/sdk_in_container/run.py +++ b/flytekit/clis/sdk_in_container/run.py @@ -32,7 +32,7 @@ from flytekit.core.type_engine import TypeEngine from flytekit.core.workflow import PythonFunctionWorkflow, WorkflowBase from flytekit.exceptions.system import FlyteSystemException -from flytekit.interaction.click_types import FlyteLiteralConverter, key_value_callback +from flytekit.interaction.click_types import FlyteLiteralConverter, key_value_callback, labels_callback from flytekit.interaction.string_literals import literal_string_repr from flytekit.loggers import logger from flytekit.models import security @@ -174,7 +174,7 @@ class RunLevelParams(PyFlyteParams): multiple=True, type=str, show_default=True, - callback=key_value_callback, + callback=labels_callback, help="Labels to be attached to the execution of the format `label_key=label_value`.", ) ) diff --git a/flytekit/interaction/click_types.py b/flytekit/interaction/click_types.py index c16339a236..4eb597d8df 100644 --- a/flytekit/interaction/click_types.py +++ b/flytekit/interaction/click_types.py @@ -57,6 +57,22 @@ def key_value_callback(_: typing.Any, param: str, values: typing.List[str]) -> t return result +def labels_callback(_: typing.Any, param: str, values: typing.List[str]) -> typing.Optional[typing.Dict[str, str]]: + """ + Callback for click to parse labels. + """ + if not values: + return None + result = {} + for v in values: + if "=" not in v: + result[v.strip()] = "" + else: + k, v = v.split("=", 1) + result[k.strip()] = v.strip() + return result + + class DirParamType(click.ParamType): name = "directory path" diff --git a/tests/flytekit/unit/cli/pyflyte/test_run.py b/tests/flytekit/unit/cli/pyflyte/test_run.py index 28bee1dea7..6957828743 100644 --- a/tests/flytekit/unit/cli/pyflyte/test_run.py +++ b/tests/flytekit/unit/cli/pyflyte/test_run.py @@ -76,6 +76,16 @@ def test_pyflyte_run_wf(remote, remote_flag, workflow_file): assert result.exit_code == 0 +def test_pyflyte_run_with_labels(): + workflow_file = pathlib.Path(__file__).parent / "workflow.py" + with mock.patch("flytekit.configuration.plugin.FlyteRemote"): + runner = CliRunner() + result = runner.invoke( + pyflyte.main, ["run", "--remote", str(workflow_file), "my_wf", "--help"], catch_exceptions=False + ) + assert result.exit_code == 0 + + def test_imperative_wf(): runner = CliRunner() result = runner.invoke( From 63f190e691e150dd83b031f8561d698bdace21c2 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Thu, 13 Jun 2024 07:51:18 +0800 Subject: [PATCH 28/32] Fix test_image_spec (#2477) Signed-off-by: Kevin Su --- flytekit/image_spec/image_spec.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/flytekit/image_spec/image_spec.py b/flytekit/image_spec/image_spec.py index 4e5f103782..1aadac3bfd 100644 --- a/flytekit/image_spec/image_spec.py +++ b/flytekit/image_spec/image_spec.py @@ -104,7 +104,6 @@ def is_container(self) -> bool: return os.environ.get(_F_IMG_ID) == self.image_name() return True - @lru_cache def exist(self) -> bool: """ Check if the image exists in the registry. @@ -127,10 +126,14 @@ def exist(self) -> bool: return False except Exception as e: tag = calculate_hash_from_image_spec(self) - # if docker engine is not running locally - container_registry = DOCKER_HUB - if self.registry and "/" in self.registry: + # if docker engine is not running locally, use requests to check if the image exists. + if "localhost:" in self.registry: + container_registry = self.registry + elif self.registry and "/" in self.registry: container_registry = self.registry.split("/")[0] + else: + # Assume the image is in docker hub if users don't specify a registry, such as ghcr.io, docker.io. + container_registry = DOCKER_HUB if container_registry == DOCKER_HUB: url = f"https://hub.docker.com/v2/repositories/{self.registry}/{self.name}/tags/{tag}" response = requests.get(url) From 17835e04bd6e58632e80391eadb0c6fc0a900a52 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bu=C4=9Fra=20Gedik?= Date: Thu, 13 Jun 2024 13:38:23 -0700 Subject: [PATCH 29/32] Add copy all options to register script (#2464) Signed-off-by: bugra.gedik --- flytekit/remote/remote.py | 21 +++++--- flytekit/tools/fast_registration.py | 31 ++++++++++-- .../unit/tools/test_fast_registration.py | 48 ++++++++++++++++++- 3 files changed, 90 insertions(+), 10 deletions(-) diff --git a/flytekit/remote/remote.py b/flytekit/remote/remote.py index b6fdffec47..3f5fae8cb3 100644 --- a/flytekit/remote/remote.py +++ b/flytekit/remote/remote.py @@ -84,7 +84,7 @@ from flytekit.remote.lazy_entity import LazyEntity from flytekit.remote.remote_callable import RemoteEntity from flytekit.remote.remote_fs import get_flyte_fs -from flytekit.tools.fast_registration import fast_package +from flytekit.tools.fast_registration import FastPackageOptions, fast_package from flytekit.tools.interactive import ipython_check from flytekit.tools.script_mode import _find_project_root, compress_scripts, hash_file from flytekit.tools.translator import ( @@ -847,18 +847,23 @@ def register_workflow( fwf._python_interface = entity.python_interface return fwf - def fast_package(self, root: os.PathLike, deref_symlinks: bool = True, output: str = None) -> (bytes, str): + def fast_package( + self, + root: os.PathLike, + deref_symlinks: bool = True, + output: str = None, + options: typing.Optional[FastPackageOptions] = None, + ) -> typing.Tuple[bytes, str]: """ Packages the given paths into an installable zip and returns the md5_bytes and the URL of the uploaded location :param root: path to the root of the package system that should be uploaded :param output: output path. Optional, will default to a tempdir :param deref_symlinks: if symlinks should be dereferenced. Defaults to True + :param options: additional options to customize fast_package behavior :return: md5_bytes, url """ # Create a zip file containing all the entries. - zip_file = fast_package(root, output, deref_symlinks) - md5_bytes, _, _ = hash_file(pathlib.Path(zip_file)) - + zip_file = fast_package(root, output, deref_symlinks, options) # Upload zip file to Admin using FlyteRemote. return self.upload_file(pathlib.Path(zip_file)) @@ -972,6 +977,7 @@ def register_script( source_path: typing.Optional[str] = None, module_name: typing.Optional[str] = None, envs: typing.Optional[typing.Dict[str, str]] = None, + fast_package_options: typing.Optional[FastPackageOptionas] = None, ) -> typing.Union[FlyteWorkflow, FlyteTask]: """ Use this method to register a workflow via script mode. @@ -987,6 +993,7 @@ def register_script( :param source_path: The root of the project path :param module_name: the name of the module :param envs: Environment variables to be passed to the serialization + :param fast_package_options: Options to customize copy_all behavior, ignored when copy_all is False. :return: """ if image_config is None: @@ -994,7 +1001,9 @@ def register_script( with tempfile.TemporaryDirectory() as tmp_dir: if copy_all: - md5_bytes, upload_native_url = self.fast_package(pathlib.Path(source_path), False, tmp_dir) + md5_bytes, upload_native_url = self.fast_package( + pathlib.Path(source_path), False, tmp_dir, fast_package_options + ) else: archive_fname = pathlib.Path(os.path.join(tmp_dir, "script_mode.tar.gz")) compress_scripts(source_path, str(archive_fname), module_name) diff --git a/flytekit/tools/fast_registration.py b/flytekit/tools/fast_registration.py index e596e62f38..c67873c9d8 100644 --- a/flytekit/tools/fast_registration.py +++ b/flytekit/tools/fast_registration.py @@ -8,20 +8,36 @@ import tarfile import tempfile import typing +from dataclasses import dataclass from typing import Optional import click from flytekit.core.context_manager import FlyteContextManager from flytekit.core.utils import timeit -from flytekit.tools.ignore import DockerIgnore, GitIgnore, IgnoreGroup, StandardIgnore +from flytekit.tools.ignore import DockerIgnore, GitIgnore, Ignore, IgnoreGroup, StandardIgnore from flytekit.tools.script_mode import tar_strip_file_attributes FAST_PREFIX = "fast" FAST_FILEENDING = ".tar.gz" -def fast_package(source: os.PathLike, output_dir: os.PathLike, deref_symlinks: bool = False) -> os.PathLike: +@dataclass(frozen=True) +class FastPackageOptions: + """ + FastPackageOptions is used to set configuration options when packaging files. + """ + + ignores: list[Ignore] + keep_default_ignores: bool = True + + +def fast_package( + source: os.PathLike, + output_dir: os.PathLike, + deref_symlinks: bool = False, + options: Optional[FastPackageOptions] = None, +) -> os.PathLike: """ Takes a source directory and packages everything not covered by common ignores into a tarball named after a hexdigest of the included files. @@ -30,7 +46,16 @@ def fast_package(source: os.PathLike, output_dir: os.PathLike, deref_symlinks: b :param bool deref_symlinks: Enables dereferencing symlinks when packaging directory :return os.PathLike: """ - ignore = IgnoreGroup(source, [GitIgnore, DockerIgnore, StandardIgnore]) + default_ignores = [GitIgnore, DockerIgnore, StandardIgnore] + if options is not None: + if options.keep_default_ignores: + ignores = options.ignores + default_ignores + else: + ignores = options.ignores + else: + ignores = default_ignores + ignore = IgnoreGroup(source, ignores) + digest = compute_digest(source, ignore.is_ignored) archive_fname = f"{FAST_PREFIX}{digest}{FAST_FILEENDING}" diff --git a/tests/flytekit/unit/tools/test_fast_registration.py b/tests/flytekit/unit/tools/test_fast_registration.py index dd68e22aa8..a150002cb5 100644 --- a/tests/flytekit/unit/tools/test_fast_registration.py +++ b/tests/flytekit/unit/tools/test_fast_registration.py @@ -7,11 +7,12 @@ from flytekit.tools.fast_registration import ( FAST_FILEENDING, FAST_PREFIX, + FastPackageOptions, compute_digest, fast_package, get_additional_distribution_loc, ) -from flytekit.tools.ignore import DockerIgnore, GitIgnore, IgnoreGroup, StandardIgnore +from flytekit.tools.ignore import DockerIgnore, GitIgnore, Ignore, IgnoreGroup, StandardIgnore from tests.flytekit.unit.tools.test_ignore import make_tree @@ -64,6 +65,51 @@ def test_package(flyte_project, tmp_path): assert str(archive_fname).endswith(FAST_FILEENDING) +def test_package_with_ignore(flyte_project, tmp_path): + class TestIgnore(Ignore): + def _is_ignored(self, path: str) -> bool: + return path.startswith("utils") + + options = FastPackageOptions(ignores=[TestIgnore]) + archive_fname = fast_package(source=flyte_project, output_dir=tmp_path, deref_symlinks=False, options=options) + with tarfile.open(archive_fname) as tar: + assert sorted(tar.getnames()) == [ + ".dockerignore", + ".gitignore", + "keep.foo", + "src", + "src/util", + "src/workflows", + "src/workflows/__pycache__", + "src/workflows/hello_world.py", + ] + assert str(os.path.basename(archive_fname)).startswith(FAST_PREFIX) + assert str(archive_fname).endswith(FAST_FILEENDING) + + +def test_package_with_ignore_without_defaults(flyte_project, tmp_path): + class TestIgnore(Ignore): + def _is_ignored(self, path: str) -> bool: + return path.startswith("utils") + + options = FastPackageOptions(ignores=[TestIgnore, GitIgnore, DockerIgnore], keep_default_ignores=False) + archive_fname = fast_package(source=flyte_project, output_dir=tmp_path, deref_symlinks=False, options=options) + with tarfile.open(archive_fname) as tar: + assert sorted(tar.getnames()) == [ + ".dockerignore", + ".gitignore", + "keep.foo", + "src", + "src/util", + "src/workflows", + "src/workflows/__pycache__", + "src/workflows/__pycache__/some.pyc", + "src/workflows/hello_world.py", + ] + assert str(os.path.basename(archive_fname)).startswith(FAST_PREFIX) + assert str(archive_fname).endswith(FAST_FILEENDING) + + def test_package_with_symlink(flyte_project, tmp_path): archive_fname = fast_package(source=flyte_project / "src", output_dir=tmp_path, deref_symlinks=True) with tarfile.open(archive_fname, dereference=True) as tar: From 76fe8cbb978a6a5e2376f28e7069b4786980a9e3 Mon Sep 17 00:00:00 2001 From: Yee Hing Tong Date: Thu, 13 Jun 2024 15:59:58 -0700 Subject: [PATCH 30/32] spelling (#2481) Signed-off-by: Yee Hing Tong --- flytekit/remote/remote.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flytekit/remote/remote.py b/flytekit/remote/remote.py index 3f5fae8cb3..ac449dd786 100644 --- a/flytekit/remote/remote.py +++ b/flytekit/remote/remote.py @@ -977,7 +977,7 @@ def register_script( source_path: typing.Optional[str] = None, module_name: typing.Optional[str] = None, envs: typing.Optional[typing.Dict[str, str]] = None, - fast_package_options: typing.Optional[FastPackageOptionas] = None, + fast_package_options: typing.Optional[FastPackageOptions] = None, ) -> typing.Union[FlyteWorkflow, FlyteTask]: """ Use this method to register a workflow via script mode. From 8562fd9249a868ead7b4264712325a53a176df7c Mon Sep 17 00:00:00 2001 From: Noah Jackson Date: Thu, 13 Jun 2024 23:32:12 -0700 Subject: [PATCH 31/32] Add identity to task execution metadata (#2315) Signed-off-by: noahjax Signed-off-by: ddl-rliu Co-authored-by: ddl-rliu --- flytekit/models/security.py | 3 +++ flytekit/models/task.py | 9 +++++++++ pyproject.toml | 2 +- tests/flytekit/unit/extend/test_agent.py | 2 ++ 4 files changed, 15 insertions(+), 1 deletion(-) diff --git a/flytekit/models/security.py b/flytekit/models/security.py index 748b4d09eb..e210c910b7 100644 --- a/flytekit/models/security.py +++ b/flytekit/models/security.py @@ -92,12 +92,14 @@ class Identity(_common.FlyteIdlEntity): iam_role: Optional[str] = None k8s_service_account: Optional[str] = None oauth2_client: Optional[OAuth2Client] = None + execution_identity: Optional[str] = None def to_flyte_idl(self) -> _sec.Identity: return _sec.Identity( iam_role=self.iam_role if self.iam_role else None, k8s_service_account=self.k8s_service_account if self.k8s_service_account else None, oauth2_client=self.oauth2_client.to_flyte_idl() if self.oauth2_client else None, + execution_identity=self.execution_identity if self.execution_identity else None, ) @classmethod @@ -108,6 +110,7 @@ def from_flyte_idl(cls, pb2_object: _sec.Identity) -> "Identity": oauth2_client=OAuth2Client.from_flyte_idl(pb2_object.oauth2_client) if pb2_object.oauth2_client and pb2_object.oauth2_client.ByteSize() else None, + execution_identity=pb2_object.execution_identity if pb2_object.execution_identity else None, ) diff --git a/flytekit/models/task.py b/flytekit/models/task.py index 5072f03757..0532b276e2 100644 --- a/flytekit/models/task.py +++ b/flytekit/models/task.py @@ -528,6 +528,7 @@ def __init__( annotations, k8s_service_account, environment_variables, + identity, ): """ Runtime task execution metadata. @@ -539,6 +540,7 @@ def __init__( :param dict[str, str] annotations: Annotations to use for the execution of this task. :param Text k8s_service_account: Service account to use for execution of this task. :param dict[str, str] environment_variables: Environment variables for this task. + :param flytekit.models.security.Identity identity: Identity of user executing this task """ self._task_execution_id = task_execution_id self._namespace = namespace @@ -546,6 +548,7 @@ def __init__( self._annotations = annotations self._k8s_service_account = k8s_service_account self._environment_variables = environment_variables + self._identity = identity @property def task_execution_id(self): @@ -571,6 +574,10 @@ def k8s_service_account(self): def environment_variables(self): return self._environment_variables + @property + def identity(self): + return self._identity + def to_flyte_idl(self): """ :rtype: flyteidl.admin.agent_pb2.TaskExecutionMetadata @@ -584,6 +591,7 @@ def to_flyte_idl(self): environment_variables={k: v for k, v in self.environment_variables.items()} if self.labels is not None else None, + identity=self.identity.to_flyte_idl() if self.identity else None, ) return task_execution_metadata @@ -604,6 +612,7 @@ def from_flyte_idl(cls, pb2_object): environment_variables={k: v for k, v in pb2_object.environment_variables.items()} if pb2_object.environment_variables is not None else None, + identity=_sec.Identity.from_flyte_idl(pb2_object.identity) if pb2_object.identity else None, ) diff --git a/pyproject.toml b/pyproject.toml index e0fd189bd6..5aece34595 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,7 +20,7 @@ dependencies = [ "diskcache>=5.2.1", "docker>=4.0.0,<7.0.0", "docstring-parser>=0.9.0", - "flyteidl>=1.11.0b1", + "flyteidl>=1.12.0", "fsspec>=2023.3.0", "gcsfs>=2023.3.0", "googleapis-common-protos>=1.57", diff --git a/tests/flytekit/unit/extend/test_agent.py b/tests/flytekit/unit/extend/test_agent.py index 17db5c2788..3226313079 100644 --- a/tests/flytekit/unit/extend/test_agent.py +++ b/tests/flytekit/unit/extend/test_agent.py @@ -47,6 +47,7 @@ WorkflowExecutionIdentifier, ) from flytekit.models.literals import LiteralMap +from flytekit.models.security import Identity from flytekit.models.task import TaskExecutionMetadata, TaskTemplate from flytekit.tools.translator import get_serializable @@ -159,6 +160,7 @@ def simple_task(i: int): annotations={"annotation_key": "annotation_val"}, k8s_service_account="k8s service account", environment_variables={"env_var_key": "env_var_val"}, + identity=Identity(execution_identity="task executor"), ) From 4b2c06bb99fa78511570770c623c6c0cab13d1e5 Mon Sep 17 00:00:00 2001 From: Daniel Sola <40698988+dansola@users.noreply.github.com> Date: Fri, 14 Jun 2024 10:24:16 -0700 Subject: [PATCH 32/32] Add functionality for .flyteignore file (#2479) * Add functionality for .flyteignore file Signed-off-by: Daniel Sola * copy docker ignore for flyteignore --------- Signed-off-by: Daniel Sola --- flytekit/tools/fast_registration.py | 4 +-- flytekit/tools/ignore.py | 21 ++++++++++++++ tests/flytekit/unit/tools/test_ignore.py | 36 ++++++++++++++++++++++-- 3 files changed, 57 insertions(+), 4 deletions(-) diff --git a/flytekit/tools/fast_registration.py b/flytekit/tools/fast_registration.py index c67873c9d8..6108048533 100644 --- a/flytekit/tools/fast_registration.py +++ b/flytekit/tools/fast_registration.py @@ -15,7 +15,7 @@ from flytekit.core.context_manager import FlyteContextManager from flytekit.core.utils import timeit -from flytekit.tools.ignore import DockerIgnore, GitIgnore, Ignore, IgnoreGroup, StandardIgnore +from flytekit.tools.ignore import DockerIgnore, FlyteIgnore, GitIgnore, Ignore, IgnoreGroup, StandardIgnore from flytekit.tools.script_mode import tar_strip_file_attributes FAST_PREFIX = "fast" @@ -46,7 +46,7 @@ def fast_package( :param bool deref_symlinks: Enables dereferencing symlinks when packaging directory :return os.PathLike: """ - default_ignores = [GitIgnore, DockerIgnore, StandardIgnore] + default_ignores = [GitIgnore, DockerIgnore, StandardIgnore, FlyteIgnore] if options is not None: if options.keep_default_ignores: ignores = options.ignores + default_ignores diff --git a/flytekit/tools/ignore.py b/flytekit/tools/ignore.py index 4a427c2734..e41daf0904 100644 --- a/flytekit/tools/ignore.py +++ b/flytekit/tools/ignore.py @@ -87,6 +87,27 @@ def _is_ignored(self, path: str) -> bool: return self.pm.matches(path) +class FlyteIgnore(Ignore): + """Uses a .flyteignore file to determine ignored files.""" + + def __init__(self, root: Path): + super().__init__(root) + self.pm = self._parse() + + def _parse(self) -> PatternMatcher: + patterns = [] + flyteignore = os.path.join(self.root, ".flyteignore") + if os.path.isfile(flyteignore): + with open(flyteignore, "r") as f: + patterns = [l.strip() for l in f.readlines() if l and not l.startswith("#")] + else: + logger.info(f"No .flyteignore found in {self.root}, not applying any filters") + return PatternMatcher(patterns) + + def _is_ignored(self, path: str) -> bool: + return self.pm.matches(path) + + class StandardIgnore(Ignore): """Retains the standard ignore functionality that previously existed. Could in theory by fed with custom ignore patterns from cli.""" diff --git a/tests/flytekit/unit/tools/test_ignore.py b/tests/flytekit/unit/tools/test_ignore.py index 3b24eeb8a0..614269fc26 100644 --- a/tests/flytekit/unit/tools/test_ignore.py +++ b/tests/flytekit/unit/tools/test_ignore.py @@ -8,7 +8,7 @@ import pytest from docker.utils.build import PatternMatcher -from flytekit.tools.ignore import DockerIgnore, GitIgnore, IgnoreGroup, StandardIgnore +from flytekit.tools.ignore import DockerIgnore, GitIgnore, IgnoreGroup, StandardIgnore, FlyteIgnore def make_tree(root: Path, tree: Dict): @@ -102,6 +102,19 @@ def all_ignore(tmp_path): return tmp_path +@pytest.fixture +def simple_flyteignore(tmp_path): + tree = { + "sub": {"some.bar": ""}, + "test.foo": "", + "keep.foo": "", + ".flyteignore": "\n".join(["*.foo", "!keep.foo", "# A comment", "sub"]), + } + + make_tree(tmp_path, tree) + return tmp_path + + def test_simple_gitignore(simple_gitignore): gitignore = GitIgnore(simple_gitignore) assert gitignore.is_ignored(str(simple_gitignore / "test.foo")) @@ -219,7 +232,7 @@ def test_all_ignore(all_ignore): def test_all_ignore_tar_filter(all_ignore): """Test tar_filter method of all ignores grouped together""" - ignore = IgnoreGroup(all_ignore, [GitIgnore, DockerIgnore, StandardIgnore]) + ignore = IgnoreGroup(all_ignore, [GitIgnore, DockerIgnore, StandardIgnore, FlyteIgnore]) assert ignore.tar_filter(TarInfo(name="sub")).name == "sub" assert ignore.tar_filter(TarInfo(name="sub/some.bar")).name == "sub/some.bar" assert not ignore.tar_filter(TarInfo(name="sub/__pycache__/")) @@ -232,4 +245,23 @@ def test_all_ignore_tar_filter(all_ignore): assert ignore.tar_filter(TarInfo(name="keep.foo")).name == "keep.foo" assert ignore.tar_filter(TarInfo(name=".gitignore")).name == ".gitignore" assert ignore.tar_filter(TarInfo(name=".dockerignore")).name == ".dockerignore" + assert ignore.tar_filter(TarInfo(name=".flyteignore")).name == ".flyteignore" assert not ignore.tar_filter(TarInfo(name=".git")) + + +def test_flyteignore_parse(simple_flyteignore): + """Test .flyteignore file parsing""" + flyteignore = FlyteIgnore(simple_flyteignore) + assert flyteignore.pm.matches("whatever.foo") + assert not flyteignore.pm.matches("keep.foo") + assert flyteignore.pm.matches("sub") + assert flyteignore.pm.matches("sub/stuff.txt") + + +def test_simple_flyteignore(simple_flyteignore): + flyteignore = FlyteIgnore(simple_flyteignore) + assert flyteignore.is_ignored(str(simple_flyteignore / "test.foo")) + assert flyteignore.is_ignored(str(simple_flyteignore / "sub")) + assert flyteignore.is_ignored(str(simple_flyteignore / "sub" / "some.bar")) + assert not flyteignore.is_ignored(str(simple_flyteignore / "keep.foo")) + assert not flyteignore.is_ignored(str(simple_flyteignore / ".flyteignore"))