Skip to content

Commit

Permalink
fix: Address tool call execution scenario when model produces empty t…
Browse files Browse the repository at this point in the history
…ool call ids (#5509)

Resolves #5508
  • Loading branch information
ekzhu authored Feb 14, 2025
1 parent ff7f863 commit e7a3c78
Show file tree
Hide file tree
Showing 3 changed files with 162 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
List,
Mapping,
Sequence,
Tuple,
)

from autogen_core import CancellationToken, Component, ComponentModel, FunctionCall
Expand Down Expand Up @@ -441,28 +442,26 @@ async def on_messages_stream(
inner_messages.append(tool_call_msg)
yield tool_call_msg

# Execute the tool calls.
exec_results = await asyncio.gather(
# Execute the tool calls and hanoff calls.
executed_calls_and_results = await asyncio.gather(
*[self._execute_tool_call(call, cancellation_token) for call in model_result.content]
)
# Collect the execution results in a list.
exec_results = [result for _, result in executed_calls_and_results]
# Add the execution results to output and model context.
tool_call_result_msg = ToolCallExecutionEvent(content=exec_results, source=self.name)
event_logger.debug(tool_call_result_msg)
await self._model_context.add_message(FunctionExecutionResultMessage(content=exec_results))
inner_messages.append(tool_call_result_msg)
yield tool_call_result_msg

# Correlate tool call results with tool calls.
tool_calls = [call for call in model_result.content if call.name not in self._handoffs]
# Separate out tool calls and tool call results from handoff requests.
tool_calls: List[FunctionCall] = []
tool_call_results: List[FunctionExecutionResult] = []
for tool_call in tool_calls:
found = False
for exec_result in exec_results:
if exec_result.call_id == tool_call.id:
found = True
tool_call_results.append(exec_result)
break
if not found:
raise RuntimeError(f"Tool call result not found for call id: {tool_call.id}")
for exec_call, exec_result in executed_calls_and_results:
if exec_call.name not in self._handoffs:
tool_calls.append(exec_call)
tool_call_results.append(exec_result)

# Detect handoff requests.
handoff_reqs = [call for call in model_result.content if call.name in self._handoffs]
Expand Down Expand Up @@ -546,7 +545,7 @@ async def on_messages_stream(

async def _execute_tool_call(
self, tool_call: FunctionCall, cancellation_token: CancellationToken
) -> FunctionExecutionResult:
) -> Tuple[FunctionCall, FunctionExecutionResult]:
"""Execute a tool call and return the result."""
try:
if not self._tools + self._handoff_tools:
Expand All @@ -557,9 +556,9 @@ async def _execute_tool_call(
arguments = json.loads(tool_call.arguments)
result = await tool.run_json(arguments, cancellation_token)
result_as_str = tool.return_value_as_string(result)
return FunctionExecutionResult(content=result_as_str, call_id=tool_call.id, is_error=False)
return (tool_call, FunctionExecutionResult(content=result_as_str, call_id=tool_call.id, is_error=False))
except Exception as e:
return FunctionExecutionResult(content=f"Error: {e}", call_id=tool_call.id, is_error=True)
return (tool_call, FunctionExecutionResult(content=f"Error: {e}", call_id=tool_call.id, is_error=True))

async def on_reset(self, cancellation_token: CancellationToken) -> None:
"""Reset the assistant agent to its initialization state."""
Expand Down
136 changes: 136 additions & 0 deletions python/packages/autogen-agentchat/tests/test_assistant_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,6 +427,142 @@ async def test_run_with_parallel_tools(monkeypatch: pytest.MonkeyPatch) -> None:
assert state == state2


@pytest.mark.asyncio
async def test_run_with_parallel_tools_with_empty_call_ids(monkeypatch: pytest.MonkeyPatch) -> None:
model = "gpt-4o-2024-05-13"
chat_completions = [
ChatCompletion(
id="id1",
choices=[
Choice(
finish_reason="tool_calls",
index=0,
message=ChatCompletionMessage(
content=None,
tool_calls=[
ChatCompletionMessageToolCall(
id="",
type="function",
function=Function(
name="_pass_function",
arguments=json.dumps({"input": "task1"}),
),
),
ChatCompletionMessageToolCall(
id="",
type="function",
function=Function(
name="_pass_function",
arguments=json.dumps({"input": "task2"}),
),
),
ChatCompletionMessageToolCall(
id="",
type="function",
function=Function(
name="_echo_function",
arguments=json.dumps({"input": "task3"}),
),
),
],
role="assistant",
),
)
],
created=0,
model=model,
object="chat.completion",
usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0),
),
ChatCompletion(
id="id2",
choices=[
Choice(
finish_reason="stop",
index=0,
message=ChatCompletionMessage(content="pass", role="assistant"),
)
],
created=0,
model=model,
object="chat.completion",
usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0),
),
ChatCompletion(
id="id2",
choices=[
Choice(
finish_reason="stop",
index=0,
message=ChatCompletionMessage(content="TERMINATE", role="assistant"),
)
],
created=0,
model=model,
object="chat.completion",
usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0),
),
]
mock = _MockChatCompletion(chat_completions)
monkeypatch.setattr(AsyncCompletions, "create", mock.mock_create)
agent = AssistantAgent(
"tool_use_agent",
model_client=OpenAIChatCompletionClient(model=model, api_key=""),
tools=[
_pass_function,
_fail_function,
FunctionTool(_echo_function, description="Echo"),
],
)
result = await agent.run(task="task")

assert len(result.messages) == 4
assert isinstance(result.messages[0], TextMessage)
assert result.messages[0].models_usage is None
assert isinstance(result.messages[1], ToolCallRequestEvent)
assert result.messages[1].content == [
FunctionCall(id="", arguments=r'{"input": "task1"}', name="_pass_function"),
FunctionCall(id="", arguments=r'{"input": "task2"}', name="_pass_function"),
FunctionCall(id="", arguments=r'{"input": "task3"}', name="_echo_function"),
]
assert result.messages[1].models_usage is not None
assert result.messages[1].models_usage.completion_tokens == 5
assert result.messages[1].models_usage.prompt_tokens == 10
assert isinstance(result.messages[2], ToolCallExecutionEvent)
expected_content = [
FunctionExecutionResult(call_id="", content="pass", is_error=False),
FunctionExecutionResult(call_id="", content="pass", is_error=False),
FunctionExecutionResult(call_id="", content="task3", is_error=False),
]
for expected in expected_content:
assert expected in result.messages[2].content
assert result.messages[2].models_usage is None
assert isinstance(result.messages[3], ToolCallSummaryMessage)
assert result.messages[3].content == "pass\npass\ntask3"
assert result.messages[3].models_usage is None

# Test streaming.
mock.curr_index = 0 # Reset the mock
index = 0
async for message in agent.run_stream(task="task"):
if isinstance(message, TaskResult):
assert message == result
else:
assert message == result.messages[index]
index += 1

# Test state saving and loading.
state = await agent.save_state()
agent2 = AssistantAgent(
"tool_use_agent",
model_client=OpenAIChatCompletionClient(model=model, api_key=""),
tools=[_pass_function, _fail_function, FunctionTool(_echo_function, description="Echo")],
)
await agent2.load_state(state)
state2 = await agent2.save_state()
assert state == state2


@pytest.mark.asyncio
async def test_handoffs(monkeypatch: pytest.MonkeyPatch) -> None:
handoff = Handoff(target="agent2")
Expand Down
13 changes: 11 additions & 2 deletions python/packages/autogen-core/src/autogen_core/models/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,19 @@ class SystemMessage(BaseModel):
"""

content: str
"""The content of the message."""

type: Literal["SystemMessage"] = "SystemMessage"


class UserMessage(BaseModel):
"""User message contains input from end users, or a catch-all for data provided to the model."""

content: Union[str, List[Union[str, Image]]]
"""The content of the message."""

# Name of the agent that sent this message
source: str
"""The name of the agent that sent this message."""

type: Literal["UserMessage"] = "UserMessage"

Expand All @@ -39,9 +42,10 @@ class AssistantMessage(BaseModel):
"""Assistant message are sampled from the language model."""

content: Union[str, List[FunctionCall]]
"""The content of the message."""

# Name of the agent that sent this message
source: str
"""The name of the agent that sent this message."""

type: Literal["AssistantMessage"] = "AssistantMessage"

Expand All @@ -50,8 +54,13 @@ class FunctionExecutionResult(BaseModel):
"""Function execution result contains the output of a function call."""

content: str
"""The output of the function call."""

call_id: str
"""The ID of the function call. Note this ID may be empty for some models."""

is_error: bool | None = None
"""Whether the function call resulted in an error."""


class FunctionExecutionResultMessage(BaseModel):
Expand Down

0 comments on commit e7a3c78

Please sign in to comment.