diff --git a/flytekit/core/constants.py b/flytekit/core/constants.py index 903e5d5ced..a80ed0f9e4 100644 --- a/flytekit/core/constants.py +++ b/flytekit/core/constants.py @@ -38,3 +38,7 @@ CACHE_KEY_METADATA = "cache-key-metadata" SERIALIZATION_FORMAT = "serialization-format" + +# Shared memory mount name and path +SHARED_MEMORY_MOUNT_NAME = "flyte-shared-memory" +SHARED_MEMORY_MOUNT_PATH = "/dev/shm" diff --git a/flytekit/core/node.py b/flytekit/core/node.py index 61ae41c060..08bd89391c 100644 --- a/flytekit/core/node.py +++ b/flytekit/core/node.py @@ -2,11 +2,11 @@ import datetime import typing -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Literal, Optional, Union from flyteidl.core import tasks_pb2 -from flytekit.core.resources import Resources, convert_resources_to_resource_model +from flytekit.core.resources import Resources, construct_extended_resources, convert_resources_to_resource_model from flytekit.core.utils import _dnsify from flytekit.extras.accelerators import BaseAccelerator from flytekit.loggers import logger @@ -191,6 +191,7 @@ def with_overrides( cache: Optional[bool] = None, cache_version: Optional[str] = None, cache_serialize: Optional[bool] = None, + shared_memory: Optional[Union[bool, str]] = None, *args, **kwargs, ): @@ -237,7 +238,12 @@ def with_overrides( if accelerator is not None: assert_not_promise(accelerator, "accelerator") - self._extended_resources = tasks_pb2.ExtendedResources(gpu_accelerator=accelerator.to_flyte_idl()) + + if shared_memory is not None: + assert_not_promise(shared_memory, "shared_memory") + + self._extended_resources = construct_extended_resources( + accelerator=accelerator, shared_memory=shared_memory) self._override_node_metadata(name, timeout, retries, interruptible, cache, cache_version, cache_serialize) diff --git a/flytekit/core/python_auto_container.py b/flytekit/core/python_auto_container.py index dfbd678fb6..3cda47e63f 100644 --- a/flytekit/core/python_auto_container.py +++ b/flytekit/core/python_auto_container.py @@ -4,7 +4,7 @@ import re from abc import ABC from dataclasses import dataclass -from typing import Callable, Dict, List, Optional, TypeVar, Union +from typing import Callable, Dict, List, Literal, Optional, TypeVar, Union from flyteidl.core import tasks_pb2 @@ -13,7 +13,7 @@ from flytekit.core.base_task import PythonTask, TaskMetadata, TaskResolverMixin from flytekit.core.context_manager import FlyteContextManager from flytekit.core.pod_template import PodTemplate -from flytekit.core.resources import Resources, ResourceSpec +from flytekit.core.resources import Resources, ResourceSpec, construct_extended_resources from flytekit.core.tracked_abc import FlyteTrackedABC from flytekit.core.tracker import TrackedInstance, extract_task_module from flytekit.core.utils import _get_container_definition, _serialize_pod_spec, timeit @@ -51,6 +51,7 @@ def __init__( pod_template: Optional[PodTemplate] = None, pod_template_name: Optional[str] = None, accelerator: Optional[BaseAccelerator] = None, + shared_memory: Optional[Union[Literal[True], str]] = None, **kwargs, ): """ @@ -78,6 +79,8 @@ def __init__( :param pod_template: Custom PodTemplate for this task. :param pod_template_name: The name of the existing PodTemplate resource which will be used in this task. :param accelerator: The accelerator to use for this task. + :param shared_memory: If True, then shared memory will be attached to the container where the size is equal + to the allocated memory. If str, then the shared memory is set to that size. """ sec_ctx = None if secret_requests: @@ -128,6 +131,7 @@ def __init__( self.pod_template = pod_template self.accelerator = accelerator + self.shared_memory = shared_memory @property def task_resolver(self) -> TaskResolverMixin: @@ -250,10 +254,9 @@ def get_extended_resources(self, settings: SerializationSettings) -> Optional[ta """ Returns the extended resources to allocate to the task on hosted Flyte. """ - if self.accelerator is None: - return None - - return tasks_pb2.ExtendedResources(gpu_accelerator=self.accelerator.to_flyte_idl()) + return construct_extended_resources( + accelerator=self.accelerator, shared_memory=self.shared_memory + ) class DefaultTaskResolver(TrackedInstance, TaskResolverMixin): diff --git a/flytekit/core/resources.py b/flytekit/core/resources.py index 9a334f98f6..0aefb192de 100644 --- a/flytekit/core/resources.py +++ b/flytekit/core/resources.py @@ -1,10 +1,13 @@ from dataclasses import dataclass, fields -from typing import Any, List, Optional, Union +from typing import Any, List, Literal, Optional, Union +from flyteidl.core import tasks_pb2 from kubernetes.client import V1Container, V1PodSpec, V1ResourceRequirements from mashumaro.mixins.json import DataClassJSONMixin +from flytekit.extras.accelerators import BaseAccelerator from flytekit.models import task as task_models +from flytekit.core.constants import SHARED_MEMORY_MOUNT_PATH, SHARED_MEMORY_MOUNT_NAME @dataclass @@ -102,6 +105,35 @@ def convert_resources_to_resource_model( return task_models.Resources(requests=request_entries, limits=limit_entries) +def construct_extended_resources( + *, + accelerator: Optional[BaseAccelerator] = None, + shared_memory: Optional[Union[bool, str]] = None, +) -> Optional[tasks_pb2.ExtendedResources]: + """Convert public extended resources to idl. + + :param accelerator: The accelerator to use for this task. + :param shared_memory: If True, then shared memory will be attached to the container where the size is equal + to the allocated memory. If str, then the shared memory is set to that size. + """ + kwargs = {} + if accelerator is not None: + kwargs["gpu_accelerator"] = accelerator.to_flyte_idl() + if isinstance(shared_memory, str) or shared_memory is True: + if shared_memory is True: + shared_memory = None + kwargs["shared_memory"] = tasks_pb2.SharedMemory( + mount_name=SHARED_MEMORY_MOUNT_NAME, + mount_path=SHARED_MEMORY_MOUNT_PATH, + size_limit=shared_memory, + ) + + if not kwargs: + return None + + return tasks_pb2.ExtendedResources(**kwargs) + + def pod_spec_from_resources( k8s_pod_name: str, requests: Optional[Resources] = None, diff --git a/flytekit/core/task.py b/flytekit/core/task.py index 6451e742c5..0546456f5c 100644 --- a/flytekit/core/task.py +++ b/flytekit/core/task.py @@ -4,7 +4,7 @@ import inspect import os from functools import partial, update_wrapper -from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Type, TypeVar, Union, overload +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Type, TypeVar, Union, overload, Literal from typing_extensions import ParamSpec # type: ignore @@ -128,6 +128,7 @@ def task( pod_template_name: Optional[str] = ..., accelerator: Optional[BaseAccelerator] = ..., pickle_untyped: bool = ..., + shared_memory: Optional[Union[Literal[True], str]] = None, ) -> Callable[[Callable[..., FuncOut]], PythonFunctionTask[T]]: ... @@ -167,6 +168,7 @@ def task( pod_template_name: Optional[str] = ..., accelerator: Optional[BaseAccelerator] = ..., pickle_untyped: bool = ..., + shared_memory: Optional[Union[Literal[True], str]] = ..., ) -> Union[Callable[P, FuncOut], PythonFunctionTask[T]]: ... @@ -211,6 +213,7 @@ def task( pod_template_name: Optional[str] = None, accelerator: Optional[BaseAccelerator] = None, pickle_untyped: bool = False, + shared_memory: Optional[Union[bool, str]] = None, ) -> Union[ Callable[P, FuncOut], Callable[[Callable[P, FuncOut]], PythonFunctionTask[T]], @@ -341,6 +344,8 @@ def launch_dynamically(): :param pod_template_name: The name of the existing PodTemplate resource which will be used in this task. :param accelerator: The accelerator to use for this task. :param pickle_untyped: Boolean that indicates if the task allows unspecified data types. + :param shared_memory: If True, then shared memory will be attached to the container where the size is equal + to the allocated memory. If int, then the shared memory is set to that size. """ def wrapper(fn: Callable[P, Any]) -> PythonFunctionTask[T]: @@ -390,6 +395,7 @@ def wrapper(fn: Callable[P, Any]) -> PythonFunctionTask[T]: pod_template_name=pod_template_name, accelerator=accelerator, pickle_untyped=pickle_untyped, + shared_memory=shared_memory, ) update_wrapper(task_instance, decorated_fn) return task_instance diff --git a/flytekit/tools/script_mode.py b/flytekit/tools/script_mode.py index 6580fa6462..ebc0f948a2 100644 --- a/flytekit/tools/script_mode.py +++ b/flytekit/tools/script_mode.py @@ -230,15 +230,21 @@ def list_imported_modules_as_files(source_path: str, modules: List[ModuleType]) if mod_file is None: continue + # if _file_is_in_directory(mod_file, flytekit_root): + # files.append(mod_file) + # continue + if any(_file_is_in_directory(mod_file, directory) for directory in invalid_directories): continue + if not _file_is_in_directory(mod_file, source_path): # Only upload files where the module file in the source directory continue files.append(mod_file) + return files diff --git a/tests/flytekit/unit/core/test_array_node_map_task.py b/tests/flytekit/unit/core/test_array_node_map_task.py index 97693940e0..e130ece470 100644 --- a/tests/flytekit/unit/core/test_array_node_map_task.py +++ b/tests/flytekit/unit/core/test_array_node_map_task.py @@ -487,6 +487,26 @@ def wf(x: typing.List[int]): assert task_spec.template.extended_resources.gpu_accelerator.device == "test_gpu" +def test_serialization_extended_resources_shared_memory(serialization_settings): + @task( + shared_memory="2Gi" + ) + def t1(a: int) -> int: + return a + 1 + + arraynode_maptask = map_task(t1) + + @workflow + def wf(x: typing.List[int]): + return arraynode_maptask(a=x) + + od = OrderedDict() + get_serializable(od, serialization_settings, wf) + task_spec = od[arraynode_maptask] + + assert task_spec.template.extended_resources.shared_memory.size_limit == "2Gi" + + def test_supported_node_type(): @task def test_task(): diff --git a/tests/flytekit/unit/core/test_node_creation.py b/tests/flytekit/unit/core/test_node_creation.py index 381f456bdb..1eaecefe05 100644 --- a/tests/flytekit/unit/core/test_node_creation.py +++ b/tests/flytekit/unit/core/test_node_creation.py @@ -497,6 +497,29 @@ def my_wf() -> str: assert not accelerator.HasField("unpartitioned") +def test_override_shared_memory(): + @task(shared_memory=True) + def bar() -> str: + return "hello" + + @workflow + def my_wf() -> str: + return bar().with_overrides(shared_memory="128Mi") + + serialization_settings = flytekit.configuration.SerializationSettings( + project="test_proj", + domain="test_domain", + version="abc", + image_config=ImageConfig(Image(name="name", fqn="image", tag="name")), + env={}, + ) + wf_spec = get_serializable(OrderedDict(), serialization_settings, my_wf) + assert len(wf_spec.template.nodes) == 1 + assert wf_spec.template.nodes[0].task_node.overrides is not None + assert wf_spec.template.nodes[0].task_node.overrides.extended_resources is not None + shared_memory = wf_spec.template.nodes[0].task_node.overrides.extended_resources.shared_memory + + def test_cache_override_values(): @task def t1(a: str) -> str: diff --git a/tests/flytekit/unit/core/test_resources.py b/tests/flytekit/unit/core/test_resources.py index 1c09a111e3..dd0dcfe938 100644 --- a/tests/flytekit/unit/core/test_resources.py +++ b/tests/flytekit/unit/core/test_resources.py @@ -8,7 +8,9 @@ from flytekit.core.resources import ( pod_spec_from_resources, convert_resources_to_resource_model, + construct_extended_resources, ) +from flytekit.extras.accelerators import T4 _ResourceName = _task_models.Resources.ResourceName @@ -155,3 +157,18 @@ def test_pod_spec_from_resources_requests_set(): ) pod_spec = pod_spec_from_resources(k8s_pod_name=k8s_pod_name, requests=requests, limits=limits) assert expected_pod_spec == V1PodSpec(**pod_spec) + + +@pytest.mark.parametrize("shared_memory", [None, False]) +def test_construct_extended_resources_shared_memory_none(shared_memory): + resources = construct_extended_resources(shared_memory=shared_memory) + assert resources is None + + +@pytest.mark.parametrize("shared_memory, expected_size_limit", [ + ("2Gi", "2Gi"), + (True, ""), +]) +def test_construct_extended_resources_shared_memory(shared_memory, expected_size_limit): + resources = construct_extended_resources(shared_memory=shared_memory) + assert resources.shared_memory.size_limit == expected_size_limit diff --git a/tests/flytekit/unit/models/test_tasks.py b/tests/flytekit/unit/models/test_tasks.py index b9685736b7..048cfb1db9 100644 --- a/tests/flytekit/unit/models/test_tasks.py +++ b/tests/flytekit/unit/models/test_tasks.py @@ -5,6 +5,7 @@ from flyteidl.core.tasks_pb2 import ExtendedResources, TaskMetadata from google.protobuf import text_format +from flytekit.core.resources import construct_extended_resources import flytekit.models.interface as interface_models import flytekit.models.literals as literal_models from flytekit import Description, Documentation, SourceCode @@ -110,7 +111,7 @@ def test_task_template(in_tuple): {"d": "e"}, ), config={"a": "b"}, - extended_resources=ExtendedResources(gpu_accelerator=T4.to_flyte_idl()), + extended_resources=construct_extended_resources(accelerator=T4, shared_memory="2Gi"), ) assert obj.id.resource_type == identifier.ResourceType.TASK assert obj.id.project == "project" @@ -130,6 +131,7 @@ def test_task_template(in_tuple): assert obj.extended_resources.gpu_accelerator.device == "nvidia-tesla-t4" assert not obj.extended_resources.gpu_accelerator.HasField("unpartitioned") assert not obj.extended_resources.gpu_accelerator.HasField("partition_size") + assert obj.extended_resources.shared_memory.size_limit == "2Gi" def test_task_spec():