Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Auto Cache Plugin #2971

Open
wants to merge 18 commits into
base: master
Choose a base branch
from
Open

Auto Cache Plugin #2971

wants to merge 18 commits into from

Conversation

dansola
Copy link
Contributor

@dansola dansola commented Dec 2, 2024

Why are the changes needed?

Make caching easier to use in flytekit by reducing cognitive burden of specifying cache versions

What changes were proposed in this pull request?

To use the caching mechanism in a Flyte task, you can define a CachePolicy that combines multiple caching strategies. Here’s an example of how to set it up:

from flytekit import task
from flytekit.core.auto_cache import CachePolicy
from flytekitplugins.auto_cache import CacheFunctionBody, CachePrivateModules

cache_policy = CachePolicy(
    auto_cache_policies = [
        CacheFunctionBody(),
        CachePrivateModules(root_dir="../my_package"),
        ...,
    ]
    salt="my_salt"
)

@task(cache=cache_policy)
def task_fn():
    ...

Salt Parameter

The salt parameter in the CachePolicy adds uniqueness to the generated hash. It can be used to differentiate between different versions of the same task. This ensures that even if the underlying code remains unchanged, the hash will vary if a different salt is provided. This feature is particularly useful for invalidating the cache for specific versions of a task.

Cache Implementations

Users can add any number of cache policies that implement the AutoCache protocol defined in @auto_cache.py. Below are the implementations available so far:

1. CacheFunctionBody

This implementation hashes the contents of the function of interest, ignoring any formatting or comment changes. It ensures that the core logic of the function is considered for versioning.

2. CacheImage

This implementation includes the hash of the container_image object passed. If the image is specified as a name, that string is hashed. If it is an ImageSpec, the parametrization of the ImageSpec is hashed, allowing for precise versioning of the container image used in the task.

3. CachePrivateModules

This implementation recursively searches the task of interest for all callables and constants used. The contents of any callable (function or class) utilized by the task are hashed, ignoring formatting or comments. The values of the literal constants used are also included in the hash.

It accounts for both import and from-import statements at the global and local levels within a module or function. Any callables that are within site-packages (i.e., external libraries) are ignored.

4. CacheExternalDependencies

This implementation recursively searches through all the callables like CachePrivateModules, but when an external package is found, it records the version of the package, which is included in the hash. This ensures that changes in external dependencies are reflected in the task's versioning.

How was this patch tested?

Unit tests for the following:

  • verifying a function hash changes only when function contents change, not when formatting or comments are added
  • verify that a dummy repository can be recursively searched when various import statements are used
  • verify that functions not used by the task of interest are not hashed
  • verify that the all constants used by a task are and any of the functions it calls are identified
  • verify that in a new python environment, the correct external libraries are identified
  • verify that the correct dependency versions can be identified

Setup process

Screenshots

Check all the applicable boxes

  • I updated the documentation accordingly.
  • All new and existing tests passed.
  • All commits are signed-off.

Related PRs

Docs link

Summary by Bito

This PR introduces a comprehensive auto-cache plugin for Flytekit that implements multiple caching strategies including function body hashing, container image versioning, and dependency tracking. The plugin provides flexible cache versioning options through a CachePolicy class that can combine multiple caching mechanisms. The implementation integrates seamlessly with existing task and workflow decorators, supported by extensive test coverage.

Unit tests added: True

Estimated effort to review (1-5, lower is better): 5

@@ -132,9 +133,9 @@ def task(

@overload
def task(
_task_function: Callable[P, FuncOut],
_task_function: Callable[..., FuncOut],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why change P to ...?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There were a bunch of linter errors like error: The first argument to Callable must be a list of types, parameter specification, or "..." [valid-type] if I didn't make that change. I think these linker errors existed in the repo already and were unrelated, though I decided to clean it up to make the linter more readable.

Comment on lines +65 to +67
self.cache_serialize = cache_serialize
self.cache_version = cache_version
self.cache_ignore_input_vars = cache_ignore_input_vars
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's the purpose of saving this state here? aren't these just forwarded to the underlying TaskMetadata?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The idea with this is the user could use the CachePolicy to define all the arguments relating to caching. This simplifies the UX a bit as opposed to having a CachePolicy and a cache_ignore_input_vars, cache_serialize, etc.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is a little confusing:

  • cache_version should not be exposed, since the AutoCache protocol is meant to produce this value automatically, and salt is meant to fulfill the need of manually bumping the cache.
  • I think it makes sense to keep cache_serialize and cache_ignore_input_vars as options to specify in the @task decorator as opposed to introducing this redundancy here.

@@ -95,7 +96,7 @@ def find_pythontask_plugin(cls, plugin_config_type: type) -> Type[PythonFunction
def task(
_task_function: None = ...,
task_config: Optional[T] = ...,
cache: bool = ...,
cache: Union[bool, CachePolicy] = ...,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we make this accept any AutoCache-compliant object?

Basically the user can provide just a single autocache object like CacheFunctionBody or compose multiple into a CachePolicy, but users should be forced to always use a CachePolicy object.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm I think I'm a bit confused since you said "users should be forced to always use a CachePolicy object". If they just use a single autocache object like CacheFunctionBody, aren't they NOT using a CachePolicy?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh reading your other comment and re-visiting this, I'm thinking you might have meant ... but users _shouldn't_ be forced to always ... Let me make that change!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

... but users shouldn't be forced to always ...
correct! sorry for the typo

Comment on lines 350 to 357
cache_version_val = cache_version or cache.get_version(params=params)
cache_serialize_val = cache_serialize or cache.cache_serialize
cache_serialize_val = cache_ignore_input_vars or cache.cache_ignore_input_vars
else:
cache_val = cache
cache_version_val = cache_version
cache_serialize_val = cache_serialize
cache_ignore_input_vars_val = cache_ignore_input_vars
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's the purpose of forwarding all of these parameters via the CachePolicy object? It doesn't look like it's being modified there.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well I thought that if the user wants to put all cache settings in the CachePolicy, they can do so for simplicity. However, if the user also sets some cache setting in the task decorator, the decorator takes precedence.

Comment on lines 20 to 27
cache_policy = CachePolicy(
auto_cache_policies = [
CacheFunctionBody(),
CachePrivateModules(root_dir="../my_package"),
...,
]
salt="my_salt"
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's also provide an example of not needing to provide a CachePolicy object, e.g. just a passing in CacheFunctionBody.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will do!

Copy link

codecov bot commented Jan 3, 2025

Codecov Report

Attention: Patch coverage is 47.82609% with 24 lines in your changes missing coverage. Please review.

Project coverage is 77.96%. Comparing base (3b7cb3c) to head (18f253e).
Report is 51 commits behind head on master.

Files with missing lines Patch % Lines
flytekit/core/auto_cache.py 45.16% 17 Missing ⚠️
flytekit/core/task.py 53.33% 6 Missing and 1 partial ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master    #2971      +/-   ##
==========================================
+ Coverage   76.49%   77.96%   +1.46%     
==========================================
  Files         200      202       +2     
  Lines       20901    21324     +423     
  Branches     2689     2739      +50     
==========================================
+ Hits        15989    16625     +636     
+ Misses       4195     3904     -291     
- Partials      717      795      +78     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@flyte-bot
Copy link
Contributor

flyte-bot commented Jan 4, 2025

Code Review Agent Run #bc105b

Actionable Suggestions - 9
  • flytekit/core/auto_cache.py - 1
    • Return type inconsistency in get_version method · Line 95-95
  • plugins/flytekit-auto-cache/tests/requirements-test.txt - 1
    • Consider flexible version pinning for dependencies · Line 1-20
  • plugins/flytekit-auto-cache/tests/verify_identified_packages.py - 1
    • Missing assert keyword in test validation · Line 11-11
  • plugins/flytekit-auto-cache/flytekitplugins/auto_cache/cache_private_modules.py - 1
    • Consider splitting long method into smaller ones · Line 108-221
  • flytekit/core/task.py - 1
    • Consider implications of looser type hints · Line 136-136
  • plugins/flytekit-auto-cache/flytekitplugins/auto_cache/cache_function_body.py - 2
    • Consider validating empty salt parameter · Line 23-23
    • Consider adding type check for func · Line 32-35
  • plugins/flytekit-auto-cache/tests/my_package/module_a.py - 1
    • Consider handling sum function result · Line 12-12
  • plugins/flytekit-auto-cache/flytekitplugins/auto_cache/cache_external_dependencies.py - 1
Additional Suggestions - 7
  • plugins/flytekit-auto-cache/flytekitplugins/auto_cache/cache_private_modules.py - 3
  • flytekit/core/task.py - 1
    • Consider extracting cache policy handling logic · Line 353-362
  • plugins/flytekit-auto-cache/flytekitplugins/auto_cache/cache_image.py - 1
    • Consider extracting duplicate hash logic · Line 46-51
  • plugins/flytekit-auto-cache/flytekitplugins/auto_cache/cache_external_dependencies.py - 1
    • Consider breaking down version lookup logic · Line 89-126
  • plugins/flytekit-auto-cache/tests/my_package/main.py - 1
    • Consider splitting function responsibilities · Line 14-22
Review Details
  • Files reviewed - 27 · Commit Range: 2786c5b..18f253e
    • flytekit/core/auto_cache.py
    • flytekit/core/task.py
    • flytekit/core/workflow.py
    • plugins/flytekit-auto-cache/flytekitplugins/auto_cache/__init__.py
    • plugins/flytekit-auto-cache/flytekitplugins/auto_cache/cache_external_dependencies.py
    • plugins/flytekit-auto-cache/flytekitplugins/auto_cache/cache_function_body.py
    • plugins/flytekit-auto-cache/flytekitplugins/auto_cache/cache_image.py
    • plugins/flytekit-auto-cache/flytekitplugins/auto_cache/cache_private_modules.py
    • plugins/flytekit-auto-cache/setup.py
    • plugins/flytekit-auto-cache/tests/dummy_functions/dummy_function.py
    • plugins/flytekit-auto-cache/tests/dummy_functions/dummy_function_comments_formatting_change.py
    • plugins/flytekit-auto-cache/tests/dummy_functions/dummy_function_logic_change.py
    • plugins/flytekit-auto-cache/tests/my_package/main.py
    • plugins/flytekit-auto-cache/tests/my_package/module_a.py
    • plugins/flytekit-auto-cache/tests/my_package/module_b.py
    • plugins/flytekit-auto-cache/tests/my_package/module_c.py
    • plugins/flytekit-auto-cache/tests/my_package/module_d.py
    • plugins/flytekit-auto-cache/tests/my_package/my_dir/__init__.py
    • plugins/flytekit-auto-cache/tests/my_package/my_dir/module_in_dir.py
    • plugins/flytekit-auto-cache/tests/my_package/utils.py
    • plugins/flytekit-auto-cache/tests/requirements-test.txt
    • plugins/flytekit-auto-cache/tests/test_external_dependencies.py
    • plugins/flytekit-auto-cache/tests/test_function_body.py
    • plugins/flytekit-auto-cache/tests/test_image.py
    • plugins/flytekit-auto-cache/tests/test_recursive.py
    • plugins/flytekit-auto-cache/tests/verify_identified_packages.py
    • plugins/flytekit-auto-cache/tests/verify_versions.py
  • Files skipped - 1
    • plugins/flytekit-auto-cache/README.md - Reason: Filter setting
  • Tools
    • Whispers (Secret Scanner) - ✔︎ Successful
    • Detect-secrets (Secret Scanner) - ✔︎ Successful
    • MyPy (Static Code Analysis) - ✔︎ Successful
    • Astral Ruff (Static Code Analysis) - ✔︎ Successful

AI Code Review powered by Bito Logo

@flyte-bot
Copy link
Contributor

Changelist by Bito

This pull request implements the following key changes.

Key Change Files Impacted
New Feature - Auto Cache Plugin Implementation

auto_cache.py - Introduces core auto-cache functionality with VersionParameters and CachePolicy classes

__init__.py - Creates plugin initialization module for auto-cache functionality

cache_external_dependencies.py - Implements external dependency tracking for caching

cache_function_body.py - Implements function body hashing for caching

cache_image.py - Implements container image-based caching

cache_private_modules.py - Implements private module dependency tracking for caching

Feature Improvement - Task and Workflow Cache Integration

task.py - Updates task decorator to support new auto-cache functionality

workflow.py - Enhances workflow functionality to support auto-cache features

New Feature - Auto Cache Plugin Implementation

auto_cache.py - Introduces core auto-cache functionality with VersionParameters and CachePolicy classes

__init__.py - Creates plugin initialization module for auto-cache functionality

cache_external_dependencies.py - Implements external dependency tracking for caching

cache_function_body.py - Implements function body hashing for caching

cache_image.py - Implements container image-based caching

cache_private_modules.py - Implements private module dependency tracking for caching

setup.py - Sets up the auto-cache plugin package configuration

Testing - Comprehensive Test Suite Implementation

dummy_function.py - Adds test function for basic caching functionality

dummy_function_comments_formatting_change.py - Tests caching behavior with formatting changes

dummy_function_logic_change.py - Tests caching behavior with logic changes

main.py - Implements test package main module

module_a.py - Adds test module for dependency tracking

module_b.py - Adds additional test module for dependency tracking

module_c.py - Implements test module with class-based components

module_d.py - Adds test module for external package imports

__init__.py - Sets up test package directory structure

module_in_dir.py - Tests nested module functionality

utils.py - Implements utility constants for testing

requirements-test.txt - Defines test dependencies

test_external_dependencies.py - Tests external dependency tracking

test_function_body.py - Tests function body hashing functionality

test_image.py - Tests container image-based caching

test_recursive.py - Tests recursive dependency analysis

verify_identified_packages.py - Verifies package identification

verify_versions.py - Validates package version detection

Feature Improvement - Task and Workflow Cache Integration

task.py - Updates task decorator to support new auto-cache functionality

workflow.py - Enhances workflow functionality to support auto-cache features

hash_obj = hashlib.sha256(task_hash.encode())
return hash_obj.hexdigest()

return None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Return type inconsistency in get_version method

Consider returning an empty string instead of None for consistency in return types. The method signature indicates it returns str but can return None.

Code suggestion
Check the AI-generated fix before applying
Suggested change
return None
return ""

Code Review Run #bc105b


Is this a valid issue, or was it incorrectly flagged by the Agent?

  • it was incorrectly flagged

Comment on lines +1 to +20
numpy==1.24.3
pandas==2.0.3
requests==2.31.0
matplotlib==3.7.2
pillow==10.0.0
scipy==1.11.2
pytest==7.4.0
urllib3==2.0.4
cryptography==41.0.3
setuptools==68.0.0
flask==2.3.2
django==4.2.4
scikit-learn==1.3.0
beautifulsoup4==4.12.2
pyyaml==6.0
fastapi==0.100.0
sqlalchemy==2.0.36
tqdm==4.65.0
pytest-mock==3.11.0
jinja2==3.1.2
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider flexible version pinning for dependencies

Consider pinning dependencies to compatible versions using ~= or >= instead of == to allow for minor version updates that include security patches while maintaining compatibility. This helps keep dependencies up-to-date with security fixes.

Code suggestion
Check the AI-generated fix before applying
Suggested change
numpy==1.24.3
pandas==2.0.3
requests==2.31.0
matplotlib==3.7.2
pillow==10.0.0
scipy==1.11.2
pytest==7.4.0
urllib3==2.0.4
cryptography==41.0.3
setuptools==68.0.0
flask==2.3.2
django==4.2.4
scikit-learn==1.3.0
beautifulsoup4==4.12.2
pyyaml==6.0
fastapi==0.100.0
sqlalchemy==2.0.36
tqdm==4.65.0
pytest-mock==3.11.0
jinja2==3.1.2
numpy~=1.24.3
pandas~=2.0.3
requests~=2.31.0
matplotlib~=3.7.2
pillow~=10.0.0
scipy~=1.11.2
pytest~=7.4.0
urllib3~=2.0.4
cryptography~=41.0.3
setuptools~=68.0.0
flask~=2.3.2
django~=4.2.4
scikit-learn~=1.3.0
beautifulsoup4~=4.12.2
pyyaml~=6.0
fastapi~=0.100.0
sqlalchemy~=2.0.36
tqdm~=4.65.0
pytest-mock~=3.11.0
jinja2~=3.1.2

Code Review Run #bc105b


Is this a valid issue, or was it incorrectly flagged by the Agent?

  • it was incorrectly flagged

packages = cache.get_version_dict().keys()

expected_packages = {'PIL', 'bs4', 'numpy', 'pandas', 'scipy', 'sklearn'}
set(packages) == expected_packages, f"Expected keys {expected_packages}, but got {set(packages)}"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing assert keyword in test validation

The assertion statement appears to be missing the assert keyword, which means this comparison won't actually validate the test condition. Consider adding the assert keyword.

Code suggestion
Check the AI-generated fix before applying
Suggested change
set(packages) == expected_packages, f"Expected keys {expected_packages}, but got {set(packages)}"
assert set(packages) == expected_packages, f"Expected keys {expected_packages}, but got {set(packages)}"

Code Review Run #bc105b


Is this a valid issue, or was it incorrectly flagged by the Agent?

  • it was incorrectly flagged

Comment on lines +108 to +221

Returns:
Set[Callable[..., Any]]: A set of all dependencies found.
"""

dependencies = set()
source = textwrap.dedent(inspect.getsource(func))
parsed_ast = ast.parse(source)

# Initialize a dictionary to mimic the function's global namespace for locally defined imports
locals_dict = {}
# Initialize a dictionary to hold constant imports and class attributes
constant_imports = {}
# If class attributes are provided, include them in the constant imports
if class_attributes:
constant_imports.update(class_attributes)

# Check each function call in the AST
for node in ast.walk(parsed_ast):
if isinstance(node, ast.Import):
# For each alias in the import statement, we import the module and add it to the locals_dict.
# This is because the module itself is being imported, not a specific attribute or function.
for alias in node.names:
module = importlib.import_module(alias.name)
locals_dict[self._get_alias_name(alias)] = module
# We then get all the literal constants defined in the module's __init__.py file.
# These constants are later checked for usage within the function.
module_constants = self.get_module_literal_constants(module)
constant_imports.update(
{f"{self._get_alias_name(alias)}.{name}": value for name, value in module_constants.items()}
)
elif isinstance(node, ast.ImportFrom):
module_name = node.module
module = importlib.import_module(module_name)
for alias in node.names:
# Attempt to resolve the imported object directly from the module
imported_obj = getattr(module, alias.name, None)
if imported_obj:
# If the object is found directly in the module, add it to the locals_dict
locals_dict[self._get_alias_name(alias)] = imported_obj
# Check if the imported object is a literal constant and add it to constant_imports if so
if self.is_literal_constant(imported_obj):
constant_imports.update({f"{self._get_alias_name(alias)}": imported_obj})
else:
# If the object is not found directly in the module, attempt to import it as a submodule
# This is necessary for cases like `from PIL import Image`, where Image is not imported in PIL's __init__.py
# PIL and similar packages use different mechanisms to expose their objects, requiring this fallback approach
submodule = importlib.import_module(f"{module_name}.{alias.name}")
imported_obj = getattr(submodule, alias.name, None)
locals_dict[self._get_alias_name(alias)] = imported_obj

elif isinstance(node, ast.Call):
# Add callable to the set of dependencies if it's user defined and continue the recursive search within those callables.
func_name = self._get_callable_name(node.func)
if func_name and func_name not in visited:
visited.add(func_name)
try:
# Attempt to resolve the callable object using locals first, then globals
func_obj = self._resolve_callable(func_name, locals_dict) or self._resolve_callable(
func_name, func.__globals__
)
# If the callable is a class and user-defined, we add and search all method. We also include attributes as potential constants.
if inspect.isclass(func_obj) and self._is_user_defined(func_obj):
current_class_attributes = {
f"class.{func_name}.{name}": value for name, value in func_obj.__dict__.items()
}
for name, method in inspect.getmembers(func_obj, predicate=inspect.isfunction):
if method not in visited:
visited.add(method.__qualname__)
dependencies.add(method)
dependencies.update(
self._get_function_dependencies(method, visited, current_class_attributes)
)
# If the callable is a function or method and user-defined, add it as a dependency and search its dependencies
elif (inspect.isfunction(func_obj) or inspect.ismethod(func_obj)) and self._is_user_defined(
func_obj
):
# Add the function or method as a dependency
dependencies.add(func_obj)
# Recursively search the function or method's dependencies
dependencies.update(self._get_function_dependencies(func_obj, visited))
except (NameError, AttributeError) as e:
click.secho(f"Could not process the callable {func_name} due to error: {str(e)}", fg="yellow")

# Extract potential constants from the global import context
global_constants = {}
for key, value in func.__globals__.items():
if hasattr(value, "__dict__"):
module_constants = self.get_module_literal_constants(value)
global_constants.update({f"{key}.{name}": value for name, value in module_constants.items()})
elif self.is_literal_constant(value):
global_constants[key] = value

# Check for the usage of all potnential constants and update the set of constants to be hashed
referenced_constants = self.get_referenced_constants(
func=func, constant_imports=constant_imports, global_constants=global_constants
)
self.constants.update(referenced_constants)

return dependencies
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider splitting long method into smaller ones

The _get_function_dependencies method is quite long (>100 lines) and handles multiple responsibilities including import handling, AST traversal, and constant extraction. Consider breaking it down into smaller focused methods for better maintainability.

Code suggestion
Check the AI-generated fix before applying
 @@ -108,114 +108,50 @@
  def _get_function_dependencies(
          self, func: Callable[..., Any], visited: Set[str], class_attributes: dict = None
      ) -> Set[Callable[..., Any]]:
 -        dependencies = set()
 -        source = textwrap.dedent(inspect.getsource(func))
 -        parsed_ast = ast.parse(source)
 -
 -        # Initialize dictionaries
 -        locals_dict = {}
 -        constant_imports = {}
 -        if class_attributes:
 -            constant_imports.update(class_attributes)
 -
 -        # Check each function call in the AST
 -        for node in ast.walk(parsed_ast):
 -            if isinstance(node, ast.Import):
 -                # Handle imports
 -                for alias in node.names:
 -                    module = importlib.import_module(alias.name)
 -                    locals_dict[self._get_alias_name(alias)] = module
 -                    module_constants = self.get_module_literal_constants(module)
 -                    constant_imports.update(
 -                        {f"{self._get_alias_name(alias)}.{name}": value for name, value in module_constants.items()}
 -                    )
 -            elif isinstance(node, ast.ImportFrom):
 -                # Handle from imports
 -                module_name = node.module
 -                module = importlib.import_module(module_name)
 -                for alias in node.names:
 -                    imported_obj = getattr(module, alias.name, None)
 -                    if imported_obj:
 -                        locals_dict[self._get_alias_name(alias)] = imported_obj
 -                        if self.is_literal_constant(imported_obj):
 -                            constant_imports.update({f"{self._get_alias_name(alias)}": imported_obj})
 -                    else:
 -                        submodule = importlib.import_module(f"{module_name}.{alias.name}")
 -                        imported_obj = getattr(submodule, alias.name, None)
 -                        locals_dict[self._get_alias_name(alias)] = imported_obj
 +        dependencies = set()
 +        source = textwrap.dedent(inspect.getsource(func))
 +        parsed_ast = ast.parse(source)
 +        
 +        locals_dict, constant_imports = self._initialize_dictionaries(class_attributes)
 +        
 +        for node in ast.walk(parsed_ast):
 +            if isinstance(node, ast.Import):
 +                self._handle_imports(node, locals_dict, constant_imports)
 +            elif isinstance(node, ast.ImportFrom):
 +                self._handle_import_from(node, locals_dict, constant_imports)
 +            elif isinstance(node, ast.Call):
 +                self._process_callable(node, locals_dict, func, visited, dependencies)
 +        
 +        self._extract_constants(func, constant_imports)
 +        return dependencies

Code Review Run #bc105b


Is this a valid issue, or was it incorrectly flagged by the Agent?

  • it was incorrectly flagged

@@ -132,9 +133,9 @@

@overload
def task(
_task_function: Callable[P, FuncOut],
_task_function: Callable[..., FuncOut],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider implications of looser type hints

Consider if changing _task_function type hint from Callable[P, FuncOut] to Callable[..., FuncOut] could make the type checking less strict. The ... allows any arguments which may hide potential type errors at compile time. Similar issues were also found in:

  • flytekit/core/workflow.py (line 846-864)
  • flytekit/core/workflow.py (line 857-901)
Code suggestion
Check the AI-generated fix before applying
Suggested change
_task_function: Callable[..., FuncOut],
_task_function: Callable[P, FuncOut],

Code Review Run #bc105b


Is this a valid issue, or was it incorrectly flagged by the Agent?

  • it was incorrectly flagged

Given a function, generates a version hash based on its source code and the salt.
"""

def __init__(self, salt: str = "") -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider validating empty salt parameter

Consider adding validation for empty salt parameter in __init__. An empty salt could potentially lead to weaker caching behavior.

Code suggestion
Check the AI-generated fix before applying
 @@ -23,8 +23,10 @@
  def __init__(self, salt: str = "") -> None:
          """
          Initialize the CacheFunctionBody instance with a salt value.
          """
 -        self.salt = salt
 +        if not salt:
 +            raise ValueError("Salt cannot be empty as it affects cache effectiveness")
 +        self.salt = salt

Code Review Run #bc105b


Is this a valid issue, or was it incorrectly flagged by the Agent?

  • it was incorrectly flagged

Comment on lines +32 to +35
def get_version(self, params: VersionParameters) -> str:
if params.func is None:
raise ValueError("Function-based cache requires a function parameter")
return self._get_version(func=params.func)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider adding type check for func

The get_version method could benefit from type checking params.func before accessing it to provide a more descriptive error message.

Code suggestion
Check the AI-generated fix before applying
Suggested change
def get_version(self, params: VersionParameters) -> str:
if params.func is None:
raise ValueError("Function-based cache requires a function parameter")
return self._get_version(func=params.func)
def get_version(self, params: VersionParameters) -> str:
if params.func is None:
raise ValueError("Function-based cache requires a function parameter")
if not callable(params.func):
raise TypeError("params.func must be a callable function")
return self._get_version(func=params.func)

Code Review Run #bc105b


Is this a valid issue, or was it incorrectly flagged by the Agent?

  • it was incorrectly flagged

module_b.another_helper()
result = norm([1, 2, 3])
print(result)
sum([SOME_CONSTANT, utils.THIRD_CONSTANT])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider handling sum function result

The sum() function call result is not being stored or used, which may indicate meaningless executed code. Consider either storing the result or removing if not needed.

Code suggestion
Check the AI-generated fix before applying
Suggested change
sum([SOME_CONSTANT, utils.THIRD_CONSTANT])
result = sum([SOME_CONSTANT, utils.THIRD_CONSTANT])

Code Review Run #bc105b


Is this a valid issue, or was it incorrectly flagged by the Agent?

  • it was incorrectly flagged

Comment on lines +106 to +107
except Exception as e:
click.secho(f"Could not get version for {package_name} using importlib.metadata: {str(e)}", fg="yellow")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Too broad exception handling

Catching a broad 'Exception' may hide bugs. Consider catching specific exceptions instead.

Code suggestion
Check the AI-generated fix before applying
Suggested change
except Exception as e:
click.secho(f"Could not get version for {package_name} using importlib.metadata: {str(e)}", fg="yellow")
except (ImportError, AttributeError) as e:
click.secho(f"Could not get version for {package_name} using importlib.metadata: {str(e)}", fg="yellow")

Code Review Run #bc105b


Is this a valid issue, or was it incorrectly flagged by the Agent?

  • it was incorrectly flagged

...


class CachePolicy:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can CachePolicy live in the plugin? It makes sense for the abstract AutoCache protocol to be defined in flytekit core, but any implementation of it should be in the plugin.

task_config: Optional[T] = ...,
cache: bool = ...,
cache: Union[bool, CachePolicy, AutoCache] = ...,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should be Union[bool, AutoCache], since CachePolicy is just an implementation of the AutoCache protocol.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants