diff --git a/flytekit/core/node_creation.py b/flytekit/core/node_creation.py index 58a72f357a..791480435f 100644 --- a/flytekit/core/node_creation.py +++ b/flytekit/core/node_creation.py @@ -3,7 +3,7 @@ from typing import TYPE_CHECKING, Union from flytekit.core.base_task import PythonTask -from flytekit.core.context_manager import BranchEvalMode, FlyteContext +from flytekit.core.context_manager import BranchEvalMode, ExecutionState, FlyteContext from flytekit.core.launch_plan import LaunchPlan from flytekit.core.node import Node from flytekit.core.promise import VoidPromise @@ -129,9 +129,12 @@ def create_node( return node # Handling local execution - # Note: execution state is set to TASK_EXECUTION when running dynamic task locally + # Note: execution state is set to DYNAMIC_TASK_EXECUTION when running a dynamic task locally # https://github.com/flyteorg/flytekit/blob/0815345faf0fae5dc26746a43d4bda4cc2cdf830/flytekit/core/python_function_task.py#L262 - elif ctx.execution_state and ctx.execution_state.is_local_execution(): + elif ctx.execution_state and ( + ctx.execution_state.is_local_execution() + or ctx.execution_state.mode == ExecutionState.Mode.DYNAMIC_TASK_EXECUTION + ): if isinstance(entity, RemoteEntity): raise AssertionError(f"Remote entities are not yet runnable locally {entity.name}") diff --git a/flytekit/core/promise.py b/flytekit/core/promise.py index 6bb07fee3e..847d727948 100644 --- a/flytekit/core/promise.py +++ b/flytekit/core/promise.py @@ -1270,6 +1270,9 @@ def flyte_entity_call_handler( if inspect.iscoroutine(result): return result + if ctx.execution_state and ctx.execution_state.mode == ExecutionState.Mode.DYNAMIC_TASK_EXECUTION: + return result + if (1 < expected_outputs == len(cast(Tuple[Promise], result))) or ( result is not None and expected_outputs == 1 ): diff --git a/flytekit/core/python_function_task.py b/flytekit/core/python_function_task.py index 2c01723bdd..a1b863a092 100644 --- a/flytekit/core/python_function_task.py +++ b/flytekit/core/python_function_task.py @@ -308,7 +308,12 @@ def dynamic_execute(self, task_function: Callable, **kwargs) -> Any: # local_execute directly though since that converts inputs into Promises. logger.debug(f"Executing Dynamic workflow, using raw inputs {kwargs}") self._create_and_cache_dynamic_workflow() - function_outputs = cast(PythonFunctionWorkflow, self._wf).execute(**kwargs) + if self.execution_mode == self.ExecutionBehavior.DYNAMIC: + es = ctx.new_execution_state().with_params(mode=ExecutionState.Mode.DYNAMIC_TASK_EXECUTION) + else: + es = cast(ExecutionState, ctx.execution_state) + with FlyteContextManager.with_context(ctx.with_execution_state(es)): + function_outputs = cast(PythonFunctionWorkflow, self._wf).execute(**kwargs) if isinstance(function_outputs, VoidPromise) or function_outputs is None: return VoidPromise(self.name)