Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

initial make auto cache plugin #2912

Draft
wants to merge 3 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions flytekit/core/auto_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from typing import Any, Callable, Protocol, runtime_checkable


@runtime_checkable
class AutoCache(Protocol):
"""
A protocol that defines the interface for a caching mechanism
that generates a version hash of a function based on its source code.

Attributes:
salt (str): A string used to add uniqueness to the generated hash. Default is "salt".

Methods:
get_version(func: Callable[..., Any]) -> str:
Given a function, generates a version hash based on its source code and the salt.
"""

def __init__(self, salt: str = "salt") -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can the default value be ""?

"""
Initialize the AutoCache instance with a salt value.

Args:
salt (str): A string to be used as the salt in the hashing process. Defaults to "salt".
"""
self.salt = salt

def get_version(self, func: Callable[..., Any]) -> str:
"""
Generate a version hash for the provided function.

Args:
func (Callable[..., Any]): A callable function whose version hash needs to be generated.

Returns:
str: The SHA-256 hash of the function's source code combined with the salt.
"""
...
17 changes: 12 additions & 5 deletions flytekit/core/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from functools import update_wrapper
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Type, TypeVar, Union, overload

from flytekit.core.auto_cache import AutoCache
from flytekit.core.utils import str2bool

try:
Expand Down Expand Up @@ -99,7 +100,7 @@ def find_pythontask_plugin(cls, plugin_config_type: type) -> Type[PythonFunction
def task(
_task_function: None = ...,
task_config: Optional[T] = ...,
cache: bool = ...,
cache: Union[bool, list[AutoCache]] = ...,
cache_serialize: bool = ...,
cache_version: str = ...,
cache_ignore_input_vars: Tuple[str, ...] = ...,
Expand Down Expand Up @@ -137,7 +138,7 @@ def task(
def task(
_task_function: Callable[P, FuncOut],
task_config: Optional[T] = ...,
cache: bool = ...,
cache: Union[bool, list[AutoCache]] = ...,
cache_serialize: bool = ...,
cache_version: str = ...,
cache_ignore_input_vars: Tuple[str, ...] = ...,
Expand Down Expand Up @@ -174,7 +175,7 @@ def task(
def task(
_task_function: Optional[Callable[P, FuncOut]] = None,
task_config: Optional[T] = None,
cache: bool = False,
cache: Union[bool, list[AutoCache]] = False,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we also accept a single AutoCache object?

cache_serialize: bool = False,
cache_version: str = "",
cache_ignore_input_vars: Tuple[str, ...] = (),
Expand Down Expand Up @@ -248,7 +249,7 @@ def my_task(x: int, y: typing.Dict[str, str]) -> str:
:param _task_function: This argument is implicitly passed and represents the decorated function
:param task_config: This argument provides configuration for a specific task types.
Please refer to the plugins documentation for the right object to use.
:param cache: Boolean that indicates if caching should be enabled
:param cache: Boolean that indicates if caching should be enabled or a list of AutoCache implementations
:param cache_serialize: Boolean that indicates if identical (ie. same inputs) instances of this task should be
executed in serial when caching is enabled. This means that given multiple concurrent executions over
identical inputs, only a single instance executes and the rest wait to reuse the cached results. This
Expand Down Expand Up @@ -343,10 +344,16 @@ def launch_dynamically():
"""

def wrapper(fn: Callable[P, Any]) -> PythonFunctionTask[T]:
if isinstance(cache, list) and all(isinstance(item, AutoCache) for item in cache):
cache_versions = [item.get_version() for item in cache]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need to pass the function here

task_hash = "".join(cache_versions)
Comment on lines +348 to +349
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@eapolinario does the cache version string have a character limit? If so we may need to re-hash the concatenated hashes.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I had a feeling this would be an issue. I just left the long join for simplicity and to not introduce a hashing function in flytekit/core/task.py , but happy to hash that.

else:
task_hash = ""

_metadata = TaskMetadata(
cache=cache,
cache_serialize=cache_serialize,
cache_version=cache_version,
cache_version=cache_version if not task_hash else task_hash,
cache_ignore_input_vars=cache_ignore_input_vars,
retries=retries,
interruptible=interruptible,
Expand Down
9 changes: 9 additions & 0 deletions plugins/flytekit-auto-cache/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# Flytekit Auto Cache Plugin



To install the plugin, run the following command:

```bash
pip install flytekitplugins-auto-cache
```
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
"""
.. currentmodule:: flytekitplugins.auto_cache

This package contains things that are useful when extending Flytekit.

.. autosummary::
:template: custom.rst
:toctree: generated/

CacheFunctionBody
"""

from .cache_function_body import CacheFunctionBody
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import ast
import hashlib
import inspect
from typing import Any, Callable


class CacheFunctionBody:
"""
A class that implements a versioning mechanism for functions by generating
a SHA-256 hash of the function's source code combined with a salt.

Attributes:
salt (str): A string used to add uniqueness to the generated hash. Default is "salt".

Methods:
get_version(func: Callable[..., Any]) -> str:
Given a function, generates a version hash based on its source code and the salt.
"""

def __init__(self, salt: str = "salt") -> None:
"""
Initialize the CacheFunctionBody instance with a salt value.

Args:
salt (str): A string to be used as the salt in the hashing process. Defaults to "salt".
"""
self.salt = salt

def get_version(self, func: Callable[..., Any]) -> str:
"""
Generate a version hash for the provided function by parsing its source code
and adding a salt before applying the SHA-256 hash function.

Args:
func (Callable[..., Any]): A callable function whose version hash needs to be generated.

Returns:
str: The SHA-256 hash of the function's source code combined with the salt.
"""
# Get the source code of the function
source = inspect.getsource(func)

# Parse the source code into an Abstract Syntax Tree (AST)
parsed_ast = ast.parse(source)

# Convert the AST into a string representation (dump it)
ast_bytes = ast.dump(parsed_ast).encode("utf-8")

# Combine the AST bytes with the salt (encoded into bytes)
combined_data = ast_bytes + self.salt.encode("utf-8")

# Return the SHA-256 hash of the combined data (AST + salt)
return hashlib.sha256(combined_data).hexdigest()
37 changes: 37 additions & 0 deletions plugins/flytekit-auto-cache/setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from setuptools import setup

PLUGIN_NAME = "auto_cache"

microlib_name = "flytekitplugins-auto-cache"

plugin_requires = ["flytekit"]

__version__ = "0.0.0+develop"

setup(
name=microlib_name,
version=__version__,
author="flyteorg",
author_email="[email protected]",
description="This package holds the auto cache plugins for flytekit",
namespace_packages=["flytekitplugins"],
packages=[f"flytekitplugins.{PLUGIN_NAME}"],
install_requires=plugin_requires,
license="apache2",
python_requires=">=3.8",
classifiers=[
"Intended Audience :: Science/Research",
"Intended Audience :: Developers",
"License :: OSI Approved :: Apache Software License",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Topic :: Scientific/Engineering",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"Topic :: Software Development",
"Topic :: Software Development :: Libraries",
"Topic :: Software Development :: Libraries :: Python Modules",
],
entry_points={"flytekit.plugins": [f"{PLUGIN_NAME}=flytekitplugins.{PLUGIN_NAME}"]},
)
Empty file.
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
def dummy_function(x: int, y: int) -> int:
result = x + y
return result
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
def dummy_function(x: int, y: int) -> int:
# Adding some comments
result = (
x + # Adding inline comment
y # Another inline comment
)

# More comments
return result
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
def dummy_function(x: int, y: int) -> int:
result = x * y
return result
85 changes: 85 additions & 0 deletions plugins/flytekit-auto-cache/tests/test_auto_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
from dummy_functions.dummy_function import dummy_function
from dummy_functions.dummy_function_comments_formatting_change import dummy_function as dummy_function_comments_formatting_change
from dummy_functions.dummy_function_logic_change import dummy_function as dummy_function_logic_change
from flytekitplugins.auto_cache import CacheFunctionBody


def test_get_version_with_same_function_and_salt():
"""
Test that calling get_version with the same function and salt returns the same hash.
"""
cache1 = CacheFunctionBody(salt="salt")
cache2 = CacheFunctionBody(salt="salt")

# Both calls should return the same hash since the function and salt are the same
version1 = cache1.get_version(dummy_function)
version2 = cache2.get_version(dummy_function)

assert version1 == version2, f"Expected {version1}, but got {version2}"


def test_get_version_with_different_salt():
"""
Test that calling get_version with different salts returns different hashes for the same function.
"""
cache1 = CacheFunctionBody(salt="salt1")
cache2 = CacheFunctionBody(salt="salt2")

# The hashes should be different because the salts are different
version1 = cache1.get_version(dummy_function)
version2 = cache2.get_version(dummy_function)

assert version1 != version2, f"Expected different hashes but got the same: {version1}"



def test_get_version_with_different_logic():
"""
Test that functions with the same name but different logic produce different hashes.
"""
cache = CacheFunctionBody(salt="salt")
version1 = cache.get_version(dummy_function)
version2 = cache.get_version(dummy_function_logic_change)

assert version1 != version2, (
f"Hashes should be different for functions with same name but different logic. "
f"Got {version1} and {version2}"
)

# Test functions with different names but same logic
def function_one(x: int, y: int) -> int:
result = x + y
return result

def function_two(x: int, y: int) -> int:
result = x + y
return result

def test_get_version_with_different_function_names():
"""
Test that functions with different names but same logic produce different hashes.
"""
cache = CacheFunctionBody(salt="salt")

version1 = cache.get_version(function_one)
version2 = cache.get_version(function_two)

assert version1 != version2, (
f"Hashes should be different for functions with different names. "
f"Got {version1} and {version2}"
)

def test_get_version_with_formatting_changes():
"""
Test that changing formatting and comments but keeping the same function name
results in the same hash.
"""

cache = CacheFunctionBody(salt="salt")
version1 = cache.get_version(dummy_function)
version2 = cache.get_version(dummy_function_comments_formatting_change)

assert version1 == version2, (
f"Hashes should be the same for functions with same name but different formatting. "
f"Got {version1} and {version2}"
)
Loading