-
-
Notifications
You must be signed in to change notification settings - Fork 532
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enhanced Configuration, Token Counting and Unified Output Formatting Schema for swarms Module #770
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,10 +1,12 @@ | ||
from swarms.schemas.agent_step_schemas import Step, ManySteps | ||
|
||
from swarms.schemas.agent_input_schema import AgentSchema | ||
|
||
from swarms.schemas.base_swarm_schemas import BaseSwarmSchema | ||
from swarms.schemas.output_schemas import OutputSchema | ||
|
||
__all__ = [ | ||
"Step", | ||
"ManySteps", | ||
"AgentSchema", | ||
"BaseSwarmSchema", | ||
"OutputSchema", | ||
] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,152 @@ | ||
from typing import Any, Dict, List, Optional, Union | ||
from pydantic import BaseModel, Field, validator | ||
Check failure Code scanning / Pyre Undefined import Error
Undefined import [21]: Could not find a module corresponding to import pydantic.
|
||
import uuid | ||
import time | ||
|
||
class AgentInputConfig(BaseModel): | ||
Check failure Code scanning / Pyre Undefined or invalid type Error
Undefined or invalid type [11]: Annotation BaseModel is not defined as a type.
|
||
""" | ||
Configuration for an agent. This can be further customized | ||
per agent type if needed. | ||
""" | ||
agent_name: str = Field(..., description="Name of the agent") | ||
agent_type: str = Field(..., description="Type of agent (e.g. 'llm', 'tool', 'memory')") | ||
model_name: Optional[str] = Field(None, description="Name of the model to use") | ||
temperature: float = Field(0.7, description="Temperature for model sampling") | ||
max_tokens: int = Field(4096, description="Maximum tokens for model response") | ||
system_prompt: Optional[str] = Field(None, description="System prompt for the agent") | ||
tools: Optional[List[str]] = Field(None, description="List of tool names available to agent") | ||
memory_type: Optional[str] = Field(None, description="Type of memory to use") | ||
metadata: Dict[str, Any] = Field(default_factory=dict, description="Additional agent metadata") | ||
|
||
class BaseSwarmSchema(BaseModel): | ||
""" | ||
Base schema for all swarm types. | ||
""" | ||
id: str = Field(default_factory=lambda: str(uuid.uuid4())) | ||
name: str | ||
Check failure Code scanning / Pyre Uninitialized attribute Error
Uninitialized attribute [13]: Attribute name is declared in class BaseSwarmSchema to have type str but is never initialized.
|
||
description: str | ||
Check failure Code scanning / Pyre Uninitialized attribute Error
Uninitialized attribute [13]: Attribute description is declared in class BaseSwarmSchema to have type str but is never initialized.
|
||
agents: List[AgentInputConfig] # Using AgentInputConfig | ||
Check failure Code scanning / Pyre Uninitialized attribute Error
Uninitialized attribute [13]: Attribute agents is declared in class BaseSwarmSchema to have type List[AgentInputConfig] but is never initialized.
|
||
max_loops: int = 1 | ||
swarm_type: str # e.g., "SequentialWorkflow", "ConcurrentWorkflow", etc. | ||
Check failure Code scanning / Pyre Uninitialized attribute Error
Uninitialized attribute [13]: Attribute swarm\_type is declared in class BaseSwarmSchema to have type str but is never initialized.
|
||
created_at: str = Field(default_factory=lambda: time.strftime("%Y-%m-%d %H:%M:%S")) | ||
config: Dict[str, Any] = Field(default_factory=dict) # Flexible config | ||
|
||
# Additional fields | ||
timeout: Optional[int] = Field(None, description="Timeout in seconds for swarm execution") | ||
error_handling: str = Field("stop", description="Error handling strategy: 'stop', 'continue', or 'retry'") | ||
max_retries: int = Field(3, description="Maximum number of retry attempts") | ||
logging_level: str = Field("info", description="Logging level for the swarm") | ||
metrics_enabled: bool = Field(True, description="Whether to collect metrics") | ||
tags: List[str] = Field(default_factory=list, description="Tags for categorizing swarms") | ||
|
||
@validator("swarm_type") | ||
def validate_swarm_type(cls, v): | ||
"""Validates the swarm type is one of the allowed types""" | ||
allowed_types = [ | ||
"SequentialWorkflow", | ||
"ConcurrentWorkflow", | ||
"AgentRearrange", | ||
"MixtureOfAgents", | ||
"SpreadSheetSwarm", | ||
"AutoSwarm", | ||
"HierarchicalSwarm", | ||
"FeedbackSwarm" | ||
] | ||
if v not in allowed_types: | ||
raise ValueError(f"Swarm type must be one of: {allowed_types}") | ||
return v | ||
|
||
@validator("config") | ||
def validate_config(cls, v, values): | ||
""" | ||
Validates the 'config' dictionary based on the 'swarm_type'. | ||
""" | ||
swarm_type = values.get("swarm_type") | ||
|
||
# Common validation for all swarm types | ||
if not isinstance(v, dict): | ||
raise ValueError("Config must be a dictionary") | ||
|
||
# Type-specific validation | ||
if swarm_type == "SequentialWorkflow": | ||
if "flow" not in v: | ||
raise ValueError("SequentialWorkflow requires a 'flow' configuration.") | ||
if not isinstance(v["flow"], list): | ||
raise ValueError("Flow configuration must be a list") | ||
|
||
elif swarm_type == "ConcurrentWorkflow": | ||
if "max_workers" not in v: | ||
raise ValueError("ConcurrentWorkflow requires a 'max_workers' configuration.") | ||
if not isinstance(v["max_workers"], int) or v["max_workers"] < 1: | ||
raise ValueError("max_workers must be a positive integer") | ||
|
||
elif swarm_type == "AgentRearrange": | ||
if "flow" not in v: | ||
raise ValueError("AgentRearrange requires a 'flow' configuration.") | ||
if not isinstance(v["flow"], list): | ||
raise ValueError("Flow configuration must be a list") | ||
|
||
elif swarm_type == "MixtureOfAgents": | ||
if "aggregator_agent" not in v: | ||
raise ValueError("MixtureOfAgents requires an 'aggregator_agent' configuration.") | ||
if "voting_method" not in v: | ||
v["voting_method"] = "majority" # Set default voting method | ||
|
||
elif swarm_type == "SpreadSheetSwarm": | ||
if "save_file_path" not in v: | ||
raise ValueError("SpreadSheetSwarm requires a 'save_file_path' configuration.") | ||
if not isinstance(v["save_file_path"], str): | ||
raise ValueError("save_file_path must be a string") | ||
|
||
elif swarm_type == "AutoSwarm": | ||
if "optimization_metric" not in v: | ||
v["optimization_metric"] = "performance" # Set default metric | ||
if "adaptation_strategy" not in v: | ||
v["adaptation_strategy"] = "dynamic" # Set default strategy | ||
|
||
elif swarm_type == "HierarchicalSwarm": | ||
if "hierarchy_levels" not in v: | ||
raise ValueError("HierarchicalSwarm requires 'hierarchy_levels' configuration.") | ||
if not isinstance(v["hierarchy_levels"], int) or v["hierarchy_levels"] < 1: | ||
raise ValueError("hierarchy_levels must be a positive integer") | ||
|
||
elif swarm_type == "FeedbackSwarm": | ||
if "feedback_collection" not in v: | ||
v["feedback_collection"] = "continuous" # Set default collection method | ||
if "feedback_integration" not in v: | ||
v["feedback_integration"] = "weighted" # Set default integration method | ||
|
||
return v | ||
|
||
@validator("error_handling") | ||
def validate_error_handling(cls, v): | ||
"""Validates error handling strategy""" | ||
allowed_strategies = ["stop", "continue", "retry"] | ||
if v not in allowed_strategies: | ||
raise ValueError(f"Error handling must be one of: {allowed_strategies}") | ||
return v | ||
|
||
@validator("logging_level") | ||
def validate_logging_level(cls, v): | ||
"""Validates logging level""" | ||
allowed_levels = ["debug", "info", "warning", "error", "critical"] | ||
if v.lower() not in allowed_levels: | ||
raise ValueError(f"Logging level must be one of: {allowed_levels}") | ||
return v.lower() | ||
|
||
def get_agent_by_name(self, name: str) -> Optional[AgentInputConfig]: | ||
"""Helper method to get agent config by name""" | ||
for agent in self.agents: | ||
if agent.agent_name == name: | ||
return agent | ||
return None | ||
|
||
def add_tag(self, tag: str) -> None: | ||
"""Helper method to add a tag""" | ||
if tag not in self.tags: | ||
self.tags.append(tag) | ||
|
||
def remove_tag(self, tag: str) -> None: | ||
"""Helper method to remove a tag""" | ||
if tag in self.tags: | ||
self.tags.remove(tag) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
from typing import Any, Dict, List, Optional | ||
from pydantic import BaseModel, Field | ||
Check failure Code scanning / Pyre Undefined import Error
Undefined import [21]: Could not find a module corresponding to import pydantic.
|
||
import time | ||
import uuid | ||
from swarms.utils.litellm_tokenizer import count_tokens | ||
|
||
class Step(BaseModel): | ||
Check failure Code scanning / Pyre Undefined or invalid type Error
Undefined or invalid type [11]: Annotation BaseModel is not defined as a type.
|
||
""" | ||
Represents a single step in an agent's task execution. | ||
""" | ||
id: str = Field(default_factory=lambda: str(uuid.uuid4())) | ||
name: str = Field(..., description="Name of the agent") | ||
task: Optional[str] = Field(None, description="Task given to the agent at this step") | ||
input: Optional[str] = Field(None, description="Input provided to the agent at this step") | ||
output: Optional[str] = Field(None, description="Output generated by the agent at this step") | ||
error: Optional[str] = Field(None, description="Error message if any error occurred during the step") | ||
start_time: str = Field(default_factory=lambda: time.strftime("%Y-%m-%d %H:%M:%S")) | ||
end_time: Optional[str] = Field(None, description="End time of the step") | ||
runtime: Optional[float] = Field(None, description="Runtime of the step in seconds") | ||
tokens_used: Optional[int] = Field(None, description="Number of tokens used in this step") | ||
cost: Optional[float] = Field(None, description="Cost of the step") | ||
metadata: Optional[Dict[str, Any]] = Field( | ||
None, description="Additional metadata about the step" | ||
) | ||
|
||
def calculate_tokens(self, model: str = "gpt-4o") -> int: | ||
"""Calculate total tokens used in this step""" | ||
total = 0 | ||
if self.input: | ||
total += count_tokens(self.input, model) | ||
if self.output: | ||
total += count_tokens(self.output, model) | ||
self.tokens_used = total | ||
return total | ||
|
||
class AgentTaskOutput(BaseModel): | ||
""" | ||
Represents the output of an agent's execution. | ||
""" | ||
id: str = Field(default_factory=lambda: str(uuid.uuid4())) | ||
agent_name: str = Field(..., description="Name of the agent") | ||
task: str = Field(..., description="The task agent was asked to perform") | ||
steps: List[Step] = Field(..., description="List of steps taken by the agent") | ||
start_time: str = Field(default_factory=lambda: time.strftime("%Y-%m-%d %H:%M:%S")) | ||
end_time: Optional[str] = Field(None, description="End time of the agent's execution") | ||
total_tokens: Optional[int] = Field(None, description="Total tokens used by the agent") | ||
cost: Optional[float] = Field(None, description="Total cost of the agent execution") | ||
# Add any other fields from ManySteps that are relevant, like full_history | ||
|
||
def calculate_total_tokens(self, model: str = "gpt-4o") -> int: | ||
"""Calculate total tokens across all steps""" | ||
total = 0 | ||
for step in self.steps: | ||
total += step.calculate_tokens(model) | ||
self.total_tokens = total | ||
return total | ||
|
||
class OutputSchema(BaseModel): | ||
""" | ||
Unified output schema for all swarm types. | ||
""" | ||
swarm_id: str = Field(default_factory=lambda: str(uuid.uuid4())) | ||
swarm_type: str = Field(..., description="Type of the swarm") | ||
task: str = Field(..., description="The task given to the swarm") | ||
agent_outputs: List[AgentTaskOutput] = Field(..., description="List of agent outputs") | ||
timestamp: str = Field(default_factory=lambda: time.strftime("%Y-%m-%d %H:%M:%S")) | ||
swarm_specific_output: Optional[Dict] = Field(None, description="Additional data specific to the swarm type") | ||
|
||
class SwarmOutputFormatter: | ||
""" | ||
Formatter class to transform raw swarm output into the OutputSchema format. | ||
""" | ||
|
||
@staticmethod | ||
def format_output( | ||
swarm_id: str, | ||
swarm_type: str, | ||
task: str, | ||
agent_outputs: List[AgentTaskOutput], | ||
swarm_specific_output: Optional[Dict] = None, | ||
) -> str: | ||
"""Formats the output into a standardized JSON string.""" | ||
output = OutputSchema( | ||
swarm_id=swarm_id, | ||
swarm_type=swarm_type, | ||
task=task, | ||
agent_outputs=agent_outputs, | ||
swarm_specific_output=swarm_specific_output, | ||
) | ||
Comment on lines
+83
to
+89
Check failure Code scanning / Pyre Unexpected keyword Error
Unexpected keyword [28]: Unexpected keyword argument swarm\_id to call object.\_\_init\_\_.
|
||
return output.model_dump_json(indent=4) | ||
Check failure Code scanning / Pyre Undefined attribute Error
Undefined attribute [16]: OutputSchema has no attribute model\_dump\_json.
|
Check failure
Code scanning / Pyre
Unexpected keyword Error