Skip to content

Commit

Permalink
fix tool calling issue (#1467)
Browse files Browse the repository at this point in the history
* fix tool calling issue

* Update tool type check

* Drop print
  • Loading branch information
bhancockio authored Oct 18, 2024
1 parent 60efcad commit 84f48c4
Show file tree
Hide file tree
Showing 4 changed files with 154 additions and 9 deletions.
2 changes: 1 addition & 1 deletion src/crewai/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,7 +394,7 @@ def _render_text_description_and_args(self, tools: List[Any]) -> str:
"""
tool_strings = []
for tool in tools:
args_schema = str(tool.args)
args_schema = str(tool.model_fields)
if hasattr(tool, "func") and tool.func:
sig = signature(tool.func)
description = (
Expand Down
8 changes: 4 additions & 4 deletions src/crewai/agents/crew_agent_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import re
from typing import Any, Dict, List, Union

from crewai.agents.agent_builder.base_agent import BaseAgent
from crewai.agents.agent_builder.base_agent_executor_mixin import CrewAgentExecutorMixin
from crewai.agents.parser import (
FINAL_ANSWER_AND_PARSABLE_ACTION_ERROR_MESSAGE,
Expand All @@ -19,7 +20,6 @@
)
from crewai.utilities.logger import Logger
from crewai.utilities.training_handler import CrewTrainingHandler
from crewai.agents.agent_builder.base_agent import BaseAgent


class CrewAgentExecutor(CrewAgentExecutorMixin):
Expand Down Expand Up @@ -323,9 +323,9 @@ def _handle_crew_training_output(
if self.crew is not None and hasattr(self.crew, "_train_iteration"):
train_iteration = self.crew._train_iteration
if agent_id in training_data and isinstance(train_iteration, int):
training_data[agent_id][train_iteration]["improved_output"] = (
result.output
)
training_data[agent_id][train_iteration][
"improved_output"
] = result.output
training_handler.save(training_data)
else:
self._logger.log(
Expand Down
10 changes: 6 additions & 4 deletions src/crewai/tools/tool_usage.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,13 @@
from textwrap import dedent
from typing import Any, List, Union

import crewai.utilities.events as events
from crewai.agents.tools_handler import ToolsHandler
from crewai.task import Task
from crewai.telemetry import Telemetry
from crewai.tools.tool_calling import InstructorToolCalling, ToolCalling
from crewai.tools.tool_usage_events import ToolUsageError, ToolUsageFinished
from crewai.utilities import I18N, Converter, ConverterError, Printer
import crewai.utilities.events as events


agentops = None
if os.environ.get("AGENTOPS_API_KEY"):
Expand Down Expand Up @@ -300,8 +299,11 @@ def _render(self) -> str:
descriptions = []
for tool in self.tools:
args = {
k: {k2: v2 for k2, v2 in v.items() if k2 in ["description", "type"]}
for k, v in tool.args.items()
name: {
"description": field.description,
"type": field.annotation.__name__,
}
for name, field in tool.args_schema.model_fields.items()
}
descriptions.append(
"\n".join(
Expand Down
143 changes: 143 additions & 0 deletions tests/tools/test_tool_usage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
import json
import random
from unittest.mock import MagicMock, patch

import pytest
from crewai_tools import BaseTool
from pydantic import BaseModel, Field

from crewai import Agent, Crew, Task
from crewai.tools.tool_usage import ToolUsage


class RandomNumberToolInput(BaseModel):
min_value: int = Field(
..., description="The minimum value of the range (inclusive)"
)
max_value: int = Field(
..., description="The maximum value of the range (inclusive)"
)


class RandomNumberTool(BaseTool):
name: str = "Random Number Generator"
description: str = "Generates a random number within a specified range"
args_schema: type[BaseModel] = RandomNumberToolInput

def _run(self, min_value: int, max_value: int) -> int:
return random.randint(min_value, max_value)


# Example agent and task
example_agent = Agent(
role="Number Generator",
goal="Generate random numbers for various purposes",
backstory="You are an AI agent specialized in generating random numbers within specified ranges.",
tools=[RandomNumberTool()],
verbose=True,
)

example_task = Task(
description="Generate a random number between 1 and 100",
expected_output="A random number between 1 and 100",
agent=example_agent,
)


def test_random_number_tool_usage():
crew = Crew(
agents=[example_agent],
tasks=[example_task],
)

with patch.object(random, "randint", return_value=42):
result = crew.kickoff()

assert "42" in result.raw


def test_random_number_tool_range():
tool = RandomNumberTool()
result = tool._run(1, 10)
assert 1 <= result <= 10


def test_random_number_tool_with_crew():
crew = Crew(
agents=[example_agent],
tasks=[example_task],
)

result = crew.kickoff()

# Check if the result contains a number between 1 and 100
assert any(str(num) in result.raw for num in range(1, 101))


def test_random_number_tool_invalid_range():
tool = RandomNumberTool()
with pytest.raises(ValueError):
tool._run(10, 1) # min_value > max_value


def test_random_number_tool_schema():
tool = RandomNumberTool()

# Get the schema using model_json_schema()
schema = tool.args_schema.model_json_schema()

# Convert the schema to a string
schema_str = json.dumps(schema)

# Check if the schema string contains the expected fields
assert "min_value" in schema_str
assert "max_value" in schema_str

# Parse the schema string back to a dictionary
schema_dict = json.loads(schema_str)

# Check if the schema contains the correct field types
assert schema_dict["properties"]["min_value"]["type"] == "integer"
assert schema_dict["properties"]["max_value"]["type"] == "integer"

# Check if the schema contains the field descriptions
assert (
"minimum value" in schema_dict["properties"]["min_value"]["description"].lower()
)
assert (
"maximum value" in schema_dict["properties"]["max_value"]["description"].lower()
)


def test_tool_usage_render():
tool = RandomNumberTool()

tool_usage = ToolUsage(
tools_handler=MagicMock(),
tools=[tool],
original_tools=[tool],
tools_description="Sample tool for testing",
tools_names="random_number_generator",
task=MagicMock(),
function_calling_llm=MagicMock(),
agent=MagicMock(),
action=MagicMock(),
)

rendered = tool_usage._render()

# Updated checks to match the actual output
assert "Tool Name: random number generator" in rendered
assert (
"Random Number Generator(min_value: 'integer', max_value: 'integer') - Generates a random number within a specified range min_value: 'The minimum value of the range (inclusive)', max_value: 'The maximum value of the range (inclusive)'"
in rendered
)
assert "Tool Arguments:" in rendered
assert (
"'min_value': {'description': 'The minimum value of the range (inclusive)', 'type': 'int'}"
in rendered
)
assert (
"'max_value': {'description': 'The maximum value of the range (inclusive)', 'type': 'int'}"
in rendered
)

0 comments on commit 84f48c4

Please sign in to comment.