Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add shared_memory to task with extended resources #3096

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions flytekit/core/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,7 @@
CACHE_KEY_METADATA = "cache-key-metadata"

SERIALIZATION_FORMAT = "serialization-format"

# Shared memory mount name and path
SHARED_MEMORY_MOUNT_NAME = "flyte-shared-memory"
SHARED_MEMORY_MOUNT_PATH = "/dev/shm"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Insecure hardcoded temporary file path

Consider using a more secure temporary file location instead of hardcoding '/dev/shm'. The shared memory directory could potentially be accessed by other processes on the system. Consider using 'tempfile.gettempdir()' to get a secure temporary directory location.

Code suggestion
Check the AI-generated fix before applying
Suggested change
SHARED_MEMORY_MOUNT_PATH = "/dev/shm"
import tempfile
SHARED_MEMORY_MOUNT_PATH = tempfile.gettempdir()

Code Review Run #3d3e41


Is this a valid issue, or was it incorrectly flagged by the Agent?

  • it was incorrectly flagged

12 changes: 9 additions & 3 deletions flytekit/core/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@

import datetime
import typing
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Literal, Optional, Union

from flyteidl.core import tasks_pb2

from flytekit.core.resources import Resources, convert_resources_to_resource_model
from flytekit.core.resources import Resources, construct_extended_resources, convert_resources_to_resource_model
from flytekit.core.utils import _dnsify
from flytekit.extras.accelerators import BaseAccelerator
from flytekit.loggers import logger
Expand Down Expand Up @@ -191,6 +191,7 @@ def with_overrides(
cache: Optional[bool] = None,
cache_version: Optional[str] = None,
cache_serialize: Optional[bool] = None,
shared_memory: Optional[Union[bool, str]] = None,
*args,
**kwargs,
):
Expand Down Expand Up @@ -237,7 +238,12 @@ def with_overrides(

if accelerator is not None:
assert_not_promise(accelerator, "accelerator")
self._extended_resources = tasks_pb2.ExtendedResources(gpu_accelerator=accelerator.to_flyte_idl())

if shared_memory is not None:
assert_not_promise(shared_memory, "shared_memory")

self._extended_resources = construct_extended_resources(
accelerator=accelerator, shared_memory=shared_memory)

self._override_node_metadata(name, timeout, retries, interruptible, cache, cache_version, cache_serialize)

Expand Down
15 changes: 9 additions & 6 deletions flytekit/core/python_auto_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import re
from abc import ABC
from dataclasses import dataclass
from typing import Callable, Dict, List, Optional, TypeVar, Union
from typing import Callable, Dict, List, Literal, Optional, TypeVar, Union

from flyteidl.core import tasks_pb2

Expand All @@ -13,7 +13,7 @@
from flytekit.core.base_task import PythonTask, TaskMetadata, TaskResolverMixin
from flytekit.core.context_manager import FlyteContextManager
from flytekit.core.pod_template import PodTemplate
from flytekit.core.resources import Resources, ResourceSpec
from flytekit.core.resources import Resources, ResourceSpec, construct_extended_resources
from flytekit.core.tracked_abc import FlyteTrackedABC
from flytekit.core.tracker import TrackedInstance, extract_task_module
from flytekit.core.utils import _get_container_definition, _serialize_pod_spec, timeit
Expand Down Expand Up @@ -51,6 +51,7 @@ def __init__(
pod_template: Optional[PodTemplate] = None,
pod_template_name: Optional[str] = None,
accelerator: Optional[BaseAccelerator] = None,
shared_memory: Optional[Union[Literal[True], str]] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider adding shared memory format validation

Consider validating the shared_memory parameter when it's a string to ensure it follows memory size format (e.g., '1Gi', '512Mi'). Currently there's no validation for the string format.

Code suggestion
Check the AI-generated fix before applying
Suggested change
shared_memory: Optional[Union[Literal[True], str]] = None,
shared_memory: Optional[Union[Literal[True], str]] = None,
if shared_memory and isinstance(shared_memory, str):
import re
if not re.match(r'^[0-9]+(Ki|Mi|Gi|Ti|Pi|Ei|[KMGTPE]i?)?$', shared_memory):
raise ValueError(
f"Invalid shared memory format: {shared_memory}. "
"Must be a valid memory size (e.g., '1Gi', '512Mi')"
)

Code Review Run #3d3e41


Is this a valid issue, or was it incorrectly flagged by the Agent?

  • it was incorrectly flagged

**kwargs,
):
"""
Expand Down Expand Up @@ -78,6 +79,8 @@ def __init__(
:param pod_template: Custom PodTemplate for this task.
:param pod_template_name: The name of the existing PodTemplate resource which will be used in this task.
:param accelerator: The accelerator to use for this task.
:param shared_memory: If True, then shared memory will be attached to the container where the size is equal
to the allocated memory. If str, then the shared memory is set to that size.
"""
sec_ctx = None
if secret_requests:
Expand Down Expand Up @@ -128,6 +131,7 @@ def __init__(

self.pod_template = pod_template
self.accelerator = accelerator
self.shared_memory = shared_memory

@property
def task_resolver(self) -> TaskResolverMixin:
Expand Down Expand Up @@ -250,10 +254,9 @@ def get_extended_resources(self, settings: SerializationSettings) -> Optional[ta
"""
Returns the extended resources to allocate to the task on hosted Flyte.
"""
if self.accelerator is None:
return None

return tasks_pb2.ExtendedResources(gpu_accelerator=self.accelerator.to_flyte_idl())
return construct_extended_resources(
accelerator=self.accelerator, shared_memory=self.shared_memory
)


class DefaultTaskResolver(TrackedInstance, TaskResolverMixin):
Expand Down
34 changes: 33 additions & 1 deletion flytekit/core/resources.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from dataclasses import dataclass, fields
from typing import Any, List, Optional, Union
from typing import Any, List, Literal, Optional, Union

from flyteidl.core import tasks_pb2
from kubernetes.client import V1Container, V1PodSpec, V1ResourceRequirements
from mashumaro.mixins.json import DataClassJSONMixin

from flytekit.extras.accelerators import BaseAccelerator
from flytekit.models import task as task_models
from flytekit.core.constants import SHARED_MEMORY_MOUNT_PATH, SHARED_MEMORY_MOUNT_NAME


@dataclass
Expand Down Expand Up @@ -102,6 +105,35 @@ def convert_resources_to_resource_model(
return task_models.Resources(requests=request_entries, limits=limit_entries)


def construct_extended_resources(
*,
accelerator: Optional[BaseAccelerator] = None,
shared_memory: Optional[Union[bool, str]] = None,
) -> Optional[tasks_pb2.ExtendedResources]:
"""Convert public extended resources to idl.

:param accelerator: The accelerator to use for this task.
:param shared_memory: If True, then shared memory will be attached to the container where the size is equal
to the allocated memory. If str, then the shared memory is set to that size.
"""
kwargs = {}
if accelerator is not None:
kwargs["gpu_accelerator"] = accelerator.to_flyte_idl()
if isinstance(shared_memory, str) or shared_memory is True:
if shared_memory is True:
shared_memory = None
kwargs["shared_memory"] = tasks_pb2.SharedMemory(
mount_name=SHARED_MEMORY_MOUNT_NAME,
mount_path=SHARED_MEMORY_MOUNT_PATH,
size_limit=shared_memory,
)

if not kwargs:
return None

return tasks_pb2.ExtendedResources(**kwargs)


def pod_spec_from_resources(
k8s_pod_name: str,
requests: Optional[Resources] = None,
Expand Down
8 changes: 7 additions & 1 deletion flytekit/core/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import inspect
import os
from functools import partial, update_wrapper
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Type, TypeVar, Union, overload
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Type, TypeVar, Union, overload, Literal

from typing_extensions import ParamSpec # type: ignore

Expand Down Expand Up @@ -128,6 +128,7 @@ def task(
pod_template_name: Optional[str] = ...,
accelerator: Optional[BaseAccelerator] = ...,
pickle_untyped: bool = ...,
shared_memory: Optional[Union[Literal[True], str]] = None,
) -> Callable[[Callable[..., FuncOut]], PythonFunctionTask[T]]: ...


Expand Down Expand Up @@ -167,6 +168,7 @@ def task(
pod_template_name: Optional[str] = ...,
accelerator: Optional[BaseAccelerator] = ...,
pickle_untyped: bool = ...,
shared_memory: Optional[Union[Literal[True], str]] = ...,
) -> Union[Callable[P, FuncOut], PythonFunctionTask[T]]: ...


Expand Down Expand Up @@ -211,6 +213,7 @@ def task(
pod_template_name: Optional[str] = None,
accelerator: Optional[BaseAccelerator] = None,
pickle_untyped: bool = False,
shared_memory: Optional[Union[bool, str]] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider adding shared memory validation

Consider adding validation for the shared_memory parameter to ensure it is either a boolean or a valid memory size string (e.g. '1Gi', '512Mi'). Currently there is no validation which could lead to runtime errors.

Code suggestion
Check the AI-generated fix before applying
Suggested change
shared_memory: Optional[Union[bool, str]] = None,
shared_memory: Optional[Union[bool, str]] = None,
def validate_shared_memory(val: Optional[Union[bool, str]]) -> None:
if val is not None and not isinstance(val, bool):
if not isinstance(val, str) or not re.match(r'^[0-9]+(Mi|Gi)$', val):
raise ValueError('shared_memory must be a boolean or valid memory size string (e.g. "1Gi", "512Mi")')
if shared_memory is not None:
validate_shared_memory(shared_memory)

Code Review Run #3d3e41


Is this a valid issue, or was it incorrectly flagged by the Agent?

  • it was incorrectly flagged

) -> Union[
Callable[P, FuncOut],
Callable[[Callable[P, FuncOut]], PythonFunctionTask[T]],
Expand Down Expand Up @@ -341,6 +344,8 @@ def launch_dynamically():
:param pod_template_name: The name of the existing PodTemplate resource which will be used in this task.
:param accelerator: The accelerator to use for this task.
:param pickle_untyped: Boolean that indicates if the task allows unspecified data types.
:param shared_memory: If True, then shared memory will be attached to the container where the size is equal
to the allocated memory. If int, then the shared memory is set to that size.
"""

def wrapper(fn: Callable[P, Any]) -> PythonFunctionTask[T]:
Expand Down Expand Up @@ -390,6 +395,7 @@ def wrapper(fn: Callable[P, Any]) -> PythonFunctionTask[T]:
pod_template_name=pod_template_name,
accelerator=accelerator,
pickle_untyped=pickle_untyped,
shared_memory=shared_memory,
)
update_wrapper(task_instance, decorated_fn)
return task_instance
Expand Down
6 changes: 6 additions & 0 deletions flytekit/tools/script_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,15 +230,21 @@ def list_imported_modules_as_files(source_path: str, modules: List[ModuleType])
if mod_file is None:
continue

# if _file_is_in_directory(mod_file, flytekit_root):
# files.append(mod_file)
# continue

if any(_file_is_in_directory(mod_file, directory) for directory in invalid_directories):
continue


if not _file_is_in_directory(mod_file, source_path):
# Only upload files where the module file in the source directory
continue

files.append(mod_file)


return files


Expand Down
20 changes: 20 additions & 0 deletions tests/flytekit/unit/core/test_array_node_map_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,6 +487,26 @@ def wf(x: typing.List[int]):
assert task_spec.template.extended_resources.gpu_accelerator.device == "test_gpu"


def test_serialization_extended_resources_shared_memory(serialization_settings):
@task(
shared_memory="2Gi"
)
def t1(a: int) -> int:
return a + 1

arraynode_maptask = map_task(t1)

@workflow
def wf(x: typing.List[int]):
return arraynode_maptask(a=x)

od = OrderedDict()
get_serializable(od, serialization_settings, wf)
task_spec = od[arraynode_maptask]

assert task_spec.template.extended_resources.shared_memory.size_limit == "2Gi"


def test_supported_node_type():
@task
def test_task():
Expand Down
23 changes: 23 additions & 0 deletions tests/flytekit/unit/core/test_node_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,6 +497,29 @@ def my_wf() -> str:
assert not accelerator.HasField("unpartitioned")


def test_override_shared_memory():
@task(shared_memory=True)
def bar() -> str:
return "hello"

@workflow
def my_wf() -> str:
return bar().with_overrides(shared_memory="128Mi")

serialization_settings = flytekit.configuration.SerializationSettings(
project="test_proj",
domain="test_domain",
version="abc",
image_config=ImageConfig(Image(name="name", fqn="image", tag="name")),
env={},
)
wf_spec = get_serializable(OrderedDict(), serialization_settings, my_wf)
assert len(wf_spec.template.nodes) == 1
assert wf_spec.template.nodes[0].task_node.overrides is not None
assert wf_spec.template.nodes[0].task_node.overrides.extended_resources is not None
shared_memory = wf_spec.template.nodes[0].task_node.overrides.extended_resources.shared_memory


def test_cache_override_values():
@task
def t1(a: str) -> str:
Expand Down
17 changes: 17 additions & 0 deletions tests/flytekit/unit/core/test_resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
from flytekit.core.resources import (
pod_spec_from_resources,
convert_resources_to_resource_model,
construct_extended_resources,
)
from flytekit.extras.accelerators import T4

_ResourceName = _task_models.Resources.ResourceName

Expand Down Expand Up @@ -155,3 +157,18 @@ def test_pod_spec_from_resources_requests_set():
)
pod_spec = pod_spec_from_resources(k8s_pod_name=k8s_pod_name, requests=requests, limits=limits)
assert expected_pod_spec == V1PodSpec(**pod_spec)


@pytest.mark.parametrize("shared_memory", [None, False])
def test_construct_extended_resources_shared_memory_none(shared_memory):
resources = construct_extended_resources(shared_memory=shared_memory)
Comment on lines +162 to +164
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider consolidating redundant test cases

Consider consolidating the test cases for None and False into a single test case since they produce the same behavior. Both values result in resources being None.

Code suggestion
Check the AI-generated fix before applying
Suggested change
@pytest.mark.parametrize("shared_memory", [None, False])
def test_construct_extended_resources_shared_memory_none(shared_memory):
resources = construct_extended_resources(shared_memory=shared_memory)
def test_construct_extended_resources_shared_memory_none():
resources = construct_extended_resources(shared_memory=None)

Code Review Run #3d3e41


Is this a valid issue, or was it incorrectly flagged by the Agent?

  • it was incorrectly flagged

assert resources is None


@pytest.mark.parametrize("shared_memory, expected_size_limit", [
("2Gi", "2Gi"),
(True, ""),
])
def test_construct_extended_resources_shared_memory(shared_memory, expected_size_limit):
resources = construct_extended_resources(shared_memory=shared_memory)
assert resources.shared_memory.size_limit == expected_size_limit
4 changes: 3 additions & 1 deletion tests/flytekit/unit/models/test_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from flyteidl.core.tasks_pb2 import ExtendedResources, TaskMetadata
from google.protobuf import text_format

from flytekit.core.resources import construct_extended_resources
import flytekit.models.interface as interface_models
import flytekit.models.literals as literal_models
from flytekit import Description, Documentation, SourceCode
Expand Down Expand Up @@ -110,7 +111,7 @@ def test_task_template(in_tuple):
{"d": "e"},
),
config={"a": "b"},
extended_resources=ExtendedResources(gpu_accelerator=T4.to_flyte_idl()),
extended_resources=construct_extended_resources(accelerator=T4, shared_memory="2Gi"),
)
assert obj.id.resource_type == identifier.ResourceType.TASK
assert obj.id.project == "project"
Expand All @@ -130,6 +131,7 @@ def test_task_template(in_tuple):
assert obj.extended_resources.gpu_accelerator.device == "nvidia-tesla-t4"
assert not obj.extended_resources.gpu_accelerator.HasField("unpartitioned")
assert not obj.extended_resources.gpu_accelerator.HasField("partition_size")
assert obj.extended_resources.shared_memory.size_limit == "2Gi"


def test_task_spec():
Expand Down
Loading