-
Notifications
You must be signed in to change notification settings - Fork 302
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
initial make auto cache plugin #2912
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
from typing import Any, Callable, Protocol, runtime_checkable | ||
|
||
|
||
@runtime_checkable | ||
class AutoCache(Protocol): | ||
""" | ||
A protocol that defines the interface for a caching mechanism | ||
that generates a version hash of a function based on its source code. | ||
|
||
Attributes: | ||
salt (str): A string used to add uniqueness to the generated hash. Default is "salt". | ||
|
||
Methods: | ||
get_version(func: Callable[..., Any]) -> str: | ||
Given a function, generates a version hash based on its source code and the salt. | ||
""" | ||
|
||
def __init__(self, salt: str = "salt") -> None: | ||
""" | ||
Initialize the AutoCache instance with a salt value. | ||
|
||
Args: | ||
salt (str): A string to be used as the salt in the hashing process. Defaults to "salt". | ||
""" | ||
self.salt = salt | ||
|
||
def get_version(self, func: Callable[..., Any]) -> str: | ||
""" | ||
Generate a version hash for the provided function. | ||
|
||
Args: | ||
func (Callable[..., Any]): A callable function whose version hash needs to be generated. | ||
|
||
Returns: | ||
str: The SHA-256 hash of the function's source code combined with the salt. | ||
""" | ||
... |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,6 +5,7 @@ | |
from functools import update_wrapper | ||
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Type, TypeVar, Union, overload | ||
|
||
from flytekit.core.auto_cache import AutoCache | ||
from flytekit.core.utils import str2bool | ||
|
||
try: | ||
|
@@ -99,7 +100,7 @@ def find_pythontask_plugin(cls, plugin_config_type: type) -> Type[PythonFunction | |
def task( | ||
_task_function: None = ..., | ||
task_config: Optional[T] = ..., | ||
cache: bool = ..., | ||
cache: Union[bool, list[AutoCache]] = ..., | ||
cache_serialize: bool = ..., | ||
cache_version: str = ..., | ||
cache_ignore_input_vars: Tuple[str, ...] = ..., | ||
|
@@ -137,7 +138,7 @@ def task( | |
def task( | ||
_task_function: Callable[P, FuncOut], | ||
task_config: Optional[T] = ..., | ||
cache: bool = ..., | ||
cache: Union[bool, list[AutoCache]] = ..., | ||
cache_serialize: bool = ..., | ||
cache_version: str = ..., | ||
cache_ignore_input_vars: Tuple[str, ...] = ..., | ||
|
@@ -174,7 +175,7 @@ def task( | |
def task( | ||
_task_function: Optional[Callable[P, FuncOut]] = None, | ||
task_config: Optional[T] = None, | ||
cache: bool = False, | ||
cache: Union[bool, list[AutoCache]] = False, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can we also accept a single |
||
cache_serialize: bool = False, | ||
cache_version: str = "", | ||
cache_ignore_input_vars: Tuple[str, ...] = (), | ||
|
@@ -248,7 +249,7 @@ def my_task(x: int, y: typing.Dict[str, str]) -> str: | |
:param _task_function: This argument is implicitly passed and represents the decorated function | ||
:param task_config: This argument provides configuration for a specific task types. | ||
Please refer to the plugins documentation for the right object to use. | ||
:param cache: Boolean that indicates if caching should be enabled | ||
:param cache: Boolean that indicates if caching should be enabled or a list of AutoCache implementations | ||
:param cache_serialize: Boolean that indicates if identical (ie. same inputs) instances of this task should be | ||
executed in serial when caching is enabled. This means that given multiple concurrent executions over | ||
identical inputs, only a single instance executes and the rest wait to reuse the cached results. This | ||
|
@@ -343,10 +344,16 @@ def launch_dynamically(): | |
""" | ||
|
||
def wrapper(fn: Callable[P, Any]) -> PythonFunctionTask[T]: | ||
if isinstance(cache, list) and all(isinstance(item, AutoCache) for item in cache): | ||
cache_versions = [item.get_version() for item in cache] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. need to pass the function here |
||
task_hash = "".join(cache_versions) | ||
Comment on lines
+348
to
+349
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @eapolinario does the cache version string have a character limit? If so we may need to re-hash the concatenated hashes. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah I had a feeling this would be an issue. I just left the long join for simplicity and to not introduce a hashing function in |
||
else: | ||
task_hash = "" | ||
|
||
_metadata = TaskMetadata( | ||
cache=cache, | ||
cache_serialize=cache_serialize, | ||
cache_version=cache_version, | ||
cache_version=cache_version if not task_hash else task_hash, | ||
cache_ignore_input_vars=cache_ignore_input_vars, | ||
retries=retries, | ||
interruptible=interruptible, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
# Flytekit Auto Cache Plugin | ||
|
||
|
||
|
||
To install the plugin, run the following command: | ||
|
||
```bash | ||
pip install flytekitplugins-auto-cache | ||
``` |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
""" | ||
.. currentmodule:: flytekitplugins.auto_cache | ||
|
||
This package contains things that are useful when extending Flytekit. | ||
|
||
.. autosummary:: | ||
:template: custom.rst | ||
:toctree: generated/ | ||
|
||
CacheFunctionBody | ||
""" | ||
|
||
from .cache_function_body import CacheFunctionBody |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
import ast | ||
import hashlib | ||
import inspect | ||
from typing import Any, Callable | ||
|
||
|
||
class CacheFunctionBody: | ||
""" | ||
A class that implements a versioning mechanism for functions by generating | ||
a SHA-256 hash of the function's source code combined with a salt. | ||
|
||
Attributes: | ||
salt (str): A string used to add uniqueness to the generated hash. Default is "salt". | ||
|
||
Methods: | ||
get_version(func: Callable[..., Any]) -> str: | ||
Given a function, generates a version hash based on its source code and the salt. | ||
""" | ||
|
||
def __init__(self, salt: str = "salt") -> None: | ||
""" | ||
Initialize the CacheFunctionBody instance with a salt value. | ||
|
||
Args: | ||
salt (str): A string to be used as the salt in the hashing process. Defaults to "salt". | ||
""" | ||
self.salt = salt | ||
|
||
def get_version(self, func: Callable[..., Any]) -> str: | ||
""" | ||
Generate a version hash for the provided function by parsing its source code | ||
and adding a salt before applying the SHA-256 hash function. | ||
|
||
Args: | ||
func (Callable[..., Any]): A callable function whose version hash needs to be generated. | ||
|
||
Returns: | ||
str: The SHA-256 hash of the function's source code combined with the salt. | ||
""" | ||
# Get the source code of the function | ||
source = inspect.getsource(func) | ||
|
||
# Parse the source code into an Abstract Syntax Tree (AST) | ||
parsed_ast = ast.parse(source) | ||
|
||
# Convert the AST into a string representation (dump it) | ||
ast_bytes = ast.dump(parsed_ast).encode("utf-8") | ||
|
||
# Combine the AST bytes with the salt (encoded into bytes) | ||
combined_data = ast_bytes + self.salt.encode("utf-8") | ||
|
||
# Return the SHA-256 hash of the combined data (AST + salt) | ||
return hashlib.sha256(combined_data).hexdigest() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
from setuptools import setup | ||
|
||
PLUGIN_NAME = "auto_cache" | ||
|
||
microlib_name = "flytekitplugins-auto-cache" | ||
|
||
plugin_requires = ["flytekit"] | ||
|
||
__version__ = "0.0.0+develop" | ||
|
||
setup( | ||
name=microlib_name, | ||
version=__version__, | ||
author="flyteorg", | ||
author_email="[email protected]", | ||
description="This package holds the auto cache plugins for flytekit", | ||
namespace_packages=["flytekitplugins"], | ||
packages=[f"flytekitplugins.{PLUGIN_NAME}"], | ||
install_requires=plugin_requires, | ||
license="apache2", | ||
python_requires=">=3.8", | ||
classifiers=[ | ||
"Intended Audience :: Science/Research", | ||
"Intended Audience :: Developers", | ||
"License :: OSI Approved :: Apache Software License", | ||
"Programming Language :: Python :: 3.8", | ||
"Programming Language :: Python :: 3.9", | ||
"Programming Language :: Python :: 3.10", | ||
"Programming Language :: Python :: 3.11", | ||
"Topic :: Scientific/Engineering", | ||
"Topic :: Scientific/Engineering :: Artificial Intelligence", | ||
"Topic :: Software Development", | ||
"Topic :: Software Development :: Libraries", | ||
"Topic :: Software Development :: Libraries :: Python Modules", | ||
], | ||
entry_points={"flytekit.plugins": [f"{PLUGIN_NAME}=flytekitplugins.{PLUGIN_NAME}"]}, | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
def dummy_function(x: int, y: int) -> int: | ||
result = x + y | ||
return result |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
def dummy_function(x: int, y: int) -> int: | ||
# Adding some comments | ||
result = ( | ||
x + # Adding inline comment | ||
y # Another inline comment | ||
) | ||
|
||
# More comments | ||
return result |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
def dummy_function(x: int, y: int) -> int: | ||
result = x * y | ||
return result |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
from dummy_functions.dummy_function import dummy_function | ||
from dummy_functions.dummy_function_comments_formatting_change import dummy_function as dummy_function_comments_formatting_change | ||
from dummy_functions.dummy_function_logic_change import dummy_function as dummy_function_logic_change | ||
from flytekitplugins.auto_cache import CacheFunctionBody | ||
|
||
|
||
def test_get_version_with_same_function_and_salt(): | ||
""" | ||
Test that calling get_version with the same function and salt returns the same hash. | ||
""" | ||
cache1 = CacheFunctionBody(salt="salt") | ||
cache2 = CacheFunctionBody(salt="salt") | ||
|
||
# Both calls should return the same hash since the function and salt are the same | ||
version1 = cache1.get_version(dummy_function) | ||
version2 = cache2.get_version(dummy_function) | ||
|
||
assert version1 == version2, f"Expected {version1}, but got {version2}" | ||
|
||
|
||
def test_get_version_with_different_salt(): | ||
""" | ||
Test that calling get_version with different salts returns different hashes for the same function. | ||
""" | ||
cache1 = CacheFunctionBody(salt="salt1") | ||
cache2 = CacheFunctionBody(salt="salt2") | ||
|
||
# The hashes should be different because the salts are different | ||
version1 = cache1.get_version(dummy_function) | ||
version2 = cache2.get_version(dummy_function) | ||
|
||
assert version1 != version2, f"Expected different hashes but got the same: {version1}" | ||
|
||
|
||
|
||
def test_get_version_with_different_logic(): | ||
""" | ||
Test that functions with the same name but different logic produce different hashes. | ||
""" | ||
cache = CacheFunctionBody(salt="salt") | ||
version1 = cache.get_version(dummy_function) | ||
version2 = cache.get_version(dummy_function_logic_change) | ||
|
||
assert version1 != version2, ( | ||
f"Hashes should be different for functions with same name but different logic. " | ||
f"Got {version1} and {version2}" | ||
) | ||
|
||
# Test functions with different names but same logic | ||
def function_one(x: int, y: int) -> int: | ||
result = x + y | ||
return result | ||
|
||
def function_two(x: int, y: int) -> int: | ||
result = x + y | ||
return result | ||
|
||
def test_get_version_with_different_function_names(): | ||
""" | ||
Test that functions with different names but same logic produce different hashes. | ||
""" | ||
cache = CacheFunctionBody(salt="salt") | ||
|
||
version1 = cache.get_version(function_one) | ||
version2 = cache.get_version(function_two) | ||
|
||
assert version1 != version2, ( | ||
f"Hashes should be different for functions with different names. " | ||
f"Got {version1} and {version2}" | ||
) | ||
|
||
def test_get_version_with_formatting_changes(): | ||
""" | ||
Test that changing formatting and comments but keeping the same function name | ||
results in the same hash. | ||
""" | ||
|
||
cache = CacheFunctionBody(salt="salt") | ||
version1 = cache.get_version(dummy_function) | ||
version2 = cache.get_version(dummy_function_comments_formatting_change) | ||
|
||
assert version1 == version2, ( | ||
f"Hashes should be the same for functions with same name but different formatting. " | ||
f"Got {version1} and {version2}" | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: can the default value be
""
?