-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Browse files
Browse the repository at this point in the history
* Add @persist decorator with SQLite persistence - Add FlowPersistence abstract base class - Implement SQLiteFlowPersistence backend - Add @persist decorator for flow state persistence - Add tests for flow persistence functionality Co-Authored-By: Joe Moura <[email protected]> * Fix remaining merge conflicts in uv.lock - Remove stray merge conflict markers - Keep main's comprehensive platform-specific resolution markers - Preserve all required dependencies for persistence functionality Co-Authored-By: Joe Moura <[email protected]> * Fix final CUDA dependency conflicts in uv.lock - Resolve NVIDIA CUDA solver dependency conflicts - Use main's comprehensive platform checks - Ensure all merge conflict markers are removed - Preserve persistence-related dependencies Co-Authored-By: Joe Moura <[email protected]> * Fix nvidia-cusparse-cu12 dependency conflicts in uv.lock - Resolve NVIDIA CUSPARSE dependency conflicts - Use main's comprehensive platform checks - Complete systematic check of entire uv.lock file - Ensure all merge conflict markers are removed Co-Authored-By: Joe Moura <[email protected]> * Fix triton filelock dependency conflicts in uv.lock - Resolve triton package filelock dependency conflict - Use main's comprehensive platform checks - Complete final systematic check of entire uv.lock file - Ensure TOML file structure is valid Co-Authored-By: Joe Moura <[email protected]> * Fix merge conflict in crew_test.py - Remove duplicate assertion in test_multimodal_agent_live_image_analysis - Clean up conflict markers - Preserve test functionality Co-Authored-By: Joe Moura <[email protected]> * Clean up trailing merge conflict marker in crew_test.py - Remove remaining conflict marker at end of file - Preserve test functionality - Complete conflict resolution Co-Authored-By: Joe Moura <[email protected]> * Improve type safety in persistence implementation and resolve merge conflicts Co-Authored-By: Joe Moura <[email protected]> * fix: Add explicit type casting in _create_initial_state method Co-Authored-By: Joe Moura <[email protected]> * fix: Improve type safety in flow state handling with proper validation Co-Authored-By: Joe Moura <[email protected]> * fix: Improve type system with proper TypeVar scoping and validation Co-Authored-By: Joe Moura <[email protected]> * fix: Improve state restoration logic and add comprehensive tests Co-Authored-By: Joe Moura <[email protected]> * fix: Initialize FlowState instances without passing id to constructor Co-Authored-By: Joe Moura <[email protected]> * feat: Add class-level flow persistence decorator with SQLite default - Add class-level @persist decorator support - Set SQLiteFlowPersistence as default backend - Use db_storage_path for consistent database location - Improve async method handling and type safety - Add comprehensive docstrings and examples Co-Authored-By: Joe Moura <[email protected]> * fix: Sort imports in decorators.py to fix lint error Co-Authored-By: Joe Moura <[email protected]> * style: Organize imports according to PEP 8 standard Co-Authored-By: Joe Moura <[email protected]> * style: Format typing imports with line breaks for better readability Co-Authored-By: Joe Moura <[email protected]> * style: Simplify import organization to fix lint error Co-Authored-By: Joe Moura <[email protected]> * style: Fix import sorting using Ruff auto-fix Co-Authored-By: Joe Moura <[email protected]> --------- Co-authored-by: Devin AI <158243242+devin-ai-integration[bot]@users.noreply.github.com> Co-authored-by: Joe Moura <[email protected]> Co-authored-by: João Moura <[email protected]>
- Loading branch information
1 parent
3dc4428
commit 294f2cc
Showing
10 changed files
with
1,064 additions
and
141 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
""" | ||
CrewAI Flow Persistence. | ||
This module provides interfaces and implementations for persisting flow states. | ||
""" | ||
|
||
from typing import Any, Dict, TypeVar, Union | ||
|
||
from pydantic import BaseModel | ||
|
||
from crewai.flow.persistence.base import FlowPersistence | ||
from crewai.flow.persistence.decorators import persist | ||
from crewai.flow.persistence.sqlite import SQLiteFlowPersistence | ||
|
||
__all__ = ["FlowPersistence", "persist", "SQLiteFlowPersistence"] | ||
|
||
StateType = TypeVar('StateType', bound=Union[Dict[str, Any], BaseModel]) | ||
DictStateType = Dict[str, Any] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
"""Base class for flow state persistence.""" | ||
|
||
import abc | ||
from typing import Any, Dict, Optional, Union | ||
|
||
from pydantic import BaseModel | ||
|
||
|
||
class FlowPersistence(abc.ABC): | ||
"""Abstract base class for flow state persistence. | ||
This class defines the interface that all persistence implementations must follow. | ||
It supports both structured (Pydantic BaseModel) and unstructured (dict) states. | ||
""" | ||
|
||
@abc.abstractmethod | ||
def init_db(self) -> None: | ||
"""Initialize the persistence backend. | ||
This method should handle any necessary setup, such as: | ||
- Creating tables | ||
- Establishing connections | ||
- Setting up indexes | ||
""" | ||
pass | ||
|
||
@abc.abstractmethod | ||
def save_state( | ||
self, | ||
flow_uuid: str, | ||
method_name: str, | ||
state_data: Union[Dict[str, Any], BaseModel] | ||
) -> None: | ||
"""Persist the flow state after method completion. | ||
Args: | ||
flow_uuid: Unique identifier for the flow instance | ||
method_name: Name of the method that just completed | ||
state_data: Current state data (either dict or Pydantic model) | ||
""" | ||
pass | ||
|
||
@abc.abstractmethod | ||
def load_state(self, flow_uuid: str) -> Optional[Dict[str, Any]]: | ||
"""Load the most recent state for a given flow UUID. | ||
Args: | ||
flow_uuid: Unique identifier for the flow instance | ||
Returns: | ||
The most recent state as a dictionary, or None if no state exists | ||
""" | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,177 @@ | ||
""" | ||
Decorators for flow state persistence. | ||
Example: | ||
```python | ||
from crewai.flow.flow import Flow, start | ||
from crewai.flow.persistence import persist, SQLiteFlowPersistence | ||
class MyFlow(Flow): | ||
@start() | ||
@persist(SQLiteFlowPersistence()) | ||
def sync_method(self): | ||
# Synchronous method implementation | ||
pass | ||
@start() | ||
@persist(SQLiteFlowPersistence()) | ||
async def async_method(self): | ||
# Asynchronous method implementation | ||
await some_async_operation() | ||
``` | ||
""" | ||
|
||
import asyncio | ||
import functools | ||
import inspect | ||
import logging | ||
from typing import ( | ||
Any, | ||
Callable, | ||
Dict, | ||
Optional, | ||
Type, | ||
TypeVar, | ||
Union, | ||
cast, | ||
get_type_hints, | ||
) | ||
|
||
from pydantic import BaseModel | ||
|
||
from crewai.flow.persistence.base import FlowPersistence | ||
from crewai.flow.persistence.sqlite import SQLiteFlowPersistence | ||
|
||
logger = logging.getLogger(__name__) | ||
T = TypeVar("T") | ||
|
||
|
||
def persist(persistence: Optional[FlowPersistence] = None): | ||
"""Decorator to persist flow state. | ||
This decorator can be applied at either the class level or method level. | ||
When applied at the class level, it automatically persists all flow method | ||
states. When applied at the method level, it persists only that method's | ||
state. | ||
Args: | ||
persistence: Optional FlowPersistence implementation to use. | ||
If not provided, uses SQLiteFlowPersistence. | ||
Returns: | ||
A decorator that can be applied to either a class or method | ||
Raises: | ||
ValueError: If the flow state doesn't have an 'id' field | ||
RuntimeError: If state persistence fails | ||
Example: | ||
@persist # Class-level persistence with default SQLite | ||
class MyFlow(Flow[MyState]): | ||
@start() | ||
def begin(self): | ||
pass | ||
""" | ||
def _persist_state(flow_instance: Any, method_name: str, persistence_instance: FlowPersistence) -> None: | ||
"""Helper to persist state with error handling.""" | ||
try: | ||
# Get flow UUID from state | ||
state = getattr(flow_instance, 'state', None) | ||
if state is None: | ||
raise ValueError("Flow instance has no state") | ||
|
||
flow_uuid: Optional[str] = None | ||
if isinstance(state, dict): | ||
flow_uuid = state.get('id') | ||
elif isinstance(state, BaseModel): | ||
flow_uuid = getattr(state, 'id', None) | ||
|
||
if not flow_uuid: | ||
raise ValueError( | ||
"Flow state must have an 'id' field for persistence" | ||
) | ||
|
||
# Persist the state | ||
persistence_instance.save_state( | ||
flow_uuid=flow_uuid, | ||
method_name=method_name, | ||
state_data=state, | ||
) | ||
except Exception as e: | ||
logger.error( | ||
f"Failed to persist state for method {method_name}: {str(e)}" | ||
) | ||
raise RuntimeError(f"State persistence failed: {str(e)}") from e | ||
|
||
def decorator(target: Union[Type, Callable[..., T]]) -> Union[Type, Callable[..., T]]: | ||
"""Decorator that handles both class and method decoration.""" | ||
actual_persistence = persistence or SQLiteFlowPersistence() | ||
|
||
if isinstance(target, type): | ||
# Class decoration | ||
class_methods = {} | ||
for name, method in target.__dict__.items(): | ||
if callable(method) and hasattr(method, "__is_flow_method__"): | ||
# Wrap each flow method with persistence | ||
if asyncio.iscoroutinefunction(method): | ||
@functools.wraps(method) | ||
async def class_async_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any: | ||
method_coro = method(self, *args, **kwargs) | ||
if asyncio.iscoroutine(method_coro): | ||
result = await method_coro | ||
else: | ||
result = method_coro | ||
_persist_state(self, method.__name__, actual_persistence) | ||
return result | ||
class_methods[name] = class_async_wrapper | ||
else: | ||
@functools.wraps(method) | ||
def class_sync_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any: | ||
result = method(self, *args, **kwargs) | ||
_persist_state(self, method.__name__, actual_persistence) | ||
return result | ||
class_methods[name] = class_sync_wrapper | ||
|
||
# Preserve flow-specific attributes | ||
for attr in ["__is_start_method__", "__trigger_methods__", "__condition_type__", "__is_router__"]: | ||
if hasattr(method, attr): | ||
setattr(class_methods[name], attr, getattr(method, attr)) | ||
setattr(class_methods[name], "__is_flow_method__", True) | ||
|
||
# Update class with wrapped methods | ||
for name, method in class_methods.items(): | ||
setattr(target, name, method) | ||
return target | ||
else: | ||
# Method decoration | ||
method = target | ||
setattr(method, "__is_flow_method__", True) | ||
|
||
if asyncio.iscoroutinefunction(method): | ||
@functools.wraps(method) | ||
async def method_async_wrapper(flow_instance: Any, *args: Any, **kwargs: Any) -> T: | ||
method_coro = method(flow_instance, *args, **kwargs) | ||
if asyncio.iscoroutine(method_coro): | ||
result = await method_coro | ||
else: | ||
result = method_coro | ||
_persist_state(flow_instance, method.__name__, actual_persistence) | ||
return result | ||
for attr in ["__is_start_method__", "__trigger_methods__", "__condition_type__", "__is_router__"]: | ||
if hasattr(method, attr): | ||
setattr(method_async_wrapper, attr, getattr(method, attr)) | ||
setattr(method_async_wrapper, "__is_flow_method__", True) | ||
return cast(Callable[..., T], method_async_wrapper) | ||
else: | ||
@functools.wraps(method) | ||
def method_sync_wrapper(flow_instance: Any, *args: Any, **kwargs: Any) -> T: | ||
result = method(flow_instance, *args, **kwargs) | ||
_persist_state(flow_instance, method.__name__, actual_persistence) | ||
return result | ||
for attr in ["__is_start_method__", "__trigger_methods__", "__condition_type__", "__is_router__"]: | ||
if hasattr(method, attr): | ||
setattr(method_sync_wrapper, attr, getattr(method, attr)) | ||
setattr(method_sync_wrapper, "__is_flow_method__", True) | ||
return cast(Callable[..., T], method_sync_wrapper) | ||
|
||
return decorator |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,124 @@ | ||
""" | ||
SQLite-based implementation of flow state persistence. | ||
""" | ||
|
||
import json | ||
import os | ||
import sqlite3 | ||
import tempfile | ||
from datetime import datetime | ||
from typing import Any, Dict, Optional, Union | ||
|
||
from pydantic import BaseModel | ||
|
||
from crewai.flow.persistence.base import FlowPersistence | ||
|
||
|
||
class SQLiteFlowPersistence(FlowPersistence): | ||
"""SQLite-based implementation of flow state persistence. | ||
This class provides a simple, file-based persistence implementation using SQLite. | ||
It's suitable for development and testing, or for production use cases with | ||
moderate performance requirements. | ||
""" | ||
|
||
db_path: str # Type annotation for instance variable | ||
|
||
def __init__(self, db_path: Optional[str] = None): | ||
"""Initialize SQLite persistence. | ||
Args: | ||
db_path: Path to the SQLite database file. If not provided, uses | ||
db_storage_path() from utilities.paths. | ||
Raises: | ||
ValueError: If db_path is invalid | ||
""" | ||
from crewai.utilities.paths import db_storage_path | ||
# Get path from argument or default location | ||
path = db_path or db_storage_path() | ||
|
||
if not path: | ||
raise ValueError("Database path must be provided") | ||
|
||
self.db_path = path # Now mypy knows this is str | ||
self.init_db() | ||
|
||
def init_db(self) -> None: | ||
"""Create the necessary tables if they don't exist.""" | ||
with sqlite3.connect(self.db_path) as conn: | ||
conn.execute(""" | ||
CREATE TABLE IF NOT EXISTS flow_states ( | ||
id INTEGER PRIMARY KEY AUTOINCREMENT, | ||
flow_uuid TEXT NOT NULL, | ||
method_name TEXT NOT NULL, | ||
timestamp DATETIME NOT NULL, | ||
state_json TEXT NOT NULL | ||
) | ||
""") | ||
# Add index for faster UUID lookups | ||
conn.execute(""" | ||
CREATE INDEX IF NOT EXISTS idx_flow_states_uuid | ||
ON flow_states(flow_uuid) | ||
""") | ||
|
||
def save_state( | ||
self, | ||
flow_uuid: str, | ||
method_name: str, | ||
state_data: Union[Dict[str, Any], BaseModel], | ||
) -> None: | ||
"""Save the current flow state to SQLite. | ||
Args: | ||
flow_uuid: Unique identifier for the flow instance | ||
method_name: Name of the method that just completed | ||
state_data: Current state data (either dict or Pydantic model) | ||
""" | ||
# Convert state_data to dict, handling both Pydantic and dict cases | ||
if isinstance(state_data, BaseModel): | ||
state_dict = dict(state_data) # Use dict() for better type compatibility | ||
elif isinstance(state_data, dict): | ||
state_dict = state_data | ||
else: | ||
raise ValueError( | ||
f"state_data must be either a Pydantic BaseModel or dict, got {type(state_data)}" | ||
) | ||
|
||
with sqlite3.connect(self.db_path) as conn: | ||
conn.execute(""" | ||
INSERT INTO flow_states ( | ||
flow_uuid, | ||
method_name, | ||
timestamp, | ||
state_json | ||
) VALUES (?, ?, ?, ?) | ||
""", ( | ||
flow_uuid, | ||
method_name, | ||
datetime.utcnow().isoformat(), | ||
json.dumps(state_dict), | ||
)) | ||
|
||
def load_state(self, flow_uuid: str) -> Optional[Dict[str, Any]]: | ||
"""Load the most recent state for a given flow UUID. | ||
Args: | ||
flow_uuid: Unique identifier for the flow instance | ||
Returns: | ||
The most recent state as a dictionary, or None if no state exists | ||
""" | ||
with sqlite3.connect(self.db_path) as conn: | ||
cursor = conn.execute(""" | ||
SELECT state_json | ||
FROM flow_states | ||
WHERE flow_uuid = ? | ||
ORDER BY id DESC | ||
LIMIT 1 | ||
""", (flow_uuid,)) | ||
row = cursor.fetchone() | ||
|
||
if row: | ||
return json.loads(row[0]) | ||
return None |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.