From f370c466f540d93543076ffe891799e4aa6f3fd0 Mon Sep 17 00:00:00 2001 From: Tianjing Li Date: Thu, 12 Dec 2024 14:36:44 -0500 Subject: [PATCH 1/3] backend: Add config for domain and site Web Search filtering (#863) * Make filtering occur from backend config * Add site filter by enum * Comment wikpedia --- .../config/configuration.template.yaml | 5 +++ src/backend/config/settings.py | 2 ++ src/backend/schemas/agent.py | 5 --- src/backend/tools/base.py | 4 ++- src/backend/tools/brave_search/tool.py | 15 +++----- src/backend/tools/google_search.py | 19 ++++------- src/backend/tools/hybrid_search.py | 23 +++++-------- src/backend/tools/tavily_search.py | 15 +++----- src/backend/tools/utils/mixins.py | 34 ------------------- 9 files changed, 32 insertions(+), 90 deletions(-) diff --git a/src/backend/config/configuration.template.yaml b/src/backend/config/configuration.template.yaml index 0597a946fb..2fe3e55a8d 100644 --- a/src/backend/config/configuration.template.yaml +++ b/src/backend/config/configuration.template.yaml @@ -24,6 +24,11 @@ tools: # List of web search tool names, eg: google_web_search, tavily_web_search enabled_web_searches: - tavily_web_search + # List of domains to filter (exclusively) for web search + domain_filters: + # - wikipedia.org + # List of sites to filter (exclusively) for web scraping + site_filters: python_interpreter: url: http://terrarium:8080 gmail: diff --git a/src/backend/config/settings.py b/src/backend/config/settings.py index 123a0ca3c1..322980dc2b 100644 --- a/src/backend/config/settings.py +++ b/src/backend/config/settings.py @@ -208,6 +208,8 @@ class BraveWebSearchSettings(BaseSettings, BaseModel): class HybridWebSearchSettings(BaseSettings, BaseModel): model_config = SETTINGS_CONFIG enabled_web_searches: Optional[List[str]] = [] + domain_filters: Optional[List[str]] = [] + site_filters: Optional[List[str]] = [] class ToolSettings(BaseSettings, BaseModel): diff --git a/src/backend/schemas/agent.py b/src/backend/schemas/agent.py index 18edc2d8b9..5719f7e55c 100644 --- a/src/backend/schemas/agent.py +++ b/src/backend/schemas/agent.py @@ -5,11 +5,6 @@ from pydantic import BaseModel, Field -class AgentToolMetadataArtifactsType(StrEnum): - DOMAIN = "domain" - SITE = "site" - - class AgentVisibility(StrEnum): PRIVATE = "private" PUBLIC = "public" diff --git a/src/backend/tools/base.py b/src/backend/tools/base.py index e5343fab3a..b491ae3907 100644 --- a/src/backend/tools/base.py +++ b/src/backend/tools/base.py @@ -35,7 +35,9 @@ class ToolError(BaseModel, extra="allow"): text: str details: str = "" - +class ToolArgument(StrEnum): + DOMAIN_FILTER = "domain_filter" + SITE_FILTER = "site_filter" class ParametersValidationMeta(type): """ Metaclass to decorate all tools `call` methods with the parameter checker. diff --git a/src/backend/tools/brave_search/tool.py b/src/backend/tools/brave_search/tool.py index 7a775ff047..5331500531 100644 --- a/src/backend/tools/brave_search/tool.py +++ b/src/backend/tools/brave_search/tool.py @@ -2,14 +2,12 @@ from backend.config.settings import Settings from backend.database_models.database import DBSessionDep -from backend.schemas.agent import AgentToolMetadataArtifactsType from backend.schemas.tool import ToolCategory, ToolDefinition -from backend.tools.base import BaseTool +from backend.tools.base import BaseTool, ToolArgument from backend.tools.brave_search.client import BraveClient -from backend.tools.utils.mixins import WebSearchFilteringMixin -class BraveWebSearch(BaseTool, WebSearchFilteringMixin): +class BraveWebSearch(BaseTool): ID = "brave_web_search" BRAVE_API_KEY = Settings().get('tools.brave_web_search.api_key') @@ -49,13 +47,8 @@ async def call( ) -> List[Dict[str, Any]]: query = parameters.get("query", "") - # Get domain filtering from kwargs or set on Agent tool metadata - if "include_domains" in kwargs: - filtered_domains = kwargs.get("include_domains") - else: - filtered_domains = self.get_filters( - AgentToolMetadataArtifactsType.DOMAIN, session, ctx - ) + # Get domain filtering from kwargs + filtered_domains = kwargs.get(ToolArgument.DOMAIN_FILTER, []) response = await self.client.search_async( q=query, count=self.num_results, include_domains=filtered_domains diff --git a/src/backend/tools/google_search.py b/src/backend/tools/google_search.py index 54ea8586b5..f05af3f8b1 100644 --- a/src/backend/tools/google_search.py +++ b/src/backend/tools/google_search.py @@ -4,13 +4,11 @@ from backend.config.settings import Settings from backend.database_models.database import DBSessionDep -from backend.schemas.agent import AgentToolMetadataArtifactsType from backend.schemas.tool import ToolCategory, ToolDefinition -from backend.tools.base import BaseTool -from backend.tools.utils.mixins import WebSearchFilteringMixin +from backend.tools.base import BaseTool, ToolArgument -class GoogleWebSearch(BaseTool, WebSearchFilteringMixin): +class GoogleWebSearch(BaseTool): ID = "google_web_search" API_KEY = Settings().get('tools.google_web_search.api_key') CSE_ID = Settings().get('tools.google_web_search.cse_id') @@ -48,16 +46,11 @@ async def call( query = parameters.get("query", "") cse = self.client.cse() - # Get domain filtering from kwargs or set on Agent tool metadata - if "include_domains" in kwargs: - filtered_domains = kwargs.get("include_domains") - else: - filtered_domains = self.get_filters( - AgentToolMetadataArtifactsType.DOMAIN, session, ctx - ) + # Get domain filtering from kwargs + filtered_domains = kwargs.get(ToolArgument.DOMAIN_FILTER, []) + domain_filters = [f"site:{domain}" for domain in filtered_domains] - site_filters = [f"site:{domain}" for domain in filtered_domains] - response = cse.list(q=query, cx=self.CSE_ID, orTerms=site_filters).execute() + response = cse.list(q=query, cx=self.CSE_ID, orTerms=domain_filters).execute() search_results = response.get("items", []) if not search_results: diff --git a/src/backend/tools/hybrid_search.py b/src/backend/tools/hybrid_search.py index 693ee26c49..d8b815de04 100644 --- a/src/backend/tools/hybrid_search.py +++ b/src/backend/tools/hybrid_search.py @@ -5,22 +5,22 @@ from backend.config.settings import Settings from backend.database_models.database import DBSessionDep from backend.model_deployments.base import BaseDeployment -from backend.schemas.agent import AgentToolMetadataArtifactsType from backend.schemas.tool import ToolCategory, ToolDefinition -from backend.tools.base import BaseTool +from backend.tools.base import BaseTool, ToolArgument from backend.tools.brave_search.tool import BraveWebSearch from backend.tools.google_search import GoogleWebSearch from backend.tools.tavily_search import TavilyWebSearch -from backend.tools.utils.mixins import WebSearchFilteringMixin from backend.tools.web_scrape import WebScrapeTool -class HybridWebSearch(BaseTool, WebSearchFilteringMixin): +class HybridWebSearch(BaseTool): ID = "hybrid_web_search" POST_RERANK_MAX_RESULTS = 5 AVAILABLE_WEB_SEARCH_TOOLS = [TavilyWebSearch, GoogleWebSearch, BraveWebSearch] - ENABLED_WEB_SEARCH_TOOLS = Settings().get('tools.hybrid_web_search.enabled_web_searches') WEB_SCRAPE_TOOL = WebScrapeTool + ENABLED_WEB_SEARCH_TOOLS = Settings().get('tools.hybrid_web_search.enabled_web_searches') + DOMAIN_FILTER = Settings().get('tools.hybrid_web_search.domain_filters') or [] + SITE_FILTER = Settings().get('tools.hybrid_web_search.site_filters') or [] def __init__(self): available_search_tools = self.get_available_search_tools() @@ -83,7 +83,7 @@ def _gather_search_tasks( tasks.append(search_tool.call(parameters, ctx, session, **kwargs)) # Add web scrape tool calls - filtered_sites = kwargs.get("include_sites", []) + filtered_sites = kwargs.get(ToolArgument.SITE_FILTER, []) for site in filtered_sites: tasks.append(self.web_scrape_tool.call({"url": site}, ctx, **kwargs)) @@ -96,16 +96,9 @@ async def call( query = parameters.get("query", "") # Handle domain filtering -> filter in search APIs - filtered_domains = self.get_filters( - AgentToolMetadataArtifactsType.DOMAIN, session, ctx - ) - kwargs["include_domains"] = filtered_domains - + kwargs[ToolArgument.DOMAIN_FILTER] = self.DOMAIN_FILTER # Handle site filtering -> perform web scraping on sites - filtered_sites = self.get_filters( - AgentToolMetadataArtifactsType.SITE, session, ctx - ) - kwargs["include_sites"] = filtered_sites + kwargs[ToolArgument.SITE_FILTER] = self.SITE_FILTER tasks = self._gather_search_tasks(parameters, ctx, session, **kwargs) diff --git a/src/backend/tools/tavily_search.py b/src/backend/tools/tavily_search.py index 51d8103a59..b8746c74c0 100644 --- a/src/backend/tools/tavily_search.py +++ b/src/backend/tools/tavily_search.py @@ -4,13 +4,11 @@ from backend.config.settings import Settings from backend.database_models.database import DBSessionDep -from backend.schemas.agent import AgentToolMetadataArtifactsType from backend.schemas.tool import ToolCategory, ToolDefinition -from backend.tools.base import BaseTool -from backend.tools.utils.mixins import WebSearchFilteringMixin +from backend.tools.base import BaseTool, ToolArgument -class TavilyWebSearch(BaseTool, WebSearchFilteringMixin): +class TavilyWebSearch(BaseTool): ID = "tavily_web_search" TAVILY_API_KEY = Settings().get('tools.tavily_web_search.api_key') @@ -48,13 +46,8 @@ async def call( # Gather search parameters query = parameters.get("query", "") - # Get domain filtering from kwargs or set on Agent tool metadata - if "include_domains" in kwargs: - filtered_domains = kwargs.get("include_domains") - else: - filtered_domains = self.get_filters( - AgentToolMetadataArtifactsType.DOMAIN, session, ctx - ) + # Get domain filtering from kwargs + filtered_domains = kwargs.get(ToolArgument.DOMAIN_FILTER, []) try: result = self.client.search( diff --git a/src/backend/tools/utils/mixins.py b/src/backend/tools/utils/mixins.py index ce4827140f..ff3001f125 100644 --- a/src/backend/tools/utils/mixins.py +++ b/src/backend/tools/utils/mixins.py @@ -1,7 +1,3 @@ -from backend.crud import agent_tool_metadata as agent_tool_metadata_crud -from backend.database_models.database import DBSessionDep -from backend.schemas.agent import AgentToolMetadataArtifactsType -from backend.schemas.context import Context from backend.services.auth.crypto import encrypt from backend.services.cache import cache_get_dict, cache_put @@ -24,33 +20,3 @@ def insert_tool_auth_cache(self, user_id: str, tool_id: str) -> str: cache_put(key, payload) return key - - -class WebSearchFilteringMixin: - def get_filters( - self, - filter_type: AgentToolMetadataArtifactsType, - session: DBSessionDep, - ctx: Context, - ) -> list[str]: - agent_id = ctx.get_agent_id() - user_id = ctx.get_user_id() - - if not agent_id or not user_id: - return [] - - agent_tool_metadata = agent_tool_metadata_crud.get_agent_tool_metadata( - db=session, - agent_id=agent_id, - tool_name=self.ID, - user_id=user_id, - ) - - if not agent_tool_metadata: - return [] - - return [ - artifact[filter_type] - for artifact in agent_tool_metadata.artifacts - if filter_type in artifact - ] From 925f3c2784d8c3316cf6315e98d0d62f6492e8d9 Mon Sep 17 00:00:00 2001 From: Tianjing Li Date: Fri, 13 Dec 2024 13:27:52 -0500 Subject: [PATCH 2/3] backend: Fix web scrape issue with results format (#878) Fix web scrape issue --- src/backend/tools/web_scrape.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/backend/tools/web_scrape.py b/src/backend/tools/web_scrape.py index 6c87e5421e..849ac606ed 100644 --- a/src/backend/tools/web_scrape.py +++ b/src/backend/tools/web_scrape.py @@ -76,13 +76,14 @@ async def call( async def handle_response(self, response: aiohttp.ClientResponse, url: str): content_type = response.headers.get("content-type") + results = [] # If URL is a PDF, read contents using helper function if "application/pdf" in content_type: - return { + results.append({ "text": read_pdf(response.content), "url": url, - } + }) elif "text/html" in content_type: content = await response.text() soup = BeautifulSoup(content, "html.parser") @@ -98,6 +99,8 @@ async def handle_response(self, response: aiohttp.ClientResponse, url: str): if title: data["title"] = title - return data + results.append(data) else: raise ValueError(f"Unsupported Content Type using web scrape: {content_type}") + + return results From 320f6ef9c504e51fe2ba4cb3e13debc9d23a2bde Mon Sep 17 00:00:00 2001 From: Eugene P <144219719+EugeneLightsOn@users.noreply.github.com> Date: Fri, 13 Dec 2024 20:16:39 +0100 Subject: [PATCH 3/3] TLK-2090 tool errors processing (#861) * TLK-2090 tool errors processing * TLK-2090 tool errors processing - lint + tests * TLK-2090 tool errors processing - review fixes * TLK-2090 tool errors processing - review fixes --- src/backend/chat/custom/tool_calls.py | 15 +++--- .../tests/unit/chat/test_tool_calls.py | 18 ++++--- .../tests/unit/tools/test_calculator.py | 3 +- .../tests/unit/tools/test_lang_chain.py | 5 +- src/backend/tools/base.py | 8 ++-- src/backend/tools/brave_search/tool.py | 12 +++-- src/backend/tools/calculator.py | 6 +-- src/backend/tools/files.py | 13 +++-- src/backend/tools/google_drive/tool.py | 37 ++++++++------- src/backend/tools/google_search.py | 8 ++-- src/backend/tools/hybrid_search.py | 3 ++ src/backend/tools/lang_chain.py | 37 ++++++++++----- src/backend/tools/python_interpreter.py | 13 +++-- src/backend/tools/slack/tool.py | 12 ++++- src/backend/tools/tavily_search.py | 2 +- src/community/tools/arxiv.py | 9 +++- src/community/tools/clinicaltrials.py | 8 +++- src/community/tools/connector.py | 10 +++- src/community/tools/llama_index.py | 47 +++++++++---------- src/community/tools/pub_med.py | 8 +++- src/community/tools/wolfram.py | 9 +++- .../src/components/MessageRow/ToolEvents.tsx | 43 +++++++++++++++-- .../assistants_web/src/hooks/use-chat.ts | 20 +++----- 23 files changed, 225 insertions(+), 121 deletions(-) diff --git a/src/backend/chat/custom/tool_calls.py b/src/backend/chat/custom/tool_calls.py index 5a67a42d93..a7592dfd7c 100644 --- a/src/backend/chat/custom/tool_calls.py +++ b/src/backend/chat/custom/tool_calls.py @@ -9,7 +9,10 @@ from backend.model_deployments.base import BaseDeployment from backend.schemas.context import Context from backend.services.logger.utils import LoggerFactory -from backend.tools.base import ToolAuthException, ToolError, ToolErrorCode +from backend.tools.base import ( + ToolAuthException, + ToolErrorCode, +) TIMEOUT_SECONDS = 60 @@ -110,11 +113,9 @@ async def _call_tool_async( { "call": tool_call, "outputs": tool.get_tool_error( - ToolError( - text="Tool authentication failed", - details=str(e), - type=ToolErrorCode.AUTH, - ) + details=str(e), + text="Tool authentication failed", + error_type=ToolErrorCode.AUTH, ), } ] @@ -122,7 +123,7 @@ async def _call_tool_async( return [ { "call": tool_call, - "outputs": tool.get_tool_error(ToolError(text=str(e))), + "outputs": tool.get_tool_error(details=str(e)), } ] diff --git a/src/backend/tests/unit/chat/test_tool_calls.py b/src/backend/tests/unit/chat/test_tool_calls.py index 16e5b15b59..0d7b3ba364 100644 --- a/src/backend/tests/unit/chat/test_tool_calls.py +++ b/src/backend/tests/unit/chat/test_tool_calls.py @@ -116,11 +116,10 @@ async def call( "name": "toolkit_calculator", "parameters": {"code": "6*7"}, }, - "outputs": [{'type': 'other', 'success': False, 'text': 'Calculator failed', 'details': ''}], + "outputs": [{'type': 'other', 'success': False, 'text': 'Error calling tool toolkit_calculator.', 'details': 'Calculator failed'}], }, ] - @patch("backend.chat.custom.tool_calls.TIMEOUT_SECONDS", 1) def test_async_call_tools_timeout(mock_get_available_tools) -> None: class MockCalculator(BaseTool): @@ -249,8 +248,8 @@ async def call( ) assert {'call': {'name': 'web_scrape', 'parameters': {'code': '6*7'}}, 'outputs': [ - {'details': '', 'success': False, 'text': "Model didn't pass required parameter: url", 'type' - : 'other'}]} in results + {'type': 'other', 'success': False, 'text': 'Error calling tool web_scrape.', + 'details': "Model didn't pass required parameter: url"}]} in results assert { "call": {"name": "toolkit_calculator", "parameters": {"code": "6*7"}}, "outputs": [{"result": 42}], @@ -299,7 +298,7 @@ async def call( async_call_tools(chat_history, MockCohereDeployment(), ctx) ) assert {'call': {'name': 'toolkit_calculator', 'parameters': {'invalid_param': '6*7'}}, 'outputs': [ - {'details': '', 'success': False, 'text': "Model didn't pass required parameter: code", + {'details': "Model didn't pass required parameter: code", 'success': False, 'text': 'Error calling tool toolkit_calculator.', 'type': 'other'}]} in results def test_tools_params_checker_invalid_param_type(mock_get_available_tools) -> None: @@ -343,9 +342,8 @@ async def call( async_call_tools(chat_history, MockCohereDeployment(), ctx) ) assert {'call': {'name': 'toolkit_calculator', 'parameters': {'code': 6}}, 'outputs': [ - {'details': '', 'success': False, - 'text': "Model passed invalid parameter. Parameter 'code' must be of type str, but got int", - 'type': 'other'}]} in results + {'type': 'other', 'success': False, 'text': 'Error calling tool toolkit_calculator.', + 'details': "Model passed invalid parameter. Parameter 'code' must be of type str, but got int"}]} in results def test_tools_params_checker_required_param_empty(mock_get_available_tools) -> None: class MockCalculator(BaseTool): @@ -388,5 +386,5 @@ async def call( async_call_tools(chat_history, MockCohereDeployment(), ctx) ) assert {'call': {'name': 'toolkit_calculator', 'parameters': {'code': ''}}, 'outputs': [ - {'details': '', 'success': False, 'text': 'Model passed empty value for required parameter: code', - 'type': 'other'}]} in results + {'details': 'Model passed empty value for required parameter: code', 'success': False, + 'text': 'Error calling tool toolkit_calculator.', 'type': 'other'}]} in results diff --git a/src/backend/tests/unit/tools/test_calculator.py b/src/backend/tests/unit/tools/test_calculator.py index 5ff68ab063..3a821a4d45 100644 --- a/src/backend/tests/unit/tools/test_calculator.py +++ b/src/backend/tests/unit/tools/test_calculator.py @@ -17,4 +17,5 @@ async def test_calculator_invalid_syntax() -> None: ctx = Context() calculator = Calculator() result = await calculator.call({"code": "2+"}, ctx) - assert result == {"text": "Parsing error - syntax not allowed."} + + assert result == [{'details': 'parse error [column 2]: parity, expression: 2+', 'success': False, 'text': 'Error calling tool toolkit_calculator.', 'type': 'other'}] diff --git a/src/backend/tests/unit/tools/test_lang_chain.py b/src/backend/tests/unit/tools/test_lang_chain.py index 1ca920502b..50a593dc92 100644 --- a/src/backend/tests/unit/tools/test_lang_chain.py +++ b/src/backend/tests/unit/tools/test_lang_chain.py @@ -78,7 +78,8 @@ async def test_wiki_retriever_no_docs() -> None: ): result = await retriever.call({"query": query}, ctx) - assert result == [] + assert result == [{'details': '','success': False,'text': 'No results found.','type': 'other'}] + @pytest.mark.skipif(not is_cohere_env_set, reason="Cohere API key not set") @@ -163,4 +164,4 @@ async def test_vector_db_retriever_no_docs() -> None: mock_db.as_retriever().get_relevant_documents.return_value = mock_docs result = await retriever.call({"query": query}, ctx) - assert result == [] + assert result == [{'details': '', 'success': False, 'text': 'No results found.', 'type': 'other'}] diff --git a/src/backend/tools/base.py b/src/backend/tools/base.py index b491ae3907..06c2b0aede 100644 --- a/src/backend/tools/base.py +++ b/src/backend/tools/base.py @@ -28,7 +28,6 @@ def __init__(self, message, tool_id: str): self.message = message self.tool_id = tool_id - class ToolError(BaseModel, extra="allow"): type: ToolErrorCode = ToolErrorCode.OTHER success: bool = False @@ -38,6 +37,7 @@ class ToolError(BaseModel, extra="allow"): class ToolArgument(StrEnum): DOMAIN_FILTER = "domain_filter" SITE_FILTER = "site_filter" + class ParametersValidationMeta(type): """ Metaclass to decorate all tools `call` methods with the parameter checker. @@ -90,14 +90,14 @@ def _handle_tool_specific_errors(cls, error: Exception, **kwargs: Any) -> None: ... @classmethod - def get_tool_error(cls, err: ToolError): - tool_error = err.model_dump() + def get_tool_error(cls, details: str, text: str = "Error calling tool", error_type: ToolErrorCode = ToolErrorCode.OTHER): + tool_error = ToolError(text=f"{text} {cls.ID}.", details=details, type=error_type).model_dump() logger.error(event=f"Error calling tool {cls.ID}", error=tool_error) return [tool_error] @classmethod def get_no_results_error(cls): - return cls.get_tool_error(ToolError(text="No results found.")) + return ToolError(text="No results found.", details="No results found for the given params.") @abstractmethod async def call( diff --git a/src/backend/tools/brave_search/tool.py b/src/backend/tools/brave_search/tool.py index 5331500531..4e424b5f40 100644 --- a/src/backend/tools/brave_search/tool.py +++ b/src/backend/tools/brave_search/tool.py @@ -50,15 +50,19 @@ async def call( # Get domain filtering from kwargs filtered_domains = kwargs.get(ToolArgument.DOMAIN_FILTER, []) - response = await self.client.search_async( - q=query, count=self.num_results, include_domains=filtered_domains - ) + try: + response = await self.client.search_async( + q=query, count=self.num_results, include_domains=filtered_domains + ) + except Exception as e: + return self.get_tool_error(details=str(e)) + response = dict(response) results = response.get("web", {}).get("results", []) if not results: - self.get_no_results_error() + return self.get_no_results_error() tool_results = [] for result in results: diff --git a/src/backend/tools/calculator.py b/src/backend/tools/calculator.py index 3b96859663..5de2c1dcda 100644 --- a/src/backend/tools/calculator.py +++ b/src/backend/tools/calculator.py @@ -52,11 +52,11 @@ async def call( to_evaluate = expression.replace("pi", "PI").replace("e", "E") - result = [] try: result = {"text": math_parser.parse(to_evaluate).evaluate({})} except Exception as e: logger.error(event=f"[Calculator] Error parsing expression: {e}") - result = {"text": "Parsing error - syntax not allowed."} + return self.get_tool_error(details=str(e)) - return result + + return result # type: ignore diff --git a/src/backend/tools/files.py b/src/backend/tools/files.py index 146a741c0e..707d4cb1ec 100644 --- a/src/backend/tools/files.py +++ b/src/backend/tools/files.py @@ -58,12 +58,12 @@ async def call(self, parameters: dict, **kwargs: Any) -> List[Dict[str, Any]]: session = kwargs.get("session") user_id = kwargs.get("user_id") if not file: - return [] + return self.get_tool_error(details="Files are not passed in model generated params") _, file_id = file retrieved_file = file_crud.get_file(session, file_id, user_id) if not retrieved_file: - return [] + return self.get_tool_error(details="The wrong files were passed in the tool parameters, or files were not found") return [ { @@ -125,13 +125,15 @@ async def call( user_id = kwargs.get("user_id") if not query or not files: - return [] + return self.get_tool_error( + details="Missing query or files. The wrong files might have been passed in the tool parameters") file_ids = [file_id for _, file_id in files] retrieved_files = file_crud.get_files_by_ids(session, file_ids, user_id) if not retrieved_files: - return [] + return self.get_tool_error( + details="Missing files. The wrong files might have been passed in the tool parameters") results = [] for file in retrieved_files: @@ -142,4 +144,7 @@ async def call( "url": file.file_name, } ) + if not results: + return self.get_no_results_error() + return results diff --git a/src/backend/tools/google_drive/tool.py b/src/backend/tools/google_drive/tool.py index 6131e095d6..d5b0a1f8f5 100644 --- a/src/backend/tools/google_drive/tool.py +++ b/src/backend/tools/google_drive/tool.py @@ -77,9 +77,17 @@ async def call(self, parameters: dict, **kwargs: Any) -> List[Dict[str, Any]]: # Search Google Drive logger.info(event="[Google Drive] Defaulting to raw Google Drive search.") agent_tool_metadata = kwargs["agent_tool_metadata"] - documents = await _default_gdrive_list_files( - user_id=user_id, query=query, agent_tool_metadata=agent_tool_metadata - ) + try: + documents = await _default_gdrive_list_files( + user_id=user_id, query=query, agent_tool_metadata=agent_tool_metadata + ) + except Exception as e: + return self.get_tool_error(details=str(e)) + + if not documents: + logger.info(event="[Google Drive] No documents found.") + return self.get_no_results_error() + return documents @@ -141,20 +149,17 @@ async def _default_gdrive_list_files( fields = f"nextPageToken, files({DOC_FIELDS})" search_results = [] - try: - search_results = ( - service.files() - .list( - pageSize=SEARCH_LIMIT, - q=q, - includeItemsFromAllDrives=True, - supportsAllDrives=True, - fields=fields, - ) - .execute() + search_results = ( + service.files() + .list( + pageSize=SEARCH_LIMIT, + q=q, + includeItemsFromAllDrives=True, + supportsAllDrives=True, + fields=fields, ) - except Exception as error: - logger.error(event="[Google Drive] Error searching files", error=error) + .execute() + ) files = search_results.get("files", []) if not files: diff --git a/src/backend/tools/google_search.py b/src/backend/tools/google_search.py index f05af3f8b1..14a7e21dd1 100644 --- a/src/backend/tools/google_search.py +++ b/src/backend/tools/google_search.py @@ -49,9 +49,11 @@ async def call( # Get domain filtering from kwargs filtered_domains = kwargs.get(ToolArgument.DOMAIN_FILTER, []) domain_filters = [f"site:{domain}" for domain in filtered_domains] - - response = cse.list(q=query, cx=self.CSE_ID, orTerms=domain_filters).execute() - search_results = response.get("items", []) + try: + response = cse.list(q=query, cx=self.CSE_ID, orTerms=domain_filters).execute() + search_results = response.get("items", []) + except Exception as e: + return self.get_tool_error(details=str(e)) if not search_results: return self.get_no_results_error() diff --git a/src/backend/tools/hybrid_search.py b/src/backend/tools/hybrid_search.py index d8b815de04..849691c3cb 100644 --- a/src/backend/tools/hybrid_search.py +++ b/src/backend/tools/hybrid_search.py @@ -115,6 +115,9 @@ async def call( **kwargs, ) + if not reranked_results: + return self.get_no_results_error() + return reranked_results async def rerank_results( diff --git a/src/backend/tools/lang_chain.py b/src/backend/tools/lang_chain.py index 4510a27ea2..71f12c5d1c 100644 --- a/src/backend/tools/lang_chain.py +++ b/src/backend/tools/lang_chain.py @@ -59,11 +59,17 @@ async def call( ) -> List[Dict[str, Any]]: wiki_retriever = WikipediaRetriever() query = parameters.get("query", "") - docs = wiki_retriever.get_relevant_documents(query) - text_splitter = CharacterTextSplitter( - chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap - ) - documents = text_splitter.split_documents(docs) + try: + docs = wiki_retriever.get_relevant_documents(query) + text_splitter = CharacterTextSplitter( + chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap + ) + documents = text_splitter.split_documents(docs) + except Exception as e: + return self.get_tool_error(details=str(e)) + + if not documents: + return self.get_no_results_error() return [ { @@ -115,13 +121,18 @@ async def call( cohere_embeddings = CohereEmbeddings(cohere_api_key=self.COHERE_API_KEY) # Load text files and split into chunks - loader = PyPDFLoader(self.filepath) - text_splitter = CharacterTextSplitter(chunk_size=300, chunk_overlap=0) - pages = loader.load_and_split(text_splitter) - - # Create a vector store from the documents - db = Chroma.from_documents(documents=pages, embedding=cohere_embeddings) - query = parameters.get("query", "") - input_docs = db.as_retriever().get_relevant_documents(query) + try: + loader = PyPDFLoader(self.filepath) + text_splitter = CharacterTextSplitter(chunk_size=300, chunk_overlap=0) + pages = loader.load_and_split(text_splitter) + + # Create a vector store from the documents + db = Chroma.from_documents(documents=pages, embedding=cohere_embeddings) + query = parameters.get("query", "") + input_docs = db.as_retriever().get_relevant_documents(query) + except Exception as e: + return self.get_tool_error(details=str(e)) + if not input_docs: + return self.get_no_results_error() return [{"text": doc.page_content} for doc in input_docs] diff --git a/src/backend/tools/python_interpreter.py b/src/backend/tools/python_interpreter.py index 426844ab48..e7015703f9 100644 --- a/src/backend/tools/python_interpreter.py +++ b/src/backend/tools/python_interpreter.py @@ -57,9 +57,15 @@ async def call(self, parameters: dict, ctx: Any, **kwargs: Any): raise Exception("Python Interpreter tool called while URL not set") code = parameters.get("code", "") - res = requests.post(self.INTERPRETER_URL, json={"code": code}) + try: + res = requests.post(self.INTERPRETER_URL, json={"code": code}) + clean_res = self._clean_response(res.json()) + except Exception as e: + return self.get_tool_error(details=str(e)) + + if not clean_res: + return self.get_no_results_error() - clean_res = self._clean_response(res.json()) return clean_res def _clean_response(self, result: Any) -> Dict[str, str]: @@ -82,7 +88,8 @@ def _clean_response(self, result: Any) -> Dict[str, str]: r.setdefault("text", r.get("std_out")) elif r.get("success") is False: error_message = r.get("error", {}).get("message", "") - r.setdefault("text", error_message) + # r.setdefault("text", error_message) + return self.get_tool_error(details=error_message) elif r.get("output_file") and r.get("output_file").get("filename"): if r["output_file"]["filename"] != "": r.setdefault( diff --git a/src/backend/tools/slack/tool.py b/src/backend/tools/slack/tool.py index 35e78f0aea..20c9616374 100644 --- a/src/backend/tools/slack/tool.py +++ b/src/backend/tools/slack/tool.py @@ -68,6 +68,14 @@ async def call(self, parameters: dict, ctx: Any, **kwargs: Any) -> List[Dict[str # Search Slack slack_service = get_slack_service(user_id=user_id, search_limit=SEARCH_LIMIT) - all_results = slack_service.search_all(query=query) - return slack_service.serialize_results(all_results) + try: + all_results = slack_service.search_all(query=query) + results = slack_service.serialize_results(all_results) + except Exception as e: + return self.get_tool_error(details=str(e)) + + if not results: + return self.get_no_results_error() + + return results diff --git a/src/backend/tools/tavily_search.py b/src/backend/tools/tavily_search.py index b8746c74c0..24ef7f3f94 100644 --- a/src/backend/tools/tavily_search.py +++ b/src/backend/tools/tavily_search.py @@ -58,7 +58,7 @@ async def call( ) except Exception as e: logger.error(f"Failed to perform Tavily web search: {str(e)}") - raise Exception(f"Failed to perform Tavily web search: {str(e)}") + return self.get_tool_error(details=str(e)) results = result.get("results", []) diff --git a/src/community/tools/arxiv.py b/src/community/tools/arxiv.py index ce5cfac71c..8293d8b681 100644 --- a/src/community/tools/arxiv.py +++ b/src/community/tools/arxiv.py @@ -38,5 +38,12 @@ def get_tool_definition(cls) -> ToolDefinition: async def call(self, parameters: dict, **kwargs: Any) -> List[Dict[str, Any]]: query = parameters.get("query", "") - result = self.client.run(query) + try: + result = self.client.run(query) + except Exception as e: + return self.get_tool_error(details=str(e)) + + if not result: + return self.get_no_results_error() + return [{"text": result}] diff --git a/src/community/tools/clinicaltrials.py b/src/community/tools/clinicaltrials.py index 3db15271ac..e9c6c4fd16 100644 --- a/src/community/tools/clinicaltrials.py +++ b/src/community/tools/clinicaltrials.py @@ -75,9 +75,13 @@ async def call( response = requests.get(self._url, params=query_params) response.raise_for_status() except requests.exceptions.RequestException as e: - return [{"text": f"Could not retrieve studies: {str(e)}"}] + return self.get_tool_error(details=str(e)) - return self._parse_response(response, location, intervention) + results = self._parse_response(response, location, intervention) + if not results: + return self.get_no_results_error() + + return results def _parse_response( self, response: requests.Response, location: str, intervention: str diff --git a/src/community/tools/connector.py b/src/community/tools/connector.py index b19445ddad..a8edfa549b 100644 --- a/src/community/tools/connector.py +++ b/src/community/tools/connector.py @@ -47,7 +47,13 @@ async def call(self, parameters: dict, **kwargs: Any) -> List[Dict[str, Any]]: "Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}", } + try: + response = requests.get(self.url, json=body, headers=headers) + results = response.json()["results"] + except Exception as e: + return self.get_tool_error(details=str(e)) - response = requests.get(self.url, json=body, headers=headers) + if not results: + return self.get_no_results_error() - return response.json()["results"] + return results diff --git a/src/community/tools/llama_index.py b/src/community/tools/llama_index.py index aafdc1b491..d2dfa2f601 100644 --- a/src/community/tools/llama_index.py +++ b/src/community/tools/llama_index.py @@ -87,35 +87,34 @@ async def call( file_ids = [file_id for _, file_id in files] retrieved_files = file_crud.get_files_by_ids(session, file_ids, user_id) if not retrieved_files: - return [] + return self.get_no_results_error() - all_results = [] file_str_list = [] for file in retrieved_files: file_str_list.append(file.file_content) - all_results.append( - { - "text": file.file_content, - "title": file.file_name, - "url": file.file_name, - } - ) # LLamaIndex get documents from parsed PDFs, split it into sentences, embed, index and retrieve - docs = StringIterableReader().load_data(file_str_list) - node_parser = SentenceSplitter(chunk_size=LlamaIndexUploadPDFRetriever.CHUNK_SIZE) - nodes = node_parser.get_nodes_from_documents(docs) - embed_model = self._get_embedding("search_document") - vector_index = VectorStoreIndex( - nodes, - embed_model=embed_model, - ) - embed_model_query = self._get_embedding("search_query") - retriever = vector_index.as_retriever( - similarity_top_k=10, - embed_model=embed_model_query, - ) - results = retriever.retrieve(query) - llama_results = [{"text": doc.text} for doc in results] + try: + docs = StringIterableReader().load_data(file_str_list) + node_parser = SentenceSplitter(chunk_size=LlamaIndexUploadPDFRetriever.CHUNK_SIZE) + nodes = node_parser.get_nodes_from_documents(docs) + embed_model = self._get_embedding("search_document") + vector_index = VectorStoreIndex( + nodes, + embed_model=embed_model, + ) + embed_model_query = self._get_embedding("search_query") + retriever = vector_index.as_retriever( + similarity_top_k=10, + embed_model=embed_model_query, + ) + results = retriever.retrieve(query) + llama_results = [{"text": doc.text} for doc in results] + except Exception as e: + return self.get_tool_error(details=str(e)) + + if not llama_results and not docs: + return self.get_no_results_error() + # If llama results are found, return them if llama_results: return llama_results diff --git a/src/community/tools/pub_med.py b/src/community/tools/pub_med.py index 6968e57ea3..7680194fb0 100644 --- a/src/community/tools/pub_med.py +++ b/src/community/tools/pub_med.py @@ -38,5 +38,11 @@ def get_tool_definition(cls) -> ToolDefinition: async def call(self, parameters: dict, **kwargs: Any) -> List[Dict[str, Any]]: query = parameters.get("query", "") - result = self.client.invoke(query) + try: + result = self.client.invoke(query) + except Exception as e: + return self.get_tool_error(details=str(e)) + if not result: + return self.get_no_results_error() + return [{"text": result}] diff --git a/src/community/tools/wolfram.py b/src/community/tools/wolfram.py index dc098e77ed..dc4c27e22a 100644 --- a/src/community/tools/wolfram.py +++ b/src/community/tools/wolfram.py @@ -41,5 +41,12 @@ def get_tool_definition(cls) -> ToolDefinition: async def call(self, parameters: dict, **kwargs: Any) -> List[Dict[str, Any]]: to_evaluate = parameters.get("expression", "") - result = self.tool.run(to_evaluate) + try: + result = self.tool.run(to_evaluate) + except Exception as e: + return self.get_tool_error(details=str(e)) + + if not result: + return self.get_no_results_error() + return {"result": result, "text": result} diff --git a/src/interfaces/assistants_web/src/components/MessageRow/ToolEvents.tsx b/src/interfaces/assistants_web/src/components/MessageRow/ToolEvents.tsx index 2a719184c8..3d65dcd722 100644 --- a/src/interfaces/assistants_web/src/components/MessageRow/ToolEvents.tsx +++ b/src/interfaces/assistants_web/src/components/MessageRow/ToolEvents.tsx @@ -5,7 +5,7 @@ import { Fragment, PropsWithChildren } from 'react'; import { StreamSearchResults, StreamToolCallsGeneration, ToolCall } from '@/cohere-client'; import { Markdown } from '@/components/Markdown'; -import { Icon, IconName, Text } from '@/components/UI'; +import { Icon, IconButton, IconName, Text, Tooltip } from '@/components/UI'; import { TOOL_CALCULATOR_ID, TOOL_GOOGLE_DRIVE_ID, @@ -21,6 +21,13 @@ type Props = { events: StreamToolCallsGeneration[] | undefined; }; +const hasToolErrorsDocuments = (search_results: StreamSearchResults | null) => { + return search_results?.documents?.some((document) => document.fields?.success === 'false'); +}; + +const getErrorDocumentsFromEvent = (search_results: StreamSearchResults | null) => + search_results?.documents?.filter((document) => document.fields?.success === 'false') || []; + /** * @description Renders a list of events depending on the model's plan and tool inputs. */ @@ -74,6 +81,7 @@ const ToolEvent: React.FC = ({ plan, event, stream_search_result if (plan) { return {plan}; } + const toolName = event?.name || ''; if (stream_search_results) { const artifacts = @@ -85,7 +93,16 @@ const ToolEvent: React.FC = ({ plan, event, stream_search_result .filter((value, index, self) => index === self.findIndex((t) => t.title === value.title)) || []; - return ( + const hasErrorsDocuments = hasToolErrorsDocuments(stream_search_results); + const errorDocuments = getErrorDocumentsFromEvent(stream_search_results); + + return hasErrorsDocuments ? ( + + {errorDocuments[errorDocuments.length - 1].text} + + ) : toolName && toolName != TOOL_PYTHON_INTERPRETER_ID ? ( {artifacts.length > 0 ? ( <> @@ -108,10 +125,9 @@ const ToolEvent: React.FC = ({ plan, event, stream_search_result <>No resources found. )} - ); + ) : null; } - const toolName = event?.name || ''; const icon = getToolIcon(toolName); switch (toolName) { @@ -193,6 +209,25 @@ const ToolEventWrapper: React.FC> = ({ ); }; +const ToolErrorWrapper: React.FC> = ({ + tooltip = 'Some error occurred', + children, +}) => { + return ( +
+ + + {children} + +
+ ); +}; + const truncateString = (str: string, max_length: number = 50) => { return str.length < max_length ? str : str.substring(0, max_length) + '...'; }; diff --git a/src/interfaces/assistants_web/src/hooks/use-chat.ts b/src/interfaces/assistants_web/src/hooks/use-chat.ts index 3fdf47afee..e1ed63964d 100644 --- a/src/interfaces/assistants_web/src/hooks/use-chat.ts +++ b/src/interfaces/assistants_web/src/hooks/use-chat.ts @@ -260,19 +260,13 @@ export const useChat = (config?: { onSend?: (msg: string) => void }) => { outputFiles = { ...outputFiles, ...newOutputFilesMap }; saveOutputFiles({ ...savedOutputFiles, ...outputFiles }); - // we are only interested in web_search results - // ignore search results of pyhton interpreter tool - if ( - toolEvents[currentToolEventIndex - 1]?.tool_calls?.[0]?.name !== - TOOL_PYTHON_INTERPRETER_ID - ) { - toolEvents.push({ - text: '', - stream_search_results: data, - tool_calls: [], - } as StreamToolCallsGeneration); - currentToolEventIndex += 1; - } + toolEvents.push({ + text: '', + stream_search_results: data, + tool_calls: [], + } as StreamToolCallsGeneration); + currentToolEventIndex += 1; + break; }