Skip to content

Commit

Permalink
Simplify _convert_messages_to_anthropic_format
Browse files Browse the repository at this point in the history
  • Loading branch information
vblagoje committed Oct 25, 2024
1 parent 1228f8c commit ad5980f
Showing 1 changed file with 75 additions and 56 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace

from haystack_experimental.dataclasses import ChatMessage, ToolCall
from haystack_experimental.dataclasses.chat_message import ChatRole
from haystack_experimental.dataclasses.chat_message import ChatRole, ToolCallResult
from haystack_experimental.dataclasses.tool import Tool, deserialize_tools_inplace

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -43,14 +43,66 @@
chatgenerator_base_class: Type[object] = object # type: ignore[no-redef]


def _convert_tool_call_results_to_anthropic_format(
tool_call_results: List[ToolCallResult], anthropic_msg: Dict[str, Any]
) -> None:
"""
Convert a list of tool call results to the format expected by Anthropic Chat API.
:param tool_call_results: The list of ToolCallResults to convert.
:param anthropic_msg: The Anthropic message to update.
"""
anthropic_content = []
if anthropic_msg.get("content"):
anthropic_content = anthropic_msg["content"]
else:
anthropic_msg["content"] = anthropic_content

for tool_call_result in tool_call_results:
if tool_call_result.origin.id is None:
raise ValueError("`ToolCall` must have a non-null `id` attribute to be used with Anthropic.")
anthropic_content.append(
{
"type": "tool_result",
"tool_use_id": tool_call_result.origin.id,
"content": [{"type": "text", "text": tool_call_result.result}],
"is_error": tool_call_result.error,
}
)


def _convert_tool_calls_to_anthropic_format(tool_calls: List[ToolCall]) -> List[Dict[str, Any]]:
"""
Convert a list of tool calls to the format expected by Anthropic Chat API.
:param tool_calls: The list of ToolCalls to convert.
:return: A list of dictionaries in the format expected by Anthropic API.
"""
anthropic_tool_calls = []
for tc in tool_calls:
if tc.id is None:
raise ValueError("`ToolCall` must have a non-null `id` attribute to be used with Anthropic.")
anthropic_tool_calls.append(
{
"type": "tool_use",
"id": tc.id,
"name": tc.tool_name,
"input": tc.arguments,
}
)
return anthropic_tool_calls


def _convert_messages_to_anthropic_format( # noqa: PLR0912
messages: List[ChatMessage],
) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
"""
Convert a list of messages to the format expected by Anthropic Chat API.
:param messages: The list of ChatMessages to convert.
:return: A list of dictionaries in the format expected by Anthropic API.
:return: A tuple of two lists:
- A list of system message dictionaries in the format expected by Anthropic API.
- A list of non-system message dictionaries in the format expected by Anthropic API.
"""
anthropic_system_messages = []
anthropic_non_system_messages = []
Expand All @@ -65,76 +117,37 @@ def _convert_messages_to_anthropic_format( # noqa: PLR0912
if message.is_from(ChatRole.SYSTEM):
anthropic_system_messages.append({"type": "text", "text": message.text})
continue

# create the base anthropic message for this message
anthropic_msg: Dict[str, Any] = {"role": message._role.value}

# Special case for when we have both text and tool calls in the same message
if message.texts and message.tool_calls:
# Special case for when we have both text and tool calls in the same message
anthropic_tool_calls = []
for tc in message.tool_calls:
if tc.id is None:
raise ValueError("`ToolCall` must have a non-null `id` attribute to be used with Anthropic.")
anthropic_tool_calls.append(
{
"type": "tool_use",
"id": tc.id,
"name": tc.tool_name,
"input": tc.arguments,
}
)
anthropic_tool_calls = _convert_tool_calls_to_anthropic_format(message.tool_calls)
anthropic_msg["content"] = [{"type": "text", "text": message.texts[0]}] + anthropic_tool_calls

# only tool calls
elif message.tool_calls:
anthropic_tool_calls = _convert_tool_calls_to_anthropic_format(message.tool_calls)
anthropic_msg["content"] = anthropic_tool_calls

# only text
elif message.texts:
anthropic_msg["content"] = [{"type": "text", "text": message.texts[0]}]

# handle tool call results and special case for tool call results stitching
elif message.tool_call_results:
if previous_message and previous_message.tool_call_results:
# special case - we already handled tool call results stitching
# in the previous message, so we skip this message
continue
anthropic_msg["content"] = []
for tool_call_result in message.tool_call_results:
if tool_call_result.origin.id is None:
raise ValueError("`ToolCall` must have a non-null `id` attribute to be used with Anthropic.")
anthropic_msg["content"].append(
{
"type": "tool_result",
"tool_use_id": tool_call_result.origin.id,
"content": [{"type": "text", "text": tool_call_result.result}],
"is_error": tool_call_result.error,
}
)
_convert_tool_call_results_to_anthropic_format(message.tool_call_results, anthropic_msg)
# special case - check if the next message is a tool result as well
# if so, we need to combine this and the next message into a single anthropic message
if next_message and next_message.tool_call_results:
for tool_call_result in next_message.tool_call_results:
if tool_call_result.origin.id is None:
raise ValueError("`ToolCall` must have a non-null `id` attribute to be used with Anthropic.")
anthropic_msg["content"].append(
{
"type": "tool_result",
"tool_use_id": tool_call_result.origin.id,
"content": [{"type": "text", "text": tool_call_result.result}],
"is_error": tool_call_result.error,
}
)
# Anthropic API requires the role to be set to "user" for tool results
_convert_tool_call_results_to_anthropic_format(next_message.tool_call_results, anthropic_msg)
# Anthropic API requires the role to be set to "user" for tool call results
anthropic_msg["role"] = "user"

elif message.tool_calls:
anthropic_tool_calls = []
for tc in message.tool_calls:
if tc.id is None:
raise ValueError("`ToolCall` must have a non-null `id` attribute to be used with Anthropic.")
anthropic_tool_calls.append(
{
"type": "tool_use",
"id": tc.id,
"name": tc.tool_name,
"input": tc.arguments,
}
)
anthropic_msg["content"] = anthropic_tool_calls

else:
raise ValueError(
"A `ChatMessage` must contain at " "least one `TextContent`, `ToolCall`, or `ToolCallResult`."
Expand All @@ -145,7 +158,13 @@ def _convert_messages_to_anthropic_format( # noqa: PLR0912
return anthropic_system_messages, anthropic_non_system_messages


def _check_duplicate_tool_names(tools: List[Tool]):
def _check_duplicate_tool_names(tools: List[Tool]) -> None:
"""
Check for duplicate tool names.
:param tools: The list of tools to check.
:raises ValueError: If duplicate tool names are found.
"""
tool_names = [tool.name for tool in tools]
duplicate_tool_names = {name for name in tool_names if tool_names.count(name) > 1}
if duplicate_tool_names:
Expand Down

0 comments on commit ad5980f

Please sign in to comment.