Skip to content

Commit

Permalink
Remove false error inside dynamic task in local executions (#2675)
Browse files Browse the repository at this point in the history
Signed-off-by: Kevin Su <[email protected]>
  • Loading branch information
pingsutw authored Aug 14, 2024
1 parent 1cd8160 commit 03d2301
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 4 deletions.
9 changes: 6 additions & 3 deletions flytekit/core/node_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import TYPE_CHECKING, Union

from flytekit.core.base_task import PythonTask
from flytekit.core.context_manager import BranchEvalMode, FlyteContext
from flytekit.core.context_manager import BranchEvalMode, ExecutionState, FlyteContext
from flytekit.core.launch_plan import LaunchPlan
from flytekit.core.node import Node
from flytekit.core.promise import VoidPromise
Expand Down Expand Up @@ -129,9 +129,12 @@ def create_node(
return node

# Handling local execution
# Note: execution state is set to TASK_EXECUTION when running dynamic task locally
# Note: execution state is set to DYNAMIC_TASK_EXECUTION when running a dynamic task locally
# https://github.com/flyteorg/flytekit/blob/0815345faf0fae5dc26746a43d4bda4cc2cdf830/flytekit/core/python_function_task.py#L262
elif ctx.execution_state and ctx.execution_state.is_local_execution():
elif ctx.execution_state and (
ctx.execution_state.is_local_execution()
or ctx.execution_state.mode == ExecutionState.Mode.DYNAMIC_TASK_EXECUTION
):
if isinstance(entity, RemoteEntity):
raise AssertionError(f"Remote entities are not yet runnable locally {entity.name}")

Expand Down
3 changes: 3 additions & 0 deletions flytekit/core/promise.py
Original file line number Diff line number Diff line change
Expand Up @@ -1270,6 +1270,9 @@ def flyte_entity_call_handler(
if inspect.iscoroutine(result):
return result

if ctx.execution_state and ctx.execution_state.mode == ExecutionState.Mode.DYNAMIC_TASK_EXECUTION:
return result

if (1 < expected_outputs == len(cast(Tuple[Promise], result))) or (
result is not None and expected_outputs == 1
):
Expand Down
7 changes: 6 additions & 1 deletion flytekit/core/python_function_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,12 @@ def dynamic_execute(self, task_function: Callable, **kwargs) -> Any:
# local_execute directly though since that converts inputs into Promises.
logger.debug(f"Executing Dynamic workflow, using raw inputs {kwargs}")
self._create_and_cache_dynamic_workflow()
function_outputs = cast(PythonFunctionWorkflow, self._wf).execute(**kwargs)
if self.execution_mode == self.ExecutionBehavior.DYNAMIC:
es = ctx.new_execution_state().with_params(mode=ExecutionState.Mode.DYNAMIC_TASK_EXECUTION)
else:
es = cast(ExecutionState, ctx.execution_state)
with FlyteContextManager.with_context(ctx.with_execution_state(es)):
function_outputs = cast(PythonFunctionWorkflow, self._wf).execute(**kwargs)

if isinstance(function_outputs, VoidPromise) or function_outputs is None:
return VoidPromise(self.name)
Expand Down

0 comments on commit 03d2301

Please sign in to comment.