Skip to content

Commit

Permalink
Merge branch 'refs/heads/main' into eugene/tlk-2361-north-modify-prea…
Browse files Browse the repository at this point in the history
…mble-or-instructions-to-generate-valid
  • Loading branch information
EugeneLightsOn committed Dec 13, 2024
2 parents fca688e + 320f6ef commit 6974443
Show file tree
Hide file tree
Showing 28 changed files with 260 additions and 211 deletions.
15 changes: 8 additions & 7 deletions src/backend/chat/custom/tool_calls.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@
from backend.model_deployments.base import BaseDeployment
from backend.schemas.context import Context
from backend.services.logger.utils import LoggerFactory
from backend.tools.base import ToolAuthException, ToolError, ToolErrorCode
from backend.tools.base import (
ToolAuthException,
ToolErrorCode,
)

TIMEOUT_SECONDS = 60

Expand Down Expand Up @@ -110,19 +113,17 @@ async def _call_tool_async(
{
"call": tool_call,
"outputs": tool.get_tool_error(
ToolError(
text="Tool authentication failed",
details=str(e),
type=ToolErrorCode.AUTH,
)
details=str(e),
text="Tool authentication failed",
error_type=ToolErrorCode.AUTH,
),
}
]
except Exception as e:
return [
{
"call": tool_call,
"outputs": tool.get_tool_error(ToolError(text=str(e))),
"outputs": tool.get_tool_error(details=str(e)),
}
]

Expand Down
5 changes: 5 additions & 0 deletions src/backend/config/configuration.template.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,11 @@ tools:
# List of web search tool names, eg: google_web_search, tavily_web_search
enabled_web_searches:
- tavily_web_search
# List of domains to filter (exclusively) for web search
domain_filters:
# - wikipedia.org
# List of sites to filter (exclusively) for web scraping
site_filters:
python_interpreter:
url: http://terrarium:8080
gmail:
Expand Down
2 changes: 2 additions & 0 deletions src/backend/config/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,8 @@ class BraveWebSearchSettings(BaseSettings, BaseModel):
class HybridWebSearchSettings(BaseSettings, BaseModel):
model_config = SETTINGS_CONFIG
enabled_web_searches: Optional[List[str]] = []
domain_filters: Optional[List[str]] = []
site_filters: Optional[List[str]] = []


class ToolSettings(BaseSettings, BaseModel):
Expand Down
5 changes: 0 additions & 5 deletions src/backend/schemas/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,6 @@
from pydantic import BaseModel, Field


class AgentToolMetadataArtifactsType(StrEnum):
DOMAIN = "domain"
SITE = "site"


class AgentVisibility(StrEnum):
PRIVATE = "private"
PUBLIC = "public"
Expand Down
18 changes: 8 additions & 10 deletions src/backend/tests/unit/chat/test_tool_calls.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,11 +116,10 @@ async def call(
"name": "toolkit_calculator",
"parameters": {"code": "6*7"},
},
"outputs": [{'type': 'other', 'success': False, 'text': 'Calculator failed', 'details': ''}],
"outputs": [{'type': 'other', 'success': False, 'text': 'Error calling tool toolkit_calculator.', 'details': 'Calculator failed'}],
},
]


@patch("backend.chat.custom.tool_calls.TIMEOUT_SECONDS", 1)
def test_async_call_tools_timeout(mock_get_available_tools) -> None:
class MockCalculator(BaseTool):
Expand Down Expand Up @@ -249,8 +248,8 @@ async def call(
)

assert {'call': {'name': 'web_scrape', 'parameters': {'code': '6*7'}}, 'outputs': [
{'details': '', 'success': False, 'text': "Model didn't pass required parameter: url", 'type'
: 'other'}]} in results
{'type': 'other', 'success': False, 'text': 'Error calling tool web_scrape.',
'details': "Model didn't pass required parameter: url"}]} in results
assert {
"call": {"name": "toolkit_calculator", "parameters": {"code": "6*7"}},
"outputs": [{"result": 42}],
Expand Down Expand Up @@ -299,7 +298,7 @@ async def call(
async_call_tools(chat_history, MockCohereDeployment(), ctx)
)
assert {'call': {'name': 'toolkit_calculator', 'parameters': {'invalid_param': '6*7'}}, 'outputs': [
{'details': '', 'success': False, 'text': "Model didn't pass required parameter: code",
{'details': "Model didn't pass required parameter: code", 'success': False, 'text': 'Error calling tool toolkit_calculator.',
'type': 'other'}]} in results

def test_tools_params_checker_invalid_param_type(mock_get_available_tools) -> None:
Expand Down Expand Up @@ -343,9 +342,8 @@ async def call(
async_call_tools(chat_history, MockCohereDeployment(), ctx)
)
assert {'call': {'name': 'toolkit_calculator', 'parameters': {'code': 6}}, 'outputs': [
{'details': '', 'success': False,
'text': "Model passed invalid parameter. Parameter 'code' must be of type str, but got int",
'type': 'other'}]} in results
{'type': 'other', 'success': False, 'text': 'Error calling tool toolkit_calculator.',
'details': "Model passed invalid parameter. Parameter 'code' must be of type str, but got int"}]} in results

def test_tools_params_checker_required_param_empty(mock_get_available_tools) -> None:
class MockCalculator(BaseTool):
Expand Down Expand Up @@ -388,5 +386,5 @@ async def call(
async_call_tools(chat_history, MockCohereDeployment(), ctx)
)
assert {'call': {'name': 'toolkit_calculator', 'parameters': {'code': ''}}, 'outputs': [
{'details': '', 'success': False, 'text': 'Model passed empty value for required parameter: code',
'type': 'other'}]} in results
{'details': 'Model passed empty value for required parameter: code', 'success': False,
'text': 'Error calling tool toolkit_calculator.', 'type': 'other'}]} in results
3 changes: 2 additions & 1 deletion src/backend/tests/unit/tools/test_calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,5 @@ async def test_calculator_invalid_syntax() -> None:
ctx = Context()
calculator = Calculator()
result = await calculator.call({"code": "2+"}, ctx)
assert result == {"text": "Parsing error - syntax not allowed."}

assert result == [{'details': 'parse error [column 2]: parity, expression: 2+', 'success': False, 'text': 'Error calling tool toolkit_calculator.', 'type': 'other'}]
5 changes: 3 additions & 2 deletions src/backend/tests/unit/tools/test_lang_chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,8 @@ async def test_wiki_retriever_no_docs() -> None:
):
result = await retriever.call({"query": query}, ctx)

assert result == []
assert result == [{'details': '','success': False,'text': 'No results found.','type': 'other'}]



@pytest.mark.skipif(not is_cohere_env_set, reason="Cohere API key not set")
Expand Down Expand Up @@ -163,4 +164,4 @@ async def test_vector_db_retriever_no_docs() -> None:
mock_db.as_retriever().get_relevant_documents.return_value = mock_docs
result = await retriever.call({"query": query}, ctx)

assert result == []
assert result == [{'details': '', 'success': False, 'text': 'No results found.', 'type': 'other'}]
10 changes: 6 additions & 4 deletions src/backend/tools/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,15 @@ def __init__(self, message, tool_id: str):
self.message = message
self.tool_id = tool_id


class ToolError(BaseModel, extra="allow"):
type: ToolErrorCode = ToolErrorCode.OTHER
success: bool = False
text: str
details: str = ""

class ToolArgument(StrEnum):
DOMAIN_FILTER = "domain_filter"
SITE_FILTER = "site_filter"

class ParametersValidationMeta(type):
"""
Expand Down Expand Up @@ -123,14 +125,14 @@ def _handle_tool_specific_errors(cls, error: Exception, **kwargs: Any) -> None:
...

@classmethod
def get_tool_error(cls, err: ToolError):
tool_error = err.model_dump()
def get_tool_error(cls, details: str, text: str = "Error calling tool", error_type: ToolErrorCode = ToolErrorCode.OTHER):
tool_error = ToolError(text=f"{text} {cls.ID}.", details=details, type=error_type).model_dump()
logger.error(event=f"Error calling tool {cls.ID}", error=tool_error)
return [tool_error]

@classmethod
def get_no_results_error(cls):
return cls.get_tool_error(ToolError(text="No results found."))
return ToolError(text="No results found.", details="No results found for the given params.")

@abstractmethod
async def call(
Expand Down
25 changes: 11 additions & 14 deletions src/backend/tools/brave_search/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,12 @@

from backend.config.settings import Settings
from backend.database_models.database import DBSessionDep
from backend.schemas.agent import AgentToolMetadataArtifactsType
from backend.schemas.tool import ToolCategory, ToolDefinition
from backend.tools.base import BaseTool
from backend.tools.base import BaseTool, ToolArgument
from backend.tools.brave_search.client import BraveClient
from backend.tools.utils.mixins import WebSearchFilteringMixin


class BraveWebSearch(BaseTool, WebSearchFilteringMixin):
class BraveWebSearch(BaseTool):
ID = "brave_web_search"
BRAVE_API_KEY = Settings().get('tools.brave_web_search.api_key')

Expand Down Expand Up @@ -49,23 +47,22 @@ async def call(
) -> List[Dict[str, Any]]:
query = parameters.get("query", "")

# Get domain filtering from kwargs or set on Agent tool metadata
if "include_domains" in kwargs:
filtered_domains = kwargs.get("include_domains")
else:
filtered_domains = self.get_filters(
AgentToolMetadataArtifactsType.DOMAIN, session, ctx
# Get domain filtering from kwargs
filtered_domains = kwargs.get(ToolArgument.DOMAIN_FILTER, [])

try:
response = await self.client.search_async(
q=query, count=self.num_results, include_domains=filtered_domains
)
except Exception as e:
return self.get_tool_error(details=str(e))

response = await self.client.search_async(
q=query, count=self.num_results, include_domains=filtered_domains
)
response = dict(response)

results = response.get("web", {}).get("results", [])

if not results:
self.get_no_results_error()
return self.get_no_results_error()

tool_results = []
for result in results:
Expand Down
6 changes: 3 additions & 3 deletions src/backend/tools/calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,11 @@ async def call(

to_evaluate = expression.replace("pi", "PI").replace("e", "E")

result = []
try:
result = {"text": math_parser.parse(to_evaluate).evaluate({})}
except Exception as e:
logger.error(event=f"[Calculator] Error parsing expression: {e}")
result = {"text": "Parsing error - syntax not allowed."}
return self.get_tool_error(details=str(e))

return result

return result # type: ignore
13 changes: 9 additions & 4 deletions src/backend/tools/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,12 @@ async def call(self, parameters: dict, **kwargs: Any) -> List[Dict[str, Any]]:
session = kwargs.get("session")
user_id = kwargs.get("user_id")
if not file:
return []
return self.get_tool_error(details="Files are not passed in model generated params")

_, file_id = file
retrieved_file = file_crud.get_file(session, file_id, user_id)
if not retrieved_file:
return []
return self.get_tool_error(details="The wrong files were passed in the tool parameters, or files were not found")

return [
{
Expand Down Expand Up @@ -125,13 +125,15 @@ async def call(
user_id = kwargs.get("user_id")

if not query or not files:
return []
return self.get_tool_error(
details="Missing query or files. The wrong files might have been passed in the tool parameters")

file_ids = [file_id for _, file_id in files]
retrieved_files = file_crud.get_files_by_ids(session, file_ids, user_id)

if not retrieved_files:
return []
return self.get_tool_error(
details="Missing files. The wrong files might have been passed in the tool parameters")

results = []
for file in retrieved_files:
Expand All @@ -142,4 +144,7 @@ async def call(
"url": file.file_name,
}
)
if not results:
return self.get_no_results_error()

return results
37 changes: 21 additions & 16 deletions src/backend/tools/google_drive/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,17 @@ async def call(self, parameters: dict, **kwargs: Any) -> List[Dict[str, Any]]:
# Search Google Drive
logger.info(event="[Google Drive] Defaulting to raw Google Drive search.")
agent_tool_metadata = kwargs["agent_tool_metadata"]
documents = await _default_gdrive_list_files(
user_id=user_id, query=query, agent_tool_metadata=agent_tool_metadata
)
try:
documents = await _default_gdrive_list_files(
user_id=user_id, query=query, agent_tool_metadata=agent_tool_metadata
)
except Exception as e:
return self.get_tool_error(details=str(e))

if not documents:
logger.info(event="[Google Drive] No documents found.")
return self.get_no_results_error()

return documents


Expand Down Expand Up @@ -141,20 +149,17 @@ async def _default_gdrive_list_files(
fields = f"nextPageToken, files({DOC_FIELDS})"

search_results = []
try:
search_results = (
service.files()
.list(
pageSize=SEARCH_LIMIT,
q=q,
includeItemsFromAllDrives=True,
supportsAllDrives=True,
fields=fields,
)
.execute()
search_results = (
service.files()
.list(
pageSize=SEARCH_LIMIT,
q=q,
includeItemsFromAllDrives=True,
supportsAllDrives=True,
fields=fields,
)
except Exception as error:
logger.error(event="[Google Drive] Error searching files", error=error)
.execute()
)

files = search_results.get("files", [])
if not files:
Expand Down
25 changes: 10 additions & 15 deletions src/backend/tools/google_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,11 @@

from backend.config.settings import Settings
from backend.database_models.database import DBSessionDep
from backend.schemas.agent import AgentToolMetadataArtifactsType
from backend.schemas.tool import ToolCategory, ToolDefinition
from backend.tools.base import BaseTool
from backend.tools.utils.mixins import WebSearchFilteringMixin
from backend.tools.base import BaseTool, ToolArgument


class GoogleWebSearch(BaseTool, WebSearchFilteringMixin):
class GoogleWebSearch(BaseTool):
ID = "google_web_search"
API_KEY = Settings().get('tools.google_web_search.api_key')
CSE_ID = Settings().get('tools.google_web_search.cse_id')
Expand Down Expand Up @@ -48,17 +46,14 @@ async def call(
query = parameters.get("query", "")
cse = self.client.cse()

# Get domain filtering from kwargs or set on Agent tool metadata
if "include_domains" in kwargs:
filtered_domains = kwargs.get("include_domains")
else:
filtered_domains = self.get_filters(
AgentToolMetadataArtifactsType.DOMAIN, session, ctx
)

site_filters = [f"site:{domain}" for domain in filtered_domains]
response = cse.list(q=query, cx=self.CSE_ID, orTerms=site_filters).execute()
search_results = response.get("items", [])
# Get domain filtering from kwargs
filtered_domains = kwargs.get(ToolArgument.DOMAIN_FILTER, [])
domain_filters = [f"site:{domain}" for domain in filtered_domains]
try:
response = cse.list(q=query, cx=self.CSE_ID, orTerms=domain_filters).execute()
search_results = response.get("items", [])
except Exception as e:
return self.get_tool_error(details=str(e))

if not search_results:
return self.get_no_results_error()
Expand Down
Loading

0 comments on commit 6974443

Please sign in to comment.