diff --git a/src/crewai/flow/flow.py b/src/crewai/flow/flow.py index 9b6463d659..fa09025945 100644 --- a/src/crewai/flow/flow.py +++ b/src/crewai/flow/flow.py @@ -131,7 +131,6 @@ def __new__(mcs, name, bases, dct): condition_type = getattr(attr_value, "__condition_type__", "OR") listeners[attr_name] = (condition_type, methods) - # TODO: should we add a check for __condition_type__ 'AND'? elif hasattr(attr_value, "__is_router__"): routers[attr_value.__router_for__] = attr_name possible_returns = get_possible_return_constants(attr_value) @@ -171,8 +170,7 @@ class _FlowGeneric(cls): # type: ignore def __init__(self) -> None: self._methods: Dict[str, Callable] = {} self._state: T = self._create_initial_state() - self._executed_methods: Set[str] = set() - self._scheduled_tasks: Set[str] = set() + self._method_execution_counts: Dict[str, int] = {} self._pending_and_listeners: Dict[str, Set[str]] = {} self._method_outputs: List[Any] = [] # List to store all method outputs @@ -309,7 +307,10 @@ async def _execute_method( ) self._method_outputs.append(result) # Store the output - self._executed_methods.add(method_name) + # Track method execution counts + self._method_execution_counts[method_name] = ( + self._method_execution_counts.get(method_name, 0) + 1 + ) return result @@ -319,35 +320,34 @@ async def _execute_listeners(self, trigger_method: str, result: Any) -> None: if trigger_method in self._routers: router_method = self._methods[self._routers[trigger_method]] path = await self._execute_method( - trigger_method, router_method - ) # TODO: Change or not? - # Use the path as the new trigger method + self._routers[trigger_method], router_method + ) trigger_method = path for listener_name, (condition_type, methods) in self._listeners.items(): if condition_type == "OR": if trigger_method in methods: - if ( - listener_name not in self._executed_methods - and listener_name not in self._scheduled_tasks - ): - self._scheduled_tasks.add(listener_name) - listener_tasks.append( - self._execute_single_listener(listener_name, result) - ) + # Schedule the listener without preventing re-execution + listener_tasks.append( + self._execute_single_listener(listener_name, result) + ) elif condition_type == "AND": - if all(method in self._executed_methods for method in methods): - if ( - listener_name not in self._executed_methods - and listener_name not in self._scheduled_tasks - ): - self._scheduled_tasks.add(listener_name) - listener_tasks.append( - self._execute_single_listener(listener_name, result) - ) + # Initialize pending methods for this listener if not already done + if listener_name not in self._pending_and_listeners: + self._pending_and_listeners[listener_name] = set(methods) + # Remove the trigger method from pending methods + self._pending_and_listeners[listener_name].discard(trigger_method) + if not self._pending_and_listeners[listener_name]: + # All required methods have been executed + listener_tasks.append( + self._execute_single_listener(listener_name, result) + ) + # Reset pending methods for this listener + self._pending_and_listeners.pop(listener_name, None) # Run all listener tasks concurrently and wait for them to complete - await asyncio.gather(*listener_tasks) + if listener_tasks: + await asyncio.gather(*listener_tasks) async def _execute_single_listener(self, listener_name: str, result: Any) -> None: try: @@ -367,9 +367,6 @@ async def _execute_single_listener(self, listener_name: str, result: Any) -> Non # If listener does not expect parameters, call without arguments listener_result = await self._execute_method(listener_name, method) - # Remove from scheduled tasks after execution - self._scheduled_tasks.discard(listener_name) - # Execute listeners of this listener await self._execute_listeners(listener_name, listener_result) except Exception as e: diff --git a/tests/flow_test.py b/tests/flow_test.py new file mode 100644 index 0000000000..ffd82367c1 --- /dev/null +++ b/tests/flow_test.py @@ -0,0 +1,264 @@ +"""Test Flow creation and execution basic functionality.""" + +import asyncio + +import pytest +from crewai.flow.flow import Flow, and_, listen, or_, router, start + + +def test_simple_sequential_flow(): + """Test a simple flow with two steps called sequentially.""" + execution_order = [] + + class SimpleFlow(Flow): + @start() + def step_1(self): + execution_order.append("step_1") + + @listen(step_1) + def step_2(self): + execution_order.append("step_2") + + flow = SimpleFlow() + flow.kickoff() + + assert execution_order == ["step_1", "step_2"] + + +def test_flow_with_multiple_starts(): + """Test a flow with multiple start methods.""" + execution_order = [] + + class MultiStartFlow(Flow): + @start() + def step_a(self): + execution_order.append("step_a") + + @start() + def step_b(self): + execution_order.append("step_b") + + @listen(step_a) + def step_c(self): + execution_order.append("step_c") + + @listen(step_b) + def step_d(self): + execution_order.append("step_d") + + flow = MultiStartFlow() + flow.kickoff() + + assert "step_a" in execution_order + assert "step_b" in execution_order + assert "step_c" in execution_order + assert "step_d" in execution_order + assert execution_order.index("step_c") > execution_order.index("step_a") + assert execution_order.index("step_d") > execution_order.index("step_b") + + +def test_cyclic_flow(): + """Test a cyclic flow that runs a finite number of iterations.""" + execution_order = [] + + class CyclicFlow(Flow): + iteration = 0 + max_iterations = 3 + + @start("loop") + def step_1(self): + if self.iteration >= self.max_iterations: + return # Do not proceed further + execution_order.append(f"step_1_{self.iteration}") + + @listen(step_1) + def step_2(self): + execution_order.append(f"step_2_{self.iteration}") + + @router(step_2) + def step_3(self): + execution_order.append(f"step_3_{self.iteration}") + self.iteration += 1 + if self.iteration < self.max_iterations: + return "loop" + + return "exit" + + flow = CyclicFlow() + flow.kickoff() + + expected_order = [] + for i in range(flow.max_iterations): + expected_order.extend([f"step_1_{i}", f"step_2_{i}", f"step_3_{i}"]) + + assert execution_order == expected_order + + +def test_flow_with_and_condition(): + """Test a flow where a step waits for multiple other steps to complete.""" + execution_order = [] + + class AndConditionFlow(Flow): + @start() + def step_1(self): + execution_order.append("step_1") + + @start() + def step_2(self): + execution_order.append("step_2") + + @listen(and_(step_1, step_2)) + def step_3(self): + execution_order.append("step_3") + + flow = AndConditionFlow() + flow.kickoff() + + assert "step_1" in execution_order + assert "step_2" in execution_order + assert execution_order[-1] == "step_3" + assert execution_order.index("step_3") > execution_order.index("step_1") + assert execution_order.index("step_3") > execution_order.index("step_2") + + +def test_flow_with_or_condition(): + """Test a flow where a step is triggered when any of multiple steps complete.""" + execution_order = [] + + class OrConditionFlow(Flow): + @start() + def step_a(self): + execution_order.append("step_a") + + @start() + def step_b(self): + execution_order.append("step_b") + + @listen(or_(step_a, step_b)) + def step_c(self): + execution_order.append("step_c") + + flow = OrConditionFlow() + flow.kickoff() + + assert "step_a" in execution_order or "step_b" in execution_order + assert "step_c" in execution_order + assert execution_order.index("step_c") > min( + execution_order.index("step_a"), execution_order.index("step_b") + ) + + +def test_flow_with_router(): + """Test a flow that uses a router method to determine the next step.""" + execution_order = [] + + class RouterFlow(Flow): + @start() + def start_method(self): + execution_order.append("start_method") + + @router(start_method) + def router(self): + execution_order.append("router") + # Ensure the condition is set to True to follow the "step_if_true" path + condition = True + return "step_if_true" if condition else "step_if_false" + + @listen("step_if_true") + def truthy(self): + execution_order.append("step_if_true") + + @listen("step_if_false") + def falsy(self): + execution_order.append("step_if_false") + + flow = RouterFlow() + flow.kickoff() + + assert execution_order == ["start_method", "router", "step_if_true"] + + +def test_async_flow(): + """Test an asynchronous flow.""" + execution_order = [] + + class AsyncFlow(Flow): + @start() + async def step_1(self): + execution_order.append("step_1") + await asyncio.sleep(0.1) + + @listen(step_1) + async def step_2(self): + execution_order.append("step_2") + await asyncio.sleep(0.1) + + flow = AsyncFlow() + asyncio.run(flow.kickoff_async()) + + assert execution_order == ["step_1", "step_2"] + + +def test_flow_with_exceptions(): + """Test flow behavior when exceptions occur in steps.""" + execution_order = [] + + class ExceptionFlow(Flow): + @start() + def step_1(self): + execution_order.append("step_1") + raise ValueError("An error occurred in step_1") + + @listen(step_1) + def step_2(self): + execution_order.append("step_2") + + flow = ExceptionFlow() + + with pytest.raises(ValueError): + flow.kickoff() + + # Ensure step_2 did not execute + assert execution_order == ["step_1"] + + +def test_flow_restart(): + """Test restarting a flow after it has completed.""" + execution_order = [] + + class RestartableFlow(Flow): + @start() + def step_1(self): + execution_order.append("step_1") + + @listen(step_1) + def step_2(self): + execution_order.append("step_2") + + flow = RestartableFlow() + flow.kickoff() + flow.kickoff() # Restart the flow + + assert execution_order == ["step_1", "step_2", "step_1", "step_2"] + + +def test_flow_with_custom_state(): + """Test a flow that maintains and modifies internal state.""" + + class StateFlow(Flow): + def __init__(self): + super().__init__() + self.counter = 0 + + @start() + def step_1(self): + self.counter += 1 + + @listen(step_1) + def step_2(self): + self.counter *= 2 + assert self.counter == 2 + + flow = StateFlow() + flow.kickoff() + assert flow.counter == 2