Skip to content

Commit

Permalink
backend: Filter Tool Auth client access token (#859)
Browse files Browse the repository at this point in the history
* WIP

* Add tests

* Update lock file

* Mock Google Drive for CI

* testing

* Resolve test for remote env
  • Loading branch information
tianjing-li authored Nov 26, 2024
1 parent 5502d8d commit 1de7f97
Show file tree
Hide file tree
Showing 7 changed files with 148 additions and 22 deletions.
42 changes: 40 additions & 2 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ pytest = "^7.1.2"
pytest-env = "^1.1.3"
pytest-cov = "^5.0.0"
factory-boy = "^3.3.0"
fakeredis = "^2.26.1"
freezegun = "^1.5.1"
pre-commit = "^2.20.0"
ruff = "^0.6.0"
Expand Down
6 changes: 5 additions & 1 deletion src/backend/routers/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,11 @@ def list_tools(
session, user_id
)
tool.auth_url = tool_auth_service.get_auth_url(user_id)
tool.token = tool_auth_service.get_token(session, user_id)

# Return access token to client when required by frontend
# e.g: to enable Google Drive picker in client
if tool.should_return_token:
tool.token = tool_auth_service.get_token(session, user_id)
except Exception as e:
logger.error(event=f"Error while fetching Tool Auth: {str(e)}")
tool.is_auth_required = True
Expand Down
1 change: 1 addition & 0 deletions src/backend/schemas/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ class ToolDefinition(Tool):
is_auth_required: bool = False # Per user
auth_url: Optional[str] = "" # Per user
token: Optional[str] = "" # Per user
should_return_token: bool = False

implementation: Any = Field(exclude=True)
auth_implementation: Any = Field(default=None, exclude=True)
Expand Down
14 changes: 14 additions & 0 deletions src/backend/tests/unit/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
from typing import Any, Generator
from unittest.mock import patch

import fakeredis
import pytest
from alembic.command import upgrade
from alembic.config import Config
from fastapi.testclient import TestClient
from redis import Redis
from sqlalchemy import create_engine
from sqlalchemy.orm import Session

Expand Down Expand Up @@ -146,6 +148,18 @@ def override_get_session() -> Generator[Session, Any, None]:
app.dependency_overrides = {}



@pytest.fixture(autouse=True)
def mock_redis_client():
"""
A pytest fixture that globally replaces `Redis.from_url` with `fakeredis`.
"""
fake_redis = fakeredis.FakeStrictRedis(decode_responses=True)

# Patch Redis.from_url to always return the fake Redis instance
with patch.object(Redis, 'from_url', return_value=fake_redis):
yield fake_redis

@pytest.fixture
def user(session: Session) -> User:
return get_factory("User", session).create(id="1")
Expand Down
105 changes: 86 additions & 19 deletions src/backend/tests/unit/routers/test_tool.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,104 @@
from unittest.mock import patch

import pytest
from fastapi.testclient import TestClient
from sqlalchemy.orm import Session

from backend.config.tools import Tool, get_available_tools
from backend.database_models.database import DBSessionDep
from backend.schemas.tool import ToolCategory, ToolDefinition
from backend.schemas.user import User
from backend.tests.unit.factories import get_factory
from backend.tools.base import BaseTool

TOOL_DEFINITION_KEYS = [
"name",
"display_name",
"parameter_definitions",
"is_visible",
"is_available",
"should_return_token",
"category",
"description"
]

@pytest.fixture
def mock_get_available_tools():
with patch("backend.routers.tool.get_available_tools") as mock:
yield mock

def test_list_tools(session_client: TestClient, session: Session) -> None:
def test_list_tools(session_client: TestClient) -> None:
response = session_client.get("/v1/tools")
assert response.status_code == 200
available_tools = get_available_tools()
for tool in response.json():
assert tool["name"] in available_tools.keys()
assert tool["kwargs"] is not None
assert tool["is_visible"] is not None
assert tool["is_available"] is not None
assert tool["category"] is not None
assert tool["description"] is not None
tool_definition = available_tools.get(tool["name"])
assert tool_definition is not None

for key in TOOL_DEFINITION_KEYS:
assert tool[key] == getattr(tool_definition, key)

def test_list_authed_tool_should_return_token(session_client: TestClient, mock_get_available_tools) -> None:
class MockGoogleDriveAuth():
def is_auth_required(self, session: DBSessionDep, user_id: str) -> bool:
return False

def get_auth_url(self, user_id: str) -> str:
return ""

def get_token(self, session: DBSessionDep, user_id: str) -> str:
return "mock"
class MockGoogleDrive(BaseTool):
ID = "google_drive"
@classmethod
def get_tool_definition(cls) -> ToolDefinition:
return ToolDefinition(
name=cls.ID,
display_name="Google Drive",
implementation=cls,
parameter_definitions={
"query": {
"description": "Query to search Google Drive documents with.",
"type": "str",
"required": True,
}
},
is_visible=True,
is_available=True,
auth_implementation=MockGoogleDriveAuth,
should_return_token=True,
error_message=cls.generate_error_message(),
category=ToolCategory.DataLoader,
description="Returns a list of relevant document snippets from the user's Google drive.",
)

# Patch Google Drive tool
mock_get_available_tools.return_value = {Tool.Google_Drive.value.ID: MockGoogleDrive.get_tool_definition()}

response = session_client.get("/v1/tools")
assert response.status_code == 200

for tool in response.json():
print(tool)
if tool["should_return_token"]:
assert tool["token"] == "mock"

def test_list_authed_tool_should_not_return_token(session_client: TestClient) -> None:
response = session_client.get("/v1/tools")

assert response.status_code == 200

for tool in response.json():
if not tool["should_return_token"]:
assert tool["token"] == ""

def test_list_tools_error_message_none_if_available(client: TestClient) -> None:
response = client.get("/v1/tools")
def test_list_tools_error_message_none_if_available(session_client: TestClient) -> None:
response = session_client.get("/v1/tools")
assert response.status_code == 200
for tool in response.json():
if tool["is_available"]:
assert tool["error_message"] is None


def test_list_tools_with_agent(
session_client: TestClient, session: Session, user: User
) -> None:
Expand All @@ -42,18 +114,13 @@ def test_list_tools_with_agent(
assert tool["name"] == Tool.Wiki_Retriever_LangChain.value.ID

# get tool that has the same name as the tool in the response
tool_definition = get_available_tools()[tool["name"]]

assert tool["kwargs"] == tool_definition.kwargs
assert tool["is_visible"] == tool_definition.is_visible
assert tool["is_available"] == tool_definition.is_available
assert tool["error_message"] == tool_definition.error_message
assert tool["category"] == tool_definition.category
assert tool["description"] == tool_definition.description
tool_definition = get_available_tools().get(tool["name"])

for key in TOOL_DEFINITION_KEYS:
assert tool[key] == getattr(tool_definition, key)

def test_list_tools_with_agent_that_doesnt_exist(
session_client: TestClient, session: Session
session_client: TestClient
) -> None:
response = session_client.get("/v1/tools", params={"agent_id": "fake_id"})
assert response.status_code == 404
Expand Down
1 change: 1 addition & 0 deletions src/backend/tools/google_drive/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def get_tool_definition(cls) -> ToolDefinition:
is_visible=True,
is_available=GoogleDrive.is_available(),
auth_implementation=GoogleDriveAuth,
should_return_token=True,
error_message=cls.generate_error_message(),
category=ToolCategory.DataLoader,
description="Returns a list of relevant document snippets from the user's Google drive.",
Expand Down

0 comments on commit 1de7f97

Please sign in to comment.