From d966233262e5281510848416adb073c087fe3aa5 Mon Sep 17 00:00:00 2001 From: Tianjing Li Date: Wed, 22 Jan 2025 09:36:17 -0500 Subject: [PATCH] Fix integration and unit tests (#912) * WIP * WIP * wip * Conftest revert * Add default database for testing * TLK-2723 Resolve current integration/unit test issues * TLK-2723 Resolve current integration/unit test issues * TLK-2723 Resolve current integration/unit test issues --------- Co-authored-by: EugeneP --- Makefile | 2 +- ...8_01_117f0d9b1d3d_seed_deployments_data.py | 10 ++--- src/backend/chat/custom/utils.py | 3 +- .../seeders/deployments_models_seed.py | 25 ----------- .../seeders/organization_seed.py | 42 +++++++++++++++++++ src/backend/pytest_integration.ini | 3 +- src/backend/services/deployment.py | 2 +- src/backend/services/request_validators.py | 8 +++- src/backend/tests/integration/conftest.py | 3 +- .../tests/integration/routers/test_agent.py | 3 +- .../integration/routers/test_conversation.py | 9 +++- .../integration/services/auth/__init__.py | 0 .../services/auth/strategies/__init__.py | 0 src/backend/tests/unit/configuration.yaml | 2 +- src/backend/tests/unit/conftest.py | 40 ++++++++++++++++-- .../crud/test_deployment.py | 0 .../{integration => unit}/crud/test_model.py | 0 .../routers/test_chat.py | 0 .../routers/test_deployment.py | 4 +- .../routers/test_model.py | 0 .../crud => unit/services/auth}/__init__.py | 0 .../services/auth/strategies}/__init__.py | 0 .../services/auth/strategies/test_basic.py | 0 .../services/auth/test_jwt.py | 0 .../services/auth/test_request_validators.py | 0 .../services/test_cache.py | 0 .../tests/unit/services/test_deployment.py | 7 +++- 27 files changed, 115 insertions(+), 48 deletions(-) delete mode 100644 src/backend/database_models/seeders/deployments_models_seed.py create mode 100644 src/backend/database_models/seeders/organization_seed.py delete mode 100644 src/backend/tests/integration/services/auth/__init__.py delete mode 100644 src/backend/tests/integration/services/auth/strategies/__init__.py rename src/backend/tests/{integration => unit}/crud/test_deployment.py (100%) rename src/backend/tests/{integration => unit}/crud/test_model.py (100%) rename src/backend/tests/{integration => unit}/routers/test_chat.py (100%) rename src/backend/tests/{integration => unit}/routers/test_deployment.py (98%) rename src/backend/tests/{integration => unit}/routers/test_model.py (100%) rename src/backend/tests/{integration/crud => unit/services/auth}/__init__.py (100%) rename src/backend/tests/{integration/services => unit/services/auth/strategies}/__init__.py (100%) rename src/backend/tests/{integration => unit}/services/auth/strategies/test_basic.py (100%) rename src/backend/tests/{integration => unit}/services/auth/test_jwt.py (100%) rename src/backend/tests/{integration => unit}/services/auth/test_request_validators.py (100%) rename src/backend/tests/{integration => unit}/services/test_cache.py (100%) diff --git a/Makefile b/Makefile index 65ba6df261..8c5586c274 100644 --- a/Makefile +++ b/Makefile @@ -56,7 +56,7 @@ run-community-tests-debug: .PHONY: run-integration-tests run-integration-tests: - docker compose run --rm --build backend poetry run pytest -c src/backend/pytest_integration.ini src/backend/tests/integration/$(file) + poetry run pytest -c src/backend/pytest_integration.ini src/backend/tests/integration/$(file) .PHONY: test-db test-db: diff --git a/src/backend/alembic/versions/2024_08_01_117f0d9b1d3d_seed_deployments_data.py b/src/backend/alembic/versions/2024_08_01_117f0d9b1d3d_seed_deployments_data.py index 95a99f8b76..8e300def5a 100644 --- a/src/backend/alembic/versions/2024_08_01_117f0d9b1d3d_seed_deployments_data.py +++ b/src/backend/alembic/versions/2024_08_01_117f0d9b1d3d_seed_deployments_data.py @@ -10,9 +10,9 @@ from alembic import op -from backend.database_models.seeders.deployments_models_seed import ( - delete_default_models, - deployments_models_seed, +from backend.database_models.seeders.organization_seed import ( + delete_default_organization, + seed_default_organization, ) # revision identifiers, used by Alembic. @@ -23,8 +23,8 @@ def upgrade() -> None: - deployments_models_seed(op) + seed_default_organization(op) def downgrade() -> None: - delete_default_models(op) + delete_default_organization(op) diff --git a/src/backend/chat/custom/utils.py b/src/backend/chat/custom/utils.py index 7676c63f69..893ba221fc 100644 --- a/src/backend/chat/custom/utils.py +++ b/src/backend/chat/custom/utils.py @@ -1,7 +1,6 @@ from typing import Any from backend.database_models.database import get_session -from backend.exceptions import DeploymentNotFoundError from backend.model_deployments.base import BaseDeployment from backend.schemas.context import Context from backend.services import deployment as deployment_service @@ -20,7 +19,7 @@ def get_deployment(name: str, ctx: Context, **kwargs: Any) -> BaseDeployment: try: session = next(get_session()) deployment = deployment_service.get_deployment_by_name(session, name, **kwargs) - except DeploymentNotFoundError: + except Exception: deployment = deployment_service.get_default_deployment(**kwargs) return deployment diff --git a/src/backend/database_models/seeders/deployments_models_seed.py b/src/backend/database_models/seeders/deployments_models_seed.py deleted file mode 100644 index 400735f52a..0000000000 --- a/src/backend/database_models/seeders/deployments_models_seed.py +++ /dev/null @@ -1,25 +0,0 @@ -from sqlalchemy.orm import Session - -from backend.database_models import Deployment, Model, Organization - - -def deployments_models_seed(op): - """ - Seed default deployments, models, organization, user and agent. - """ - # Previously we would seed the default deployments and models here. We've changed this - # behaviour during a refactor of the deployments module so that deployments and models - # are inserted when they're first used. This solves an issue where seed data would - # sometimes be inserted with invalid config data. - pass - - -def delete_default_models(op): - """ - Delete deployments and models. - """ - session = Session(op.get_bind()) - session.query(Deployment).delete() - session.query(Model).delete() - session.query(Organization).filter_by(id="default").delete() - session.commit() diff --git a/src/backend/database_models/seeders/organization_seed.py b/src/backend/database_models/seeders/organization_seed.py new file mode 100644 index 0000000000..c8d670dc67 --- /dev/null +++ b/src/backend/database_models/seeders/organization_seed.py @@ -0,0 +1,42 @@ +from sqlalchemy import text +from sqlalchemy.orm import Session + +from backend.database_models import Organization + + +def seed_default_organization(op): + """ + Seed default organization. + """ + # Previously we would seed the default deployments and models here. We've changed this + # behaviour during a refactor of the deployments module so that deployments and models + # are inserted when they're first used. This solves an issue where seed data would + # sometimes be inserted with invalid config data. + + _ = Session(op.get_bind()) + + # Seed default organization + sql_command = text( + """ + INSERT INTO organizations ( + id, name, created_at, updated_at + ) + VALUES ( + :id, :name, now(), now() + ) + ON CONFLICT (id) DO NOTHING; + """ + ).bindparams( + id="default", + name="Default Organization", + ) + op.execute(sql_command) + + +def delete_default_organization(op): + """ + Delete default organization. + """ + session = Session(op.get_bind()) + session.query(Organization).filter_by(id="default").delete() + session.commit() diff --git a/src/backend/pytest_integration.ini b/src/backend/pytest_integration.ini index c686703e0c..2c593c116a 100644 --- a/src/backend/pytest_integration.ini +++ b/src/backend/pytest_integration.ini @@ -1,5 +1,6 @@ [pytest] env = - DATABASE_URL=postgresql://postgres:postgres@db:5432/postgres + DATABASE_URL=postgresql://postgres:postgres@localhost:5433 filterwarnings = ignore::UserWarning:pydantic.* + ignore::DeprecationWarning \ No newline at end of file diff --git a/src/backend/services/deployment.py b/src/backend/services/deployment.py index ac4c597af4..3f838b58a6 100644 --- a/src/backend/services/deployment.py +++ b/src/backend/services/deployment.py @@ -91,7 +91,7 @@ def get_deployment_definition_by_name(session: DBSessionDep, deployment_name: st raise DeploymentNotFoundError(deployment_id=deployment_name) if definition.name not in [d.name for d in deployment_crud.get_deployments(session)]: - create_db_deployment(session, definition) + definition = create_db_deployment(session, definition) return definition diff --git a/src/backend/services/request_validators.py b/src/backend/services/request_validators.py index 5bbdd65248..f932d26e12 100644 --- a/src/backend/services/request_validators.py +++ b/src/backend/services/request_validators.py @@ -11,6 +11,7 @@ from backend.crud import model as model_crud from backend.crud import organization as organization_crud from backend.database_models.database import DBSessionDep +from backend.exceptions import DeploymentNotFoundError from backend.model_deployments.utils import class_name_validator from backend.services import deployment as deployment_service from backend.services.agent import validate_agent_exists @@ -217,7 +218,12 @@ async def validate_env_vars(session: DBSessionDep, request: Request): invalid_keys = [] deployment_id = unquote_plus(request.path_params.get("deployment_id")) - deployment = deployment_service.get_deployment(session, deployment_id) + try: + deployment = deployment_service.get_deployment(session, deployment_id) + except DeploymentNotFoundError: + raise HTTPException( + status_code=404, detail=f"Deployment {deployment_id} not found." + ) for key in env_vars: if key not in deployment.env_vars(): diff --git a/src/backend/tests/integration/conftest.py b/src/backend/tests/integration/conftest.py index d2207d04fc..fb8528249f 100644 --- a/src/backend/tests/integration/conftest.py +++ b/src/backend/tests/integration/conftest.py @@ -18,7 +18,7 @@ from backend.schemas.user import User from backend.tests.unit.factories import get_factory -DATABASE_URL = os.environ["DATABASE_URL"] +DATABASE_URL = os.environ.get("DATABASE_URL", "postgresql://postgres:postgres@localhost:5433") @pytest.fixture @@ -162,7 +162,6 @@ def deployment(session: Session) -> Deployment: deployment_class_name="CohereDeployment" ) - @pytest.fixture def model(session: Session) -> Model: return get_factory("Model", session).create() diff --git a/src/backend/tests/integration/routers/test_agent.py b/src/backend/tests/integration/routers/test_agent.py index e80c23842a..32b457d631 100644 --- a/src/backend/tests/integration/routers/test_agent.py +++ b/src/backend/tests/integration/routers/test_agent.py @@ -283,7 +283,8 @@ def test_create_agent_deployment_not_in_db( "deployment": CohereDeployment.name(), } cohere_deployment = deployment_crud.get_deployment_by_name(session, CohereDeployment.name()) - deployment_crud.delete_deployment(session, cohere_deployment.id) + if cohere_deployment: + deployment_crud.delete_deployment(session, cohere_deployment.id) response = session_client.post( "/v1/agents", json=request_json, headers={"User-Id": user.id} ) diff --git a/src/backend/tests/integration/routers/test_conversation.py b/src/backend/tests/integration/routers/test_conversation.py index 1700c7fd1e..7bd296ca38 100644 --- a/src/backend/tests/integration/routers/test_conversation.py +++ b/src/backend/tests/integration/routers/test_conversation.py @@ -11,6 +11,10 @@ from backend.tests.unit.factories import get_factory +@pytest.mark.skipif( + os.environ.get("COHERE_API_KEY") is None, + reason="Cohere API key not set, skipping test", +) def test_search_conversations( session_client: TestClient, session: Session, @@ -64,7 +68,10 @@ def test_search_conversations_with_reranking( assert len(results) == 1 assert results[0]["id"] == conversation2.id - +@pytest.mark.skipif( + os.environ.get("COHERE_API_KEY") is None, + reason="Cohere API key not set, skipping test", +) def test_search_conversations_no_conversations( session_client: TestClient, session: Session, diff --git a/src/backend/tests/integration/services/auth/__init__.py b/src/backend/tests/integration/services/auth/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/src/backend/tests/integration/services/auth/strategies/__init__.py b/src/backend/tests/integration/services/auth/strategies/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/src/backend/tests/unit/configuration.yaml b/src/backend/tests/unit/configuration.yaml index 6fa7a4e576..7213a15b14 100644 --- a/src/backend/tests/unit/configuration.yaml +++ b/src/backend/tests/unit/configuration.yaml @@ -1,5 +1,5 @@ deployments: - default_deployment: + default_deployment: cohere_platform enabled_deployments: sagemaker: access_key: "sagemaker_access_key" diff --git a/src/backend/tests/unit/conftest.py b/src/backend/tests/unit/conftest.py index 9b180aaef5..3c12fe33ea 100644 --- a/src/backend/tests/unit/conftest.py +++ b/src/backend/tests/unit/conftest.py @@ -15,12 +15,13 @@ from backend.database_models import get_session from backend.database_models.base import CustomFilterQuery +from backend.database_models.deployment import Deployment from backend.main import app, create_app from backend.schemas.organization import Organization from backend.schemas.user import User from backend.tests.unit.factories import get_factory -DATABASE_URL = os.environ["DATABASE_URL"] +DATABASE_URL = os.environ.get("DATABASE_URL", "postgresql://postgres:postgres@localhost:5433") MASTER_DB_NAME = "postgres" TEST_DB_PREFIX = "postgres_" MASTER_DATABASE_FULL_URL = f"{DATABASE_URL}/{MASTER_DB_NAME}" @@ -58,7 +59,7 @@ def client(): yield TestClient(app) -@pytest.fixture(scope="session") +@pytest.fixture def engine(worker_id: str) -> Generator[Any, None, None]: """ Yields a SQLAlchemy engine which is disposed of after the test session @@ -81,6 +82,32 @@ def engine(worker_id: str) -> Generator[Any, None, None]: drop_test_database_if_exists(test_db_name) +@pytest.fixture(scope="session") +def engine_chat(worker_id: str) -> Generator[Any, None, None]: + """ + Yields a SQLAlchemy engine which is disposed of after the test session + """ + test_db_name = f"{TEST_DB_PREFIX}{worker_id}" + if worker_id == "master": + test_db_name = f"{TEST_DB_PREFIX}{worker_id}_chat" + + test_db_url = f"{DATABASE_URL}/{test_db_name}" + + drop_test_database_if_exists(test_db_name) + create_test_database(test_db_name) + engine = create_engine(test_db_url, echo=True) + + with engine.begin(): + alembic_cfg = Config("src/backend/alembic.ini") + alembic_cfg.set_main_option("sqlalchemy.url", test_db_url) + upgrade(alembic_cfg, "head") + + yield engine + + engine.dispose() + drop_test_database_if_exists(test_db_name) + + @pytest.fixture(scope="function") def session(engine: Any) -> Generator[Session, None, None]: """ @@ -122,7 +149,7 @@ def override_get_session() -> Generator[Session, Any, None]: @pytest.fixture(scope="session") -def session_chat(engine: Any) -> Generator[Session, None, None]: +def session_chat(engine_chat: Any) -> Generator[Session, None, None]: """ Yields a SQLAlchemy session within a transaction that is rolled back after every session @@ -130,7 +157,7 @@ def session_chat(engine: Any) -> Generator[Session, None, None]: We need to use the fixture in the session scope because the chat endpoint is asynchronous and needs to be open for the entire session """ - connection = engine.connect() + connection = engine_chat.connect() transaction = connection.begin() # Use connection within the started transaction session = Session(bind=connection, query_cls=CustomFilterQuery) @@ -188,6 +215,11 @@ def user(session: Session) -> User: def organization(session: Session) -> Organization: return get_factory("Organization", session).create() +@pytest.fixture +def deployment(session: Session) -> Deployment: + return get_factory("Deployment", session).create( + deployment_class_name="CohereDeployment" + ) @pytest.fixture def mock_available_model_deployments(request): diff --git a/src/backend/tests/integration/crud/test_deployment.py b/src/backend/tests/unit/crud/test_deployment.py similarity index 100% rename from src/backend/tests/integration/crud/test_deployment.py rename to src/backend/tests/unit/crud/test_deployment.py diff --git a/src/backend/tests/integration/crud/test_model.py b/src/backend/tests/unit/crud/test_model.py similarity index 100% rename from src/backend/tests/integration/crud/test_model.py rename to src/backend/tests/unit/crud/test_model.py diff --git a/src/backend/tests/integration/routers/test_chat.py b/src/backend/tests/unit/routers/test_chat.py similarity index 100% rename from src/backend/tests/integration/routers/test_chat.py rename to src/backend/tests/unit/routers/test_chat.py diff --git a/src/backend/tests/integration/routers/test_deployment.py b/src/backend/tests/unit/routers/test_deployment.py similarity index 98% rename from src/backend/tests/integration/routers/test_deployment.py rename to src/backend/tests/unit/routers/test_deployment.py index 6df4d29a1a..2d1ebbbd69 100644 --- a/src/backend/tests/integration/routers/test_deployment.py +++ b/src/backend/tests/unit/routers/test_deployment.py @@ -157,9 +157,9 @@ def test_set_env_vars( def test_set_env_vars_with_invalid_deployment_name( - client: TestClient + session_client: TestClient ): - response = client.post("/v1/deployments/unknown/update_config", json={}) + response = session_client.post("/v1/deployments/unknown/update_config", json={}) assert response.status_code == 404 diff --git a/src/backend/tests/integration/routers/test_model.py b/src/backend/tests/unit/routers/test_model.py similarity index 100% rename from src/backend/tests/integration/routers/test_model.py rename to src/backend/tests/unit/routers/test_model.py diff --git a/src/backend/tests/integration/crud/__init__.py b/src/backend/tests/unit/services/auth/__init__.py similarity index 100% rename from src/backend/tests/integration/crud/__init__.py rename to src/backend/tests/unit/services/auth/__init__.py diff --git a/src/backend/tests/integration/services/__init__.py b/src/backend/tests/unit/services/auth/strategies/__init__.py similarity index 100% rename from src/backend/tests/integration/services/__init__.py rename to src/backend/tests/unit/services/auth/strategies/__init__.py diff --git a/src/backend/tests/integration/services/auth/strategies/test_basic.py b/src/backend/tests/unit/services/auth/strategies/test_basic.py similarity index 100% rename from src/backend/tests/integration/services/auth/strategies/test_basic.py rename to src/backend/tests/unit/services/auth/strategies/test_basic.py diff --git a/src/backend/tests/integration/services/auth/test_jwt.py b/src/backend/tests/unit/services/auth/test_jwt.py similarity index 100% rename from src/backend/tests/integration/services/auth/test_jwt.py rename to src/backend/tests/unit/services/auth/test_jwt.py diff --git a/src/backend/tests/integration/services/auth/test_request_validators.py b/src/backend/tests/unit/services/auth/test_request_validators.py similarity index 100% rename from src/backend/tests/integration/services/auth/test_request_validators.py rename to src/backend/tests/unit/services/auth/test_request_validators.py diff --git a/src/backend/tests/integration/services/test_cache.py b/src/backend/tests/unit/services/test_cache.py similarity index 100% rename from src/backend/tests/integration/services/test_cache.py rename to src/backend/tests/unit/services/test_cache.py diff --git a/src/backend/tests/unit/services/test_deployment.py b/src/backend/tests/unit/services/test_deployment.py index 44df0d0512..d3aae770d4 100644 --- a/src/backend/tests/unit/services/test_deployment.py +++ b/src/backend/tests/unit/services/test_deployment.py @@ -83,7 +83,12 @@ def test_get_deployment_definition_by_name(session, mock_available_model_deploym def test_get_deployment_definition_by_name_no_db_deployments(session, mock_available_model_deployments, clear_db_deployments) -> None: definition = deployment_service.get_deployment_definition_by_name(session, MockCohereDeployment.name()) - assert definition == MockCohereDeployment.to_deployment_definition() + mock = MockCohereDeployment.to_deployment_definition() + assert definition.name == mock.name + assert definition.models == mock.models + assert definition.class_name == mock.class_name + assert definition.config == mock.config + def test_get_deployment_definition_by_name_wrong_name(session, mock_available_model_deployments) -> None: with pytest.raises(DeploymentNotFoundError):