diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py index 22b64a342061..70de5713814d 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py @@ -3,7 +3,7 @@ from typing import Any, Callable, Dict, List, Mapping, Sequence from autogen_core import Component, ComponentModel -from autogen_core.models import ChatCompletionClient, ModelFamily, SystemMessage, UserMessage +from autogen_core.models import AssistantMessage, ChatCompletionClient, ModelFamily, SystemMessage, UserMessage from pydantic import BaseModel from typing_extensions import Self @@ -39,6 +39,7 @@ def __init__( selector_prompt: str, allow_repeated_speaker: bool, selector_func: Callable[[Sequence[AgentEvent | ChatMessage]], str | None] | None, + max_selector_attempts: int, ) -> None: super().__init__( group_topic_type, @@ -53,6 +54,7 @@ def __init__( self._previous_speaker: str | None = None self._allow_repeated_speaker = allow_repeated_speaker self._selector_func = selector_func + self._max_selector_attempts = max_selector_attempts async def validate_group_state(self, messages: List[ChatMessage] | None) -> None: pass @@ -131,54 +133,71 @@ async def select_speaker(self, thread: List[AgentEvent | ChatMessage]) -> str: # Select the next speaker. if len(participants) > 1: - select_speaker_prompt = self._selector_prompt.format( - roles=roles, participants=str(participants), history=history - ) - select_speaker_messages: List[SystemMessage | UserMessage] - if self._model_client.model_info["family"] in [ - ModelFamily.GPT_4, - ModelFamily.GPT_4O, - ModelFamily.GPT_35, - ModelFamily.O1, - ModelFamily.O3, - ]: - select_speaker_messages = [SystemMessage(content=select_speaker_prompt)] - else: - # Many other models need a UserMessage to respond to - select_speaker_messages = [UserMessage(content=select_speaker_prompt, source="selector")] + agent_name = await self._select_speaker(roles, participants, history, self._max_selector_attempts) + else: + agent_name = participants[0] + self._previous_speaker = agent_name + trace_logger.debug(f"Selected speaker: {agent_name}") + return agent_name - response = await self._model_client.create(messages=select_speaker_messages) + async def _select_speaker(self, roles: str, participants: List[str], history: str, max_attempts: int) -> str: + select_speaker_prompt = self._selector_prompt.format( + roles=roles, participants=str(participants), history=history + ) + select_speaker_messages: List[SystemMessage | UserMessage | AssistantMessage] + if self._model_client.model_info["family"] in [ + ModelFamily.GPT_4, + ModelFamily.GPT_4O, + ModelFamily.GPT_35, + ModelFamily.O1, + ModelFamily.O3, + ]: + select_speaker_messages = [SystemMessage(content=select_speaker_prompt)] + else: + # Many other models need a UserMessage to respond to + select_speaker_messages = [UserMessage(content=select_speaker_prompt, source="user")] + num_attempts = 0 + while num_attempts < max_attempts: + num_attempts += 1 + response = await self._model_client.create(messages=select_speaker_messages) assert isinstance(response.content, str) + select_speaker_messages.append(AssistantMessage(content=response.content, source="selector")) mentions = self._mentioned_agents(response.content, self._participant_topic_types) if len(mentions) == 0: - trace_logger.warning( - f"No valid agent was mentioned in the model response: {response.content}. " - "Using the previous speaker if available, otherwise selecting the first participant." + trace_logger.debug(f"Model failed to select a valid name: {response.content} (attempt {num_attempts})") + feedback = f"No valid name was mentioned. Please select from: {str(participants)}." + select_speaker_messages.append(UserMessage(content=feedback, source="user")) + elif len(mentions) > 1: + trace_logger.debug(f"Model selected multiple names: {str(mentions)} (attempt {num_attempts})") + feedback = ( + f"Expected exactly one name to be mentioned. Please select only one from: {str(participants)}." ) - if self._previous_speaker is not None: - agent_name = self._previous_speaker - else: - agent_name = participants[0] + select_speaker_messages.append(UserMessage(content=feedback, source="user")) else: - if len(mentions) > 1: - trace_logger.warning( - f"Expected exactly one agent to be mentioned, but got {mentions}. Using the first one." - ) agent_name = list(mentions.keys())[0] if ( not self._allow_repeated_speaker and self._previous_speaker is not None and agent_name == self._previous_speaker ): - trace_logger.warning( - f"Repeated speaker is not allowed, but the selector selected the previous speaker: {agent_name}" + trace_logger.debug(f"Model selected the previous speaker: {agent_name} (attempt {num_attempts})") + feedback = ( + f"Repeated speaker is not allowed, please select a different name from: {str(participants)}." ) - else: - agent_name = participants[0] - self._previous_speaker = agent_name - trace_logger.debug(f"Selected speaker: {agent_name}") - return agent_name + select_speaker_messages.append(UserMessage(content=feedback, source="user")) + else: + # Valid selection + trace_logger.debug(f"Model selected a valid name: {agent_name} (attempt {num_attempts})") + return agent_name + + if self._previous_speaker is not None: + trace_logger.warning(f"Model failed to select a speaker after {max_attempts}, using the previous speaker.") + return self._previous_speaker + trace_logger.warning( + f"Model failed to select a speaker after {max_attempts} and there was no previous speaker, using the first participant." + ) + return participants[0] def _mentioned_agents(self, message_content: str, agent_names: List[str]) -> Dict[str, int]: """Counts the number of times each agent is mentioned in the provided message content. @@ -224,6 +243,7 @@ class SelectorGroupChatConfig(BaseModel): selector_prompt: str allow_repeated_speaker: bool # selector_func: ComponentModel | None + max_selector_attempts: int = 3 class SelectorGroupChat(BaseGroupChat, Component[SelectorGroupChatConfig]): @@ -242,11 +262,15 @@ class SelectorGroupChat(BaseGroupChat, Component[SelectorGroupChatConfig]): Must contain '{roles}', '{participants}', and '{history}' to be filled in. allow_repeated_speaker (bool, optional): Whether to include the previous speaker in the list of candidates to be selected for the next turn. Defaults to False. The model may still select the previous speaker -- a warning will be logged if this happens. + max_selector_attempts (int, optional): The maximum number of attempts to select a speaker using the model. Defaults to 3. + If the model fails to select a speaker after the maximum number of attempts, the previous speaker will be used if available, + otherwise the first participant will be used. selector_func (Callable[[Sequence[AgentEvent | ChatMessage]], str | None], optional): A custom selector function that takes the conversation history and returns the name of the next speaker. If provided, this function will be used to override the model to select the next speaker. If the function returns None, the model will be used to select the next speaker. + Raises: ValueError: If the number of participants is less than two or if the selector prompt is invalid. @@ -382,6 +406,7 @@ def __init__( Read the above conversation. Then select the next role from {participants} to play. Only return the role. """, allow_repeated_speaker: bool = False, + max_selector_attempts: int = 3, selector_func: Callable[[Sequence[AgentEvent | ChatMessage]], str | None] | None = None, ): super().__init__( @@ -404,6 +429,7 @@ def __init__( self._model_client = model_client self._allow_repeated_speaker = allow_repeated_speaker self._selector_func = selector_func + self._max_selector_attempts = max_selector_attempts def _create_group_chat_manager_factory( self, @@ -425,6 +451,7 @@ def _create_group_chat_manager_factory( self._selector_prompt, self._allow_repeated_speaker, self._selector_func, + self._max_selector_attempts, ) def _to_config(self) -> SelectorGroupChatConfig: @@ -435,6 +462,7 @@ def _to_config(self) -> SelectorGroupChatConfig: max_turns=self._max_turns, selector_prompt=self._selector_prompt, allow_repeated_speaker=self._allow_repeated_speaker, + max_selector_attempts=self._max_selector_attempts, # selector_func=self._selector_func.dump_component() if self._selector_func else None, ) @@ -449,6 +477,7 @@ def _from_config(cls, config: SelectorGroupChatConfig) -> Self: max_turns=config.max_turns, selector_prompt=config.selector_prompt, allow_repeated_speaker=config.allow_repeated_speaker, + max_selector_attempts=config.max_selector_attempts, # selector_func=ComponentLoader.load_component(config.selector_func, Callable[[Sequence[AgentEvent | ChatMessage]], str | None]) # if config.selector_func # else None, diff --git a/python/packages/autogen-agentchat/tests/test_group_chat.py b/python/packages/autogen-agentchat/tests/test_group_chat.py index b952fa16a89e..098d68349ca1 100644 --- a/python/packages/autogen-agentchat/tests/test_group_chat.py +++ b/python/packages/autogen-agentchat/tests/test_group_chat.py @@ -741,33 +741,16 @@ async def test_selector_group_chat_two_speakers_allow_repeated(monkeypatch: pyte @pytest.mark.asyncio -async def test_selector_group_chat_multiple_speakers_selected(monkeypatch: pytest.MonkeyPatch) -> None: - model = "gpt-4o-2024-05-13" - chat_completions = [ - ChatCompletion( - id="id2", - choices=[ - Choice( - finish_reason="stop", - index=0, - message=ChatCompletionMessage(content="agent2, agent3", role="assistant"), - ) - ], - created=0, - model=model, - object="chat.completion", - usage=CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0), - ), - ] - mock = _MockChatCompletion(chat_completions) - monkeypatch.setattr(AsyncCompletions, "create", mock.mock_create) - +async def test_selector_group_chat_succcess_after_2_attempts() -> None: + model_client = ReplayChatCompletionClient( + ["agent2, agent3", "agent2"], + ) agent1 = _StopAgent("agent1", description="echo agent 1", stop_at=1) agent2 = _EchoAgent("agent2", description="echo agent 2") agent3 = _EchoAgent("agent3", description="echo agent 3") team = SelectorGroupChat( participants=[agent1, agent2, agent3], - model_client=OpenAIChatCompletionClient(model=model, api_key=""), + model_client=model_client, max_turns=1, ) result = await team.run(task="Write a program that prints 'Hello, world!'") @@ -777,48 +760,46 @@ async def test_selector_group_chat_multiple_speakers_selected(monkeypatch: pytes @pytest.mark.asyncio -async def test_selector_group_chat_non_speaker_selected(monkeypatch: pytest.MonkeyPatch) -> None: - model = "gpt-4o-2024-05-13" - chat_completions = [ - ChatCompletion( - id="id2", - choices=[ - Choice(finish_reason="stop", index=0, message=ChatCompletionMessage(content="agent5", role="assistant")) - ], - created=0, - model=model, - object="chat.completion", - usage=CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0), - ), - ChatCompletion( - id="id2", - choices=[ - Choice(finish_reason="stop", index=0, message=ChatCompletionMessage(content="agent5", role="assistant")) - ], - created=0, - model=model, - object="chat.completion", - usage=CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0), - ), - ] - mock = _MockChatCompletion(chat_completions) - monkeypatch.setattr(AsyncCompletions, "create", mock.mock_create) +async def test_selector_group_chat_fall_back_to_first_after_3_attempts() -> None: + model_client = ReplayChatCompletionClient( + [ + "agent2, agent3", # Multiple speakers + "agent5", # Non-existent speaker + "agent3, agent1", # Multiple speakers + ] + ) + agent1 = _StopAgent("agent1", description="echo agent 1", stop_at=1) + agent2 = _EchoAgent("agent2", description="echo agent 2") + agent3 = _EchoAgent("agent3", description="echo agent 3") + team = SelectorGroupChat( + participants=[agent1, agent2, agent3], + model_client=model_client, + max_turns=1, + ) + result = await team.run(task="Write a program that prints 'Hello, world!'") + assert len(result.messages) == 2 + assert result.messages[0].content == "Write a program that prints 'Hello, world!'" + assert result.messages[1].source == "agent1" + +@pytest.mark.asyncio +async def test_selector_group_chat_fall_back_to_previous_after_3_attempts() -> None: + model_client = ReplayChatCompletionClient( + ["agent2", "agent2", "agent2", "agent2"], + ) agent1 = _StopAgent("agent1", description="echo agent 1", stop_at=1) agent2 = _EchoAgent("agent2", description="echo agent 2") agent3 = _EchoAgent("agent3", description="echo agent 3") team = SelectorGroupChat( participants=[agent1, agent2, agent3], - model_client=OpenAIChatCompletionClient(model=model, api_key=""), + model_client=model_client, max_turns=2, ) result = await team.run(task="Write a program that prints 'Hello, world!'") assert len(result.messages) == 3 assert result.messages[0].content == "Write a program that prints 'Hello, world!'" - assert ( - result.messages[1].source == "agent1" - ) # agent1 is selected as the speaker as agent5 is not in the participants list - assert result.messages[2].source == "agent1" # agent1, the previous speaker, is selected as the speaker again. + assert result.messages[1].source == "agent2" + assert result.messages[2].source == "agent2" @pytest.mark.asyncio