From ec314c586cb2b84eafc64ad4aa3b1e1af94567f5 Mon Sep 17 00:00:00 2001 From: Eric Zhu Date: Thu, 13 Feb 2025 11:44:55 -0800 Subject: [PATCH] feat: Add strict mode support to BaseTool, ToolSchema and FunctionTool (#5507) Resolves #4447 For `openai` client's structured output support is through its beta client, which requires the function JSON schema to be strict when in structured output mode. Reference: https://platform.openai.com/docs/guides/function-calling#strict-mode --- .../src/autogen_agentchat/base/_handoff.py | 2 +- .../src/autogen_core/tools/_base.py | 34 +++- .../src/autogen_core/tools/_function_tool.py | 13 +- .../packages/autogen-core/tests/test_tools.py | 43 ++++- .../models/openai/_openai_client.py | 156 +++++++++++++----- .../tests/models/test_openai_model_client.py | 76 +++++++++ 6 files changed, 271 insertions(+), 53 deletions(-) diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/base/_handoff.py b/python/packages/autogen-agentchat/src/autogen_agentchat/base/_handoff.py index e5bfa98d3b02..7afef094b640 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/base/_handoff.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/base/_handoff.py @@ -53,4 +53,4 @@ def handoff_tool(self) -> BaseTool[BaseModel, BaseModel]: def _handoff_tool() -> str: return self.message - return FunctionTool(_handoff_tool, name=self.name, description=self.description) + return FunctionTool(_handoff_tool, name=self.name, description=self.description, strict=True) diff --git a/python/packages/autogen-core/src/autogen_core/tools/_base.py b/python/packages/autogen-core/src/autogen_core/tools/_base.py index b484ef84f3e9..8b6489731159 100644 --- a/python/packages/autogen-core/src/autogen_core/tools/_base.py +++ b/python/packages/autogen-core/src/autogen_core/tools/_base.py @@ -18,12 +18,14 @@ class ParametersSchema(TypedDict): type: str properties: Dict[str, Any] required: NotRequired[Sequence[str]] + additionalProperties: NotRequired[bool] class ToolSchema(TypedDict): parameters: NotRequired[ParametersSchema] name: str description: NotRequired[str] + strict: NotRequired[bool] @runtime_checkable @@ -66,12 +68,14 @@ def __init__( return_type: Type[ReturnT], name: str, description: str, + strict: bool = False, ) -> None: self._args_type = args_type # Normalize Annotated to the base type. self._return_type = normalize_annotated_type(return_type) self._name = name self._description = description + self._strict = strict @property def schema(self) -> ToolSchema: @@ -81,18 +85,32 @@ def schema(self) -> ToolSchema: model_schema = cast(Dict[str, Any], jsonref.replace_refs(obj=model_schema, proxies=False)) # type: ignore del model_schema["$defs"] + parameters = ParametersSchema( + type="object", + properties=model_schema["properties"], + required=model_schema.get("required", []), + additionalProperties=model_schema.get("additionalProperties", False), + ) + + # If strict is enabled, the tool schema should list all properties as required. + assert "required" in parameters + if self._strict and set(parameters["required"]) != set(parameters["properties"].keys()): + raise ValueError( + "Strict mode is enabled, but not all input arguments are marked as required. Default arguments are not allowed in strict mode." + ) + + assert "additionalProperties" in parameters + if self._strict and parameters["additionalProperties"]: + raise ValueError( + "Strict mode is enabled but additional argument is also enabled. This is not allowed in strict mode." + ) + tool_schema = ToolSchema( name=self._name, description=self._description, - parameters=ParametersSchema( - type="object", - properties=model_schema["properties"], - ), + parameters=parameters, + strict=self._strict, ) - if "required" in model_schema: - assert "parameters" in tool_schema - tool_schema["parameters"]["required"] = model_schema["required"] - return tool_schema @property diff --git a/python/packages/autogen-core/src/autogen_core/tools/_function_tool.py b/python/packages/autogen-core/src/autogen_core/tools/_function_tool.py index 7cd4d0a9e8d4..fe3f8fb75c3a 100644 --- a/python/packages/autogen-core/src/autogen_core/tools/_function_tool.py +++ b/python/packages/autogen-core/src/autogen_core/tools/_function_tool.py @@ -47,6 +47,9 @@ class FunctionTool(BaseTool[BaseModel, BaseModel], Component[FunctionToolConfig] it does and the context in which it should be called. name (str, optional): An optional custom name for the tool. Defaults to the function's original name if not provided. + strict (bool, optional): If set to True, the tool schema will only contain arguments that are explicitly + defined in the function signature, and no default values will be allowed. Defaults to False. + This is required to be set to True when used with models in structured output mode. Example: @@ -83,7 +86,12 @@ async def example(): component_config_schema = FunctionToolConfig def __init__( - self, func: Callable[..., Any], description: str, name: str | None = None, global_imports: Sequence[Import] = [] + self, + func: Callable[..., Any], + description: str, + name: str | None = None, + global_imports: Sequence[Import] = [], + strict: bool = False, ) -> None: self._func = func self._global_imports = global_imports @@ -92,8 +100,7 @@ def __init__( args_model = args_base_model_from_signature(func_name + "args", signature) return_type = signature.return_annotation self._has_cancellation_support = "cancellation_token" in signature.parameters - - super().__init__(args_model, return_type, func_name, description) + super().__init__(args_model, return_type, func_name, description, strict) async def run(self, args: BaseModel, cancellation_token: CancellationToken) -> Any: if asyncio.iscoroutinefunction(self._func): diff --git a/python/packages/autogen-core/tests/test_tools.py b/python/packages/autogen-core/tests/test_tools.py index f1c2e44f0610..d0346c193634 100644 --- a/python/packages/autogen-core/tests/test_tools.py +++ b/python/packages/autogen-core/tests/test_tools.py @@ -93,6 +93,37 @@ def my_function(arg: str, other: Annotated[int, "int arg"], nonrequired: int = 5 assert len(schema["parameters"]["properties"]) == 3 +def test_func_tool_schema_generation_strict() -> None: + def my_function1(arg: str, other: Annotated[int, "int arg"], nonrequired: int = 5) -> MyResult: + return MyResult(result="test") + + with pytest.raises(ValueError, match="Strict mode is enabled"): + tool = FunctionTool(my_function1, description="Function tool.", strict=True) + schema = tool.schema + + def my_function2(arg: str, other: Annotated[int, "int arg"]) -> MyResult: + return MyResult(result="test") + + tool = FunctionTool(my_function2, description="Function tool.", strict=True) + schema = tool.schema + + assert schema["name"] == "my_function2" + assert "description" in schema + assert schema["description"] == "Function tool." + assert "parameters" in schema + assert schema["parameters"]["type"] == "object" + assert schema["parameters"]["properties"].keys() == {"arg", "other"} + assert schema["parameters"]["properties"]["arg"]["type"] == "string" + assert schema["parameters"]["properties"]["arg"]["description"] == "arg" + assert schema["parameters"]["properties"]["other"]["type"] == "integer" + assert schema["parameters"]["properties"]["other"]["description"] == "int arg" + assert "required" in schema["parameters"] + assert schema["parameters"]["required"] == ["arg", "other"] + assert len(schema["parameters"]["properties"]) == 2 + assert "additionalProperties" in schema["parameters"] + assert schema["parameters"]["additionalProperties"] is False + + def test_func_tool_schema_generation_only_default_arg() -> None: def my_function(arg: str = "default") -> MyResult: return MyResult(result="test") @@ -107,7 +138,17 @@ def my_function(arg: str = "default") -> MyResult: assert len(schema["parameters"]["properties"]) == 1 assert schema["parameters"]["properties"]["arg"]["type"] == "string" assert schema["parameters"]["properties"]["arg"]["description"] == "arg" - assert "required" not in schema["parameters"] + assert "required" in schema["parameters"] + assert schema["parameters"]["required"] == [] + + +def test_func_tool_schema_generation_only_default_arg_strict() -> None: + def my_function(arg: str = "default") -> MyResult: + return MyResult(result="test") + + with pytest.raises(ValueError, match="Strict mode is enabled"): + tool = FunctionTool(my_function, description="Function tool.", strict=True) + _ = tool.schema def test_func_tool_with_partial_positional_arguments_schema_generation() -> None: diff --git a/python/packages/autogen-ext/src/autogen_ext/models/openai/_openai_client.py b/python/packages/autogen-ext/src/autogen_ext/models/openai/_openai_client.py index 8469f921ba73..78e932895d92 100644 --- a/python/packages/autogen-ext/src/autogen_ext/models/openai/_openai_client.py +++ b/python/packages/autogen-ext/src/autogen_ext/models/openai/_openai_client.py @@ -307,6 +307,7 @@ def convert_tools( parameters=( cast(FunctionParameters, tool_schema["parameters"]) if "parameters" in tool_schema else {} ), + strict=(tool_schema["strict"] if "strict" in tool_schema else False), ), ) ) @@ -977,6 +978,12 @@ def model_info(self) -> ModelInfo: class OpenAIChatCompletionClient(BaseOpenAIChatCompletionClient, Component[OpenAIClientConfigurationConfigModel]): """Chat completion client for OpenAI hosted models. + To use this client, you must install the `openai` extra: + + .. code-block:: bash + + pip install "autogen-ext[openai]" + You can also use this client for OpenAI-compatible ChatCompletion endpoints. **Using this client for non-OpenAI models is not tested or guaranteed.** @@ -996,7 +1003,7 @@ class OpenAIChatCompletionClient(BaseOpenAIChatCompletionClient, Component[OpenA max_tokens (optional, int): n (optional, int): presence_penalty (optional, float): - response_format (optional, literal["json_object", "text"]): + response_format (optional, literal["json_object", "text"] | pydantic.BaseModel): seed (optional, int): stop (optional, str | List[str]): temperature (optional, float): @@ -1009,63 +1016,132 @@ class OpenAIChatCompletionClient(BaseOpenAIChatCompletionClient, Component[OpenA This can be useful for models that do not support the `name` field in message. Defaults to False. + Examples: - To use this client, you must install the `openai` extension: + The following code snippet shows how to use the client with an OpenAI model: - .. code-block:: bash + .. code-block:: python - pip install "autogen-ext[openai]" + from autogen_ext.models.openai import OpenAIChatCompletionClient + from autogen_core.models import UserMessage - The following code snippet shows how to use the client with an OpenAI model: + openai_client = OpenAIChatCompletionClient( + model="gpt-4o-2024-08-06", + # api_key="sk-...", # Optional if you have an OPENAI_API_KEY environment variable set. + ) - .. code-block:: python + result = await openai_client.create([UserMessage(content="What is the capital of France?", source="user")]) # type: ignore + print(result) - from autogen_ext.models.openai import OpenAIChatCompletionClient - from autogen_core.models import UserMessage - openai_client = OpenAIChatCompletionClient( - model="gpt-4o-2024-08-06", - # api_key="sk-...", # Optional if you have an OPENAI_API_KEY environment variable set. - ) + To use the client with a non-OpenAI model, you need to provide the base URL of the model and the model info. + For example, to use Ollama, you can use the following code snippet: - result = await openai_client.create([UserMessage(content="What is the capital of France?", source="user")]) # type: ignore - print(result) + .. code-block:: python + from autogen_ext.models.openai import OpenAIChatCompletionClient + from autogen_core.models import ModelFamily + + custom_model_client = OpenAIChatCompletionClient( + model="deepseek-r1:1.5b", + base_url="http://localhost:11434/v1", + api_key="placeholder", + model_info={ + "vision": False, + "function_calling": False, + "json_output": False, + "family": ModelFamily.R1, + }, + ) - To use the client with a non-OpenAI model, you need to provide the base URL of the model and the model info. - For example, to use Ollama, you can use the following code snippet: + To use structured output as well as function calling, you can use the following code snippet: - .. code-block:: python + .. code-block:: python - from autogen_ext.models.openai import OpenAIChatCompletionClient - from autogen_core.models import ModelFamily - - custom_model_client = OpenAIChatCompletionClient( - model="deepseek-r1:1.5b", - base_url="http://localhost:11434/v1", - api_key="placeholder", - model_info={ - "vision": False, - "function_calling": False, - "json_output": False, - "family": ModelFamily.R1, - }, - ) + import asyncio + from typing import Literal - To load the client from a configuration, you can use the `load_component` method: + from autogen_core.models import ( + AssistantMessage, + FunctionExecutionResult, + FunctionExecutionResultMessage, + SystemMessage, + UserMessage, + ) + from autogen_core.tools import FunctionTool + from autogen_ext.models.openai import OpenAIChatCompletionClient + from pydantic import BaseModel - .. code-block:: python - from autogen_core.models import ChatCompletionClient + # Define the structured output format. + class AgentResponse(BaseModel): + thoughts: str + response: Literal["happy", "sad", "neutral"] - config = { - "provider": "OpenAIChatCompletionClient", - "config": {"model": "gpt-4o", "api_key": "REPLACE_WITH_YOUR_API_KEY"}, - } - client = ChatCompletionClient.load_component(config) + # Define the function to be called as a tool. + def sentiment_analysis(text: str) -> str: + \"\"\"Given a text, return the sentiment.\"\"\" + return "happy" if "happy" in text else "sad" if "sad" in text else "neutral" + + + # Create a FunctionTool instance with `strict=True`, + # which is required for structured output mode. + tool = FunctionTool(sentiment_analysis, description="Sentiment Analysis", strict=True) + + # Create an OpenAIChatCompletionClient instance. + model_client = OpenAIChatCompletionClient( + model="gpt-4o-mini", + response_format=AgentResponse, # type: ignore + ) + + + async def main() -> None: + # Generate a response using the tool. + response1 = await model_client.create( + messages=[ + SystemMessage(content="Analyze input text sentiment using the tool provided."), + UserMessage(content="I am happy.", source="user"), + ], + tools=[tool], + ) + print(response1.content) + # Should be a list of tool calls. + # [FunctionCall(name="sentiment_analysis", arguments={"text": "I am happy."}, ...)] + + assert isinstance(response1.content, list) + response2 = await model_client.create( + messages=[ + SystemMessage(content="Analyze input text sentiment using the tool provided."), + UserMessage(content="I am happy.", source="user"), + AssistantMessage(content=response1.content, source="assistant"), + FunctionExecutionResultMessage( + content=[FunctionExecutionResult(content="happy", call_id=response1.content[0].id, is_error=False)] + ), + ], + ) + print(response2.content) + # Should be a structured output. + # {"thoughts": "The user is happy.", "response": "happy"} + + + asyncio.run(main()) + + + To load the client from a configuration, you can use the `load_component` method: + + .. code-block:: python + + from autogen_core.models import ChatCompletionClient + + config = { + "provider": "OpenAIChatCompletionClient", + "config": {"model": "gpt-4o", "api_key": "REPLACE_WITH_YOUR_API_KEY"}, + } + + client = ChatCompletionClient.load_component(config) - To view the full list of available configuration options, see the :py:class:`OpenAIClientConfigurationConfigModel` class. + To view the full list of available configuration options, see the :py:class:`OpenAIClientConfigurationConfigModel` class. """ diff --git a/python/packages/autogen-ext/tests/models/test_openai_model_client.py b/python/packages/autogen-ext/tests/models/test_openai_model_client.py index 74c4b2c08f3f..ea5b28e28a63 100644 --- a/python/packages/autogen-ext/tests/models/test_openai_model_client.py +++ b/python/packages/autogen-ext/tests/models/test_openai_model_client.py @@ -956,6 +956,82 @@ async def test_openai() -> None: await _test_model_client_with_function_calling(model_client) +@pytest.mark.asyncio +async def test_openai_structured_output() -> None: + api_key = os.getenv("OPENAI_API_KEY") + if not api_key: + pytest.skip("OPENAI_API_KEY not found in environment variables") + + class AgentResponse(BaseModel): + thoughts: str + response: Literal["happy", "sad", "neutral"] + + model_client = OpenAIChatCompletionClient( + model="gpt-4o-mini", + api_key=api_key, + response_format=AgentResponse, # type: ignore + ) + + # Test that the openai client was called with the correct response format. + create_result = await model_client.create(messages=[UserMessage(content="I am happy.", source="user")]) + assert isinstance(create_result.content, str) + response = AgentResponse.model_validate(json.loads(create_result.content)) + assert response.thoughts + assert response.response in ["happy", "sad", "neutral"] + + +@pytest.mark.asyncio +async def test_openai_structured_output_with_tool_calls() -> None: + api_key = os.getenv("OPENAI_API_KEY") + if not api_key: + pytest.skip("OPENAI_API_KEY not found in environment variables") + + class AgentResponse(BaseModel): + thoughts: str + response: Literal["happy", "sad", "neutral"] + + def sentiment_analysis(text: str) -> str: + """Given a text, return the sentiment.""" + return "happy" if "happy" in text else "sad" if "sad" in text else "neutral" + + tool = FunctionTool(sentiment_analysis, description="Sentiment Analysis", strict=True) + + model_client = OpenAIChatCompletionClient( + model="gpt-4o-mini", + api_key=api_key, + response_format=AgentResponse, # type: ignore + ) + + response1 = await model_client.create( + messages=[ + SystemMessage(content="Analyze input text sentiment using the tool provided."), + UserMessage(content="I am happy.", source="user"), + ], + tools=[tool], + ) + assert isinstance(response1.content, list) + assert len(response1.content) == 1 + assert isinstance(response1.content[0], FunctionCall) + assert response1.content[0].name == "sentiment_analysis" + assert json.loads(response1.content[0].arguments) == {"text": "I am happy."} + assert response1.finish_reason == "function_calls" + + response2 = await model_client.create( + messages=[ + SystemMessage(content="Analyze input text sentiment using the tool provided."), + UserMessage(content="I am happy.", source="user"), + AssistantMessage(content=response1.content, source="assistant"), + FunctionExecutionResultMessage( + content=[FunctionExecutionResult(content="happy", call_id=response1.content[0].id, is_error=False)] + ), + ], + ) + assert isinstance(response2.content, str) + parsed_response = AgentResponse.model_validate(json.loads(response2.content)) + assert parsed_response.thoughts + assert parsed_response.response in ["happy", "sad", "neutral"] + + @pytest.mark.asyncio async def test_gemini() -> None: api_key = os.getenv("GEMINI_API_KEY")