Skip to content

Commit

Permalink
Fix flows to support cycles and added in test (#1556)
Browse files Browse the repository at this point in the history
  • Loading branch information
bhancockio authored Nov 5, 2024
1 parent 8204de6 commit 9a979f8
Show file tree
Hide file tree
Showing 2 changed files with 289 additions and 28 deletions.
53 changes: 25 additions & 28 deletions src/crewai/flow/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,6 @@ def __new__(mcs, name, bases, dct):
condition_type = getattr(attr_value, "__condition_type__", "OR")
listeners[attr_name] = (condition_type, methods)

# TODO: should we add a check for __condition_type__ 'AND'?
elif hasattr(attr_value, "__is_router__"):
routers[attr_value.__router_for__] = attr_name
possible_returns = get_possible_return_constants(attr_value)
Expand Down Expand Up @@ -171,8 +170,7 @@ class _FlowGeneric(cls): # type: ignore
def __init__(self) -> None:
self._methods: Dict[str, Callable] = {}
self._state: T = self._create_initial_state()
self._executed_methods: Set[str] = set()
self._scheduled_tasks: Set[str] = set()
self._method_execution_counts: Dict[str, int] = {}
self._pending_and_listeners: Dict[str, Set[str]] = {}
self._method_outputs: List[Any] = [] # List to store all method outputs

Expand Down Expand Up @@ -309,7 +307,10 @@ async def _execute_method(
)
self._method_outputs.append(result) # Store the output

self._executed_methods.add(method_name)
# Track method execution counts
self._method_execution_counts[method_name] = (
self._method_execution_counts.get(method_name, 0) + 1
)

return result

Expand All @@ -319,35 +320,34 @@ async def _execute_listeners(self, trigger_method: str, result: Any) -> None:
if trigger_method in self._routers:
router_method = self._methods[self._routers[trigger_method]]
path = await self._execute_method(
trigger_method, router_method
) # TODO: Change or not?
# Use the path as the new trigger method
self._routers[trigger_method], router_method
)
trigger_method = path

for listener_name, (condition_type, methods) in self._listeners.items():
if condition_type == "OR":
if trigger_method in methods:
if (
listener_name not in self._executed_methods
and listener_name not in self._scheduled_tasks
):
self._scheduled_tasks.add(listener_name)
listener_tasks.append(
self._execute_single_listener(listener_name, result)
)
# Schedule the listener without preventing re-execution
listener_tasks.append(
self._execute_single_listener(listener_name, result)
)
elif condition_type == "AND":
if all(method in self._executed_methods for method in methods):
if (
listener_name not in self._executed_methods
and listener_name not in self._scheduled_tasks
):
self._scheduled_tasks.add(listener_name)
listener_tasks.append(
self._execute_single_listener(listener_name, result)
)
# Initialize pending methods for this listener if not already done
if listener_name not in self._pending_and_listeners:
self._pending_and_listeners[listener_name] = set(methods)
# Remove the trigger method from pending methods
self._pending_and_listeners[listener_name].discard(trigger_method)
if not self._pending_and_listeners[listener_name]:
# All required methods have been executed
listener_tasks.append(
self._execute_single_listener(listener_name, result)
)
# Reset pending methods for this listener
self._pending_and_listeners.pop(listener_name, None)

# Run all listener tasks concurrently and wait for them to complete
await asyncio.gather(*listener_tasks)
if listener_tasks:
await asyncio.gather(*listener_tasks)

async def _execute_single_listener(self, listener_name: str, result: Any) -> None:
try:
Expand All @@ -367,9 +367,6 @@ async def _execute_single_listener(self, listener_name: str, result: Any) -> Non
# If listener does not expect parameters, call without arguments
listener_result = await self._execute_method(listener_name, method)

# Remove from scheduled tasks after execution
self._scheduled_tasks.discard(listener_name)

# Execute listeners of this listener
await self._execute_listeners(listener_name, listener_result)
except Exception as e:
Expand Down
264 changes: 264 additions & 0 deletions tests/flow_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,264 @@
"""Test Flow creation and execution basic functionality."""

import asyncio

import pytest
from crewai.flow.flow import Flow, and_, listen, or_, router, start


def test_simple_sequential_flow():
"""Test a simple flow with two steps called sequentially."""
execution_order = []

class SimpleFlow(Flow):
@start()
def step_1(self):
execution_order.append("step_1")

@listen(step_1)
def step_2(self):
execution_order.append("step_2")

flow = SimpleFlow()
flow.kickoff()

assert execution_order == ["step_1", "step_2"]


def test_flow_with_multiple_starts():
"""Test a flow with multiple start methods."""
execution_order = []

class MultiStartFlow(Flow):
@start()
def step_a(self):
execution_order.append("step_a")

@start()
def step_b(self):
execution_order.append("step_b")

@listen(step_a)
def step_c(self):
execution_order.append("step_c")

@listen(step_b)
def step_d(self):
execution_order.append("step_d")

flow = MultiStartFlow()
flow.kickoff()

assert "step_a" in execution_order
assert "step_b" in execution_order
assert "step_c" in execution_order
assert "step_d" in execution_order
assert execution_order.index("step_c") > execution_order.index("step_a")
assert execution_order.index("step_d") > execution_order.index("step_b")


def test_cyclic_flow():
"""Test a cyclic flow that runs a finite number of iterations."""
execution_order = []

class CyclicFlow(Flow):
iteration = 0
max_iterations = 3

@start("loop")
def step_1(self):
if self.iteration >= self.max_iterations:
return # Do not proceed further
execution_order.append(f"step_1_{self.iteration}")

@listen(step_1)
def step_2(self):
execution_order.append(f"step_2_{self.iteration}")

@router(step_2)
def step_3(self):
execution_order.append(f"step_3_{self.iteration}")
self.iteration += 1
if self.iteration < self.max_iterations:
return "loop"

return "exit"

flow = CyclicFlow()
flow.kickoff()

expected_order = []
for i in range(flow.max_iterations):
expected_order.extend([f"step_1_{i}", f"step_2_{i}", f"step_3_{i}"])

assert execution_order == expected_order


def test_flow_with_and_condition():
"""Test a flow where a step waits for multiple other steps to complete."""
execution_order = []

class AndConditionFlow(Flow):
@start()
def step_1(self):
execution_order.append("step_1")

@start()
def step_2(self):
execution_order.append("step_2")

@listen(and_(step_1, step_2))
def step_3(self):
execution_order.append("step_3")

flow = AndConditionFlow()
flow.kickoff()

assert "step_1" in execution_order
assert "step_2" in execution_order
assert execution_order[-1] == "step_3"
assert execution_order.index("step_3") > execution_order.index("step_1")
assert execution_order.index("step_3") > execution_order.index("step_2")


def test_flow_with_or_condition():
"""Test a flow where a step is triggered when any of multiple steps complete."""
execution_order = []

class OrConditionFlow(Flow):
@start()
def step_a(self):
execution_order.append("step_a")

@start()
def step_b(self):
execution_order.append("step_b")

@listen(or_(step_a, step_b))
def step_c(self):
execution_order.append("step_c")

flow = OrConditionFlow()
flow.kickoff()

assert "step_a" in execution_order or "step_b" in execution_order
assert "step_c" in execution_order
assert execution_order.index("step_c") > min(
execution_order.index("step_a"), execution_order.index("step_b")
)


def test_flow_with_router():
"""Test a flow that uses a router method to determine the next step."""
execution_order = []

class RouterFlow(Flow):
@start()
def start_method(self):
execution_order.append("start_method")

@router(start_method)
def router(self):
execution_order.append("router")
# Ensure the condition is set to True to follow the "step_if_true" path
condition = True
return "step_if_true" if condition else "step_if_false"

@listen("step_if_true")
def truthy(self):
execution_order.append("step_if_true")

@listen("step_if_false")
def falsy(self):
execution_order.append("step_if_false")

flow = RouterFlow()
flow.kickoff()

assert execution_order == ["start_method", "router", "step_if_true"]


def test_async_flow():
"""Test an asynchronous flow."""
execution_order = []

class AsyncFlow(Flow):
@start()
async def step_1(self):
execution_order.append("step_1")
await asyncio.sleep(0.1)

@listen(step_1)
async def step_2(self):
execution_order.append("step_2")
await asyncio.sleep(0.1)

flow = AsyncFlow()
asyncio.run(flow.kickoff_async())

assert execution_order == ["step_1", "step_2"]


def test_flow_with_exceptions():
"""Test flow behavior when exceptions occur in steps."""
execution_order = []

class ExceptionFlow(Flow):
@start()
def step_1(self):
execution_order.append("step_1")
raise ValueError("An error occurred in step_1")

@listen(step_1)
def step_2(self):
execution_order.append("step_2")

flow = ExceptionFlow()

with pytest.raises(ValueError):
flow.kickoff()

# Ensure step_2 did not execute
assert execution_order == ["step_1"]


def test_flow_restart():
"""Test restarting a flow after it has completed."""
execution_order = []

class RestartableFlow(Flow):
@start()
def step_1(self):
execution_order.append("step_1")

@listen(step_1)
def step_2(self):
execution_order.append("step_2")

flow = RestartableFlow()
flow.kickoff()
flow.kickoff() # Restart the flow

assert execution_order == ["step_1", "step_2", "step_1", "step_2"]


def test_flow_with_custom_state():
"""Test a flow that maintains and modifies internal state."""

class StateFlow(Flow):
def __init__(self):
super().__init__()
self.counter = 0

@start()
def step_1(self):
self.counter += 1

@listen(step_1)
def step_2(self):
self.counter *= 2
assert self.counter == 2

flow = StateFlow()
flow.kickoff()
assert flow.counter == 2

0 comments on commit 9a979f8

Please sign in to comment.