From 6cc2f510bf737b0053ae01a8a6d5582086ae5ec0 Mon Sep 17 00:00:00 2001 From: "Brandon Hancock (bhancock_ai)" <109994880+bhancockio@users.noreply.github.com> Date: Tue, 24 Dec 2024 16:55:44 -0500 Subject: [PATCH] Feat/joao flow improvement requests (#1795) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Add in or and and in router * In the middle of improving plotting * final plot changes --------- Co-authored-by: João Moura --- src/crewai/flow/flow.py | 157 ++++++++++++------------- src/crewai/flow/utils.py | 56 +++++++-- src/crewai/flow/visualization_utils.py | 94 ++++++++++----- tests/flow_test.py | 59 ++++++++++ 4 files changed, 237 insertions(+), 129 deletions(-) diff --git a/src/crewai/flow/flow.py b/src/crewai/flow/flow.py index ccc76dc95c..4a6361cce4 100644 --- a/src/crewai/flow/flow.py +++ b/src/crewai/flow/flow.py @@ -80,10 +80,27 @@ def decorator(func): return decorator -def router(method): +def router(condition): def decorator(func): func.__is_router__ = True - func.__router_for__ = method.__name__ + # Handle conditions like listen/start + if isinstance(condition, str): + func.__trigger_methods__ = [condition] + func.__condition_type__ = "OR" + elif ( + isinstance(condition, dict) + and "type" in condition + and "methods" in condition + ): + func.__trigger_methods__ = condition["methods"] + func.__condition_type__ = condition["type"] + elif callable(condition) and hasattr(condition, "__name__"): + func.__trigger_methods__ = [condition.__name__] + func.__condition_type__ = "OR" + else: + raise ValueError( + "Condition must be a method, string, or a result of or_() or and_()" + ) return func return decorator @@ -123,8 +140,8 @@ def __new__(mcs, name, bases, dct): start_methods = [] listeners = {} - routers = {} router_paths = {} + routers = set() for attr_name, attr_value in dct.items(): if hasattr(attr_value, "__is_start_method__"): @@ -137,18 +154,11 @@ def __new__(mcs, name, bases, dct): methods = attr_value.__trigger_methods__ condition_type = getattr(attr_value, "__condition_type__", "OR") listeners[attr_name] = (condition_type, methods) - - elif hasattr(attr_value, "__is_router__"): - routers[attr_value.__router_for__] = attr_name - possible_returns = get_possible_return_constants(attr_value) - if possible_returns: - router_paths[attr_name] = possible_returns - - # Register router as a listener to its triggering method - trigger_method_name = attr_value.__router_for__ - methods = [trigger_method_name] - condition_type = "OR" - listeners[attr_name] = (condition_type, methods) + if hasattr(attr_value, "__is_router__") and attr_value.__is_router__: + routers.add(attr_name) + possible_returns = get_possible_return_constants(attr_value) + if possible_returns: + router_paths[attr_name] = possible_returns setattr(cls, "_start_methods", start_methods) setattr(cls, "_listeners", listeners) @@ -163,7 +173,7 @@ class Flow(Generic[T], metaclass=FlowMeta): _start_methods: List[str] = [] _listeners: Dict[str, tuple[str, List[str]]] = {} - _routers: Dict[str, str] = {} + _routers: Set[str] = set() _router_paths: Dict[str, List[str]] = {} initial_state: Union[Type[T], T, None] = None event_emitter = Signal("event_emitter") @@ -210,20 +220,10 @@ def method_outputs(self) -> List[Any]: return self._method_outputs def _initialize_state(self, inputs: Dict[str, Any]) -> None: - """ - Initializes or updates the state with the provided inputs. - - Args: - inputs: Dictionary of inputs to initialize or update the state. - - Raises: - ValueError: If inputs do not match the structured state model. - TypeError: If state is neither a BaseModel instance nor a dictionary. - """ if isinstance(self._state, BaseModel): - # Structured state management + # Structured state try: - # Define a function to create the dynamic class + def create_model_with_extra_forbid( base_model: Type[BaseModel], ) -> Type[BaseModel]: @@ -233,34 +233,20 @@ class ModelWithExtraForbid(base_model): # type: ignore return ModelWithExtraForbid - # Create the dynamic class ModelWithExtraForbid = create_model_with_extra_forbid( self._state.__class__ ) - - # Create a new instance using the combined state and inputs self._state = cast( T, ModelWithExtraForbid(**{**self._state.model_dump(), **inputs}) ) - except ValidationError as e: raise ValueError(f"Invalid inputs for structured state: {e}") from e elif isinstance(self._state, dict): - # Unstructured state management self._state.update(inputs) else: raise TypeError("State must be a BaseModel instance or a dictionary.") def kickoff(self, inputs: Optional[Dict[str, Any]] = None) -> Any: - """ - Starts the execution of the flow synchronously. - - Args: - inputs: Optional dictionary of inputs to initialize or update the state. - - Returns: - The final output from the flow execution. - """ self.event_emitter.send( self, event=FlowStartedEvent( @@ -274,15 +260,6 @@ def kickoff(self, inputs: Optional[Dict[str, Any]] = None) -> Any: return asyncio.run(self.kickoff_async()) async def kickoff_async(self, inputs: Optional[Dict[str, Any]] = None) -> Any: - """ - Starts the execution of the flow asynchronously. - - Args: - inputs: Optional dictionary of inputs to initialize or update the state. - - Returns: - The final output from the flow execution. - """ if not self._start_methods: raise ValueError("No start method defined") @@ -290,16 +267,12 @@ async def kickoff_async(self, inputs: Optional[Dict[str, Any]] = None) -> Any: self.__class__.__name__, list(self._methods.keys()) ) - # Create tasks for all start methods tasks = [ self._execute_start_method(start_method) for start_method in self._start_methods ] - - # Run all start methods concurrently await asyncio.gather(*tasks) - # Determine the final output (from the last executed method) final_output = self._method_outputs[-1] if self._method_outputs else None self.event_emitter.send( @@ -310,7 +283,6 @@ async def kickoff_async(self, inputs: Optional[Dict[str, Any]] = None) -> Any: result=final_output, ), ) - return final_output async def _execute_start_method(self, start_method_name: str) -> None: @@ -327,49 +299,68 @@ async def _execute_method( if asyncio.iscoroutinefunction(method) else method(*args, **kwargs) ) - self._method_outputs.append(result) # Store the output - - # Track method execution counts + self._method_outputs.append(result) self._method_execution_counts[method_name] = ( self._method_execution_counts.get(method_name, 0) + 1 ) - return result async def _execute_listeners(self, trigger_method: str, result: Any) -> None: - listener_tasks = [] - - if trigger_method in self._routers: - router_method = self._methods[self._routers[trigger_method]] - path = await self._execute_method( - self._routers[trigger_method], router_method + # First, handle routers repeatedly until no router triggers anymore + while True: + routers_triggered = self._find_triggered_methods( + trigger_method, router_only=True ) - trigger_method = path - + if not routers_triggered: + break + for router_name in routers_triggered: + await self._execute_single_listener(router_name, result) + # After executing router, the router's result is the path + # The last router executed sets the trigger_method + # The router result is the last element in self._method_outputs + trigger_method = self._method_outputs[-1] + + # Now that no more routers are triggered by current trigger_method, + # execute normal listeners + listeners_triggered = self._find_triggered_methods( + trigger_method, router_only=False + ) + if listeners_triggered: + tasks = [ + self._execute_single_listener(listener_name, result) + for listener_name in listeners_triggered + ] + await asyncio.gather(*tasks) + + def _find_triggered_methods( + self, trigger_method: str, router_only: bool + ) -> List[str]: + triggered = [] for listener_name, (condition_type, methods) in self._listeners.items(): + is_router = listener_name in self._routers + + if router_only != is_router: + continue + if condition_type == "OR": + # If the trigger_method matches any in methods, run this if trigger_method in methods: - # Schedule the listener without preventing re-execution - listener_tasks.append( - self._execute_single_listener(listener_name, result) - ) + triggered.append(listener_name) elif condition_type == "AND": # Initialize pending methods for this listener if not already done if listener_name not in self._pending_and_listeners: self._pending_and_listeners[listener_name] = set(methods) # Remove the trigger method from pending methods - self._pending_and_listeners[listener_name].discard(trigger_method) + if trigger_method in self._pending_and_listeners[listener_name]: + self._pending_and_listeners[listener_name].discard(trigger_method) + if not self._pending_and_listeners[listener_name]: # All required methods have been executed - listener_tasks.append( - self._execute_single_listener(listener_name, result) - ) + triggered.append(listener_name) # Reset pending methods for this listener self._pending_and_listeners.pop(listener_name, None) - # Run all listener tasks concurrently and wait for them to complete - if listener_tasks: - await asyncio.gather(*listener_tasks) + return triggered async def _execute_single_listener(self, listener_name: str, result: Any) -> None: try: @@ -386,17 +377,13 @@ async def _execute_single_listener(self, listener_name: str, result: Any) -> Non sig = inspect.signature(method) params = list(sig.parameters.values()) - - # Exclude 'self' parameter method_params = [p for p in params if p.name != "self"] if method_params: - # If listener expects parameters, pass the result listener_result = await self._execute_method( listener_name, method, result ) else: - # If listener does not expect parameters, call without arguments listener_result = await self._execute_method(listener_name, method) self.event_emitter.send( @@ -408,8 +395,9 @@ async def _execute_single_listener(self, listener_name: str, result: Any) -> Non ), ) - # Execute listeners of this listener + # Execute listeners (and possibly routers) of this listener await self._execute_listeners(listener_name, listener_result) + except Exception as e: print( f"[Flow._execute_single_listener] Error in method {listener_name}: {e}" @@ -422,5 +410,4 @@ def plot(self, filename: str = "crewai_flow") -> None: self._telemetry.flow_plotting_span( self.__class__.__name__, list(self._methods.keys()) ) - plot_flow(self, filename) diff --git a/src/crewai/flow/utils.py b/src/crewai/flow/utils.py index 98d03f24f6..dc1f611fbb 100644 --- a/src/crewai/flow/utils.py +++ b/src/crewai/flow/utils.py @@ -31,16 +31,50 @@ def get_possible_return_constants(function): print(f"Source code:\n{source}") return None - return_values = [] + return_values = set() + dict_definitions = {} + + class DictionaryAssignmentVisitor(ast.NodeVisitor): + def visit_Assign(self, node): + # Check if this assignment is assigning a dictionary literal to a variable + if isinstance(node.value, ast.Dict) and len(node.targets) == 1: + target = node.targets[0] + if isinstance(target, ast.Name): + var_name = target.id + dict_values = [] + # Extract string values from the dictionary + for val in node.value.values: + if isinstance(val, ast.Constant) and isinstance(val.value, str): + dict_values.append(val.value) + # If non-string, skip or just ignore + if dict_values: + dict_definitions[var_name] = dict_values + self.generic_visit(node) class ReturnVisitor(ast.NodeVisitor): def visit_Return(self, node): - # Check if the return value is a constant (Python 3.8+) - if isinstance(node.value, ast.Constant): - return_values.append(node.value.value) - + # Direct string return + if isinstance(node.value, ast.Constant) and isinstance( + node.value.value, str + ): + return_values.add(node.value.value) + # Dictionary-based return, like return paths[result] + elif isinstance(node.value, ast.Subscript): + # Check if we're subscripting a known dictionary variable + if isinstance(node.value.value, ast.Name): + var_name = node.value.value.id + if var_name in dict_definitions: + # Add all possible dictionary values + for v in dict_definitions[var_name]: + return_values.add(v) + self.generic_visit(node) + + # First pass: identify dictionary assignments + DictionaryAssignmentVisitor().visit(code_ast) + # Second pass: identify returns ReturnVisitor().visit(code_ast) - return return_values + + return list(return_values) if return_values else None def calculate_node_levels(flow): @@ -61,10 +95,7 @@ def calculate_node_levels(flow): current_level = levels[current] visited.add(current) - for listener_name, ( - condition_type, - trigger_methods, - ) in flow._listeners.items(): + for listener_name, (condition_type, trigger_methods) in flow._listeners.items(): if condition_type == "OR": if current in trigger_methods: if ( @@ -89,7 +120,7 @@ def calculate_node_levels(flow): queue.append(listener_name) # Handle router connections - if current in flow._routers.values(): + if current in flow._routers: router_method_name = current paths = flow._router_paths.get(router_method_name, []) for path in paths: @@ -105,6 +136,7 @@ def calculate_node_levels(flow): levels[listener_name] = current_level + 1 if listener_name not in visited: queue.append(listener_name) + return levels @@ -142,7 +174,7 @@ def dfs_ancestors(node, ancestors, visited, flow): dfs_ancestors(listener_name, ancestors, visited, flow) # Handle router methods separately - if node in flow._routers.values(): + if node in flow._routers: router_method_name = node paths = flow._router_paths.get(router_method_name, []) for path in paths: diff --git a/src/crewai/flow/visualization_utils.py b/src/crewai/flow/visualization_utils.py index 5b95a13699..321f633443 100644 --- a/src/crewai/flow/visualization_utils.py +++ b/src/crewai/flow/visualization_utils.py @@ -94,12 +94,14 @@ def add_edges(net, flow, node_positions, colors): ancestors = build_ancestor_dict(flow) parent_children = build_parent_children_dict(flow) + # Edges for normal listeners for method_name in flow._listeners: condition_type, trigger_methods = flow._listeners[method_name] is_and_condition = condition_type == "AND" for trigger in trigger_methods: - if trigger in flow._methods or trigger in flow._routers.values(): + # Check if nodes exist before adding edges + if trigger in node_positions and method_name in node_positions: is_router_edge = any( trigger in paths for paths in flow._router_paths.values() ) @@ -135,7 +137,22 @@ def add_edges(net, flow, node_positions, colors): } net.add_edge(trigger, method_name, **edge_style) + else: + # Nodes not found in node_positions. Check if it's a known router outcome and a known method. + is_router_edge = any( + trigger in paths for paths in flow._router_paths.values() + ) + # Check if method_name is a known method + method_known = method_name in flow._methods + + # If it's a known router edge and the method is known, don't warn. + # This means the path is legitimate, just not reflected as nodes here. + if not (is_router_edge and method_known): + print( + f"Warning: No node found for '{trigger}' or '{method_name}'. Skipping edge." + ) + # Edges for router return paths for router_method_name, paths in flow._router_paths.items(): for path in paths: for listener_name, ( @@ -143,36 +160,49 @@ def add_edges(net, flow, node_positions, colors): trigger_methods, ) in flow._listeners.items(): if path in trigger_methods: - is_cycle_edge = is_ancestor(trigger, method_name, ancestors) - parent_has_multiple_children = ( - len(parent_children.get(router_method_name, [])) > 1 - ) - needs_curvature = is_cycle_edge or parent_has_multiple_children - - if needs_curvature: - source_pos = node_positions.get(router_method_name) - target_pos = node_positions.get(listener_name) - - if source_pos and target_pos: - dx = target_pos[0] - source_pos[0] - smooth_type = "curvedCCW" if dx <= 0 else "curvedCW" - index = get_child_index( - router_method_name, listener_name, parent_children - ) - edge_smooth = { - "type": smooth_type, - "roundness": 0.2 + (0.1 * index), - } + if ( + router_method_name in node_positions + and listener_name in node_positions + ): + is_cycle_edge = is_ancestor( + router_method_name, listener_name, ancestors + ) + parent_has_multiple_children = ( + len(parent_children.get(router_method_name, [])) > 1 + ) + needs_curvature = is_cycle_edge or parent_has_multiple_children + + if needs_curvature: + source_pos = node_positions.get(router_method_name) + target_pos = node_positions.get(listener_name) + + if source_pos and target_pos: + dx = target_pos[0] - source_pos[0] + smooth_type = "curvedCCW" if dx <= 0 else "curvedCW" + index = get_child_index( + router_method_name, listener_name, parent_children + ) + edge_smooth = { + "type": smooth_type, + "roundness": 0.2 + (0.1 * index), + } + else: + edge_smooth = {"type": "cubicBezier"} else: - edge_smooth = {"type": "cubicBezier"} + edge_smooth = False + + edge_style = { + "color": colors["router_edge"], + "width": 2, + "arrows": "to", + "dashes": True, + "smooth": edge_smooth, + } + net.add_edge(router_method_name, listener_name, **edge_style) else: - edge_smooth = False - - edge_style = { - "color": colors["router_edge"], - "width": 2, - "arrows": "to", - "dashes": True, - "smooth": edge_smooth, - } - net.add_edge(router_method_name, listener_name, **edge_style) + # Same check here: known router edge and known method? + method_known = listener_name in flow._methods + if not method_known: + print( + f"Warning: No node found for '{router_method_name}' or '{listener_name}'. Skipping edge." + ) diff --git a/tests/flow_test.py b/tests/flow_test.py index 2e20203619..d52c459ce8 100644 --- a/tests/flow_test.py +++ b/tests/flow_test.py @@ -263,3 +263,62 @@ def step_2(self): flow = StateFlow() flow.kickoff() assert flow.counter == 2 + + +def test_router_with_multiple_conditions(): + """Test a router that triggers when any of multiple steps complete (OR condition), + and another router that triggers only after all specified steps complete (AND condition). + """ + + execution_order = [] + + class ComplexRouterFlow(Flow): + @start() + def step_a(self): + execution_order.append("step_a") + + @start() + def step_b(self): + execution_order.append("step_b") + + @router(or_("step_a", "step_b")) + def router_or(self): + execution_order.append("router_or") + return "next_step_or" + + @listen("next_step_or") + def handle_next_step_or_event(self): + execution_order.append("handle_next_step_or_event") + + @listen(handle_next_step_or_event) + def branch_2_step(self): + execution_order.append("branch_2_step") + + @router(and_(handle_next_step_or_event, branch_2_step)) + def router_and(self): + execution_order.append("router_and") + return "final_step" + + @listen("final_step") + def log_final_step(self): + execution_order.append("log_final_step") + + flow = ComplexRouterFlow() + flow.kickoff() + + assert "step_a" in execution_order + assert "step_b" in execution_order + assert "router_or" in execution_order + assert "handle_next_step_or_event" in execution_order + assert "branch_2_step" in execution_order + assert "router_and" in execution_order + assert "log_final_step" in execution_order + + # Check that the AND router triggered after both relevant steps: + assert execution_order.index("router_and") > execution_order.index( + "handle_next_step_or_event" + ) + assert execution_order.index("router_and") > execution_order.index("branch_2_step") + + # final_step should run after router_and + assert execution_order.index("log_final_step") > execution_order.index("router_and")