diff --git a/flytekit/core/auto_cache.py b/flytekit/core/auto_cache.py index adf4741a0b..2915abb729 100644 --- a/flytekit/core/auto_cache.py +++ b/flytekit/core/auto_cache.py @@ -1,4 +1,21 @@ -from typing import Any, Callable, Protocol, runtime_checkable +from dataclasses import dataclass +from typing import Any, Callable, Optional, Protocol, Union, runtime_checkable + +from flytekit.image_spec.image_spec import ImageSpec + + +@dataclass +class VersionParameters: + """ + Parameters used for version hash generation. + + Args: + func (Optional[Callable]): The function to generate a version for + container_image (Optional[Union[str, ImageSpec]]): The container image to generate a version for + """ + + func: Optional[Callable[..., Any]] = None + container_image: Optional[Union[str, ImageSpec]] = None @runtime_checkable @@ -6,32 +23,54 @@ class AutoCache(Protocol): """ A protocol that defines the interface for a caching mechanism that generates a version hash of a function based on its source code. - - Attributes: - salt (str): A string used to add uniqueness to the generated hash. Default is "salt". - - Methods: - get_version(func: Callable[..., Any]) -> str: - Given a function, generates a version hash based on its source code and the salt. """ - def __init__(self, salt: str = "salt") -> None: + salt: str + + def get_version(self, params: VersionParameters) -> str: """ - Initialize the AutoCache instance with a salt value. + Generate a version hash based on the provided parameters. Args: - salt (str): A string to be used as the salt in the hashing process. Defaults to "salt". + params (VersionParameters): Parameters to use for hash generation. + + Returns: + str: The generated version hash. """ + ... + + +class CachePolicy: + """ + A class that combines multiple caching mechanisms to generate a version hash. + + Args: + *cache_objects: Variable number of AutoCache instances + salt: Optional salt string to add uniqueness to the hash + """ + + def __init__(self, *cache_objects: AutoCache, salt: str = "") -> None: + self.cache_objects = cache_objects self.salt = salt - def get_version(self, func: Callable[..., Any]) -> str: + def get_version(self, params: VersionParameters) -> str: """ - Generate a version hash for the provided function. + Generate a version hash using all cache objects. Args: - func (Callable[..., Any]): A callable function whose version hash needs to be generated. + params (VersionParameters): Parameters to use for hash generation. Returns: - str: The SHA-256 hash of the function's source code combined with the salt. + str: The combined hash from all cache objects. """ - ... + task_hash = "" + for cache_instance in self.cache_objects: + # Apply the policy's salt to each cache instance + cache_instance.salt = self.salt + task_hash += cache_instance.get_version(params) + + # Generate SHA-256 hash + import hashlib + + hash_obj = hashlib.sha256(task_hash.encode()) + return hash_obj.hexdigest() diff --git a/flytekit/core/task.py b/flytekit/core/task.py index 7519f341d6..cfd509b278 100644 --- a/flytekit/core/task.py +++ b/flytekit/core/task.py @@ -5,7 +5,7 @@ from functools import update_wrapper from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Type, TypeVar, Union, overload -from flytekit.core.auto_cache import AutoCache +from flytekit.core.auto_cache import CachePolicy, VersionParameters from flytekit.core.utils import str2bool try: @@ -100,7 +100,7 @@ def find_pythontask_plugin(cls, plugin_config_type: type) -> Type[PythonFunction def task( _task_function: None = ..., task_config: Optional[T] = ..., - cache: Union[bool, list[AutoCache]] = ..., + cache: Union[bool, CachePolicy] = ..., cache_serialize: bool = ..., cache_version: str = ..., cache_ignore_input_vars: Tuple[str, ...] = ..., @@ -136,9 +136,9 @@ def task( @overload def task( - _task_function: Callable[P, FuncOut], + _task_function: Callable[..., FuncOut], task_config: Optional[T] = ..., - cache: Union[bool, list[AutoCache]] = ..., + cache: Union[bool, CachePolicy] = ..., cache_serialize: bool = ..., cache_version: str = ..., cache_ignore_input_vars: Tuple[str, ...] = ..., @@ -169,13 +169,13 @@ def task( pod_template: Optional["PodTemplate"] = ..., pod_template_name: Optional[str] = ..., accelerator: Optional[BaseAccelerator] = ..., -) -> Union[Callable[P, FuncOut], PythonFunctionTask[T]]: ... +) -> Union[Callable[..., FuncOut], PythonFunctionTask[T]]: ... def task( - _task_function: Optional[Callable[P, FuncOut]] = None, + _task_function: Optional[Callable[..., FuncOut]] = None, task_config: Optional[T] = None, - cache: Union[bool, list[AutoCache]] = False, + cache: Union[bool, CachePolicy] = False, cache_serialize: bool = False, cache_version: str = "", cache_ignore_input_vars: Tuple[str, ...] = (), @@ -213,8 +213,8 @@ def task( pod_template_name: Optional[str] = None, accelerator: Optional[BaseAccelerator] = None, ) -> Union[ - Callable[P, FuncOut], - Callable[[Callable[P, FuncOut]], PythonFunctionTask[T]], + Callable[..., FuncOut], + Callable[[Callable[..., FuncOut]], PythonFunctionTask[T]], PythonFunctionTask[T], ]: """ @@ -343,17 +343,19 @@ def launch_dynamically(): :param accelerator: The accelerator to use for this task. """ - def wrapper(fn: Callable[P, Any]) -> PythonFunctionTask[T]: - if isinstance(cache, list) and all(isinstance(item, AutoCache) for item in cache): - cache_versions = [item.get_version() for item in cache] - task_hash = "".join(cache_versions) + def wrapper(fn: Callable[..., Any]) -> PythonFunctionTask[T]: + if isinstance(cache, CachePolicy): + params = VersionParameters(func=fn, container_image=container_image) + cache_version_val = cache.get_version(params=params) + cache_val = True else: - task_hash = "" + cache_val = cache + cache_version_val = cache_version _metadata = TaskMetadata( - cache=cache, + cache=cache_val, cache_serialize=cache_serialize, - cache_version=cache_version if not task_hash else task_hash, + cache_version=cache_version_val, cache_ignore_input_vars=cache_ignore_input_vars, retries=retries, interruptible=interruptible, @@ -439,7 +441,7 @@ def wrapper(fn) -> ReferenceTask: return wrapper -def decorate_function(fn: Callable[P, Any]) -> Callable[P, Any]: +def decorate_function(fn: Callable[..., Any]) -> Callable[..., Any]: """ Decorates the task with additional functionality if necessary. diff --git a/flytekit/core/workflow.py b/flytekit/core/workflow.py index de0f620e96..07f99103e2 100644 --- a/flytekit/core/workflow.py +++ b/flytekit/core/workflow.py @@ -843,23 +843,23 @@ def workflow( @overload def workflow( - _workflow_function: Callable[P, FuncOut], + _workflow_function: Callable[..., FuncOut], failure_policy: Optional[WorkflowFailurePolicy] = ..., interruptible: bool = ..., on_failure: Optional[Union[WorkflowBase, Task]] = ..., docs: Optional[Documentation] = ..., default_options: Optional[Options] = ..., -) -> Union[Callable[P, FuncOut], PythonFunctionWorkflow]: ... +) -> Union[Callable[..., FuncOut], PythonFunctionWorkflow]: ... def workflow( - _workflow_function: Optional[Callable[P, FuncOut]] = None, + _workflow_function: Optional[Callable[..., FuncOut]] = None, failure_policy: Optional[WorkflowFailurePolicy] = None, interruptible: bool = False, on_failure: Optional[Union[WorkflowBase, Task]] = None, docs: Optional[Documentation] = None, default_options: Optional[Options] = None, -) -> Union[Callable[P, FuncOut], Callable[[Callable[P, FuncOut]], PythonFunctionWorkflow], PythonFunctionWorkflow]: +) -> Union[Callable[..., FuncOut], Callable[[Callable[..., FuncOut]], PythonFunctionWorkflow], PythonFunctionWorkflow]: """ This decorator declares a function to be a Flyte workflow. Workflows are declarative entities that construct a DAG of tasks using the data flow between tasks. @@ -894,7 +894,7 @@ def workflow( the labels and annotations are allowed to be set as defaults. """ - def wrapper(fn: Callable[P, FuncOut]) -> PythonFunctionWorkflow: + def wrapper(fn: Callable[..., FuncOut]) -> PythonFunctionWorkflow: workflow_metadata = WorkflowMetadata(on_failure=failure_policy or WorkflowFailurePolicy.FAIL_IMMEDIATELY) workflow_metadata_defaults = WorkflowMetadataDefaults(interruptible) diff --git a/plugins/flytekit-auto-cache/flytekitplugins/auto_cache/cache_function_body.py b/plugins/flytekit-auto-cache/flytekitplugins/auto_cache/cache_function_body.py index 6751b53c0f..830aee94f0 100644 --- a/plugins/flytekit-auto-cache/flytekitplugins/auto_cache/cache_function_body.py +++ b/plugins/flytekit-auto-cache/flytekitplugins/auto_cache/cache_function_body.py @@ -4,6 +4,8 @@ import textwrap from typing import Any, Callable +from flytekit.core.auto_cache import VersionParameters + class CacheFunctionBody: """ @@ -27,8 +29,10 @@ def __init__(self, salt: str = "salt") -> None: """ self.salt = salt - def get_version(self, func: Callable[..., Any]) -> str: - return self._get_version(func=func) + def get_version(self, params: VersionParameters) -> str: + if params.func is None: + raise ValueError("Function-based cache requires a function parameter") + return self._get_version(func=params.func) def _get_version(self, func: Callable[..., Any]) -> str: """ diff --git a/plugins/flytekit-auto-cache/flytekitplugins/auto_cache/cache_image.py b/plugins/flytekit-auto-cache/flytekitplugins/auto_cache/cache_image.py new file mode 100644 index 0000000000..cf57b645ce --- /dev/null +++ b/plugins/flytekit-auto-cache/flytekitplugins/auto_cache/cache_image.py @@ -0,0 +1,22 @@ +import hashlib + +from flytekit.core.auto_cache import VersionParameters +from flytekit.image_spec.image_spec import ImageSpec + + +class CacheImage: + def __init__(self, salt: str): + self.salt = salt + + def get_version(self, params: VersionParameters) -> str: + if params.container_image is None: + raise ValueError("Image-based cache requires a container_image parameter") + + # If the image is an ImageSpec, combine tag with salt + if isinstance(params.container_image, ImageSpec): + combined = params.container_image.tag + self.salt + return hashlib.sha256(combined.encode("utf-8")).hexdigest() + + # If the image is a string, combine with salt + combined = params.container_image + self.salt + return hashlib.sha256(combined.encode("utf-8")).hexdigest() diff --git a/plugins/flytekit-auto-cache/flytekitplugins/auto_cache/cache_private_modules.py b/plugins/flytekit-auto-cache/flytekitplugins/auto_cache/cache_private_modules.py index bee640b240..64781f9740 100644 --- a/plugins/flytekit-auto-cache/flytekitplugins/auto_cache/cache_private_modules.py +++ b/plugins/flytekit-auto-cache/flytekitplugins/auto_cache/cache_private_modules.py @@ -8,6 +8,8 @@ from pathlib import Path from typing import Any, Callable, Set, Union +from flytekit.core.auto_cache import VersionParameters + @contextmanager def temporarily_add_to_syspath(path): @@ -24,9 +26,12 @@ def __init__(self, salt: str, root_dir: str): self.salt = salt self.root_dir = Path(root_dir).resolve() - def get_version(self, func: Callable[..., Any]) -> str: - hash_components = [self._get_version(func)] - dependencies = self._get_function_dependencies(func, set()) + def get_version(self, params: VersionParameters) -> str: + if params.func is None: + raise ValueError("Function-based cache requires a function parameter") + + hash_components = [self._get_version(params.func)] + dependencies = self._get_function_dependencies(params.func, set()) for dep in dependencies: hash_components.append(self._get_version(dep)) # Combine all component hashes into a single version hash diff --git a/plugins/flytekit-auto-cache/tests/test_function_body.py b/plugins/flytekit-auto-cache/tests/test_function_body.py index efebe8ad0a..351b7e62af 100644 --- a/plugins/flytekit-auto-cache/tests/test_function_body.py +++ b/plugins/flytekit-auto-cache/tests/test_function_body.py @@ -1,6 +1,7 @@ from dummy_functions.dummy_function import dummy_function from dummy_functions.dummy_function_comments_formatting_change import dummy_function as dummy_function_comments_formatting_change from dummy_functions.dummy_function_logic_change import dummy_function as dummy_function_logic_change +from flytekit.core.auto_cache import VersionParameters from flytekitplugins.auto_cache import CacheFunctionBody @@ -11,9 +12,11 @@ def test_get_version_with_same_function_and_salt(): cache1 = CacheFunctionBody(salt="salt") cache2 = CacheFunctionBody(salt="salt") + params = VersionParameters(func=dummy_function) + # Both calls should return the same hash since the function and salt are the same - version1 = cache1.get_version(dummy_function) - version2 = cache2.get_version(dummy_function) + version1 = cache1.get_version(params) + version2 = cache2.get_version(params) assert version1 == version2, f"Expected {version1}, but got {version2}" @@ -25,9 +28,11 @@ def test_get_version_with_different_salt(): cache1 = CacheFunctionBody(salt="salt1") cache2 = CacheFunctionBody(salt="salt2") + params = VersionParameters(func=dummy_function) + # The hashes should be different because the salts are different - version1 = cache1.get_version(dummy_function) - version2 = cache2.get_version(dummy_function) + version1 = cache1.get_version(params) + version2 = cache2.get_version(params) assert version1 != version2, f"Expected different hashes but got the same: {version1}" @@ -38,8 +43,12 @@ def test_get_version_with_different_logic(): Test that functions with the same name but different logic produce different hashes. """ cache = CacheFunctionBody(salt="salt") - version1 = cache.get_version(dummy_function) - version2 = cache.get_version(dummy_function_logic_change) + + params1 = VersionParameters(func=dummy_function) + params2 = VersionParameters(func=dummy_function_logic_change) + + version1 = cache.get_version(params1) + version2 = cache.get_version(params2) assert version1 != version2, ( f"Hashes should be different for functions with same name but different logic. " @@ -61,8 +70,11 @@ def test_get_version_with_different_function_names(): """ cache = CacheFunctionBody(salt="salt") - version1 = cache.get_version(function_one) - version2 = cache.get_version(function_two) + params1 = VersionParameters(func=function_one) + params2 = VersionParameters(func=function_two) + + version1 = cache.get_version(params1) + version2 = cache.get_version(params2) assert version1 != version2, ( f"Hashes should be different for functions with different names. " @@ -76,8 +88,12 @@ def test_get_version_with_formatting_changes(): """ cache = CacheFunctionBody(salt="salt") - version1 = cache.get_version(dummy_function) - version2 = cache.get_version(dummy_function_comments_formatting_change) + + params1 = VersionParameters(func=dummy_function) + params2 = VersionParameters(func=dummy_function_comments_formatting_change) + + version1 = cache.get_version(params1) + version2 = cache.get_version(params2) assert version1 == version2, ( f"Hashes should be the same for functions with same name but different formatting. " diff --git a/plugins/flytekit-auto-cache/tests/test_image.py b/plugins/flytekit-auto-cache/tests/test_image.py new file mode 100644 index 0000000000..8a9020ab5c --- /dev/null +++ b/plugins/flytekit-auto-cache/tests/test_image.py @@ -0,0 +1,93 @@ +import pytest # type: ignore +import hashlib +from flytekit.core.auto_cache import VersionParameters +from flytekit.image_spec.image_spec import ImageSpec +from flytekitplugins.auto_cache import CacheImage + + +def test_get_version_with_same_image_and_salt(): + """ + Test that calling get_version with the same image and salt returns the same hash. + """ + cache1 = CacheImage(salt="salt") + cache2 = CacheImage(salt="salt") + + params = VersionParameters(container_image="python:3.9") + + version1 = cache1.get_version(params) + version2 = cache2.get_version(params) + + assert version1 == version2, f"Expected {version1}, but got {version2}" + + +def test_get_version_with_different_salt(): + """ + Test that calling get_version with different salts returns different hashes for the same image. + """ + cache1 = CacheImage(salt="salt1") + cache2 = CacheImage(salt="salt2") + + params = VersionParameters(container_image="python:3.9") + + version1 = cache1.get_version(params) + version2 = cache2.get_version(params) + + assert version1 != version2, f"Expected different hashes but got the same: {version1}" + + +def test_get_version_with_different_images(): + """ + Test that different images produce different hashes. + """ + cache = CacheImage(salt="salt") + + params1 = VersionParameters(container_image="python:3.9") + params2 = VersionParameters(container_image="python:3.8") + + version1 = cache.get_version(params1) + version2 = cache.get_version(params2) + + assert version1 != version2, ( + f"Hashes should be different for different images. " + f"Got {version1} and {version2}" + ) + + +def test_get_version_with_image_spec(): + """ + Test that ImageSpec objects use their tag directly. + """ + cache = CacheImage(salt="salt") + + image_spec = ImageSpec( + name="my-image", + registry="my-registry", + tag="v1.0.0" + ) + params = VersionParameters(container_image=image_spec) + + version = cache.get_version(params) + expected = hashlib.sha256("v1.0.0".encode("utf-8")).hexdigest() + assert version == expected, f"Expected {expected}, but got {version}" + + +def test_get_version_without_image(): + """ + Test that calling get_version without an image raises ValueError. + """ + cache = CacheImage(salt="salt") + params = VersionParameters(func=lambda x: x) # Only providing func, no image + + with pytest.raises(ValueError, match="Image-based cache requires a container_image parameter"): + cache.get_version(params) + + +def test_get_version_with_none_image(): + """ + Test that calling get_version with None image raises ValueError. + """ + cache = CacheImage(salt="salt") + params = VersionParameters(container_image=None) + + with pytest.raises(ValueError, match="Image-based cache requires a container_image parameter"): + cache.get_version(params)