From c2b5c454e922f5dff1795f2a320868691429de3b Mon Sep 17 00:00:00 2001 From: Iaroslav Ciupin Date: Mon, 12 Aug 2024 17:12:44 +0300 Subject: [PATCH] Return explicit task execution code not found (#2659) Signed-off-by: Iaroslav Ciupin --- flytekit/core/data_persistence.py | 6 ++++-- flytekit/exceptions/user.py | 5 +++++ flytekit/tools/fast_registration.py | 8 +++++++- tests/flytekit/unit/core/test_checkpoint.py | 6 +++--- 4 files changed, 19 insertions(+), 6 deletions(-) diff --git a/flytekit/core/data_persistence.py b/flytekit/core/data_persistence.py index a6b401bff8..89556a53d0 100644 --- a/flytekit/core/data_persistence.py +++ b/flytekit/core/data_persistence.py @@ -36,7 +36,7 @@ from flytekit.configuration import DataConfig from flytekit.core.local_fsspec import FlyteLocalFileSystem from flytekit.core.utils import timeit -from flytekit.exceptions.user import FlyteAssertion, FlyteValueException +from flytekit.exceptions.user import FlyteAssertion, FlyteDataNotFoundException from flytekit.interfaces.random import random from flytekit.loggers import logger @@ -300,7 +300,7 @@ def get(self, from_path: str, to_path: str, recursive: bool = False, **kwargs): except OSError as oe: logger.debug(f"Error in getting {from_path} to {to_path} rec {recursive} {oe}") if not file_system.exists(from_path): - raise FlyteValueException(from_path, "File not found") + raise FlyteDataNotFoundException(from_path) file_system = self.get_filesystem(get_protocol(from_path), anonymous=True) if file_system is not None: logger.debug(f"Attempting anonymous get with {file_system}") @@ -558,6 +558,8 @@ def get_data(self, remote_path: str, local_path: str, is_multipart: bool = False pathlib.Path(local_path).parent.mkdir(parents=True, exist_ok=True) with timeit(f"Download data to local from {remote_path}"): self.get(remote_path, to_path=local_path, recursive=is_multipart, **kwargs) + except FlyteDataNotFoundException: + raise except Exception as ex: raise FlyteAssertion( f"Failed to get data from {remote_path} to {local_path} (recursive={is_multipart}).\n\n" diff --git a/flytekit/exceptions/user.py b/flytekit/exceptions/user.py index a4b5caa75a..645754dc35 100644 --- a/flytekit/exceptions/user.py +++ b/flytekit/exceptions/user.py @@ -55,6 +55,11 @@ def __init__(self, received_value, error_message): super(FlyteValueException, self).__init__(self._create_verbose_message(received_value, error_message)) +class FlyteDataNotFoundException(FlyteValueException): + def __init__(self, path: str): + super(FlyteDataNotFoundException, self).__init__(path, "File not found") + + class FlyteAssertion(FlyteUserException, AssertionError): _ERROR_CODE = "USER:AssertionError" diff --git a/flytekit/tools/fast_registration.py b/flytekit/tools/fast_registration.py index ca4ab2d2cc..d17bbe8994 100644 --- a/flytekit/tools/fast_registration.py +++ b/flytekit/tools/fast_registration.py @@ -15,6 +15,7 @@ from flytekit.core.context_manager import FlyteContextManager from flytekit.core.utils import timeit +from flytekit.exceptions.user import FlyteDataNotFoundException from flytekit.loggers import logger from flytekit.tools.ignore import DockerIgnore, FlyteIgnore, GitIgnore, Ignore, IgnoreGroup, StandardIgnore from flytekit.tools.script_mode import tar_strip_file_attributes @@ -146,7 +147,12 @@ def download_distribution(additional_distribution: str, destination: str): # NOTE the os.path.join(destination, ''). This is to ensure that the given path is in fact a directory and all # downloaded data should be copied into this directory. We do this to account for a difference in behavior in # fsspec, which requires a trailing slash in case of pre-existing directory. - FlyteContextManager.current_context().file_access.get_data(additional_distribution, os.path.join(destination, "")) + try: + FlyteContextManager.current_context().file_access.get_data( + additional_distribution, os.path.join(destination, "") + ) + except FlyteDataNotFoundException as ex: + raise RuntimeError("task execution code was not found") from ex tarfile_name = os.path.basename(additional_distribution) if not tarfile_name.endswith(".tar.gz"): raise RuntimeError("Unrecognized additional distribution format for {}".format(additional_distribution)) diff --git a/tests/flytekit/unit/core/test_checkpoint.py b/tests/flytekit/unit/core/test_checkpoint.py index 53338ec0ae..96db6da1a9 100644 --- a/tests/flytekit/unit/core/test_checkpoint.py +++ b/tests/flytekit/unit/core/test_checkpoint.py @@ -5,7 +5,7 @@ import flytekit from flytekit.core.checkpointer import SyncCheckpoint -from flytekit.exceptions.user import FlyteAssertion +from flytekit.exceptions.user import FlyteDataNotFoundException def test_sync_checkpoint_write(tmpdir): @@ -90,10 +90,10 @@ def test_sync_checkpoint_restore_corrupt(tmpdir): prev.unlink() src.rmdir() - with pytest.raises(FlyteAssertion): + with pytest.raises(FlyteDataNotFoundException): cp.restore(user_dest) - with pytest.raises(FlyteAssertion): + with pytest.raises(FlyteDataNotFoundException): cp.restore(user_dest)