Skip to content

Commit

Permalink
Add map task serialization unit test
Browse files Browse the repository at this point in the history
Signed-off-by: Mecoli1219 <[email protected]>
  • Loading branch information
Mecoli1219 committed Sep 26, 2024
1 parent cef8463 commit 579c4b8
Showing 1 changed file with 65 additions and 1 deletion.
66 changes: 65 additions & 1 deletion tests/flytekit/unit/core/test_array_node_map_task.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import functools
import os
import pathlib
import typing
from collections import OrderedDict
from typing import List
Expand All @@ -13,11 +14,12 @@
from flytekit.configuration import FastSerializationSettings, Image, ImageConfig, SerializationSettings
from flytekit.core import context_manager
from flytekit.core.array_node_map_task import ArrayNodeMapTask, ArrayNodeMapTaskResolver
from flytekit.core.python_auto_container import PICKLE_FILE_PATH
from flytekit.core.task import TaskMetadata
from flytekit.core.type_engine import TypeEngine
from flytekit.extras.accelerators import GPUAccelerator
from flytekit.experimental.eager_function import eager
from flytekit.tools.translator import get_serializable
from flytekit.tools.translator import get_serializable, Options
from flytekit.types.pickle import BatchSize


Expand All @@ -33,6 +35,19 @@ def serialization_settings():
)


@pytest.fixture
def interactive_serialization_settings():
default_img = Image(name="default", fqn="test", tag="tag")
return SerializationSettings(
project="project",
domain="domain",
version="version",
env=None,
image_config=ImageConfig(default_image=default_img, images=[default_img]),
interactive_mode_enabled=True,
)


def test_map(serialization_settings):
@task
def say_hello(name: str) -> str:
Expand Down Expand Up @@ -134,6 +149,55 @@ def t1(a: int) -> int:
]


def test_interactive_serialization(interactive_serialization_settings):
@task
def t1(a: int) -> int:
return a + 1

def mock_file_uploader(dest: pathlib.Path):
return (0, dest.name)

arraynode_maptask = map_task(t1, metadata=TaskMetadata(retries=2))
option = Options()
option.file_uploader = mock_file_uploader
task_spec = get_serializable(OrderedDict(), interactive_serialization_settings, arraynode_maptask, options=option)

assert task_spec.template.metadata.retries.retries == 2
assert task_spec.template.custom["minSuccessRatio"] == 1.0
assert task_spec.template.type == "python-task"
assert task_spec.template.task_type_version == 1
assert task_spec.template.container.args == [
"pyflyte-fast-execute",
"--additional-distribution",
PICKLE_FILE_PATH,
"--dest-dir",
".",
"--",
"pyflyte-map-execute",
"--inputs",
"{{.input}}",
"--output-prefix",
"{{.outputPrefix}}",
"--raw-output-data-prefix",
"{{.rawOutputDataPrefix}}",
"--checkpoint-path",
"{{.checkpointOutputPrefix}}",
"--prev-checkpoint",
"{{.prevCheckpointPrefix}}",
"--resolver",
"flytekit.core.array_node_map_task.ArrayNodeMapTaskResolver",
"--",
"vars",
"",
"resolver",
"flytekit.core.python_auto_container.default_notebook_task_resolver",
"task-module",
"tests.flytekit.unit.core.test_array_node_map_task",
"task-name",
"t1",
]


def test_fast_serialization(serialization_settings):
serialization_settings.fast_serialization_settings = FastSerializationSettings(enabled=True)

Expand Down

0 comments on commit 579c4b8

Please sign in to comment.