diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index 900afa8562..897c5a6400 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -676,12 +676,33 @@ def _make_dataclass_serializable(self, python_val: T, python_type: Type[T]) -> t """ from flytekit.types.directory import FlyteDirectory from flytekit.types.file import FlyteFile + from flytekit.types.structured import StructuredDataset # Handle Optional if UnionTransformer.is_optional_type(python_type): - if python_val is None: - return None - return self._make_dataclass_serializable(python_val, get_args(python_type)[0]) + + def get_expected_type(python_val: T, types: tuple) -> Type[T | None]: + if len(set(types) & {FlyteFile, FlyteDirectory, StructuredDataset}) > 1: + raise ValueError( + "Cannot have more than one Flyte type in the Union when attempting to use the string shortcut. Please specify the full object (e.g. FlyteFile(...)) instead of just passing a string." + ) + + for t in types: + try: + trans = TypeEngine.get_transformer(t) # type: ignore + if trans: + trans.assert_type(t, python_val) + return t + except Exception: + continue + return type(None) + + # Get the expected type in the Union type + expected_type = type(None) + if python_val is not None: + expected_type = get_expected_type(python_val, get_args(python_type)) # type: ignore + + return self._make_dataclass_serializable(python_val, expected_type) if hasattr(python_type, "__origin__") and get_origin(python_type) is list: if python_val is None: diff --git a/tests/flytekit/unit/core/test_dataclass.py b/tests/flytekit/unit/core/test_dataclass.py index 58dfcd1e45..4e098c254b 100644 --- a/tests/flytekit/unit/core/test_dataclass.py +++ b/tests/flytekit/unit/core/test_dataclass.py @@ -1118,3 +1118,17 @@ def empty_nested_dc_wf() -> NestedFlyteTypes: empty_nested_flyte_types = empty_nested_dc_wf() DataclassTransformer().assert_type(NestedFlyteTypes, empty_nested_flyte_types) + +def test_dataclass_serialize_with_multiple_dataclass_union(): + @dataclass + class A(): + x: int + + @dataclass + class B(): + x: FlyteFile + + b = B(x="s3://my-bucket/my-file") + res = DataclassTransformer()._make_dataclass_serializable(b, Union[None, A, B]) + + assert res.x.path == "s3://my-bucket/my-file" diff --git a/tests/flytekit/unit/core/test_flytetypes.py b/tests/flytekit/unit/core/test_flytetypes.py new file mode 100644 index 0000000000..366c3547c7 --- /dev/null +++ b/tests/flytekit/unit/core/test_flytetypes.py @@ -0,0 +1,17 @@ +from dataclasses import dataclass +from flytekit.types.file import FlyteFile +from flytekit.types.structured.structured_dataset import StructuredDataset +from flytekit.core.type_engine import DataclassTransformer +from typing import Union +import pytest +import re + +def test_dataclass_union_with_multiple_flytetypes_error(): + @dataclass + class DC(): + x: Union[None, StructuredDataset, FlyteFile] + + + dc = DC(x="s3://my-bucket/my-file") + with pytest.raises(ValueError, match=re.escape("Cannot have more than one Flyte type in the Union when attempting to use the string shortcut. Please specify the full object (e.g. FlyteFile(...)) instead of just passing a string.")): + DataclassTransformer()._make_dataclass_serializable(dc, DC) diff --git a/tests/flytekit/unit/core/test_type_engine.py b/tests/flytekit/unit/core/test_type_engine.py index 8721a8d4db..cb641eebe4 100644 --- a/tests/flytekit/unit/core/test_type_engine.py +++ b/tests/flytekit/unit/core/test_type_engine.py @@ -967,6 +967,7 @@ class TestFileStruct(DataClassJsonMixin): b: typing.Optional[FlyteFile] b_prime: typing.Optional[FlyteFile] c: typing.Union[FlyteFile, None] + c_prime: typing.Union[None, FlyteFile] d: typing.List[FlyteFile] e: typing.List[typing.Optional[FlyteFile]] e_prime: typing.List[typing.Optional[FlyteFile]] @@ -989,6 +990,7 @@ class TestFileStruct(DataClassJsonMixin): b=f1, b_prime=None, c=f1, + c_prime=f1, d=[f1], e=[f1], e_prime=[None], @@ -1011,6 +1013,7 @@ class TestFileStruct(DataClassJsonMixin): assert dict_obj["b"]["path"] == remote_path assert dict_obj["b_prime"] is None assert dict_obj["c"]["path"] == remote_path + assert dict_obj["c_prime"]["path"] == remote_path assert dict_obj["d"][0]["path"] == remote_path assert dict_obj["e"][0]["path"] == remote_path assert dict_obj["e_prime"][0] is None @@ -1028,6 +1031,7 @@ class TestFileStruct(DataClassJsonMixin): assert o.b.remote_path == ot.b.remote_source assert ot.b_prime is None assert o.c.remote_path == ot.c.remote_source + assert o.c_prime.remote_path == ot.c_prime.remote_source assert o.d[0].remote_path == ot.d[0].remote_source assert o.e[0].remote_path == ot.e[0].remote_source assert o.e_prime == [None]