From 84f48c465d602a2b121a82d464fe701f84160e9f Mon Sep 17 00:00:00 2001 From: "Brandon Hancock (bhancock_ai)" <109994880+bhancockio@users.noreply.github.com> Date: Fri, 18 Oct 2024 14:56:56 -0400 Subject: [PATCH] fix tool calling issue (#1467) * fix tool calling issue * Update tool type check * Drop print --- src/crewai/agent.py | 2 +- src/crewai/agents/crew_agent_executor.py | 8 +- src/crewai/tools/tool_usage.py | 10 +- tests/tools/test_tool_usage.py | 143 +++++++++++++++++++++++ 4 files changed, 154 insertions(+), 9 deletions(-) create mode 100644 tests/tools/test_tool_usage.py diff --git a/src/crewai/agent.py b/src/crewai/agent.py index 3f81ece215..165a406561 100644 --- a/src/crewai/agent.py +++ b/src/crewai/agent.py @@ -394,7 +394,7 @@ def _render_text_description_and_args(self, tools: List[Any]) -> str: """ tool_strings = [] for tool in tools: - args_schema = str(tool.args) + args_schema = str(tool.model_fields) if hasattr(tool, "func") and tool.func: sig = signature(tool.func) description = ( diff --git a/src/crewai/agents/crew_agent_executor.py b/src/crewai/agents/crew_agent_executor.py index b901fe1323..b11782ca16 100644 --- a/src/crewai/agents/crew_agent_executor.py +++ b/src/crewai/agents/crew_agent_executor.py @@ -2,6 +2,7 @@ import re from typing import Any, Dict, List, Union +from crewai.agents.agent_builder.base_agent import BaseAgent from crewai.agents.agent_builder.base_agent_executor_mixin import CrewAgentExecutorMixin from crewai.agents.parser import ( FINAL_ANSWER_AND_PARSABLE_ACTION_ERROR_MESSAGE, @@ -19,7 +20,6 @@ ) from crewai.utilities.logger import Logger from crewai.utilities.training_handler import CrewTrainingHandler -from crewai.agents.agent_builder.base_agent import BaseAgent class CrewAgentExecutor(CrewAgentExecutorMixin): @@ -323,9 +323,9 @@ def _handle_crew_training_output( if self.crew is not None and hasattr(self.crew, "_train_iteration"): train_iteration = self.crew._train_iteration if agent_id in training_data and isinstance(train_iteration, int): - training_data[agent_id][train_iteration]["improved_output"] = ( - result.output - ) + training_data[agent_id][train_iteration][ + "improved_output" + ] = result.output training_handler.save(training_data) else: self._logger.log( diff --git a/src/crewai/tools/tool_usage.py b/src/crewai/tools/tool_usage.py index f75a9443a3..71c02fc3cf 100644 --- a/src/crewai/tools/tool_usage.py +++ b/src/crewai/tools/tool_usage.py @@ -6,14 +6,13 @@ from textwrap import dedent from typing import Any, List, Union +import crewai.utilities.events as events from crewai.agents.tools_handler import ToolsHandler from crewai.task import Task from crewai.telemetry import Telemetry from crewai.tools.tool_calling import InstructorToolCalling, ToolCalling from crewai.tools.tool_usage_events import ToolUsageError, ToolUsageFinished from crewai.utilities import I18N, Converter, ConverterError, Printer -import crewai.utilities.events as events - agentops = None if os.environ.get("AGENTOPS_API_KEY"): @@ -300,8 +299,11 @@ def _render(self) -> str: descriptions = [] for tool in self.tools: args = { - k: {k2: v2 for k2, v2 in v.items() if k2 in ["description", "type"]} - for k, v in tool.args.items() + name: { + "description": field.description, + "type": field.annotation.__name__, + } + for name, field in tool.args_schema.model_fields.items() } descriptions.append( "\n".join( diff --git a/tests/tools/test_tool_usage.py b/tests/tools/test_tool_usage.py new file mode 100644 index 0000000000..d0a6ec0b05 --- /dev/null +++ b/tests/tools/test_tool_usage.py @@ -0,0 +1,143 @@ +import json +import random +from unittest.mock import MagicMock, patch + +import pytest +from crewai_tools import BaseTool +from pydantic import BaseModel, Field + +from crewai import Agent, Crew, Task +from crewai.tools.tool_usage import ToolUsage + + +class RandomNumberToolInput(BaseModel): + min_value: int = Field( + ..., description="The minimum value of the range (inclusive)" + ) + max_value: int = Field( + ..., description="The maximum value of the range (inclusive)" + ) + + +class RandomNumberTool(BaseTool): + name: str = "Random Number Generator" + description: str = "Generates a random number within a specified range" + args_schema: type[BaseModel] = RandomNumberToolInput + + def _run(self, min_value: int, max_value: int) -> int: + return random.randint(min_value, max_value) + + +# Example agent and task +example_agent = Agent( + role="Number Generator", + goal="Generate random numbers for various purposes", + backstory="You are an AI agent specialized in generating random numbers within specified ranges.", + tools=[RandomNumberTool()], + verbose=True, +) + +example_task = Task( + description="Generate a random number between 1 and 100", + expected_output="A random number between 1 and 100", + agent=example_agent, +) + + +def test_random_number_tool_usage(): + crew = Crew( + agents=[example_agent], + tasks=[example_task], + ) + + with patch.object(random, "randint", return_value=42): + result = crew.kickoff() + + assert "42" in result.raw + + +def test_random_number_tool_range(): + tool = RandomNumberTool() + result = tool._run(1, 10) + assert 1 <= result <= 10 + + +def test_random_number_tool_with_crew(): + crew = Crew( + agents=[example_agent], + tasks=[example_task], + ) + + result = crew.kickoff() + + # Check if the result contains a number between 1 and 100 + assert any(str(num) in result.raw for num in range(1, 101)) + + +def test_random_number_tool_invalid_range(): + tool = RandomNumberTool() + with pytest.raises(ValueError): + tool._run(10, 1) # min_value > max_value + + +def test_random_number_tool_schema(): + tool = RandomNumberTool() + + # Get the schema using model_json_schema() + schema = tool.args_schema.model_json_schema() + + # Convert the schema to a string + schema_str = json.dumps(schema) + + # Check if the schema string contains the expected fields + assert "min_value" in schema_str + assert "max_value" in schema_str + + # Parse the schema string back to a dictionary + schema_dict = json.loads(schema_str) + + # Check if the schema contains the correct field types + assert schema_dict["properties"]["min_value"]["type"] == "integer" + assert schema_dict["properties"]["max_value"]["type"] == "integer" + + # Check if the schema contains the field descriptions + assert ( + "minimum value" in schema_dict["properties"]["min_value"]["description"].lower() + ) + assert ( + "maximum value" in schema_dict["properties"]["max_value"]["description"].lower() + ) + + +def test_tool_usage_render(): + tool = RandomNumberTool() + + tool_usage = ToolUsage( + tools_handler=MagicMock(), + tools=[tool], + original_tools=[tool], + tools_description="Sample tool for testing", + tools_names="random_number_generator", + task=MagicMock(), + function_calling_llm=MagicMock(), + agent=MagicMock(), + action=MagicMock(), + ) + + rendered = tool_usage._render() + + # Updated checks to match the actual output + assert "Tool Name: random number generator" in rendered + assert ( + "Random Number Generator(min_value: 'integer', max_value: 'integer') - Generates a random number within a specified range min_value: 'The minimum value of the range (inclusive)', max_value: 'The maximum value of the range (inclusive)'" + in rendered + ) + assert "Tool Arguments:" in rendered + assert ( + "'min_value': {'description': 'The minimum value of the range (inclusive)', 'type': 'int'}" + in rendered + ) + assert ( + "'max_value': {'description': 'The maximum value of the range (inclusive)', 'type': 'int'}" + in rendered + )