From 0e57e64e517424381797d545957c349475f70528 Mon Sep 17 00:00:00 2001 From: Eric Zawadski Date: Mon, 13 Jan 2025 11:27:09 -0800 Subject: [PATCH 1/5] chore(backend): add type hints to integration tests --- .../tests/integration/crud/test_deployment.py | 30 +++++++++--------- .../tests/integration/crud/test_model.py | 31 ++++++++++--------- .../tests/integration/routers/test_agent.py | 9 +++--- .../tests/integration/routers/test_model.py | 30 +++++++++++++----- 4 files changed, 60 insertions(+), 40 deletions(-) diff --git a/src/backend/tests/integration/crud/test_deployment.py b/src/backend/tests/integration/crud/test_deployment.py index 48c2c7bc74..908d57e844 100644 --- a/src/backend/tests/integration/crud/test_deployment.py +++ b/src/backend/tests/integration/crud/test_deployment.py @@ -1,12 +1,14 @@ import pytest +from sqlalchemy.orm import Session from backend.crud import deployment as deployment_crud from backend.database_models.deployment import Deployment from backend.schemas.deployment import DeploymentCreate, DeploymentUpdate +from backend.schemas.user import User from backend.tests.unit.factories import get_factory -def test_create_deployment(session, deployment): +def test_create_deployment(session: Session, deployment: Deployment) -> None: deployment_data = DeploymentCreate( name="Test Deployment", deployment_class_name="CohereDeployment", @@ -29,7 +31,7 @@ def test_create_deployment(session, deployment): assert deployment.name == deployment_data.name -def test_create_deployment_invalid_class_name(session): +def test_create_deployment_invalid_class_name(session: Session) -> None: with pytest.raises(ValueError) as e: deployment_data = DeploymentCreate( name="Test Deployment", @@ -43,19 +45,19 @@ def test_create_deployment_invalid_class_name(session): assert "Deployment class not found" in str(e.value) -def test_get_deployment(session): +def test_get_deployment(session: Session) -> None: deployment = get_factory("Deployment", session).create(name="Test Deployment") retrieved_deployment = deployment_crud.get_deployment(session, deployment.id) assert retrieved_deployment.id == deployment.id assert retrieved_deployment.name == deployment.name -def test_fail_get_nonexistent_deployment(session): +def test_fail_get_nonexistent_deployment(session: Session) -> None: deployment = deployment_crud.get_deployment(session, "123") assert deployment is None -def test_list_deployments(session): +def test_list_deployments(session: Session) -> None: # Delete default deployments session.query(Deployment).delete() _ = get_factory("Deployment", session).create(name="Test Deployment") @@ -65,14 +67,14 @@ def test_list_deployments(session): assert deployments[0].name == "Test Deployment" -def test_list_deployments_empty(session): +def test_list_deployments_empty(session: Session) -> None: # Delete default deployments session.query(Deployment).delete() deployments = deployment_crud.get_deployments(session) assert len(deployments) == 0 -def test_list_deployments_with_pagination(session): +def test_list_deployments_with_pagination(session: Session) -> None: # Delete default deployments session.query(Deployment).delete() for i in range(10): @@ -82,7 +84,7 @@ def test_list_deployments_with_pagination(session): assert len(deployments) == 5 -def test_get_available_deployments(session, user): +def test_get_available_deployments(session: Session, user: User) -> None: session.query(Deployment).delete() deployment = get_factory("Deployment", session).create() _ = get_factory("Deployment", session).create( @@ -95,14 +97,14 @@ def test_get_available_deployments(session, user): assert deployments[0].id == deployment.id -def test_get_available_deployments_empty(session, user): +def test_get_available_deployments_empty(session: Session, user: User) -> None: session.query(Deployment).delete() deployments = deployment_crud.get_available_deployments(session) assert len(deployments) == 0 -def test_update_deployment(session, deployment): +def test_update_deployment(session: Session, deployment: Deployment) -> None: new_deployment_data = DeploymentUpdate( name="NewName", description="New Description", @@ -122,7 +124,7 @@ def test_update_deployment(session, deployment): assert updated_deployment.id == deployment.id -def test_update_deployment_partial(session, deployment): +def test_update_deployment_partial(session: Session, deployment: Deployment) -> None: new_deployment_data = DeploymentUpdate(name="Cohere") updated_deployment = deployment_crud.update_deployment( @@ -133,7 +135,7 @@ def test_update_deployment_partial(session, deployment): assert updated_deployment.id == deployment.id -def test_do_not_update_deployment(session, deployment): +def test_do_not_update_deployment(session: Session, deployment: Deployment) -> None: new_deployment_data = DeploymentUpdate(name="Test Deployment") updated_deployment = deployment_crud.update_deployment( @@ -142,7 +144,7 @@ def test_do_not_update_deployment(session, deployment): assert updated_deployment.name == deployment.name -def test_delete_deployment(session): +def test_delete_deployment(session: Session) -> None: deployment = get_factory("Deployment", session).create() deployment_crud.delete_deployment(session, deployment.id) @@ -151,7 +153,7 @@ def test_delete_deployment(session): assert deployment is None -def test_delete_nonexistent_deployment(session): +def test_delete_nonexistent_deployment(session: Session) -> None: deployment_crud.delete_deployment(session, "123") # no error deployment = deployment_crud.get_deployment(session, "123") assert deployment is None diff --git a/src/backend/tests/integration/crud/test_model.py b/src/backend/tests/integration/crud/test_model.py index 5fbd3bacf0..95389177bc 100644 --- a/src/backend/tests/integration/crud/test_model.py +++ b/src/backend/tests/integration/crud/test_model.py @@ -1,10 +1,13 @@ +from sqlalchemy.orm import Session + from backend.crud import model as model_crud +from backend.database_models.deployment import Deployment from backend.database_models.model import Model from backend.schemas.model import ModelCreate, ModelUpdate from backend.tests.unit.factories import get_factory -def test_create_model(session, deployment): +def test_create_model(session: Session, deployment: Deployment) -> None: model_data = ModelCreate( name="Test Model", cohere_name="Test Cohere Model", @@ -21,7 +24,7 @@ def test_create_model(session, deployment): assert model.name == model_data.name -def test_get_model(session, deployment): +def test_get_model(session: Session, deployment: Deployment) -> None: model = get_factory("Model", session).create( name="Test Model", deployment=deployment ) @@ -30,12 +33,12 @@ def test_get_model(session, deployment): assert retrieved_model.name == model.name -def test_fail_get_nonexistent_model(session): +def test_fail_get_nonexistent_model(session: Session) -> None: model = model_crud.get_model(session, "123") assert model is None -def test_list_models(session, deployment): +def test_list_models(session: Session, deployment: Deployment) -> None: # Delete default models session.query(Model).delete() _ = get_factory("Model", session).create(name="Test Model", deployment=deployment) @@ -45,14 +48,14 @@ def test_list_models(session, deployment): assert models[0].name == "Test Model" -def test_list_models_empty(session): +def test_list_models_empty(session: Session) -> None: # Delete default models session.query(Model).delete() models = model_crud.get_models(session) assert len(models) == 0 -def test_list_models_with_pagination(session, deployment): +def test_list_models_with_pagination(session: Session, deployment: Deployment) -> None: # Delete default models session.query(Model).delete() for i in range(10): @@ -67,7 +70,7 @@ def test_list_models_with_pagination(session, deployment): assert model.name == f"Test Model {i + 5}" -def test_get_models_by_deployment_id(session, deployment): +def test_get_models_by_deployment_id(session: Session, deployment: Deployment) -> None: for i in range(10): model = get_factory("Model", session).create( name=f"Test Model {i}", deployment=deployment @@ -80,12 +83,12 @@ def test_get_models_by_deployment_id(session, deployment): assert model.name == f"Test Model {i}" -def test_get_models_by_deployment_id_empty(session, deployment): +def test_get_models_by_deployment_id_empty(session: Session, deployment: Deployment) -> None: models = model_crud.get_models_by_deployment_id(session, deployment.id) assert len(models) == 0 -def test_get_models_by_deployment_id_with_pagination(session, deployment): +def test_get_models_by_deployment_id_with_pagination(session: Session, deployment: Deployment) -> None: for i in range(10): model = get_factory("Model", session).create( name=f"Test Model {i}", deployment=deployment @@ -100,7 +103,7 @@ def test_get_models_by_deployment_id_with_pagination(session, deployment): assert model.name == f"Test Model {i + 5}" -def test_update_model(session, deployment): +def test_update_model(session: Session, deployment: Deployment) -> None: model = get_factory("Model", session).create( name="Sagemaker model", deployment=deployment ) @@ -127,7 +130,7 @@ def test_update_model(session, deployment): assert model.deployment_id == new_model_data.deployment_id -def test_update_model_partial(session, deployment): +def test_update_model_partial(session: Session, deployment: Deployment) -> None: model = get_factory("Model", session).create( name="Test Model U", deployment=deployment ) @@ -148,7 +151,7 @@ def test_update_model_partial(session, deployment): assert model.deployment_id == model.deployment_id -def test_do_not_update_model(session, deployment): +def test_do_not_update_model(session: Session, deployment: Deployment) -> None: model = get_factory("Model", session).create( name="Test Model", deployment=deployment ) @@ -159,7 +162,7 @@ def test_do_not_update_model(session, deployment): assert updated_model.name == model.name -def test_delete_model(session, deployment): +def test_delete_model(session: Session, deployment: Deployment) -> None: model = get_factory("Model", session).create(deployment=deployment) model_crud.delete_model(session, model.id) @@ -168,7 +171,7 @@ def test_delete_model(session, deployment): assert model is None -def test_delete_nonexistent_model(session): +def test_delete_nonexistent_model(session: Session) -> None: model_crud.delete_model(session, "123") # no error model = model_crud.get_model(session, "123") assert model is None diff --git a/src/backend/tests/integration/routers/test_agent.py b/src/backend/tests/integration/routers/test_agent.py index e80c23842a..191eaf9831 100644 --- a/src/backend/tests/integration/routers/test_agent.py +++ b/src/backend/tests/integration/routers/test_agent.py @@ -13,6 +13,7 @@ from backend.database_models.snapshot import Snapshot from backend.exceptions import DeploymentNotFoundError from backend.model_deployments.cohere_platform import CohereDeployment +from backend.schemas.user import User from backend.tests.unit.factories import get_factory is_cohere_env_set = ( @@ -20,7 +21,7 @@ and os.environ.get("COHERE_API_KEY") != "" ) -def test_create_agent(session_client: TestClient, session: Session, user, mock_cohere_list_models) -> None: +def test_create_agent(session_client: TestClient, session: Session, user: User, mock_cohere_list_models) -> None: request_json = { "name": "test agent", "version": 1, @@ -60,7 +61,7 @@ def test_create_agent(session_client: TestClient, session: Session, user, mock_c def test_create_agent_with_tool_metadata( - session_client: TestClient, session: Session, user, mock_cohere_list_models + session_client: TestClient, session: Session, user: User, mock_cohere_list_models ) -> None: request_json = { "name": "test agent", @@ -118,7 +119,7 @@ def test_create_agent_with_tool_metadata( def test_create_agent_missing_non_required_fields( - session_client: TestClient, session: Session, user, mock_cohere_list_models + session_client: TestClient, session: Session, user: User, mock_cohere_list_models ) -> None: request_json = { "name": "test agent", @@ -149,7 +150,7 @@ def test_create_agent_missing_non_required_fields( assert agent.model == request_json["model"] -def test_update_agent(session_client: TestClient, session: Session, user, mock_cohere_list_models) -> None: +def test_update_agent(session_client: TestClient, session: Session, user: User, mock_cohere_list_models) -> None: agent = get_factory("Agent", session).create( name="test agent", version=1, diff --git a/src/backend/tests/integration/routers/test_model.py b/src/backend/tests/integration/routers/test_model.py index 133a684859..5f56bf929c 100644 --- a/src/backend/tests/integration/routers/test_model.py +++ b/src/backend/tests/integration/routers/test_model.py @@ -2,10 +2,11 @@ from sqlalchemy.orm import Session from backend.database_models import Model +from backend.database_models.deployment import Deployment from backend.tests.unit.factories import get_factory -def test_create_model(session_client: TestClient, session: Session, deployment) -> None: +def test_create_model(session_client: TestClient, session: Session, deployment: Deployment) -> None: request_json = { "name": "sagemaker-command-created", "cohere_name": "command", @@ -29,7 +30,7 @@ def test_create_model(session_client: TestClient, session: Session, deployment) def test_create_model_non_existing_deployment( - session_client: TestClient, session: Session + session_client: TestClient, session: Session, ) -> None: request_json = { "name": "sagemaker-command-created", @@ -50,7 +51,7 @@ def test_create_model_non_existing_deployment( ) -def test_update_model(session_client: TestClient, session: Session, deployment) -> None: +def test_update_model(session_client: TestClient, session: Session, deployment: Deployment) -> None: request_json = { "name": "sagemaker-command-updated", "cohere_name": "command", @@ -69,7 +70,7 @@ def test_update_model(session_client: TestClient, session: Session, deployment) assert model.deployment_id == response_json["deployment_id"] -def test_get_model(session_client: TestClient, session: Session, deployment) -> None: +def test_get_model(session_client: TestClient, session: Session, deployment: Deployment) -> None: # Delete all models session.query(Model).delete() model = get_factory("Model", session).create(deployment=deployment) @@ -89,7 +90,11 @@ def test_get_model_non_existing(session_client: TestClient, session: Session) -> assert "Model not found" in response_json["detail"] -def test_list_models(session_client: TestClient, session: Session, deployment) -> None: +def test_list_models( + session_client: TestClient, + session: Session, + deployment: Deployment, +) -> None: # Delete all models session.query(Model).delete() for _ in range(5): @@ -101,7 +106,10 @@ def test_list_models(session_client: TestClient, session: Session, deployment) - assert len(models) == 5 -def test_list_models_empty(session_client: TestClient, session: Session) -> None: +def test_list_models_empty( + session_client: TestClient, + session: Session, +) -> None: session.query(Model).delete() response = session_client.get("/v1/models") assert response.status_code == 200 @@ -110,7 +118,9 @@ def test_list_models_empty(session_client: TestClient, session: Session) -> None def test_list_models_with_pagination( - session_client: TestClient, session: Session, deployment + session_client: TestClient, + session: Session, + deployment: Deployment, ) -> None: # Delete all models session.query(Model).delete() @@ -128,7 +138,11 @@ def test_list_models_with_pagination( assert model["name"] == f"Test Model {i + 5}" -def test_delete_model(session_client: TestClient, session: Session, deployment) -> None: +def test_delete_model( + session_client: TestClient, + session: Session, + deployment: Deployment, +) -> None: model = get_factory("Model", session).create(deployment=deployment) response = session_client.delete(f"/v1/models/{model.id}") assert response.status_code == 200 From 37e01b33bea8c803d4ebb2f7354cfc0341f183e8 Mon Sep 17 00:00:00 2001 From: Eric Zawadski Date: Tue, 14 Jan 2025 06:07:38 -0800 Subject: [PATCH 2/5] chore(backend): remove deprecated method --- src/backend/routers/organization.py | 2 +- src/backend/tests/unit/routers/test_organization.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/backend/routers/organization.py b/src/backend/routers/organization.py index 33bbef3278..c26b51c035 100644 --- a/src/backend/routers/organization.py +++ b/src/backend/routers/organization.py @@ -38,7 +38,7 @@ def create_organization( """ Create a new organization. """ - organization_data = OrganizationModel(**organization.dict()) + organization_data = OrganizationModel(**organization.model_dump()) return organization_crud.create_organization(session, organization_data) diff --git a/src/backend/tests/unit/routers/test_organization.py b/src/backend/tests/unit/routers/test_organization.py index 658321c992..67459a10fb 100644 --- a/src/backend/tests/unit/routers/test_organization.py +++ b/src/backend/tests/unit/routers/test_organization.py @@ -8,7 +8,7 @@ def test_create_organization(session_client: TestClient, session: Session) -> None: organization = CreateOrganization(name="test organization") - response = session_client.post("/v1/organizations", json=organization.dict()) + response = session_client.post("/v1/organizations", json=organization.model_dump()) assert response.status_code == 200 assert response.json()["name"] == organization.name @@ -18,7 +18,7 @@ def test_create_organization_with_existing_name( ) -> None: get_factory("Organization", session).create(name="test organization") new_organization = CreateOrganization(name="test organization") - response = session_client.post("/v1/organizations", json=new_organization.dict()) + response = session_client.post("/v1/organizations", json=new_organization.model_dump()) assert response.status_code == 400 assert response.json() == { "detail": "Organization with name: test organization already exists." @@ -29,7 +29,7 @@ def test_update_organization(session_client: TestClient, session: Session) -> No organization = get_factory("Organization", session).create(name="test organization") new_organization = UpdateOrganization(name="new organization") response = session_client.put( - f"/v1/organizations/{organization.id}", json=new_organization.dict() + f"/v1/organizations/{organization.id}", json=new_organization.model_dump() ) assert response.status_code == 200 assert response.json()["name"] == new_organization.name @@ -39,7 +39,7 @@ def test_update_not_existing_organization( session_client: TestClient, session: Session ) -> None: new_organization = UpdateOrganization(name="new organization") - response = session_client.put("/v1/organizations/123", json=new_organization.dict()) + response = session_client.put("/v1/organizations/123", json=new_organization.model_dump()) assert response.status_code == 404 assert response.json() == {"detail": "Organization with ID: 123 not found."} From dc7506682f99e8bd76e34931e0873ac20a1343b9 Mon Sep 17 00:00:00 2001 From: Eric Zawadski Date: Tue, 14 Jan 2025 12:28:46 -0800 Subject: [PATCH 3/5] chore(ci): improve readability --- .github/workflows/backend_integration_tests.yml | 9 ++++++++- .github/workflows/backend_unit_tests.yml | 7 +++++++ 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/.github/workflows/backend_integration_tests.yml b/.github/workflows/backend_integration_tests.yml index 8537b85670..9d398d51b5 100644 --- a/.github/workflows/backend_integration_tests.yml +++ b/.github/workflows/backend_integration_tests.yml @@ -8,16 +8,18 @@ on: jobs: pytest: permissions: write-all - environment: development + # environment: development runs-on: ubuntu-latest steps: - name: Checkout repo uses: actions/checkout@v3 + - uses: actions/setup-python@v5 with: python-version: '3.11' cache: 'pip' + - name: Install poetry uses: snok/install-poetry@v1 with: @@ -25,23 +27,28 @@ jobs: virtualenvs-in-project: true virtualenvs-path: .venv installer-parallel: true + - name: Load cached venv id: cached-poetry-dependencies uses: actions/cache@v4 with: path: .venv key: venv-${{ runner.os }}-${{ steps.setup-python.outputs.python-version }}-${{ hashFiles('**/poetry.lock') }} + - name: Install dependencies if: steps.cached-poetry-dependencies.outputs.cache-hit != 'true' run: poetry install --with dev --no-interaction --no-root + - name: Setup test DB container run: make test-db + - name: Test with pytest if: github.actor != 'dependabot[bot]' run: | make run-integration-tests env: PYTHONPATH: src + - name: Upload coverage reports to Codecov uses: codecov/codecov-action@v4.0.1 with: diff --git a/.github/workflows/backend_unit_tests.yml b/.github/workflows/backend_unit_tests.yml index 8664706c75..a175c3a79e 100644 --- a/.github/workflows/backend_unit_tests.yml +++ b/.github/workflows/backend_unit_tests.yml @@ -14,10 +14,12 @@ jobs: steps: - name: Checkout repo uses: actions/checkout@v3 + - uses: actions/setup-python@v5 with: python-version: '3.11' cache: 'pip' + - name: Install poetry uses: snok/install-poetry@v1 with: @@ -25,23 +27,28 @@ jobs: virtualenvs-in-project: true virtualenvs-path: .venv installer-parallel: true + - name: Load cached venv id: cached-poetry-dependencies uses: actions/cache@v4 with: path: .venv key: venv-${{ runner.os }}-${{ steps.setup-python.outputs.python-version }}-${{ hashFiles('**/poetry.lock') }} + - name: Install dependencies if: steps.cached-poetry-dependencies.outputs.cache-hit != 'true' run: poetry install --with dev --no-interaction --no-root + - name: Setup test DB container run: make test-db + - name: Test with pytest if: github.actor != 'dependabot[bot]' run: | make run-unit-tests-debug env: PYTHONPATH: src + - name: Upload coverage reports to Codecov uses: codecov/codecov-action@v4.0.1 with: From 912bfedb525046654951f9e45b05a4d165976934 Mon Sep 17 00:00:00 2001 From: Eric Zawadski Date: Wed, 15 Jan 2025 13:05:28 -0800 Subject: [PATCH 4/5] feat(backend): improved typehints and decorators --- src/backend/model_deployments/azure.py | 24 ++--- src/backend/model_deployments/base.py | 26 ++--- src/backend/model_deployments/bedrock.py | 24 ++--- .../model_deployments/cohere_platform.py | 24 ++--- src/backend/model_deployments/sagemaker.py | 22 ++-- .../model_deployments/single_container.py | 24 ++--- src/backend/schemas/context.py | 26 ++--- src/backend/services/conversation.py | 16 +-- src/backend/tests/integration/conftest.py | 7 +- .../mock_deployments/mock_azure.py | 24 ++--- .../mock_deployments/mock_bedrock.py | 24 ++--- .../mock_deployments/mock_cohere_platform.py | 101 +++++++++++------- .../mock_deployments/mock_sagemaker.py | 24 ++--- .../mock_deployments/mock_single_container.py | 24 ++--- 14 files changed, 211 insertions(+), 179 deletions(-) diff --git a/src/backend/model_deployments/azure.py b/src/backend/model_deployments/azure.py index bea01b7743..dce6660516 100644 --- a/src/backend/model_deployments/azure.py +++ b/src/backend/model_deployments/azure.py @@ -1,4 +1,4 @@ -from typing import Any, AsyncGenerator, Dict, List +from typing import Any, AsyncGenerator import cohere @@ -43,33 +43,33 @@ def __init__(self, **kwargs: Any): base_url=self.chat_endpoint_url, api_key=self.api_key ) - @classmethod - def name(cls) -> str: + @staticmethod + def name() -> str: return "Azure" - @classmethod - def env_vars(cls) -> List[str]: + @staticmethod + def env_vars() -> list[str]: return [AZURE_API_KEY_ENV_VAR, AZURE_CHAT_URL_ENV_VAR] - @classmethod - def rerank_enabled(cls) -> bool: + @staticmethod + def rerank_enabled() -> bool: return False @classmethod - def list_models(cls) -> List[str]: + def list_models(cls) -> list[str]: if not cls.is_available(): return [] return cls.DEFAULT_MODELS - @classmethod - def is_available(cls) -> bool: + @staticmethod + def is_available() -> bool: return ( AzureDeployment.default_api_key is not None and AzureDeployment.default_chat_endpoint_url is not None ) - async def invoke_chat(self, chat_request: CohereChatRequest) -> Any: + async def invoke_chat(self, chat_request: CohereChatRequest, **kwargs) -> Any: response = self.client.chat( **chat_request.model_dump(exclude={"stream", "file_ids", "agent_id"}), ) @@ -86,6 +86,6 @@ async def invoke_chat_stream( yield to_dict(event) async def invoke_rerank( - self, query: str, documents: List[Dict[str, Any]], ctx: Context + self, query: str, documents: list[str], ctx: Context, **kwargs ) -> Any: return None diff --git a/src/backend/model_deployments/base.py b/src/backend/model_deployments/base.py index cae22e68fe..f6bec64471 100644 --- a/src/backend/model_deployments/base.py +++ b/src/backend/model_deployments/base.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Any, AsyncGenerator, Dict, List +from typing import Any from backend.config.settings import Settings from backend.schemas.cohere_chat import CohereChatRequest @@ -25,32 +25,32 @@ def __init__(self, db_id=None, **kwargs: Any): def id(cls) -> str: return cls.db_id if cls.db_id else cls.name().replace(" ", "_").lower() - @classmethod + @staticmethod @abstractmethod - def name(cls) -> str: ... + def name() -> str: ... - @classmethod + @staticmethod @abstractmethod - def env_vars(cls) -> List[str]: ... + def env_vars() -> list[str]: ... - @classmethod + @staticmethod @abstractmethod - def rerank_enabled(cls) -> bool: ... + def rerank_enabled() -> bool: ... @classmethod @abstractmethod - def list_models(cls) -> List[str]: ... + def list_models(cls) -> list[str]: ... - @classmethod + @staticmethod @abstractmethod - def is_available(cls) -> bool: ... + def is_available() -> bool: ... @classmethod def is_community(cls) -> bool: return False @classmethod - def config(cls) -> Dict[str, Any]: + def config(cls) -> dict[str, Any]: config = Settings().get(f"deployments.{cls.id()}") config_dict = {} if not config else dict(config) for key, value in config_dict.items(): @@ -79,9 +79,9 @@ async def invoke_chat( @abstractmethod async def invoke_chat_stream( self, chat_request: CohereChatRequest, ctx: Context, **kwargs: Any - ) -> AsyncGenerator[Any, Any]: ... + ) -> Any: ... @abstractmethod async def invoke_rerank( - self, query: str, documents: List[Dict[str, Any]], ctx: Context, **kwargs: Any + self, query: str, documents: list[str], ctx: Context, **kwargs: Any ) -> Any: ... diff --git a/src/backend/model_deployments/bedrock.py b/src/backend/model_deployments/bedrock.py index 7241c79dd1..9403deed47 100644 --- a/src/backend/model_deployments/bedrock.py +++ b/src/backend/model_deployments/bedrock.py @@ -1,4 +1,4 @@ -from typing import Any, AsyncGenerator, Dict, List +from typing import Any, AsyncGenerator import cohere @@ -42,12 +42,12 @@ def __init__(self, **kwargs: Any): ), ) - @classmethod - def name(cls) -> str: + @staticmethod + def name() -> str: return "Bedrock" - @classmethod - def env_vars(cls) -> List[str]: + @staticmethod + def env_vars() -> list[str]: return [ BEDROCK_ACCESS_KEY_ENV_VAR, BEDROCK_SECRET_KEY_ENV_VAR, @@ -55,19 +55,19 @@ def env_vars(cls) -> List[str]: BEDROCK_REGION_NAME_ENV_VAR, ] - @classmethod - def rerank_enabled(cls) -> bool: + @staticmethod + def rerank_enabled() -> bool: return False @classmethod - def list_models(cls) -> List[str]: + def list_models(cls) -> list[str]: if not cls.is_available(): return [] return cls.DEFAULT_MODELS - @classmethod - def is_available(cls) -> bool: + @staticmethod + def is_available() -> bool: return ( BedrockDeployment.access_key is not None and BedrockDeployment.secret_access_key is not None @@ -75,7 +75,7 @@ def is_available(cls) -> bool: and BedrockDeployment.region_name is not None ) - async def invoke_chat(self, chat_request: CohereChatRequest) -> Any: + async def invoke_chat(self, chat_request: CohereChatRequest, **kwargs: Any) -> Any: # bedrock accepts a subset of the chat request fields bedrock_chat_req = chat_request.model_dump( exclude={"tools", "conversation_id", "model", "stream"}, exclude_none=True @@ -101,6 +101,6 @@ async def invoke_chat_stream( yield to_dict(event) async def invoke_rerank( - self, query: str, documents: List[Dict[str, Any]], ctx: Context + self, query: str, documents: list[str], ctx: Context, **kwargs: Any ) -> Any: return None diff --git a/src/backend/model_deployments/cohere_platform.py b/src/backend/model_deployments/cohere_platform.py index cbddb750ea..a718d0b68e 100644 --- a/src/backend/model_deployments/cohere_platform.py +++ b/src/backend/model_deployments/cohere_platform.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List +from typing import Any import cohere import requests @@ -29,20 +29,20 @@ def __init__(self, **kwargs: Any): ) self.client = cohere.Client(api_key, client_name=self.client_name) - @classmethod - def name(cls) -> str: + @staticmethod + def name() -> str: return "Cohere Platform" - @classmethod - def env_vars(cls) -> List[str]: + @staticmethod + def env_vars() -> list[str]: return [COHERE_API_KEY_ENV_VAR] - @classmethod - def rerank_enabled(cls) -> bool: + @staticmethod + def rerank_enabled() -> bool: return True @classmethod - def list_models(cls) -> List[str]: + def list_models(cls) -> list[str]: logger = LoggerFactory().get_logger() if not CohereDeployment.is_available(): return [] @@ -64,12 +64,12 @@ def list_models(cls) -> List[str]: models = response.json()["models"] return [model["name"] for model in models if model.get("endpoints") and "chat" in model["endpoints"]] - @classmethod - def is_available(cls) -> bool: + @staticmethod + def is_available() -> bool: return CohereDeployment.api_key is not None async def invoke_chat( - self, chat_request: CohereChatRequest, ctx: Context, **kwargs: Any + self, chat_request: CohereChatRequest, **kwargs: Any ) -> Any: response = self.client.chat( **chat_request.model_dump(exclude={"stream", "file_ids", "agent_id"}), @@ -99,7 +99,7 @@ async def invoke_chat_stream( yield event_dict async def invoke_rerank( - self, query: str, documents: List[Dict[str, Any]], ctx: Context, **kwargs: Any + self, query: str, documents: list[str], ctx: Context, **kwargs: Any ) -> Any: response = self.client.rerank( query=query, documents=documents, model=DEFAULT_RERANK_MODEL diff --git a/src/backend/model_deployments/sagemaker.py b/src/backend/model_deployments/sagemaker.py index b8de329230..6a686a378c 100644 --- a/src/backend/model_deployments/sagemaker.py +++ b/src/backend/model_deployments/sagemaker.py @@ -1,6 +1,6 @@ import io import json -from typing import Any, AsyncGenerator, Dict, List +from typing import Any, AsyncGenerator import boto3 @@ -65,12 +65,12 @@ def __init__(self, **kwargs: Any): "ContentType": "application/json", } - @classmethod - def name(cls) -> str: + @staticmethod + def name() -> str: return "SageMaker" - @classmethod - def env_vars(cls) -> List[str]: + @staticmethod + def env_vars() -> list[str]: return [ SAGE_MAKER_ACCESS_KEY_ENV_VAR, SAGE_MAKER_SECRET_KEY_ENV_VAR, @@ -79,19 +79,19 @@ def env_vars(cls) -> List[str]: SAGE_MAKER_ENDPOINT_NAME_ENV_VAR, ] - @classmethod - def rerank_enabled(cls) -> bool: + @staticmethod + def rerank_enabled() -> bool: return False @classmethod - def list_models(cls) -> List[str]: + def list_models(cls) -> list[str]: if not SageMakerDeployment.is_available(): return [] return cls.DEFAULT_MODELS - @classmethod - def is_available(cls) -> bool: + @staticmethod + def is_available() -> bool: return ( SageMakerDeployment.region_name is not None and SageMakerDeployment.aws_access_key_id is not None @@ -121,7 +121,7 @@ async def invoke_chat_stream( yield stream_event async def invoke_rerank( - self, query: str, documents: List[Dict[str, Any]], ctx: Context + self, query: str, documents: list[str], ctx: Context, **kwargs ) -> Any: return None diff --git a/src/backend/model_deployments/single_container.py b/src/backend/model_deployments/single_container.py index a9d69ab6a9..4ddcb2e174 100644 --- a/src/backend/model_deployments/single_container.py +++ b/src/backend/model_deployments/single_container.py @@ -1,4 +1,4 @@ -from typing import Any, AsyncGenerator, Dict, List +from typing import Any, AsyncGenerator import cohere @@ -33,33 +33,33 @@ def __init__(self, **kwargs: Any): base_url=self.url, client_name=self.client_name, api_key="none" ) - @classmethod - def name(cls) -> str: + @staticmethod + def name() -> str: return "Single Container" - @classmethod - def env_vars(cls) -> List[str]: + @staticmethod + def env_vars() -> list[str]: return [SC_URL_ENV_VAR, SC_MODEL_ENV_VAR] - @classmethod - def rerank_enabled(cls) -> bool: + @staticmethod + def rerank_enabled() -> bool: return SingleContainerDeployment.default_model.startswith("rerank") @classmethod - def list_models(cls) -> List[str]: + def list_models(cls) -> list[str]: if not SingleContainerDeployment.is_available(): return [] return [SingleContainerDeployment.default_model] - @classmethod - def is_available(cls) -> bool: + @staticmethod + def is_available() -> bool: return ( SingleContainerDeployment.default_model is not None and SingleContainerDeployment.default_url is not None ) - async def invoke_chat(self, chat_request: CohereChatRequest) -> Any: + async def invoke_chat(self, chat_request: CohereChatRequest, **kwargs) -> Any: response = self.client.chat( **chat_request.model_dump( exclude={"stream", "file_ids", "model", "agent_id"} @@ -80,7 +80,7 @@ async def invoke_chat_stream( yield to_dict(event) async def invoke_rerank( - self, query: str, documents: List[Dict[str, Any]], ctx: Context + self, query: str, documents: list[str], ctx: Context, **kwargs ) -> Any: return self.client.rerank( query=query, documents=documents, model=DEFAULT_RERANK_MODEL diff --git a/src/backend/schemas/context.py b/src/backend/schemas/context.py index 6faaa3ec93..5d10ec2476 100644 --- a/src/backend/schemas/context.py +++ b/src/backend/schemas/context.py @@ -1,4 +1,4 @@ -from typing import Any, Optional +from typing import Any, Optional, Self from pydantic import BaseModel @@ -68,7 +68,7 @@ def with_deployment_name(self, deployment_name: str): def with_user( self, session: DBSessionDep | None = None, user: User | None = None - ) -> "Context": + ) -> Self: if not user and not session: return self @@ -81,42 +81,42 @@ def with_user( return self - def with_agent(self, agent: Agent | None) -> "Context": + def with_agent(self, agent: Agent | None) -> Self: self.agent = agent return self def with_agent_tool_metadata( self, agent_tool_metadata: AgentToolMetadata - ) -> "Context": + ) -> Self: self.agent_tool_metadata = agent_tool_metadata return self - def with_model(self, model: str) -> "Context": + def with_model(self, model: str) -> Self: self.model = model return self - def with_deployment_config(self, deployment_config=None) -> "Context": + def with_deployment_config(self, deployment_config=None) -> Self: if deployment_config: self.deployment_config = deployment_config else: self.deployment_config = get_deployment_config(self.request) return self - def with_conversation_id(self, conversation_id: str) -> "Context": + def with_conversation_id(self, conversation_id: str) -> Self: self.conversation_id = conversation_id return self - def with_stream_start_ms(self, now_ms: float) -> "Context": + def with_stream_start_ms(self, now_ms: float) -> Self: self.stream_start_ms = now_ms - def with_agent_id(self, agent_id: str) -> "Context": + def with_agent_id(self, agent_id: str) -> Self: if not agent_id: return self self.agent_id = agent_id return self - def with_organization_id(self, organization_id: str) -> "Context": + def with_organization_id(self, organization_id: str) -> Self: self.organization_id = organization_id return self @@ -124,7 +124,7 @@ def with_organization( self, session: DBSessionDep | None = None, organization: Organization | None = None, - ) -> "Context": + ) -> Self: if not organization and not session: return self @@ -141,11 +141,11 @@ def with_organization( return self - def with_global_filtering(self) -> "Context": + def with_global_filtering(self) -> Self: self.use_global_filtering = True return self - def without_global_filtering(self) -> "Context": + def without_global_filtering(self) -> Self: self.use_global_filtering = False return self diff --git a/src/backend/services/conversation.py b/src/backend/services/conversation.py index a3b4306404..3bc358a9b5 100644 --- a/src/backend/services/conversation.py +++ b/src/backend/services/conversation.py @@ -1,5 +1,5 @@ import uuid -from typing import List, Optional +from typing import Optional from fastapi import HTTPException @@ -142,7 +142,7 @@ def get_messages_with_files( return messages_with_file -def get_documents_to_rerank(conversations: List[Conversation]) -> List[str]: +def get_documents_to_rerank(conversations: list[Conversation]) -> list[str]: """Get documents (strings) to rerank from a list of conversations Args: @@ -166,22 +166,22 @@ def get_documents_to_rerank(conversations: List[Conversation]) -> List[str]: async def filter_conversations( query: str, - conversations: List[Conversation], - rerank_documents: List[str], + conversations: list[Conversation], + rerank_documents: list[str], model_deployment: BaseDeployment, ctx: Context, -) -> List[Conversation]: +) -> list[Conversation]: """Filter conversations based on the rerank score Args: query (str): The query to filter conversations - conversations (List[Conversation]): List of conversations - rerank_documents (List[str]): List of documents to rerank + conversations (list[Conversation]): List of conversations + rerank_documents (list[str]): List of documents to rerank model_deployment: Model deployment object ctx (Context): Context object Returns: - List[Conversation]: List of filtered conversations + list[Conversation]: List of filtered conversations """ # if rerank is not enabled, filter out conversations that don't contain the query if not model_deployment.rerank_enabled(): diff --git a/src/backend/tests/integration/conftest.py b/src/backend/tests/integration/conftest.py index 0b005901a7..ee5252cb3e 100644 --- a/src/backend/tests/integration/conftest.py +++ b/src/backend/tests/integration/conftest.py @@ -22,7 +22,7 @@ @pytest.fixture -def client(): +def client() -> Generator[TestClient, None, None]: yield TestClient(app) @@ -194,5 +194,8 @@ def mock_available_model_deployments(request): @pytest.fixture def mock_cohere_list_models(): - with patch("backend.model_deployments.cohere_platform.CohereDeployment.list_models", return_value=["command", "command-r", "command-r-plus", "command-light-nightly"]) as mock: + with patch( + "backend.model_deployments.cohere_platform.CohereDeployment.list_models", + return_value=["command", "command-r", "command-r-plus", "command-light-nightly"] + ) as mock: yield mock diff --git a/src/backend/tests/unit/model_deployments/mock_deployments/mock_azure.py b/src/backend/tests/unit/model_deployments/mock_deployments/mock_azure.py index 610fd2595d..12279f1c23 100644 --- a/src/backend/tests/unit/model_deployments/mock_deployments/mock_azure.py +++ b/src/backend/tests/unit/model_deployments/mock_deployments/mock_azure.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Generator, List +from typing import Any, Generator from cohere.types import StreamedChatResponse @@ -18,31 +18,31 @@ class MockAzureDeployment(MockDeployment): def __init__(self, **kwargs: Any): pass - @classmethod - def name(cls) -> str: + @staticmethod + def name() -> str: return "Azure" - @classmethod - def env_vars(cls) -> List[str]: + @staticmethod + def env_vars() -> list[str]: return ["AZURE_API_KEY", "AZURE_CHAT_ENDPOINT_URL"] - @classmethod - def rerank_enabled(cls) -> bool: + @staticmethod + def rerank_enabled() -> bool: return False @classmethod - def list_models(cls) -> List[str]: + def list_models(cls) -> list[str]: if not cls.is_available(): return [] return cls.DEFAULT_MODELS - @classmethod - def is_available(cls) -> bool: + @staticmethod + def is_available() -> bool: return True def invoke_chat( - self, chat_request: CohereChatRequest, ctx: Context, **kwargs: Any + self, chat_request: CohereChatRequest, **kwargs: Any ) -> Any: event = { "text": "Hi! Hello there! How's it going?", @@ -97,6 +97,6 @@ def invoke_chat_stream( yield event def invoke_rerank( - self, query: str, documents: List[Dict[str, Any]], ctx: Context, **kwargs: Any + self, query: str, documents: list[str], ctx: Context, **kwargs: Any ) -> Any: return None diff --git a/src/backend/tests/unit/model_deployments/mock_deployments/mock_bedrock.py b/src/backend/tests/unit/model_deployments/mock_deployments/mock_bedrock.py index cb9b84a910..53fa171faa 100644 --- a/src/backend/tests/unit/model_deployments/mock_deployments/mock_bedrock.py +++ b/src/backend/tests/unit/model_deployments/mock_deployments/mock_bedrock.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Generator, List +from typing import Any, Generator from cohere.types import StreamedChatResponse @@ -18,28 +18,28 @@ class MockBedrockDeployment(MockDeployment): def __init__(self, **kwargs: Any): pass - @classmethod - def name(cls) -> str: + @staticmethod + def name() -> str: return "Bedrock" - @classmethod - def env_vars(cls) -> List[str]: + @staticmethod + def env_vars() -> list[str]: return [] - @property - def rerank_enabled(self) -> bool: + @staticmethod + def rerank_enabled() -> bool: return False @classmethod - def list_models(cls) -> List[str]: + def list_models(cls) -> list[str]: return cls.DEFAULT_MODELS - @classmethod - def is_available(cls) -> bool: + @staticmethod + def is_available() -> bool: return True def invoke_chat( - self, chat_request: CohereChatRequest, ctx: Context, **kwargs: Any + self, chat_request: CohereChatRequest, **kwargs: Any ) -> Generator[StreamedChatResponse, None, None]: event = { "text": "Hi! Hello there! How's it going?", @@ -93,6 +93,6 @@ def invoke_chat_stream( yield event def invoke_rerank( - self, query: str, documents: List[Dict[str, Any]], ctx: Context, **kwargs: Any + self, query: str, documents: list[str], ctx: Context, **kwargs: Any ) -> Any: return None diff --git a/src/backend/tests/unit/model_deployments/mock_deployments/mock_cohere_platform.py b/src/backend/tests/unit/model_deployments/mock_deployments/mock_cohere_platform.py index f15312b24f..6e23919573 100644 --- a/src/backend/tests/unit/model_deployments/mock_deployments/mock_cohere_platform.py +++ b/src/backend/tests/unit/model_deployments/mock_deployments/mock_cohere_platform.py @@ -1,10 +1,12 @@ -from typing import Any, Dict, Generator, List +import random +from typing import Any, Generator from cohere.types import StreamedChatResponse from backend.chat.enums import StreamEvent from backend.schemas.cohere_chat import CohereChatRequest from backend.schemas.context import Context +from backend.services.conversation import SEARCH_RELEVANCE_THRESHOLD from backend.tests.unit.model_deployments.mock_deployments.mock_base import ( MockDeployment, ) @@ -18,28 +20,28 @@ class MockCohereDeployment(MockDeployment): def __init__(self, **kwargs: Any): pass - @classmethod - def name(cls) -> str: + @staticmethod + def name() -> str: return "Cohere Platform" - @classmethod - def env_vars(cls) -> List[str]: + @staticmethod + def env_vars() -> list[str]: return ["COHERE_API_KEY"] - @property - def rerank_enabled(self) -> bool: + @staticmethod + def rerank_enabled() -> bool: return True @classmethod - def list_models(cls) -> List[str]: + def list_models(cls) -> list[str]: return cls.DEFAULT_MODELS - @classmethod - def is_available(cls) -> bool: + @staticmethod + def is_available() -> bool: return True @classmethod - def config(cls) -> Dict[str, Any]: + def config(cls) -> dict[str, Any]: return {"COHERE_API_KEY": "fake-api-key"} @@ -69,36 +71,63 @@ def invoke_chat( } yield event - def invoke_chat_stream( + async def invoke_chat_stream( self, chat_request: CohereChatRequest, ctx: Context, **kwargs: Any ) -> Generator[StreamedChatResponse, None, None]: - events = [ - { - "event_type": StreamEvent.STREAM_START, - "generation_id": "test", - }, - { - "event_type": StreamEvent.TEXT_GENERATION, - "text": "This is a test.", - }, - { - "event_type": StreamEvent.STREAM_END, - "response": { - "generation_id": "test", - "citations": [], - "documents": [], - "search_results": [], - "search_queries": [], - }, - "finish_reason": "MAX_TOKENS", + # Start Event Stream + events = [{ + "event_type": StreamEvent.STREAM_START, + "generation_id": "ca0f398e-f8c8-48f0-b093-12d1754d00ed", + }] + + # Add Tool Calls + for tool in chat_request.tools: + events.append({ + "event_type": StreamEvent.TOOL_CALLS_GENERATION, + "text": "", + "tool_calls": [ + { "name": tool.name, "parameters": {} }, + ], + }) + + # Add Text Generation + events.append({ + "event_type": StreamEvent.TEXT_GENERATION, + "text": "This is a test.", + }) + + # End Stream + events.append({ + "event_type": StreamEvent.STREAM_END, + "response": { + "generation_id": "ca0f398e-f8c8-48f0-b093-12d1754d00ed", + "citations": [], + "documents": [], + "search_results": [], + "search_queries": [], }, - ] + "finish_reason": "MAX_TOKENS", + }) for event in events: yield event - def invoke_rerank( - self, query: str, documents: List[Dict[str, Any]], ctx: Context, **kwargs: Any + async def invoke_rerank( + self, query: str, documents: list[str], ctx: Context, **kwargs: Any ) -> Any: - # TODO: Add - pass + results = [] + for idx, doc in enumerate(documents): + if query in doc: + results.append({ + "index": idx, + "relevance_score": random.uniform(SEARCH_RELEVANCE_THRESHOLD, 1), + }) + event = { + "id": "eae2b023-bf49-4139-bf15-9825022762f4", + "results": results, + "meta": { + "api_version":{"version":"1"}, + "billed_units":{"search_units":1} + } + } + return event diff --git a/src/backend/tests/unit/model_deployments/mock_deployments/mock_sagemaker.py b/src/backend/tests/unit/model_deployments/mock_deployments/mock_sagemaker.py index d40c2737ee..2f0f577562 100644 --- a/src/backend/tests/unit/model_deployments/mock_deployments/mock_sagemaker.py +++ b/src/backend/tests/unit/model_deployments/mock_deployments/mock_sagemaker.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Generator, List +from typing import Any, Generator from cohere.types import StreamedChatResponse @@ -18,28 +18,28 @@ class MockSageMakerDeployment(MockDeployment): def __init__(self, **kwargs: Any): pass - @classmethod - def name(cls) -> str: + @staticmethod + def name() -> str: return "SageMaker" - @classmethod - def env_vars(cls) -> List[str]: + @staticmethod + def env_vars() -> list[str]: return [] - @property - def rerank_enabled(self) -> bool: + @staticmethod + def rerank_enabled() -> bool: return False @classmethod - def list_models(cls) -> List[str]: + def list_models(cls) -> list[str]: return cls.DEFAULT_MODELS - @classmethod - def is_available(cls) -> bool: + @staticmethod + def is_available() -> bool: return True def invoke_chat( - self, chat_request: CohereChatRequest, ctx: Context, **kwargs: Any + self, chat_request: CohereChatRequest, **kwargs: Any ) -> Generator[StreamedChatResponse, None, None]: pass @@ -72,6 +72,6 @@ def invoke_chat_stream( yield event def invoke_rerank( - self, query: str, documents: List[Dict[str, Any]], ctx: Context, **kwargs: Any + self, query: str, documents: list[str], ctx: Context, **kwargs: Any ) -> Any: return None diff --git a/src/backend/tests/unit/model_deployments/mock_deployments/mock_single_container.py b/src/backend/tests/unit/model_deployments/mock_deployments/mock_single_container.py index e8cf3ac124..2ed0464ff4 100644 --- a/src/backend/tests/unit/model_deployments/mock_deployments/mock_single_container.py +++ b/src/backend/tests/unit/model_deployments/mock_deployments/mock_single_container.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Generator, List +from typing import Any, Generator from cohere.types import StreamedChatResponse @@ -18,28 +18,28 @@ class MockSingleContainerDeployment(MockDeployment): def __init__(self, **kwargs: Any): pass - @classmethod - def name(cls) -> str: + @staticmethod + def name() -> str: return "Single Container" - @classmethod - def env_vars(cls) -> List[str]: + @staticmethod + def env_vars() -> list[str]: return [] - @property - def rerank_enabled(self) -> bool: + @staticmethod + def rerank_enabled() -> bool: return False @classmethod - def list_models(cls) -> List[str]: + def list_models(cls) -> list[str]: return cls.DEFAULT_MODELS - @classmethod - def is_available(cls) -> bool: + @staticmethod + def is_available() -> bool: return True def invoke_chat( - self, chat_request: CohereChatRequest, ctx: Context, **kwargs: Any + self, chat_request: CohereChatRequest, **kwargs: Any ) -> Generator[StreamedChatResponse, None, None]: event = { "text": "Hi! Hello there! How's it going?", @@ -93,7 +93,7 @@ def invoke_chat_stream( yield event def invoke_rerank( - self, query: str, documents: List[Dict[str, Any]], ctx: Context, **kwargs: Any + self, query: str, documents: list[str], ctx: Context, **kwargs: Any ) -> Any: # TODO: Add pass From ee32c24e6ca8b4d37e442157445115980ab01d17 Mon Sep 17 00:00:00 2001 From: Eric Zawadski Date: Wed, 15 Jan 2025 13:05:53 -0800 Subject: [PATCH 5/5] feat(backend): removed dependency on cohere api key --- .../tests/integration/routers/test_agent.py | 32 ++- .../integration/routers/test_conversation.py | 32 +-- .../routers/test_chat.py | 241 +++++++++++------- 3 files changed, 183 insertions(+), 122 deletions(-) rename src/backend/tests/{integration => unit}/routers/test_chat.py (86%) diff --git a/src/backend/tests/integration/routers/test_agent.py b/src/backend/tests/integration/routers/test_agent.py index 191eaf9831..1e48dcd2bb 100644 --- a/src/backend/tests/integration/routers/test_agent.py +++ b/src/backend/tests/integration/routers/test_agent.py @@ -21,7 +21,13 @@ and os.environ.get("COHERE_API_KEY") != "" ) -def test_create_agent(session_client: TestClient, session: Session, user: User, mock_cohere_list_models) -> None: + +def test_create_agent( + session_client: TestClient, + session: Session, + user: User, + mock_cohere_list_models, +) -> None: request_json = { "name": "test agent", "version": 1, @@ -297,7 +303,7 @@ def test_create_agent_deployment_not_in_db( def test_create_agent_invalid_tool( - session_client: TestClient, session: Session, user + session_client: TestClient, session: Session, user: User, ) -> None: request_json = { "name": "test agent", @@ -314,7 +320,7 @@ def test_create_agent_invalid_tool( def test_create_existing_agent( - session_client: TestClient, session: Session, user + session_client: TestClient, session: Session, user: User, ) -> None: agent = get_factory("Agent", session).create(name="test agent") request_json = { @@ -336,7 +342,9 @@ def test_list_agents_empty_returns_default_agent(session_client: TestClient, ses assert len(response_agents) == 1 -def test_list_agents(session_client: TestClient, session: Session, user) -> None: +def test_list_agents( + session_client: TestClient, session: Session, user: User, +) -> None: num_agents = 3 for _ in range(num_agents): _ = get_factory("Agent", session).create(user=user) @@ -350,7 +358,7 @@ def test_list_agents(session_client: TestClient, session: Session, user) -> None def test_list_organization_agents( session_client: TestClient, session: Session, - user, + user: User, ) -> None: num_agents = 3 organization = get_factory("Organization", session).create() @@ -379,7 +387,7 @@ def test_list_organization_agents( def test_list_organization_agents_query_param( session_client: TestClient, session: Session, - user, + user: User, ) -> None: num_agents = 3 organization = get_factory("Organization", session).create() @@ -408,7 +416,7 @@ def test_list_organization_agents_query_param( def test_list_organization_agents_nonexistent_organization( session_client: TestClient, session: Session, - user, + user: User, ) -> None: response = session_client.get( "/v1/agents", headers={"User-Id": user.id, "Organization-Id": "123"} @@ -418,7 +426,7 @@ def test_list_organization_agents_nonexistent_organization( def test_list_private_agents( - session_client: TestClient, session: Session, user + session_client: TestClient, session: Session, user: User, ) -> None: for _ in range(3): _ = get_factory("Agent", session).create(user=user, is_private=True) @@ -438,7 +446,9 @@ def test_list_private_agents( assert len(response_agents) == 3 -def test_list_public_agents(session_client: TestClient, session: Session, user) -> None: +def test_list_public_agents( + session_client: TestClient, session: Session, user: User, +) -> None: for _ in range(3): _ = get_factory("Agent", session).create(user=user, is_private=True) @@ -458,7 +468,7 @@ def test_list_public_agents(session_client: TestClient, session: Session, user) def list_public_and_private_agents( - session_client: TestClient, session: Session, user + session_client: TestClient, session: Session, user: User, ) -> None: for _ in range(3): _ = get_factory("Agent", session).create(user=user, is_private=True) @@ -479,7 +489,7 @@ def list_public_and_private_agents( def test_list_agents_with_pagination( - session_client: TestClient, session: Session, user + session_client: TestClient, session: Session, user: User, ) -> None: for _ in range(5): _ = get_factory("Agent", session).create(user=user) diff --git a/src/backend/tests/integration/routers/test_conversation.py b/src/backend/tests/integration/routers/test_conversation.py index 1700c7fd1e..919610f937 100644 --- a/src/backend/tests/integration/routers/test_conversation.py +++ b/src/backend/tests/integration/routers/test_conversation.py @@ -1,5 +1,3 @@ -import os - import pytest from fastapi.testclient import TestClient from sqlalchemy.orm import Session @@ -10,11 +8,14 @@ from backend.schemas.user import User from backend.tests.unit.factories import get_factory +_IS_GOOGLE_CLOUD_API_KEY_SET = bool(Settings().get('google_cloud.api_key')) + def test_search_conversations( session_client: TestClient, session: Session, user: User, + mock_available_model_deployments, ) -> None: conversation = get_factory("Conversation", session).create( title="test title", user_id=user.id @@ -24,8 +25,6 @@ def test_search_conversations( headers={"User-Id": user.id}, params={"query": "test"}, ) - print("here") - print(response.json) results = response.json() assert response.status_code == 200 @@ -33,14 +32,11 @@ def test_search_conversations( assert results[0]["id"] == conversation.id -@pytest.mark.skipif( - os.environ.get("COHERE_API_KEY") is None, - reason="Cohere API key not set, skipping test", -) def test_search_conversations_with_reranking( session_client: TestClient, session: Session, user: User, + mock_available_model_deployments, ) -> None: _ = get_factory("Conversation", session).create( title="Hello, how are you?", text_messages=[], user_id=user.id @@ -83,19 +79,16 @@ def test_search_conversations_no_conversations( assert response.json() == [] -# MISC - - -@pytest.mark.skip(reason="Restore this test when we get access to run models on Huggingface") def test_generate_title( session_client: TestClient, session: Session, user: User, + mock_available_model_deployments, ) -> None: - conversation = get_factory("Conversation", session).create(user_id=user.id) + conversation_initial = get_factory("Conversation", session).create(user_id=user.id) response = session_client.post( - f"/v1/conversations/{conversation.id}/generate-title", - headers={"User-Id": conversation.user_id}, + f"/v1/conversations/{conversation_initial.id}/generate-title", + headers={"User-Id": conversation_initial.user_id}, ) response_json = response.json() @@ -105,7 +98,7 @@ def test_generate_title( # Check if the conversation was updated conversation = ( session.query(Conversation) - .filter_by(id=conversation.id, user_id=conversation.user_id) + .filter_by(id=conversation_initial.id, user_id=conversation_initial.user_id) .first() ) assert conversation is not None @@ -165,10 +158,7 @@ def test_generate_title_error_invalid_model( # SYNTHESIZE -is_google_cloud_api_key_set = bool(Settings().get('google_cloud.api_key')) - - -@pytest.mark.skipif(not is_google_cloud_api_key_set, reason="Google Cloud API key not set, skipping test") +@pytest.mark.skipif(not _IS_GOOGLE_CLOUD_API_KEY_SET, reason="Google Cloud API key not set, skipping test") def test_synthesize_english_message( session_client: TestClient, session: Session, @@ -186,7 +176,7 @@ def test_synthesize_english_message( assert response.headers["Content-Type"] == "audio/mp3" -@pytest.mark.skipif(not is_google_cloud_api_key_set, reason="Google Cloud API key not set, skipping test") +@pytest.mark.skipif(not _IS_GOOGLE_CLOUD_API_KEY_SET, reason="Google Cloud API key not set, skipping test") def test_synthesize_non_english_message( session_client: TestClient, session: Session, diff --git a/src/backend/tests/integration/routers/test_chat.py b/src/backend/tests/unit/routers/test_chat.py similarity index 86% rename from src/backend/tests/integration/routers/test_chat.py rename to src/backend/tests/unit/routers/test_chat.py index 9d59ccbb29..572a31a19e 100644 --- a/src/backend/tests/integration/routers/test_chat.py +++ b/src/backend/tests/unit/routers/test_chat.py @@ -30,10 +30,12 @@ def user(session_chat: Session) -> User: # STREAMING CHAT TESTS -@pytest.mark.skipif(not is_cohere_env_set, reason="Cohere API key not set") def test_streaming_new_chat( - session_client_chat: TestClient, session_chat: Session, user: User -): + session_client_chat: TestClient, + session_chat: Session, + user: User, + mock_available_model_deployments: list[dict], +) -> None: response = session_client_chat.post( "/v1/chat-stream", headers={ @@ -49,11 +51,12 @@ def test_streaming_new_chat( ) -@pytest.mark.skipif(not is_cohere_env_set, reason="Cohere API key not set") -@pytest.mark.skip(reason="Failing due to error 405 calling API, but works in practice. Requires debugging") def test_streaming_new_chat_with_agent( - session_client_chat: TestClient, session_chat: Session, user: User -): + session_client_chat: TestClient, + session_chat: Session, + user: User, + mock_available_model_deployments: list[dict], +) -> None: deployment = get_factory("Deployment", session_chat).create() model = get_factory("Model", session_chat).create(deployment=deployment) agent = get_factory("Agent", session_chat).create(user=user, tools=[], deployment_id=deployment.id, @@ -74,11 +77,12 @@ def test_streaming_new_chat_with_agent( ) -@pytest.mark.skipif(not is_cohere_env_set, reason="Cohere API key not set") -@pytest.mark.skip(reason="Failing due to error 405 calling API, but works in practice. Requires debugging") def test_streaming_new_chat_with_agent_existing_conversation( - session_client_chat: TestClient, session_chat: Session, user: User -): + session_client_chat: TestClient, + session_chat: Session, + user: User, + mock_available_model_deployments: list[dict], +) -> None: deployment = get_factory("Deployment", session_chat).create() model = get_factory("Model", session_chat).create(deployment=deployment) agent = get_factory("Agent", session_chat).create(user=user, tools=[], deployment_id=deployment.id, @@ -126,10 +130,12 @@ def test_streaming_new_chat_with_agent_existing_conversation( ) -@pytest.mark.skipif(not is_cohere_env_set, reason="Cohere API key not set") def test_streaming_chat_with_existing_conversation_from_other_agent( - session_client_chat: TestClient, session_chat: Session, user: User -): + session_client_chat: TestClient, + session_chat: Session, + user: User, + mock_available_model_deployments: list[dict], +) -> None: agent = get_factory("Agent", session_chat).create(user=user) _ = get_factory("Agent", session_chat).create(user=user, id="123") conversation = get_factory("Conversation", session_chat).create( @@ -171,11 +177,12 @@ def test_streaming_chat_with_existing_conversation_from_other_agent( } -@pytest.mark.skipif(not is_cohere_env_set, reason="Cohere API key not set") -@pytest.mark.skip(reason="Failing due to error 405 calling API, but works in practice. Requires debugging") def test_streaming_chat_with_tools_not_in_agent_tools( - session_client_chat: TestClient, session_chat: Session, user: User -): + session_client_chat: TestClient, + session_chat: Session, + user: User, + mock_available_model_deployments: list[dict], +) -> None: deployment = get_factory("Deployment", session_chat).create() model = get_factory("Model", session_chat).create(deployment=deployment) agent = get_factory("Agent", session_chat).create(user=user, tools=["wikipedia"], deployment_id=deployment.id, @@ -198,10 +205,12 @@ def test_streaming_chat_with_tools_not_in_agent_tools( validate_chat_streaming_tool_cals_response(response, ["tavily_web_search"]) -@pytest.mark.skipif(not is_cohere_env_set, reason="Cohere API key not set") def test_streaming_chat_with_agent_tools_and_empty_request_tools( - session_client_chat: TestClient, session_chat: Session, user: User -): + session_client_chat: TestClient, + session_chat: + Session, user: User, + mock_available_model_deployments: list[dict], +) -> None: deployment = get_factory("Deployment", session_chat).create() model = get_factory("Model", session_chat).create(deployment=deployment) agent = get_factory("Agent", session_chat).create(user=user, tools=["tavily_web_search"], @@ -224,10 +233,12 @@ def test_streaming_chat_with_agent_tools_and_empty_request_tools( validate_chat_streaming_tool_cals_response(response, ["tavily_web_search"]) -@pytest.mark.skipif(not is_cohere_env_set, reason="Cohere API key not set") def test_streaming_existing_chat( - session_client_chat: TestClient, session_chat: Session, user: User -): + session_client_chat: TestClient, + session_chat: Session, + user: User, + mock_available_model_deployments: list[dict], +) -> None: conversation = get_factory("Conversation", session_chat).create(user_id=user.id) _ = get_factory("Message", session_chat).create( @@ -269,10 +280,12 @@ def test_streaming_existing_chat( ) -@pytest.mark.skipif(not is_cohere_env_set, reason="Cohere API key not set") def test_fail_chat_missing_user_id( - session_client_chat: TestClient, session_chat: Session, user: User -): + session_client_chat: TestClient, + session_chat: Session, + user: User, + mock_available_model_deployments: list[dict], +) -> None: response = session_client_chat.post( "/v1/chat", json={"message": "Hello"}, @@ -283,10 +296,12 @@ def test_fail_chat_missing_user_id( assert response.json() == {"detail": "User-Id required in request headers."} -@pytest.mark.skipif(not is_cohere_env_set, reason="Cohere API key not set") def test_default_chat_missing_deployment_name( - session_client_chat: TestClient, session_chat: Session, user: User -): + session_client_chat: TestClient, + session_chat: Session, + user: User, + mock_available_model_deployments: list[dict], +) -> None: response = session_client_chat.post( "/v1/chat", json={"message": "Hello"}, @@ -296,10 +311,12 @@ def test_default_chat_missing_deployment_name( assert response.status_code == 200 -@pytest.mark.skipif(not is_cohere_env_set, reason="Cohere API key not set") def test_streaming_fail_chat_missing_message( - session_client_chat: TestClient, session_chat: Session, user: User -): + session_client_chat: TestClient, + session_chat: Session, + user: User, + mock_available_model_deployments: list[dict], +) -> None: response = session_client_chat.post( "/v1/chat-stream", headers={ @@ -323,7 +340,12 @@ def test_streaming_fail_chat_missing_message( @pytest.mark.skipif(not is_cohere_env_set, reason="Cohere API key not set") -def test_streaming_chat_with_managed_tools(session_client_chat, session_chat, user): +def test_streaming_chat_with_managed_tools( + session_client_chat: TestClient, + session_chat: Session, + user: User, + # mock_available_model_deployments: list[dict], +) -> None: tools = session_client_chat.get("/v1/tools", headers={"User-Id": user.id}).json() assert len(tools) > 0 tool = [t for t in tools if t["is_visible"] and t["category"] != ToolCategory.Function][ @@ -344,11 +366,12 @@ def test_streaming_chat_with_managed_tools(session_client_chat, session_chat, us response, user, session_chat, session_client_chat, 2 ) - -@pytest.mark.skipif(not is_cohere_env_set, reason="Cohere API key not set") def test_streaming_chat_with_invalid_tool( - session_client_chat: TestClient, session_chat: Session, user: User -): + session_client_chat: TestClient, + session_chat: Session, + user: User, + mock_available_model_deployments: list[dict], +) -> None: response = session_client_chat.post( "/v1/chat-stream", json={"message": "Hello", "tools": [{"name": "invalid_tool"}]}, @@ -362,10 +385,12 @@ def test_streaming_chat_with_invalid_tool( assert response.json() == {"detail": "Custom tools must have a description"} -@pytest.mark.skipif(not is_cohere_env_set, reason="Cohere API key not set") def test_streaming_chat_with_managed_and_custom_tools( - session_client_chat, session_chat, user -): + session_client_chat: TestClient, + session_chat: Session, + user: User, + mock_available_model_deployments: list[dict], +) -> None: tools = session_client_chat.get("/v1/tools", headers={"User-Id": user.id}).json() assert len(tools) > 0 tool = [t for t in tools if t["is_visible"] and t["category"] != ToolCategory.Function][ @@ -396,8 +421,11 @@ def test_streaming_chat_with_managed_and_custom_tools( @pytest.mark.skipif(not is_cohere_env_set, reason="Cohere API key not set") def test_streaming_chat_with_search_queries_only( - session_client_chat: TestClient, session_chat: Session, user: User -): + session_client_chat: TestClient, + session_chat: Session, + user: User, + # mock_available_model_deployments: list[dict], +) -> None: response = session_client_chat.post( "/v1/chat-stream", json={ @@ -421,9 +449,11 @@ def test_streaming_chat_with_search_queries_only( ) -@pytest.mark.skipif(not is_cohere_env_set, reason="Cohere API key not set") def test_streaming_chat_with_chat_history( - session_client_chat: TestClient, session_chat: Session, user: User + session_client_chat: TestClient, + session_chat: Session, + user: User, + mock_available_model_deployments: list[dict], ) -> None: response = session_client_chat.post( "/v1/chat-stream", @@ -451,10 +481,12 @@ def test_streaming_chat_with_chat_history( ) -@pytest.mark.skipif(not is_cohere_env_set, reason="Cohere API key not set") def test_streaming_existing_chat_with_files_attaches_to_user_message( - session_client_chat: TestClient, session_chat: Session, user: User -): + session_client_chat: TestClient, + session_chat: Session, + user: User, + mock_available_model_deployments: list[dict], +) -> None: conversation = get_factory("Conversation", session_chat).create(user_id=user.id) file1 = get_factory("File", session_chat).create(user_id=user.id) file2 = get_factory("File", session_chat).create(user_id=user.id) @@ -487,10 +519,12 @@ def test_streaming_existing_chat_with_files_attaches_to_user_message( ) -@pytest.mark.skipif(not is_cohere_env_set, reason="Cohere API key not set") def test_streaming_existing_chat_with_attached_files_does_not_attach( - session_client_chat: TestClient, session_chat: Session, user: User -): + session_client_chat: TestClient, + session_chat: Session, + user: User, + mock_available_model_deployments: list[dict], +) -> None: file1 = get_factory("File", session_chat).create( user_id=user.id, ) @@ -544,10 +578,12 @@ def test_streaming_existing_chat_with_attached_files_does_not_attach( ) -@pytest.mark.skipif(not is_cohere_env_set, reason="Cohere API key not set") def test_streaming_chat_private_agent( - session_client_chat: TestClient, session_chat: Session, user: User -): + session_client_chat: TestClient, + session_chat: Session, + user: User, + mock_available_model_deployments: list[dict], +) -> None: agent = get_factory("Agent", session_chat).create( user=user, is_private=True, tools=[] ) @@ -567,10 +603,12 @@ def test_streaming_chat_private_agent( ) -@pytest.mark.skipif(not is_cohere_env_set, reason="Cohere API key not set") def test_streaming_chat_public_agent( - session_client_chat: TestClient, session_chat: Session, user: User -): + session_client_chat: TestClient, + session_chat: Session, + user: User, + mock_available_model_deployments: list[dict], +) -> None: agent = get_factory("Agent", session_chat).create( user_id=user.id, is_private=False, tools=[] ) @@ -611,10 +649,12 @@ def test_streaming_chat_private_agent_by_another_user( assert response.json() == {"detail": f"Agent with ID {agent.id} not found."} -@pytest.mark.skipif(not is_cohere_env_set, reason="Cohere API key not set") def test_stream_regenerate_existing_chat( - session_client_chat: TestClient, session_chat: Session, user: User -): + session_client_chat: TestClient, + session_chat: Session, + user: User, + mock_available_model_deployments: list[dict], +) -> None: conversation = get_factory("Conversation", session_chat).create(user_id=user.id) _ = get_factory("Message", session_chat).create( @@ -656,10 +696,12 @@ def test_stream_regenerate_existing_chat( ) -@pytest.mark.skipif(not is_cohere_env_set, reason="Cohere API key not set") def test_stream_regenerate_not_existing_chat( - session_client_chat: TestClient, session_chat: Session, user: User -): + session_client_chat: TestClient, + session_chat: Session, + user: User, + mock_available_model_deployments: list[dict], +) -> None: conversation_id = "test_conversation_id" response = session_client_chat.post( @@ -679,10 +721,12 @@ def test_stream_regenerate_not_existing_chat( assert response.json() == {"detail": f"Conversation with ID: {conversation_id} not found."} -@pytest.mark.skipif(not is_cohere_env_set, reason="Cohere API key not set") def test_stream_regenerate_existing_chat_not_existing_user_messages( - session_client_chat: TestClient, session_chat: Session, user: User -): + session_client_chat: TestClient, + session_chat: Session, + user: User, + mock_available_model_deployments: list[dict], +) -> None: conversation = get_factory("Conversation", session_chat).create(user_id=user.id) session_chat.refresh(conversation) @@ -705,10 +749,12 @@ def test_stream_regenerate_existing_chat_not_existing_user_messages( # NON-STREAMING CHAT TESTS -@pytest.mark.skipif(not is_cohere_env_set, reason="Cohere API key not set") def test_non_streaming_chat( - session_client_chat: TestClient, session_chat: Session, user: User -): + session_client_chat: TestClient, + session_chat: Session, + user: User, + mock_available_model_deployments: list[dict], +) -> None: response = session_client_chat.post( "/v1/chat", json={"message": "Hello", "max_tokens": 10}, @@ -725,7 +771,12 @@ def test_non_streaming_chat( @pytest.mark.skipif(not is_cohere_env_set, reason="Cohere API key not set") -def test_non_streaming_chat_with_managed_tools(session_client_chat, session_chat, user): +def test_non_streaming_chat_with_managed_tools( + session_client_chat: TestClient, + session_chat: Session, + user: User, + # mock_available_model_deployments: list[dict], +) -> None: tools = session_client_chat.get("/v1/tools", headers={"User-Id": user.id}).json() assert len(tools) > 0 tool = [t for t in tools if t["is_visible"] and t["category"] != ToolCategory.Function][ @@ -747,10 +798,12 @@ def test_non_streaming_chat_with_managed_tools(session_client_chat, session_chat validate_conversation(session_chat, user, conversation_id, 2) -@pytest.mark.skipif(not is_cohere_env_set, reason="Cohere API key not set") def test_non_streaming_chat_with_managed_and_custom_tools( - session_client_chat, session_chat, user -): + session_client_chat: TestClient, + session_chat: Session, + user: User, + mock_available_model_deployments: list[dict], +) -> None: tools = session_client_chat.get("/v1/tools", headers={"User-Id": user.id}).json() assert len(tools) > 0 tool = [t for t in tools if t["is_visible"] and t["category"] != ToolCategory.Function][ @@ -779,10 +832,12 @@ def test_non_streaming_chat_with_managed_and_custom_tools( assert response.json() == {"detail": "Cannot mix both managed and custom tools"} -@pytest.mark.skipif(not is_cohere_env_set, reason="Cohere API key not set") def test_non_streaming_chat_with_search_queries_only( - session_client_chat: TestClient, session_chat: Session, user: User -): + session_client_chat: TestClient, + session_chat: Session, + user: User, + mock_available_model_deployments: list[dict], +) -> None: response = session_client_chat.post( "/v1/chat", json={ @@ -801,9 +856,11 @@ def test_non_streaming_chat_with_search_queries_only( validate_conversation(session_chat, user, conversation_id, 2) -@pytest.mark.skipif(not is_cohere_env_set, reason="Cohere API key not set") def test_non_streaming_chat_with_chat_history( - session_client_chat: TestClient, session_chat: Session, user: User + session_client_chat: TestClient, + session_chat: Session, + user: User, + mock_available_model_deployments: list[dict], ) -> None: response = session_client_chat.post( "/v1/chat", @@ -826,10 +883,12 @@ def test_non_streaming_chat_with_chat_history( validate_conversation(session_chat, user, conversation_id, 0) -@pytest.mark.skipif(not is_cohere_env_set, reason="Cohere API key not set") def test_non_streaming_existing_chat_with_files_attaches_to_user_message( - session_client_chat: TestClient, session_chat: Session, user: User -): + session_client_chat: TestClient, + session_chat: Session, + user: User, + mock_available_model_deployments: list[dict], +) -> None: conversation = get_factory("Conversation", session_chat).create(user_id=user.id) file1 = get_factory("File", session_chat).create(user_id=user.id) file2 = get_factory("File", session_chat).create(user_id=user.id) @@ -861,10 +920,12 @@ def test_non_streaming_existing_chat_with_files_attaches_to_user_message( assert file2.id in message.file_ids -@pytest.mark.skipif(not is_cohere_env_set, reason="Cohere API key not set") def test_non_streaming_existing_chat_with_attached_files_does_not_attach( - session_client_chat: TestClient, session_chat: Session, user: User -): + session_client_chat: TestClient, + session_chat: Session, + user: User, + mock_available_model_deployments: list[dict], +) -> None: conversation = get_factory("Conversation", session_chat).create(user_id=user.id) existing_message = get_factory("Message", session_chat).create( conversation_id=conversation.id, user_id=user.id, position=0, is_active=True @@ -948,12 +1009,12 @@ def validate_chat_streaming_response( validate_conversation(session, user, conversation_id, expected_num_messages) -@pytest.mark.skipif(not is_cohere_env_set, reason="Cohere API key not set") def test_streaming_chat_with_files( session_client_chat: TestClient, session_chat: Session, user: User, -): + mock_available_model_deployments: list[dict], +) -> None: # Create convo conversation = get_factory("Conversation", session_chat).create(user_id=user.id) @@ -1011,10 +1072,10 @@ def validate_conversation( assert conversation.user_id == user.id assert len(conversation.messages) == expected_num_messages # Also test DB object - conversation = session.get(Conversation, (conversation_id, user.id)) - assert conversation is not None - assert conversation.user_id == user.id - assert len(conversation.messages) == expected_num_messages + conversation_db = session.get(Conversation, (conversation_id, user.id)) + assert conversation_db is not None + assert conversation_db.user_id == user.id + assert len(conversation_db.messages) == expected_num_messages def validate_stream_end_event(