Skip to content

Commit

Permalink
Feat/joao flow improvement requests (#1795)
Browse files Browse the repository at this point in the history
* Add in or and and in router

* In the middle of improving plotting

* final plot changes

---------

Co-authored-by: João Moura <[email protected]>
  • Loading branch information
bhancockio and joaomdmoura authored Dec 24, 2024
1 parent 9a65abf commit 6cc2f51
Show file tree
Hide file tree
Showing 4 changed files with 237 additions and 129 deletions.
157 changes: 72 additions & 85 deletions src/crewai/flow/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,27 @@ def decorator(func):
return decorator


def router(method):
def router(condition):
def decorator(func):
func.__is_router__ = True
func.__router_for__ = method.__name__
# Handle conditions like listen/start
if isinstance(condition, str):
func.__trigger_methods__ = [condition]
func.__condition_type__ = "OR"
elif (
isinstance(condition, dict)
and "type" in condition
and "methods" in condition
):
func.__trigger_methods__ = condition["methods"]
func.__condition_type__ = condition["type"]
elif callable(condition) and hasattr(condition, "__name__"):
func.__trigger_methods__ = [condition.__name__]
func.__condition_type__ = "OR"
else:
raise ValueError(
"Condition must be a method, string, or a result of or_() or and_()"
)
return func

return decorator
Expand Down Expand Up @@ -123,8 +140,8 @@ def __new__(mcs, name, bases, dct):

start_methods = []
listeners = {}
routers = {}
router_paths = {}
routers = set()

for attr_name, attr_value in dct.items():
if hasattr(attr_value, "__is_start_method__"):
Expand All @@ -137,18 +154,11 @@ def __new__(mcs, name, bases, dct):
methods = attr_value.__trigger_methods__
condition_type = getattr(attr_value, "__condition_type__", "OR")
listeners[attr_name] = (condition_type, methods)

elif hasattr(attr_value, "__is_router__"):
routers[attr_value.__router_for__] = attr_name
possible_returns = get_possible_return_constants(attr_value)
if possible_returns:
router_paths[attr_name] = possible_returns

# Register router as a listener to its triggering method
trigger_method_name = attr_value.__router_for__
methods = [trigger_method_name]
condition_type = "OR"
listeners[attr_name] = (condition_type, methods)
if hasattr(attr_value, "__is_router__") and attr_value.__is_router__:
routers.add(attr_name)
possible_returns = get_possible_return_constants(attr_value)
if possible_returns:
router_paths[attr_name] = possible_returns

setattr(cls, "_start_methods", start_methods)
setattr(cls, "_listeners", listeners)
Expand All @@ -163,7 +173,7 @@ class Flow(Generic[T], metaclass=FlowMeta):

_start_methods: List[str] = []
_listeners: Dict[str, tuple[str, List[str]]] = {}
_routers: Dict[str, str] = {}
_routers: Set[str] = set()
_router_paths: Dict[str, List[str]] = {}
initial_state: Union[Type[T], T, None] = None
event_emitter = Signal("event_emitter")
Expand Down Expand Up @@ -210,20 +220,10 @@ def method_outputs(self) -> List[Any]:
return self._method_outputs

def _initialize_state(self, inputs: Dict[str, Any]) -> None:
"""
Initializes or updates the state with the provided inputs.
Args:
inputs: Dictionary of inputs to initialize or update the state.
Raises:
ValueError: If inputs do not match the structured state model.
TypeError: If state is neither a BaseModel instance nor a dictionary.
"""
if isinstance(self._state, BaseModel):
# Structured state management
# Structured state
try:
# Define a function to create the dynamic class

def create_model_with_extra_forbid(
base_model: Type[BaseModel],
) -> Type[BaseModel]:
Expand All @@ -233,34 +233,20 @@ class ModelWithExtraForbid(base_model): # type: ignore

return ModelWithExtraForbid

# Create the dynamic class
ModelWithExtraForbid = create_model_with_extra_forbid(
self._state.__class__
)

# Create a new instance using the combined state and inputs
self._state = cast(
T, ModelWithExtraForbid(**{**self._state.model_dump(), **inputs})
)

except ValidationError as e:
raise ValueError(f"Invalid inputs for structured state: {e}") from e
elif isinstance(self._state, dict):
# Unstructured state management
self._state.update(inputs)
else:
raise TypeError("State must be a BaseModel instance or a dictionary.")

def kickoff(self, inputs: Optional[Dict[str, Any]] = None) -> Any:
"""
Starts the execution of the flow synchronously.
Args:
inputs: Optional dictionary of inputs to initialize or update the state.
Returns:
The final output from the flow execution.
"""
self.event_emitter.send(
self,
event=FlowStartedEvent(
Expand All @@ -274,32 +260,19 @@ def kickoff(self, inputs: Optional[Dict[str, Any]] = None) -> Any:
return asyncio.run(self.kickoff_async())

async def kickoff_async(self, inputs: Optional[Dict[str, Any]] = None) -> Any:
"""
Starts the execution of the flow asynchronously.
Args:
inputs: Optional dictionary of inputs to initialize or update the state.
Returns:
The final output from the flow execution.
"""
if not self._start_methods:
raise ValueError("No start method defined")

self._telemetry.flow_execution_span(
self.__class__.__name__, list(self._methods.keys())
)

# Create tasks for all start methods
tasks = [
self._execute_start_method(start_method)
for start_method in self._start_methods
]

# Run all start methods concurrently
await asyncio.gather(*tasks)

# Determine the final output (from the last executed method)
final_output = self._method_outputs[-1] if self._method_outputs else None

self.event_emitter.send(
Expand All @@ -310,7 +283,6 @@ async def kickoff_async(self, inputs: Optional[Dict[str, Any]] = None) -> Any:
result=final_output,
),
)

return final_output

async def _execute_start_method(self, start_method_name: str) -> None:
Expand All @@ -327,49 +299,68 @@ async def _execute_method(
if asyncio.iscoroutinefunction(method)
else method(*args, **kwargs)
)
self._method_outputs.append(result) # Store the output

# Track method execution counts
self._method_outputs.append(result)
self._method_execution_counts[method_name] = (
self._method_execution_counts.get(method_name, 0) + 1
)

return result

async def _execute_listeners(self, trigger_method: str, result: Any) -> None:
listener_tasks = []

if trigger_method in self._routers:
router_method = self._methods[self._routers[trigger_method]]
path = await self._execute_method(
self._routers[trigger_method], router_method
# First, handle routers repeatedly until no router triggers anymore
while True:
routers_triggered = self._find_triggered_methods(
trigger_method, router_only=True
)
trigger_method = path

if not routers_triggered:
break
for router_name in routers_triggered:
await self._execute_single_listener(router_name, result)
# After executing router, the router's result is the path
# The last router executed sets the trigger_method
# The router result is the last element in self._method_outputs
trigger_method = self._method_outputs[-1]

# Now that no more routers are triggered by current trigger_method,
# execute normal listeners
listeners_triggered = self._find_triggered_methods(
trigger_method, router_only=False
)
if listeners_triggered:
tasks = [
self._execute_single_listener(listener_name, result)
for listener_name in listeners_triggered
]
await asyncio.gather(*tasks)

def _find_triggered_methods(
self, trigger_method: str, router_only: bool
) -> List[str]:
triggered = []
for listener_name, (condition_type, methods) in self._listeners.items():
is_router = listener_name in self._routers

if router_only != is_router:
continue

if condition_type == "OR":
# If the trigger_method matches any in methods, run this
if trigger_method in methods:
# Schedule the listener without preventing re-execution
listener_tasks.append(
self._execute_single_listener(listener_name, result)
)
triggered.append(listener_name)
elif condition_type == "AND":
# Initialize pending methods for this listener if not already done
if listener_name not in self._pending_and_listeners:
self._pending_and_listeners[listener_name] = set(methods)
# Remove the trigger method from pending methods
self._pending_and_listeners[listener_name].discard(trigger_method)
if trigger_method in self._pending_and_listeners[listener_name]:
self._pending_and_listeners[listener_name].discard(trigger_method)

if not self._pending_and_listeners[listener_name]:
# All required methods have been executed
listener_tasks.append(
self._execute_single_listener(listener_name, result)
)
triggered.append(listener_name)
# Reset pending methods for this listener
self._pending_and_listeners.pop(listener_name, None)

# Run all listener tasks concurrently and wait for them to complete
if listener_tasks:
await asyncio.gather(*listener_tasks)
return triggered

async def _execute_single_listener(self, listener_name: str, result: Any) -> None:
try:
Expand All @@ -386,17 +377,13 @@ async def _execute_single_listener(self, listener_name: str, result: Any) -> Non

sig = inspect.signature(method)
params = list(sig.parameters.values())

# Exclude 'self' parameter
method_params = [p for p in params if p.name != "self"]

if method_params:
# If listener expects parameters, pass the result
listener_result = await self._execute_method(
listener_name, method, result
)
else:
# If listener does not expect parameters, call without arguments
listener_result = await self._execute_method(listener_name, method)

self.event_emitter.send(
Expand All @@ -408,8 +395,9 @@ async def _execute_single_listener(self, listener_name: str, result: Any) -> Non
),
)

# Execute listeners of this listener
# Execute listeners (and possibly routers) of this listener
await self._execute_listeners(listener_name, listener_result)

except Exception as e:
print(
f"[Flow._execute_single_listener] Error in method {listener_name}: {e}"
Expand All @@ -422,5 +410,4 @@ def plot(self, filename: str = "crewai_flow") -> None:
self._telemetry.flow_plotting_span(
self.__class__.__name__, list(self._methods.keys())
)

plot_flow(self, filename)
56 changes: 44 additions & 12 deletions src/crewai/flow/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,16 +31,50 @@ def get_possible_return_constants(function):
print(f"Source code:\n{source}")
return None

return_values = []
return_values = set()
dict_definitions = {}

class DictionaryAssignmentVisitor(ast.NodeVisitor):
def visit_Assign(self, node):
# Check if this assignment is assigning a dictionary literal to a variable
if isinstance(node.value, ast.Dict) and len(node.targets) == 1:
target = node.targets[0]
if isinstance(target, ast.Name):
var_name = target.id
dict_values = []
# Extract string values from the dictionary
for val in node.value.values:
if isinstance(val, ast.Constant) and isinstance(val.value, str):
dict_values.append(val.value)
# If non-string, skip or just ignore
if dict_values:
dict_definitions[var_name] = dict_values
self.generic_visit(node)

class ReturnVisitor(ast.NodeVisitor):
def visit_Return(self, node):
# Check if the return value is a constant (Python 3.8+)
if isinstance(node.value, ast.Constant):
return_values.append(node.value.value)

# Direct string return
if isinstance(node.value, ast.Constant) and isinstance(
node.value.value, str
):
return_values.add(node.value.value)
# Dictionary-based return, like return paths[result]
elif isinstance(node.value, ast.Subscript):
# Check if we're subscripting a known dictionary variable
if isinstance(node.value.value, ast.Name):
var_name = node.value.value.id
if var_name in dict_definitions:
# Add all possible dictionary values
for v in dict_definitions[var_name]:
return_values.add(v)
self.generic_visit(node)

# First pass: identify dictionary assignments
DictionaryAssignmentVisitor().visit(code_ast)
# Second pass: identify returns
ReturnVisitor().visit(code_ast)
return return_values

return list(return_values) if return_values else None


def calculate_node_levels(flow):
Expand All @@ -61,10 +95,7 @@ def calculate_node_levels(flow):
current_level = levels[current]
visited.add(current)

for listener_name, (
condition_type,
trigger_methods,
) in flow._listeners.items():
for listener_name, (condition_type, trigger_methods) in flow._listeners.items():
if condition_type == "OR":
if current in trigger_methods:
if (
Expand All @@ -89,7 +120,7 @@ def calculate_node_levels(flow):
queue.append(listener_name)

# Handle router connections
if current in flow._routers.values():
if current in flow._routers:
router_method_name = current
paths = flow._router_paths.get(router_method_name, [])
for path in paths:
Expand All @@ -105,6 +136,7 @@ def calculate_node_levels(flow):
levels[listener_name] = current_level + 1
if listener_name not in visited:
queue.append(listener_name)

return levels


Expand Down Expand Up @@ -142,7 +174,7 @@ def dfs_ancestors(node, ancestors, visited, flow):
dfs_ancestors(listener_name, ancestors, visited, flow)

# Handle router methods separately
if node in flow._routers.values():
if node in flow._routers:
router_method_name = node
paths = flow._router_paths.get(router_method_name, [])
for path in paths:
Expand Down
Loading

0 comments on commit 6cc2f51

Please sign in to comment.