Skip to content

Commit

Permalink
[BUG] Open FlyteFile from remote path (flyteorg#2991)
Browse files Browse the repository at this point in the history
* fix: Open FlyteFile from remote path

Signed-off-by: JiaWei Jiang <[email protected]>

* Add integration test

Signed-off-by: JiaWei Jiang <[email protected]>

* refactor: Use ctx as param instead of recreation

Signed-off-by: JiaWei Jiang <[email protected]>

* refactor: Clean test logic

1. Remove redundant prints
2. Use `mock.patch.dict` to setup `os.environ` for the current test fn
    * Avoid contaminating other tests running in the same process

Signed-off-by: JiaWei Jiang <[email protected]>

* refactor: Setup local path and downloader in constructor

Signed-off-by: JiaWei Jiang <[email protected]>

* refactor: Move SimpleFileTransfer to an utility file

Signed-off-by: JiaWei Jiang <[email protected]>

* Remove redundant env var setup

Please refer to flyteorg#3001

Signed-off-by: JiaWei Jiang <[email protected]>

* test: Add another ff use case

Create ff in one task pod and read it in another task pod.

Signed-off-by: JiaWei Jiang <[email protected]>

---------

Signed-off-by: JiaWei Jiang <[email protected]>
Signed-off-by: Shuying Liang <[email protected]>
  • Loading branch information
JiangJiaWei1103 authored and shuyingliang committed Jan 11, 2025
1 parent 836e27f commit ce04660
Show file tree
Hide file tree
Showing 4 changed files with 229 additions and 9 deletions.
65 changes: 57 additions & 8 deletions flytekit/types/file/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,14 +297,53 @@ def __init__(
self._downloader = downloader
self._downloaded = False
self._remote_path = remote_path
self._remote_source: typing.Optional[str] = None
self._remote_source: typing.Optional[typing.Union[str, os.PathLike]] = None

# Setup local path and downloader for delayed downloading
# We introduce another attribute self._local_path to avoid overriding user-defined self.path
self._local_path = self.path

ctx = FlyteContextManager.current_context()
if ctx.file_access.is_remote(self.path):
self._remote_source = self.path
self._local_path = ctx.file_access.get_random_local_path(self._remote_source)
self._downloader = lambda: FlyteFilePathTransformer.downloader(
ctx=ctx,
remote_path=self._remote_source, # type: ignore
local_path=self._local_path,
)

def __fspath__(self):
# This is where a delayed downloading of the file will happen
"""
Define the file path protocol for opening FlyteFile with the context manager,
following show two common use cases:
1. Directly open a FlyteFile with a local path:
ff = FlyteFile(path=local_path)
with open(ff, "r") as f:
# Read your local file here
# ...
There's no need to handle downloading of the file because it's on the local file system.
In this case, a dummy downloading will be done.
2. Directly open a FlyteFile with a remote path:
ff = FlyteFile(path=remote_path)
with open(ff, "r") as f:
# Read your remote file here
# ...
We now support directly opening a FlyteFile with a file from the remote data storage.
In this case, a delayed downloading of the remote file will be done.
For details, please refer to this issue: https://github.com/flyteorg/flyte/issues/6090.
"""
if not self._downloaded:
# Download data from remote to local or run dummy downloading for input local path
self._downloader()
self._downloaded = True
return self.path
return self._local_path

def __eq__(self, other):
if isinstance(other, FlyteFile):
Expand Down Expand Up @@ -693,16 +732,26 @@ async def async_to_python_value(

# For the remote case, return an FlyteFile object that can download
local_path = ctx.file_access.get_random_local_path(uri)

def _downloader():
return ctx.file_access.get_data(uri, local_path, is_multipart=False)

expected_format = FlyteFilePathTransformer.get_format(expected_python_type)
ff = FlyteFile.__class_getitem__(expected_format)(local_path, _downloader)
ff = FlyteFile.__class_getitem__(expected_format)(
path=local_path, downloader=lambda: self.downloader(ctx=ctx, remote_path=uri, local_path=local_path)
)
ff._remote_source = uri

return ff

@staticmethod
def downloader(
ctx: FlyteContext, remote_path: typing.Union[str, os.PathLike], local_path: typing.Union[str, os.PathLike]
) -> None:
"""
Download data from remote_path to local_path.
We design the downloader as a static method because its behavior is logically
related to this class but don't need to interact with class or instance data.
"""
ctx.file_access.get_data(remote_path, local_path, is_multipart=False)

def guess_python_type(self, literal_type: LiteralType) -> typing.Type[FlyteFile[typing.Any]]:
if (
literal_type.blob is not None
Expand Down
23 changes: 22 additions & 1 deletion tests/flytekit/integration/remote/test_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from urllib.parse import urlparse
import uuid
import pytest
from mock import mock, patch
import mock

from flytekit import LaunchPlan, kwtypes, WorkflowExecutionPhase
from flytekit.configuration import Config, ImageConfig, SerializationSettings
Expand All @@ -29,6 +29,9 @@
from flytekit.clients.friendly import SynchronousFlyteClient as _SynchronousFlyteClient
from flytekit.configuration import PlatformConfig

from tests.flytekit.integration.remote.utils import SimpleFileTransfer


MODULE_PATH = pathlib.Path(__file__).parent / "workflows/basic"
CONFIG = os.environ.get("FLYTECTL_CONFIG", str(pathlib.Path.home() / ".flyte" / "config-sandbox.yaml"))
# Run `make build-dev` to build and push the image to the local registry.
Expand Down Expand Up @@ -812,3 +815,21 @@ def test_get_control_plane_version():
client = _SynchronousFlyteClient(PlatformConfig.for_endpoint("localhost:30080", True))
version = client.get_control_plane_version()
assert version == "unknown" or version.startswith("v")


def test_open_ff():
"""Test opening FlyteFile from a remote path."""
# Upload a file to minio s3 bucket
file_transfer = SimpleFileTransfer()
remote_file_path = file_transfer.upload_file(file_type="json")

execution_id = run("flytefile.py", "wf", "--remote_file_path", remote_file_path)
remote = FlyteRemote(Config.auto(config_file=CONFIG), PROJECT, DOMAIN)
execution = remote.fetch_execution(name=execution_id)
execution = remote.wait(execution=execution, timeout=datetime.timedelta(minutes=5))
assert execution.closure.phase == WorkflowExecutionPhase.SUCCEEDED, f"Execution failed with phase: {execution.closure.phase}"

# Delete the remote file to free the space
url = urlparse(remote_file_path)
bucket, key = url.netloc, url.path.lstrip("/")
file_transfer.delete_file(bucket=bucket, key=key)
98 changes: 98 additions & 0 deletions tests/flytekit/integration/remote/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
"""
Common utilities for flyte remote runs in integration tests.
"""
import os
import json
import tempfile
import pathlib

import botocore.session
from botocore.client import BaseClient
from flytekit.configuration import Config
from flytekit.remote.remote import FlyteRemote


# Define constants
CONFIG = os.environ.get("FLYTECTL_CONFIG", str(pathlib.Path.home() / ".flyte" / "config-sandbox.yaml"))
PROJECT = "flytesnacks"
DOMAIN = "development"


class SimpleFileTransfer:
"""Utilities for file transfer to minio s3 bucket.
Mainly support single file uploading and automatic teardown.
"""

def __init__(self) -> None:
self._remote = FlyteRemote(
config=Config.auto(config_file=CONFIG),
default_project=PROJECT,
default_domain=DOMAIN
)
self._s3_client = self._get_minio_s3_client(self._remote)

def _get_minio_s3_client(self, remote: FlyteRemote) -> BaseClient:
"""Creat a botocore client."""
minio_s3_config = remote.file_access.data_config.s3
sess = botocore.session.get_session()

return sess.create_client(
"s3",
endpoint_url=minio_s3_config.endpoint,
aws_access_key_id=minio_s3_config.access_key_id,
aws_secret_access_key=minio_s3_config.secret_access_key,
)

def upload_file(self, file_type: str) -> str:
"""Upload a single file to minio s3 bucket.
Args:
file_type: File type. Support "txt" and "json".
Returns:
remote_file_path: Remote file path.
"""
with tempfile.TemporaryDirectory() as tmp_dir:
local_file_path = self._dump_tmp_file(file_type, tmp_dir)

# Upload to minio s3 bucket
_, remote_file_path = self._remote.upload_file(
to_upload=local_file_path,
project=PROJECT,
domain=DOMAIN,
)

return remote_file_path

def _dump_tmp_file(self, file_type: str, tmp_dir: str) -> str:
"""Generate and dump a temporary file locally.
Args:
file_type: File type.
tmp_dir: Temporary directory.
Returns:
tmp_file_path: Temporary local file path.
"""
if file_type == "txt":
tmp_file_path = pathlib.Path(tmp_dir) / "test.txt"
with open(tmp_file_path, "w") as f:
f.write("Hello World!")
elif file_type == "json":
d = {"name": "john", "height": 190}
tmp_file_path = pathlib.Path(tmp_dir) / "test.json"
with open(tmp_file_path, "w") as f:
json.dump(d, f)

return tmp_file_path

def delete_file(self, bucket: str, key: str) -> None:
"""Delete the remote file from minio s3 bucket to free the space.
Args:
bucket: s3 bucket name.
key: Key name of the object.
"""
res = self._s3_client.delete_object(Bucket=bucket, Key=key)
assert res["ResponseMetadata"]["HTTPStatusCode"] == 204
52 changes: 52 additions & 0 deletions tests/flytekit/integration/remote/workflows/basic/flytefile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from flytekit import task, workflow
from flytekit.types.file import FlyteFile


@task
def create_ff(file_path: str) -> FlyteFile:
"""Create a FlyteFile."""
return FlyteFile(path=file_path)


@task
def read_ff(ff: FlyteFile) -> None:
"""Read input FlyteFile.
This can be used in the case in which a FlyteFile is created
in another task pod and read in this task pod.
"""
with open(ff, "r") as f:
content = f.read()
print(f"FILE CONTENT | {content}")


@task
def create_and_read_ff(file_path: str) -> FlyteFile:
"""Create a FlyteFile and read it.
Both FlyteFile creation and reading are done in this task pod.
Args:
file_path: File path.
Returns:
ff: FlyteFile object.
"""
ff = FlyteFile(path=file_path)
with open(ff, "r") as f:
content = f.read()
print(f"FILE CONTENT | {content}")

return ff


@workflow
def wf(remote_file_path: str) -> None:
ff_1 = create_ff(file_path=remote_file_path)
read_ff(ff=ff_1)
ff_2 = create_and_read_ff(file_path=remote_file_path)
read_ff(ff=ff_2)


if __name__ == "__main__":
wf()

0 comments on commit ce04660

Please sign in to comment.