Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add @persist decorator with FlowPersistence interface #1892

Open
wants to merge 21 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
aeacbbf
Add @persist decorator with SQLite persistence
devin-ai-integration[bot] Jan 14, 2025
668e195
Resolve merge conflicts integrating FlowState with new persistence
devin-ai-integration[bot] Jan 14, 2025
6d08831
Fix remaining merge conflicts in uv.lock
devin-ai-integration[bot] Jan 14, 2025
cc87e08
Fix final CUDA dependency conflicts in uv.lock
devin-ai-integration[bot] Jan 14, 2025
357ca68
Fix nvidia-cusparse-cu12 dependency conflicts in uv.lock
devin-ai-integration[bot] Jan 14, 2025
59e9afa
Fix triton filelock dependency conflicts in uv.lock
devin-ai-integration[bot] Jan 14, 2025
3c0101f
Fix merge conflict in crew_test.py
devin-ai-integration[bot] Jan 14, 2025
6f5e73d
Clean up trailing merge conflict marker in crew_test.py
devin-ai-integration[bot] Jan 14, 2025
e3e7e67
Improve type safety in persistence implementation and resolve merge c…
devin-ai-integration[bot] Jan 14, 2025
4e0a7ba
fix: Add explicit type casting in _create_initial_state method
devin-ai-integration[bot] Jan 14, 2025
212e60f
fix: Improve type safety in flow state handling with proper validation
devin-ai-integration[bot] Jan 14, 2025
9625630
fix: Improve type system with proper TypeVar scoping and validation
devin-ai-integration[bot] Jan 14, 2025
785e97a
fix: Improve state restoration logic and add comprehensive tests
devin-ai-integration[bot] Jan 15, 2025
1b6207d
fix: Initialize FlowState instances without passing id to constructor
devin-ai-integration[bot] Jan 15, 2025
2271447
Merge branch 'main' into devin/1736848480-persist-decorator
joaomdmoura Jan 15, 2025
37dc5ee
feat: Add class-level flow persistence decorator with SQLite default
devin-ai-integration[bot] Jan 15, 2025
7cd11f5
fix: Sort imports in decorators.py to fix lint error
devin-ai-integration[bot] Jan 15, 2025
40afbb9
style: Organize imports according to PEP 8 standard
devin-ai-integration[bot] Jan 15, 2025
1ab8fe3
style: Format typing imports with line breaks for better readability
devin-ai-integration[bot] Jan 15, 2025
8396f6e
style: Simplify import organization to fix lint error
devin-ai-integration[bot] Jan 15, 2025
a4bc624
style: Fix import sorting using Ruff auto-fix
devin-ai-integration[bot] Jan 15, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
161 changes: 131 additions & 30 deletions src/crewai/flow/flow.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import inspect
import uuid
from typing import (
Any,
Callable,
Expand All @@ -25,6 +26,8 @@
MethodExecutionStartedEvent,
)
from crewai.flow.flow_visualizer import plot_flow
from crewai.flow.persistence import FlowPersistence
from crewai.flow.persistence.base import FlowPersistence
from crewai.flow.utils import get_possible_return_constants
from crewai.telemetry import Telemetry

Expand Down Expand Up @@ -326,21 +329,27 @@ def __new__(mcs, name, bases, dct):
routers = set()

for attr_name, attr_value in dct.items():
if hasattr(attr_value, "__is_start_method__"):
start_methods.append(attr_name)
# Check for any flow-related attributes
if (hasattr(attr_value, "__is_flow_method__") or
hasattr(attr_value, "__is_start_method__") or
hasattr(attr_value, "__trigger_methods__") or
hasattr(attr_value, "__is_router__")):

# Register start methods
if hasattr(attr_value, "__is_start_method__"):
start_methods.append(attr_name)

# Register listeners and routers
if hasattr(attr_value, "__trigger_methods__"):
methods = attr_value.__trigger_methods__
condition_type = getattr(attr_value, "__condition_type__", "OR")
listeners[attr_name] = (condition_type, methods)
elif hasattr(attr_value, "__trigger_methods__"):
methods = attr_value.__trigger_methods__
condition_type = getattr(attr_value, "__condition_type__", "OR")
listeners[attr_name] = (condition_type, methods)
if hasattr(attr_value, "__is_router__") and attr_value.__is_router__:
routers.add(attr_name)
possible_returns = get_possible_return_constants(attr_value)
if possible_returns:
router_paths[attr_name] = possible_returns

if hasattr(attr_value, "__is_router__") and attr_value.__is_router__:
routers.add(attr_name)
possible_returns = get_possible_return_constants(attr_value)
if possible_returns:
router_paths[attr_name] = possible_returns

setattr(cls, "_start_methods", start_methods)
setattr(cls, "_listeners", listeners)
Expand All @@ -367,22 +376,70 @@ class _FlowGeneric(cls): # type: ignore
_FlowGeneric.__name__ = f"{cls.__name__}[{item.__name__}]"
return _FlowGeneric

def __init__(self) -> None:
def __init__(
self,
persistence: Optional[FlowPersistence] = None,
restore_uuid: Optional[str] = None,
**kwargs: Any,
) -> None:
"""Initialize a new Flow instance.

Args:
persistence: Optional persistence backend for storing flow states
restore_uuid: Optional UUID to restore state from persistence
**kwargs: Additional state values to initialize or override
"""
# Validate state model before initialization
if isinstance(self.initial_state, type):
if issubclass(self.initial_state, BaseModel) and not issubclass(self.initial_state, FlowState):
# Check if model has id field
model_fields = getattr(self.initial_state, "model_fields", None)
if not model_fields or "id" not in model_fields:
raise ValueError("Flow state model must have an 'id' field")

self._methods: Dict[str, Callable] = {}
self._state: T = self._create_initial_state()
self._method_execution_counts: Dict[str, int] = {}
self._pending_and_listeners: Dict[str, Set[str]] = {}
self._method_outputs: List[Any] = [] # List to store all method outputs
self._persistence: Optional[FlowPersistence] = persistence

# First restore from persistence if requested
if restore_uuid and self._persistence is not None:
stored_state = self._persistence.load_state(restore_uuid)
if stored_state:
self._restore_state(stored_state)

# Then apply any additional kwargs to override/update state
if kwargs:
self._initialize_state(kwargs)

self._telemetry.flow_creation_span(self.__class__.__name__)

# Register all flow-related methods
for method_name in dir(self):
if callable(getattr(self, method_name)) and not method_name.startswith(
"__"
):
self._methods[method_name] = getattr(self, method_name)
if not method_name.startswith("_"):
method = getattr(self, method_name)
# Check for any flow-related attributes
if (hasattr(method, "__is_flow_method__") or
hasattr(method, "__is_start_method__") or
hasattr(method, "__trigger_methods__") or
hasattr(method, "__is_router__")):
# Ensure method is bound to this instance
if not hasattr(method, "__self__"):
method = method.__get__(self, self.__class__)
self._methods[method_name] = method

def _create_initial_state(self) -> T:
"""Create and initialize flow state with UUID.

Returns:
New state instance with UUID initialized

Raises:
ValueError: If structured state model lacks 'id' field
TypeError: If state is neither BaseModel nor dictionary
"""
# Handle case where initial_state is None but we have a type parameter
if self.initial_state is None and hasattr(self, "_initial_state_T"):
state_type = getattr(self, "_initial_state_T")
Expand All @@ -394,6 +451,8 @@ def _create_initial_state(self) -> T:
class StateWithId(state_type, FlowState): # type: ignore
pass
return StateWithId() # type: ignore
elif state_type == dict:
return {"id": str(uuid4())} # type: ignore

# Handle case where no initial state is provided
if self.initial_state is None:
Expand All @@ -404,14 +463,19 @@ class StateWithId(state_type, FlowState): # type: ignore
if issubclass(self.initial_state, FlowState):
return self.initial_state() # type: ignore
elif issubclass(self.initial_state, BaseModel):
# Create a new type that includes the ID field
class StateWithId(self.initial_state, FlowState): # type: ignore
pass
return StateWithId() # type: ignore
# Validate that the model has an id field
model_fields = getattr(self.initial_state, "model_fields", None)
if not model_fields or "id" not in model_fields:
raise ValueError("Flow state model must have an 'id' field")
return self.initial_state() # type: ignore
elif self.initial_state == dict:
return {"id": str(uuid4())} # type: ignore

# Handle dictionary case
if isinstance(self.initial_state, dict) and "id" not in self.initial_state:
self.initial_state["id"] = str(uuid4())
# Handle dictionary instance case
if isinstance(self.initial_state, dict):
if "id" not in self.initial_state:
self.initial_state["id"] = str(uuid4())
return self.initial_state

return self.initial_state # type: ignore

Expand All @@ -425,14 +489,30 @@ def method_outputs(self) -> List[Any]:
return self._method_outputs

def _initialize_state(self, inputs: Dict[str, Any]) -> None:
"""Initialize or update flow state with new inputs.

Args:
inputs: Dictionary of state values to set/update

Raises:
ValueError: If validation fails for structured state
TypeError: If state is neither BaseModel nor dictionary
"""
if isinstance(self._state, dict):
# Preserve the ID when updating unstructured state
current_id = self._state.get("id")
self._state.update(inputs)
if current_id:
self._state["id"] = current_id
elif "id" not in self._state:
self._state["id"] = str(uuid4())
# For dict states, preserve existing ID or use provided one
if "id" in inputs:
# Create new state dict with provided ID
new_state = dict(inputs)
self._state.clear()
self._state.update(new_state)
else:
# Preserve existing ID if any
current_id = self._state.get("id")
self._state.update(inputs)
if current_id:
self._state["id"] = current_id
elif "id" not in self._state:
self._state["id"] = str(uuid4())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure I follow this, the logic is that if an id is passed the flow should load the state from the persistency layer, and override just new fields being sent, but before overriding it shoudl reload all the state fields based on the id

elif isinstance(self._state, BaseModel):
# Structured state
try:
Expand Down Expand Up @@ -469,6 +549,27 @@ class ModelWithExtraForbid(base_model): # type: ignore
raise ValueError(f"Invalid inputs for structured state: {e}") from e
else:
raise TypeError("State must be a BaseModel instance or a dictionary.")

def _restore_state(self, stored_state: Dict[str, Any]) -> None:
"""Restore flow state from persistence.

Args:
stored_state: Previously stored state to restore

Raises:
ValueError: If validation fails for structured state
TypeError: If state is neither BaseModel nor dictionary
"""
# When restoring from persistence, use the stored ID
stored_id = stored_state.get("id")
if not stored_id:
raise ValueError("Stored state must have an 'id' field")

# Create a new state dict with the stored ID
new_state = dict(stored_state)

# Initialize state with stored values
self._initialize_state(new_state)

def kickoff(self, inputs: Optional[Dict[str, Any]] = None) -> Any:
self.event_emitter.send(
Expand Down
18 changes: 18 additions & 0 deletions src/crewai/flow/persistence/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
"""
CrewAI Flow Persistence.

This module provides interfaces and implementations for persisting flow states.
"""

from typing import Any, Dict, TypeVar, Union

from pydantic import BaseModel

from crewai.flow.persistence.base import FlowPersistence
from crewai.flow.persistence.decorators import persist
from crewai.flow.persistence.sqlite import SQLiteFlowPersistence

__all__ = ["FlowPersistence", "persist", "SQLiteFlowPersistence"]

StateType = TypeVar('StateType', bound=Union[Dict[str, Any], BaseModel])
DictStateType = Dict[str, Any]
53 changes: 53 additions & 0 deletions src/crewai/flow/persistence/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
"""Base class for flow state persistence."""

import abc
from typing import Any, Dict, Optional, Union

from pydantic import BaseModel


class FlowPersistence(abc.ABC):
"""Abstract base class for flow state persistence.

This class defines the interface that all persistence implementations must follow.
It supports both structured (Pydantic BaseModel) and unstructured (dict) states.
"""

@abc.abstractmethod
def init_db(self) -> None:
"""Initialize the persistence backend.

This method should handle any necessary setup, such as:
- Creating tables
- Establishing connections
- Setting up indexes
"""
pass

@abc.abstractmethod
def save_state(
self,
flow_uuid: str,
method_name: str,
state_data: Union[Dict[str, Any], BaseModel]
) -> None:
"""Persist the flow state after method completion.

Args:
flow_uuid: Unique identifier for the flow instance
method_name: Name of the method that just completed
state_data: Current state data (either dict or Pydantic model)
"""
pass

@abc.abstractmethod
def load_state(self, flow_uuid: str) -> Optional[Dict[str, Any]]:
"""Load the most recent state for a given flow UUID.

Args:
flow_uuid: Unique identifier for the flow instance

Returns:
The most recent state as a dictionary, or None if no state exists
"""
pass
Loading
Loading