Skip to content

Commit

Permalink
Add error_on_tool_error param to `FunctionCallingLLM.predict_and_ca…
Browse files Browse the repository at this point in the history
…ll` (#17663)
  • Loading branch information
alexander-azizi-martin authored Jan 31, 2025
1 parent 0a6fa89 commit 772b6d2
Show file tree
Hide file tree
Showing 3 changed files with 165 additions and 5 deletions.
10 changes: 7 additions & 3 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -172,15 +172,19 @@ LlamaIndex is organized as a **monorepo**, meaning different packages live withi
```bash
curl -sSL https://install.python-poetry.org | python3 -
```
2. Activate the environment:
2. Install the Poetry shell plugin (if you don't already have it):
```bash
poetry self add poetry-plugin-shell
```
3. Activate the environment:
```bash
poetry shell
```
3. Install dependencies:
4. Install dependencies:
```bash
poetry install --only dev,docs --no-root
```
4. Install the package(s) you want to work on. You will for sure need to install `llama-index-core`:
5. Install the package(s) you want to work on. You will for sure need to install `llama-index-core`:

```bash
pip install -e llama-index-core
Expand Down
22 changes: 20 additions & 2 deletions llama-index-core/llama_index/core/llms/function_calling.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ def predict_and_call(
verbose: bool = False,
allow_parallel_tool_calls: bool = False,
error_on_no_tool_call: bool = True,
error_on_tool_error: bool = False,
**kwargs: Any,
) -> "AgentChatResponse":
"""Predict and call the tool."""
Expand Down Expand Up @@ -193,7 +194,15 @@ def predict_and_call(
call_tool_with_selection(tool_call, tools, verbose=verbose)
for tool_call in tool_calls
]
if allow_parallel_tool_calls:
tool_outputs_with_error = [
tool_output for tool_output in tool_outputs if tool_output.is_error
]
if error_on_tool_error and len(tool_outputs_with_error) > 0:
error_text = "\n\n".join(
[tool_output.content for tool_output in tool_outputs]
)
raise ValueError(error_text)
elif allow_parallel_tool_calls:
output_text = "\n\n".join(
[tool_output.content for tool_output in tool_outputs]
)
Expand All @@ -218,6 +227,7 @@ async def apredict_and_call(
verbose: bool = False,
allow_parallel_tool_calls: bool = False,
error_on_no_tool_call: bool = True,
error_on_tool_error: bool = False,
**kwargs: Any,
) -> "AgentChatResponse":
"""Predict and call the tool."""
Expand Down Expand Up @@ -252,7 +262,15 @@ async def apredict_and_call(
for tool_call in tool_calls
]
tool_outputs = await asyncio.gather(*tool_tasks)
if allow_parallel_tool_calls:
tool_outputs_with_error = [
tool_output for tool_output in tool_outputs if tool_output.is_error
]
if error_on_tool_error and len(tool_outputs_with_error) > 0:
error_text = "\n\n".join(
[tool_output.content for tool_output in tool_outputs]
)
raise ValueError(error_text)
elif allow_parallel_tool_calls:
output_text = "\n\n".join(
[tool_output.content for tool_output in tool_outputs]
)
Expand Down
138 changes: 138 additions & 0 deletions llama-index-core/tests/llms/test_function_calling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
from typing import Any, AsyncGenerator, Coroutine, Dict, List, Optional, Sequence, Union

import pytest
from llama_index.core.base.llms.types import (
ChatMessage,
ChatResponse,
ChatResponseGen,
CompletionResponse,
LLMMetadata,
)
from llama_index.core.llms.function_calling import FunctionCallingLLM
from llama_index.core.llms.llm import ToolSelection
from llama_index.core.program.function_program import FunctionTool, get_function_tool
from llama_index.core.tools.types import BaseTool
from pydantic import BaseModel, Field


class MockFunctionCallingLLM(FunctionCallingLLM):
def __init__(self, tool_selection: List[ToolSelection]):
super().__init__()
self._tool_selection = tool_selection

async def achat(
self, messages: Sequence[ChatMessage], **kwargs: Any
) -> Coroutine[Any, Any, ChatResponse]:
return ChatResponse(message=ChatMessage(role="user", content=""))

def acomplete(
self, prompt: str, formatted: bool = False, **kwargs: Any
) -> Coroutine[Any, Any, CompletionResponse]:
pass

def astream_chat(
self, messages: Sequence[ChatMessage], **kwargs: Any
) -> Coroutine[Any, Any, AsyncGenerator[ChatResponse, None]]:
pass

def astream_complete(
self, prompt: str, formatted: bool = False, **kwargs: Any
) -> Coroutine[Any, Any, AsyncGenerator[CompletionResponse, None]]:
pass

def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse:
return ChatResponse(message=ChatMessage(role="user", content=""))

def complete(
self, prompt: str, formatted: bool = False, **kwargs: Any
) -> CompletionResponse:
pass

def stream_chat(
self, messages: Sequence[ChatMessage], **kwargs: Any
) -> ChatResponseGen:
pass

def stream_complete(
self, prompt: str, formatted: bool = False, **kwargs: Any
) -> ChatResponseGen:
pass

@property
def metadata(self) -> LLMMetadata:
return LLMMetadata(is_function_calling_model=True)

def _prepare_chat_with_tools(
self,
tools: Sequence["BaseTool"],
user_msg: Optional[Union[str, ChatMessage]] = None,
chat_history: Optional[List[ChatMessage]] = None,
verbose: bool = False,
allow_parallel_tool_calls: bool = False,
**kwargs: Any,
) -> Dict[str, Any]:
return {"messages": []}

def get_tool_calls_from_response(
self,
response: ChatResponse,
error_on_no_tool_call: bool = True,
**kwargs: Any,
) -> List[ToolSelection]:
return self._tool_selection


class Person(BaseModel):
name: str = Field(description="Person name")


@pytest.fixture()
def person_tool() -> FunctionTool:
return get_function_tool(Person)


@pytest.fixture()
def person_tool_selection(person_tool: FunctionTool) -> ToolSelection:
return ToolSelection(
tool_id="",
tool_name=person_tool.metadata.name,
tool_kwargs={},
)


def test_predict_and_call(
person_tool: FunctionTool, person_tool_selection: ToolSelection
) -> None:
"""Test predict_and_call will return ToolOutput with error rather than raising one."""
llm = MockFunctionCallingLLM([person_tool_selection])
response = llm.predict_and_call(tools=[person_tool])
assert all(tool_output.is_error for tool_output in response.sources)


def test_predict_and_call_throws_if_error_on_tool(
person_tool: FunctionTool, person_tool_selection: ToolSelection
) -> None:
"""Test predict_and_call will raise an error."""
llm = MockFunctionCallingLLM([person_tool_selection])
with pytest.raises(ValueError):
llm.predict_and_call(tools=[person_tool], error_on_tool_error=True)


@pytest.mark.asyncio()
async def test_apredict_and_call(
person_tool: FunctionTool, person_tool_selection: ToolSelection
) -> None:
"""Test apredict_and_call will return ToolOutput with error rather than raising one."""
llm = MockFunctionCallingLLM([person_tool_selection])
response = await llm.apredict_and_call(tools=[person_tool])
assert all(tool_output.is_error for tool_output in response.sources)


@pytest.mark.asyncio()
async def test_apredict_and_call_throws_if_error_on_tool(
person_tool: FunctionTool, person_tool_selection: ToolSelection
) -> None:
"""Test apredict_and_call will raise an error."""
llm = MockFunctionCallingLLM([person_tool_selection])
with pytest.raises(ValueError):
await llm.apredict_and_call(tools=[person_tool], error_on_tool_error=True)

0 comments on commit 772b6d2

Please sign in to comment.