Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix flows to support cycles and added in test #1556

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 25 additions & 28 deletions src/crewai/flow/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,6 @@ def __new__(mcs, name, bases, dct):
condition_type = getattr(attr_value, "__condition_type__", "OR")
listeners[attr_name] = (condition_type, methods)

# TODO: should we add a check for __condition_type__ 'AND'?
elif hasattr(attr_value, "__is_router__"):
routers[attr_value.__router_for__] = attr_name
possible_returns = get_possible_return_constants(attr_value)
Expand Down Expand Up @@ -171,8 +170,7 @@ class _FlowGeneric(cls): # type: ignore
def __init__(self) -> None:
self._methods: Dict[str, Callable] = {}
self._state: T = self._create_initial_state()
self._executed_methods: Set[str] = set()
self._scheduled_tasks: Set[str] = set()
self._method_execution_counts: Dict[str, int] = {}
self._pending_and_listeners: Dict[str, Set[str]] = {}
self._method_outputs: List[Any] = [] # List to store all method outputs

Expand Down Expand Up @@ -309,7 +307,10 @@ async def _execute_method(
)
self._method_outputs.append(result) # Store the output

self._executed_methods.add(method_name)
# Track method execution counts
self._method_execution_counts[method_name] = (
self._method_execution_counts.get(method_name, 0) + 1
)

return result

Expand All @@ -319,35 +320,34 @@ async def _execute_listeners(self, trigger_method: str, result: Any) -> None:
if trigger_method in self._routers:
router_method = self._methods[self._routers[trigger_method]]
path = await self._execute_method(
trigger_method, router_method
) # TODO: Change or not?
# Use the path as the new trigger method
self._routers[trigger_method], router_method
)
trigger_method = path

for listener_name, (condition_type, methods) in self._listeners.items():
if condition_type == "OR":
if trigger_method in methods:
if (
listener_name not in self._executed_methods
and listener_name not in self._scheduled_tasks
):
self._scheduled_tasks.add(listener_name)
listener_tasks.append(
self._execute_single_listener(listener_name, result)
)
# Schedule the listener without preventing re-execution
listener_tasks.append(
self._execute_single_listener(listener_name, result)
)
elif condition_type == "AND":
if all(method in self._executed_methods for method in methods):
if (
listener_name not in self._executed_methods
and listener_name not in self._scheduled_tasks
):
self._scheduled_tasks.add(listener_name)
listener_tasks.append(
self._execute_single_listener(listener_name, result)
)
# Initialize pending methods for this listener if not already done
if listener_name not in self._pending_and_listeners:
self._pending_and_listeners[listener_name] = set(methods)
# Remove the trigger method from pending methods
self._pending_and_listeners[listener_name].discard(trigger_method)
if not self._pending_and_listeners[listener_name]:
# All required methods have been executed
listener_tasks.append(
self._execute_single_listener(listener_name, result)
)
# Reset pending methods for this listener
self._pending_and_listeners.pop(listener_name, None)

# Run all listener tasks concurrently and wait for them to complete
await asyncio.gather(*listener_tasks)
if listener_tasks:
await asyncio.gather(*listener_tasks)

async def _execute_single_listener(self, listener_name: str, result: Any) -> None:
try:
Expand All @@ -367,9 +367,6 @@ async def _execute_single_listener(self, listener_name: str, result: Any) -> Non
# If listener does not expect parameters, call without arguments
listener_result = await self._execute_method(listener_name, method)

# Remove from scheduled tasks after execution
self._scheduled_tasks.discard(listener_name)

# Execute listeners of this listener
await self._execute_listeners(listener_name, listener_result)
except Exception as e:
Expand Down
264 changes: 264 additions & 0 deletions tests/flow_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,264 @@
"""Test Flow creation and execution basic functionality."""

import asyncio

import pytest
from crewai.flow.flow import Flow, and_, listen, or_, router, start


def test_simple_sequential_flow():
"""Test a simple flow with two steps called sequentially."""
execution_order = []

class SimpleFlow(Flow):
@start()
def step_1(self):
execution_order.append("step_1")

@listen(step_1)
def step_2(self):
execution_order.append("step_2")

flow = SimpleFlow()
flow.kickoff()

assert execution_order == ["step_1", "step_2"]


def test_flow_with_multiple_starts():
"""Test a flow with multiple start methods."""
execution_order = []

class MultiStartFlow(Flow):
@start()
def step_a(self):
execution_order.append("step_a")

@start()
def step_b(self):
execution_order.append("step_b")

@listen(step_a)
def step_c(self):
execution_order.append("step_c")

@listen(step_b)
def step_d(self):
execution_order.append("step_d")

flow = MultiStartFlow()
flow.kickoff()

assert "step_a" in execution_order
assert "step_b" in execution_order
assert "step_c" in execution_order
assert "step_d" in execution_order
assert execution_order.index("step_c") > execution_order.index("step_a")
assert execution_order.index("step_d") > execution_order.index("step_b")


def test_cyclic_flow():
"""Test a cyclic flow that runs a finite number of iterations."""
execution_order = []

class CyclicFlow(Flow):
iteration = 0
max_iterations = 3

@start("loop")
def step_1(self):
if self.iteration >= self.max_iterations:
return # Do not proceed further
execution_order.append(f"step_1_{self.iteration}")

@listen(step_1)
def step_2(self):
execution_order.append(f"step_2_{self.iteration}")

@router(step_2)
def step_3(self):
execution_order.append(f"step_3_{self.iteration}")
self.iteration += 1
if self.iteration < self.max_iterations:
return "loop"

return "exit"

flow = CyclicFlow()
flow.kickoff()

expected_order = []
for i in range(flow.max_iterations):
expected_order.extend([f"step_1_{i}", f"step_2_{i}", f"step_3_{i}"])

assert execution_order == expected_order


def test_flow_with_and_condition():
"""Test a flow where a step waits for multiple other steps to complete."""
execution_order = []

class AndConditionFlow(Flow):
@start()
def step_1(self):
execution_order.append("step_1")

@start()
def step_2(self):
execution_order.append("step_2")

@listen(and_(step_1, step_2))
def step_3(self):
execution_order.append("step_3")

flow = AndConditionFlow()
flow.kickoff()

assert "step_1" in execution_order
assert "step_2" in execution_order
assert execution_order[-1] == "step_3"
assert execution_order.index("step_3") > execution_order.index("step_1")
assert execution_order.index("step_3") > execution_order.index("step_2")


def test_flow_with_or_condition():
"""Test a flow where a step is triggered when any of multiple steps complete."""
execution_order = []

class OrConditionFlow(Flow):
@start()
def step_a(self):
execution_order.append("step_a")

@start()
def step_b(self):
execution_order.append("step_b")

@listen(or_(step_a, step_b))
def step_c(self):
execution_order.append("step_c")

flow = OrConditionFlow()
flow.kickoff()

assert "step_a" in execution_order or "step_b" in execution_order
assert "step_c" in execution_order
assert execution_order.index("step_c") > min(
execution_order.index("step_a"), execution_order.index("step_b")
)


def test_flow_with_router():
"""Test a flow that uses a router method to determine the next step."""
execution_order = []

class RouterFlow(Flow):
@start()
def start_method(self):
execution_order.append("start_method")

@router(start_method)
def router(self):
execution_order.append("router")
# Ensure the condition is set to True to follow the "step_if_true" path
condition = True
return "step_if_true" if condition else "step_if_false"

@listen("step_if_true")
def truthy(self):
execution_order.append("step_if_true")

@listen("step_if_false")
def falsy(self):
execution_order.append("step_if_false")

flow = RouterFlow()
flow.kickoff()

assert execution_order == ["start_method", "router", "step_if_true"]


def test_async_flow():
"""Test an asynchronous flow."""
execution_order = []

class AsyncFlow(Flow):
@start()
async def step_1(self):
execution_order.append("step_1")
await asyncio.sleep(0.1)

@listen(step_1)
async def step_2(self):
execution_order.append("step_2")
await asyncio.sleep(0.1)

flow = AsyncFlow()
asyncio.run(flow.kickoff_async())

assert execution_order == ["step_1", "step_2"]


def test_flow_with_exceptions():
"""Test flow behavior when exceptions occur in steps."""
execution_order = []

class ExceptionFlow(Flow):
@start()
def step_1(self):
execution_order.append("step_1")
raise ValueError("An error occurred in step_1")

@listen(step_1)
def step_2(self):
execution_order.append("step_2")

flow = ExceptionFlow()

with pytest.raises(ValueError):
flow.kickoff()

# Ensure step_2 did not execute
assert execution_order == ["step_1"]


def test_flow_restart():
"""Test restarting a flow after it has completed."""
execution_order = []

class RestartableFlow(Flow):
@start()
def step_1(self):
execution_order.append("step_1")

@listen(step_1)
def step_2(self):
execution_order.append("step_2")

flow = RestartableFlow()
flow.kickoff()
flow.kickoff() # Restart the flow

assert execution_order == ["step_1", "step_2", "step_1", "step_2"]


def test_flow_with_custom_state():
"""Test a flow that maintains and modifies internal state."""

class StateFlow(Flow):
def __init__(self):
super().__init__()
self.counter = 0

@start()
def step_1(self):
self.counter += 1

@listen(step_1)
def step_2(self):
self.counter *= 2
assert self.counter == 2

flow = StateFlow()
flow.kickoff()
assert flow.counter == 2
Loading