Skip to content

Commit

Permalink
kevin's update
Browse files Browse the repository at this point in the history
Signed-off-by: Kevin Su <[email protected]>
  • Loading branch information
pingsutw committed Jun 17, 2024
1 parent 77cbf14 commit f8c2ad3
Showing 1 changed file with 7 additions and 5 deletions.
12 changes: 7 additions & 5 deletions flytekit/core/array_node_map_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,16 @@
from flytekit.core.context_manager import ExecutionState, FlyteContext, FlyteContextManager
from flytekit.core.interface import transform_interface_to_list_interface
from flytekit.core.python_function_task import PythonFunctionTask, PythonInstanceTask
from flytekit.core.type_engine import TypeEngine
from flytekit.core.type_engine import ListTransformer, TypeEngine
from flytekit.core.utils import timeit
from flytekit.exceptions import scopes as exception_scopes
from flytekit.loggers import logger
from flytekit.models import literals as _literal_models
from flytekit.models.array_job import ArrayJob
from flytekit.models.core.workflow import NodeMetadata
from flytekit.models.interface import Variable
from flytekit.models.task import Container, K8sPod, Sql, Task
from flytekit.tools.module_loader import load_object_from_module
from flytekit.models import literals as _literal_models


class ArrayNodeMapTask(PythonTask):
Expand Down Expand Up @@ -216,14 +216,16 @@ def _literal_map_to_python_input(
) -> Dict[str, Any]:
task_index = self._compute_array_job_index()
map_task_inputs = {}

inputs = self.python_interface.inputs
for k in self.interface.inputs.keys():
v = literal_map.literals[k]
if isinstance(v, list) and k not in self.bound_inputs:
map_task_inputs[k] = v[task_index]
if v.collection and k not in self.bound_inputs:
map_task_inputs[k] = v.collection.literals[task_index]
v = self.python_interface.inputs[k]
sub_type = ListTransformer.get_sub_type(v)
if typing.get_origin(v) is list:
inputs[k] = v.__args__[0]
inputs[k] = sub_type
else:
map_task_inputs[k] = v
return TypeEngine.literal_map_to_kwargs(ctx, _literal_models.LiteralMap(literals=map_task_inputs), inputs)
Expand Down

0 comments on commit f8c2ad3

Please sign in to comment.