From e0af79f0eeabbfc988f0b219f09917e7616b39e0 Mon Sep 17 00:00:00 2001 From: Tianjing Li Date: Thu, 7 Nov 2024 09:54:36 -0800 Subject: [PATCH 01/14] chore: Remove merge tests (#832) Remove all merge tests except integration --- .github/workflows/backend_unit_tests.yml | 1 - .github/workflows/frontend_assistants_web_tests.yml | 1 - .github/workflows/frontend_coral_web_tests.yml | 1 - .github/workflows/frontend_slack_bot_tests.yml | 1 - 4 files changed, 4 deletions(-) diff --git a/.github/workflows/backend_unit_tests.yml b/.github/workflows/backend_unit_tests.yml index 7c7dd7f59a..d44f747f5c 100644 --- a/.github/workflows/backend_unit_tests.yml +++ b/.github/workflows/backend_unit_tests.yml @@ -4,7 +4,6 @@ on: push: branches: [main] pull_request: {} - merge_group: {} jobs: pytest: diff --git a/.github/workflows/frontend_assistants_web_tests.yml b/.github/workflows/frontend_assistants_web_tests.yml index b73a6b95b4..5ffbb5eee6 100644 --- a/.github/workflows/frontend_assistants_web_tests.yml +++ b/.github/workflows/frontend_assistants_web_tests.yml @@ -6,7 +6,6 @@ on: paths: - src/interfaces/assistants_web/** pull_request: {} - merge_group: {} jobs: interface_tests: diff --git a/.github/workflows/frontend_coral_web_tests.yml b/.github/workflows/frontend_coral_web_tests.yml index afdba43aab..c45556ba2d 100644 --- a/.github/workflows/frontend_coral_web_tests.yml +++ b/.github/workflows/frontend_coral_web_tests.yml @@ -6,7 +6,6 @@ on: paths: - src/interfaces/coral_web/** pull_request: {} - merge_group: {} jobs: interface_tests: diff --git a/.github/workflows/frontend_slack_bot_tests.yml b/.github/workflows/frontend_slack_bot_tests.yml index 0b7009224a..b87ecaa5d0 100644 --- a/.github/workflows/frontend_slack_bot_tests.yml +++ b/.github/workflows/frontend_slack_bot_tests.yml @@ -6,7 +6,6 @@ on: paths: - src/interfaces/slack_bot/** pull_request: {} - merge_group: {} jobs: interface_tests: From 4b1b8cdcd4afb01c0e2d767848ba7dcac83caa2b Mon Sep 17 00:00:00 2001 From: Tianjing Li Date: Fri, 8 Nov 2024 07:18:55 -0800 Subject: [PATCH 02/14] backend: Tool config major refactoring (#822) * replace with iD * wip * wip * Wip * more wip * Fix coral web unit tests * Resolve coral_web tests * Fix last case * Fix coral-web test * Refactor get_available_tools to use function directly * wip * wip * wip, community tools todo * Fix tests * Lint * wip * Fix all unit tests * Remove makefile change * Fix unit tests * Fixing tests wip * Fix lint * Remove pdbs * remove merge group --- .../workflows/backend_integration_tests.yml | 1 - docs/config_details/config_description.md | 2 - docs/custom_tool_guides/tool_guide.md | 50 ++- docs/how_to_guides.md | 2 +- src/backend/chat/custom/custom.py | 16 +- src/backend/chat/custom/tool_calls.py | 4 +- src/backend/cli/constants.py | 18 +- src/backend/cli/main.py | 5 +- .../config/configuration.template.yaml | 8 - src/backend/config/settings.py | 1 - src/backend/config/tools.py | 291 ++---------------- src/backend/main.py | 5 + src/backend/pytest.ini | 4 +- src/backend/routers/auth.py | 17 +- src/backend/routers/organization.py | 5 +- src/backend/routers/tool.py | 15 +- src/backend/schemas/tool.py | 17 +- src/backend/services/auth/strategies/base.py | 7 +- src/backend/services/chat.py | 8 +- src/backend/services/file.py | 6 +- src/backend/services/request_validators.py | 8 +- .../tests/integration/routers/test_agent.py | 14 +- .../tests/unit/chat/test_tool_calls.py | 146 ++++----- .../tests/unit/config/test_deployments.py | 15 +- src/backend/tests/unit/config/test_tools.py | 0 src/backend/tests/unit/configuration.yaml | 1 - src/backend/tests/unit/crud/test_agent.py | 16 +- .../unit/crud/test_agent_tool_metadata.py | 36 ++- src/backend/tests/unit/crud/test_tool_auth.py | 6 +- src/backend/tests/unit/factories/agent.py | 14 +- .../unit/factories/agent_tool_metadata.py | 16 +- src/backend/tests/unit/factories/tool_auth.py | 4 +- src/backend/tests/unit/routers/test_agent.py | 36 +-- src/backend/tests/unit/routers/test_chat.py | 59 +--- src/backend/tests/unit/routers/test_tool.py | 26 +- src/backend/tools/base.py | 42 ++- src/backend/tools/brave_search/tool.py | 26 +- src/backend/tools/calculator.py | 23 +- src/backend/tools/files.py | 58 +++- src/backend/tools/google_drive/tool.py | 27 +- src/backend/tools/google_search.py | 23 +- src/backend/tools/hybrid_search.py | 28 +- src/backend/tools/lang_chain.py | 28 +- src/backend/tools/python_interpreter.py | 31 +- src/backend/tools/slack/tool.py | 25 +- src/backend/tools/tavily_search.py | 52 +++- src/backend/tools/utils/mixins.py | 2 +- src/backend/tools/utils/tools_checkers.py | 20 +- src/backend/tools/web_scrape.py | 28 +- src/community/config/tools.py | 140 +-------- src/community/tools/__init__.py | 5 - src/community/tools/arxiv.py | 25 +- src/community/tools/clinicaltrials.py | 40 ++- src/community/tools/connector.py | 30 +- src/community/tools/llama_index.py | 35 ++- src/community/tools/pub_med.py | 25 +- src/community/tools/wolfram.py | 18 +- .../src/app/(main)/(chat)/Chat.tsx | 6 +- .../cohere-client/generated/schemas.gen.ts | 228 +++++++------- .../cohere-client/generated/services.gen.ts | 11 +- .../src/cohere-client/generated/types.gen.ts | 56 ++-- .../AgentSettingsForm/ToolsStep.tsx | 4 +- .../src/components/Composer/Composer.tsx | 6 +- .../components/Composer/ComposerToolbar.tsx | 4 +- .../components/Composer/DataSourceMenu.tsx | 6 +- .../components/Conversation/Conversation.tsx | 4 +- .../MessagingContainer/AssistantTools.tsx | 6 +- .../assistants_web/src/hooks/use-tools.ts | 12 +- 68 files changed, 995 insertions(+), 958 deletions(-) create mode 100644 src/backend/tests/unit/config/test_tools.py diff --git a/.github/workflows/backend_integration_tests.yml b/.github/workflows/backend_integration_tests.yml index 9346fd8a02..c63dbeed9c 100644 --- a/.github/workflows/backend_integration_tests.yml +++ b/.github/workflows/backend_integration_tests.yml @@ -3,7 +3,6 @@ name: Backend - Integration tests on: push: branches: [main] - merge_group: {} jobs: pytest: diff --git a/docs/config_details/config_description.md b/docs/config_details/config_description.md index ca80b1ba04..785eb9c936 100644 --- a/docs/config_details/config_description.md +++ b/docs/config_details/config_description.md @@ -23,8 +23,6 @@ - redis - Redis configurations - url - URL of the redis, for example, redis://:redis@redis:6379 - tools - Tool configurations - - enabled_tools - these are the tools that are enabled for the toolkit. The full list of tools can be found in the src/backend/config/tools.py file. - The community tools are listed in the src/community/config/tools.py file. Please note that the tools availability is checked too. - python_interpreter - Python interpreter configurations - url - URL of the python interpreter tool - feature_flags - Feature flags configurations diff --git a/docs/custom_tool_guides/tool_guide.md b/docs/custom_tool_guides/tool_guide.md index d8c37d57b9..e1eff9e9ed 100644 --- a/docs/custom_tool_guides/tool_guide.md +++ b/docs/custom_tool_guides/tool_guide.md @@ -53,7 +53,7 @@ from community.tools import BaseTool class ArxivRetriever(BaseTool): - NAME = "arxiv" + ID = "arxiv" def __init__(self): self.client = ArxivAPIWrapper() @@ -64,6 +64,27 @@ class ArxivRetriever(BaseTool): def is_available(cls) -> bool: return True + @classmethod + # You will need to add a tool definition here + def get_tool_definition(cls) -> ToolDefinition: + return ToolDefinition( + name=cls.ID, + display_name="Arxiv", + implementation=cls, + parameter_definitions={ + "query": { + "description": "Query for retrieval.", + "type": "str", + "required": True, + } + }, + is_visible=False, + is_available=cls.is_available(), + error_message=cls.generate_error_message(), + category=ToolCategory.DataLoader, + description="Retrieves documents from Arxiv.", + ) + # Your tool needs to implement this call() method def call(self, parameters: str, **kwargs: Any) -> List[Dict[str, Any]]: result = self.client.run(parameters) @@ -84,34 +105,9 @@ return [{"text": "The fox is blue", "url": "wikipedia.org/foxes", "title": "Colo Next, add your tool class to the init file by locating it in `src/community/tools/__init__.py`. Import your tool here, then add it to the `__all__` list. -To enable your tool, you will need to go to the `configuration.yaml` file and add your tool's name to the list of `enabled_tools`. This tool name will correspond to the one defined in the `NAME` attribute of your class. - Finally, you will need to add your tool definition to the config file. Locate it in `src/community/config/tools.py`, and import your tool at the top with `from backend.tools import ..`. -In the ToolName enum, add your tool as an enum value. For example, `My_Tool = MyTool.NAME`. - -In the `ALL_TOOLS` dictionary, add your tool definition. This should look like: - -```python - ToolName.My_Tool: ManagedTool( # THE TOOLNAME HERE CORRESPONDS TO THE ENUM YOU DEFINED EARLIER - display_name="My Tool", - implementation=MyTool, # THIS IS THE CLASS YOU IMPORTED AT THE TOP - parameter_definitions={ # THESE ARE PARAMS THE MODEL WILL SEND TO YOUR TOOL, ADJUST AS NEEDED - "query": { - "description": "Query to search with", - "type": "str", - "required": True, - } - }, - is_visible=True, - is_available=MyTool.is_available(), - auth_implementation=None, # EMPTY IF NO AUTH NEEDED - error_message="Something went wrong", - category=Category.DataLoader, # CHECK CATEGORY ENUM FOR POSSIBLE VALUES - description="An example definition to get you started.", - ), -``` - +Finally, to enable your tool, add your tool as an enum value. For example, `My_Tool = MyToolClass`. ## Step 5: Test Your Tool! diff --git a/docs/how_to_guides.md b/docs/how_to_guides.md index b498f8cb66..5aee501b91 100644 --- a/docs/how_to_guides.md +++ b/docs/how_to_guides.md @@ -48,7 +48,7 @@ The core chat interface is the Coral frontend. To implement your own interface: If you have already created a [connector](https://docs.cohere.com/docs/connectors), you can utilize it within the toolkit by following these steps: 1. Configure your connector using `ConnectorRetriever`. -2. Add its definition in [community/config/tools.py](https://github.com/cohere-ai/cohere-toolkit/blob/main/src/community/config/tools.py), following the `Arxiv` implementation, using the category `Category.DataLoader`. +2. Add its definition in [community/config/tools.py](https://github.com/cohere-ai/cohere-toolkit/blob/main/src/community/config/tools.py), following the `Arxiv` implementation, using the category `ToolCategory.DataLoader`. You can now use both the Coral frontend and API with your connector. diff --git a/src/backend/chat/custom/custom.py b/src/backend/chat/custom/custom.py index 16ce04e3b5..1fb5e78240 100644 --- a/src/backend/chat/custom/custom.py +++ b/src/backend/chat/custom/custom.py @@ -6,20 +6,19 @@ from backend.chat.custom.tool_calls import async_call_tools from backend.chat.custom.utils import get_deployment from backend.chat.enums import StreamEvent -from backend.config.tools import AVAILABLE_TOOLS +from backend.config.tools import get_available_tools from backend.database_models.file import File from backend.model_deployments.base import BaseDeployment from backend.schemas.chat import ChatMessage, ChatRole, EventState from backend.schemas.cohere_chat import CohereChatRequest from backend.schemas.context import Context -from backend.schemas.tool import Category, Tool +from backend.schemas.tool import Tool, ToolCategory from backend.services.chat import check_death_loop from backend.services.file import get_file_service from backend.tools.utils.tools_checkers import tool_has_category MAX_STEPS = 15 - class CustomChat(BaseChat): """Custom chat flow not using integrations for models.""" @@ -163,7 +162,7 @@ async def call_chat( file_reader_tools_names = [] if managed_tools: chat_request.tools = managed_tools - file_reader_tools_names = [tool.name for tool in managed_tools_full_schema if tool_has_category(tool, Category.FileLoader)] + file_reader_tools_names = [tool.name for tool in managed_tools_full_schema if tool_has_category(tool, ToolCategory.FileLoader)] # Get files if available all_files = [] @@ -248,17 +247,18 @@ def update_chat_history_with_tool_results( chat_request.chat_history.extend(tool_results) def get_managed_tools(self, chat_request: CohereChatRequest, full_schema=False): + available_tools = get_available_tools() if full_schema: return [ - AVAILABLE_TOOLS.get(tool.name) + available_tools.get(tool.name) for tool in chat_request.tools - if AVAILABLE_TOOLS.get(tool.name) + if available_tools.get(tool.name) ] return [ - Tool(**AVAILABLE_TOOLS.get(tool.name).model_dump()) + Tool(**available_tools.get(tool.name).model_dump()) for tool in chat_request.tools - if AVAILABLE_TOOLS.get(tool.name) + if available_tools.get(tool.name) ] def add_files_to_chat_history( diff --git a/src/backend/chat/custom/tool_calls.py b/src/backend/chat/custom/tool_calls.py index c5fab560bf..2c003878d5 100644 --- a/src/backend/chat/custom/tool_calls.py +++ b/src/backend/chat/custom/tool_calls.py @@ -5,7 +5,7 @@ from sqlalchemy.orm import Session from backend.chat.collate import rerank_and_chunk, to_dict -from backend.config.tools import AVAILABLE_TOOLS +from backend.config.tools import get_available_tools from backend.model_deployments.base import BaseDeployment from backend.schemas.context import Context from backend.services.logger.utils import LoggerFactory @@ -76,7 +76,7 @@ async def _call_tool_async( tool_call: dict, deployment_model: BaseDeployment, ) -> List[Dict[str, Any]]: - tool = AVAILABLE_TOOLS.get(tool_call["name"]) + tool = get_available_tools().get(tool_call["name"]) if not tool: logger.info( event=f"[Custom Chat] Tool not included in tools parameter: {tool_call['name']}", diff --git a/src/backend/cli/constants.py b/src/backend/cli/constants.py index e3f7dbd642..5d3facd9bd 100644 --- a/src/backend/cli/constants.py +++ b/src/backend/cli/constants.py @@ -22,9 +22,11 @@ class BuildTarget(StrEnum): PROD = "prod" -class ToolName(StrEnum): +class Tool(StrEnum): PythonInterpreter = "Python Interpreter" TavilyInternetSearch = "Tavily Internet Search" + Wolfram_Alpha = "Wolfram Alpha" + WELCOME_MESSAGE = r""" @@ -50,18 +52,28 @@ class ToolName(StrEnum): TOOLS = { - ToolName.PythonInterpreter: { + Tool.PythonInterpreter: { "secrets": { "PYTHON_INTERPRETER_URL": PYTHON_INTERPRETER_URL_DEFAULT, }, }, - ToolName.TavilyInternetSearch: { + Tool.TavilyInternetSearch: { "secrets": { "TAVILY_API_KEY": None, }, }, } +# For main.py cli setup script +COMMUNITY_TOOLS = { + Tool.Wolfram_Alpha: { + "secrets": { + "WOLFRAM_APP_ID": None, # default value + }, + }, +} + + ENV_YAML_CONFIG_MAPPING = { "USE_COMMUNITY_FEATURES": { "type": "config", diff --git a/src/backend/cli/main.py b/src/backend/cli/main.py index 9dbe44a62a..40d6eb7e79 100755 --- a/src/backend/cli/main.py +++ b/src/backend/cli/main.py @@ -1,6 +1,6 @@ import argparse -from backend.cli.constants import TOOLS +from backend.cli.constants import COMMUNITY_TOOLS, TOOLS from backend.cli.prompts import ( PROMPTS, community_tools_prompt, @@ -23,7 +23,6 @@ from community.config.deployments import ( AVAILABLE_MODEL_DEPLOYMENTS as COMMUNITY_DEPLOYMENTS_SETUP, ) -from community.config.tools import COMMUNITY_TOOLS_SETUP def start(): @@ -43,7 +42,7 @@ def start(): # ENABLE COMMUNITY TOOLS use_community_features = args.use_community and community_tools_prompt(secrets) if use_community_features: - TOOLS.update(COMMUNITY_TOOLS_SETUP) + TOOLS.update(COMMUNITY_TOOLS) # SET UP TOOLS for name, configs in TOOLS.items(): diff --git a/src/backend/config/configuration.template.yaml b/src/backend/config/configuration.template.yaml index b5ffb72e17..a2034de59a 100644 --- a/src/backend/config/configuration.template.yaml +++ b/src/backend/config/configuration.template.yaml @@ -20,14 +20,6 @@ database: redis: url: redis://:redis@redis:6379 tools: - enabled_tools: - - wikipedia - - search_file - - read_file - - toolkit_python_interpreter - - toolkit_calculator - - hybrid_web_search - - web_scrape hybrid_web_search: # List of web search tool names, eg: google_web_search, tavily_web_search enabled_web_searches: diff --git a/src/backend/config/settings.py b/src/backend/config/settings.py index a60be0942b..63e1853b7c 100644 --- a/src/backend/config/settings.py +++ b/src/backend/config/settings.py @@ -195,7 +195,6 @@ class HybridWebSearchSettings(BaseSettings, BaseModel): class ToolSettings(BaseSettings, BaseModel): model_config = SETTINGS_CONFIG - enabled_tools: Optional[List[str]] = None python_interpreter: Optional[PythonToolSettings] = Field( default=PythonToolSettings() diff --git a/src/backend/config/tools.py b/src/backend/config/tools.py index 6a5d7a13b4..082978adaf 100644 --- a/src/backend/config/tools.py +++ b/src/backend/config/tools.py @@ -1,20 +1,18 @@ -from enum import StrEnum +from enum import Enum from backend.config.settings import Settings -from backend.schemas.tool import Category, ManagedTool +from backend.schemas.tool import ToolDefinition from backend.services.logger.utils import LoggerFactory from backend.tools import ( BraveWebSearch, Calculator, GoogleDrive, - GoogleDriveAuth, GoogleWebSearch, HybridWebSearch, LangChainWikiRetriever, PythonInterpreter, ReadFileTool, SearchFileTool, - SlackAuth, SlackTool, TavilyWebSearch, WebScrapeTool, @@ -23,270 +21,41 @@ logger = LoggerFactory().get_logger() """ -List of available tools. Each tool should have a name, implementation, is_visible and category. -They can also have kwargs if necessary. - -You can switch the visibility of a tool by changing the is_visible parameter to True or False. -If a tool is not visible, it will not be shown in the frontend. - -If you want to add a new tool, check the instructions on how to implement a retriever in the documentation. -Don't forget to add the implementation to this AVAILABLE_TOOLS dictionary! +Tool Name enum, mapping to the tool's main implementation class. """ - -class ToolName(StrEnum): - Wiki_Retriever_LangChain = LangChainWikiRetriever.NAME - Search_File = SearchFileTool.NAME - Read_File = ReadFileTool.NAME - Python_Interpreter = PythonInterpreter.NAME - Calculator = Calculator.NAME - Google_Drive = GoogleDrive.NAME - Web_Scrape = WebScrapeTool.NAME - Tavily_Web_Search = TavilyWebSearch.NAME - Google_Web_Search = GoogleWebSearch.NAME - Brave_Web_Search = BraveWebSearch.NAME - Hybrid_Web_Search = HybridWebSearch.NAME - Slack = SlackTool.NAME - - -ALL_TOOLS = { - ToolName.Search_File: ManagedTool( - display_name="Search File", - implementation=SearchFileTool, - parameter_definitions={ - "search_query": { - "description": "Textual search query to search over the file's content for", - "type": "str", - "required": True, - }, - "files": { - "description": "A list of files represented as tuples of (filename, file ID) to search over", - "type": "list[tuple[str, str]]", - "required": True, - }, - }, - is_visible=True, - is_available=SearchFileTool.is_available(), - error_message="SearchFileTool not available.", - category=Category.FileLoader, - description="Performs a search over a list of one or more of the attached files for a textual search query", - ), - ToolName.Read_File: ManagedTool( - display_name="Read Document", - implementation=ReadFileTool, - parameter_definitions={ - "file": { - "description": "A file represented as a tuple (filename, file ID) to read over", - "type": "tuple[str, str]", - "required": True, - } - }, - is_visible=True, - is_available=ReadFileTool.is_available(), - error_message="ReadFileTool not available.", - category=Category.FileLoader, - description="Returns the textual contents of an uploaded file, broken up in text chunks.", - ), - ToolName.Python_Interpreter: ManagedTool( - display_name="Python Interpreter", - implementation=PythonInterpreter, - parameter_definitions={ - "code": { - "description": ( - "Python code to execute using the Python interpreter with no internet access. " - "Do not generate code that tries to open files directly, instead use file contents passed to the interpreter, " - "then print output or save output to a file." - ), - "type": "str", - "required": True, - } - }, - is_visible=True, - is_available=PythonInterpreter.is_available(), - error_message="PythonInterpreterFunctionTool not available, please make sure to set the tools.python_interpreter.url variable in your configuration.yaml", - category=Category.Function, - description="Runs python code in a sandbox.", - ), - ToolName.Wiki_Retriever_LangChain: ManagedTool( - display_name="Wikipedia", - implementation=LangChainWikiRetriever, - parameter_definitions={ - "query": { - "description": "Query for retrieval.", - "type": "str", - "required": True, - } - }, - kwargs={"chunk_size": 300, "chunk_overlap": 0}, - is_visible=True, - is_available=LangChainWikiRetriever.is_available(), - error_message="LangChainWikiRetriever not available.", - category=Category.DataLoader, - description="Retrieves documents from Wikipedia using LangChain.", - ), - ToolName.Calculator: ManagedTool( - display_name="Calculator", - implementation=Calculator, - parameter_definitions={ - "code": { - "description": "The expression for the calculator to evaluate, it should be a valid mathematical expression.", - "type": "str", - "required": True, - } - }, - is_visible=False, - is_available=Calculator.is_available(), - error_message="Calculator tool not available.", - category=Category.Function, - description="This is a powerful multi-purpose calculator which is capable of a wide array of math calculations.", - ), - ToolName.Google_Drive: ManagedTool( - display_name="Google Drive", - implementation=GoogleDrive, - parameter_definitions={ - "query": { - "description": "Query to search Google Drive documents with.", - "type": "str", - "required": True, - } - }, - is_visible=True, - is_available=GoogleDrive.is_available(), - auth_implementation=GoogleDriveAuth, - error_message="Google Drive not available, please enable it in the GoogleDrive tool class.", - category=Category.DataLoader, - description="Returns a list of relevant document snippets for the user's google drive.", - ), - ToolName.Web_Scrape: ManagedTool( - name=ToolName.Web_Scrape, - display_name="Web Scrape", - implementation=WebScrapeTool, - parameter_definitions={ - "url": { - "description": "The url to scrape.", - "type": "str", - "required": True, - }, - "query": { - "description": "The query to use to select the most relevant passages to return. Using an empty string will return the passages in the order they appear on the webpage", - "type": "str", - "required": False, - }, - }, - is_visible=True, - is_available=WebScrapeTool.is_available(), - error_message="WebScrapeTool not available.", - category=Category.DataLoader, - description="Scrape and returns the textual contents of a webpage as a list of passages for a given url.", - ), - ToolName.Tavily_Web_Search: ManagedTool( - display_name="Web Search", - implementation=TavilyWebSearch, - parameter_definitions={ - "query": { - "description": "Query to search the internet with", - "type": "str", - "required": True, - } - }, - is_visible=False, - is_available=TavilyWebSearch.is_available(), - error_message="TavilyWebSearch not available, please make sure to set the tools.tavily_web_search.api_key variable in your secrets.yaml", - category=Category.WebSearch, - description="Returns a list of relevant document snippets for a textual query retrieved from the internet.", - ), - ToolName.Google_Web_Search: ManagedTool( - display_name="Google Web Search", - implementation=GoogleWebSearch, - parameter_definitions={ - "query": { - "description": "A search query for the Google search engine.", - "type": "str", - "required": True, - } - }, - is_visible=False, - is_available=GoogleWebSearch.is_available(), - error_message="Google Web Search not available, please enable it in the GoogleWebSearch tool class.", - category=Category.WebSearch, - description="Returns relevant results by performing a Google web search.", - ), - ToolName.Brave_Web_Search: ManagedTool( - display_name="Brave Web Search", - implementation=BraveWebSearch, - parameter_definitions={ - "query": { - "description": "Query for retrieval.", - "type": "str", - "required": True, - } - }, - is_visible=False, - is_available=BraveWebSearch.is_available(), - error_message="BraveWebSearch not available, please make sure to set the tools.brave_web_search.api_key variable in your secrets.yaml", - category=Category.WebSearch, - description="Returns a list of relevant document snippets for a textual query retrieved from the internet using Brave Search.", - ), - ToolName.Hybrid_Web_Search: ManagedTool( - display_name="Hybrid Web Search", - implementation=HybridWebSearch, - parameter_definitions={ - "query": { - "description": "Query for retrieval.", - "type": "str", - "required": True, - } - }, - is_visible=True, - is_available=HybridWebSearch.is_available(), - error_message="HybridWebSearch not available, please make sure to set at least one option in the tools.hybrid_web_search.enabled_web_searches variable in your configuration.yaml", - category=Category.WebSearch, - description="Returns a list of relevant document snippets for a textual query retrieved from the internet using a mix of any existing Web Search tools.", - ), - ToolName.Slack: ManagedTool( - display_name="Slack", - implementation=SlackTool, - parameter_definitions={ - "query": { - "description": "Query to search slack.", - "type": "str", - "required": True, - } - }, - is_visible=True, - is_available=SlackTool.is_available(), - auth_implementation=SlackAuth, - error_message="SlackTool not available, please enable it in the SlackTool class.", - category=Category.DataLoader, - description="Returns a list of relevant document snippets from slack.", - ), -} - - -def get_available_tools() -> dict[ToolName, dict]: +class Tool(Enum): + Wiki_Retriever_LangChain = LangChainWikiRetriever + Read_File = ReadFileTool + Search_File = SearchFileTool + Python_Interpreter = PythonInterpreter + Calculator = Calculator + Google_Drive = GoogleDrive + Web_Scrape = WebScrapeTool + Tavily_Web_Search = TavilyWebSearch + Google_Web_Search = GoogleWebSearch + Brave_Web_Search = BraveWebSearch + Hybrid_Web_Search = HybridWebSearch + Slack = SlackTool + + +def get_available_tools() -> dict[str, ToolDefinition]: + # Get list of implementations from Tool Enum + tool_classes = [tool.value for tool in Tool] + # Generate dictionary of ToolDefinitions keyed by Tool ID + tools = { + tool.ID: tool.get_tool_definition() for tool in tool_classes + } + + # Handle adding Community-implemented tools use_community_tools = Settings().get('feature_flags.use_community_features') - - tools = ALL_TOOLS.copy() if use_community_tools: try: - from community.config.tools import COMMUNITY_TOOLS - - tools = ALL_TOOLS.copy() - tools.update(COMMUNITY_TOOLS) + from community.config.tools import get_community_tools + community_tools = get_community_tools() + tools.update(community_tools) except ImportError: logger.warning( event="[Tools] Error loading tools: Community tools not available." ) - for tool in tools.values(): - # Conditionally set error message - tool.error_message = tool.error_message if not tool.is_available else None - # Retrieve name - tool.name = tool.implementation.NAME - - enabled_tools = Settings().get('tools.enabled_tools') - if enabled_tools is not None and len(enabled_tools) > 0: - tools = {key: value for key, value in tools.items() if key in enabled_tools} return tools - - -AVAILABLE_TOOLS = get_available_tools() diff --git a/src/backend/main.py b/src/backend/main.py index 9569cde052..3bdd288a30 100644 --- a/src/backend/main.py +++ b/src/backend/main.py @@ -1,3 +1,5 @@ +import logging + from alembic.command import upgrade from alembic.config import Config from dotenv import load_dotenv @@ -29,6 +31,9 @@ from backend.services.context import ContextMiddleware, get_context from backend.services.logger.middleware import LoggingMiddleware +# Only show errors for Pydantic +logging.getLogger('pydantic').setLevel(logging.ERROR) + load_dotenv() # CORS Origins diff --git a/src/backend/pytest.ini b/src/backend/pytest.ini index e8ad063ce5..3ba10e4a74 100644 --- a/src/backend/pytest.ini +++ b/src/backend/pytest.ini @@ -1,3 +1,5 @@ [pytest] env = - DATABASE_URL=postgresql://postgres:postgres@localhost:5433/postgres \ No newline at end of file + DATABASE_URL=postgresql://postgres:postgres@localhost:5433/postgres +filterwarnings = + ignore::UserWarning:pydantic.* diff --git a/src/backend/routers/auth.py b/src/backend/routers/auth.py index b726fef898..3e3d6b52a7 100644 --- a/src/backend/routers/auth.py +++ b/src/backend/routers/auth.py @@ -9,7 +9,7 @@ from backend.config.auth import ENABLED_AUTH_STRATEGY_MAPPING from backend.config.routers import RouterName from backend.config.settings import Settings -from backend.config.tools import AVAILABLE_TOOLS, ToolName +from backend.config.tools import Tool, get_available_tools from backend.crud import blacklist as blacklist_crud from backend.database_models import Blacklist from backend.database_models.database import DBSessionDep @@ -295,8 +295,9 @@ def log_and_redirect_err(error_message: str): err = f"Tool Auth cache {tool_auth_cache} does not contain user_id or tool_id." log_and_redirect_err(err) - if tool_id in AVAILABLE_TOOLS: - tool = AVAILABLE_TOOLS.get(tool_id) + available_tools = get_available_tools() + if tool_id in available_tools: + tool = available_tools.get(tool_id) err = None # Tool not found @@ -336,7 +337,7 @@ async def delete_tool_auth( If completed, the corresponding ToolAuth for the requesting user is removed from the DB. Args: - tool_id (str): Tool ID to be deleted for the user. (eg. google_drive) Should be one of the values listed in the ToolName string enum class. + tool_id (str): Tool ID to be deleted for the user. (eg. google_drive) Should be one of the values listed in the Tool string enum class. request (Request): current Request object. session (DBSessionDep): Database session. ctx (Context): Context object. @@ -356,16 +357,16 @@ async def delete_tool_auth( if user_id is None or user_id == "" or user_id == "default": logger.error_and_raise_http_exception(event="User ID not found.") - if tool_id not in [tool_name.value for tool_name in ToolName]: + if tool_id not in [tool_name.value for tool_name in Tool]: logger.error_and_raise_http_exception( - event="tool_id must be present in the path of the request and must be a member of the ToolName string enum class.", + event="tool_id must be present in the path of the request and must be a member of the Tool string enum class.", ) - tool = AVAILABLE_TOOLS.get(tool_id) + tool = get_available_tools().get(tool_id) if tool is None: logger.error_and_raise_http_exception( - event=f"Tool {tool_id} is not available in AVAILABLE_TOOLS." + event=f"Tool {tool_id} is not available." ) if tool.auth_implementation is None: diff --git a/src/backend/routers/organization.py b/src/backend/routers/organization.py index f1a14c2512..6c252f7c6c 100644 --- a/src/backend/routers/organization.py +++ b/src/backend/routers/organization.py @@ -87,9 +87,10 @@ def get_organization( Args: organization_id (str): Tool ID. session (DBSessionDep): Database session. + ctx: Context. Returns: - ManagedTool: Organization with the given ID. + Organization: Organization with the given ID. """ organization = organization_crud.get_organization(session, organization_id) if not organization: @@ -135,7 +136,7 @@ def list_organizations( session (DBSessionDep): Database session. Returns: - list[ManagedTool]: List of available organizations. + list[Organization]: List of available organizations. """ all_organizations = organization_crud.get_organizations(session) return all_organizations diff --git a/src/backend/routers/tool.py b/src/backend/routers/tool.py index a0e95fb3ba..b9078ebc0c 100644 --- a/src/backend/routers/tool.py +++ b/src/backend/routers/tool.py @@ -1,10 +1,10 @@ from fastapi import APIRouter, Depends, Request from backend.config.routers import RouterName -from backend.config.tools import AVAILABLE_TOOLS +from backend.config.tools import get_available_tools from backend.database_models.database import DBSessionDep from backend.schemas.context import Context -from backend.schemas.tool import ManagedTool +from backend.schemas.tool import ToolDefinition from backend.services.agent import validate_agent_exists from backend.services.context import get_context @@ -12,13 +12,13 @@ router.name = RouterName.TOOL -@router.get("", response_model=list[ManagedTool]) +@router.get("", response_model=list[ToolDefinition]) def list_tools( request: Request, session: DBSessionDep, agent_id: str | None = None, ctx: Context = Depends(get_context), -) -> list[ManagedTool]: +) -> list[ToolDefinition]: """ List all available tools. @@ -28,19 +28,20 @@ def list_tools( agent_id (str): Agent ID. ctx (Context): Context object. Returns: - list[ManagedTool]: List of available tools. + list[ToolDefinition]: List of available tools. """ user_id = ctx.get_user_id() logger = ctx.get_logger() - all_tools = AVAILABLE_TOOLS.values() + available_tools = get_available_tools() + all_tools = list(available_tools.values()) if agent_id is not None: agent_tools = [] agent = validate_agent_exists(session, agent_id, user_id) for tool in agent.tools: - agent_tools.append(AVAILABLE_TOOLS[tool]) + agent_tools.append(available_tools[tool]) all_tools = agent_tools for tool in all_tools: diff --git a/src/backend/schemas/tool.py b/src/backend/schemas/tool.py index ec92090ae8..d8fa884bd7 100644 --- a/src/backend/schemas/tool.py +++ b/src/backend/schemas/tool.py @@ -4,30 +4,25 @@ from pydantic import BaseModel, Field -class Category(StrEnum): +class ToolCategory(StrEnum): DataLoader = "Data loader" FileLoader = "File loader" Function = "Function" WebSearch = "Web search" -class ToolInput(BaseModel): - pass - - class Tool(BaseModel): name: Optional[str] = "" - display_name: str = "" - description: Optional[str] = "" parameter_definitions: Optional[dict] = {} - -class ManagedTool(Tool): +class ToolDefinition(Tool): + display_name: str = "" + description: str = "" + error_message: Optional[str] = "" kwargs: dict = {} is_visible: bool = False is_available: bool = False - error_message: Optional[str] = "" - category: Category = Category.DataLoader + category: ToolCategory = ToolCategory.DataLoader is_auth_required: bool = False # Per user auth_url: Optional[str] = "" # Per user diff --git a/src/backend/services/auth/strategies/base.py b/src/backend/services/auth/strategies/base.py index 38112519a2..e305545824 100644 --- a/src/backend/services/auth/strategies/base.py +++ b/src/backend/services/auth/strategies/base.py @@ -43,13 +43,14 @@ class BaseOAuthStrategy: def __init__(self, *args, **kwargs): self._post_init_check() - def _post_init_check(self): + @classmethod + def _post_init_check(cls): if any( [ - self.NAME is None, + cls.NAME is None, ] ): - raise ValueError(f"{self.__name__} must have NAME attribute defined.") + raise ValueError(f"{cls.__name__} must have NAME attribute defined.") @abstractmethod def get_client_id(self, **kwargs: Any): diff --git a/src/backend/services/chat.py b/src/backend/services/chat.py index bf1d560a7b..8e8abc6e6e 100644 --- a/src/backend/services/chat.py +++ b/src/backend/services/chat.py @@ -9,7 +9,7 @@ from backend.chat.collate import to_dict from backend.chat.enums import StreamEvent -from backend.config.tools import AVAILABLE_TOOLS +from backend.config.tools import get_available_tools from backend.crud import agent_tool_metadata as agent_tool_metadata_crud from backend.crud import conversation as conversation_crud from backend.crud import message as message_crud @@ -156,7 +156,7 @@ def process_chat( tools = chat_request.tools managed_tools = ( - len([tool.name for tool in tools if tool.name in AVAILABLE_TOOLS]) > 0 + len([tool.name for tool in tools if tool.name in get_available_tools()]) > 0 ) return ( @@ -253,7 +253,7 @@ def process_message_regeneration( ) managed_tools = ( - len([tool.name for tool in chat_request.tools if tool.name in AVAILABLE_TOOLS]) > 0 + len([tool.name for tool in chat_request.tools if tool.name in get_available_tools()]) > 0 ) return ( @@ -313,7 +313,7 @@ def is_custom_tool_call(chat_response: BaseChatRequest) -> bool: # check if any of the tools is not in the available tools for tool in chat_response.tools: - if tool.name not in AVAILABLE_TOOLS: + if tool.name not in get_available_tools(): return True return False diff --git a/src/backend/services/file.py b/src/backend/services/file.py index f1d8b02234..d52212efc7 100644 --- a/src/backend/services/file.py +++ b/src/backend/services/file.py @@ -132,7 +132,7 @@ def get_files_by_agent_id( Returns: list[File]: The files that were created """ - from backend.config.tools import ToolName + from backend.config.tools import Tool from backend.tools.files import FileToolsArtifactTypes agent = validate_agent_exists(session, agent_id, user_id) @@ -144,8 +144,8 @@ def get_files_by_agent_id( ( tool_metadata.artifacts for tool_metadata in agent_tool_metadata - if tool_metadata.tool_name == ToolName.Read_File - or tool_metadata.tool_name == ToolName.Search_File + if tool_metadata.tool_name == Tool.Read_File.value.ID + or tool_metadata.tool_name == Tool.Search_File.value.ID ), [], # Default value if the generator is empty ) diff --git a/src/backend/services/request_validators.py b/src/backend/services/request_validators.py index badb2b4369..21d6012628 100644 --- a/src/backend/services/request_validators.py +++ b/src/backend/services/request_validators.py @@ -8,7 +8,7 @@ find_config_by_deployment_id, find_config_by_deployment_name, ) -from backend.config.tools import AVAILABLE_TOOLS +from backend.config.tools import get_available_tools from backend.crud import agent as agent_crud from backend.crud import conversation as conversation_crud from backend.crud import deployment as deployment_crud @@ -212,7 +212,7 @@ async def validate_chat_request(session: DBSessionDep, request: Request): if not tools: return - managed_tools = [tool["name"] for tool in tools if tool["name"] in AVAILABLE_TOOLS] + managed_tools = [tool["name"] for tool in tools if tool["name"] in get_available_tools()] if managed_tools and len(tools) != len(managed_tools): raise HTTPException( status_code=400, detail="Cannot mix both managed and custom tools" @@ -288,7 +288,7 @@ async def validate_create_agent_request(session: DBSessionDep, request: Request) tools = body.get("tools") if tools: for tool in tools: - if tool not in AVAILABLE_TOOLS: + if tool not in get_available_tools(): raise HTTPException(status_code=404, detail=f"Tool {tool} not found.") name = body.get("name") @@ -339,7 +339,7 @@ async def validate_update_agent_request(session: DBSessionDep, request: Request) tools = body.get("tools") if tools: for tool in tools: - if tool not in AVAILABLE_TOOLS: + if tool not in get_available_tools(): logger.error(event="Tool not found.", tool=tool) raise HTTPException(status_code=404, detail=f"Tool {tool} not found.") diff --git a/src/backend/tests/integration/routers/test_agent.py b/src/backend/tests/integration/routers/test_agent.py index 9661606fe2..9ba0be0649 100644 --- a/src/backend/tests/integration/routers/test_agent.py +++ b/src/backend/tests/integration/routers/test_agent.py @@ -3,7 +3,7 @@ from sqlalchemy.orm import Session from backend.config.deployments import ModelDeploymentName -from backend.config.tools import ToolName +from backend.config.tools import Tool from backend.database_models.agent import Agent from backend.database_models.agent_tool_metadata import AgentToolMetadata from backend.tests.unit.factories import get_factory @@ -18,7 +18,7 @@ def test_create_agent(session_client: TestClient, session: Session, user) -> Non "temperature": 0.5, "model": "command-r-plus", "deployment": ModelDeploymentName.CoherePlatform, - "tools": [ToolName.Calculator, ToolName.Search_File, ToolName.Read_File], + "tools": [Tool.Calculator.value.ID, Tool.Search_File.value.ID, Tool.Read_File.value.ID], } response = session_client.post( @@ -59,10 +59,10 @@ def test_create_agent_with_tool_metadata( "temperature": 0.5, "model": "command-r-plus", "deployment": ModelDeploymentName.CoherePlatform, - "tools": [ToolName.Google_Drive, ToolName.Search_File], + "tools": [Tool.Google_Drive.value.ID, Tool.Search_File.value.ID], "tools_metadata": [ { - "tool_name": ToolName.Google_Drive, + "tool_name": Tool.Google_Drive.value.ID, "artifacts": [ { "name": "/folder", @@ -72,7 +72,7 @@ def test_create_agent_with_tool_metadata( ], }, { - "tool_name": ToolName.Search_File, + "tool_name": Tool.Search_File.value.ID, "artifacts": [ { "name": "file.txt", @@ -96,11 +96,11 @@ def test_create_agent_with_tool_metadata( .all() ) assert len(tool_metadata) == 2 - assert tool_metadata[0].tool_name == ToolName.Google_Drive + assert tool_metadata[0].tool_name == Tool.Google_Drive.value.ID assert tool_metadata[0].artifacts == [ {"name": "/folder", "ids": "folder1", "type": "folder_ids"}, ] - assert tool_metadata[1].tool_name == ToolName.Search_File + assert tool_metadata[1].tool_name == Tool.Search_File.value.ID assert tool_metadata[1].artifacts == [ {"name": "file.txt", "ids": "file1", "type": "file_ids"} ] diff --git a/src/backend/tests/unit/chat/test_tool_calls.py b/src/backend/tests/unit/chat/test_tool_calls.py index e30173d295..b161049de3 100644 --- a/src/backend/tests/unit/chat/test_tool_calls.py +++ b/src/backend/tests/unit/chat/test_tool_calls.py @@ -6,16 +6,22 @@ from fastapi import HTTPException from backend.chat.custom.tool_calls import async_call_tools -from backend.config.tools import AVAILABLE_TOOLS, ToolName -from backend.schemas.tool import ManagedTool +from backend.config.tools import Tool +from backend.schemas.tool import ToolDefinition from backend.services.context import Context from backend.tests.unit.model_deployments.mock_deployments import MockCohereDeployment from backend.tools.base import BaseTool -def test_async_call_tools_success() -> None: +@pytest.fixture +def mock_get_available_tools(): + with patch("backend.chat.custom.tool_calls.get_available_tools") as mock: + yield mock + + +def test_async_call_tools_success(mock_get_available_tools) -> None: class MockCalculator(BaseTool): - NAME = "toolkit_calculator" + ID = "toolkit_calculator" async def call( self, parameters: dict, ctx: Any, **kwargs: Any @@ -26,29 +32,28 @@ async def call( chat_history = [ { "tool_calls": [ - {"name": "toolkit_calculator", "parameters": {"expression": "6*7"}} + {"name": "toolkit_calculator", "parameters": {"code": "6*7"}} ] } ] - MOCKED_TOOLS = {ToolName.Calculator: ManagedTool(implementation=MockCalculator)} - with patch.dict(AVAILABLE_TOOLS, MOCKED_TOOLS): - results = asyncio.run( - async_call_tools(chat_history, MockCohereDeployment(), ctx) - ) - assert results == [ - { - "call": { - "name": "toolkit_calculator", - "parameters": {"expression": "6*7"}, - }, - "outputs": [{"result": 42}], - } - ] - - -def test_async_call_tools_failure() -> None: + mock_get_available_tools.return_value = {Tool.Calculator.value.ID: ToolDefinition(implementation=MockCalculator)} + results = asyncio.run( + async_call_tools(chat_history, MockCohereDeployment(), ctx) + ) + assert results == [ + { + "call": { + "name": "toolkit_calculator", + "parameters": {"code": "6*7"}, + }, + "outputs": [{"result": 42}], + } + ] + + +def test_async_call_tools_failure(mock_get_available_tools) -> None: class MockCalculator(BaseTool): - NAME = "toolkit_calculator" + ID = "toolkit_calculator" async def call( self, parameters: dict, ctx: Any, **kwargs: Any @@ -59,32 +64,31 @@ async def call( chat_history = [ { "tool_calls": [ - {"name": "toolkit_calculator", "parameters": {"expression": "6*7"}} + {"name": "toolkit_calculator", "parameters": {"code": "6*7"}} ] } ] - MOCKED_TOOLS = {ToolName.Calculator: ManagedTool(implementation=MockCalculator)} - with patch.dict(AVAILABLE_TOOLS, MOCKED_TOOLS): - results = asyncio.run( - async_call_tools(chat_history, MockCohereDeployment(), ctx) - ) - assert results == [ - { - "call": { - "name": "toolkit_calculator", - "parameters": {"expression": "6*7"}, - }, - "outputs": [ - {"error": "Calculator failed", "status_code": 500, "success": False} - ], + mock_get_available_tools.return_value = {Tool.Calculator.value.ID: ToolDefinition(implementation=MockCalculator)} + results = asyncio.run( + async_call_tools(chat_history, MockCohereDeployment(), ctx) + ) + assert results == [ + { + "call": { + "name": "toolkit_calculator", + "parameters": {"code": "6*7"}, }, - ] + "outputs": [ + {"error": "Calculator failed", "status_code": 500, "success": False} + ], + }, + ] @patch("backend.chat.custom.tool_calls.TIMEOUT_SECONDS", 1) -def test_async_call_tools_timeout() -> None: +def test_async_call_tools_timeout(mock_get_available_tools) -> None: class MockCalculator(BaseTool): - NAME = "toolkit_calculator" + ID = "toolkit_calculator" async def call( self, parameters: dict, ctx: Any, **kwargs: Any @@ -96,23 +100,23 @@ async def call( chat_history = [ { "tool_calls": [ - {"name": "toolkit_calculator", "parameters": {"expression": "6*7"}} + {"name": "toolkit_calculator", "parameters": {"code": "6*7"}} ] } ] - MOCKED_TOOLS = {ToolName.Calculator: ManagedTool(implementation=MockCalculator)} - with patch.dict(AVAILABLE_TOOLS, MOCKED_TOOLS): - with pytest.raises(HTTPException) as excinfo: - asyncio.run(async_call_tools(chat_history, MockCohereDeployment(), ctx)) - assert str(excinfo.value.status_code) == "500" - assert ( - str(excinfo.value.detail) == "Timeout while calling tools with timeout: 1" + mock_get_available_tools.return_value = {Tool.Calculator.value.ID: ToolDefinition(implementation=MockCalculator)} + + with pytest.raises(HTTPException) as excinfo: + asyncio.run(async_call_tools(chat_history, MockCohereDeployment(), ctx)) + assert str(excinfo.value.status_code) == "500" + assert ( + str(excinfo.value.detail) == "Timeout while calling tools with timeout: 1" ) -def test_async_call_tools_failure_and_success() -> None: +def test_async_call_tools_failure_and_success(mock_get_available_tools) -> None: class MockWebScrape(BaseTool): - NAME = "web_scrape" + ID = "web_scrape" async def call( self, parameters: dict, ctx: Any, **kwargs: Any @@ -120,7 +124,7 @@ async def call( raise Exception("Web scrape failed") class MockCalculator(BaseTool): - NAME = "toolkit_calculator" + ID = "toolkit_calculator" async def call( self, parameters: dict, ctx: Any, **kwargs: Any @@ -131,26 +135,26 @@ async def call( chat_history = [ { "tool_calls": [ - {"name": "web_scrape", "parameters": {"expression": "6*7"}}, - {"name": "toolkit_calculator", "parameters": {"expression": "6*7"}}, + {"name": "web_scrape", "parameters": {"code": "6*7"}}, + {"name": "toolkit_calculator", "parameters": {"code": "6*7"}}, ] } ] - MOCKED_TOOLS = { - ToolName.Calculator: ManagedTool(implementation=MockCalculator), - ToolName.Web_Scrape: ManagedTool(implementation=MockWebScrape), + mock_get_available_tools.return_value = { + Tool.Calculator.value.ID: ToolDefinition(implementation=MockCalculator), + Tool.Web_Scrape.value.ID: ToolDefinition(implementation=MockWebScrape), } - with patch.dict(AVAILABLE_TOOLS, MOCKED_TOOLS): - results = asyncio.run( - async_call_tools(chat_history, MockCohereDeployment(), ctx) - ) - assert { - "call": {"name": "web_scrape", "parameters": {"expression": "6*7"}}, - "outputs": [ - {"error": "Web scrape failed", "status_code": 500, "success": False} - ], - } in results - assert { - "call": {"name": "toolkit_calculator", "parameters": {"expression": "6*7"}}, - "outputs": [{"result": 42}], - } in results + + results = asyncio.run( + async_call_tools(chat_history, MockCohereDeployment(), ctx) + ) + assert { + "call": {"name": "web_scrape", "parameters": {"code": "6*7"}}, + "outputs": [ + {"error": "Web scrape failed", "status_code": 500, "success": False} + ], + } in results + assert { + "call": {"name": "toolkit_calculator", "parameters": {"code": "6*7"}}, + "outputs": [{"result": 42}], + } in results diff --git a/src/backend/tests/unit/config/test_deployments.py b/src/backend/tests/unit/config/test_deployments.py index bb6bac146f..adaa443040 100644 --- a/src/backend/tests/unit/config/test_deployments.py +++ b/src/backend/tests/unit/config/test_deployments.py @@ -1,13 +1,6 @@ -from unittest.mock import Mock +from backend.config.tools import Tool -from backend.config.deployments import ( - get_default_deployment, -) -from backend.tests.unit.model_deployments.mock_deployments.mock_cohere_platform import ( - MockCohereDeployment, -) - -def test_get_default_deployment(mock_available_model_deployments: Mock) -> None: - default_deployment = get_default_deployment() - assert isinstance(default_deployment, MockCohereDeployment) +def test_all_tools_have_id() -> None: + for tool in Tool: + assert tool.value.ID is not None diff --git a/src/backend/tests/unit/config/test_tools.py b/src/backend/tests/unit/config/test_tools.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/backend/tests/unit/configuration.yaml b/src/backend/tests/unit/configuration.yaml index a620a18a20..501c4b531e 100644 --- a/src/backend/tests/unit/configuration.yaml +++ b/src/backend/tests/unit/configuration.yaml @@ -16,7 +16,6 @@ database: redis: url: tools: - enabled_tools: python_interpreter: url: feature_flags: diff --git a/src/backend/tests/unit/crud/test_agent.py b/src/backend/tests/unit/crud/test_agent.py index 18d92b6c15..5da2fafdc8 100644 --- a/src/backend/tests/unit/crud/test_agent.py +++ b/src/backend/tests/unit/crud/test_agent.py @@ -2,7 +2,7 @@ from sqlalchemy.exc import IntegrityError from sqlalchemy.sql.expression import false -from backend.config.tools import ToolName +from backend.config.tools import Tool from backend.crud import agent as agent_crud from backend.database_models.agent import Agent from backend.schemas.agent import AgentVisibility, UpdateAgentRequest @@ -17,7 +17,7 @@ def test_create_agent(session, user): description="test", preamble="test", temperature=0.5, - tools=[ToolName.Wiki_Retriever_LangChain, ToolName.Search_File], + tools=[Tool.Wiki_Retriever_LangChain.value.ID, Tool.Search_File.value.ID], is_private=True, ) @@ -28,7 +28,7 @@ def test_create_agent(session, user): assert agent.description == "test" assert agent.preamble == "test" assert agent.temperature == 0.5 - assert agent.tools == [ToolName.Wiki_Retriever_LangChain, ToolName.Search_File] + assert agent.tools == [Tool.Wiki_Retriever_LangChain.value.ID, Tool.Search_File.value.ID] assert agent.is_private agent = agent_crud.get_agent_by_id(session, agent.id, user.id) @@ -38,7 +38,7 @@ def test_create_agent(session, user): assert agent.description == "test" assert agent.preamble == "test" assert agent.temperature == 0.5 - assert agent.tools == [ToolName.Wiki_Retriever_LangChain, ToolName.Search_File] + assert agent.tools == [Tool.Wiki_Retriever_LangChain.value.ID, Tool.Search_File.value.ID] def test_create_agent_empty_non_required_fields(session, user): @@ -87,7 +87,7 @@ def test_create_agent_duplicate_name_version(session, user): description="test", preamble="test", temperature=0.5, - tools=[ToolName.Wiki_Retriever_LangChain, ToolName.Search_File], + tools=[Tool.Wiki_Retriever_LangChain.value.ID, Tool.Search_File.value.ID], ) with pytest.raises(IntegrityError): @@ -205,7 +205,7 @@ def test_update_agent(session, user): preamble="test", temperature=0.5, user=user, - tools=[ToolName.Wiki_Retriever_LangChain, ToolName.Search_File], + tools=[Tool.Wiki_Retriever_LangChain.value.ID, Tool.Search_File.value.ID], ) new_agent_data = UpdateAgentRequest( @@ -214,7 +214,7 @@ def test_update_agent(session, user): version=2, preamble="new_test", temperature=0.6, - tools=[ToolName.Python_Interpreter, ToolName.Calculator], + tools=[Tool.Python_Interpreter.value.ID, Tool.Calculator.value.ID], ) agent = agent_crud.update_agent(session, agent, new_agent_data, user.id) @@ -223,7 +223,7 @@ def test_update_agent(session, user): assert agent.version == new_agent_data.version assert agent.preamble == new_agent_data.preamble assert agent.temperature == new_agent_data.temperature - assert agent.tools == [ToolName.Python_Interpreter, ToolName.Calculator] + assert agent.tools == [Tool.Python_Interpreter.value.ID, Tool.Calculator.value.ID] def test_delete_agent(session, user): diff --git a/src/backend/tests/unit/crud/test_agent_tool_metadata.py b/src/backend/tests/unit/crud/test_agent_tool_metadata.py index 813e85cbf3..d6ec6a9e64 100644 --- a/src/backend/tests/unit/crud/test_agent_tool_metadata.py +++ b/src/backend/tests/unit/crud/test_agent_tool_metadata.py @@ -1,7 +1,7 @@ import pytest from sqlalchemy.exc import IntegrityError -from backend.config.tools import ToolName +from backend.config.tools import Tool from backend.crud import agent_tool_metadata as agent_tool_metadata_crud from backend.database_models.agent_tool_metadata import AgentToolMetadata from backend.schemas.agent import UpdateAgentToolMetadataRequest @@ -23,13 +23,13 @@ def test_create_agent_tool_metadata(session, user): agent = get_factory("Agent", session).create( - id="1", name="test_agent", tools=[ToolName.Google_Drive], user=user + id="1", name="test_agent", tools=[Tool.Google_Drive.value.ID], user=user ) agent_tool_metadata_data = AgentToolMetadata( user_id=user.id, agent_id=agent.id, - tool_name=ToolName.Google_Drive, + tool_name=Tool.Google_Drive.value.ID, artifacts=[mock_artifact_1], ) agent_tool_metadata = agent_tool_metadata_crud.create_agent_tool_metadata( @@ -37,7 +37,7 @@ def test_create_agent_tool_metadata(session, user): ) assert agent_tool_metadata.user_id == user.id assert agent_tool_metadata.agent_id == agent.id - assert agent_tool_metadata.tool_name == ToolName.Google_Drive + assert agent_tool_metadata.tool_name == Tool.Google_Drive.value.ID assert agent_tool_metadata.artifacts == [mock_artifact_1] assert agent_tool_metadata.artifacts[0]["type"] == "document" @@ -46,7 +46,7 @@ def test_create_agent_tool_metadata(session, user): ) assert agent_tool_metadata.user_id == user.id assert agent_tool_metadata.agent_id == agent.id - assert agent_tool_metadata.tool_name == ToolName.Google_Drive + assert agent_tool_metadata.tool_name == Tool.Google_Drive.value.ID assert agent_tool_metadata.artifacts == [mock_artifact_1] assert agent_tool_metadata.artifacts[0]["type"] == "document" @@ -54,7 +54,7 @@ def test_create_agent_tool_metadata(session, user): def test_create_agent_missing_agent_id(session, user): agent_tool_metadata_data = AgentToolMetadata( user_id=user.id, - tool_name=ToolName.Google_Drive, + tool_name=Tool.Google_Drive.value.ID, artifacts=[mock_artifact_1], ) with pytest.raises(IntegrityError): @@ -65,7 +65,7 @@ def test_create_agent_missing_agent_id(session, user): def test_create_agent_missing_tool_name(session, user): agent = get_factory("Agent", session).create( - id="1", name="test_agent", tools=[ToolName.Google_Drive], user=user + id="1", name="test_agent", tools=[Tool.Google_Drive.value.ID], user=user ) agent_tool_metadata_data = AgentToolMetadata( @@ -81,12 +81,12 @@ def test_create_agent_missing_tool_name(session, user): def test_create_agent_missing_user_id(session, user): agent = get_factory("Agent", session).create( - id="1", name="test_agent", tools=[ToolName.Google_Drive], user=user + id="1", name="test_agent", tools=[Tool.Google_Drive.value.ID], user=user ) agent_tool_metadata_data = AgentToolMetadata( agent_id=agent.id, - tool_name=ToolName.Google_Drive, + tool_name=Tool.Google_Drive.value.ID, artifacts=[mock_artifact_1], user_id="123", ) @@ -101,7 +101,7 @@ def test_update_agent_tool_metadata(session, user): original_agent_tool_metadata = get_factory("AgentToolMetadata", session).create( user_id=user.id, agent_id=agent.id, - tool_name=ToolName.Google_Drive, + tool_name=Tool.Google_Drive.value.ID, artifacts=[mock_artifact_1], ) @@ -123,7 +123,7 @@ def test_get_agent_tool_metadata_by_id(session, user): agent_tool_metadata = get_factory("AgentToolMetadata", session).create( user_id=user.id, agent_id=agent.id, - tool_name=ToolName.Google_Drive, + tool_name=Tool.Google_Drive.value.ID, artifacts=[mock_artifact_1, mock_artifact_2], ) agent_tool_metadata = agent_tool_metadata_crud.get_agent_tool_metadata_by_id( @@ -131,7 +131,7 @@ def test_get_agent_tool_metadata_by_id(session, user): ) assert agent_tool_metadata.user_id == user.id assert agent_tool_metadata.agent_id == agent.id - assert agent_tool_metadata.tool_name == ToolName.Google_Drive + assert agent_tool_metadata.tool_name == Tool.Google_Drive.value.ID assert agent_tool_metadata.artifacts == [mock_artifact_1, mock_artifact_2] @@ -143,18 +143,20 @@ def test_get_all_agent_tool_metadata_by_agent_id(session, user): _ = get_factory("AgentToolMetadata", session).create( user_id=user.id, agent_id=agent1.id, - tool_name=ToolName.Google_Drive, + tool_name=Tool.Google_Drive.value.ID, artifacts=[mock_artifact_1, mock_artifact_2], ) + # Constraint was added preventing multiple entries for the same user + agent + tool so fixing to change the tool used i = 0 - for tool in ToolName: + for tool in Tool: i += 1 + _ = get_factory("Agent", session).create(user_id=user.id) _ = get_factory("AgentToolMetadata", session).create( id=f"{i}", - tool_name=tool.value, + tool_name=tool.value.ID, artifacts=[mock_artifact_1, mock_artifact_2], user_id=user.id, agent_id=agent2.id, @@ -165,7 +167,7 @@ def test_get_all_agent_tool_metadata_by_agent_id(session, user): session, agent_id=agent2.id ) ) - assert len(all_agent_tool_metadata) == len(ToolName) + assert len(all_agent_tool_metadata) == len(Tool) def test_delete_agent_tool_metadata_by_id(session, user): @@ -173,7 +175,7 @@ def test_delete_agent_tool_metadata_by_id(session, user): agent_tool_metadata = get_factory("AgentToolMetadata", session).create( user_id=user.id, agent_id=agent.id, - tool_name=ToolName.Google_Drive, + tool_name=Tool.Google_Drive.value.ID, artifacts=[mock_artifact_1, mock_artifact_2], ) diff --git a/src/backend/tests/unit/crud/test_tool_auth.py b/src/backend/tests/unit/crud/test_tool_auth.py index 251adce406..c6d2772858 100644 --- a/src/backend/tests/unit/crud/test_tool_auth.py +++ b/src/backend/tests/unit/crud/test_tool_auth.py @@ -1,6 +1,6 @@ from datetime import datetime -from backend.config.tools import ToolName +from backend.config.tools import Tool from backend.crud import tool_auth as tool_auth_crud from backend.database_models.tool_auth import ToolAuth from backend.tests.unit.factories import get_factory @@ -9,7 +9,7 @@ def test_create_tool_auth(session, user): tool_auth_data = ToolAuth( user_id=user.id, - tool_id=ToolName.Google_Drive, + tool_id=Tool.Google_Drive.value.ID, token_type="Bearer", encrypted_access_token=bytes(b"foobar"), encrypted_refresh_token=bytes(b"foobar"), @@ -34,7 +34,7 @@ def test_create_tool_auth(session, user): def test_delete_tool_auth_by_tool_id(session, user): tool_auth = get_factory("ToolAuth", session).create( user_id=user.id, - tool_id=ToolName.Google_Drive, + tool_id=Tool.Google_Drive.value.ID, token_type="Bearer", encrypted_access_token=bytes(b"foobar"), encrypted_refresh_token=bytes(b"foobar"), diff --git a/src/backend/tests/unit/factories/agent.py b/src/backend/tests/unit/factories/agent.py index 7b50336e26..0b04348157 100644 --- a/src/backend/tests/unit/factories/agent.py +++ b/src/backend/tests/unit/factories/agent.py @@ -1,6 +1,6 @@ import factory -from backend.config.tools import ToolName +from backend.config.tools import Tool from backend.database_models.agent import Agent from backend.tests.unit.factories.base import BaseFactory from backend.tests.unit.factories.user import UserFactory @@ -25,12 +25,12 @@ class Meta: factory.Faker( "random_element", elements=[ - ToolName.Wiki_Retriever_LangChain, - ToolName.Search_File, - ToolName.Read_File, - ToolName.Python_Interpreter, - ToolName.Calculator, - ToolName.Tavily_Web_Search, + Tool.Wiki_Retriever_LangChain.value.ID, + Tool.Search_File.value.ID, + Tool.Read_File.value.ID, + Tool.Python_Interpreter.value.ID, + Tool.Calculator.value.ID, + Tool.Tavily_Web_Search.value.ID, ], ) ] diff --git a/src/backend/tests/unit/factories/agent_tool_metadata.py b/src/backend/tests/unit/factories/agent_tool_metadata.py index d8c9151542..0bb8520528 100644 --- a/src/backend/tests/unit/factories/agent_tool_metadata.py +++ b/src/backend/tests/unit/factories/agent_tool_metadata.py @@ -1,6 +1,6 @@ import factory -from backend.config.tools import ToolName +from backend.config.tools import Tool from backend.database_models.agent_tool_metadata import AgentToolMetadata from .base import BaseFactory @@ -16,13 +16,13 @@ class Meta: factory.Faker( "random_element", elements=[ - ToolName.Wiki_Retriever_LangChain, - ToolName.Search_File, - ToolName.Read_File, - ToolName.Python_Interpreter, - ToolName.Calculator, - ToolName.Tavily_Web_Search, - ToolName.Google_Drive, + Tool.Wiki_Retriever_LangChain.value.ID, + Tool.Search_File.value.ID, + Tool.Read_File.value.ID, + Tool.Python_Interpreter.value.ID, + Tool.Calculator.value.ID, + Tool.Tavily_Web_Search.value.ID, + Tool.Google_Drive.value.ID, ], ) ] diff --git a/src/backend/tests/unit/factories/tool_auth.py b/src/backend/tests/unit/factories/tool_auth.py index 8393eae259..af6f198288 100644 --- a/src/backend/tests/unit/factories/tool_auth.py +++ b/src/backend/tests/unit/factories/tool_auth.py @@ -2,7 +2,7 @@ import factory -from backend.config.tools import ToolName +from backend.config.tools import Tool from backend.database_models.tool_auth import ToolAuth from .base import BaseFactory @@ -13,7 +13,7 @@ class Meta: model = ToolAuth user_id = factory.Faker("uuid4") - tool_id = ToolName.Google_Drive + tool_id = Tool.Google_Drive.value.ID token_type = "Bearer" encrypted_access_token = bytes(b"foobar") encrypted_refresh_token = bytes(b"foobar") diff --git a/src/backend/tests/unit/routers/test_agent.py b/src/backend/tests/unit/routers/test_agent.py index b047318a82..725c2a752e 100644 --- a/src/backend/tests/unit/routers/test_agent.py +++ b/src/backend/tests/unit/routers/test_agent.py @@ -5,7 +5,7 @@ from sqlalchemy.orm import Session from backend.config.deployments import ModelDeploymentName -from backend.config.tools import ToolName +from backend.config.tools import Tool from backend.crud import agent as agent_crud from backend.crud import deployment as deployment_crud from backend.database_models.agent import Agent @@ -135,14 +135,14 @@ def test_create_agent_invalid_tool( "name": "test agent", "model": "command-r-plus", "deployment": ModelDeploymentName.CoherePlatform, - "tools": [ToolName.Calculator, "not a real tool"], + "tools": [Tool.Calculator.value.ID, "fake_tool"], } response = session_client.post( "/v1/agents", json=request_json, headers={"User-Id": user.id} ) assert response.status_code == 404 - assert response.json() == {"detail": "Tool not a real tool not found."} + assert response.json() == {"detail": "Tool fake_tool not found."} def test_create_existing_agent( @@ -372,7 +372,7 @@ def test_get_agent(session_client: TestClient, session: Session, user) -> None: agent_tool_metadata = get_factory("AgentToolMetadata", session).create( user_id=user.id, agent_id=agent.id, - tool_name=ToolName.Google_Drive, + tool_name=Tool.Google_Drive.value.ID, artifacts=[ { "name": "/folder1", @@ -393,7 +393,7 @@ def test_get_agent(session_client: TestClient, session: Session, user) -> None: assert response.status_code == 200 response_agent = response.json() assert response_agent["name"] == agent.name - assert response_agent["tools_metadata"][0]["tool_name"] == ToolName.Google_Drive + assert response_agent["tools_metadata"][0]["tool_name"] == Tool.Google_Drive.value.ID assert ( response_agent["tools_metadata"][0]["artifacts"] == agent_tool_metadata.artifacts @@ -498,13 +498,13 @@ def test_partial_update_agent(session_client: TestClient, session: Session) -> N description="test description", preamble="test preamble", temperature=0.5, - tools=[ToolName.Calculator], + tools=[Tool.Calculator.value.ID], user=user, ) request_json = { "name": "updated name", - "tools": [ToolName.Search_File, ToolName.Read_File], + "tools": [Tool.Search_File.value.ID, Tool.Read_File.value.ID], } response = session_client.put( @@ -519,7 +519,7 @@ def test_partial_update_agent(session_client: TestClient, session: Session) -> N assert updated_agent["description"] == "test description" assert updated_agent["preamble"] == "test preamble" assert updated_agent["temperature"] == 0.5 - assert updated_agent["tools"] == [ToolName.Search_File, ToolName.Read_File] + assert updated_agent["tools"] == [Tool.Search_File.value.ID, Tool.Read_File.value.ID] def test_update_agent_with_tool_metadata( @@ -537,7 +537,7 @@ def test_update_agent_with_tool_metadata( agent_tool_metadata = get_factory("AgentToolMetadata", session).create( user_id=user.id, agent_id=agent.id, - tool_name=ToolName.Google_Drive, + tool_name=Tool.Google_Drive.value.ID, artifacts=[ { "url": "test", @@ -601,7 +601,7 @@ def test_update_agent_with_tool_metadata_and_new_tool_metadata( agent_tool_metadata = get_factory("AgentToolMetadata", session).create( user_id=user.id, agent_id=agent.id, - tool_name=ToolName.Google_Drive, + tool_name=Tool.Google_Drive.value.ID, artifacts=[ { "url": "test", @@ -681,7 +681,7 @@ def test_update_agent_remove_existing_tool_metadata( get_factory("AgentToolMetadata", session).create( user_id=user.id, agent_id=agent.id, - tool_name=ToolName.Google_Drive, + tool_name=Tool.Google_Drive.value.ID, artifacts=[ { "url": "test", @@ -809,7 +809,7 @@ def test_update_agent_invalid_tool( request_json = { "model": "not a real model", "deployment": "not a real deployment", - "tools": [ToolName.Calculator, "not a real tool"], + "tools": [Tool.Calculator.value.ID, "not a real tool"], } response = session_client.put( @@ -1036,7 +1036,7 @@ def test_create_agent_tool_metadata( ) -> None: agent = get_factory("Agent", session).create(user=user) request_json = { - "tool_name": ToolName.Google_Drive, + "tool_name": Tool.Google_Drive.value.ID, "artifacts": [ { "name": "/folder1", @@ -1065,7 +1065,7 @@ def test_create_agent_tool_metadata( agent_tool_metadata = session.get( AgentToolMetadata, response_agent_tool_metadata["id"] ) - assert agent_tool_metadata.tool_name == ToolName.Google_Drive + assert agent_tool_metadata.tool_name == Tool.Google_Drive.value.ID assert agent_tool_metadata.artifacts == [ { "name": "/folder1", @@ -1087,7 +1087,7 @@ def test_update_agent_tool_metadata( agent_tool_metadata = get_factory("AgentToolMetadata", session).create( user_id=user.id, agent_id=agent.id, - tool_name=ToolName.Google_Drive, + tool_name=Tool.Google_Drive.value.ID, artifacts=[ { "name": "/folder1", @@ -1148,7 +1148,7 @@ def test_get_agent_tool_metadata( agent_tool_metadata_1 = get_factory("AgentToolMetadata", session).create( user_id=user.id, agent_id=agent.id, - tool_name=ToolName.Google_Drive, + tool_name=Tool.Google_Drive.value.ID, artifacts=[ {"name": "/folder", "ids": ["folder1", "folder2"], "type": "folder_ids"} ], @@ -1156,7 +1156,7 @@ def test_get_agent_tool_metadata( agent_tool_metadata_2 = get_factory("AgentToolMetadata", session).create( user_id=user.id, agent_id=agent.id, - tool_name=ToolName.Search_File, + tool_name=Tool.Search_File.value.ID, artifacts=[{"name": "file.txt", "ids": ["file1", "file2"], "type": "file_ids"}], ) @@ -1182,7 +1182,7 @@ def test_delete_agent_tool_metadata( agent_tool_metadata = get_factory("AgentToolMetadata", session).create( user_id=user.id, agent_id=agent.id, - tool_name=ToolName.Google_Drive, + tool_name=Tool.Google_Drive.value.ID, artifacts=[ { "name": "/folder1", diff --git a/src/backend/tests/unit/routers/test_chat.py b/src/backend/tests/unit/routers/test_chat.py index 7e8d06ea2e..559865f040 100644 --- a/src/backend/tests/unit/routers/test_chat.py +++ b/src/backend/tests/unit/routers/test_chat.py @@ -13,7 +13,7 @@ from backend.database_models.conversation import Conversation from backend.database_models.message import Message, MessageAgent from backend.database_models.user import User -from backend.schemas.tool import Category +from backend.schemas.tool import ToolCategory from backend.tests.unit.factories import get_factory is_cohere_env_set = ( @@ -375,36 +375,11 @@ def test_streaming_fail_chat_missing_message( } -@pytest.mark.skipif(not is_cohere_env_set, reason="Cohere API key not set") -def test_streaming_chat_with_custom_tools(session_client_chat, session_chat, user): - response = session_client_chat.post( - "/v1/chat-stream", - json={ - "message": "Give me a number", - "tools": [ - { - "name": "random_number_generator", - "description": "generate a random number", - } - ], - }, - headers={ - "User-Id": user.id, - "Deployment-Name": ModelDeploymentName.CoherePlatform, - }, - ) - - assert response.status_code == 200 - validate_chat_streaming_response( - response, user, session_chat, session_client_chat, 0, is_custom_tools=True - ) - - @pytest.mark.skipif(not is_cohere_env_set, reason="Cohere API key not set") def test_streaming_chat_with_managed_tools(session_client_chat, session_chat, user): tools = session_client_chat.get("/v1/tools", headers={"User-Id": user.id}).json() assert len(tools) > 0 - tool = [t for t in tools if t["is_visible"] and t["category"] != Category.Function][ + tool = [t for t in tools if t["is_visible"] and t["category"] != ToolCategory.Function][ 0 ].get("name") @@ -446,7 +421,7 @@ def test_streaming_chat_with_managed_and_custom_tools( ): tools = session_client_chat.get("/v1/tools", headers={"User-Id": user.id}).json() assert len(tools) > 0 - tool = [t for t in tools if t["is_visible"] and t["category"] != Category.Function][ + tool = [t for t in tools if t["is_visible"] and t["category"] != ToolCategory.Function][ 0 ].get("name") @@ -806,7 +781,7 @@ def test_non_streaming_chat( def test_non_streaming_chat_with_managed_tools(session_client_chat, session_chat, user): tools = session_client_chat.get("/v1/tools", headers={"User-Id": user.id}).json() assert len(tools) > 0 - tool = [t for t in tools if t["is_visible"] and t["category"] != Category.Function][ + tool = [t for t in tools if t["is_visible"] and t["category"] != ToolCategory.Function][ 0 ].get("name") @@ -831,7 +806,7 @@ def test_non_streaming_chat_with_managed_and_custom_tools( ): tools = session_client_chat.get("/v1/tools", headers={"User-Id": user.id}).json() assert len(tools) > 0 - tool = [t for t in tools if t["is_visible"] and t["category"] != Category.Function][ + tool = [t for t in tools if t["is_visible"] and t["category"] != ToolCategory.Function][ 0 ].get("name") @@ -856,30 +831,6 @@ def test_non_streaming_chat_with_managed_and_custom_tools( assert response.status_code == 400 assert response.json() == {"detail": "Cannot mix both managed and custom tools"} - -@pytest.mark.skipif(not is_cohere_env_set, reason="Cohere API key not set") -def test_non_streaming_chat_with_custom_tools(session_client_chat, session_chat, user): - response = session_client_chat.post( - "/v1/chat", - json={ - "message": "Give me a number", - "tools": [ - { - "name": "random_number_generator", - "description": "generate a random number", - } - ], - }, - headers={ - "User-Id": user.id, - "Deployment-Name": ModelDeploymentName.CoherePlatform, - }, - ) - - assert response.status_code == 200 - assert len(response.json()["tool_calls"]) == 1 - - @pytest.mark.skipif(not is_cohere_env_set, reason="Cohere API key not set") def test_non_streaming_chat_with_search_queries_only( session_client_chat: TestClient, session_chat: Session, user: User diff --git a/src/backend/tests/unit/routers/test_tool.py b/src/backend/tests/unit/routers/test_tool.py index 8636bb1181..943dd5bb89 100644 --- a/src/backend/tests/unit/routers/test_tool.py +++ b/src/backend/tests/unit/routers/test_tool.py @@ -1,7 +1,7 @@ from fastapi.testclient import TestClient from sqlalchemy.orm import Session -from backend.config.tools import AVAILABLE_TOOLS, ToolName +from backend.config.tools import Tool, get_available_tools from backend.schemas.user import User from backend.tests.unit.factories import get_factory @@ -9,18 +9,14 @@ def test_list_tools(session_client: TestClient, session: Session) -> None: response = session_client.get("/v1/tools") assert response.status_code == 200 + available_tools = get_available_tools() for tool in response.json(): - assert tool["name"] in AVAILABLE_TOOLS.keys() - - # get tool that has the same name as the tool in the response - tool_definition = AVAILABLE_TOOLS[tool["name"]] - - assert tool["kwargs"] == tool_definition.kwargs - assert tool["is_visible"] == tool_definition.is_visible - assert tool["is_available"] == tool_definition.is_available - assert tool["error_message"] == tool_definition.error_message - assert tool["category"] == tool_definition.category - assert tool["description"] == tool_definition.description + assert tool["name"] in available_tools.keys() + assert tool["kwargs"] is not None + assert tool["is_visible"] is not None + assert tool["is_available"] is not None + assert tool["category"] is not None + assert tool["description"] is not None def test_list_tools_error_message_none_if_available(client: TestClient) -> None: @@ -35,7 +31,7 @@ def test_list_tools_with_agent( session_client: TestClient, session: Session, user: User ) -> None: agent = get_factory("Agent", session).create( - name="test agent", tools=[ToolName.Wiki_Retriever_LangChain], user=user + name="test agent", tools=[Tool.Wiki_Retriever_LangChain.value.ID], user=user ) response = session_client.get("/v1/tools", params={"agent_id": agent.id}) @@ -43,10 +39,10 @@ def test_list_tools_with_agent( assert len(response.json()) == 1 tool = response.json()[0] - assert tool["name"] == ToolName.Wiki_Retriever_LangChain + assert tool["name"] == Tool.Wiki_Retriever_LangChain.value.ID # get tool that has the same name as the tool in the response - tool_definition = AVAILABLE_TOOLS[tool["name"]] + tool_definition = get_available_tools()[tool["name"]] assert tool["kwargs"] == tool_definition.kwargs assert tool["is_visible"] == tool_definition.is_visible diff --git a/src/backend/tools/base.py b/src/backend/tools/base.py index aa66fcd2c9..203a9328dd 100644 --- a/src/backend/tools/base.py +++ b/src/backend/tools/base.py @@ -1,5 +1,5 @@ import datetime -from abc import abstractmethod +from abc import ABC, abstractmethod from typing import Any, Dict, List from fastapi import Request @@ -8,31 +8,28 @@ from backend.crud import tool_auth as tool_auth_crud from backend.database_models.database import DBSessionDep from backend.database_models.tool_auth import ToolAuth +from backend.schemas.tool import ToolDefinition from backend.services.logger.utils import LoggerFactory logger = LoggerFactory().get_logger() -class BaseTool: +class BaseTool(): """ Abstract base class for all Tools. Attributes: - NAME (str): The name of the tool. + ID (str): The name of the tool. """ - - NAME = None + ID = None def __init__(self, *args, **kwargs): self._post_init_check() - def _post_init_check(self): - if any( - [ - self.NAME is None, - ] - ): - raise ValueError(f"{self.__name__} must have NAME attribute defined.") + @classmethod + def _post_init_check(cls): + if cls.ID is None: + raise ValueError(f"{cls.__name__} must have ID attribute defined.") @classmethod @abstractmethod @@ -40,6 +37,16 @@ def is_available(cls) -> bool: ... @classmethod @abstractmethod + def get_tool_definition(cls) -> ToolDefinition: ... + + @classmethod + def generate_error_message(cls) -> str | None: + if cls.is_available(): + return None + + return f"{cls.__name__} is not available. Please make sure all required config variables are set." + + @classmethod def _handle_tool_specific_errors(cls, error: Exception, **kwargs: Any) -> None: pass @@ -49,7 +56,7 @@ async def call( ) -> List[Dict[str, Any]]: ... -class BaseToolAuthentication: +class BaseToolAuthentication(ABC): """ Abstract base class for Tool Authentication. """ @@ -61,12 +68,13 @@ def __init__(self, *args, **kwargs): self._post_init_check() - def _post_init_check(self): + @classmethod + def _post_init_check(cls): if any( [ - self.BACKEND_HOST is None, - self.FRONTEND_HOST is None, - self.AUTH_SECRET_KEY is None, + cls.BACKEND_HOST is None, + cls.FRONTEND_HOST is None, + cls.AUTH_SECRET_KEY is None, ] ): raise ValueError( diff --git a/src/backend/tools/brave_search/tool.py b/src/backend/tools/brave_search/tool.py index 85899b6a9d..0fcd9bf207 100644 --- a/src/backend/tools/brave_search/tool.py +++ b/src/backend/tools/brave_search/tool.py @@ -4,13 +4,14 @@ from backend.database_models.database import DBSessionDep from backend.model_deployments.base import BaseDeployment from backend.schemas.agent import AgentToolMetadataArtifactsType +from backend.schemas.tool import ToolCategory, ToolDefinition from backend.tools.base import BaseTool from backend.tools.brave_search.client import BraveClient from backend.tools.utils.mixins import WebSearchFilteringMixin class BraveWebSearch(BaseTool, WebSearchFilteringMixin): - NAME = "brave_web_search" + ID = "brave_web_search" BRAVE_API_KEY = Settings().get('tools.brave_web_search.api_key') def __init__(self): @@ -21,6 +22,29 @@ def __init__(self): def is_available(cls) -> bool: return cls.BRAVE_API_KEY is not None + @classmethod + def get_tool_definition(cls) -> ToolDefinition: + return ToolDefinition( + name=cls.ID, + display_name="Brave Web Search", + implementation=cls, + parameter_definitions={ + "query": { + "description": "Query for retrieval.", + "type": "str", + "required": True, + } + }, + is_visible=False, + is_available=cls.is_available(), + error_message=cls.generate_error_message(), + category=ToolCategory.WebSearch, + description=( + "Returns a list of relevant document snippets for a textual query retrieved " + "from the internet using Brave Search." + ), + ) + async def call( self, parameters: dict, ctx: Any, session: DBSessionDep, **kwargs: Any ) -> List[Dict[str, Any]]: diff --git a/src/backend/tools/calculator.py b/src/backend/tools/calculator.py index f4566e2875..3b96859663 100644 --- a/src/backend/tools/calculator.py +++ b/src/backend/tools/calculator.py @@ -2,6 +2,7 @@ from py_expression_eval import Parser +from backend.schemas.tool import ToolCategory, ToolDefinition from backend.tools.base import BaseTool @@ -10,12 +11,32 @@ class Calculator(BaseTool): Function Tool that evaluates mathematical expressions. """ - NAME = "toolkit_calculator" + ID = "toolkit_calculator" @classmethod def is_available(cls) -> bool: return True + @classmethod + def get_tool_definition(cls) -> ToolDefinition: + return ToolDefinition( + name=cls.ID, + display_name="Calculator", + implementation=Calculator, + parameter_definitions={ + "code": { + "description": "The expression for the calculator to evaluate, it should be a valid mathematical expression.", + "type": "str", + "required": True, + } + }, + is_visible=False, + is_available=Calculator.is_available(), + category=ToolCategory.Function, + error_message=cls.generate_error_message(), + description="A powerful multi-purpose calculator capable of a wide array of math calculations.", + ) + async def call( self, parameters: dict, ctx: Any, **kwargs: Any ) -> List[Dict[str, Any]]: diff --git a/src/backend/tools/files.py b/src/backend/tools/files.py index 3b72b662ad..146a741c0e 100644 --- a/src/backend/tools/files.py +++ b/src/backend/tools/files.py @@ -2,6 +2,7 @@ from typing import Any, Dict, List import backend.crud.file as file_crud +from backend.schemas.tool import ToolCategory, ToolDefinition from backend.tools.base import BaseTool @@ -13,7 +14,7 @@ class ReadFileTool(BaseTool): Tool to read a file from the file system. """ - NAME = "read_file" + ID = "read_file" MAX_NUM_CHUNKS = 10 SEARCH_LIMIT = 5 @@ -24,6 +25,33 @@ def __init__(self): def is_available(cls) -> bool: return True + @classmethod + def get_tool_definition(cls) -> ToolDefinition: + return ToolDefinition( + name=cls.ID, + display_name="Read Document", + implementation=cls, + parameter_definitions={ + "file": { + "description": "A file represented as a tuple (filename, file ID) to read over", + "type": "tuple[str, str]", + "required": True, + } + }, + is_visible=True, + is_available=cls.is_available(), + error_message=cls.generate_error_message(), + category=ToolCategory.FileLoader, + description="Returns the chunked textual contents of an uploaded file.", + ) + + def get_info(cls) -> ToolDefinition: + return ToolDefinition( + display_name="Calculator", + description="A powerful multi-purpose calculator capable of a wide array of math calculations.", + error_message=cls.generate_error_message(), + ) + async def call(self, parameters: dict, **kwargs: Any) -> List[Dict[str, Any]]: file = parameters.get("file") @@ -50,7 +78,8 @@ class SearchFileTool(BaseTool): Tool to query a list of files. """ - NAME = "search_file" + ID = "search_file" + DISPLAY_NAME = "Search Files" MAX_NUM_CHUNKS = 10 SEARCH_LIMIT = 5 @@ -61,6 +90,31 @@ def __init__(self): def is_available(cls) -> bool: return True + @classmethod + def get_tool_definition(cls) -> ToolDefinition: + return ToolDefinition( + name=cls.ID, + display_name="Search File", + implementation=cls, + parameter_definitions={ + "search_query": { + "description": "Textual search query to search over the file's content for", + "type": "str", + "required": True, + }, + "files": { + "description": "A list of files represented as tuples of (filename, file ID) to search over", + "type": "list[tuple[str, str]]", + "required": True, + }, + }, + is_visible=True, + is_available=cls.is_available(), + error_message=cls.generate_error_message(), + category=ToolCategory.FileLoader, + description="Searches across one or more attached files based on a textual search query.", + ) + async def call( self, parameters: dict, ctx: Any, **kwargs: Any ) -> List[Dict[str, Any]]: diff --git a/src/backend/tools/google_drive/tool.py b/src/backend/tools/google_drive/tool.py index 3691b75b56..a8c732223b 100644 --- a/src/backend/tools/google_drive/tool.py +++ b/src/backend/tools/google_drive/tool.py @@ -4,8 +4,10 @@ from backend.config.settings import Settings from backend.crud import tool_auth as tool_auth_crud +from backend.schemas.tool import ToolCategory, ToolDefinition from backend.services.logger.utils import LoggerFactory from backend.tools.base import BaseTool +from backend.tools.google_drive.auth import GoogleDriveAuth from backend.tools.google_drive.constants import GOOGLE_DRIVE_TOOL_ID, SEARCH_LIMIT from backend.tools.google_drive.utils import ( extract_export_link, @@ -23,9 +25,7 @@ class GoogleDrive(BaseTool): """ Tool that searches Google Drive """ - - NAME = GOOGLE_DRIVE_TOOL_ID - + ID = GOOGLE_DRIVE_TOOL_ID CLIENT_ID = Settings().get('tools.google_drive.client_id') CLIENT_SECRET = Settings().get('tools.google_drive.client_secret') @@ -33,6 +33,27 @@ class GoogleDrive(BaseTool): def is_available(cls) -> bool: return cls.CLIENT_ID is not None and cls.CLIENT_SECRET is not None + @classmethod + def get_tool_definition(cls) -> ToolDefinition: + return ToolDefinition( + name=cls.ID, + display_name="Google Drive", + implementation=cls, + parameter_definitions={ + "query": { + "description": "Query to search Google Drive documents with.", + "type": "str", + "required": True, + } + }, + is_visible=True, + is_available=GoogleDrive.is_available(), + auth_implementation=GoogleDriveAuth, + error_message=cls.generate_error_message(), + category=ToolCategory.DataLoader, + description="Returns a list of relevant document snippets from the user's Google drive.", + ) + def _handle_tool_specific_errors(self, error: Exception, **kwargs: Any): message = "[Google Drive] Tool Error: {}".format(str(error)) diff --git a/src/backend/tools/google_search.py b/src/backend/tools/google_search.py index c8df4216e6..cc2ddc40cd 100644 --- a/src/backend/tools/google_search.py +++ b/src/backend/tools/google_search.py @@ -5,12 +5,13 @@ from backend.config.settings import Settings from backend.database_models.database import DBSessionDep from backend.schemas.agent import AgentToolMetadataArtifactsType +from backend.schemas.tool import ToolCategory, ToolDefinition from backend.tools.base import BaseTool from backend.tools.utils.mixins import WebSearchFilteringMixin class GoogleWebSearch(BaseTool, WebSearchFilteringMixin): - NAME = "google_web_search" + ID = "google_web_search" API_KEY = Settings().get('tools.google_web_search.api_key') CSE_ID = Settings().get('tools.google_web_search.cse_id') @@ -21,6 +22,26 @@ def __init__(self): def is_available(cls) -> bool: return bool(cls.API_KEY) and bool(cls.CSE_ID) + @classmethod + def get_tool_definition(cls) -> ToolDefinition: + return ToolDefinition( + name=cls.ID, + display_name="Google Web Search", + implementation=cls, + parameter_definitions={ + "query": { + "description": "A search query for the Google search engine.", + "type": "str", + "required": True, + } + }, + is_visible=False, + is_available=cls.is_available(), + error_message=cls.generate_error_message(), + category=ToolCategory.WebSearch, + description="Returns relevant results by performing a Google web search.", + ) + async def call( self, parameters: dict, ctx: Any, session: DBSessionDep, **kwargs: Any ) -> List[Dict[str, Any]]: diff --git a/src/backend/tools/hybrid_search.py b/src/backend/tools/hybrid_search.py index 8af1e98cc3..e6bf4973ec 100644 --- a/src/backend/tools/hybrid_search.py +++ b/src/backend/tools/hybrid_search.py @@ -6,6 +6,7 @@ from backend.database_models.database import DBSessionDep from backend.model_deployments.base import BaseDeployment from backend.schemas.agent import AgentToolMetadataArtifactsType +from backend.schemas.tool import ToolCategory, ToolDefinition from backend.tools.base import BaseTool from backend.tools.brave_search.tool import BraveWebSearch from backend.tools.google_search import GoogleWebSearch @@ -15,7 +16,7 @@ class HybridWebSearch(BaseTool, WebSearchFilteringMixin): - NAME = "hybrid_web_search" + ID = "hybrid_web_search" POST_RERANK_MAX_RESULTS = 6 AVAILABLE_WEB_SEARCH_TOOLS = [TavilyWebSearch, GoogleWebSearch, BraveWebSearch] ENABLED_WEB_SEARCH_TOOLS = Settings().get('tools.hybrid_web_search.enabled_web_searches') @@ -38,13 +39,36 @@ def is_available(cls) -> bool: # False if empty, True otherwise return bool(available_searches) + @classmethod + def get_tool_definition(cls) -> ToolDefinition: + return ToolDefinition( + name=cls.ID, + display_name="Hybrid Web Search", + implementation=cls, + parameter_definitions={ + "query": { + "description": "Query for retrieval.", + "type": "str", + "required": True, + } + }, + is_visible=True, + is_available=cls.is_available(), + error_message=cls.generate_error_message(), + category=ToolCategory.WebSearch, + description=( + "Returns a list of relevant document snippets for a textual query " + "retrieved from the internet using a mix of any existing Web Search tools." + ) + ) + @classmethod def get_available_search_tools(cls): available_search_tools = [] for search_name in cls.ENABLED_WEB_SEARCH_TOOLS: for search_tool in cls.AVAILABLE_WEB_SEARCH_TOOLS: - if search_name == search_tool.NAME and search_tool.is_available(): + if search_name == search_tool.ID and search_tool.is_available(): available_search_tools.append(search_tool) return available_search_tools diff --git a/src/backend/tools/lang_chain.py b/src/backend/tools/lang_chain.py index 9dd64f8eec..345d5f3d79 100644 --- a/src/backend/tools/lang_chain.py +++ b/src/backend/tools/lang_chain.py @@ -7,6 +7,7 @@ from langchain_community.vectorstores import Chroma from backend.config.settings import Settings +from backend.schemas.tool import ToolCategory, ToolDefinition from backend.tools.base import BaseTool """ @@ -22,8 +23,7 @@ class LangChainWikiRetriever(BaseTool): This class retrieves documents from Wikipedia using the langchain package. This requires wikipedia package to be installed. """ - - NAME = "wikipedia" + ID = "wikipedia" def __init__(self, chunk_size: int = 300, chunk_overlap: int = 0): self.chunk_size = chunk_size @@ -33,6 +33,27 @@ def __init__(self, chunk_size: int = 300, chunk_overlap: int = 0): def is_available(cls) -> bool: return True + @classmethod + def get_tool_definition(cls) -> ToolDefinition: + return ToolDefinition( + name=cls.ID, + display_name="Wikipedia", + implementation=cls, + parameter_definitions={ + "query": { + "description": "Query for retrieval.", + "type": "str", + "required": True, + } + }, + kwargs={"chunk_size": 300, "chunk_overlap": 0}, + is_visible=True, + is_available=cls.is_available(), + error_message=cls.generate_error_message(), + category=ToolCategory.DataLoader, + description="Retrieves documents from Wikipedia.", + ) + async def call( self, parameters: dict, ctx: Any, **kwargs: Any ) -> List[Dict[str, Any]]: @@ -58,8 +79,7 @@ class LangChainVectorDBRetriever(BaseTool): """ This class retrieves documents from a vector database using the langchain package. """ - - NAME = "vector_retriever" + ID = "vector_retriever" COHERE_API_KEY = Settings().get('deployments.cohere_platform.api_key') def __init__(self, filepath: str): diff --git a/src/backend/tools/python_interpreter.py b/src/backend/tools/python_interpreter.py index 3ebc664124..426844ab48 100644 --- a/src/backend/tools/python_interpreter.py +++ b/src/backend/tools/python_interpreter.py @@ -5,6 +5,7 @@ from dotenv import load_dotenv from backend.config.settings import Settings +from backend.schemas.tool import ToolCategory, ToolDefinition from backend.tools.base import BaseTool load_dotenv() @@ -16,13 +17,41 @@ class PythonInterpreter(BaseTool): It requires a URL at which the interpreter lives """ - NAME = "toolkit_python_interpreter" + ID = "toolkit_python_interpreter" INTERPRETER_URL = Settings().get('tools.python_interpreter.url') @classmethod def is_available(cls) -> bool: return cls.INTERPRETER_URL is not None + @classmethod + def get_tool_definition(cls) -> ToolDefinition: + return ToolDefinition( + name=cls.ID, + display_name="Python Interpreter", + implementation=cls, + parameter_definitions={ + "code": { + "description": ( + "Python code to execute using the Python interpreter with no internet access. " + "Do not generate code that tries to open files directly, instead use file contents passed to the interpreter, " + "then print output or save output to a file." + ), + "type": "str", + "required": True, + } + }, + is_visible=True, + is_available=cls.is_available(), + error_message=cls.generate_error_message(), + category=ToolCategory.Function, + description=( + "Executes python code and returns the result. The code runs " + "in a static sandbox without internet access and without interactive mode, " + "so print output or save output to a file." + ), + ) + async def call(self, parameters: dict, ctx: Any, **kwargs: Any): if not self.INTERPRETER_URL: raise Exception("Python Interpreter tool called while URL not set") diff --git a/src/backend/tools/slack/tool.py b/src/backend/tools/slack/tool.py index c1adee118e..35e78f0aea 100644 --- a/src/backend/tools/slack/tool.py +++ b/src/backend/tools/slack/tool.py @@ -2,8 +2,10 @@ from backend.config.settings import Settings from backend.crud import tool_auth as tool_auth_crud +from backend.schemas.tool import ToolCategory, ToolDefinition from backend.services.logger.utils import LoggerFactory from backend.tools.base import BaseTool +from backend.tools.slack.auth import SlackAuth from backend.tools.slack.constants import SEARCH_LIMIT, SLACK_TOOL_ID from backend.tools.slack.utils import get_slack_service @@ -15,7 +17,7 @@ class SlackTool(BaseTool): Tool that searches Slack for messages and files based on a query. """ - NAME = SLACK_TOOL_ID + ID = SLACK_TOOL_ID CLIENT_ID = Settings().get('tools.slack.client_id') CLIENT_SECRET = Settings().get('tools.slack.client_secret') @@ -23,6 +25,27 @@ class SlackTool(BaseTool): def is_available(cls) -> bool: return cls.CLIENT_ID is not None and cls.CLIENT_SECRET is not None + @classmethod + def get_tool_definition(cls) -> ToolDefinition: + return ToolDefinition( + name=cls.ID, + display_name="Slack", + implementation=cls, + parameter_definitions={ + "query": { + "description": "Query to search slack.", + "type": "str", + "required": True, + } + }, + is_visible=True, + is_available=cls.is_available(), + auth_implementation=SlackAuth, + error_message=cls.generate_error_message(), + category=ToolCategory.DataLoader, + description="Returns a list of relevant document snippets from slack.", + ) + @classmethod def _handle_tool_specific_errors(cls, error: Exception, **kwargs: Any) -> None: message = "[Slack] Tool Error: {}".format(str(error)) diff --git a/src/backend/tools/tavily_search.py b/src/backend/tools/tavily_search.py index abf30db883..0750a2517b 100644 --- a/src/backend/tools/tavily_search.py +++ b/src/backend/tools/tavily_search.py @@ -6,12 +6,13 @@ from backend.database_models.database import DBSessionDep from backend.model_deployments.base import BaseDeployment from backend.schemas.agent import AgentToolMetadataArtifactsType +from backend.schemas.tool import ToolCategory, ToolDefinition from backend.tools.base import BaseTool from backend.tools.utils.mixins import WebSearchFilteringMixin class TavilyWebSearch(BaseTool, WebSearchFilteringMixin): - NAME = "tavily_web_search" + ID = "tavily_web_search" TAVILY_API_KEY = Settings().get('tools.tavily_web_search.api_key') POST_RERANK_MAX_RESULTS = 6 @@ -22,6 +23,26 @@ def __init__(self): def is_available(cls) -> bool: return cls.TAVILY_API_KEY is not None + @classmethod + def get_tool_definition(cls) -> ToolDefinition: + return ToolDefinition( + name=cls.ID, + display_name="Web Search", + implementation=cls, + parameter_definitions={ + "query": { + "description": "Query to search the internet with", + "type": "str", + "required": True, + } + }, + is_visible=False, + is_available=cls.is_available(), + error_message=cls.generate_error_message(), + category=ToolCategory.WebSearch, + description="Returns a list of relevant document snippets for a textual query retrieved from the internet.", + ) + async def call( self, parameters: dict, ctx: Any, session: DBSessionDep, **kwargs: Any ) -> List[Dict[str, Any]]: @@ -57,19 +78,22 @@ async def call( # Append original search result expanded.append(result) - # Get other snippets - snippets = result["raw_content"].split("\n") - for snippet in snippets: - if result["content"] != snippet: - if len(snippet.split()) <= 10: - continue # Skip snippets with less than 10 words - - new_result = { - "url": result["url"], - "title": result["title"], - "content": snippet.strip(), - } - expanded.append(new_result) + # Retrieve snippets from raw content if exists + raw_content = result["raw_content"] + if raw_content: + # Get other snippets + snippets = result["raw_content"].split("\n") + for snippet in snippets: + if result["content"] != snippet: + if len(snippet.split()) <= 10: + continue # Skip snippets with less than 10 words + + new_result = { + "url": result["url"], + "title": result["title"], + "content": snippet.strip(), + } + expanded.append(new_result) reranked_results = await self.rerank_page_snippets( query, expanded, model=kwargs.get("model_deployment"), ctx=ctx, **kwargs diff --git a/src/backend/tools/utils/mixins.py b/src/backend/tools/utils/mixins.py index cd04024aab..ce4827140f 100644 --- a/src/backend/tools/utils/mixins.py +++ b/src/backend/tools/utils/mixins.py @@ -42,7 +42,7 @@ def get_filters( agent_tool_metadata = agent_tool_metadata_crud.get_agent_tool_metadata( db=session, agent_id=agent_id, - tool_name=self.NAME, + tool_name=self.ID, user_id=user_id, ) diff --git a/src/backend/tools/utils/tools_checkers.py b/src/backend/tools/utils/tools_checkers.py index 3a9acc66d3..f666cd7845 100644 --- a/src/backend/tools/utils/tools_checkers.py +++ b/src/backend/tools/utils/tools_checkers.py @@ -1,29 +1,29 @@ -from backend.schemas.tool import Category, ManagedTool -from community.config.tools import CommunityToolName +from backend.schemas.tool import ToolCategory, ToolDefinition +from community.config.tools import CommunityTool -def tool_has_category(tool: ManagedTool, category: Category) -> bool: +def tool_has_category(tool: ToolDefinition, category: ToolCategory) -> bool: """ Check if a tool has a specific category. Args: - tool (ManagedTool): The tool to check. - category (Category): The category to check for. + tool (ToolDefinition): The tool to check. + category (ToolCategory): The category to check for. Returns: - bool: True if the tool has the category, False otherwise. + bool: True if the tool has the category, False otherwise. """ return tool.category == category -def is_community_tool(tool: ManagedTool) -> bool: +def is_community_tool(tool: ToolDefinition) -> bool: """ Check if a tool is a community tool. Args: - tool (ManagedTool): The tool to check. + tool (ToolDefinition): The tool to check. Returns: - bool: True if the tool is a community tool, False otherwise. + bool: True if the tool is a community tool, False otherwise. """ - return tool.name in CommunityToolName + return tool.name in CommunityTool diff --git a/src/backend/tools/web_scrape.py b/src/backend/tools/web_scrape.py index 66ccf20f71..5479e951fe 100644 --- a/src/backend/tools/web_scrape.py +++ b/src/backend/tools/web_scrape.py @@ -4,6 +4,7 @@ import aiohttp from langchain_text_splitters import MarkdownHeaderTextSplitter +from backend.schemas.tool import ToolCategory, ToolDefinition from backend.services.logger.utils import LoggerFactory from backend.tools.base import BaseTool @@ -11,7 +12,7 @@ class WebScrapeTool(BaseTool): - NAME = "web_scrape" + ID = "web_scrape" ENDPOINT: ClassVar[str] = "http://co-reader" ENABLE_CHUNKING: ClassVar[bool] = True @@ -19,6 +20,31 @@ class WebScrapeTool(BaseTool): def is_available(cls) -> bool: return True + @classmethod + def get_tool_definition(cls) -> ToolDefinition: + return ToolDefinition( + name=cls.ID, + display_name="Web Scrape", + implementation=cls, + parameter_definitions={ + "url": { + "description": "The url to scrape.", + "type": "str", + "required": True, + }, + "query": { + "description": "The query to use to select the most relevant passages to return. Using an empty string will return the passages in the order they appear on the webpage", + "type": "str", + "required": False, + }, + }, + is_visible=False, + is_available=cls.is_available(), + error_message=cls.generate_error_message(), + category=ToolCategory.DataLoader, + description="Scrape and returns the textual contents of a webpage as a list of passages for a given url.", + ) + async def call( self, parameters: dict, ctx: Any, **kwargs: Any ) -> List[Dict[str, Any]]: diff --git a/src/community/config/tools.py b/src/community/config/tools.py index 7ea673690f..8b97859bb8 100644 --- a/src/community/config/tools.py +++ b/src/community/config/tools.py @@ -1,137 +1,31 @@ -from enum import StrEnum +from enum import Enum +from backend.schemas.tool import ToolDefinition from community.tools import ( ArxivRetriever, - Category, ClinicalTrials, ConnectorRetriever, LlamaIndexUploadPDFRetriever, - ManagedTool, PubMedRetriever, WolframAlpha, ) -class CommunityToolName(StrEnum): - Arxiv = ArxivRetriever.NAME - Connector = ConnectorRetriever.NAME - Pub_Med = PubMedRetriever.NAME - File_Upload_LlamaIndex = LlamaIndexUploadPDFRetriever.NAME - Wolfram_Alpha = WolframAlpha.NAME - ClinicalTrials = ClinicalTrials.NAME +class CommunityTool(Enum): + Arxiv = ArxivRetriever + Connector = ConnectorRetriever + Pub_Med = PubMedRetriever + File_Upload_LlamaIndex = LlamaIndexUploadPDFRetriever + Wolfram_Alpha = WolframAlpha + ClinicalTrials = ClinicalTrials -COMMUNITY_TOOLS = { - CommunityToolName.Arxiv: ManagedTool( - display_name="Arxiv", - implementation=ArxivRetriever, - parameter_definitions={ - "query": { - "description": "Query for retrieval.", - "type": "str", - "required": True, - } - }, - is_visible=True, - is_available=ArxivRetriever.is_available(), - error_message="ArxivRetriever is not available.", - category=Category.DataLoader, - description="Retrieves documents from Arxiv.", - ), - CommunityToolName.Connector: ManagedTool( - display_name="Example Connector", - implementation=ConnectorRetriever, - is_visible=True, - is_available=ConnectorRetriever.is_available(), - error_message="ConnectorRetriever is not available.", - category=Category.DataLoader, - description="Connects to a data source.", - ), - CommunityToolName.Pub_Med: ManagedTool( - display_name="PubMed", - implementation=PubMedRetriever, - parameter_definitions={ - "query": { - "description": "Query for retrieval.", - "type": "str", - "required": True, - } - }, - is_visible=True, - is_available=PubMedRetriever.is_available(), - error_message="PubMedRetriever is not available.", - category=Category.DataLoader, - description="Retrieves documents from Pub Med.", - ), - CommunityToolName.File_Upload_LlamaIndex: ManagedTool( - display_name="Llama File Reader", - implementation=LlamaIndexUploadPDFRetriever, - parameter_definitions={ - "query": { - "description": "Query for retrieval.", - "type": "str", - "required": True, - }, - "files": { - "description": "A list of files represented as tuples of (filename, file ID) to search over", - "type": "list[tuple[str, str]]", - "required": True, - }, +def get_community_tools() -> dict[str, ToolDefinition]: + # Get list of implementations from Tool Enum + tool_classes = [tool.value for tool in CommunityTool] + # Generate dictionary of ToolDefinitions keyed by Tool ID + community_tools = { + tool.ID: tool.get_tool_definition() for tool in tool_classes + } - }, - is_visible=True, - is_available=LlamaIndexUploadPDFRetriever.is_available(), - error_message="LlamaIndexUploadPDFRetriever is not available.", - category=Category.FileLoader, - description="Retrieves the most relevant documents from the uploaded files based on the query using Llama Index.", - ), - CommunityToolName.Wolfram_Alpha: ManagedTool( - display_name="Wolfram Alpha", - implementation=WolframAlpha, - is_visible=False, - is_available=WolframAlpha.is_available(), - error_message="WolframAlphaFunctionTool is not available, please set tools.wolfram_alpha.app_id in secrets.yaml", - category=Category.Function, - description="Evaluate arithmetic expressions.", - ), - CommunityToolName.ClinicalTrials: ManagedTool( - display_name="Clinical Trials", - implementation=ClinicalTrials, - is_visible=True, - is_available=ClinicalTrials.is_available(), - error_message="ClinicalTrialsTool is not available.", - category=Category.Function, - description="Retrieves clinical studies from ClinicalTrials.gov.", - parameter_definitions={ - "condition": { - "description": "Filters clinical studies to a specified disease or condition", - "type": "str", - "required": False, - }, - "location": { - "description": "Filters clinical studies to a specified city, state, or country.", - "type": "str", - "required": False, - }, - "intervention": { - "description": "Filters clinical studies to a specified drug or treatment.", - "type": "str", - "required": False, - }, - "is_recruiting": { - "description": "Filters clinical studies to those that are actively recruiting.", - "type": "bool", - "required": False, - }, - }, - ), -} - -# For main.py cli setup script -COMMUNITY_TOOLS_SETUP = { - CommunityToolName.Wolfram_Alpha: { - "secrets": { - "WOLFRAM_APP_ID": None, # default value - }, - }, -} + return community_tools diff --git a/src/community/tools/__init__.py b/src/community/tools/__init__.py index 86a0013172..1cffba1972 100644 --- a/src/community/tools/__init__.py +++ b/src/community/tools/__init__.py @@ -1,5 +1,3 @@ -from backend.schemas.tool import Category, ManagedTool -from backend.tools.base import BaseTool from community.tools.arxiv import ArxivRetriever from community.tools.clinicaltrials import ClinicalTrials from community.tools.connector import ConnectorRetriever @@ -14,7 +12,4 @@ "ConnectorRetriever", "LlamaIndexUploadPDFRetriever", "PubMedRetriever", - "Category", - "ManagedTool", - "BaseTool", ] diff --git a/src/community/tools/arxiv.py b/src/community/tools/arxiv.py index 7d92d87549..ce5cfac71c 100644 --- a/src/community/tools/arxiv.py +++ b/src/community/tools/arxiv.py @@ -2,11 +2,12 @@ from langchain_community.utilities import ArxivAPIWrapper -from community.tools import BaseTool +from backend.schemas.tool import ToolCategory, ToolDefinition +from backend.tools.base import BaseTool class ArxivRetriever(BaseTool): - NAME = "arxiv" + ID = "arxiv" def __init__(self): self.client = ArxivAPIWrapper() @@ -15,6 +16,26 @@ def __init__(self): def is_available(cls) -> bool: return True + @classmethod + def get_tool_definition(cls) -> ToolDefinition: + return ToolDefinition( + name=cls.ID, + display_name="Arxiv", + implementation=cls, + parameter_definitions={ + "query": { + "description": "Query for retrieval.", + "type": "str", + "required": True, + } + }, + is_visible=False, + is_available=cls.is_available(), + error_message=cls.generate_error_message(), + category=ToolCategory.DataLoader, + description="Retrieves documents from Arxiv.", + ) + async def call(self, parameters: dict, **kwargs: Any) -> List[Dict[str, Any]]: query = parameters.get("query", "") result = self.client.run(query) diff --git a/src/community/tools/clinicaltrials.py b/src/community/tools/clinicaltrials.py index 0a8af52aed..3db15271ac 100644 --- a/src/community/tools/clinicaltrials.py +++ b/src/community/tools/clinicaltrials.py @@ -2,7 +2,8 @@ import requests -from community.tools import BaseTool +from backend.schemas.tool import ToolCategory, ToolDefinition +from backend.tools.base import BaseTool class ClinicalTrials(BaseTool): @@ -12,7 +13,7 @@ class ClinicalTrials(BaseTool): See: https://clinicaltrials.gov/data-api/api """ - NAME = "clinical_trials" + ID = "clinical_trials" def __init__(self, url="https://clinicaltrials.gov/api/v2/studies"): self._url = url @@ -21,6 +22,41 @@ def __init__(self, url="https://clinicaltrials.gov/api/v2/studies"): def is_available(cls) -> bool: return True + @classmethod + def get_tool_definition(cls) -> ToolDefinition: + return ToolDefinition( + name=cls.ID, + display_name="Clinical Trials", + implementation=cls, + is_visible=False, + is_available=cls.is_available(), + error_message=cls.generate_error_message(), + category=ToolCategory.Function, + description="Retrieves clinical studies from ClinicalTrials.gov.", + parameter_definitions={ + "condition": { + "description": "Filters clinical studies to a specified disease or condition", + "type": "str", + "required": False, + }, + "location": { + "description": "Filters clinical studies to a specified city, state, or country.", + "type": "str", + "required": False, + }, + "intervention": { + "description": "Filters clinical studies to a specified drug or treatment.", + "type": "str", + "required": False, + }, + "is_recruiting": { + "description": "Filters clinical studies to those that are actively recruiting.", + "type": "bool", + "required": False, + }, + }, + ) + async def call( self, parameters: Dict[str, Any], n_max_studies: int = 10, **kwargs ) -> List[Dict[str, Any]]: diff --git a/src/community/tools/connector.py b/src/community/tools/connector.py index a2af411a6a..b19445ddad 100644 --- a/src/community/tools/connector.py +++ b/src/community/tools/connector.py @@ -2,7 +2,8 @@ import requests -from community.tools import BaseTool +from backend.schemas.tool import ToolCategory, ToolDefinition +from backend.tools.base import BaseTool """ Plug in your Connector configuration here. For example: @@ -10,28 +11,43 @@ Url: http://example_connector.com/search Auth: Bearer token for the connector +To see SSO examples, check out our Google Drive or Slack tool implementations + More details: https://docs.cohere.com/docs/connectors """ class ConnectorRetriever(BaseTool): - NAME = "example_connector" + ID = "example_connector" - def __init__(self, url: str, auth: str): + def __init__(self, url: str, api_key: str): self.url = url - self.auth = auth + self.api_key = api_key @classmethod def is_available(cls) -> bool: - return True + return False + + @classmethod + def get_tool_definition(cls) -> ToolDefinition: + return ToolDefinition( + name=cls.ID, + display_name="Example Connector Template - Do not use", + implementation=ConnectorRetriever, + is_visible=False, + is_available=cls.is_available(), + error_message=cls.generate_error_message(), + category=ToolCategory.DataLoader, + description="Example connector for a data source using a basic API.", + ) async def call(self, parameters: dict, **kwargs: Any) -> List[Dict[str, Any]]: body = {"query": parameters} headers = { "Content-Type": "application/json", - "Authorization": f"Bearer {self.auth}", + "Authorization": f"Bearer {self.api_key}", } - response = requests.post(self.url, json=body, headers=headers) + response = requests.get(self.url, json=body, headers=headers) return response.json()["results"] diff --git a/src/community/tools/llama_index.py b/src/community/tools/llama_index.py index 6cdef8da4c..aafdc1b491 100644 --- a/src/community/tools/llama_index.py +++ b/src/community/tools/llama_index.py @@ -7,7 +7,8 @@ import backend.crud.file as file_crud from backend.config import Settings -from community.tools import BaseTool +from backend.schemas.tool import ToolCategory, ToolDefinition +from backend.tools.base import BaseTool """ Plug in your llama index retrieval implementation here. @@ -25,7 +26,7 @@ class LlamaIndexUploadPDFRetriever(BaseTool): This requires llama_index package to be installed. """ - NAME = "file_reader_llamaindex" + ID = "file_reader_llamaindex" CHUNK_SIZE = 512 def __init__(self): @@ -39,11 +40,39 @@ def _get_embedding(self, embed_type): input_type=embed_type, ) - @classmethod def is_available(cls) -> bool: return True + @classmethod + def get_tool_definition(cls) -> ToolDefinition: + return ToolDefinition( + name=cls.ID, + display_name="Llama File Reader", + implementation=cls, + parameter_definitions={ + "query": { + "description": "Query for retrieval.", + "type": "str", + "required": True, + }, + "files": { + "description": "A list of files represented as tuples of (filename, file ID) to search over", + "type": "list[tuple[str, str]]", + "required": True, + }, + + }, + is_visible=False, + is_available=cls.is_available(), + error_message=cls.generate_error_message(), + category=ToolCategory.FileLoader, + description=( + "Retrieves the most relevant documents from the uploaded " + "files based on the query using Llama Index." + ) + ) + async def call( self, parameters: dict, ctx: Any, **kwargs: Any ) -> List[Dict[str, Any]]: diff --git a/src/community/tools/pub_med.py b/src/community/tools/pub_med.py index 1ce46c7f80..6968e57ea3 100644 --- a/src/community/tools/pub_med.py +++ b/src/community/tools/pub_med.py @@ -2,11 +2,12 @@ from langchain_community.tools.pubmed.tool import PubmedQueryRun -from community.tools import BaseTool +from backend.schemas.tool import ToolCategory, ToolDefinition +from backend.tools.base import BaseTool class PubMedRetriever(BaseTool): - NAME = "pub_med" + ID = "pub_med" def __init__(self): self.client = PubmedQueryRun() @@ -15,6 +16,26 @@ def __init__(self): def is_available(cls) -> bool: return True + @classmethod + def get_tool_definition(cls) -> ToolDefinition: + return ToolDefinition( + name=cls.ID, + display_name="Pub Med", + implementation=cls, + parameter_definitions={ + "query": { + "description": "Query for retrieval.", + "type": "str", + "required": True, + } + }, + is_visible=False, + is_available=cls.is_available(), + error_message=cls.generate_error_message(), + category=ToolCategory.DataLoader, + description="Retrieves documents from Pub Med.", + ) + async def call(self, parameters: dict, **kwargs: Any) -> List[Dict[str, Any]]: query = parameters.get("query", "") result = self.client.invoke(query) diff --git a/src/community/tools/wolfram.py b/src/community/tools/wolfram.py index 9fc022ebab..dc098e77ed 100644 --- a/src/community/tools/wolfram.py +++ b/src/community/tools/wolfram.py @@ -3,7 +3,8 @@ from langchain_community.utilities.wolfram_alpha import WolframAlphaAPIWrapper from backend.config.settings import Settings -from community.tools import BaseTool +from backend.schemas.tool import ToolCategory, ToolDefinition +from backend.tools.base import BaseTool class WolframAlpha(BaseTool): @@ -13,7 +14,7 @@ class WolframAlpha(BaseTool): See: https://python.langchain.com/docs/integrations/tools/wolfram_alpha/ """ - NAME = "wolfram_alpha" + ID = "wolfram_alpha" wolfram_app_id = Settings().get('tools.wolfram_alpha.app_id') @@ -25,6 +26,19 @@ def __init__(self): def is_available(cls) -> bool: return cls.wolfram_app_id is not None + @classmethod + def get_tool_definition(cls) -> ToolDefinition: + return ToolDefinition( + name=cls.ID, + display_name="Wolfram Alpha", + implementation=cls, + is_visible=True, + is_available=cls.is_available(), + error_message=cls.generate_error_message(), + category=ToolCategory.Function, + description="Evaluate arithmetic expressions using Wolfram Alpha.", + ) + async def call(self, parameters: dict, **kwargs: Any) -> List[Dict[str, Any]]: to_evaluate = parameters.get("expression", "") result = self.tool.run(to_evaluate) diff --git a/src/interfaces/assistants_web/src/app/(main)/(chat)/Chat.tsx b/src/interfaces/assistants_web/src/app/(main)/(chat)/Chat.tsx index 68fc43bbee..9d0a5aa71b 100644 --- a/src/interfaces/assistants_web/src/app/(main)/(chat)/Chat.tsx +++ b/src/interfaces/assistants_web/src/app/(main)/(chat)/Chat.tsx @@ -2,7 +2,7 @@ import { useEffect } from 'react'; -import { Document, ManagedTool } from '@/cohere-client'; +import { Document, ToolDefinition } from '@/cohere-client'; import { Conversation, ConversationError } from '@/components/Conversation'; import { TOOL_PYTHON_INTERPRETER_ID } from '@/constants'; import { useAgent, useAvailableTools, useConversation, useListTools } from '@/hooks'; @@ -24,7 +24,7 @@ const Chat: React.FC<{ agentId?: string; conversationId?: string }> = ({ const { setConversation } = useConversationStore(); const { addCitation, saveOutputFiles } = useCitationsStore(); const { setParams, resetFileParams } = useParamsStore(); - const { availableTools } = useAvailableTools({ agent, managedTools: tools }); + const { availableTools } = useAvailableTools({ agent, allTools: tools }); const { data: conversation, @@ -44,7 +44,7 @@ const Chat: React.FC<{ agentId?: string; conversationId?: string }> = ({ .map((name) => (tools ?? [])?.find((t) => t.name === name)) .filter( (t) => t !== undefined && availableTools.some((at) => at.name === t?.name) - ) as ManagedTool[]); + ) as ToolDefinition[]); const fileIds = conversation?.files.map((file) => file.id); diff --git a/src/interfaces/assistants_web/src/cohere-client/generated/schemas.gen.ts b/src/interfaces/assistants_web/src/cohere-client/generated/schemas.gen.ts index 81dfd72bcd..2043c8e93d 100644 --- a/src/interfaces/assistants_web/src/cohere-client/generated/schemas.gen.ts +++ b/src/interfaces/assistants_web/src/cohere-client/generated/schemas.gen.ts @@ -270,12 +270,6 @@ export const $Body_batch_upload_file_v1_conversations_batch_upload_file_post = { title: 'Body_batch_upload_file_v1_conversations_batch_upload_file_post', } as const; -export const $Category = { - type: 'string', - enum: ['Data loader', 'File loader', 'Function', 'Web search'], - title: 'Category', -} as const; - export const $ChatMessage = { properties: { role: { @@ -1919,114 +1913,6 @@ export const $Logout = { title: 'Logout', } as const; -export const $ManagedTool = { - properties: { - name: { - anyOf: [ - { - type: 'string', - }, - { - type: 'null', - }, - ], - title: 'Name', - default: '', - }, - display_name: { - type: 'string', - title: 'Display Name', - default: '', - }, - description: { - anyOf: [ - { - type: 'string', - }, - { - type: 'null', - }, - ], - title: 'Description', - default: '', - }, - parameter_definitions: { - anyOf: [ - { - type: 'object', - }, - { - type: 'null', - }, - ], - title: 'Parameter Definitions', - default: {}, - }, - kwargs: { - type: 'object', - title: 'Kwargs', - default: {}, - }, - is_visible: { - type: 'boolean', - title: 'Is Visible', - default: false, - }, - is_available: { - type: 'boolean', - title: 'Is Available', - default: false, - }, - error_message: { - anyOf: [ - { - type: 'string', - }, - { - type: 'null', - }, - ], - title: 'Error Message', - default: '', - }, - category: { - $ref: '#/components/schemas/Category', - default: 'Data loader', - }, - is_auth_required: { - type: 'boolean', - title: 'Is Auth Required', - default: false, - }, - auth_url: { - anyOf: [ - { - type: 'string', - }, - { - type: 'null', - }, - ], - title: 'Auth Url', - default: '', - }, - token: { - anyOf: [ - { - type: 'string', - }, - { - type: 'null', - }, - ], - title: 'Token', - default: '', - }, - }, - type: 'object', - title: 'ManagedTool', -} as const; - export const $Message = { properties: { text: { @@ -3228,6 +3114,120 @@ export const $ToolCallDelta = { title: 'ToolCallDelta', } as const; +export const $ToolCategory = { + type: 'string', + enum: ['Data loader', 'File loader', 'Function', 'Web search'], + title: 'ToolCategory', +} as const; + +export const $ToolDefinition = { + properties: { + name: { + anyOf: [ + { + type: 'string', + }, + { + type: 'null', + }, + ], + title: 'Name', + default: '', + }, + display_name: { + type: 'string', + title: 'Display Name', + default: '', + }, + description: { + anyOf: [ + { + type: 'string', + }, + { + type: 'null', + }, + ], + title: 'Description', + default: '', + }, + parameter_definitions: { + anyOf: [ + { + type: 'object', + }, + { + type: 'null', + }, + ], + title: 'Parameter Definitions', + default: {}, + }, + kwargs: { + type: 'object', + title: 'Kwargs', + default: {}, + }, + is_visible: { + type: 'boolean', + title: 'Is Visible', + default: false, + }, + is_available: { + type: 'boolean', + title: 'Is Available', + default: false, + }, + error_message: { + anyOf: [ + { + type: 'string', + }, + { + type: 'null', + }, + ], + title: 'Error Message', + default: '', + }, + category: { + $ref: '#/components/schemas/ToolCategory', + default: 'Data loader', + }, + is_auth_required: { + type: 'boolean', + title: 'Is Auth Required', + default: false, + }, + auth_url: { + anyOf: [ + { + type: 'string', + }, + { + type: 'null', + }, + ], + title: 'Auth Url', + default: '', + }, + token: { + anyOf: [ + { + type: 'string', + }, + { + type: 'null', + }, + ], + title: 'Token', + default: '', + }, + }, + type: 'object', + title: 'ToolDefinition', +} as const; + export const $ToolInputType = { type: 'string', enum: ['QUERY', 'CODE'], diff --git a/src/interfaces/assistants_web/src/cohere-client/generated/services.gen.ts b/src/interfaces/assistants_web/src/cohere-client/generated/services.gen.ts index f281bd5624..2bc613de69 100644 --- a/src/interfaces/assistants_web/src/cohere-client/generated/services.gen.ts +++ b/src/interfaces/assistants_web/src/cohere-client/generated/services.gen.ts @@ -293,7 +293,7 @@ export class DefaultService { * If completed, the corresponding ToolAuth for the requesting user is removed from the DB. * * Args: - * tool_id (str): Tool ID to be deleted for the user. (eg. google_drive) Should be one of the values listed in the ToolName string enum class. + * tool_id (str): Tool ID to be deleted for the user. (eg. google_drive) Should be one of the values listed in the Tool enum. * request (Request): current Request object. * session (DBSessionDep): Database session. * ctx (Context): Context object. @@ -990,10 +990,10 @@ export class DefaultService { * agent_id (str): Agent ID. * ctx (Context): Context object. * Returns: - * list[ManagedTool]: List of available tools. + * list[ToolDefinition]: List of available tools. * @param data The data for the request. * @param data.agentId - * @returns ManagedTool Successful Response + * @returns ToolDefinition Successful Response * @throws ApiError */ public listToolsV1ToolsGet( @@ -1785,7 +1785,7 @@ export class DefaultService { * session (DBSessionDep): Database session. * * Returns: - * list[ManagedTool]: List of available organizations. + * list[Organization]: List of available organizations. * @returns Organization Successful Response * @throws ApiError */ @@ -1866,9 +1866,10 @@ export class DefaultService { * Args: * organization_id (str): Tool ID. * session (DBSessionDep): Database session. + * ctx: Context. * * Returns: - * ManagedTool: Organization with the given ID. + * Organization: Organization with the given ID. * @param data The data for the request. * @param data.organizationId * @returns Organization Successful Response diff --git a/src/interfaces/assistants_web/src/cohere-client/generated/types.gen.ts b/src/interfaces/assistants_web/src/cohere-client/generated/types.gen.ts index c6cb588614..c8329f7488 100644 --- a/src/interfaces/assistants_web/src/cohere-client/generated/types.gen.ts +++ b/src/interfaces/assistants_web/src/cohere-client/generated/types.gen.ts @@ -56,13 +56,6 @@ export type Body_batch_upload_file_v1_conversations_batch_upload_file_post = { files: Array; }; -export enum Category { - DATA_LOADER = 'Data loader', - FILE_LOADER = 'File loader', - FUNCTION = 'Function', - WEB_SEARCH = 'Web search', -} - /** * A list of previous messages between the user and the model, meant to give the model conversational context for responding to the user's message. */ @@ -393,25 +386,6 @@ export type Login = { export type Logout = unknown; -export type ManagedTool = { - name?: string | null; - display_name?: string; - description?: string | null; - parameter_definitions?: { - [key: string]: unknown; - } | null; - kwargs?: { - [key: string]: unknown; - }; - is_visible?: boolean; - is_available?: boolean; - error_message?: string | null; - category?: Category; - is_auth_required?: boolean; - auth_url?: string | null; - token?: string | null; -}; - export type Message = { text: string; id: string; @@ -677,6 +651,32 @@ export type ToolCallDelta = { parameters: string | null; }; +export enum ToolCategory { + DATA_LOADER = 'Data loader', + FILE_LOADER = 'File loader', + FUNCTION = 'Function', + WEB_SEARCH = 'Web search', +} + +export type ToolDefinition = { + name?: string | null; + display_name?: string; + description?: string | null; + parameter_definitions?: { + [key: string]: unknown; + } | null; + kwargs?: { + [key: string]: unknown; + }; + is_visible?: boolean; + is_available?: boolean; + error_message?: string | null; + category?: ToolCategory; + is_auth_required?: boolean; + auth_url?: string | null; + token?: string | null; +}; + /** * Type of input passed to the tool */ @@ -961,7 +961,7 @@ export type ListToolsV1ToolsGetData = { agentId?: string | null; }; -export type ListToolsV1ToolsGetResponse = Array; +export type ListToolsV1ToolsGetResponse = Array; export type CreateDeploymentV1DeploymentsPostData = { requestBody: DeploymentCreate; @@ -1611,7 +1611,7 @@ export type $OpenApiTs = { /** * Successful Response */ - 200: Array; + 200: Array; /** * Validation Error */ diff --git a/src/interfaces/assistants_web/src/components/AgentSettingsForm/ToolsStep.tsx b/src/interfaces/assistants_web/src/components/AgentSettingsForm/ToolsStep.tsx index 595dd0812b..b712c5d890 100644 --- a/src/interfaces/assistants_web/src/components/AgentSettingsForm/ToolsStep.tsx +++ b/src/interfaces/assistants_web/src/components/AgentSettingsForm/ToolsStep.tsx @@ -1,12 +1,12 @@ import Link from 'next/link'; -import { ManagedTool } from '@/cohere-client'; +import { ToolDefinition } from '@/cohere-client'; import { StatusConnection } from '@/components/AgentSettingsForm/StatusConnection'; import { Button, Icon, IconName, Switch, Text } from '@/components/UI'; import { AGENT_SETTINGS_TOOLS, TOOL_FALLBACK_ICON, TOOL_ID_TO_DISPLAY_INFO } from '@/constants'; type Props = { - tools?: ManagedTool[]; + tools?: ToolDefinition[]; activeTools?: string[]; setActiveTools: (tools: string[]) => void; handleAuthButtonClick: (toolName: string) => void; diff --git a/src/interfaces/assistants_web/src/components/Composer/Composer.tsx b/src/interfaces/assistants_web/src/components/Composer/Composer.tsx index 418109480d..90580c0411 100644 --- a/src/interfaces/assistants_web/src/components/Composer/Composer.tsx +++ b/src/interfaces/assistants_web/src/components/Composer/Composer.tsx @@ -3,7 +3,7 @@ import { useResizeObserver } from '@react-hookz/web'; import React, { useEffect, useRef, useState } from 'react'; -import { AgentPublic, ManagedTool } from '@/cohere-client'; +import { AgentPublic, ToolDefinition } from '@/cohere-client'; import { ComposerError, ComposerFiles, ComposerToolbar } from '@/components/Composer'; import { DragDropFileInput, Icon, STYLE_LEVEL_TO_CLASSES } from '@/components/UI'; import { CHAT_COMPOSER_TEXTAREA_ID } from '@/constants'; @@ -21,7 +21,7 @@ type Props = { onChange: (message: string) => void; onUploadFile: (files: File[]) => void; agent?: AgentPublic; - tools?: ManagedTool[]; + tools?: ToolDefinition[]; chatWindowRef?: React.RefObject; lastUserMessage?: ChatMessage; }; @@ -42,7 +42,7 @@ export const Composer: React.FC = ({ const breakpoint = useBreakpoint(); const isSmallBreakpoint = breakpoint === 'sm'; const textareaRef = useRef(null); - const { unauthedTools } = useAvailableTools({ agent, managedTools: tools }); + const { unauthedTools } = useAvailableTools({ agent, allTools: tools }); const isToolAuthRequired = unauthedTools.length > 0; const [chatWindowHeight, setChatWindowHeight] = useState(0); diff --git a/src/interfaces/assistants_web/src/components/Composer/ComposerToolbar.tsx b/src/interfaces/assistants_web/src/components/Composer/ComposerToolbar.tsx index 75df594412..4b78be2838 100644 --- a/src/interfaces/assistants_web/src/components/Composer/ComposerToolbar.tsx +++ b/src/interfaces/assistants_web/src/components/Composer/ComposerToolbar.tsx @@ -2,13 +2,13 @@ import React from 'react'; -import { AgentPublic, ManagedTool } from '@/cohere-client'; +import { AgentPublic, ToolDefinition } from '@/cohere-client'; import { DataSourceMenu, FilesMenu } from '@/components/Composer'; import { cn } from '@/utils'; type Props = { agent?: AgentPublic; - tools?: ManagedTool[]; + tools?: ToolDefinition[]; onUploadFile: (files: File[]) => void; }; diff --git a/src/interfaces/assistants_web/src/components/Composer/DataSourceMenu.tsx b/src/interfaces/assistants_web/src/components/Composer/DataSourceMenu.tsx index 1bbdd5099d..24761cb6d8 100644 --- a/src/interfaces/assistants_web/src/components/Composer/DataSourceMenu.tsx +++ b/src/interfaces/assistants_web/src/components/Composer/DataSourceMenu.tsx @@ -3,7 +3,7 @@ import { Popover, PopoverButton, PopoverPanel } from '@headlessui/react'; import React from 'react'; -import { AgentPublic, ManagedTool } from '@/cohere-client'; +import { AgentPublic, ToolDefinition } from '@/cohere-client'; import { Icon, Switch, Text } from '@/components/UI'; import { useAvailableTools, useBrandedColors } from '@/hooks'; import { useParamsStore } from '@/stores'; @@ -11,7 +11,7 @@ import { checkIsBaseAgent, cn, getToolIcon } from '@/utils'; export type Props = { agent?: AgentPublic; - tools?: ManagedTool[]; + tools?: ToolDefinition[]; }; /** @@ -23,7 +23,7 @@ export const DataSourceMenu: React.FC = ({ agent, tools }) => { } = useParamsStore(); const { availableTools, handleToggle } = useAvailableTools({ agent, - managedTools: tools, + allTools: tools, }); const { text, contrastText, border, bg } = useBrandedColors(agent?.id); diff --git a/src/interfaces/assistants_web/src/components/Conversation/Conversation.tsx b/src/interfaces/assistants_web/src/components/Conversation/Conversation.tsx index 4e6c1158e0..336033fb5c 100644 --- a/src/interfaces/assistants_web/src/components/Conversation/Conversation.tsx +++ b/src/interfaces/assistants_web/src/components/Conversation/Conversation.tsx @@ -2,7 +2,7 @@ import React, { useRef } from 'react'; -import { AgentPublic, ManagedTool } from '@/cohere-client'; +import { AgentPublic, ToolDefinition } from '@/cohere-client'; import { Composer } from '@/components/Composer'; import { Header } from '@/components/Conversation'; import { MessagingContainer, WelcomeGuideTooltip } from '@/components/MessagingContainer'; @@ -19,7 +19,7 @@ import { ChatMessage } from '@/types/message'; type Props = { startOptionsEnabled?: boolean; agent?: AgentPublic; - tools?: ManagedTool[]; + tools?: ToolDefinition[]; history?: ChatMessage[]; }; diff --git a/src/interfaces/assistants_web/src/components/MessagingContainer/AssistantTools.tsx b/src/interfaces/assistants_web/src/components/MessagingContainer/AssistantTools.tsx index 84b7c56482..9faad2f7bf 100644 --- a/src/interfaces/assistants_web/src/components/MessagingContainer/AssistantTools.tsx +++ b/src/interfaces/assistants_web/src/components/MessagingContainer/AssistantTools.tsx @@ -2,7 +2,7 @@ import React from 'react'; -import { AgentPublic, ManagedTool } from '@/cohere-client'; +import { AgentPublic, ToolDefinition } from '@/cohere-client'; import { WelcomeGuideTooltip } from '@/components/MessagingContainer'; import { Button, Icon, Text, ToggleCard } from '@/components/UI'; import { useAvailableTools } from '@/hooks'; @@ -13,7 +13,7 @@ import { checkIsBaseAgent, cn, getToolIcon } from '@/utils'; * @description Tools for the assistant to use in the conversation. */ export const AssistantTools: React.FC<{ - tools: ManagedTool[]; + tools: ToolDefinition[]; agent?: AgentPublic; className?: string; }> = ({ tools, agent, className = '' }) => { @@ -23,7 +23,7 @@ export const AssistantTools: React.FC<{ const enabledTools = paramTools ?? []; const { availableTools, unauthedTools, handleToggle } = useAvailableTools({ agent, - managedTools: tools, + allTools: tools, }); if (availableTools.length === 0) return null; diff --git a/src/interfaces/assistants_web/src/hooks/use-tools.ts b/src/interfaces/assistants_web/src/hooks/use-tools.ts index 584e437f1d..2c4959f9eb 100644 --- a/src/interfaces/assistants_web/src/hooks/use-tools.ts +++ b/src/interfaces/assistants_web/src/hooks/use-tools.ts @@ -3,7 +3,7 @@ import { useMemo } from 'react'; import useDrivePicker from 'react-google-drive-picker'; import type { PickerCallback } from 'react-google-drive-picker/dist/typeDefs'; -import { AgentPublic, ApiError, ManagedTool, useCohereClient } from '@/cohere-client'; +import { AgentPublic, ApiError, ToolDefinition, useCohereClient } from '@/cohere-client'; import { BASE_AGENT_EXCLUDED_TOOLS, DEFAULT_AGENT_TOOLS, TOOL_GOOGLE_DRIVE_ID } from '@/constants'; import { env } from '@/env.mjs'; import { useNotify } from '@/hooks'; @@ -13,7 +13,7 @@ import { checkIsBaseAgent } from '@/utils'; export const useListTools = (enabled: boolean = true) => { const client = useCohereClient(); - return useQuery({ + return useQuery({ queryKey: ['tools'], queryFn: async () => { const tools = await client.listTools({}); @@ -84,10 +84,10 @@ export const useOpenGoogleDrivePicker = (callbackFunction: (data: PickerCallback export const useAvailableTools = ({ agent, - managedTools, + allTools, }: { agent?: AgentPublic; - managedTools?: ManagedTool[]; + allTools?: ToolDefinition[]; }) => { const requiredTools = agent?.tools; @@ -106,14 +106,14 @@ export const useAvailableTools = ({ ) ?? []; const availableTools = useMemo(() => { - return (managedTools ?? []).filter( + return (allTools ?? []).filter( (t) => t.is_visible && t.is_available && (!requiredTools || requiredTools.some((rt) => rt === t.name)) && !(isBaseAgent && BASE_AGENT_EXCLUDED_TOOLS.some((rt) => rt === t.name)) ); - }, [managedTools, requiredTools]); + }, [allTools, requiredTools]); const handleToggle = (name: string, checked: boolean) => { const newParams: Partial = { From 63f80ec2bfde1d912c497d385cbd5c0a92fdde18 Mon Sep 17 00:00:00 2001 From: Tianjing Li Date: Fri, 8 Nov 2024 11:48:26 -0800 Subject: [PATCH 03/14] backend: Make post_init_check an instance method (#834) Make post_init_check an instance method --- src/backend/services/auth/strategies/base.py | 7 +++---- src/backend/tools/base.py | 16 +++++++--------- 2 files changed, 10 insertions(+), 13 deletions(-) diff --git a/src/backend/services/auth/strategies/base.py b/src/backend/services/auth/strategies/base.py index e305545824..38112519a2 100644 --- a/src/backend/services/auth/strategies/base.py +++ b/src/backend/services/auth/strategies/base.py @@ -43,14 +43,13 @@ class BaseOAuthStrategy: def __init__(self, *args, **kwargs): self._post_init_check() - @classmethod - def _post_init_check(cls): + def _post_init_check(self): if any( [ - cls.NAME is None, + self.NAME is None, ] ): - raise ValueError(f"{cls.__name__} must have NAME attribute defined.") + raise ValueError(f"{self.__name__} must have NAME attribute defined.") @abstractmethod def get_client_id(self, **kwargs: Any): diff --git a/src/backend/tools/base.py b/src/backend/tools/base.py index 203a9328dd..af5456b217 100644 --- a/src/backend/tools/base.py +++ b/src/backend/tools/base.py @@ -26,10 +26,9 @@ class BaseTool(): def __init__(self, *args, **kwargs): self._post_init_check() - @classmethod - def _post_init_check(cls): - if cls.ID is None: - raise ValueError(f"{cls.__name__} must have ID attribute defined.") + def _post_init_check(self): + if self.ID is None: + raise ValueError(f"{self.__name__} must have ID attribute defined.") @classmethod @abstractmethod @@ -68,13 +67,12 @@ def __init__(self, *args, **kwargs): self._post_init_check() - @classmethod - def _post_init_check(cls): + def _post_init_check(self): if any( [ - cls.BACKEND_HOST is None, - cls.FRONTEND_HOST is None, - cls.AUTH_SECRET_KEY is None, + self.BACKEND_HOST is None, + self.FRONTEND_HOST is None, + self.AUTH_SECRET_KEY is None, ] ): raise ValueError( From fa642356a775841c6c05f5383f3103cef7e9e5af Mon Sep 17 00:00:00 2001 From: Eugene P <144219719+EugeneLightsOn@users.noreply.github.com> Date: Fri, 8 Nov 2024 22:06:12 +0100 Subject: [PATCH 04/14] Fix the auth tools issue (#829) * TLK-1999 Fix the auth tools issue: display an error message if a tool is enabled and visible but not available. * TLK-1999 Fix the slack.md * TLK-1999 Fix typing --- docs/custom_tool_guides/slack.md | 4 +- src/backend/routers/tool.py | 6 ++- .../src/app/(main)/settings/Settings.tsx | 48 ++++++++++++++----- .../AgentSettingsForm/ToolsStep.tsx | 32 +++++++++---- .../assistants_web/src/constants/tools.ts | 2 +- 5 files changed, 66 insertions(+), 26 deletions(-) diff --git a/docs/custom_tool_guides/slack.md b/docs/custom_tool_guides/slack.md index 1f20e78b44..697cbe570c 100644 --- a/docs/custom_tool_guides/slack.md +++ b/docs/custom_tool_guides/slack.md @@ -47,7 +47,7 @@ SLACK_CLIENT_SECRET= ## 4. Enable the Slack Tool in the Frontend -To enable the Slack tool in the frontend, you will need to modify the `src/community/config/tools.py` file. Add the `TOOL_SLACK_ID` to the `AGENT_SETTINGS_TOOLS` list. +To enable the Slack tool in the frontend, you will need to modify the `src/interfaces/assistants_web/src/constants/tools.ts` file. Add the `TOOL_SLACK_ID` to the `AGENT_SETTINGS_TOOLS` list. ```typescript export const AGENT_SETTINGS_TOOLS = [ @@ -58,7 +58,7 @@ export const AGENT_SETTINGS_TOOLS = [ ]; ``` -To enable the Slack tool in the frontend for Base Agent, you will need to modify the `src/community/config/tools.py` file. Remove the `TOOL_SLACK_ID` from the `BASE_AGENT_EXCLUDED_TOOLS` list. +To enable the Slack tool in the frontend for Base Agent, you will need to modify the `src/interfaces/assistants_web/src/constants/tools.ts` file. Remove the `TOOL_SLACK_ID` from the `BASE_AGENT_EXCLUDED_TOOLS` list. By default, the Slack Tool is disabled for the Base Agent. Also if you need to exclude some tool from the Base Agent just add it to the `BASE_AGENT_EXCLUDED_TOOLS` list. ```typescript export const BASE_AGENT_EXCLUDED_TOOLS = []; diff --git a/src/backend/routers/tool.py b/src/backend/routers/tool.py index b9078ebc0c..78792cad66 100644 --- a/src/backend/routers/tool.py +++ b/src/backend/routers/tool.py @@ -45,7 +45,9 @@ def list_tools( all_tools = agent_tools for tool in all_tools: - if tool.is_available and tool.auth_implementation is not None: + # Tools with auth implementation can be enabled and visible but not accessible (e.g., if secrets are not set). + # Therefore, we need to set is_auth_required for these types of tools as well for the frontend. + if (tool.is_available or tool.is_visible) and tool.auth_implementation is not None: try: tool_auth_service = tool.auth_implementation() @@ -56,7 +58,7 @@ def list_tools( tool.token = tool_auth_service.get_token(session, user_id) except Exception as e: logger.error(event=f"Error while fetching Tool Auth: {str(e)}") - + tool.is_auth_required = True tool.is_available = False tool.error_message = ( f"Error while calling Tool Auth implementation {str(e)}" diff --git a/src/interfaces/assistants_web/src/app/(main)/settings/Settings.tsx b/src/interfaces/assistants_web/src/app/(main)/settings/Settings.tsx index 4e1f34bd3c..2863deb160 100644 --- a/src/interfaces/assistants_web/src/app/(main)/settings/Settings.tsx +++ b/src/interfaces/assistants_web/src/app/(main)/settings/Settings.tsx @@ -135,6 +135,8 @@ const GoogleDriveConnection = () => { }; const isGoogleDriveConnected = !(googleDriveTool.is_auth_required ?? false); + const isGoogleDriveAvailable = googleDriveTool.is_available ?? false; + const googleDriveError = googleDriveTool.error_message ?? ''; const authUrl = getToolAuthUrl(googleDriveTool.auth_url); return ( @@ -150,7 +152,16 @@ const GoogleDriveConnection = () => { Connect to Google Drive and add files to the assistant
- {isGoogleDriveConnected ? ( + {!isGoogleDriveAvailable ? ( +
+
+

+ {googleDriveError || + 'Google Drive connection is not available. Please set the required configuration parameters.'} +

+
+
+ ) : isGoogleDriveConnected ? (
@@ -208,20 +219,33 @@ const SlackConnection = () => { }; const isSlackConnected = !(slackTool.is_auth_required ?? false); + const isSlackAvailable = slackTool.is_available ?? false; + const slackError = slackTool.error_message ?? ''; return ( -
-
-
- - Slack -
- -
- Connect to Slack +
+
+
+
+ + Slack +
+ +
+ Connect to Slack +
- {isSlackConnected ? ( -
+ {!isSlackAvailable ? ( +
+
+

+ {slackError || + 'Slack connection is not available. Please set the required configuration parameters.'} +

+
+
+ ) : isSlackConnected ? ( +
{description} {!isAuthRequired && !!authUrl && } - {isAuthRequired && !!authUrl && ( + {!isAvailable && ( + + {errorMessage || + 'Connection is not available. Please set the required configuration parameters.'} + + )} + {isAuthRequired && !!authUrl && isAvailable && (
{files && files.length > 0 && (
- {files.map(({ file_name: name, id }) => ( -
+ {files.map(({ id, conversation_id, file_name: name }) => ( +
= () => { /> {name}
- handleDeleteFile(id)} - disabled={isDeletingFile} - iconName="close" - className="invisible group-hover:visible" - /> +
+ + handleOpenFile({ fileId: id, conversationId: conversation_id }) + } + /> + handleDeleteFile(id)} + /> +
))} diff --git a/src/interfaces/assistants_web/src/components/UI/FileViewer.tsx b/src/interfaces/assistants_web/src/components/UI/FileViewer.tsx new file mode 100644 index 0000000000..76df78fd8b --- /dev/null +++ b/src/interfaces/assistants_web/src/components/UI/FileViewer.tsx @@ -0,0 +1,37 @@ +import { Markdown } from '@/components/Markdown'; +import { Icon } from '@/components/UI/Icon'; +import { Spinner } from '@/components/UI/Spinner'; +import { Text } from '@/components/UI/Text'; +import { useFile } from '@/hooks'; + +type Props = { + fileId: string; + agentId?: string; + conversationId?: string; +}; + +export const FileViewer: React.FC = ({ fileId, agentId, conversationId }) => { + const { data: file, isLoading } = useFile({ fileId, agentId, conversationId }); + + if (isLoading) { + return ; + } + + return ( +
+
+
+ +
+ + {file?.file_name ?? 'Failed to load file content'} + +
+ {file && ( +
+ +
+ )} +
+ ); +}; diff --git a/src/interfaces/assistants_web/src/components/UI/Modal.tsx b/src/interfaces/assistants_web/src/components/UI/Modal.tsx index 0271f6c345..b45232d249 100644 --- a/src/interfaces/assistants_web/src/components/UI/Modal.tsx +++ b/src/interfaces/assistants_web/src/components/UI/Modal.tsx @@ -11,6 +11,7 @@ type ModalProps = { title?: string; children?: React.ReactNode; onClose?: VoidFunction; + dialogPaddingClassName?: string; }; /** @@ -21,6 +22,7 @@ export const Modal: React.FC = ({ isOpen, children, onClose = () => {}, + dialogPaddingClassName, }) => { return ( @@ -60,7 +62,8 @@ export const Modal: React.FC = ({ {children && ( = ({ return ( <> {children ? ( -
+
{children}
) : ( diff --git a/src/interfaces/assistants_web/src/context/ModalContext.tsx b/src/interfaces/assistants_web/src/context/ModalContext.tsx index 646cf19103..bf0ed9db8a 100644 --- a/src/interfaces/assistants_web/src/context/ModalContext.tsx +++ b/src/interfaces/assistants_web/src/context/ModalContext.tsx @@ -7,6 +7,7 @@ import { Modal } from '@/components/UI'; interface OpenParams { title?: string; content?: React.ReactNode | React.FC; + dialogPaddingClassName?: string; } export type OpenFunction = (params: OpenParams) => void; @@ -18,6 +19,7 @@ interface Context { open: OpenFunction; close: CloseFunction; content: React.ReactNode | React.FC; + dialogPaddingClassName?: string; } /** @@ -27,18 +29,22 @@ const useModal = (): Context => { const [isOpen, setIsOpen] = useState(false); const [title, setTitle] = useState(undefined); const [content, setContent] = useState(undefined); + const [dialogPaddingClassName, setDialogPaddingClassName] = useState( + undefined + ); - const open = ({ title, content }: OpenParams) => { + const open = ({ title, content, dialogPaddingClassName }: OpenParams) => { setIsOpen(true); setTitle(title); setContent(content); + setDialogPaddingClassName(dialogPaddingClassName); }; const close = () => { setIsOpen(false); }; - return { isOpen, open, close, content, title }; + return { isOpen, open, close, content, title, dialogPaddingClassName }; }; /** @@ -56,15 +62,21 @@ const ModalContext = createContext({ open: () => {}, close: () => {}, content: undefined, + dialogPaddingClassName: undefined, }); const ModalProvider: React.FC = ({ children }) => { - const { isOpen, title, open, close, content } = useModal(); + const { isOpen, title, open, close, content, dialogPaddingClassName } = useModal(); return ( - + <>{children} - + <>{content} diff --git a/src/interfaces/assistants_web/src/hooks/use-files.ts b/src/interfaces/assistants_web/src/hooks/use-files.ts index fcb6122d86..5997b68336 100644 --- a/src/interfaces/assistants_web/src/hooks/use-files.ts +++ b/src/interfaces/assistants_web/src/hooks/use-files.ts @@ -12,6 +12,29 @@ import { useConversationStore, useFilesStore, useParamsStore } from '@/stores'; import { UploadingFile } from '@/stores/slices/filesSlice'; import { fileSizeToBytes, formatFileSize, getFileExtension, mapExtensionToMimeType } from '@/utils'; +export const useFile = ({ + fileId, + agentId, + conversationId, +}: { + fileId: string; + agentId?: string; + conversationId?: string; +}) => { + const cohereClient = useCohereClient(); + return useQuery({ + queryKey: ['file', fileId], + queryFn: async () => { + if ((!agentId && !conversationId) || (agentId && conversationId)) { + throw new Error('Exactly one of agentId or conversationId must be provided'); + } + return agentId + ? await cohereClient.getAgentFile({ agentId: agentId!, fileId }) + : await cohereClient.getConversationFile({ conversationId: conversationId!, fileId }); + }, + }); +}; + export const useListConversationFiles = ( conversationId?: string, options?: { enabled?: boolean } From ccf41ea5d4b0ba83b7717e343eac878014bb8c44 Mon Sep 17 00:00:00 2001 From: Tianjing Li Date: Wed, 13 Nov 2024 10:23:42 -0500 Subject: [PATCH 08/14] frontend: Allow toggling tools for non-default agents (#839) Allow toggling tools for non-default agents --- .../components/Composer/DataSourceMenu.tsx | 21 ++++++++----------- .../MessagingContainer/AssistantTools.tsx | 6 ++---- .../components/MessagingContainer/Welcome.tsx | 10 ++++----- 3 files changed, 15 insertions(+), 22 deletions(-) diff --git a/src/interfaces/assistants_web/src/components/Composer/DataSourceMenu.tsx b/src/interfaces/assistants_web/src/components/Composer/DataSourceMenu.tsx index 24761cb6d8..3a7ebc3184 100644 --- a/src/interfaces/assistants_web/src/components/Composer/DataSourceMenu.tsx +++ b/src/interfaces/assistants_web/src/components/Composer/DataSourceMenu.tsx @@ -7,7 +7,7 @@ import { AgentPublic, ToolDefinition } from '@/cohere-client'; import { Icon, Switch, Text } from '@/components/UI'; import { useAvailableTools, useBrandedColors } from '@/hooks'; import { useParamsStore } from '@/stores'; -import { checkIsBaseAgent, cn, getToolIcon } from '@/utils'; +import { cn, getToolIcon } from '@/utils'; export type Props = { agent?: AgentPublic; @@ -25,9 +25,8 @@ export const DataSourceMenu: React.FC = ({ agent, tools }) => { agent, allTools: tools, }); + const { theme, text, contrastText, border, bg } = useBrandedColors(agent?.id); - const { text, contrastText, border, bg } = useBrandedColors(agent?.id); - const isBaseAgent = checkIsBaseAgent(agent); return ( = ({ agent, tools }) => { as="span" className={cn('font-medium', text, { [contrastText]: open })} > - Tools: {isBaseAgent ? paramsTools?.length ?? 0 : availableTools.length ?? 0} + Tools: {paramsTools ? paramsTools?.length ?? 0 : availableTools.length ?? 0} )} @@ -107,14 +106,12 @@ export const DataSourceMenu: React.FC = ({ agent, tools }) => { {tool.display_name}
- {isBaseAgent && ( - t.name === tool.name)} - onChange={(checked) => handleToggle(tool.name!, checked)} - showCheckedState - /> - )} + t.name === tool.name)} + onChange={(checked) => handleToggle(tool.name!, checked)} + showCheckedState + />
))} diff --git a/src/interfaces/assistants_web/src/components/MessagingContainer/AssistantTools.tsx b/src/interfaces/assistants_web/src/components/MessagingContainer/AssistantTools.tsx index 9faad2f7bf..c970b4863c 100644 --- a/src/interfaces/assistants_web/src/components/MessagingContainer/AssistantTools.tsx +++ b/src/interfaces/assistants_web/src/components/MessagingContainer/AssistantTools.tsx @@ -7,7 +7,7 @@ import { WelcomeGuideTooltip } from '@/components/MessagingContainer'; import { Button, Icon, Text, ToggleCard } from '@/components/UI'; import { useAvailableTools } from '@/hooks'; import { useParamsStore } from '@/stores'; -import { checkIsBaseAgent, cn, getToolIcon } from '@/utils'; +import { cn, getToolIcon } from '@/utils'; /** * @description Tools for the assistant to use in the conversation. @@ -41,18 +41,16 @@ export const AssistantTools: React.FC<{ {availableTools.map(({ name, display_name, description, error_message }) => { const enabledTool = enabledTools.find((enabledTool) => enabledTool.name === name); const checked = !!enabledTool; - const disabled = !checkIsBaseAgent(agent); return ( handleToggle(name ?? '', checked)} + onToggle={(checked) => handleToggle(name!, checked)} agentId={agent?.id} /> ); diff --git a/src/interfaces/assistants_web/src/components/MessagingContainer/Welcome.tsx b/src/interfaces/assistants_web/src/components/MessagingContainer/Welcome.tsx index b775bae8bc..1c88555916 100644 --- a/src/interfaces/assistants_web/src/components/MessagingContainer/Welcome.tsx +++ b/src/interfaces/assistants_web/src/components/MessagingContainer/Welcome.tsx @@ -71,12 +71,10 @@ export const Welcome: React.FC = ({ show, agentId }) => { {agent?.description || 'Ask questions and get answers based on your files.'} - {isBaseAgent && ( -
- - Toggle Tools On/Off -
- )} +
+ + Toggle Tools On/Off +
From 8fd8cf1785928402e9f7e86d8202dc6ba30fe16a Mon Sep 17 00:00:00 2001 From: Tianjing Li Date: Mon, 18 Nov 2024 10:27:40 -0500 Subject: [PATCH 09/14] backend/assistants_web: Move Default Agent usage to backend (#842) * Add router for default agent * Working app with default agent * Lint * Working app and tests * add back global exc handler * prettify coral * PR review --- src/backend/config/default_agent.py | 33 +++++++++ src/backend/main.py | 2 - src/backend/routers/agent.py | 33 +++++---- src/backend/tests/unit/routers/test_agent.py | 43 ++++++------ .../(main)/(chat)/c/[conversationId]/page.tsx | 6 +- .../src/app/(main)/(chat)/page.tsx | 8 ++- .../app/(main)/discover/DiscoverAgentCard.tsx | 10 +-- .../assistants_web/src/app/(main)/layout.tsx | 4 +- .../src/app/(main)/new/CreateAgent.tsx | 4 +- .../src/cohere-client/client.ts | 5 ++ .../cohere-client/generated/schemas.gen.ts | 67 ++++++++----------- .../cohere-client/generated/services.gen.ts | 2 +- .../src/cohere-client/generated/types.gen.ts | 9 ++- .../AgentSettingsForm/ToolsStep.tsx | 7 +- .../components/AgentSettingsForm/index.tsx | 5 +- .../src/components/Agents/AgentLogo.tsx | 6 +- .../HotKeys/custom-views/Search.tsx | 7 +- .../components/MessagingContainer/Welcome.tsx | 27 ++++---- .../src/components/SideNavPanel/AgentIcon.tsx | 11 +-- .../SideNavPanel/ConversationList.tsx | 5 +- .../src/constants/conversation.ts | 28 +------- .../assistants_web/src/constants/tools.ts | 9 +-- .../assistants_web/src/hooks/use-agents.ts | 36 ++++++---- .../src/hooks/use-brandedColors.ts | 22 +++--- .../assistants_web/src/hooks/use-chat.ts | 4 +- .../assistants_web/src/hooks/use-tools.ts | 16 ++--- .../assistants_web/src/utils/agents.ts | 5 +- .../src/components/Agents/AgentCard.tsx | 18 ++--- .../src/components/Agents/AgentForm.tsx | 4 +- .../src/components/Agents/AgentsList.tsx | 2 +- .../src/components/Agents/CreateAgent.tsx | 4 +- .../components/Agents/DiscoverAgentCard.tsx | 12 ++-- .../src/components/Agents/DiscoverAgents.tsx | 2 +- src/interfaces/coral_web/src/constants.ts | 2 +- 34 files changed, 235 insertions(+), 223 deletions(-) create mode 100644 src/backend/config/default_agent.py diff --git a/src/backend/config/default_agent.py b/src/backend/config/default_agent.py new file mode 100644 index 0000000000..69ca098ec8 --- /dev/null +++ b/src/backend/config/default_agent.py @@ -0,0 +1,33 @@ +import datetime + +from backend.config.deployments import ModelDeploymentName +from backend.config.tools import Tool +from backend.schemas.agent import AgentPublic + +DEFAULT_AGENT_ID = "default" +DEFAULT_DEPLOYMENT = ModelDeploymentName.CoherePlatform +DEFAULT_MODEL = "command-r-plus" + +def get_default_agent() -> AgentPublic: + return AgentPublic( + id=DEFAULT_AGENT_ID, + name='Command R+', + description='Ask questions and get answers based on your files.', + created_at=datetime.datetime.now(), + updated_at=datetime.datetime.now(), + preamble="", + version=1, + temperature=0.3, + tools=[ + Tool.Read_File.value.ID, + Tool.Search_File.value.ID, + Tool.Python_Interpreter.value.ID, + Tool.Hybrid_Web_Search.value.ID, + ], + tools_metadata=[], + deployment=DEFAULT_DEPLOYMENT, + model=DEFAULT_MODEL, + user_id='', + organization_id=None, + is_private=False, + ) diff --git a/src/backend/main.py b/src/backend/main.py index 3bdd288a30..8efeed67d8 100644 --- a/src/backend/main.py +++ b/src/backend/main.py @@ -91,7 +91,6 @@ def create_app(): app = create_app() - @app.exception_handler(Exception) async def validation_exception_handler(request: Request, exc: Exception): ctx = get_context(request) @@ -115,7 +114,6 @@ async def validation_exception_handler(request: Request, exc: Exception): }, ) - @app.on_event("startup") async def startup_event(): """ diff --git a/src/backend/routers/agent.py b/src/backend/routers/agent.py index 65c8d19efa..0c63f3dc16 100644 --- a/src/backend/routers/agent.py +++ b/src/backend/routers/agent.py @@ -5,6 +5,7 @@ from fastapi import File as RequestFile from fastapi import UploadFile as FastAPIUploadFile +from backend.config.default_agent import DEFAULT_AGENT_ID, get_default_agent from backend.config.routers import RouterName from backend.crud import agent as agent_crud from backend.crud import agent_tool_metadata as agent_tool_metadata_crud @@ -71,9 +72,9 @@ ], ) async def create_agent( - session: DBSessionDep, - agent: CreateAgentRequest, - ctx: Context = Depends(get_context), + session: DBSessionDep, + agent: CreateAgentRequest, + ctx: Context = Depends(get_context), ) -> AgentPublic: """ Create an agent. @@ -127,13 +128,13 @@ async def create_agent( @router.get("", response_model=list[AgentPublic]) async def list_agents( - *, - offset: int = 0, - limit: int = 100, - session: DBSessionDep, - visibility: AgentVisibility = AgentVisibility.ALL, - organization_id: Optional[str] = None, - ctx: Context = Depends(get_context), + *, + offset: int = 0, + limit: int = 100, + session: DBSessionDep, + visibility: AgentVisibility = AgentVisibility.ALL, + organization_id: Optional[str] = None, + ctx: Context = Depends(get_context), ) -> list[AgentPublic]: """ List all agents. @@ -163,6 +164,8 @@ async def list_agents( visibility=visibility, organization_id=organization_id, ) + # Tradeoff: This appends the default Agent regardless of pagination + agents.append(get_default_agent()) return agents except Exception as e: logger.exception(event=e) @@ -171,8 +174,8 @@ async def list_agents( @router.get("/{agent_id}", response_model=AgentPublic) async def get_agent_by_id( - agent_id: str, session: DBSessionDep, ctx: Context = Depends(get_context) -) -> Agent: + agent_id: str, session: DBSessionDep, ctx: Context = Depends(get_context) +) -> AgentPublic: """ Args: agent_id (str): Agent ID. @@ -189,7 +192,11 @@ async def get_agent_by_id( agent = None try: - agent = agent_crud.get_agent_by_id(session, agent_id, user_id) + # Intentionally not adding Default Agent to DB so it's more flexible + if agent_id == DEFAULT_AGENT_ID: + agent = get_default_agent() + else: + agent = agent_crud.get_agent_by_id(session, agent_id, user_id) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) diff --git a/src/backend/tests/unit/routers/test_agent.py b/src/backend/tests/unit/routers/test_agent.py index e7b7d5df75..20bee41aa0 100644 --- a/src/backend/tests/unit/routers/test_agent.py +++ b/src/backend/tests/unit/routers/test_agent.py @@ -4,6 +4,7 @@ from fastapi.testclient import TestClient from sqlalchemy.orm import Session +from backend.config.default_agent import DEFAULT_AGENT_ID from backend.config.deployments import ModelDeploymentName from backend.config.tools import Tool from backend.crud import deployment as deployment_crud @@ -17,6 +18,9 @@ and os.environ.get("COHERE_API_KEY") != "" ) +def filter_default_agent(agents: list) -> list: + return [agent for agent in agents if agent.get("id") != DEFAULT_AGENT_ID] + def test_create_agent_missing_name( session_client: TestClient, session: Session, user ) -> None: @@ -159,24 +163,23 @@ def test_create_existing_agent( assert response.json() == {"detail": "Agent test agent already exists."} -def test_list_agents_empty(session_client: TestClient, session: Session) -> None: - # Delete default agent - session.query(Agent).delete() +def test_list_agents_empty_returns_default_agent(session_client: TestClient, session: Session) -> None: response = session_client.get("/v1/agents", headers={"User-Id": "123"}) assert response.status_code == 200 response_agents = response.json() - assert len(response_agents) == 0 + # Returns default agent + assert len(response_agents) == 1 def test_list_agents(session_client: TestClient, session: Session, user) -> None: - session.query(Agent).delete() - for _ in range(3): + num_agents = 3 + for _ in range(num_agents): _ = get_factory("Agent", session).create(user=user) response = session_client.get("/v1/agents", headers={"User-Id": user.id}) assert response.status_code == 200 - response_agents = response.json() - assert len(response_agents) == 3 + response_agents = filter_default_agent(response.json()) + assert len(response_agents) == num_agents def test_list_organization_agents( @@ -184,10 +187,10 @@ def test_list_organization_agents( session: Session, user, ) -> None: - session.query(Agent).delete() + num_agents = 3 organization = get_factory("Organization", session).create() organization1 = get_factory("Organization", session).create() - for i in range(3): + for i in range(num_agents): _ = get_factory("Agent", session).create( user=user, organization_id=organization.id, @@ -201,9 +204,9 @@ def test_list_organization_agents( "/v1/agents", headers={"User-Id": user.id, "Organization-Id": organization.id} ) assert response.status_code == 200 - response_agents = response.json() + response_agents = filter_default_agent(response.json()) agents = sorted(response_agents, key=lambda x: x["name"]) - for i in range(3): + for i in range(num_agents): assert agents[i]["name"] == f"agent-{i}-{organization.id}" @@ -212,10 +215,10 @@ def test_list_organization_agents_query_param( session: Session, user, ) -> None: - session.query(Agent).delete() + num_agents = 3 organization = get_factory("Organization", session).create() organization1 = get_factory("Organization", session).create() - for i in range(3): + for i in range(num_agents): _ = get_factory("Agent", session).create( user=user, organization_id=organization.id ) @@ -230,9 +233,9 @@ def test_list_organization_agents_query_param( headers={"User-Id": user.id, "Organization-Id": organization.id}, ) assert response.status_code == 200 - response_agents = response.json() + response_agents = filter_default_agent(response.json()) agents = sorted(response_agents, key=lambda x: x["name"]) - for i in range(3): + for i in range(num_agents): assert agents[i]["name"] == f"agent-{i}-{organization1.id}" @@ -263,7 +266,7 @@ def test_list_private_agents( ) assert response.status_code == 200 - response_agents = response.json() + response_agents = filter_default_agent(response.json()) # Only the agents created by user should be returned assert len(response_agents) == 3 @@ -282,7 +285,7 @@ def test_list_public_agents(session_client: TestClient, session: Session, user) ) assert response.status_code == 200 - response_agents = response.json() + response_agents = filter_default_agent(response.json()) # Only the agents created by user should be returned assert len(response_agents) == 2 @@ -319,14 +322,14 @@ def test_list_agents_with_pagination( "/v1/agents?limit=3&offset=2", headers={"User-Id": user.id} ) assert response.status_code == 200 - response_agents = response.json() + response_agents = filter_default_agent(response.json()) assert len(response_agents) == 3 response = session_client.get( "/v1/agents?limit=2&offset=4", headers={"User-Id": user.id} ) assert response.status_code == 200 - response_agents = response.json() + response_agents = filter_default_agent(response.json()) assert len(response_agents) == 1 diff --git a/src/interfaces/assistants_web/src/app/(main)/(chat)/c/[conversationId]/page.tsx b/src/interfaces/assistants_web/src/app/(main)/(chat)/c/[conversationId]/page.tsx index e2f485b15a..4f90e8dab3 100644 --- a/src/interfaces/assistants_web/src/app/(main)/(chat)/c/[conversationId]/page.tsx +++ b/src/interfaces/assistants_web/src/app/(main)/(chat)/c/[conversationId]/page.tsx @@ -2,7 +2,7 @@ import { HydrationBoundary, QueryClient, dehydrate } from '@tanstack/react-query import { NextPage } from 'next'; import Chat from '@/app/(main)/(chat)/Chat'; -import { BASE_AGENT } from '@/constants'; +import { DEFAULT_AGENT_ID } from '@/constants'; import { getCohereServerClient } from '@/server/cohereServerClient'; type Props = { @@ -23,8 +23,8 @@ const Page: NextPage = async ({ params }) => { cohereServerClient.getConversation({ conversationId: params.conversationId }), }), queryClient.prefetchQuery({ - queryKey: ['agent', null], - queryFn: () => BASE_AGENT, + queryKey: ['agent', DEFAULT_AGENT_ID], + queryFn: () => cohereServerClient.getDefaultAgent(), }), ]); diff --git a/src/interfaces/assistants_web/src/app/(main)/(chat)/page.tsx b/src/interfaces/assistants_web/src/app/(main)/(chat)/page.tsx index 9575430abd..bdc11e49b9 100644 --- a/src/interfaces/assistants_web/src/app/(main)/(chat)/page.tsx +++ b/src/interfaces/assistants_web/src/app/(main)/(chat)/page.tsx @@ -2,14 +2,16 @@ import { HydrationBoundary, QueryClient, dehydrate } from '@tanstack/react-query import { NextPage } from 'next'; import Chat from '@/app/(main)/(chat)/Chat'; -import { BASE_AGENT } from '@/constants'; +import { DEFAULT_AGENT_ID } from '@/constants'; +import { getCohereServerClient } from '@/server/cohereServerClient'; const Page: NextPage = async () => { const queryClient = new QueryClient(); + const cohereServerClient = getCohereServerClient(); await queryClient.prefetchQuery({ - queryKey: ['agent', null], - queryFn: () => BASE_AGENT, + queryKey: ['agent', DEFAULT_AGENT_ID], + queryFn: () => cohereServerClient.getDefaultAgent(), }); return ( diff --git a/src/interfaces/assistants_web/src/app/(main)/discover/DiscoverAgentCard.tsx b/src/interfaces/assistants_web/src/app/(main)/discover/DiscoverAgentCard.tsx index 89ebb38bbd..c0a0c7f4fa 100644 --- a/src/interfaces/assistants_web/src/app/(main)/discover/DiscoverAgentCard.tsx +++ b/src/interfaces/assistants_web/src/app/(main)/discover/DiscoverAgentCard.tsx @@ -7,7 +7,7 @@ import { DeleteAgent } from '@/components/Modals/DeleteAgent'; import { CoralLogo, KebabMenu, Text } from '@/components/UI'; import { useContextStore } from '@/context'; import { useBrandedColors, useSession } from '@/hooks'; -import { checkIsBaseAgent, cn } from '@/utils'; +import { checkIsDefaultAgent, cn } from '@/utils'; type Props = { agent?: AgentPublic; @@ -17,11 +17,11 @@ type Props = { * @description renders a card for an agent with the agent's name, description */ export const DiscoverAgentCard: React.FC = ({ agent }) => { - const isBaseAgent = checkIsBaseAgent(agent); + const isDefaultAgent = checkIsDefaultAgent(agent); const { bg, contrastText, contrastFill } = useBrandedColors(agent?.id); const session = useSession(); const isCreator = agent?.user_id === session.userId; - const createdBy = isBaseAgent ? 'COHERE' : isCreator ? 'YOU' : 'TEAM'; + const createdBy = isDefaultAgent ? 'COHERE' : isCreator ? 'YOU' : 'TEAM'; const { open, close } = useContextStore(); @@ -36,7 +36,7 @@ export const DiscoverAgentCard: React.FC = ({ agent }) => { return (
@@ -46,7 +46,7 @@ export const DiscoverAgentCard: React.FC = ({ agent }) => { bg )} > - {isBaseAgent ? ( + {isDefaultAgent ? ( ) : ( diff --git a/src/interfaces/assistants_web/src/app/(main)/layout.tsx b/src/interfaces/assistants_web/src/app/(main)/layout.tsx index e5112cd503..b5a504efab 100644 --- a/src/interfaces/assistants_web/src/app/(main)/layout.tsx +++ b/src/interfaces/assistants_web/src/app/(main)/layout.tsx @@ -6,7 +6,7 @@ import { redirect } from 'next/navigation'; import { Swipeable } from '@/components/Global'; import { HotKeys } from '@/components/HotKeys'; import { SideNavPanel } from '@/components/SideNavPanel'; -import { COOKIE_KEYS, DEFAULT_AGENT_TOOLS } from '@/constants'; +import { BACKGROUND_TOOLS, COOKIE_KEYS } from '@/constants'; import { getCohereServerClient } from '@/server/cohereServerClient'; const MainLayout: NextPage = async ({ children }) => { @@ -35,7 +35,7 @@ const MainLayout: NextPage = async ({ children }) => { queryKey: ['tools'], queryFn: async () => { const tools = await cohereServerClient.listTools({}); - return tools.filter((tool) => !DEFAULT_AGENT_TOOLS.includes(tool.name ?? '')); + return tools.filter((tool) => !BACKGROUND_TOOLS.includes(tool.name ?? '')); }, }), queryClient.prefetchQuery({ diff --git a/src/interfaces/assistants_web/src/app/(main)/new/CreateAgent.tsx b/src/interfaces/assistants_web/src/app/(main)/new/CreateAgent.tsx index 67ff7731e3..fddd3035f6 100644 --- a/src/interfaces/assistants_web/src/app/(main)/new/CreateAgent.tsx +++ b/src/interfaces/assistants_web/src/app/(main)/new/CreateAgent.tsx @@ -9,8 +9,8 @@ import { AgentSettingsFields, AgentSettingsForm } from '@/components/AgentSettin import { MobileHeader } from '@/components/Global'; import { Button, Icon, Text } from '@/components/UI'; import { + BACKGROUND_TOOLS, DEFAULT_AGENT_MODEL, - DEFAULT_AGENT_TOOLS, DEFAULT_PREAMBLE, DEPLOYMENT_COHERE_PLATFORM, } from '@/constants'; @@ -23,7 +23,7 @@ const DEFAULT_FIELD_VALUES = { preamble: DEFAULT_PREAMBLE, deployment: DEPLOYMENT_COHERE_PLATFORM, model: DEFAULT_AGENT_MODEL, - tools: DEFAULT_AGENT_TOOLS, + tools: BACKGROUND_TOOLS, is_private: false, }; /** diff --git a/src/interfaces/assistants_web/src/cohere-client/client.ts b/src/interfaces/assistants_web/src/cohere-client/client.ts index 2f45f1f9de..c5372f6fa5 100644 --- a/src/interfaces/assistants_web/src/cohere-client/client.ts +++ b/src/interfaces/assistants_web/src/cohere-client/client.ts @@ -17,6 +17,7 @@ import { UpdateConversationRequest, UpdateDeploymentEnv, } from '@/cohere-client'; +import { DEFAULT_AGENT_ID } from '@/constants'; import { mapToChatRequest } from './mappings'; @@ -285,6 +286,10 @@ export class CohereClient { // this.cohereService.default.oidcAuthorizeV1OidcAuthGet(); } + public getDefaultAgent() { + return this.cohereService.default.getAgentByIdV1AgentsAgentIdGet({ agentId: DEFAULT_AGENT_ID }); + } + public getAgent(agentId: string) { return this.cohereService.default.getAgentByIdV1AgentsAgentIdGet({ agentId }); } diff --git a/src/interfaces/assistants_web/src/cohere-client/generated/schemas.gen.ts b/src/interfaces/assistants_web/src/cohere-client/generated/schemas.gen.ts index d517eda38d..0766845281 100644 --- a/src/interfaces/assistants_web/src/cohere-client/generated/schemas.gen.ts +++ b/src/interfaces/assistants_web/src/cohere-client/generated/schemas.gen.ts @@ -6,6 +6,17 @@ export const $AgentPublic = { type: 'string', title: 'User Id', }, + organization_id: { + anyOf: [ + { + type: 'string', + }, + { + type: 'null', + }, + ], + title: 'Organization Id', + }, id: { type: 'string', title: 'Id', @@ -2928,23 +2939,6 @@ export const $Tool = { title: 'Name', default: '', }, - display_name: { - type: 'string', - title: 'Display Name', - default: '', - }, - description: { - anyOf: [ - { - type: 'string', - }, - { - type: 'null', - }, - ], - title: 'Description', - default: '', - }, parameter_definitions: { anyOf: [ { @@ -3040,34 +3034,39 @@ export const $ToolDefinition = { title: 'Name', default: '', }, - display_name: { - type: 'string', - title: 'Display Name', - default: '', - }, - description: { + parameter_definitions: { anyOf: [ { - type: 'string', + type: 'object', }, { type: 'null', }, ], + title: 'Parameter Definitions', + default: {}, + }, + display_name: { + type: 'string', + title: 'Display Name', + default: '', + }, + description: { + type: 'string', title: 'Description', default: '', }, - parameter_definitions: { + error_message: { anyOf: [ { - type: 'object', + type: 'string', }, { type: 'null', }, ], - title: 'Parameter Definitions', - default: {}, + title: 'Error Message', + default: '', }, kwargs: { type: 'object', @@ -3084,18 +3083,6 @@ export const $ToolDefinition = { title: 'Is Available', default: false, }, - error_message: { - anyOf: [ - { - type: 'string', - }, - { - type: 'null', - }, - ], - title: 'Error Message', - default: '', - }, category: { $ref: '#/components/schemas/ToolCategory', default: 'Data loader', diff --git a/src/interfaces/assistants_web/src/cohere-client/generated/services.gen.ts b/src/interfaces/assistants_web/src/cohere-client/generated/services.gen.ts index a84f49c81e..c373afc050 100644 --- a/src/interfaces/assistants_web/src/cohere-client/generated/services.gen.ts +++ b/src/interfaces/assistants_web/src/cohere-client/generated/services.gen.ts @@ -297,7 +297,7 @@ export class DefaultService { * If completed, the corresponding ToolAuth for the requesting user is removed from the DB. * * Args: - * tool_id (str): Tool ID to be deleted for the user. (eg. google_drive) Should be one of the values listed in the Tool enum. + * tool_id (str): Tool ID to be deleted for the user. (eg. google_drive) Should be one of the values listed in the Tool string enum class. * request (Request): current Request object. * session (DBSessionDep): Database session. * ctx (Context): Context object. diff --git a/src/interfaces/assistants_web/src/cohere-client/generated/types.gen.ts b/src/interfaces/assistants_web/src/cohere-client/generated/types.gen.ts index 860fee09a0..2dd641109e 100644 --- a/src/interfaces/assistants_web/src/cohere-client/generated/types.gen.ts +++ b/src/interfaces/assistants_web/src/cohere-client/generated/types.gen.ts @@ -2,6 +2,7 @@ export type AgentPublic = { user_id: string; + organization_id?: string | null; id: string; created_at: string; updated_at: string; @@ -621,8 +622,6 @@ export type ToggleConversationPinRequest = { export type Tool = { name?: string | null; - display_name?: string; - description?: string | null; parameter_definitions?: { [key: string]: unknown; } | null; @@ -650,17 +649,17 @@ export enum ToolCategory { export type ToolDefinition = { name?: string | null; - display_name?: string; - description?: string | null; parameter_definitions?: { [key: string]: unknown; } | null; + display_name?: string; + description?: string; + error_message?: string | null; kwargs?: { [key: string]: unknown; }; is_visible?: boolean; is_available?: boolean; - error_message?: string | null; category?: ToolCategory; is_auth_required?: boolean; auth_url?: string | null; diff --git a/src/interfaces/assistants_web/src/components/AgentSettingsForm/ToolsStep.tsx b/src/interfaces/assistants_web/src/components/AgentSettingsForm/ToolsStep.tsx index 0124b35f70..b9482f3e83 100644 --- a/src/interfaces/assistants_web/src/components/AgentSettingsForm/ToolsStep.tsx +++ b/src/interfaces/assistants_web/src/components/AgentSettingsForm/ToolsStep.tsx @@ -3,7 +3,7 @@ import Link from 'next/link'; import { ToolDefinition } from '@/cohere-client'; import { StatusConnection } from '@/components/AgentSettingsForm/StatusConnection'; import { Button, Icon, IconName, Switch, Text } from '@/components/UI'; -import { AGENT_SETTINGS_TOOLS, TOOL_FALLBACK_ICON, TOOL_ID_TO_DISPLAY_INFO } from '@/constants'; +import { TOOL_FALLBACK_ICON, TOOL_ID_TO_DISPLAY_INFO } from '@/constants'; type Props = { tools?: ToolDefinition[]; @@ -18,10 +18,7 @@ export const ToolsStep: React.FC = ({ setActiveTools, handleAuthButtonClick, }) => { - const availableTools = tools?.filter( - (tool) => tool.name && AGENT_SETTINGS_TOOLS.includes(tool.name) - ); - const toolsAuthRequired = tools?.filter((tool) => tool.is_auth_required && tool.auth_url); + const availableTools = tools?.filter((tool) => tool.name && tool.is_available && tool.is_visible); const handleUpdateActiveTools = (checked: boolean, name: string) => { if (checked) { diff --git a/src/interfaces/assistants_web/src/components/AgentSettingsForm/index.tsx b/src/interfaces/assistants_web/src/components/AgentSettingsForm/index.tsx index 83d26adb4b..c945b958a1 100644 --- a/src/interfaces/assistants_web/src/components/AgentSettingsForm/index.tsx +++ b/src/interfaces/assistants_web/src/components/AgentSettingsForm/index.tsx @@ -83,7 +83,8 @@ export const AgentSettingsForm: React.FC = (props) => { fields.tools_metadata?.find((metadata) => metadata.tool_name === TOOL_GOOGLE_DRIVE_ID) ?.artifacts as DataSourceArtifact[] ); - // read_file and search_file have identical metadata -> using read_file as base + + // read_file and search_file have identical metadata -> use read_file as base const [defaultUploadFiles, setDefaultUploadFiles] = useState( fields.tools_metadata?.find((metadata) => metadata.tool_name === TOOL_READ_DOCUMENT_ID) ?.artifacts as DataSourceArtifact[] @@ -258,7 +259,7 @@ export const AgentSettingsForm: React.FC = (props) => { setCurrentStep(expanded ? 'tools' : undefined)} > diff --git a/src/interfaces/assistants_web/src/components/Agents/AgentLogo.tsx b/src/interfaces/assistants_web/src/components/Agents/AgentLogo.tsx index 4207ce0711..0dd1c9ef9f 100644 --- a/src/interfaces/assistants_web/src/components/Agents/AgentLogo.tsx +++ b/src/interfaces/assistants_web/src/components/Agents/AgentLogo.tsx @@ -4,13 +4,13 @@ import { useBrandedColors } from '@/hooks'; import { cn } from '@/utils'; export const AgentLogo = ({ agent }: { agent: AgentPublic }) => { - const isBaseAgent = !agent.id; + const isDefaultAgent = !agent.id; const { bg, contrastText, contrastFill } = useBrandedColors(agent.id); return (
- {isBaseAgent && } - {!isBaseAgent && ( + {isDefaultAgent && } + {!isDefaultAgent && ( {agent.name[0]} diff --git a/src/interfaces/assistants_web/src/components/HotKeys/custom-views/Search.tsx b/src/interfaces/assistants_web/src/components/HotKeys/custom-views/Search.tsx index 08e5a58fd4..c681eaed88 100644 --- a/src/interfaces/assistants_web/src/components/HotKeys/custom-views/Search.tsx +++ b/src/interfaces/assistants_web/src/components/HotKeys/custom-views/Search.tsx @@ -7,7 +7,7 @@ import { useState } from 'react'; import { AgentLogo } from '@/components/Agents/AgentLogo'; import { CommandActionGroup, HotKeyGroupOption, HotKeysDialogInput } from '@/components/HotKeys'; -import { BASE_AGENT } from '@/constants'; +import { DEFAULT_AGENT_ID } from '@/constants'; import { useConversations, useListAgents } from '@/hooks'; type Props = { @@ -29,10 +29,11 @@ export const Search: React.FC = ({ isOpen, close, onBack }) => { quickActions: [], }, ]); - const router = useRouter(); + const router = useRouter(); const { data: assistants } = useListAgents(); const { data: conversations } = useConversations({}); + let defaultAgent = assistants?.find((a) => a.id == DEFAULT_AGENT_ID); useDebouncedEffect( () => { @@ -79,7 +80,7 @@ export const Search: React.FC = ({ isOpen, close, onBack }) => { assistant.id === conversation.agent_id) ?? - BASE_AGENT + defaultAgent! } /> {conversation.title} diff --git a/src/interfaces/assistants_web/src/components/MessagingContainer/Welcome.tsx b/src/interfaces/assistants_web/src/components/MessagingContainer/Welcome.tsx index 1c88555916..e74f3bd212 100644 --- a/src/interfaces/assistants_web/src/components/MessagingContainer/Welcome.tsx +++ b/src/interfaces/assistants_web/src/components/MessagingContainer/Welcome.tsx @@ -5,10 +5,9 @@ import React from 'react'; import { AssistantTools } from '@/components/MessagingContainer'; import { CoralLogo, Icon, Text } from '@/components/UI'; -import { BASE_AGENT_EXCLUDED_TOOLS } from '@/constants'; import { useAgent, useBrandedColors, useListTools } from '@/hooks'; +import { checkIsDefaultAgent } from '@/utils'; import { cn } from '@/utils'; -import { checkIsBaseAgent } from '@/utils'; type Props = { show: boolean; @@ -23,12 +22,10 @@ export const Welcome: React.FC = ({ show, agentId }) => { const { data: tools = [], isLoading: isToolsLoading } = useListTools(); const { contrastText, bg, contrastFill } = useBrandedColors(agentId); - const isBaseAgent = checkIsBaseAgent(agent); - // Filter out tools that are excluded for the base agent + const isDefaultAgent = checkIsDefaultAgent(agent); + + // // Filter out tools that are excluded for the base agent let toolsFiltered = [...tools]; - if (isBaseAgent) { - toolsFiltered = tools.filter((tool) => !BASE_AGENT_EXCLUDED_TOOLS.includes(tool.name ?? '')); - } return ( = ({ show, agentId }) => { bg )} > - {isBaseAgent ? ( + {isDefaultAgent ? ( ) : ( @@ -59,9 +56,9 @@ export const Welcome: React.FC = ({ show, agentId }) => { )}
- {isBaseAgent ? 'Your Public Assistant' : agent?.name} + {isDefaultAgent ? 'Your Public Assistant' : agent?.name} - {isBaseAgent && ( + {isDefaultAgent && ( By Cohere @@ -71,10 +68,12 @@ export const Welcome: React.FC = ({ show, agentId }) => { {agent?.description || 'Ask questions and get answers based on your files.'}
-
- - Toggle Tools On/Off -
+ {tools.length > 0 && ( +
+ + Toggle Tools On/Off +
+ )}
diff --git a/src/interfaces/assistants_web/src/components/SideNavPanel/AgentIcon.tsx b/src/interfaces/assistants_web/src/components/SideNavPanel/AgentIcon.tsx index 04166d3f37..b6e128d4ae 100644 --- a/src/interfaces/assistants_web/src/components/SideNavPanel/AgentIcon.tsx +++ b/src/interfaces/assistants_web/src/components/SideNavPanel/AgentIcon.tsx @@ -3,6 +3,7 @@ import { usePathname, useRouter } from 'next/navigation'; import { CoralLogo, Text, Tooltip } from '@/components/UI'; +import { DEFAULT_AGENT_ID } from '@/constants'; import { useBrandedColors, useChatRoutes, useConversationFileActions, useIsDesktop } from '@/hooks'; import { useCitationsStore, @@ -14,7 +15,6 @@ import { cn } from '@/utils'; type Props = { name: string; - isBaseAgent?: boolean; id?: string; }; @@ -23,15 +23,16 @@ type Props = { * It shows a tooltip of the agent's name and a colored icon with the first letter of the agent's name. * If the agent is a base agent, it shows the Coral logo instead. */ -export const AgentIcon: React.FC = ({ name, id, isBaseAgent }) => { +export const AgentIcon: React.FC = ({ name, id }) => { const { conversationId } = useChatRoutes(); const router = useRouter(); const isDesktop = useIsDesktop(); const isMobile = !isDesktop; const pathname = usePathname(); const { setLeftPanelOpen } = useSettingsStore(); + const isDefaultAgent = id === DEFAULT_AGENT_ID; - const isActive = isBaseAgent + const isActive = isDefaultAgent ? conversationId ? pathname === `/c/${conversationId}` : pathname === '/' @@ -56,7 +57,7 @@ export const AgentIcon: React.FC = ({ name, id, isBaseAgent }) => { const handleClick = () => { if (isActive) return; - const url = isBaseAgent ? '/' : `/a/${id}`; + const url = isDefaultAgent ? '/' : `/a/${id}`; router.push(url); @@ -81,7 +82,7 @@ export const AgentIcon: React.FC = ({ name, id, isBaseAgent }) => { bg )} > - {isBaseAgent ? ( + {isDefaultAgent ? ( ) : ( diff --git a/src/interfaces/assistants_web/src/components/SideNavPanel/ConversationList.tsx b/src/interfaces/assistants_web/src/components/SideNavPanel/ConversationList.tsx index 6cd4a907db..76156f64d9 100644 --- a/src/interfaces/assistants_web/src/components/SideNavPanel/ConversationList.tsx +++ b/src/interfaces/assistants_web/src/components/SideNavPanel/ConversationList.tsx @@ -3,6 +3,7 @@ import { useMemo, useState } from 'react'; import { Flipped, Flipper } from 'react-flip-toolkit'; +import { AgentPublic } from '@/cohere-client'; import { ConversationWithoutMessages as Conversation } from '@/cohere-client'; import { AgentIcon, @@ -48,7 +49,7 @@ export const ConversationList: React.FC = () => { const RecentAgents: React.FC = () => { const { isLeftPanelOpen } = useSettingsStore(); - const recentAgents = useRecentAgents(); + const recentAgents = useRecentAgents() as AgentPublic[]; const flipKey = recentAgents.map((agent) => agent?.id || agent?.name).join(','); return ( @@ -61,7 +62,7 @@ const RecentAgents: React.FC = () => { {(flippedProps) => (
- +
)}
diff --git a/src/interfaces/assistants_web/src/constants/conversation.ts b/src/interfaces/assistants_web/src/constants/conversation.ts index 31a8042e18..5947da1062 100644 --- a/src/interfaces/assistants_web/src/constants/conversation.ts +++ b/src/interfaces/assistants_web/src/constants/conversation.ts @@ -1,40 +1,14 @@ -import { AgentPublic } from '@/cohere-client'; import { FileAccept } from '@/components/UI'; -import { DEPLOYMENT_COHERE_PLATFORM } from '@/constants/setup'; -import { - AGENT_SETTINGS_TOOLS, - FILE_UPLOAD_TOOLS, - TOOL_READ_DOCUMENT_ID, - TOOL_SEARCH_FILE_ID, - TOOL_WEB_SCRAPE_ID, -} from '@/constants/tools'; export const DEFAULT_CONVERSATION_NAME = 'New Conversation'; export const DEFAULT_AGENT_MODEL = 'command-r-plus'; +export const DEFAULT_AGENT_ID = 'default'; export const DEFAULT_TYPING_VELOCITY = 35; export const CONVERSATION_HISTORY_OFFSET = 100; export const DEFAULT_PREAMBLE = "## Task And Context\nYou help people answer their questions and other requests interactively. You will be asked a very wide array of requests on all kinds of topics. You will be equipped with a wide range of search engines or similar tools to help you, which you use to research your answer. You should focus on serving the user's needs as best you can, which will be wide-ranging.\n\n## Style Guide\nUnless the user asks for a different style of answer, you should answer in full sentences, using proper grammar and spelling."; -export const DEFAULT_AGENT_TOOLS = [TOOL_SEARCH_FILE_ID, TOOL_READ_DOCUMENT_ID, TOOL_WEB_SCRAPE_ID]; - -export const BASE_AGENT: AgentPublic = { - id: '', - name: 'Command R+', - description: 'Ask questions and get answers based on your files.', - created_at: new Date().toISOString(), - updated_at: new Date().toISOString(), - preamble: '', - version: 1, - temperature: 0.3, - tools: [...AGENT_SETTINGS_TOOLS, ...FILE_UPLOAD_TOOLS], - model: DEFAULT_AGENT_MODEL, - deployment: DEPLOYMENT_COHERE_PLATFORM, - user_id: '', - is_private: false, -}; - export const ACCEPTED_FILE_TYPES: FileAccept[] = [ 'text/csv', 'text/plain', diff --git a/src/interfaces/assistants_web/src/constants/tools.ts b/src/interfaces/assistants_web/src/constants/tools.ts index 68c0db5efc..5e62648cbe 100644 --- a/src/interfaces/assistants_web/src/constants/tools.ts +++ b/src/interfaces/assistants_web/src/constants/tools.ts @@ -13,15 +13,8 @@ export const TOOL_CALCULATOR_ID = 'toolkit_calculator'; export const TOOL_WEB_SCRAPE_ID = 'web_scrape'; export const TOOL_GOOGLE_DRIVE_ID = 'google_drive'; export const TOOL_SLACK_ID = 'slack'; -export const FILE_UPLOAD_TOOLS = [TOOL_SEARCH_FILE_ID, TOOL_READ_DOCUMENT_ID]; -export const AGENT_SETTINGS_TOOLS = [ - TOOL_HYBRID_WEB_SEARCH_ID, - TOOL_PYTHON_INTERPRETER_ID, - TOOL_WEB_SCRAPE_ID, -]; -// Tools won't be available for the base agent -export const BASE_AGENT_EXCLUDED_TOOLS: string[] = [TOOL_SLACK_ID]; +export const BACKGROUND_TOOLS = [TOOL_SEARCH_FILE_ID, TOOL_READ_DOCUMENT_ID, TOOL_WEB_SCRAPE_ID]; export const TOOL_FALLBACK_ICON = 'circles-four'; export const TOOL_ID_TO_DISPLAY_INFO: { [id: string]: { icon: IconName } } = { diff --git a/src/interfaces/assistants_web/src/hooks/use-agents.ts b/src/interfaces/assistants_web/src/hooks/use-agents.ts index 1ad4f2815b..116f291deb 100644 --- a/src/interfaces/assistants_web/src/hooks/use-agents.ts +++ b/src/interfaces/assistants_web/src/hooks/use-agents.ts @@ -9,7 +9,7 @@ import { UpdateAgentRequest, useCohereClient, } from '@/cohere-client'; -import { BASE_AGENT } from '@/constants'; +import { DEFAULT_AGENT_ID } from '@/constants'; import { useConversations } from '@/hooks'; export const useListAgents = () => { @@ -17,8 +17,7 @@ export const useListAgents = () => { return useQuery({ queryKey: ['listAgents'], queryFn: async () => { - const agents = await cohereClient.listAgents({}); - return agents.concat(BASE_AGENT); + return await cohereClient.listAgents({}); }, }); }; @@ -59,7 +58,7 @@ export const useAgent = ({ agentId }: { agentId?: string }) => { queryFn: async () => { try { if (!agentId) { - return BASE_AGENT; + return await cohereClient.getAgent(DEFAULT_AGENT_ID); } return await cohereClient.getAgent(agentId); } catch (e) { @@ -70,6 +69,21 @@ export const useAgent = ({ agentId }: { agentId?: string }) => { }); }; +export const useDefaultAgent = () => { + const cohereClient = useCohereClient(); + return useQuery({ + queryKey: ['agent', DEFAULT_AGENT_ID], + queryFn: async () => { + try { + return await cohereClient.getAgent(DEFAULT_AGENT_ID); + } catch (e) { + console.error(e); + throw e; + } + }, + }); +}; + /** * @description Returns a function to check if an agent name is unique. */ @@ -109,9 +123,12 @@ export const useRecentAgents = (limit: number = 5) => { }, []); const recentAgents = useMemo(() => { - let recent = uniq(conversations.sort(sortByDate).map((conversation) => conversation.agent_id)) + let recent = uniq( + conversations + .sort(sortByDate) + .map((conversation) => conversation.agent_id ?? DEFAULT_AGENT_ID) + ) .map((agentId) => agents.find((agent) => agent.id === agentId)) - .map((agent) => (!agent ? BASE_AGENT : agent)) .slice(0, limit); // if there are less than `limit` recent agents, fill with the latest created agents @@ -124,12 +141,7 @@ export const useRecentAgents = (limit: number = 5) => { recent = recent.concat(remainingRecentAgents); } - // if still there are less than `limit` recent agents, fill with base agent - if (recent.length < limit && recent.every((agent) => agent?.id !== BASE_AGENT.id)) { - recent = recent.concat(BASE_AGENT); - } - - return recent; + return recent.filter((a) => a !== undefined); }, [conversations, agents, sortByDate, limit]); return recentAgents; diff --git a/src/interfaces/assistants_web/src/hooks/use-brandedColors.ts b/src/interfaces/assistants_web/src/hooks/use-brandedColors.ts index 7b4129d9f4..40d84bc4f3 100644 --- a/src/interfaces/assistants_web/src/hooks/use-brandedColors.ts +++ b/src/interfaces/assistants_web/src/hooks/use-brandedColors.ts @@ -2,7 +2,7 @@ import { useMemo } from 'react'; -import { COHERE_BRANDED_COLORS } from '@/constants'; +import { COHERE_BRANDED_COLORS, DEFAULT_AGENT_ID } from '@/constants'; import { cn } from '@/utils'; const COHERE_THEMES_MAP: { default: COHERE_BRANDED_COLORS; branded: COHERE_BRANDED_COLORS[] } = { @@ -35,34 +35,38 @@ const DEFAULT_COLOR = 'evolved-blue-500'; const DEFAULT_LIGHT_COLOR = 'blue-800'; const DEFAULT_CONTRAST_COLOR = 'marble-950'; +const shouldUseDefault = (assistantId: string | undefined): boolean => { + return !assistantId || assistantId == DEFAULT_AGENT_ID; +}; + const getAssistantColor = (assistantId: string | undefined): string => { - if (!assistantId) return DEFAULT_COLOR; + if (shouldUseDefault(assistantId)) return DEFAULT_COLOR; - const idNumber = assistantId.split('').reduce((acc, char) => acc + char.charCodeAt(0), 0); + const idNumber = assistantId!.split('').reduce((acc, char) => acc + char.charCodeAt(0), 0); const index = idNumber % ASSISTANT_COLORS.length; return ASSISTANT_COLORS[index]; }; const getAssistantLightColor = (assistantId: string | undefined): string => { - if (!assistantId) return DEFAULT_LIGHT_COLOR; + if (shouldUseDefault(assistantId)) return DEFAULT_LIGHT_COLOR; - const idNumber = assistantId.split('').reduce((acc, char) => acc + char.charCodeAt(0), 0); + const idNumber = assistantId!.split('').reduce((acc, char) => acc + char.charCodeAt(0), 0); const index = idNumber % ASSISTANT_LIGHT_COLORS.length; return ASSISTANT_LIGHT_COLORS[index]; }; const getAssistantContrastColor = (assistantId: string | undefined): string => { - if (!assistantId) return DEFAULT_CONTRAST_COLOR; + if (shouldUseDefault(assistantId)) return DEFAULT_CONTRAST_COLOR; - const idNumber = assistantId.split('').reduce((acc, char) => acc + char.charCodeAt(0), 0); + const idNumber = assistantId!.split('').reduce((acc, char) => acc + char.charCodeAt(0), 0); const index = idNumber % ASSISTANT_CONTRAST_COLORS.length; return ASSISTANT_CONTRAST_COLORS[index]; }; export const getCohereTheme = (assistantId?: string): COHERE_BRANDED_COLORS => { - if (!assistantId) return COHERE_THEMES_MAP.default; - const idNumber = assistantId.split('').reduce((acc, char) => acc + char.charCodeAt(0), 0); + if (shouldUseDefault(assistantId)) return COHERE_THEMES_MAP.default; + const idNumber = assistantId!.split('').reduce((acc, char) => acc + char.charCodeAt(0), 0); const index = idNumber % COHERE_THEMES_MAP.branded.length; return COHERE_THEMES_MAP.branded[index]; }; diff --git a/src/interfaces/assistants_web/src/hooks/use-chat.ts b/src/interfaces/assistants_web/src/hooks/use-chat.ts index a64bdc8c65..3fdf47afee 100644 --- a/src/interfaces/assistants_web/src/hooks/use-chat.ts +++ b/src/interfaces/assistants_web/src/hooks/use-chat.ts @@ -16,7 +16,7 @@ import { isStreamError, } from '@/cohere-client'; import { - DEFAULT_AGENT_TOOLS, + BACKGROUND_TOOLS, DEFAULT_TYPING_VELOCITY, DEPLOYMENT_COHERE_PLATFORM, TOOL_PYTHON_INTERPRETER_ID, @@ -528,7 +528,7 @@ export const useChat = (config?: { onSend?: (msg: string) => void }) => { conversation_id: currentConversationId, tools: requestTools ?.map((tool) => ({ name: tool.name })) - .concat(DEFAULT_AGENT_TOOLS.map((defaultTool) => ({ name: defaultTool }))), + .concat(BACKGROUND_TOOLS.map((backgroundTool) => ({ name: backgroundTool }))), file_ids: fileIds && fileIds.length > 0 ? fileIds : undefined, temperature, preamble, diff --git a/src/interfaces/assistants_web/src/hooks/use-tools.ts b/src/interfaces/assistants_web/src/hooks/use-tools.ts index 2c4959f9eb..12d66711b9 100644 --- a/src/interfaces/assistants_web/src/hooks/use-tools.ts +++ b/src/interfaces/assistants_web/src/hooks/use-tools.ts @@ -4,12 +4,11 @@ import useDrivePicker from 'react-google-drive-picker'; import type { PickerCallback } from 'react-google-drive-picker/dist/typeDefs'; import { AgentPublic, ApiError, ToolDefinition, useCohereClient } from '@/cohere-client'; -import { BASE_AGENT_EXCLUDED_TOOLS, DEFAULT_AGENT_TOOLS, TOOL_GOOGLE_DRIVE_ID } from '@/constants'; +import { BACKGROUND_TOOLS, TOOL_GOOGLE_DRIVE_ID } from '@/constants'; import { env } from '@/env.mjs'; import { useNotify } from '@/hooks'; import { useParamsStore } from '@/stores'; import { ConfigurableParams } from '@/stores/slices/paramsSlice'; -import { checkIsBaseAgent } from '@/utils'; export const useListTools = (enabled: boolean = true) => { const client = useCohereClient(); @@ -17,7 +16,7 @@ export const useListTools = (enabled: boolean = true) => { queryKey: ['tools'], queryFn: async () => { const tools = await client.listTools({}); - return tools.filter((tool) => !DEFAULT_AGENT_TOOLS.includes(tool.name ?? '')); + return tools.filter((tool) => !BACKGROUND_TOOLS.includes(tool.name ?? '')); }, refetchOnWindowFocus: false, enabled, @@ -95,14 +94,10 @@ export const useAvailableTools = ({ const { params, setParams } = useParamsStore(); const { tools: paramTools } = params; const enabledTools = paramTools ?? []; - const isBaseAgent = checkIsBaseAgent(agent); + const unauthedTools = tools?.filter( - (tool) => - tool.is_auth_required && - tool.name && - requiredTools?.includes(tool.name) && - !(isBaseAgent && BASE_AGENT_EXCLUDED_TOOLS.includes(tool.name)) + (tool) => tool.is_auth_required && tool.name && requiredTools?.includes(tool.name) ) ?? []; const availableTools = useMemo(() => { @@ -110,8 +105,7 @@ export const useAvailableTools = ({ (t) => t.is_visible && t.is_available && - (!requiredTools || requiredTools.some((rt) => rt === t.name)) && - !(isBaseAgent && BASE_AGENT_EXCLUDED_TOOLS.some((rt) => rt === t.name)) + (!requiredTools || requiredTools.some((rt) => rt === t.name)) ); }, [allTools, requiredTools]); diff --git a/src/interfaces/assistants_web/src/utils/agents.ts b/src/interfaces/assistants_web/src/utils/agents.ts index fe109d5a16..fd4d3159f2 100644 --- a/src/interfaces/assistants_web/src/utils/agents.ts +++ b/src/interfaces/assistants_web/src/utils/agents.ts @@ -1,9 +1,10 @@ import { AgentPublic } from '@/cohere-client'; +import { DEFAULT_AGENT_ID } from '@/constants'; /** * @description Checks if the agent is the base agent. * @param agent - The agent to check. */ -export const checkIsBaseAgent = (agent: AgentPublic | undefined) => { - return agent?.id === ''; +export const checkIsDefaultAgent = (agent: AgentPublic | undefined) => { + return agent?.id === DEFAULT_AGENT_ID; }; diff --git a/src/interfaces/coral_web/src/components/Agents/AgentCard.tsx b/src/interfaces/coral_web/src/components/Agents/AgentCard.tsx index 12d336d53c..a123722eb1 100644 --- a/src/interfaces/coral_web/src/components/Agents/AgentCard.tsx +++ b/src/interfaces/coral_web/src/components/Agents/AgentCard.tsx @@ -26,7 +26,7 @@ import { getCohereColor } from '@/utils/getCohereColor'; type Props = { isExpanded: boolean; name: string; - isBaseAgent?: boolean; + isDefaultAgent?: boolean; id?: string; }; @@ -35,14 +35,14 @@ type Props = { * It shows the agent's name and a colored icon with the first letter of the agent's name. * If the agent is a base agent, it shows the Coral logo instead. */ -export const AgentCard: React.FC = ({ name, id, isBaseAgent, isExpanded }) => { +export const AgentCard: React.FC = ({ name, id, isDefaultAgent, isExpanded }) => { const isTouchDevice = getIsTouchDevice(); const { conversationId } = useChatRoutes(); const router = useRouter(); const pathname = usePathname(); const { data: conversations } = useConversations({ agentId: id }); - const isActive = isBaseAgent + const isActive = isDefaultAgent ? conversationId ? pathname === `/c/${conversationId}` : pathname === '/' @@ -74,7 +74,7 @@ export const AgentCard: React.FC = ({ name, id, isBaseAgent, isExpanded } conversations?.sort((a, b) => Date.parse(b.updated_at) - Date.parse(a.updated_at))[0]?.id ?? ''; const conversationPath = newestConversationId ? `c/${newestConversationId}` : ''; - const url = isBaseAgent + const url = isDefaultAgent ? `/c/${newestConversationId}` : id ? `/a/${id}/${conversationPath}` @@ -84,7 +84,7 @@ export const AgentCard: React.FC = ({ name, id, isBaseAgent, isExpanded } }; const handleNewChat = () => { - const url = isBaseAgent ? `/c` : id ? `/a/${id}` : '/'; + const url = isDefaultAgent ? `/c` : id ? `/a/${id}` : '/'; router.push(url, undefined); setEditAgentPanelOpen(false); resetConversationSettings(); @@ -127,12 +127,12 @@ export const AgentCard: React.FC = ({ name, id, isBaseAgent, isExpanded } 'flex h-8 w-8 flex-shrink-0 items-center justify-center rounded duration-300', id && getCohereColor(id), { - 'bg-mushroom-700': isBaseAgent, + 'bg-mushroom-700': isDefaultAgent, } )} > - {isBaseAgent && } - {!isBaseAgent && ( + {isDefaultAgent && } + {!isDefaultAgent && ( {name[0]} @@ -150,7 +150,7 @@ export const AgentCard: React.FC = ({ name, id, isBaseAgent, isExpanded } ) { const { data: toolsData } = useListTools(); const tools = - toolsData?.filter((t) => t.is_available && !DEFAULT_AGENT_TOOLS.includes(t.name ?? '')) ?? []; + toolsData?.filter((t) => t.is_available && !BACKGROUND_TOOLS.includes(t.name ?? '')) ?? []; const googleDrivefiles: GoogleDriveToolArtifact[] = useMemo(() => { const toolsMetadata = fields.tools_metadata ?? []; diff --git a/src/interfaces/coral_web/src/components/Agents/AgentsList.tsx b/src/interfaces/coral_web/src/components/Agents/AgentsList.tsx index 5cfcb29c1b..76cfea603b 100644 --- a/src/interfaces/coral_web/src/components/Agents/AgentsList.tsx +++ b/src/interfaces/coral_web/src/components/Agents/AgentsList.tsx @@ -31,7 +31,7 @@ export const AgentsList: React.FC = () => {
- + {recentAgents.map((agent) => ( = ({ id, name, description, isBaseAgent }) => { +export const DiscoverAgentCard: React.FC = ({ id, name, description, isDefaultAgent }) => { const { open, close } = useContextStore(); const handleDeleteAssistant = () => { @@ -39,11 +39,11 @@ export const DiscoverAgentCard: React.FC = ({ id, name, description, isBa 'truncate', id && getCohereColor(id), { - 'bg-mushroom-700': isBaseAgent, + 'bg-mushroom-700': isDefaultAgent, } )} > - {isBaseAgent ? ( + {isDefaultAgent ? ( ) : ( @@ -54,7 +54,7 @@ export const DiscoverAgentCard: React.FC = ({ id, name, description, isBa {name} - {!isBaseAgent && ( + {!isDefaultAgent && (
= ({ id, name, description, isBa {description}
+ } + tooltip="Hot keys" + iconName="menu" + theme="mushroom" + onClick={() => openHotKeysDialog()} + stretch + /> + { ), + ['hot-keys']: ( + + + + ), }[name]; }; diff --git a/src/interfaces/assistants_web/src/stores/persistedStore.ts b/src/interfaces/assistants_web/src/stores/persistedStore.ts index 5422e91872..8cf2f99d13 100644 --- a/src/interfaces/assistants_web/src/stores/persistedStore.ts +++ b/src/interfaces/assistants_web/src/stores/persistedStore.ts @@ -55,6 +55,8 @@ export const useSettingsStore = () => { setRightPanelOpen: state.setRightPanelOpen, setUseAssistantKnowledge: state.setUseAssistantKnowledge, setShowSteps: state.setShowSteps, + isHotKeysDialogOpen: state.isHotKeysDialogOpen, + setIsHotKeysDialogOpen: state.setIsHotKeysDialogOpen, }), shallow ); diff --git a/src/interfaces/assistants_web/src/stores/slices/settingsSlice.ts b/src/interfaces/assistants_web/src/stores/slices/settingsSlice.ts index d8292c662e..8b278667df 100644 --- a/src/interfaces/assistants_web/src/stores/slices/settingsSlice.ts +++ b/src/interfaces/assistants_web/src/stores/slices/settingsSlice.ts @@ -5,6 +5,7 @@ const INITIAL_STATE = { isLeftPanelOpen: true, isRightPanelOpen: false, showSteps: true, + isHotKeysDialogOpen: false, }; type State = { @@ -12,6 +13,7 @@ type State = { isLeftPanelOpen: boolean; isRightPanelOpen: boolean; showSteps: boolean; + isHotKeysDialogOpen: boolean; }; type Actions = { @@ -19,6 +21,7 @@ type Actions = { setLeftPanelOpen: (isOpen: boolean) => void; setRightPanelOpen: (isOpen: boolean) => void; setShowSteps: (showSteps: boolean) => void; + setIsHotKeysDialogOpen: (isOpen: boolean) => void; }; export type SettingsStore = State & Actions; @@ -50,5 +53,11 @@ export const createSettingsSlice: StateCreator ({ + ...state, + isHotKeysDialogOpen: isOpen, + })); + }, ...INITIAL_STATE, }); From 85de86a406b4c97249a659080da61f1a8e68f09c Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 19 Nov 2024 11:17:59 -0500 Subject: [PATCH 14/14] chore(deps): bump cross-spawn in /src/interfaces/slack_bot (#850) Bumps and [cross-spawn](https://github.com/moxystudio/node-cross-spawn). These dependencies needed to be updated together. Updates `cross-spawn` from 7.0.3 to 7.0.6 - [Changelog](https://github.com/moxystudio/node-cross-spawn/blob/master/CHANGELOG.md) - [Commits](https://github.com/moxystudio/node-cross-spawn/compare/v7.0.3...v7.0.6) Updates `cross-spawn` from 6.0.5 to 7.0.6 - [Changelog](https://github.com/moxystudio/node-cross-spawn/blob/master/CHANGELOG.md) - [Commits](https://github.com/moxystudio/node-cross-spawn/compare/v7.0.3...v7.0.6) --- updated-dependencies: - dependency-name: cross-spawn dependency-type: indirect - dependency-name: cross-spawn dependency-type: indirect ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- src/interfaces/slack_bot/package-lock.json | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/interfaces/slack_bot/package-lock.json b/src/interfaces/slack_bot/package-lock.json index 60c3e26853..993bb4b87f 100644 --- a/src/interfaces/slack_bot/package-lock.json +++ b/src/interfaces/slack_bot/package-lock.json @@ -2473,9 +2473,9 @@ "integrity": "sha512-QADzlaHc8icV8I7vbaJXJwod9HWYp8uCqf1xa4OfNu1T7JVxQIrUgOWtHdNDtPiywmFbiS12VjotIXLrKM3orQ==" }, "node_modules/cross-spawn": { - "version": "7.0.3", - "resolved": "https://registry.npmjs.org/cross-spawn/-/cross-spawn-7.0.3.tgz", - "integrity": "sha512-iRDPJKUPVEND7dHPO8rkbOnPpyDygcDFtWjpeWNCgy8WP2rXcxXL8TskReQl6OrB2G7+UJrags1q15Fudc7G6w==", + "version": "7.0.6", + "resolved": "https://registry.npmjs.org/cross-spawn/-/cross-spawn-7.0.6.tgz", + "integrity": "sha512-uV2QOWP2nWzsy2aMp8aRibhi9dlzF5Hgh5SHaB9OiTGEyDTiJJyx0uy51QXdyWbtAHNua4XJzUKca3OzKUd3vA==", "dependencies": { "path-key": "^3.1.0", "shebang-command": "^2.0.0", @@ -5430,9 +5430,9 @@ } }, "node_modules/npm-run-all/node_modules/cross-spawn": { - "version": "6.0.5", - "resolved": "https://registry.npmjs.org/cross-spawn/-/cross-spawn-6.0.5.tgz", - "integrity": "sha512-eTVLrBSt7fjbDygz805pMnstIs2VTBNkRm0qxZd+M7A5XDdxVRWO5MxGBXZhjY4cqLYLdtrGqRf8mBPmzwSpWQ==", + "version": "6.0.6", + "resolved": "https://registry.npmjs.org/cross-spawn/-/cross-spawn-6.0.6.tgz", + "integrity": "sha512-VqCUuhcd1iB+dsv8gxPttb5iZh/D0iubSP21g36KXdEuf6I5JiioesUVjpCdHV9MZRUfVFlvwtIUyPfxo5trtw==", "dependencies": { "nice-try": "^1.0.4", "path-key": "^2.0.1",