Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AssistantAgent no longer sends out StopMessage. We use TextMentionTermination("TERMINATE") on the team instead for default setting. #4030

Merged
merged 2 commits into from
Nov 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
HandoffMessage,
InnerMessage,
ResetMessage,
StopMessage,
TextMessage,
ToolCallMessage,
ToolCallResultMessage,
Expand Down Expand Up @@ -232,8 +231,8 @@ def __init__(
def produced_message_types(self) -> List[type[ChatMessage]]:
"""The types of messages that the assistant agent produces."""
if self._handoffs:
return [TextMessage, HandoffMessage, StopMessage]
return [TextMessage, StopMessage]
return [TextMessage, HandoffMessage]
return [TextMessage]

async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
async for message in self.on_messages_stream(messages, cancellation_token):
Expand Down Expand Up @@ -303,16 +302,9 @@ async def on_messages_stream(
self._model_context.append(AssistantMessage(content=result.content, source=self.name))

assert isinstance(result.content, str)
# Detect stop request.
request_stop = "terminate" in result.content.strip().lower()
if request_stop:
yield Response(
chat_message=StopMessage(content=result.content, source=self.name), inner_messages=inner_messages
)
else:
yield Response(
chat_message=TextMessage(content=result.content, source=self.name), inner_messages=inner_messages
)
yield Response(
chat_message=TextMessage(content=result.content, source=self.name), inner_messages=inner_messages
)

async def _execute_tool_call(
self, tool_call: FunctionCall, cancellation_token: CancellationToken
Expand Down
28 changes: 14 additions & 14 deletions python/packages/autogen-agentchat/tests/test_group_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
ToolCallMessage,
ToolCallResultMessage,
)
from autogen_agentchat.task import MaxMessageTermination, StopMessageTermination
from autogen_agentchat.task import MaxMessageTermination, TextMentionTermination
from autogen_agentchat.teams import (
RoundRobinGroupChat,
SelectorGroupChat,
Expand Down Expand Up @@ -151,7 +151,7 @@ async def test_round_robin_group_chat(monkeypatch: pytest.MonkeyPatch) -> None:
team = RoundRobinGroupChat(participants=[coding_assistant_agent, code_executor_agent])
result = await team.run(
"Write a program that prints 'Hello, world!'",
termination_condition=StopMessageTermination(),
termination_condition=TextMentionTermination("TERMINATE"),
)
expected_messages = [
"Write a program that prints 'Hello, world!'",
Expand All @@ -172,7 +172,7 @@ async def test_round_robin_group_chat(monkeypatch: pytest.MonkeyPatch) -> None:
mock.reset()
index = 0
async for message in team.run_stream(
"Write a program that prints 'Hello, world!'", termination_condition=StopMessageTermination()
"Write a program that prints 'Hello, world!'", termination_condition=TextMentionTermination("TERMINATE")
):
if isinstance(message, TaskResult):
assert message == result
Expand Down Expand Up @@ -247,7 +247,7 @@ async def test_round_robin_group_chat_with_tools(monkeypatch: pytest.MonkeyPatch
team = RoundRobinGroupChat(participants=[tool_use_agent, echo_agent])
result = await team.run(
"Write a program that prints 'Hello, world!'",
termination_condition=StopMessageTermination(),
termination_condition=TextMentionTermination("TERMINATE"),
)

assert len(result.messages) == 6
Expand All @@ -256,7 +256,7 @@ async def test_round_robin_group_chat_with_tools(monkeypatch: pytest.MonkeyPatch
assert isinstance(result.messages[2], ToolCallResultMessage) # tool call result
assert isinstance(result.messages[3], TextMessage) # tool use agent response
assert isinstance(result.messages[4], TextMessage) # echo agent response
assert isinstance(result.messages[5], StopMessage) # tool use agent response
assert isinstance(result.messages[5], TextMessage) # tool use agent response

context = tool_use_agent._model_context # pyright: ignore
assert context[0].content == "Write a program that prints 'Hello, world!'"
Expand All @@ -275,7 +275,7 @@ async def test_round_robin_group_chat_with_tools(monkeypatch: pytest.MonkeyPatch
mock.reset()
index = 0
async for message in team.run_stream(
"Write a program that prints 'Hello, world!'", termination_condition=StopMessageTermination()
"Write a program that prints 'Hello, world!'", termination_condition=TextMentionTermination("TERMINATE")
):
if isinstance(message, TaskResult):
assert message == result
Expand Down Expand Up @@ -351,7 +351,7 @@ async def test_selector_group_chat(monkeypatch: pytest.MonkeyPatch) -> None:
)
result = await team.run(
"Write a program that prints 'Hello, world!'",
termination_condition=StopMessageTermination(),
termination_condition=TextMentionTermination("TERMINATE"),
)
assert len(result.messages) == 6
assert result.messages[0].content == "Write a program that prints 'Hello, world!'"
Expand All @@ -366,7 +366,7 @@ async def test_selector_group_chat(monkeypatch: pytest.MonkeyPatch) -> None:
agent1._count = 0 # pyright: ignore
index = 0
async for message in team.run_stream(
"Write a program that prints 'Hello, world!'", termination_condition=StopMessageTermination()
"Write a program that prints 'Hello, world!'", termination_condition=TextMentionTermination("TERMINATE")
):
if isinstance(message, TaskResult):
assert message == result
Expand Down Expand Up @@ -401,7 +401,7 @@ async def test_selector_group_chat_two_speakers(monkeypatch: pytest.MonkeyPatch)
)
result = await team.run(
"Write a program that prints 'Hello, world!'",
termination_condition=StopMessageTermination(),
termination_condition=TextMentionTermination("TERMINATE"),
)
assert len(result.messages) == 5
assert result.messages[0].content == "Write a program that prints 'Hello, world!'"
Expand All @@ -417,7 +417,7 @@ async def test_selector_group_chat_two_speakers(monkeypatch: pytest.MonkeyPatch)
agent1._count = 0 # pyright: ignore
index = 0
async for message in team.run_stream(
"Write a program that prints 'Hello, world!'", termination_condition=StopMessageTermination()
"Write a program that prints 'Hello, world!'", termination_condition=TextMentionTermination("TERMINATE")
):
if isinstance(message, TaskResult):
assert message == result
Expand Down Expand Up @@ -472,7 +472,7 @@ async def test_selector_group_chat_two_speakers_allow_repeated(monkeypatch: pyte
allow_repeated_speaker=True,
)
result = await team.run(
"Write a program that prints 'Hello, world!'", termination_condition=StopMessageTermination()
"Write a program that prints 'Hello, world!'", termination_condition=TextMentionTermination("TERMINATE")
)
assert len(result.messages) == 4
assert result.messages[0].content == "Write a program that prints 'Hello, world!'"
Expand All @@ -484,7 +484,7 @@ async def test_selector_group_chat_two_speakers_allow_repeated(monkeypatch: pyte
mock.reset()
index = 0
async for message in team.run_stream(
"Write a program that prints 'Hello, world!'", termination_condition=StopMessageTermination()
"Write a program that prints 'Hello, world!'", termination_condition=TextMentionTermination("TERMINATE")
):
if isinstance(message, TaskResult):
assert message == result
Expand Down Expand Up @@ -649,7 +649,7 @@ async def test_swarm_handoff_using_tool_calls(monkeypatch: pytest.MonkeyPatch) -
)
agent2 = _HandOffAgent("agent2", description="agent 2", next_agent="agent1")
team = Swarm([agent1, agent2])
result = await team.run("task", termination_condition=StopMessageTermination())
result = await team.run("task", termination_condition=TextMentionTermination("TERMINATE"))
assert len(result.messages) == 7
assert result.messages[0].content == "task"
assert isinstance(result.messages[1], ToolCallMessage)
Expand All @@ -663,7 +663,7 @@ async def test_swarm_handoff_using_tool_calls(monkeypatch: pytest.MonkeyPatch) -
agent1._model_context.clear() # pyright: ignore
mock.reset()
index = 0
stream = team.run_stream("task", termination_condition=StopMessageTermination())
stream = team.run_stream("task", termination_condition=TextMentionTermination("TERMINATE"))
async for message in stream:
if isinstance(message, TaskResult):
assert message == result
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from autogen_agentchat.agents import CodingAssistantAgent, ToolUseAssistantAgent\n",
"from autogen_agentchat.task import StopMessageTermination\n",
"from autogen_agentchat.task import TextMentionTermination\n",
"from autogen_agentchat.teams import RoundRobinGroupChat\n",
"from autogen_core.components.tools import FunctionTool\n",
"from autogen_ext.models import OpenAIChatCompletionClient"
Expand Down Expand Up @@ -265,7 +265,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": null,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -400,7 +400,9 @@
}
],
"source": [
"result = await team.run(\"Write a financial report on American airlines\", termination_condition=StopMessageTermination())\n",
"result = await team.run(\n",
" \"Write a financial report on American airlines\", termination_condition=TextMentionTermination(\"TERMINATE\")\n",
")\n",
"print(result)"
]
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from autogen_agentchat.agents import CodingAssistantAgent, ToolUseAssistantAgent\n",
"from autogen_agentchat.task import StopMessageTermination\n",
"from autogen_agentchat.task import TextMentionTermination\n",
"from autogen_agentchat.teams import RoundRobinGroupChat\n",
"from autogen_core.components.tools import FunctionTool\n",
"from autogen_ext.models import OpenAIChatCompletionClient"
Expand Down Expand Up @@ -161,7 +161,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": null,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -332,7 +332,7 @@
"\n",
"result = await team.run(\n",
" task=\"Write a literature review on no code tools for building multi agent ai systems\",\n",
" termination_condition=StopMessageTermination(),\n",
" termination_condition=TextMentionTermination(\"TERMINATE\"),\n",
")"
]
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from autogen_agentchat.agents import CodingAssistantAgent\n",
"from autogen_agentchat.task import StopMessageTermination\n",
"from autogen_agentchat.task import TextMentionTermination\n",
"from autogen_agentchat.teams import RoundRobinGroupChat\n",
"from autogen_ext.models import OpenAIChatCompletionClient"
]
Expand Down Expand Up @@ -69,7 +69,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": null,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -195,7 +195,9 @@
],
"source": [
"group_chat = RoundRobinGroupChat([planner_agent, local_agent, language_agent, travel_summary_agent])\n",
"result = await group_chat.run(task=\"Plan a 3 day trip to Nepal.\", termination_condition=StopMessageTermination())\n",
"result = await group_chat.run(\n",
" task=\"Plan a 3 day trip to Nepal.\", termination_condition=TextMentionTermination(\"TERMINATE\")\n",
")\n",
"print(result)"
]
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -47,7 +47,7 @@
")\n",
"from autogen_agentchat.base import Response\n",
"from autogen_agentchat.messages import ChatMessage, StopMessage, TextMessage\n",
"from autogen_agentchat.task import StopMessageTermination\n",
"from autogen_agentchat.task import TextMentionTermination\n",
"from autogen_agentchat.teams import SelectorGroupChat\n",
"from autogen_core.base import CancellationToken\n",
"from autogen_core.components.tools import FunctionTool\n",
Expand Down Expand Up @@ -114,7 +114,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": null,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -254,7 +254,7 @@
"team = SelectorGroupChat(\n",
" [user_proxy, flight_broker, travel_assistant], model_client=OpenAIChatCompletionClient(model=\"gpt-4o-mini\")\n",
")\n",
"await team.run(\"Help user plan a trip and book a flight.\", termination_condition=StopMessageTermination())"
"await team.run(\"Help user plan a trip and book a flight.\", termination_condition=TextMentionTermination(\"TERMINATE\"))"
]
}
],
Expand All @@ -274,7 +274,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.5"
"version": "3.12.6"
}
},
"nbformat": 4,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -37,7 +37,7 @@
"from autogen_agentchat import EVENT_LOGGER_NAME\n",
"from autogen_agentchat.agents import CodingAssistantAgent\n",
"from autogen_agentchat.logging import ConsoleLogHandler\n",
"from autogen_agentchat.task import MaxMessageTermination, StopMessageTermination\n",
"from autogen_agentchat.task import MaxMessageTermination, TextMentionTermination\n",
"from autogen_agentchat.teams import RoundRobinGroupChat\n",
"from autogen_core.components.models import OpenAIChatCompletionClient\n",
"\n",
Expand Down Expand Up @@ -140,7 +140,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": null,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -178,7 +178,7 @@
"round_robin_team = RoundRobinGroupChat([writing_assistant_agent])\n",
"\n",
"round_robin_team_result = await round_robin_team.run(\n",
" \"Write a unique, Haiku about the weather in Paris\", termination_condition=StopMessageTermination()\n",
" \"Write a unique, Haiku about the weather in Paris\", termination_condition=TextMentionTermination(\"TERMINATE\")\n",
")"
]
}
Expand Down
Loading