Skip to content

Commit

Permalink
fix: cleaned up SDK (#182)
Browse files Browse the repository at this point in the history
  • Loading branch information
vijayvammi authored Jan 11, 2025
1 parent 955deef commit 74ba96b
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 102 deletions.
3 changes: 2 additions & 1 deletion examples/02-sequential/on_failure_fail.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ def main():
step_3 = Stub(name="step 3", terminate_with_success=True)
step_4 = Stub(name="step 4", terminate_with_failure=True) # (1)

step_1.on_failure = step_4.name
on_failure_pipeline = Pipeline(steps=[step_4])
step_1.on_failure = on_failure_pipeline # (2)

pipeline = Pipeline(
steps=[step_1, step_2, step_3],
Expand Down
6 changes: 4 additions & 2 deletions examples/02-sequential/on_failure_succeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,12 @@ def main():
step_3 = Stub(name="step 3", terminate_with_success=True)
step_4 = Stub(name="step 4", terminate_with_success=True) # (1)

step_1.on_failure = step_4.name
on_failure_pipeline = Pipeline(steps=[step_4])

step_1.on_failure = on_failure_pipeline # (2)

pipeline = Pipeline(
steps=[step_1, step_2, step_3, [step_4]],
steps=[step_1, step_2, step_3],
)
pipeline.execute()

Expand Down
139 changes: 40 additions & 99 deletions runnable/sdk.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ class BaseTraversal(ABC, BaseModel):
next_node: str = Field(default="", serialization_alias="next_node")
terminate_with_success: bool = Field(default=False, exclude=True)
terminate_with_failure: bool = Field(default=False, exclude=True)
on_failure: str = Field(default="", alias="on_failure")
on_failure: Optional[Pipeline] = Field(default=None)

model_config = ConfigDict(extra="forbid")

Expand Down Expand Up @@ -117,18 +117,6 @@ def __lshift__(self, other: TraversalNode) -> TraversalNode:

return other

def depends_on(self, node: StepType) -> Self:
assert not isinstance(node, Success)
assert not isinstance(node, Fail)

if node.next_node:
raise Exception(
f"The {node} node already has a next node: {node.next_node}"
)

node.next_node = self.name
return self

@model_validator(mode="after")
def validate_terminations(self) -> Self:
if self.terminate_with_failure and self.terminate_with_success:
Expand Down Expand Up @@ -175,7 +163,6 @@ def serialize_returns(
if isinstance(x, str):
task_returns.append(TaskReturns(name=x, kind="json"))
continue

# Its already task returns
task_returns.append(x)

Expand All @@ -188,6 +175,9 @@ def create_node(self) -> TaskNode:
"A node not being terminated must have a user defined next node"
)

if self.on_failure:
self.on_failure = self.on_failure.steps[0].name # type: ignore

return TaskNode.parse_from_config(
self.model_dump(exclude_none=True, by_alias=True)
)
Expand Down Expand Up @@ -605,8 +595,6 @@ class Pipeline(BaseModel):
The order of steps is important as it determines the order of execution.
Any on failure behavior should the first step in ```on_failure``` pipelines.
on_failure (List[List[Pipeline], optional): A list of Pipelines to execute in case of failure.
For example, for the below pipeline:
Expand All @@ -624,7 +612,7 @@ class Pipeline(BaseModel):
"""

steps: List[Union[StepType, List["Pipeline"]]]
steps: List[StepType]
name: str = ""
description: str = ""

Expand All @@ -637,114 +625,67 @@ def add_terminal_nodes(self) -> bool:
_dag: graph.Graph = PrivateAttr()
model_config = ConfigDict(extra="forbid")

def _validate_path(self, path: List[StepType], failure_path: bool = False) -> None:
# TODO: Drastically simplify this
# Check if one and only one step terminates with success
# Check no more than one step terminates with failure

reached_success = False
reached_failure = False

for step in path:
if step.terminate_with_success:
if reached_success:
raise Exception(
"A pipeline cannot have more than one step that terminates with success"
)
reached_success = True
continue
if step.terminate_with_failure:
if reached_failure:
raise Exception(
"A pipeline cannot have more than one step that terminates with failure"
)
reached_failure = True

if not reached_success and not reached_failure:
raise Exception(
"A pipeline must have at least one step that terminates with success"
)

def _construct_path(self, path: List[StepType]) -> None:
prev_step = path[0]

for step in path:
if step == prev_step:
continue

if prev_step.terminate_with_success or prev_step.terminate_with_failure:
raise Exception(
f"A step that terminates with success/failure cannot have a next step: {prev_step}"
)

if prev_step.next_node and prev_step.next_node not in ["success", "fail"]:
raise Exception(f"Step already has a next node: {prev_step} ")

prev_step.next_node = step.name
prev_step = step

def model_post_init(self, __context: Any) -> None:
"""
The sequence of steps can either be:
[step1, step2,..., stepN, [step11, step12,..., step1N], [step21, step22,...,]]
[step1, step2,..., stepN]
indicates:
- step1 > step2 > ... > stepN
- We expect terminate with success or fail to be explicitly stated on a step.
- If it is stated, the step cannot have a next step defined apart from "success" and "fail".
The inner list of steps is only to accommodate on-failure behaviors.
- For sake of simplicity, lets assume that it has the same behavior as the happy pipeline.
- A task which was already seen should not be part of this.
- There should be at least one step which terminates with success
Any definition of pipeline should have one node that terminates with success.
"""
# TODO: Bug with repeat names
# TODO: https://github.com/AstraZeneca/runnable/issues/156
# The last step of the pipeline is defaulted to be a success step
# unless it is explicitly stated to terminate with failure.
terminal_step: StepType = self.steps[-1]
if not terminal_step.terminate_with_failure:
terminal_step.terminate_with_success = True

success_path: List[StepType] = []
on_failure_paths: List[List[StepType]] = []
# assert that there is only one termination node with success or failure
# Assert that there are no duplicate step names
observed: Dict[str, str] = {}
count_termination: int = 0

for step in self.steps:
if isinstance(
step, (Stub, PythonTask, NotebookTask, ShellTask, Parallel, Map)
):
success_path.append(step)
continue
# on_failure_paths.append(step)

if not success_path:
raise Exception("There should be some success path")

# Check all paths are valid and construct the path
paths = [success_path] + on_failure_paths
failure_path = False
for path in paths:
self._validate_path(path, failure_path)
self._construct_path(path)

failure_path = True
if step.terminate_with_success or step.terminate_with_failure:
count_termination += 1
if step.name in observed:
raise Exception(
f"Step names should be unique. Found duplicate: {step.name}"
)
observed[step.name] = step.name

all_steps: List[StepType] = []
if count_termination > 1:
raise AssertionError(
"A pipeline can only have one termination node with success or failure"
)

for path in paths:
for step in path:
all_steps.append(step)
# link the steps by assigning the next_node name to be that name of the node
# immediately after it.
for i in range(len(self.steps) - 1):
self.steps[i] >> self.steps[i + 1]

seen = set()
unique = [x for x in all_steps if not (x in seen or seen.add(x))] # type: ignore
# Add any on_failure pipelines to the steps
gathered_on_failure: List[StepType] = []
for step in self.steps:
if step.on_failure:
gathered_on_failure.extend(step.on_failure.steps)

self._dag = graph.Graph(
start_at=all_steps[0].name,
start_at=self.steps[0].name,
description=self.description,
internal_branch_name=self.internal_branch_name,
)

for step in unique:
self.steps.extend(gathered_on_failure)

for step in self.steps:
self._dag.add_node(step.create_node())

if self.add_terminal_nodes:
self._dag.add_terminal_nodes()
self._dag.add_terminal_nodes()

self._dag.check_graph()

Expand Down

0 comments on commit 74ba96b

Please sign in to comment.