Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enhanced Configuration, Token Counting and Unified Output Formatting Schema for swarms Module #770

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions swarms/schemas/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from swarms.schemas.agent_step_schemas import Step, ManySteps

from swarms.schemas.agent_input_schema import AgentSchema

from swarms.schemas.base_swarm_schemas import BaseSwarmSchema
from swarms.schemas.output_schemas import OutputSchema

__all__ = [
"Step",
"ManySteps",
"AgentSchema",
"BaseSwarmSchema",
"OutputSchema",
]
29 changes: 29 additions & 0 deletions swarms/schemas/base_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import List, Literal, Optional, Union

from pydantic import BaseModel, Field
from swarms.utils.litellm_tokenizer import count_tokens


class ModelCard(BaseModel):
Expand Down Expand Up @@ -49,6 +50,18 @@
)
content: Union[str, List[ContentItem]]

def count_tokens(self, model: str = "gpt-4o") -> int:
"""Count tokens in the message content"""
if isinstance(self.content, str):
return count_tokens(self.content, model)
elif isinstance(self.content, list):
total = 0
for item in self.content:
if isinstance(item, TextContent):
total += count_tokens(item.text, model)
return total
return 0


class ChatMessageResponse(BaseModel):
role: str = Field(
Expand Down Expand Up @@ -92,6 +105,22 @@
total_tokens: int = 0
completion_tokens: Optional[int] = 0

@classmethod
def calculate_usage(
cls,
messages: List[ChatMessageInput],
completion: Optional[str] = None,
model: str = "gpt-4o"
) -> "UsageInfo":
"""Calculate token usage for messages and completion"""
prompt_tokens = sum(msg.count_tokens(model) for msg in messages)
completion_tokens = count_tokens(completion, model) if completion else 0
return cls(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens
)
Comment on lines +118 to +122

Check failure

Code scanning / Pyre

Unexpected keyword Error

Unexpected keyword [28]: Unexpected keyword argument prompt\_tokens to call object.\_\_init\_\_.


class ChatCompletionResponse(BaseModel):
model: str
Expand Down
152 changes: 152 additions & 0 deletions swarms/schemas/base_swarm_schemas.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
from typing import Any, Dict, List, Optional, Union
from pydantic import BaseModel, Field, validator

Check failure

Code scanning / Pyre

Undefined import Error

Undefined import [21]: Could not find a module corresponding to import pydantic.
import uuid
import time

class AgentInputConfig(BaseModel):

Check failure

Code scanning / Pyre

Undefined or invalid type Error

Undefined or invalid type [11]: Annotation BaseModel is not defined as a type.
"""
Configuration for an agent. This can be further customized
per agent type if needed.
"""
agent_name: str = Field(..., description="Name of the agent")
agent_type: str = Field(..., description="Type of agent (e.g. 'llm', 'tool', 'memory')")
model_name: Optional[str] = Field(None, description="Name of the model to use")
temperature: float = Field(0.7, description="Temperature for model sampling")
max_tokens: int = Field(4096, description="Maximum tokens for model response")
system_prompt: Optional[str] = Field(None, description="System prompt for the agent")
tools: Optional[List[str]] = Field(None, description="List of tool names available to agent")
memory_type: Optional[str] = Field(None, description="Type of memory to use")
metadata: Dict[str, Any] = Field(default_factory=dict, description="Additional agent metadata")

class BaseSwarmSchema(BaseModel):
"""
Base schema for all swarm types.
"""
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
name: str

Check failure

Code scanning / Pyre

Uninitialized attribute Error

Uninitialized attribute [13]: Attribute name is declared in class BaseSwarmSchema to have type str but is never initialized.
description: str

Check failure

Code scanning / Pyre

Uninitialized attribute Error

Uninitialized attribute [13]: Attribute description is declared in class BaseSwarmSchema to have type str but is never initialized.
agents: List[AgentInputConfig] # Using AgentInputConfig

Check failure

Code scanning / Pyre

Uninitialized attribute Error

Uninitialized attribute [13]: Attribute agents is declared in class BaseSwarmSchema to have type List[AgentInputConfig] but is never initialized.
max_loops: int = 1
swarm_type: str # e.g., "SequentialWorkflow", "ConcurrentWorkflow", etc.

Check failure

Code scanning / Pyre

Uninitialized attribute Error

Uninitialized attribute [13]: Attribute swarm\_type is declared in class BaseSwarmSchema to have type str but is never initialized.
created_at: str = Field(default_factory=lambda: time.strftime("%Y-%m-%d %H:%M:%S"))
config: Dict[str, Any] = Field(default_factory=dict) # Flexible config

# Additional fields
timeout: Optional[int] = Field(None, description="Timeout in seconds for swarm execution")
error_handling: str = Field("stop", description="Error handling strategy: 'stop', 'continue', or 'retry'")
max_retries: int = Field(3, description="Maximum number of retry attempts")
logging_level: str = Field("info", description="Logging level for the swarm")
metrics_enabled: bool = Field(True, description="Whether to collect metrics")
tags: List[str] = Field(default_factory=list, description="Tags for categorizing swarms")

@validator("swarm_type")
def validate_swarm_type(cls, v):
"""Validates the swarm type is one of the allowed types"""
allowed_types = [
"SequentialWorkflow",
"ConcurrentWorkflow",
"AgentRearrange",
"MixtureOfAgents",
"SpreadSheetSwarm",
"AutoSwarm",
"HierarchicalSwarm",
"FeedbackSwarm"
]
if v not in allowed_types:
raise ValueError(f"Swarm type must be one of: {allowed_types}")
return v

@validator("config")
def validate_config(cls, v, values):
"""
Validates the 'config' dictionary based on the 'swarm_type'.
"""
swarm_type = values.get("swarm_type")

# Common validation for all swarm types
if not isinstance(v, dict):
raise ValueError("Config must be a dictionary")

# Type-specific validation
if swarm_type == "SequentialWorkflow":
if "flow" not in v:
raise ValueError("SequentialWorkflow requires a 'flow' configuration.")
if not isinstance(v["flow"], list):
raise ValueError("Flow configuration must be a list")

elif swarm_type == "ConcurrentWorkflow":
if "max_workers" not in v:
raise ValueError("ConcurrentWorkflow requires a 'max_workers' configuration.")
if not isinstance(v["max_workers"], int) or v["max_workers"] < 1:
raise ValueError("max_workers must be a positive integer")

elif swarm_type == "AgentRearrange":
if "flow" not in v:
raise ValueError("AgentRearrange requires a 'flow' configuration.")
if not isinstance(v["flow"], list):
raise ValueError("Flow configuration must be a list")

elif swarm_type == "MixtureOfAgents":
if "aggregator_agent" not in v:
raise ValueError("MixtureOfAgents requires an 'aggregator_agent' configuration.")
if "voting_method" not in v:
v["voting_method"] = "majority" # Set default voting method

elif swarm_type == "SpreadSheetSwarm":
if "save_file_path" not in v:
raise ValueError("SpreadSheetSwarm requires a 'save_file_path' configuration.")
if not isinstance(v["save_file_path"], str):
raise ValueError("save_file_path must be a string")

elif swarm_type == "AutoSwarm":
if "optimization_metric" not in v:
v["optimization_metric"] = "performance" # Set default metric
if "adaptation_strategy" not in v:
v["adaptation_strategy"] = "dynamic" # Set default strategy

elif swarm_type == "HierarchicalSwarm":
if "hierarchy_levels" not in v:
raise ValueError("HierarchicalSwarm requires 'hierarchy_levels' configuration.")
if not isinstance(v["hierarchy_levels"], int) or v["hierarchy_levels"] < 1:
raise ValueError("hierarchy_levels must be a positive integer")

elif swarm_type == "FeedbackSwarm":
if "feedback_collection" not in v:
v["feedback_collection"] = "continuous" # Set default collection method
if "feedback_integration" not in v:
v["feedback_integration"] = "weighted" # Set default integration method

return v

@validator("error_handling")
def validate_error_handling(cls, v):
"""Validates error handling strategy"""
allowed_strategies = ["stop", "continue", "retry"]
if v not in allowed_strategies:
raise ValueError(f"Error handling must be one of: {allowed_strategies}")
return v

@validator("logging_level")
def validate_logging_level(cls, v):
"""Validates logging level"""
allowed_levels = ["debug", "info", "warning", "error", "critical"]
if v.lower() not in allowed_levels:
raise ValueError(f"Logging level must be one of: {allowed_levels}")
return v.lower()

def get_agent_by_name(self, name: str) -> Optional[AgentInputConfig]:
"""Helper method to get agent config by name"""
for agent in self.agents:
if agent.agent_name == name:
return agent
return None

def add_tag(self, tag: str) -> None:
"""Helper method to add a tag"""
if tag not in self.tags:
self.tags.append(tag)

def remove_tag(self, tag: str) -> None:
"""Helper method to remove a tag"""
if tag in self.tags:
self.tags.remove(tag)
90 changes: 90 additions & 0 deletions swarms/schemas/output_schemas.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Field

Check failure

Code scanning / Pyre

Undefined import Error

Undefined import [21]: Could not find a module corresponding to import pydantic.
import time
import uuid
from swarms.utils.litellm_tokenizer import count_tokens

class Step(BaseModel):

Check failure

Code scanning / Pyre

Undefined or invalid type Error

Undefined or invalid type [11]: Annotation BaseModel is not defined as a type.
"""
Represents a single step in an agent's task execution.
"""
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
name: str = Field(..., description="Name of the agent")
task: Optional[str] = Field(None, description="Task given to the agent at this step")
input: Optional[str] = Field(None, description="Input provided to the agent at this step")
output: Optional[str] = Field(None, description="Output generated by the agent at this step")
error: Optional[str] = Field(None, description="Error message if any error occurred during the step")
start_time: str = Field(default_factory=lambda: time.strftime("%Y-%m-%d %H:%M:%S"))
end_time: Optional[str] = Field(None, description="End time of the step")
runtime: Optional[float] = Field(None, description="Runtime of the step in seconds")
tokens_used: Optional[int] = Field(None, description="Number of tokens used in this step")
cost: Optional[float] = Field(None, description="Cost of the step")
metadata: Optional[Dict[str, Any]] = Field(
None, description="Additional metadata about the step"
)

def calculate_tokens(self, model: str = "gpt-4o") -> int:
"""Calculate total tokens used in this step"""
total = 0
if self.input:
total += count_tokens(self.input, model)
if self.output:
total += count_tokens(self.output, model)
self.tokens_used = total
return total

class AgentTaskOutput(BaseModel):
"""
Represents the output of an agent's execution.
"""
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
agent_name: str = Field(..., description="Name of the agent")
task: str = Field(..., description="The task agent was asked to perform")
steps: List[Step] = Field(..., description="List of steps taken by the agent")
start_time: str = Field(default_factory=lambda: time.strftime("%Y-%m-%d %H:%M:%S"))
end_time: Optional[str] = Field(None, description="End time of the agent's execution")
total_tokens: Optional[int] = Field(None, description="Total tokens used by the agent")
cost: Optional[float] = Field(None, description="Total cost of the agent execution")
# Add any other fields from ManySteps that are relevant, like full_history

def calculate_total_tokens(self, model: str = "gpt-4o") -> int:
"""Calculate total tokens across all steps"""
total = 0
for step in self.steps:
total += step.calculate_tokens(model)
self.total_tokens = total
return total

class OutputSchema(BaseModel):
"""
Unified output schema for all swarm types.
"""
swarm_id: str = Field(default_factory=lambda: str(uuid.uuid4()))
swarm_type: str = Field(..., description="Type of the swarm")
task: str = Field(..., description="The task given to the swarm")
agent_outputs: List[AgentTaskOutput] = Field(..., description="List of agent outputs")
timestamp: str = Field(default_factory=lambda: time.strftime("%Y-%m-%d %H:%M:%S"))
swarm_specific_output: Optional[Dict] = Field(None, description="Additional data specific to the swarm type")

class SwarmOutputFormatter:
"""
Formatter class to transform raw swarm output into the OutputSchema format.
"""

@staticmethod
def format_output(
swarm_id: str,
swarm_type: str,
task: str,
agent_outputs: List[AgentTaskOutput],
swarm_specific_output: Optional[Dict] = None,
) -> str:
"""Formats the output into a standardized JSON string."""
output = OutputSchema(
swarm_id=swarm_id,
swarm_type=swarm_type,
task=task,
agent_outputs=agent_outputs,
swarm_specific_output=swarm_specific_output,
)
Comment on lines +83 to +89

Check failure

Code scanning / Pyre

Unexpected keyword Error

Unexpected keyword [28]: Unexpected keyword argument swarm\_id to call object.\_\_init\_\_.
return output.model_dump_json(indent=4)

Check failure

Code scanning / Pyre

Undefined attribute Error

Undefined attribute [16]: OutputSchema has no attribute model\_dump\_json.
Loading