Skip to content

Commit

Permalink
use retry loops
Browse files Browse the repository at this point in the history
  • Loading branch information
ekzhu committed Feb 8, 2025
1 parent 4c87512 commit e65e3f6
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 88 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Any, Callable, Dict, List, Mapping, Sequence

from autogen_core import Component, ComponentModel
from autogen_core.models import ChatCompletionClient, ModelFamily, SystemMessage, UserMessage
from autogen_core.models import AssistantMessage, ChatCompletionClient, ModelFamily, SystemMessage, UserMessage
from pydantic import BaseModel
from typing_extensions import Self

Expand Down Expand Up @@ -39,6 +39,7 @@ def __init__(
selector_prompt: str,
allow_repeated_speaker: bool,
selector_func: Callable[[Sequence[AgentEvent | ChatMessage]], str | None] | None,
max_selector_attempts: int,
) -> None:
super().__init__(
group_topic_type,
Expand All @@ -53,6 +54,7 @@ def __init__(
self._previous_speaker: str | None = None
self._allow_repeated_speaker = allow_repeated_speaker
self._selector_func = selector_func
self._max_selector_attempts = max_selector_attempts

async def validate_group_state(self, messages: List[ChatMessage] | None) -> None:
pass
Expand Down Expand Up @@ -131,54 +133,71 @@ async def select_speaker(self, thread: List[AgentEvent | ChatMessage]) -> str:

# Select the next speaker.
if len(participants) > 1:
select_speaker_prompt = self._selector_prompt.format(
roles=roles, participants=str(participants), history=history
)
select_speaker_messages: List[SystemMessage | UserMessage]
if self._model_client.model_info["family"] in [
ModelFamily.GPT_4,
ModelFamily.GPT_4O,
ModelFamily.GPT_35,
ModelFamily.O1,
ModelFamily.O3,
]:
select_speaker_messages = [SystemMessage(content=select_speaker_prompt)]
else:
# Many other models need a UserMessage to respond to
select_speaker_messages = [UserMessage(content=select_speaker_prompt, source="selector")]
agent_name = await self._select_speaker(roles, participants, history, self._max_selector_attempts)
else:
agent_name = participants[0]
self._previous_speaker = agent_name
trace_logger.debug(f"Selected speaker: {agent_name}")
return agent_name

response = await self._model_client.create(messages=select_speaker_messages)
async def _select_speaker(self, roles: str, participants: List[str], history: str, max_attempts: int) -> str:
select_speaker_prompt = self._selector_prompt.format(
roles=roles, participants=str(participants), history=history
)
select_speaker_messages: List[SystemMessage | UserMessage | AssistantMessage]
if self._model_client.model_info["family"] in [
ModelFamily.GPT_4,
ModelFamily.GPT_4O,
ModelFamily.GPT_35,
ModelFamily.O1,
ModelFamily.O3,
]:
select_speaker_messages = [SystemMessage(content=select_speaker_prompt)]
else:
# Many other models need a UserMessage to respond to
select_speaker_messages = [UserMessage(content=select_speaker_prompt, source="user")]

num_attempts = 0
while num_attempts < max_attempts:
num_attempts += 1
response = await self._model_client.create(messages=select_speaker_messages)
assert isinstance(response.content, str)
select_speaker_messages.append(AssistantMessage(content=response.content, source="selector"))
mentions = self._mentioned_agents(response.content, self._participant_topic_types)
if len(mentions) == 0:
trace_logger.warning(
f"No valid agent was mentioned in the model response: {response.content}. "
"Using the previous speaker if available, otherwise selecting the first participant."
trace_logger.debug(f"Model failed to select a valid name: {response.content} (attempt {num_attempts})")
feedback = f"No valid name was mentioned. Please select from: {str(participants)}."
select_speaker_messages.append(UserMessage(content=feedback, source="user"))
elif len(mentions) > 1:
trace_logger.debug(f"Model selected multiple names: {str(mentions)} (attempt {num_attempts})")
feedback = (
f"Expected exactly one name to be mentioned. Please select only one from: {str(participants)}."
)
if self._previous_speaker is not None:
agent_name = self._previous_speaker
else:
agent_name = participants[0]
select_speaker_messages.append(UserMessage(content=feedback, source="user"))
else:
if len(mentions) > 1:
trace_logger.warning(
f"Expected exactly one agent to be mentioned, but got {mentions}. Using the first one."
)
agent_name = list(mentions.keys())[0]
if (
not self._allow_repeated_speaker
and self._previous_speaker is not None
and agent_name == self._previous_speaker
):
trace_logger.warning(
f"Repeated speaker is not allowed, but the selector selected the previous speaker: {agent_name}"
trace_logger.debug(f"Model selected the previous speaker: {agent_name} (attempt {num_attempts})")
feedback = (
f"Repeated speaker is not allowed, please select a different name from: {str(participants)}."
)
else:
agent_name = participants[0]
self._previous_speaker = agent_name
trace_logger.debug(f"Selected speaker: {agent_name}")
return agent_name
select_speaker_messages.append(UserMessage(content=feedback, source="user"))
else:
# Valid selection
trace_logger.debug(f"Model selected a valid name: {agent_name} (attempt {num_attempts})")
return agent_name

if self._previous_speaker is not None:
trace_logger.warning(f"Model failed to select a speaker after {max_attempts}, using the previous speaker.")
return self._previous_speaker
trace_logger.warning(
f"Model failed to select a speaker after {max_attempts} and there was no previous speaker, using the first participant."
)
return participants[0]

def _mentioned_agents(self, message_content: str, agent_names: List[str]) -> Dict[str, int]:
"""Counts the number of times each agent is mentioned in the provided message content.
Expand Down Expand Up @@ -224,6 +243,7 @@ class SelectorGroupChatConfig(BaseModel):
selector_prompt: str
allow_repeated_speaker: bool
# selector_func: ComponentModel | None
max_selector_attempts: int = 3


class SelectorGroupChat(BaseGroupChat, Component[SelectorGroupChatConfig]):
Expand All @@ -242,11 +262,15 @@ class SelectorGroupChat(BaseGroupChat, Component[SelectorGroupChatConfig]):
Must contain '{roles}', '{participants}', and '{history}' to be filled in.
allow_repeated_speaker (bool, optional): Whether to include the previous speaker in the list of candidates to be selected for the next turn.
Defaults to False. The model may still select the previous speaker -- a warning will be logged if this happens.
max_selector_attempts (int, optional): The maximum number of attempts to select a speaker using the model. Defaults to 3.
If the model fails to select a speaker after the maximum number of attempts, the previous speaker will be used if available,
otherwise the first participant will be used.
selector_func (Callable[[Sequence[AgentEvent | ChatMessage]], str | None], optional): A custom selector
function that takes the conversation history and returns the name of the next speaker.
If provided, this function will be used to override the model to select the next speaker.
If the function returns None, the model will be used to select the next speaker.
Raises:
ValueError: If the number of participants is less than two or if the selector prompt is invalid.
Expand Down Expand Up @@ -382,6 +406,7 @@ def __init__(
Read the above conversation. Then select the next role from {participants} to play. Only return the role.
""",
allow_repeated_speaker: bool = False,
max_selector_attempts: int = 3,
selector_func: Callable[[Sequence[AgentEvent | ChatMessage]], str | None] | None = None,
):
super().__init__(
Expand All @@ -404,6 +429,7 @@ def __init__(
self._model_client = model_client
self._allow_repeated_speaker = allow_repeated_speaker
self._selector_func = selector_func
self._max_selector_attempts = max_selector_attempts

def _create_group_chat_manager_factory(
self,
Expand All @@ -425,6 +451,7 @@ def _create_group_chat_manager_factory(
self._selector_prompt,
self._allow_repeated_speaker,
self._selector_func,
self._max_selector_attempts,
)

def _to_config(self) -> SelectorGroupChatConfig:
Expand All @@ -435,6 +462,7 @@ def _to_config(self) -> SelectorGroupChatConfig:
max_turns=self._max_turns,
selector_prompt=self._selector_prompt,
allow_repeated_speaker=self._allow_repeated_speaker,
max_selector_attempts=self._max_selector_attempts,
# selector_func=self._selector_func.dump_component() if self._selector_func else None,
)

Expand All @@ -449,6 +477,7 @@ def _from_config(cls, config: SelectorGroupChatConfig) -> Self:
max_turns=config.max_turns,
selector_prompt=config.selector_prompt,
allow_repeated_speaker=config.allow_repeated_speaker,
max_selector_attempts=config.max_selector_attempts,
# selector_func=ComponentLoader.load_component(config.selector_func, Callable[[Sequence[AgentEvent | ChatMessage]], str | None])
# if config.selector_func
# else None,
Expand Down
87 changes: 34 additions & 53 deletions python/packages/autogen-agentchat/tests/test_group_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -741,33 +741,16 @@ async def test_selector_group_chat_two_speakers_allow_repeated(monkeypatch: pyte


@pytest.mark.asyncio
async def test_selector_group_chat_multiple_speakers_selected(monkeypatch: pytest.MonkeyPatch) -> None:
model = "gpt-4o-2024-05-13"
chat_completions = [
ChatCompletion(
id="id2",
choices=[
Choice(
finish_reason="stop",
index=0,
message=ChatCompletionMessage(content="agent2, agent3", role="assistant"),
)
],
created=0,
model=model,
object="chat.completion",
usage=CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0),
),
]
mock = _MockChatCompletion(chat_completions)
monkeypatch.setattr(AsyncCompletions, "create", mock.mock_create)

async def test_selector_group_chat_succcess_after_2_attempts() -> None:
model_client = ReplayChatCompletionClient(
["agent2, agent3", "agent2"],
)
agent1 = _StopAgent("agent1", description="echo agent 1", stop_at=1)
agent2 = _EchoAgent("agent2", description="echo agent 2")
agent3 = _EchoAgent("agent3", description="echo agent 3")
team = SelectorGroupChat(
participants=[agent1, agent2, agent3],
model_client=OpenAIChatCompletionClient(model=model, api_key=""),
model_client=model_client,
max_turns=1,
)
result = await team.run(task="Write a program that prints 'Hello, world!'")
Expand All @@ -777,48 +760,46 @@ async def test_selector_group_chat_multiple_speakers_selected(monkeypatch: pytes


@pytest.mark.asyncio
async def test_selector_group_chat_non_speaker_selected(monkeypatch: pytest.MonkeyPatch) -> None:
model = "gpt-4o-2024-05-13"
chat_completions = [
ChatCompletion(
id="id2",
choices=[
Choice(finish_reason="stop", index=0, message=ChatCompletionMessage(content="agent5", role="assistant"))
],
created=0,
model=model,
object="chat.completion",
usage=CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0),
),
ChatCompletion(
id="id2",
choices=[
Choice(finish_reason="stop", index=0, message=ChatCompletionMessage(content="agent5", role="assistant"))
],
created=0,
model=model,
object="chat.completion",
usage=CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0),
),
]
mock = _MockChatCompletion(chat_completions)
monkeypatch.setattr(AsyncCompletions, "create", mock.mock_create)
async def test_selector_group_chat_fall_back_to_first_after_3_attempts() -> None:
model_client = ReplayChatCompletionClient(
[
"agent2, agent3", # Multiple speakers
"agent5", # Non-existent speaker
"agent3, agent1", # Multiple speakers
]
)
agent1 = _StopAgent("agent1", description="echo agent 1", stop_at=1)
agent2 = _EchoAgent("agent2", description="echo agent 2")
agent3 = _EchoAgent("agent3", description="echo agent 3")
team = SelectorGroupChat(
participants=[agent1, agent2, agent3],
model_client=model_client,
max_turns=1,
)
result = await team.run(task="Write a program that prints 'Hello, world!'")
assert len(result.messages) == 2
assert result.messages[0].content == "Write a program that prints 'Hello, world!'"
assert result.messages[1].source == "agent1"


@pytest.mark.asyncio
async def test_selector_group_chat_fall_back_to_previous_after_3_attempts() -> None:
model_client = ReplayChatCompletionClient(
["agent2", "agent2", "agent2", "agent2"],
)
agent1 = _StopAgent("agent1", description="echo agent 1", stop_at=1)
agent2 = _EchoAgent("agent2", description="echo agent 2")
agent3 = _EchoAgent("agent3", description="echo agent 3")
team = SelectorGroupChat(
participants=[agent1, agent2, agent3],
model_client=OpenAIChatCompletionClient(model=model, api_key=""),
model_client=model_client,
max_turns=2,
)
result = await team.run(task="Write a program that prints 'Hello, world!'")
assert len(result.messages) == 3
assert result.messages[0].content == "Write a program that prints 'Hello, world!'"
assert (
result.messages[1].source == "agent1"
) # agent1 is selected as the speaker as agent5 is not in the participants list
assert result.messages[2].source == "agent1" # agent1, the previous speaker, is selected as the speaker again.
assert result.messages[1].source == "agent2"
assert result.messages[2].source == "agent2"


@pytest.mark.asyncio
Expand Down

0 comments on commit e65e3f6

Please sign in to comment.