diff --git a/tests/flytekit/integration/remote/test_remote.py b/tests/flytekit/integration/remote/test_remote.py index f20a33aea8..8736c7b2ef 100644 --- a/tests/flytekit/integration/remote/test_remote.py +++ b/tests/flytekit/integration/remote/test_remote.py @@ -14,7 +14,7 @@ from urllib.parse import urlparse import uuid import pytest -import mock +from unittest import mock from flytekit import LaunchPlan, kwtypes, WorkflowExecutionPhase from flytekit.configuration import Config, ImageConfig, SerializationSettings @@ -833,3 +833,21 @@ def test_open_ff(): url = urlparse(remote_file_path) bucket, key = url.netloc, url.path.lstrip("/") file_transfer.delete_file(bucket=bucket, key=key) + + +def test_attr_access_sd(): + """Test accessing StructuredDataset attribute from a dataclass.""" + # Upload a file to minio s3 bucket + file_transfer = SimpleFileTransfer() + remote_file_path = file_transfer.upload_file(file_type="parquet") + + execution_id = run("attr_access_sd.py", "wf", "--uri", remote_file_path) + remote = FlyteRemote(Config.auto(config_file=CONFIG), PROJECT, DOMAIN) + execution = remote.fetch_execution(name=execution_id) + execution = remote.wait(execution=execution, timeout=datetime.timedelta(minutes=5)) + assert execution.closure.phase == WorkflowExecutionPhase.SUCCEEDED, f"Execution failed with phase: {execution.closure.phase}" + + # Delete the remote file to free the space + url = urlparse(remote_file_path) + bucket, key = url.netloc, url.path.lstrip("/") + file_transfer.delete_file(bucket=bucket, key=key) diff --git a/tests/flytekit/integration/remote/utils.py b/tests/flytekit/integration/remote/utils.py index dadc8c6530..c16a0d0f4d 100644 --- a/tests/flytekit/integration/remote/utils.py +++ b/tests/flytekit/integration/remote/utils.py @@ -84,6 +84,9 @@ def _dump_tmp_file(self, file_type: str, tmp_dir: str) -> str: tmp_file_path = pathlib.Path(tmp_dir) / "test.json" with open(tmp_file_path, "w") as f: json.dump(d, f) + elif file_type == "parquet": + # Because `upload_file` accepts a single file only, we specify 00000 to make it a single file + tmp_file_path = pathlib.Path(__file__).parent / "workflows/basic/data/df.parquet/00000" return tmp_file_path diff --git a/tests/flytekit/integration/remote/workflows/basic/attr_access_sd.py b/tests/flytekit/integration/remote/workflows/basic/attr_access_sd.py new file mode 100644 index 0000000000..9d01926081 --- /dev/null +++ b/tests/flytekit/integration/remote/workflows/basic/attr_access_sd.py @@ -0,0 +1,46 @@ +""" +Test accessing StructuredDataset attribute from a dataclass. +""" +from dataclasses import dataclass + +import pandas as pd +from flytekit import task, workflow +from flytekit.types.structured import StructuredDataset + + +@dataclass +class DC: + sd: StructuredDataset + + +@task +def create_dc(uri: str) -> DC: + """Create a dataclass with a StructuredDataset attribute. + + Args: + uri: File URI. + + Returns: + dc: A dataclass with a StructuredDataset attribute. + """ + dc = DC(sd=StructuredDataset(uri=uri, file_format="parquet")) + + return dc + + +@task +def read_sd(sd: StructuredDataset) -> StructuredDataset: + """Read input StructuredDataset.""" + print("sd:", sd.open(pd.DataFrame).all()) + + return sd + + +@workflow +def wf(uri: str) -> None: + dc = create_dc(uri=uri) + read_sd(sd=dc.sd) + + +if __name__ == "__main__": + wf(uri="tests/flytekit/integration/remote/workflows/basic/data/df.parquet")