Skip to content

Commit

Permalink
Adding Autocomplete to OSS (#1198)
Browse files Browse the repository at this point in the history
* Cleaned up model_config

* Fix pydantic issues

* 99% done with autocomplete

* fixed test issues

* Fix type checking issues
  • Loading branch information
bhancockio authored Aug 16, 2024
1 parent 3451b6f commit bf7372f
Show file tree
Hide file tree
Showing 14 changed files with 109 additions and 121 deletions.
11 changes: 6 additions & 5 deletions src/crewai/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,10 +113,11 @@ class Agent(BaseAgent):
description="Maximum number of retries for an agent to execute a task when an error occurs.",
)

def __init__(__pydantic_self__, **data):
config = data.pop("config", {})
super().__init__(**config, **data)
__pydantic_self__.agent_ops_agent_name = __pydantic_self__.role
@model_validator(mode="after")
def set_agent_ops_agent_name(self) -> "Agent":
"""Set agent ops agent name."""
self.agent_ops_agent_name = self.role
return self

@model_validator(mode="after")
def set_agent_executor(self) -> "Agent":
Expand Down Expand Up @@ -213,7 +214,7 @@ def execute_task(
raise e
result = self.execute_task(task, context, tools)

if self.max_rpm:
if self.max_rpm and self._rpm_controller:
self._rpm_controller.stop_rpm_counter()

# If there was any tool in self.tools_results that had result_as_answer
Expand Down
23 changes: 9 additions & 14 deletions src/crewai/agents/agent_builder/base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from pydantic import (
UUID4,
BaseModel,
ConfigDict,
Field,
InstanceOf,
PrivateAttr,
Expand Down Expand Up @@ -74,12 +73,17 @@ class BaseAgent(ABC, BaseModel):
"""

__hash__ = object.__hash__ # type: ignore
_logger: Logger = PrivateAttr()
_rpm_controller: RPMController = PrivateAttr(default=None)
_logger: Logger = PrivateAttr(default_factory=lambda: Logger(verbose=False))
_rpm_controller: Optional[RPMController] = PrivateAttr(default=None)
_request_within_rpm_limit: Any = PrivateAttr(default=None)
formatting_errors: int = 0
model_config = ConfigDict(arbitrary_types_allowed=True)
_original_role: Optional[str] = PrivateAttr(default=None)
_original_goal: Optional[str] = PrivateAttr(default=None)
_original_backstory: Optional[str] = PrivateAttr(default=None)
_token_process: TokenProcess = PrivateAttr(default_factory=TokenProcess)
id: UUID4 = Field(default_factory=uuid.uuid4, frozen=True)
formatting_errors: int = Field(
default=0, description="Number of formatting errors."
)
role: str = Field(description="Role of the agent")
goal: str = Field(description="Objective of the agent")
backstory: str = Field(description="Backstory of the agent")
Expand Down Expand Up @@ -123,15 +127,6 @@ class BaseAgent(ABC, BaseModel):
default=None, description="Maximum number of tokens for the agent's execution."
)

_original_role: str | None = None
_original_goal: str | None = None
_original_backstory: str | None = None
_token_process: TokenProcess = TokenProcess()

def __init__(__pydantic_self__, **data):
config = data.pop("config", {})
super().__init__(**config, **data)

@model_validator(mode="after")
def set_config_attributes(self):
if self.config:
Expand Down
11 changes: 5 additions & 6 deletions src/crewai/agents/cache/cache_handler.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
from typing import Optional
from typing import Any, Dict, Optional

from pydantic import BaseModel, PrivateAttr

class CacheHandler:
"""Callback handler for tool usage."""

_cache: dict = {}
class CacheHandler(BaseModel):
"""Callback handler for tool usage."""

def __init__(self):
self._cache = {}
_cache: Dict[str, Any] = PrivateAttr(default_factory=dict)

def add(self, tool, input, output):
self._cache[f"{tool}-{input}"] = output
Expand Down
2 changes: 0 additions & 2 deletions src/crewai/crew.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from pydantic import (
UUID4,
BaseModel,
ConfigDict,
Field,
InstanceOf,
Json,
Expand Down Expand Up @@ -105,7 +104,6 @@ class Crew(BaseModel):

name: Optional[str] = Field(default=None)
cache: bool = Field(default=True)
model_config = ConfigDict(arbitrary_types_allowed=True)
tasks: List[Task] = Field(default_factory=list)
agents: List[BaseAgent] = Field(default_factory=list)
process: Process = Field(default=Process.sequential)
Expand Down
2 changes: 0 additions & 2 deletions src/crewai/project/crew_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,12 @@

import yaml
from dotenv import load_dotenv
from pydantic import ConfigDict

load_dotenv()


def CrewBase(cls):
class WrappedClass(cls):
model_config = ConfigDict(arbitrary_types_allowed=True)
is_crew_class: bool = True # type: ignore

# Get the directory of the class being decorated
Expand Down
20 changes: 10 additions & 10 deletions src/crewai/project/pipeline_base.py
Original file line number Diff line number Diff line change
@@ -1,40 +1,40 @@
from typing import Callable, Dict

from pydantic import ConfigDict
from typing import Any, Callable, Dict, List, Type, Union

from crewai.crew import Crew
from crewai.pipeline.pipeline import Pipeline
from crewai.routers.router import Router

PipelineStage = Union[Crew, List[Crew], Router]


# TODO: Could potentially remove. Need to check with @joao and @gui if this is needed for CrewAI+
def PipelineBase(cls):
def PipelineBase(cls: Type[Any]) -> Type[Any]:
class WrappedClass(cls):
model_config = ConfigDict(arbitrary_types_allowed=True)
is_pipeline_class: bool = True # type: ignore
stages: List[PipelineStage]

def __init__(self, *args, **kwargs):
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self.stages = []
self._map_pipeline_components()

def _get_all_functions(self):
def _get_all_functions(self) -> Dict[str, Callable[..., Any]]:
return {
name: getattr(self, name)
for name in dir(self)
if callable(getattr(self, name))
}

def _filter_functions(
self, functions: Dict[str, Callable], attribute: str
) -> Dict[str, Callable]:
self, functions: Dict[str, Callable[..., Any]], attribute: str
) -> Dict[str, Callable[..., Any]]:
return {
name: func
for name, func in functions.items()
if hasattr(func, attribute)
}

def _map_pipeline_components(self):
def _map_pipeline_components(self) -> None:
all_functions = self._get_all_functions()
crew_functions = self._filter_functions(all_functions, "is_crew")
router_functions = self._filter_functions(all_functions, "is_router")
Expand Down
42 changes: 18 additions & 24 deletions src/crewai/routers/router.py
Original file line number Diff line number Diff line change
@@ -1,49 +1,43 @@
from copy import deepcopy
from typing import Any, Callable, Dict, Generic, Tuple, TypeVar
from typing import Any, Callable, Dict, Tuple

from pydantic import BaseModel, Field, PrivateAttr

T = TypeVar("T", bound=Dict[str, Any])
U = TypeVar("U")

class Route(BaseModel):
condition: Callable[[Dict[str, Any]], bool]
pipeline: Any

class Route(Generic[T, U]):
condition: Callable[[T], bool]
pipeline: U

def __init__(self, condition: Callable[[T], bool], pipeline: U):
self.condition = condition
self.pipeline = pipeline


class Router(BaseModel, Generic[T, U]):
routes: Dict[str, Route[T, U]] = Field(
class Router(BaseModel):
routes: Dict[str, Route] = Field(
default_factory=dict,
description="Dictionary of route names to (condition, pipeline) tuples",
)
default: U = Field(..., description="Default pipeline if no conditions are met")
default: Any = Field(..., description="Default pipeline if no conditions are met")
_route_types: Dict[str, type] = PrivateAttr(default_factory=dict)

model_config = {"arbitrary_types_allowed": True}
class Config:
arbitrary_types_allowed = True

def __init__(self, routes: Dict[str, Route[T, U]], default: U, **data):
def __init__(self, routes: Dict[str, Route], default: Any, **data):
super().__init__(routes=routes, default=default, **data)
self._check_copyable(default)
for name, route in routes.items():
self._check_copyable(route.pipeline)
self._route_types[name] = type(route.pipeline)

@staticmethod
def _check_copyable(obj):
def _check_copyable(obj: Any) -> None:
if not hasattr(obj, "copy") or not callable(getattr(obj, "copy")):
raise ValueError(f"Object of type {type(obj)} must have a 'copy' method")

def add_route(
self,
name: str,
condition: Callable[[T], bool],
pipeline: U,
) -> "Router[T, U]":
condition: Callable[[Dict[str, Any]], bool],
pipeline: Any,
) -> "Router":
"""
Add a named route with its condition and corresponding pipeline to the router.
Expand All @@ -60,7 +54,7 @@ def add_route(
self._route_types[name] = type(pipeline)
return self

def route(self, input_data: T) -> Tuple[U, str]:
def route(self, input_data: Dict[str, Any]) -> Tuple[Any, str]:
"""
Evaluate the input against the conditions and return the appropriate pipeline.
Expand All @@ -76,15 +70,15 @@ def route(self, input_data: T) -> Tuple[U, str]:

return self.default, "default"

def copy(self) -> "Router[T, U]":
def copy(self) -> "Router":
"""Create a deep copy of the Router."""
new_routes = {
name: Route(
condition=deepcopy(route.condition),
pipeline=route.pipeline.copy(), # type: ignore
pipeline=route.pipeline.copy(),
)
for name, route in self.routes.items()
}
new_default = self.default.copy() # type: ignore
new_default = self.default.copy()

return Router(routes=new_routes, default=new_default)
38 changes: 15 additions & 23 deletions src/crewai/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,14 @@
from typing import Any, Dict, List, Optional, Tuple, Type, Union

from opentelemetry.trace import Span
from pydantic import UUID4, BaseModel, Field, field_validator, model_validator
from pydantic import (
UUID4,
BaseModel,
Field,
PrivateAttr,
field_validator,
model_validator,
)
from pydantic_core import PydanticCustomError

from crewai.agents.agent_builder.base_agent import BaseAgent
Expand Down Expand Up @@ -39,9 +46,6 @@ class Task(BaseModel):
tools: List of tools/resources limited for task execution.
"""

class Config:
arbitrary_types_allowed = True

__hash__ = object.__hash__ # type: ignore
used_tools: int = 0
tools_errors: int = 0
Expand Down Expand Up @@ -104,16 +108,12 @@ class Config:
default=None,
)

_telemetry: Telemetry
_execution_span: Span | None = None
_original_description: str | None = None
_original_expected_output: str | None = None
_thread: threading.Thread | None = None
_execution_time: float | None = None

def __init__(__pydantic_self__, **data):
config = data.pop("config", {})
super().__init__(**config, **data)
_telemetry: Telemetry = PrivateAttr(default_factory=Telemetry)
_execution_span: Optional[Span] = PrivateAttr(default=None)
_original_description: Optional[str] = PrivateAttr(default=None)
_original_expected_output: Optional[str] = PrivateAttr(default=None)
_thread: Optional[threading.Thread] = PrivateAttr(default=None)
_execution_time: Optional[float] = PrivateAttr(default=None)

@field_validator("id", mode="before")
@classmethod
Expand All @@ -137,12 +137,6 @@ def output_file_validation(cls, value: str) -> str:
return value[1:]
return value

@model_validator(mode="after")
def set_private_attrs(self) -> "Task":
"""Set private attributes."""
self._telemetry = Telemetry()
return self

@model_validator(mode="after")
def set_attributes_based_on_config(self) -> "Task":
"""Set attributes based on the agent configuration."""
Expand Down Expand Up @@ -263,9 +257,7 @@ def _execute_core(
content = (
json_output
if json_output
else pydantic_output.model_dump_json()
if pydantic_output
else result
else pydantic_output.model_dump_json() if pydantic_output else result
)
self._save_file(content)

Expand Down
5 changes: 2 additions & 3 deletions src/crewai/tools/cache_tools.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
from langchain.tools import StructuredTool
from pydantic import BaseModel, ConfigDict, Field
from pydantic import BaseModel, Field

from crewai.agents.cache import CacheHandler


class CacheTools(BaseModel):
"""Default tools to hit the cache."""

model_config = ConfigDict(arbitrary_types_allowed=True)
name: str = "Hit Cache"
cache_handler: CacheHandler = Field(
description="Cache Handler for the crew",
default=CacheHandler(),
default_factory=CacheHandler,
)

def tool(self):
Expand Down
10 changes: 5 additions & 5 deletions src/crewai/utilities/logger.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from datetime import datetime

from crewai.utilities.printer import Printer
from pydantic import BaseModel, Field, PrivateAttr

from crewai.utilities.printer import Printer

class Logger:
_printer = Printer()

def __init__(self, verbose=False):
self.verbose = verbose
class Logger(BaseModel):
verbose: bool = Field(default=False)
_printer: Printer = PrivateAttr(default_factory=Printer)

def log(self, level, message, color="bold_green"):
if self.verbose:
Expand Down
Loading

0 comments on commit bf7372f

Please sign in to comment.