Skip to content

Commit

Permalink
feat: Add strict mode support to BaseTool, ToolSchema and FunctionTool (
Browse files Browse the repository at this point in the history
#5507)

Resolves #4447

For `openai` client's structured output support is through its beta
client, which requires the function JSON schema to be strict when in
structured output mode.

Reference:
https://platform.openai.com/docs/guides/function-calling#strict-mode
  • Loading branch information
ekzhu authored Feb 13, 2025
1 parent 9704208 commit ec314c5
Show file tree
Hide file tree
Showing 6 changed files with 271 additions and 53 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -53,4 +53,4 @@ def handoff_tool(self) -> BaseTool[BaseModel, BaseModel]:
def _handoff_tool() -> str:
return self.message

return FunctionTool(_handoff_tool, name=self.name, description=self.description)
return FunctionTool(_handoff_tool, name=self.name, description=self.description, strict=True)
34 changes: 26 additions & 8 deletions python/packages/autogen-core/src/autogen_core/tools/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,14 @@ class ParametersSchema(TypedDict):
type: str
properties: Dict[str, Any]
required: NotRequired[Sequence[str]]
additionalProperties: NotRequired[bool]


class ToolSchema(TypedDict):
parameters: NotRequired[ParametersSchema]
name: str
description: NotRequired[str]
strict: NotRequired[bool]


@runtime_checkable
Expand Down Expand Up @@ -66,12 +68,14 @@ def __init__(
return_type: Type[ReturnT],
name: str,
description: str,
strict: bool = False,
) -> None:
self._args_type = args_type
# Normalize Annotated to the base type.
self._return_type = normalize_annotated_type(return_type)
self._name = name
self._description = description
self._strict = strict

@property
def schema(self) -> ToolSchema:
Expand All @@ -81,18 +85,32 @@ def schema(self) -> ToolSchema:
model_schema = cast(Dict[str, Any], jsonref.replace_refs(obj=model_schema, proxies=False)) # type: ignore
del model_schema["$defs"]

parameters = ParametersSchema(
type="object",
properties=model_schema["properties"],
required=model_schema.get("required", []),
additionalProperties=model_schema.get("additionalProperties", False),
)

# If strict is enabled, the tool schema should list all properties as required.
assert "required" in parameters
if self._strict and set(parameters["required"]) != set(parameters["properties"].keys()):
raise ValueError(
"Strict mode is enabled, but not all input arguments are marked as required. Default arguments are not allowed in strict mode."
)

assert "additionalProperties" in parameters
if self._strict and parameters["additionalProperties"]:
raise ValueError(
"Strict mode is enabled but additional argument is also enabled. This is not allowed in strict mode."
)

tool_schema = ToolSchema(
name=self._name,
description=self._description,
parameters=ParametersSchema(
type="object",
properties=model_schema["properties"],
),
parameters=parameters,
strict=self._strict,
)
if "required" in model_schema:
assert "parameters" in tool_schema
tool_schema["parameters"]["required"] = model_schema["required"]

return tool_schema

@property
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ class FunctionTool(BaseTool[BaseModel, BaseModel], Component[FunctionToolConfig]
it does and the context in which it should be called.
name (str, optional): An optional custom name for the tool. Defaults to
the function's original name if not provided.
strict (bool, optional): If set to True, the tool schema will only contain arguments that are explicitly
defined in the function signature, and no default values will be allowed. Defaults to False.
This is required to be set to True when used with models in structured output mode.
Example:
Expand Down Expand Up @@ -83,7 +86,12 @@ async def example():
component_config_schema = FunctionToolConfig

def __init__(
self, func: Callable[..., Any], description: str, name: str | None = None, global_imports: Sequence[Import] = []
self,
func: Callable[..., Any],
description: str,
name: str | None = None,
global_imports: Sequence[Import] = [],
strict: bool = False,
) -> None:
self._func = func
self._global_imports = global_imports
Expand All @@ -92,8 +100,7 @@ def __init__(
args_model = args_base_model_from_signature(func_name + "args", signature)
return_type = signature.return_annotation
self._has_cancellation_support = "cancellation_token" in signature.parameters

super().__init__(args_model, return_type, func_name, description)
super().__init__(args_model, return_type, func_name, description, strict)

async def run(self, args: BaseModel, cancellation_token: CancellationToken) -> Any:
if asyncio.iscoroutinefunction(self._func):
Expand Down
43 changes: 42 additions & 1 deletion python/packages/autogen-core/tests/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,37 @@ def my_function(arg: str, other: Annotated[int, "int arg"], nonrequired: int = 5
assert len(schema["parameters"]["properties"]) == 3


def test_func_tool_schema_generation_strict() -> None:
def my_function1(arg: str, other: Annotated[int, "int arg"], nonrequired: int = 5) -> MyResult:
return MyResult(result="test")

with pytest.raises(ValueError, match="Strict mode is enabled"):
tool = FunctionTool(my_function1, description="Function tool.", strict=True)
schema = tool.schema

def my_function2(arg: str, other: Annotated[int, "int arg"]) -> MyResult:
return MyResult(result="test")

tool = FunctionTool(my_function2, description="Function tool.", strict=True)
schema = tool.schema

assert schema["name"] == "my_function2"
assert "description" in schema
assert schema["description"] == "Function tool."
assert "parameters" in schema
assert schema["parameters"]["type"] == "object"
assert schema["parameters"]["properties"].keys() == {"arg", "other"}
assert schema["parameters"]["properties"]["arg"]["type"] == "string"
assert schema["parameters"]["properties"]["arg"]["description"] == "arg"
assert schema["parameters"]["properties"]["other"]["type"] == "integer"
assert schema["parameters"]["properties"]["other"]["description"] == "int arg"
assert "required" in schema["parameters"]
assert schema["parameters"]["required"] == ["arg", "other"]
assert len(schema["parameters"]["properties"]) == 2
assert "additionalProperties" in schema["parameters"]
assert schema["parameters"]["additionalProperties"] is False


def test_func_tool_schema_generation_only_default_arg() -> None:
def my_function(arg: str = "default") -> MyResult:
return MyResult(result="test")
Expand All @@ -107,7 +138,17 @@ def my_function(arg: str = "default") -> MyResult:
assert len(schema["parameters"]["properties"]) == 1
assert schema["parameters"]["properties"]["arg"]["type"] == "string"
assert schema["parameters"]["properties"]["arg"]["description"] == "arg"
assert "required" not in schema["parameters"]
assert "required" in schema["parameters"]
assert schema["parameters"]["required"] == []


def test_func_tool_schema_generation_only_default_arg_strict() -> None:
def my_function(arg: str = "default") -> MyResult:
return MyResult(result="test")

with pytest.raises(ValueError, match="Strict mode is enabled"):
tool = FunctionTool(my_function, description="Function tool.", strict=True)
_ = tool.schema


def test_func_tool_with_partial_positional_arguments_schema_generation() -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,7 @@ def convert_tools(
parameters=(
cast(FunctionParameters, tool_schema["parameters"]) if "parameters" in tool_schema else {}
),
strict=(tool_schema["strict"] if "strict" in tool_schema else False),
),
)
)
Expand Down Expand Up @@ -977,6 +978,12 @@ def model_info(self) -> ModelInfo:
class OpenAIChatCompletionClient(BaseOpenAIChatCompletionClient, Component[OpenAIClientConfigurationConfigModel]):
"""Chat completion client for OpenAI hosted models.
To use this client, you must install the `openai` extra:
.. code-block:: bash
pip install "autogen-ext[openai]"
You can also use this client for OpenAI-compatible ChatCompletion endpoints.
**Using this client for non-OpenAI models is not tested or guaranteed.**
Expand All @@ -996,7 +1003,7 @@ class OpenAIChatCompletionClient(BaseOpenAIChatCompletionClient, Component[OpenA
max_tokens (optional, int):
n (optional, int):
presence_penalty (optional, float):
response_format (optional, literal["json_object", "text"]):
response_format (optional, literal["json_object", "text"] | pydantic.BaseModel):
seed (optional, int):
stop (optional, str | List[str]):
temperature (optional, float):
Expand All @@ -1009,63 +1016,132 @@ class OpenAIChatCompletionClient(BaseOpenAIChatCompletionClient, Component[OpenA
This can be useful for models that do not support the `name` field in
message. Defaults to False.
Examples:
To use this client, you must install the `openai` extension:
The following code snippet shows how to use the client with an OpenAI model:
.. code-block:: bash
.. code-block:: python
pip install "autogen-ext[openai]"
from autogen_ext.models.openai import OpenAIChatCompletionClient
from autogen_core.models import UserMessage
The following code snippet shows how to use the client with an OpenAI model:
openai_client = OpenAIChatCompletionClient(
model="gpt-4o-2024-08-06",
# api_key="sk-...", # Optional if you have an OPENAI_API_KEY environment variable set.
)
.. code-block:: python
result = await openai_client.create([UserMessage(content="What is the capital of France?", source="user")]) # type: ignore
print(result)
from autogen_ext.models.openai import OpenAIChatCompletionClient
from autogen_core.models import UserMessage
openai_client = OpenAIChatCompletionClient(
model="gpt-4o-2024-08-06",
# api_key="sk-...", # Optional if you have an OPENAI_API_KEY environment variable set.
)
To use the client with a non-OpenAI model, you need to provide the base URL of the model and the model info.
For example, to use Ollama, you can use the following code snippet:
result = await openai_client.create([UserMessage(content="What is the capital of France?", source="user")]) # type: ignore
print(result)
.. code-block:: python
from autogen_ext.models.openai import OpenAIChatCompletionClient
from autogen_core.models import ModelFamily
custom_model_client = OpenAIChatCompletionClient(
model="deepseek-r1:1.5b",
base_url="http://localhost:11434/v1",
api_key="placeholder",
model_info={
"vision": False,
"function_calling": False,
"json_output": False,
"family": ModelFamily.R1,
},
)
To use the client with a non-OpenAI model, you need to provide the base URL of the model and the model info.
For example, to use Ollama, you can use the following code snippet:
To use structured output as well as function calling, you can use the following code snippet:
.. code-block:: python
.. code-block:: python
from autogen_ext.models.openai import OpenAIChatCompletionClient
from autogen_core.models import ModelFamily
custom_model_client = OpenAIChatCompletionClient(
model="deepseek-r1:1.5b",
base_url="http://localhost:11434/v1",
api_key="placeholder",
model_info={
"vision": False,
"function_calling": False,
"json_output": False,
"family": ModelFamily.R1,
},
)
import asyncio
from typing import Literal
To load the client from a configuration, you can use the `load_component` method:
from autogen_core.models import (
AssistantMessage,
FunctionExecutionResult,
FunctionExecutionResultMessage,
SystemMessage,
UserMessage,
)
from autogen_core.tools import FunctionTool
from autogen_ext.models.openai import OpenAIChatCompletionClient
from pydantic import BaseModel
.. code-block:: python
from autogen_core.models import ChatCompletionClient
# Define the structured output format.
class AgentResponse(BaseModel):
thoughts: str
response: Literal["happy", "sad", "neutral"]
config = {
"provider": "OpenAIChatCompletionClient",
"config": {"model": "gpt-4o", "api_key": "REPLACE_WITH_YOUR_API_KEY"},
}
client = ChatCompletionClient.load_component(config)
# Define the function to be called as a tool.
def sentiment_analysis(text: str) -> str:
\"\"\"Given a text, return the sentiment.\"\"\"
return "happy" if "happy" in text else "sad" if "sad" in text else "neutral"
# Create a FunctionTool instance with `strict=True`,
# which is required for structured output mode.
tool = FunctionTool(sentiment_analysis, description="Sentiment Analysis", strict=True)
# Create an OpenAIChatCompletionClient instance.
model_client = OpenAIChatCompletionClient(
model="gpt-4o-mini",
response_format=AgentResponse, # type: ignore
)
async def main() -> None:
# Generate a response using the tool.
response1 = await model_client.create(
messages=[
SystemMessage(content="Analyze input text sentiment using the tool provided."),
UserMessage(content="I am happy.", source="user"),
],
tools=[tool],
)
print(response1.content)
# Should be a list of tool calls.
# [FunctionCall(name="sentiment_analysis", arguments={"text": "I am happy."}, ...)]
assert isinstance(response1.content, list)
response2 = await model_client.create(
messages=[
SystemMessage(content="Analyze input text sentiment using the tool provided."),
UserMessage(content="I am happy.", source="user"),
AssistantMessage(content=response1.content, source="assistant"),
FunctionExecutionResultMessage(
content=[FunctionExecutionResult(content="happy", call_id=response1.content[0].id, is_error=False)]
),
],
)
print(response2.content)
# Should be a structured output.
# {"thoughts": "The user is happy.", "response": "happy"}
asyncio.run(main())
To load the client from a configuration, you can use the `load_component` method:
.. code-block:: python
from autogen_core.models import ChatCompletionClient
config = {
"provider": "OpenAIChatCompletionClient",
"config": {"model": "gpt-4o", "api_key": "REPLACE_WITH_YOUR_API_KEY"},
}
client = ChatCompletionClient.load_component(config)
To view the full list of available configuration options, see the :py:class:`OpenAIClientConfigurationConfigModel` class.
To view the full list of available configuration options, see the :py:class:`OpenAIClientConfigurationConfigModel` class.
"""

Expand Down
Loading

0 comments on commit ec314c5

Please sign in to comment.