-
Notifications
You must be signed in to change notification settings - Fork 308
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add shared_memory to task with extended resources #3096
base: master
Are you sure you want to change the base?
Changes from all commits
ac61755
e3221e2
65dda33
aecce00
eadab88
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -4,7 +4,7 @@ | |||||||||||||||||||
import re | ||||||||||||||||||||
from abc import ABC | ||||||||||||||||||||
from dataclasses import dataclass | ||||||||||||||||||||
from typing import Callable, Dict, List, Optional, TypeVar, Union | ||||||||||||||||||||
from typing import Callable, Dict, List, Literal, Optional, TypeVar, Union | ||||||||||||||||||||
|
||||||||||||||||||||
from flyteidl.core import tasks_pb2 | ||||||||||||||||||||
|
||||||||||||||||||||
|
@@ -13,7 +13,7 @@ | |||||||||||||||||||
from flytekit.core.base_task import PythonTask, TaskMetadata, TaskResolverMixin | ||||||||||||||||||||
from flytekit.core.context_manager import FlyteContextManager | ||||||||||||||||||||
from flytekit.core.pod_template import PodTemplate | ||||||||||||||||||||
from flytekit.core.resources import Resources, ResourceSpec | ||||||||||||||||||||
from flytekit.core.resources import Resources, ResourceSpec, construct_extended_resources | ||||||||||||||||||||
from flytekit.core.tracked_abc import FlyteTrackedABC | ||||||||||||||||||||
from flytekit.core.tracker import TrackedInstance, extract_task_module | ||||||||||||||||||||
from flytekit.core.utils import _get_container_definition, _serialize_pod_spec, timeit | ||||||||||||||||||||
|
@@ -51,6 +51,7 @@ def __init__( | |||||||||||||||||||
pod_template: Optional[PodTemplate] = None, | ||||||||||||||||||||
pod_template_name: Optional[str] = None, | ||||||||||||||||||||
accelerator: Optional[BaseAccelerator] = None, | ||||||||||||||||||||
shared_memory: Optional[Union[Literal[True], str]] = None, | ||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Consider adding shared memory format validation
Consider validating the Code suggestionCheck the AI-generated fix before applying
Suggested change
Code Review Run #3d3e41 Is this a valid issue, or was it incorrectly flagged by the Agent?
|
||||||||||||||||||||
**kwargs, | ||||||||||||||||||||
): | ||||||||||||||||||||
""" | ||||||||||||||||||||
|
@@ -78,6 +79,8 @@ def __init__( | |||||||||||||||||||
:param pod_template: Custom PodTemplate for this task. | ||||||||||||||||||||
:param pod_template_name: The name of the existing PodTemplate resource which will be used in this task. | ||||||||||||||||||||
:param accelerator: The accelerator to use for this task. | ||||||||||||||||||||
:param shared_memory: If True, then shared memory will be attached to the container where the size is equal | ||||||||||||||||||||
to the allocated memory. If str, then the shared memory is set to that size. | ||||||||||||||||||||
""" | ||||||||||||||||||||
sec_ctx = None | ||||||||||||||||||||
if secret_requests: | ||||||||||||||||||||
|
@@ -128,6 +131,7 @@ def __init__( | |||||||||||||||||||
|
||||||||||||||||||||
self.pod_template = pod_template | ||||||||||||||||||||
self.accelerator = accelerator | ||||||||||||||||||||
self.shared_memory = shared_memory | ||||||||||||||||||||
|
||||||||||||||||||||
@property | ||||||||||||||||||||
def task_resolver(self) -> TaskResolverMixin: | ||||||||||||||||||||
|
@@ -250,10 +254,9 @@ def get_extended_resources(self, settings: SerializationSettings) -> Optional[ta | |||||||||||||||||||
""" | ||||||||||||||||||||
Returns the extended resources to allocate to the task on hosted Flyte. | ||||||||||||||||||||
""" | ||||||||||||||||||||
if self.accelerator is None: | ||||||||||||||||||||
return None | ||||||||||||||||||||
|
||||||||||||||||||||
return tasks_pb2.ExtendedResources(gpu_accelerator=self.accelerator.to_flyte_idl()) | ||||||||||||||||||||
return construct_extended_resources( | ||||||||||||||||||||
accelerator=self.accelerator, shared_memory=self.shared_memory | ||||||||||||||||||||
) | ||||||||||||||||||||
|
||||||||||||||||||||
|
||||||||||||||||||||
class DefaultTaskResolver(TrackedInstance, TaskResolverMixin): | ||||||||||||||||||||
|
Original file line number | Diff line number | Diff line change | ||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -4,7 +4,7 @@ | |||||||||||||||||
import inspect | ||||||||||||||||||
import os | ||||||||||||||||||
from functools import partial, update_wrapper | ||||||||||||||||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Type, TypeVar, Union, overload | ||||||||||||||||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Type, TypeVar, Union, overload, Literal | ||||||||||||||||||
|
||||||||||||||||||
from typing_extensions import ParamSpec # type: ignore | ||||||||||||||||||
|
||||||||||||||||||
|
@@ -128,6 +128,7 @@ def task( | |||||||||||||||||
pod_template_name: Optional[str] = ..., | ||||||||||||||||||
accelerator: Optional[BaseAccelerator] = ..., | ||||||||||||||||||
pickle_untyped: bool = ..., | ||||||||||||||||||
shared_memory: Optional[Union[Literal[True], str]] = None, | ||||||||||||||||||
) -> Callable[[Callable[..., FuncOut]], PythonFunctionTask[T]]: ... | ||||||||||||||||||
|
||||||||||||||||||
|
||||||||||||||||||
|
@@ -167,6 +168,7 @@ def task( | |||||||||||||||||
pod_template_name: Optional[str] = ..., | ||||||||||||||||||
accelerator: Optional[BaseAccelerator] = ..., | ||||||||||||||||||
pickle_untyped: bool = ..., | ||||||||||||||||||
shared_memory: Optional[Union[Literal[True], str]] = ..., | ||||||||||||||||||
) -> Union[Callable[P, FuncOut], PythonFunctionTask[T]]: ... | ||||||||||||||||||
|
||||||||||||||||||
|
||||||||||||||||||
|
@@ -211,6 +213,7 @@ def task( | |||||||||||||||||
pod_template_name: Optional[str] = None, | ||||||||||||||||||
accelerator: Optional[BaseAccelerator] = None, | ||||||||||||||||||
pickle_untyped: bool = False, | ||||||||||||||||||
shared_memory: Optional[Union[bool, str]] = None, | ||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Consider adding shared memory validation
Consider adding validation for the Code suggestionCheck the AI-generated fix before applying
Suggested change
Code Review Run #3d3e41 Is this a valid issue, or was it incorrectly flagged by the Agent?
|
||||||||||||||||||
) -> Union[ | ||||||||||||||||||
Callable[P, FuncOut], | ||||||||||||||||||
Callable[[Callable[P, FuncOut]], PythonFunctionTask[T]], | ||||||||||||||||||
|
@@ -341,6 +344,8 @@ def launch_dynamically(): | |||||||||||||||||
:param pod_template_name: The name of the existing PodTemplate resource which will be used in this task. | ||||||||||||||||||
:param accelerator: The accelerator to use for this task. | ||||||||||||||||||
:param pickle_untyped: Boolean that indicates if the task allows unspecified data types. | ||||||||||||||||||
:param shared_memory: If True, then shared memory will be attached to the container where the size is equal | ||||||||||||||||||
to the allocated memory. If int, then the shared memory is set to that size. | ||||||||||||||||||
""" | ||||||||||||||||||
|
||||||||||||||||||
def wrapper(fn: Callable[P, Any]) -> PythonFunctionTask[T]: | ||||||||||||||||||
|
@@ -390,6 +395,7 @@ def wrapper(fn: Callable[P, Any]) -> PythonFunctionTask[T]: | |||||||||||||||||
pod_template_name=pod_template_name, | ||||||||||||||||||
accelerator=accelerator, | ||||||||||||||||||
pickle_untyped=pickle_untyped, | ||||||||||||||||||
shared_memory=shared_memory, | ||||||||||||||||||
) | ||||||||||||||||||
update_wrapper(task_instance, decorated_fn) | ||||||||||||||||||
return task_instance | ||||||||||||||||||
|
Original file line number | Diff line number | Diff line change | ||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -8,7 +8,9 @@ | |||||||||||
from flytekit.core.resources import ( | ||||||||||||
pod_spec_from_resources, | ||||||||||||
convert_resources_to_resource_model, | ||||||||||||
construct_extended_resources, | ||||||||||||
) | ||||||||||||
from flytekit.extras.accelerators import T4 | ||||||||||||
|
||||||||||||
_ResourceName = _task_models.Resources.ResourceName | ||||||||||||
|
||||||||||||
|
@@ -155,3 +157,18 @@ def test_pod_spec_from_resources_requests_set(): | |||||||||||
) | ||||||||||||
pod_spec = pod_spec_from_resources(k8s_pod_name=k8s_pod_name, requests=requests, limits=limits) | ||||||||||||
assert expected_pod_spec == V1PodSpec(**pod_spec) | ||||||||||||
|
||||||||||||
|
||||||||||||
@pytest.mark.parametrize("shared_memory", [None, False]) | ||||||||||||
def test_construct_extended_resources_shared_memory_none(shared_memory): | ||||||||||||
resources = construct_extended_resources(shared_memory=shared_memory) | ||||||||||||
Comment on lines
+162
to
+164
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Consider consolidating redundant test cases
Consider consolidating the test cases for Code suggestionCheck the AI-generated fix before applying
Suggested change
Code Review Run #3d3e41 Is this a valid issue, or was it incorrectly flagged by the Agent?
|
||||||||||||
assert resources is None | ||||||||||||
|
||||||||||||
|
||||||||||||
@pytest.mark.parametrize("shared_memory, expected_size_limit", [ | ||||||||||||
("2Gi", "2Gi"), | ||||||||||||
(True, ""), | ||||||||||||
]) | ||||||||||||
def test_construct_extended_resources_shared_memory(shared_memory, expected_size_limit): | ||||||||||||
resources = construct_extended_resources(shared_memory=shared_memory) | ||||||||||||
assert resources.shared_memory.size_limit == expected_size_limit |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider using a more secure temporary file location instead of hardcoding '/dev/shm'. The shared memory directory could potentially be accessed by other processes on the system. Consider using 'tempfile.gettempdir()' to get a secure temporary directory location.
Code suggestion
Code Review Run #3d3e41
Is this a valid issue, or was it incorrectly flagged by the Agent?