Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add @persist decorator with FlowPersistence interface #1892

Open
wants to merge 21 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
aeacbbf
Add @persist decorator with SQLite persistence
devin-ai-integration[bot] Jan 14, 2025
668e195
Resolve merge conflicts integrating FlowState with new persistence
devin-ai-integration[bot] Jan 14, 2025
6d08831
Fix remaining merge conflicts in uv.lock
devin-ai-integration[bot] Jan 14, 2025
cc87e08
Fix final CUDA dependency conflicts in uv.lock
devin-ai-integration[bot] Jan 14, 2025
357ca68
Fix nvidia-cusparse-cu12 dependency conflicts in uv.lock
devin-ai-integration[bot] Jan 14, 2025
59e9afa
Fix triton filelock dependency conflicts in uv.lock
devin-ai-integration[bot] Jan 14, 2025
3c0101f
Fix merge conflict in crew_test.py
devin-ai-integration[bot] Jan 14, 2025
6f5e73d
Clean up trailing merge conflict marker in crew_test.py
devin-ai-integration[bot] Jan 14, 2025
e3e7e67
Improve type safety in persistence implementation and resolve merge c…
devin-ai-integration[bot] Jan 14, 2025
4e0a7ba
fix: Add explicit type casting in _create_initial_state method
devin-ai-integration[bot] Jan 14, 2025
212e60f
fix: Improve type safety in flow state handling with proper validation
devin-ai-integration[bot] Jan 14, 2025
9625630
fix: Improve type system with proper TypeVar scoping and validation
devin-ai-integration[bot] Jan 14, 2025
785e97a
fix: Improve state restoration logic and add comprehensive tests
devin-ai-integration[bot] Jan 15, 2025
1b6207d
fix: Initialize FlowState instances without passing id to constructor
devin-ai-integration[bot] Jan 15, 2025
2271447
Merge branch 'main' into devin/1736848480-persist-decorator
joaomdmoura Jan 15, 2025
37dc5ee
feat: Add class-level flow persistence decorator with SQLite default
devin-ai-integration[bot] Jan 15, 2025
7cd11f5
fix: Sort imports in decorators.py to fix lint error
devin-ai-integration[bot] Jan 15, 2025
40afbb9
style: Organize imports according to PEP 8 standard
devin-ai-integration[bot] Jan 15, 2025
1ab8fe3
style: Format typing imports with line breaks for better readability
devin-ai-integration[bot] Jan 15, 2025
8396f6e
style: Simplify import organization to fix lint error
devin-ai-integration[bot] Jan 15, 2025
a4bc624
style: Fix import sorting using Ruff auto-fix
devin-ai-integration[bot] Jan 15, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
390 changes: 327 additions & 63 deletions src/crewai/flow/flow.py

Large diffs are not rendered by default.

18 changes: 18 additions & 0 deletions src/crewai/flow/persistence/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
"""
CrewAI Flow Persistence.

This module provides interfaces and implementations for persisting flow states.
"""

from typing import Any, Dict, TypeVar, Union

from pydantic import BaseModel

from crewai.flow.persistence.base import FlowPersistence
from crewai.flow.persistence.decorators import persist
from crewai.flow.persistence.sqlite import SQLiteFlowPersistence

__all__ = ["FlowPersistence", "persist", "SQLiteFlowPersistence"]

StateType = TypeVar('StateType', bound=Union[Dict[str, Any], BaseModel])
DictStateType = Dict[str, Any]
53 changes: 53 additions & 0 deletions src/crewai/flow/persistence/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
"""Base class for flow state persistence."""

import abc
from typing import Any, Dict, Optional, Union

from pydantic import BaseModel


class FlowPersistence(abc.ABC):
"""Abstract base class for flow state persistence.

This class defines the interface that all persistence implementations must follow.
It supports both structured (Pydantic BaseModel) and unstructured (dict) states.
"""

@abc.abstractmethod
def init_db(self) -> None:
"""Initialize the persistence backend.

This method should handle any necessary setup, such as:
- Creating tables
- Establishing connections
- Setting up indexes
"""
pass

@abc.abstractmethod
def save_state(
self,
flow_uuid: str,
method_name: str,
state_data: Union[Dict[str, Any], BaseModel]
) -> None:
"""Persist the flow state after method completion.

Args:
flow_uuid: Unique identifier for the flow instance
method_name: Name of the method that just completed
state_data: Current state data (either dict or Pydantic model)
"""
pass

@abc.abstractmethod
def load_state(self, flow_uuid: str) -> Optional[Dict[str, Any]]:
"""Load the most recent state for a given flow UUID.

Args:
flow_uuid: Unique identifier for the flow instance

Returns:
The most recent state as a dictionary, or None if no state exists
"""
pass
177 changes: 177 additions & 0 deletions src/crewai/flow/persistence/decorators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
"""
Decorators for flow state persistence.

Example:
```python
from crewai.flow.flow import Flow, start
from crewai.flow.persistence import persist, SQLiteFlowPersistence

class MyFlow(Flow):
@start()
@persist(SQLiteFlowPersistence())
def sync_method(self):
# Synchronous method implementation
pass

@start()
@persist(SQLiteFlowPersistence())
async def async_method(self):
# Asynchronous method implementation
await some_async_operation()
```
"""

import asyncio
import functools
import inspect
import logging
from typing import (
Any,
Callable,
Dict,
Optional,
Type,
TypeVar,
Union,
cast,
get_type_hints,
)

from pydantic import BaseModel

from crewai.flow.persistence.base import FlowPersistence
from crewai.flow.persistence.sqlite import SQLiteFlowPersistence

logger = logging.getLogger(__name__)
T = TypeVar("T")


def persist(persistence: Optional[FlowPersistence] = None):
"""Decorator to persist flow state.

This decorator can be applied at either the class level or method level.
When applied at the class level, it automatically persists all flow method
states. When applied at the method level, it persists only that method's
state.

Args:
persistence: Optional FlowPersistence implementation to use.
If not provided, uses SQLiteFlowPersistence.

Returns:
A decorator that can be applied to either a class or method

Raises:
ValueError: If the flow state doesn't have an 'id' field
RuntimeError: If state persistence fails

Example:
@persist # Class-level persistence with default SQLite
class MyFlow(Flow[MyState]):
@start()
def begin(self):
pass
"""
def _persist_state(flow_instance: Any, method_name: str, persistence_instance: FlowPersistence) -> None:
"""Helper to persist state with error handling."""
try:
# Get flow UUID from state
state = getattr(flow_instance, 'state', None)
if state is None:
raise ValueError("Flow instance has no state")

flow_uuid: Optional[str] = None
if isinstance(state, dict):
flow_uuid = state.get('id')
elif isinstance(state, BaseModel):
flow_uuid = getattr(state, 'id', None)

if not flow_uuid:
raise ValueError(
"Flow state must have an 'id' field for persistence"
)

# Persist the state
persistence_instance.save_state(
flow_uuid=flow_uuid,
method_name=method_name,
state_data=state,
)
except Exception as e:
logger.error(
f"Failed to persist state for method {method_name}: {str(e)}"
)
raise RuntimeError(f"State persistence failed: {str(e)}") from e

def decorator(target: Union[Type, Callable[..., T]]) -> Union[Type, Callable[..., T]]:
"""Decorator that handles both class and method decoration."""
actual_persistence = persistence or SQLiteFlowPersistence()

if isinstance(target, type):
# Class decoration
class_methods = {}
for name, method in target.__dict__.items():
if callable(method) and hasattr(method, "__is_flow_method__"):
# Wrap each flow method with persistence
if asyncio.iscoroutinefunction(method):
@functools.wraps(method)
async def class_async_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
method_coro = method(self, *args, **kwargs)
if asyncio.iscoroutine(method_coro):
result = await method_coro
else:
result = method_coro
_persist_state(self, method.__name__, actual_persistence)
return result
class_methods[name] = class_async_wrapper
else:
@functools.wraps(method)
def class_sync_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
result = method(self, *args, **kwargs)
_persist_state(self, method.__name__, actual_persistence)
return result
class_methods[name] = class_sync_wrapper

# Preserve flow-specific attributes
for attr in ["__is_start_method__", "__trigger_methods__", "__condition_type__", "__is_router__"]:
if hasattr(method, attr):
setattr(class_methods[name], attr, getattr(method, attr))
setattr(class_methods[name], "__is_flow_method__", True)

# Update class with wrapped methods
for name, method in class_methods.items():
setattr(target, name, method)
return target
else:
# Method decoration
method = target
setattr(method, "__is_flow_method__", True)

if asyncio.iscoroutinefunction(method):
@functools.wraps(method)
async def method_async_wrapper(flow_instance: Any, *args: Any, **kwargs: Any) -> T:
method_coro = method(flow_instance, *args, **kwargs)
if asyncio.iscoroutine(method_coro):
result = await method_coro
else:
result = method_coro
_persist_state(flow_instance, method.__name__, actual_persistence)
return result
for attr in ["__is_start_method__", "__trigger_methods__", "__condition_type__", "__is_router__"]:
if hasattr(method, attr):
setattr(method_async_wrapper, attr, getattr(method, attr))
setattr(method_async_wrapper, "__is_flow_method__", True)
return cast(Callable[..., T], method_async_wrapper)
else:
@functools.wraps(method)
def method_sync_wrapper(flow_instance: Any, *args: Any, **kwargs: Any) -> T:
result = method(flow_instance, *args, **kwargs)
_persist_state(flow_instance, method.__name__, actual_persistence)
return result
for attr in ["__is_start_method__", "__trigger_methods__", "__condition_type__", "__is_router__"]:
if hasattr(method, attr):
setattr(method_sync_wrapper, attr, getattr(method, attr))
setattr(method_sync_wrapper, "__is_flow_method__", True)
return cast(Callable[..., T], method_sync_wrapper)

return decorator
124 changes: 124 additions & 0 deletions src/crewai/flow/persistence/sqlite.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
"""
SQLite-based implementation of flow state persistence.
"""

import json
import os
import sqlite3
import tempfile
from datetime import datetime
from typing import Any, Dict, Optional, Union

from pydantic import BaseModel

from crewai.flow.persistence.base import FlowPersistence


class SQLiteFlowPersistence(FlowPersistence):
"""SQLite-based implementation of flow state persistence.

This class provides a simple, file-based persistence implementation using SQLite.
It's suitable for development and testing, or for production use cases with
moderate performance requirements.
"""

db_path: str # Type annotation for instance variable

def __init__(self, db_path: Optional[str] = None):
"""Initialize SQLite persistence.

Args:
db_path: Path to the SQLite database file. If not provided, uses
db_storage_path() from utilities.paths.

Raises:
ValueError: If db_path is invalid
"""
from crewai.utilities.paths import db_storage_path
# Get path from argument or default location
path = db_path or db_storage_path()

if not path:
raise ValueError("Database path must be provided")

self.db_path = path # Now mypy knows this is str
self.init_db()

def init_db(self) -> None:
"""Create the necessary tables if they don't exist."""
with sqlite3.connect(self.db_path) as conn:
conn.execute("""
CREATE TABLE IF NOT EXISTS flow_states (
id INTEGER PRIMARY KEY AUTOINCREMENT,
flow_uuid TEXT NOT NULL,
method_name TEXT NOT NULL,
timestamp DATETIME NOT NULL,
state_json TEXT NOT NULL
)
""")
# Add index for faster UUID lookups
conn.execute("""
CREATE INDEX IF NOT EXISTS idx_flow_states_uuid
ON flow_states(flow_uuid)
""")

def save_state(
self,
flow_uuid: str,
method_name: str,
state_data: Union[Dict[str, Any], BaseModel],
) -> None:
"""Save the current flow state to SQLite.

Args:
flow_uuid: Unique identifier for the flow instance
method_name: Name of the method that just completed
state_data: Current state data (either dict or Pydantic model)
"""
# Convert state_data to dict, handling both Pydantic and dict cases
if isinstance(state_data, BaseModel):
state_dict = dict(state_data) # Use dict() for better type compatibility
elif isinstance(state_data, dict):
state_dict = state_data
else:
raise ValueError(
f"state_data must be either a Pydantic BaseModel or dict, got {type(state_data)}"
)

with sqlite3.connect(self.db_path) as conn:
conn.execute("""
INSERT INTO flow_states (
flow_uuid,
method_name,
timestamp,
state_json
) VALUES (?, ?, ?, ?)
""", (
flow_uuid,
method_name,
datetime.utcnow().isoformat(),
json.dumps(state_dict),
))

def load_state(self, flow_uuid: str) -> Optional[Dict[str, Any]]:
"""Load the most recent state for a given flow UUID.

Args:
flow_uuid: Unique identifier for the flow instance

Returns:
The most recent state as a dictionary, or None if no state exists
"""
with sqlite3.connect(self.db_path) as conn:
cursor = conn.execute("""
SELECT state_json
FROM flow_states
WHERE flow_uuid = ?
ORDER BY id DESC
LIMIT 1
""", (flow_uuid,))
row = cursor.fetchone()

if row:
return json.loads(row[0])
return None
10 changes: 7 additions & 3 deletions src/crewai/utilities/paths.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,18 @@

"""Path management utilities for CrewAI storage and configuration."""

def db_storage_path():
"""Returns the path for database storage."""
def db_storage_path() -> str:
"""Returns the path for SQLite database storage.

Returns:
str: Full path to the SQLite database file
"""
app_name = get_project_directory_name()
app_author = "CrewAI"

data_dir = Path(appdirs.user_data_dir(app_name, app_author))
data_dir.mkdir(parents=True, exist_ok=True)
return data_dir
return str(data_dir / "crewai_flows.db")


def get_project_directory_name():
Expand Down
Loading
Loading