Skip to content

Commit

Permalink
fix react bug introduced in 2.6.0 that cause tools without type hint …
Browse files Browse the repository at this point in the history
…to fail (#7621)
  • Loading branch information
kalanyuz authored Jan 31, 2025
1 parent ef8fe82 commit 71a2a1e
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 12 deletions.
37 changes: 28 additions & 9 deletions dspy/predict/react.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,13 @@


class Tool:
def __init__(self, func: Callable, name: str = None, desc: str = None, args: dict[str, Any] = None):
def __init__(
self,
func: Callable,
name: str = None,
desc: str = None,
args: dict[str, Any] = None,
):
annotations_func = func if inspect.isfunction(func) or inspect.ismethod(func) else func.__call__
self.func = func
self.name = name or getattr(func, "__name__", type(func).__name__)
Expand Down Expand Up @@ -59,7 +65,12 @@ def __init__(self, signature, tools: list[Callable], max_iters=5):
f"Signals that the final outputs, i.e. {outputs}, are now available and marks the task as complete."
)
finish_args = {} # k: v.annotation for k, v in signature.output_fields.items()}
tools["finish"] = Tool(func=lambda **kwargs: "Completed.", name="finish", desc=finish_desc, args=finish_args)
tools["finish"] = Tool(
func=lambda **kwargs: "Completed.",
name="finish",
desc=finish_desc,
args=finish_args,
)

for idx, tool in enumerate(tools.values()):
args = tool.args if hasattr(tool, "args") else str({tool.input_variable: str})
Expand All @@ -76,7 +87,8 @@ def __init__(self, signature, tools: list[Callable], max_iters=5):
)

fallback_signature = dspy.Signature(
{**signature.input_fields, **signature.output_fields}, signature.instructions
{**signature.input_fields, **signature.output_fields},
signature.instructions,
).append("trajectory", dspy.InputField(), type_=str)

self.tools = tools
Expand All @@ -91,20 +103,27 @@ def format(trajectory: dict[str, Any], last_iteration: bool):

trajectory = {}
for idx in range(self.max_iters):
pred = self.react(**input_args, trajectory=format(trajectory, last_iteration=(idx == self.max_iters - 1)))
pred = self.react(
**input_args,
trajectory=format(trajectory, last_iteration=(idx == self.max_iters - 1)),
)

trajectory[f"thought_{idx}"] = pred.next_thought
trajectory[f"tool_name_{idx}"] = pred.next_tool_name
trajectory[f"tool_args_{idx}"] = pred.next_tool_args

try:
parsed_tool_args = {}
tool = self.tools[pred.next_tool_name]
for k, v in pred.next_tool_args.items():
arg_type = self.tools[pred.next_tool_name].arg_types[k]
if isinstance((origin := get_origin(arg_type) or arg_type), type) and issubclass(origin, BaseModel):
parsed_tool_args[k] = arg_type.model_validate(v)
else:
parsed_tool_args[k] = v
if hasattr(tool, "arg_types") and k in tool.arg_types:
arg_type = tool.arg_types[k]
if isinstance((origin := get_origin(arg_type) or arg_type), type) and issubclass(
origin, BaseModel
):
parsed_tool_args[k] = arg_type.model_validate(v)
continue
parsed_tool_args[k] = v
trajectory[f"observation_{idx}"] = self.tools[pred.next_tool_name](**parsed_tool_args)
except Exception as e:
trajectory[f"observation_{idx}"] = f"Failed to execute: {e}"
Expand Down
38 changes: 35 additions & 3 deletions tests/predict/test_react.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from dataclasses import dataclass

import dspy
from dspy.utils.dummies import DummyLM, dummy_rm
from dspy.predict import react
from pydantic import BaseModel

import dspy
from dspy.predict import react
from dspy.utils.dummies import DummyLM, dummy_rm

# def test_example_no_tools():
# # Create a simple dataset which the model will use with the Retrieve tool.
Expand Down Expand Up @@ -228,3 +228,35 @@ class InvitationSignature(dspy.Signature):
"observation_1": "Completed.",
}
assert outputs.trajectory == expected_trajectory


def test_tool_calling_without_typehint():
def foo(a, b):
"""Add two numbers."""
return a + b

react = dspy.ReAct("a, b -> c:int", tools=[foo])
lm = DummyLM(
[
{"next_thought": "I need to add two numbers.", "next_tool_name": "foo", "next_tool_args": {"a": 1, "b": 2}},
{"next_thought": "I have the sum, now I can finish.", "next_tool_name": "finish", "next_tool_args": {}},
{"reasoning": "I added the numbers successfully", "c": 3},
]
)
dspy.settings.configure(lm=lm)
outputs = react(a=1, b=2)

expected_trajectory = {
"thought_0": "I need to add two numbers.",
"tool_name_0": "foo",
"tool_args_0": {
"a": 1,
"b": 2,
},
"observation_0": 3,
"thought_1": "I have the sum, now I can finish.",
"tool_name_1": "finish",
"tool_args_1": {},
"observation_1": "Completed.",
}
assert outputs.trajectory == expected_trajectory

0 comments on commit 71a2a1e

Please sign in to comment.