From e7a3c7859458c56343fa80d3dae2ec972d86630a Mon Sep 17 00:00:00 2001 From: Eric Zhu Date: Thu, 13 Feb 2025 23:11:44 -0800 Subject: [PATCH] fix: Address tool call execution scenario when model produces empty tool call ids (#5509) Resolves #5508 --- .../agents/_assistant_agent.py | 31 ++-- .../tests/test_assistant_agent.py | 136 ++++++++++++++++++ .../src/autogen_core/models/_types.py | 13 +- 3 files changed, 162 insertions(+), 18 deletions(-) diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py index 6bba99241652..3b109c19670c 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py @@ -11,6 +11,7 @@ List, Mapping, Sequence, + Tuple, ) from autogen_core import CancellationToken, Component, ComponentModel, FunctionCall @@ -441,28 +442,26 @@ async def on_messages_stream( inner_messages.append(tool_call_msg) yield tool_call_msg - # Execute the tool calls. - exec_results = await asyncio.gather( + # Execute the tool calls and hanoff calls. + executed_calls_and_results = await asyncio.gather( *[self._execute_tool_call(call, cancellation_token) for call in model_result.content] ) + # Collect the execution results in a list. + exec_results = [result for _, result in executed_calls_and_results] + # Add the execution results to output and model context. tool_call_result_msg = ToolCallExecutionEvent(content=exec_results, source=self.name) event_logger.debug(tool_call_result_msg) await self._model_context.add_message(FunctionExecutionResultMessage(content=exec_results)) inner_messages.append(tool_call_result_msg) yield tool_call_result_msg - # Correlate tool call results with tool calls. - tool_calls = [call for call in model_result.content if call.name not in self._handoffs] + # Separate out tool calls and tool call results from handoff requests. + tool_calls: List[FunctionCall] = [] tool_call_results: List[FunctionExecutionResult] = [] - for tool_call in tool_calls: - found = False - for exec_result in exec_results: - if exec_result.call_id == tool_call.id: - found = True - tool_call_results.append(exec_result) - break - if not found: - raise RuntimeError(f"Tool call result not found for call id: {tool_call.id}") + for exec_call, exec_result in executed_calls_and_results: + if exec_call.name not in self._handoffs: + tool_calls.append(exec_call) + tool_call_results.append(exec_result) # Detect handoff requests. handoff_reqs = [call for call in model_result.content if call.name in self._handoffs] @@ -546,7 +545,7 @@ async def on_messages_stream( async def _execute_tool_call( self, tool_call: FunctionCall, cancellation_token: CancellationToken - ) -> FunctionExecutionResult: + ) -> Tuple[FunctionCall, FunctionExecutionResult]: """Execute a tool call and return the result.""" try: if not self._tools + self._handoff_tools: @@ -557,9 +556,9 @@ async def _execute_tool_call( arguments = json.loads(tool_call.arguments) result = await tool.run_json(arguments, cancellation_token) result_as_str = tool.return_value_as_string(result) - return FunctionExecutionResult(content=result_as_str, call_id=tool_call.id, is_error=False) + return (tool_call, FunctionExecutionResult(content=result_as_str, call_id=tool_call.id, is_error=False)) except Exception as e: - return FunctionExecutionResult(content=f"Error: {e}", call_id=tool_call.id, is_error=True) + return (tool_call, FunctionExecutionResult(content=f"Error: {e}", call_id=tool_call.id, is_error=True)) async def on_reset(self, cancellation_token: CancellationToken) -> None: """Reset the assistant agent to its initialization state.""" diff --git a/python/packages/autogen-agentchat/tests/test_assistant_agent.py b/python/packages/autogen-agentchat/tests/test_assistant_agent.py index 2d82158ef4cd..827a547660e4 100644 --- a/python/packages/autogen-agentchat/tests/test_assistant_agent.py +++ b/python/packages/autogen-agentchat/tests/test_assistant_agent.py @@ -427,6 +427,142 @@ async def test_run_with_parallel_tools(monkeypatch: pytest.MonkeyPatch) -> None: assert state == state2 +@pytest.mark.asyncio +async def test_run_with_parallel_tools_with_empty_call_ids(monkeypatch: pytest.MonkeyPatch) -> None: + model = "gpt-4o-2024-05-13" + chat_completions = [ + ChatCompletion( + id="id1", + choices=[ + Choice( + finish_reason="tool_calls", + index=0, + message=ChatCompletionMessage( + content=None, + tool_calls=[ + ChatCompletionMessageToolCall( + id="", + type="function", + function=Function( + name="_pass_function", + arguments=json.dumps({"input": "task1"}), + ), + ), + ChatCompletionMessageToolCall( + id="", + type="function", + function=Function( + name="_pass_function", + arguments=json.dumps({"input": "task2"}), + ), + ), + ChatCompletionMessageToolCall( + id="", + type="function", + function=Function( + name="_echo_function", + arguments=json.dumps({"input": "task3"}), + ), + ), + ], + role="assistant", + ), + ) + ], + created=0, + model=model, + object="chat.completion", + usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0), + ), + ChatCompletion( + id="id2", + choices=[ + Choice( + finish_reason="stop", + index=0, + message=ChatCompletionMessage(content="pass", role="assistant"), + ) + ], + created=0, + model=model, + object="chat.completion", + usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0), + ), + ChatCompletion( + id="id2", + choices=[ + Choice( + finish_reason="stop", + index=0, + message=ChatCompletionMessage(content="TERMINATE", role="assistant"), + ) + ], + created=0, + model=model, + object="chat.completion", + usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0), + ), + ] + mock = _MockChatCompletion(chat_completions) + monkeypatch.setattr(AsyncCompletions, "create", mock.mock_create) + agent = AssistantAgent( + "tool_use_agent", + model_client=OpenAIChatCompletionClient(model=model, api_key=""), + tools=[ + _pass_function, + _fail_function, + FunctionTool(_echo_function, description="Echo"), + ], + ) + result = await agent.run(task="task") + + assert len(result.messages) == 4 + assert isinstance(result.messages[0], TextMessage) + assert result.messages[0].models_usage is None + assert isinstance(result.messages[1], ToolCallRequestEvent) + assert result.messages[1].content == [ + FunctionCall(id="", arguments=r'{"input": "task1"}', name="_pass_function"), + FunctionCall(id="", arguments=r'{"input": "task2"}', name="_pass_function"), + FunctionCall(id="", arguments=r'{"input": "task3"}', name="_echo_function"), + ] + assert result.messages[1].models_usage is not None + assert result.messages[1].models_usage.completion_tokens == 5 + assert result.messages[1].models_usage.prompt_tokens == 10 + assert isinstance(result.messages[2], ToolCallExecutionEvent) + expected_content = [ + FunctionExecutionResult(call_id="", content="pass", is_error=False), + FunctionExecutionResult(call_id="", content="pass", is_error=False), + FunctionExecutionResult(call_id="", content="task3", is_error=False), + ] + for expected in expected_content: + assert expected in result.messages[2].content + assert result.messages[2].models_usage is None + assert isinstance(result.messages[3], ToolCallSummaryMessage) + assert result.messages[3].content == "pass\npass\ntask3" + assert result.messages[3].models_usage is None + + # Test streaming. + mock.curr_index = 0 # Reset the mock + index = 0 + async for message in agent.run_stream(task="task"): + if isinstance(message, TaskResult): + assert message == result + else: + assert message == result.messages[index] + index += 1 + + # Test state saving and loading. + state = await agent.save_state() + agent2 = AssistantAgent( + "tool_use_agent", + model_client=OpenAIChatCompletionClient(model=model, api_key=""), + tools=[_pass_function, _fail_function, FunctionTool(_echo_function, description="Echo")], + ) + await agent2.load_state(state) + state2 = await agent2.save_state() + assert state == state2 + + @pytest.mark.asyncio async def test_handoffs(monkeypatch: pytest.MonkeyPatch) -> None: handoff = Handoff(target="agent2") diff --git a/python/packages/autogen-core/src/autogen_core/models/_types.py b/python/packages/autogen-core/src/autogen_core/models/_types.py index 22b1cff3bd40..76d87ab8d46d 100644 --- a/python/packages/autogen-core/src/autogen_core/models/_types.py +++ b/python/packages/autogen-core/src/autogen_core/models/_types.py @@ -21,6 +21,8 @@ class SystemMessage(BaseModel): """ content: str + """The content of the message.""" + type: Literal["SystemMessage"] = "SystemMessage" @@ -28,9 +30,10 @@ class UserMessage(BaseModel): """User message contains input from end users, or a catch-all for data provided to the model.""" content: Union[str, List[Union[str, Image]]] + """The content of the message.""" - # Name of the agent that sent this message source: str + """The name of the agent that sent this message.""" type: Literal["UserMessage"] = "UserMessage" @@ -39,9 +42,10 @@ class AssistantMessage(BaseModel): """Assistant message are sampled from the language model.""" content: Union[str, List[FunctionCall]] + """The content of the message.""" - # Name of the agent that sent this message source: str + """The name of the agent that sent this message.""" type: Literal["AssistantMessage"] = "AssistantMessage" @@ -50,8 +54,13 @@ class FunctionExecutionResult(BaseModel): """Function execution result contains the output of a function call.""" content: str + """The output of the function call.""" + call_id: str + """The ID of the function call. Note this ID may be empty for some models.""" + is_error: bool | None = None + """Whether the function call resulted in an error.""" class FunctionExecutionResultMessage(BaseModel):