Skip to content

Commit

Permalink
test: Fix transaction unit tests (#425)
Browse files Browse the repository at this point in the history
Co-authored-by: Vincent Sarago <[email protected]>
Co-authored-by: smohiudd <[email protected]>
  • Loading branch information
3 people authored Sep 25, 2024
1 parent d6aedd7 commit e53f81b
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 32 deletions.
7 changes: 5 additions & 2 deletions .github/workflows/pr.yml
Original file line number Diff line number Diff line change
Expand Up @@ -77,14 +77,17 @@ jobs:
- name: Install reqs for ingest api
run: python -m pip install -r ingest_api/runtime/requirements_dev.txt

- name: Install reqs for stac api
run: python -m pip install stac_api/runtime/

- name: Install veda auth for ingest api
run: python -m pip install common/auth

- name: Ingest unit tests
run: NO_PYDANTIC_SSM_SETTINGS=1 python -m pytest ingest_api/runtime/tests/ -vv -s

# - name: Stac-api transactions unit tests
# run: python -m pytest stac_api/runtime/tests/ -vv -s
- name: Stac-api transactions unit tests
run: python -m pytest stac_api/runtime/tests/ --asyncio-mode=auto -vv -s

- name: Stop services
run: docker compose stop
Expand Down
5 changes: 4 additions & 1 deletion scripts/run-local-tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,7 @@ docker exec veda.db /tmp/scripts/bin/load-data.sh
python -m pytest .github/workflows/tests/ -vv -s

# Run ingest unit tests
NO_PYDANTIC_SSM_SETTINGS=1 python -m pytest --cov=ingest_api/runtime/src ingest_api/runtime/tests/ -vv -s
NO_PYDANTIC_SSM_SETTINGS=1 python -m pytest --cov=ingest_api/runtime/src ingest_api/runtime/tests/ -vv -s

# Transactions tests
python -m pytest stac_api/runtime/tests/ --asyncio-mode=auto -vv -s
14 changes: 6 additions & 8 deletions stac_api/runtime/src/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,13 @@

from pydantic import BaseModel, Field
from pystac import STACObjectType
from pystac.errors import STACValidationError
from pystac.errors import STACTypeError, STACValidationError
from pystac.validation import validate_dict
from src.config import api_settings

from fastapi import Request
from fastapi.responses import JSONResponse
from starlette.middleware.base import BaseHTTPMiddleware

path_prefix = api_settings.root_path or ""


class BulkItems(BaseModel):
"""Validation model for bulk-items endpoint request"""
Expand All @@ -33,24 +30,25 @@ async def dispatch(self, request: Request, call_next):
try:
body = await request.body()
request_data = json.loads(body)

if re.match(
f"^{path_prefix}/collections(?:/[^/]+)?$",
"^.*?/collections(?:/[^/]+)?$",
request.url.path,
):
validate_dict(request_data, STACObjectType.COLLECTION)
elif re.match(
f"^{path_prefix}/collections/[^/]+/items(?:/[^/]+)?$",
"^.*?/collections/[^/]+/items(?:/[^/]+)?$",
request.url.path,
):
validate_dict(request_data, STACObjectType.ITEM)
elif re.match(
f"^{path_prefix}/collections/[^/]+/bulk-items$",
"^.*?/collections/[^/]+/bulk_items$",
request.url.path,
):
bulk_items = BulkItems(**request_data)
for item_data in bulk_items.items.values():
validate_dict(item_data, STACObjectType.ITEM)
except STACValidationError as e:
except (STACValidationError, STACTypeError) as e:
return JSONResponse(
status_code=422,
content={"detail": "Validation Error", "errors": str(e)},
Expand Down
27 changes: 18 additions & 9 deletions stac_api/runtime/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@
import os

import pytest
from httpx import ASGITransport, AsyncClient

from fastapi.testclient import TestClient
from stac_fastapi.pgstac.db import close_db_connection, connect_to_db

VALID_COLLECTION = {
"id": "CMIP245-winter-median-pr",
Expand Down Expand Up @@ -209,7 +210,7 @@
}


@pytest.fixture
@pytest.fixture(autouse=True)
def test_environ():
"""
Set up the test environment with mocked AWS and PostgreSQL credentials.
Expand All @@ -235,8 +236,8 @@ def test_environ():
os.environ["POSTGRES_USER"] = "username"
os.environ["POSTGRES_PASS"] = "password"
os.environ["POSTGRES_DBNAME"] = "postgis"
os.environ["POSTGRES_HOST_READER"] = "database"
os.environ["POSTGRES_HOST_WRITER"] = "database"
os.environ["POSTGRES_HOST_READER"] = "0.0.0.0"
os.environ["POSTGRES_HOST_WRITER"] = "0.0.0.0"
os.environ["POSTGRES_PORT"] = "5432"


Expand All @@ -251,7 +252,7 @@ def override_validated_token():


@pytest.fixture
def app(test_environ):
async def app():
"""
Fixture to initialize the FastAPI application.
Expand All @@ -266,11 +267,13 @@ def app(test_environ):
"""
from src.app import app

return app
await connect_to_db(app)
yield app
await close_db_connection(app)


@pytest.fixture
def api_client(app):
@pytest.fixture(scope="function")
async def api_client(app):
"""
Fixture to initialize the API client for making requests.
Expand All @@ -286,7 +289,13 @@ def api_client(app):
from src.app import auth

app.dependency_overrides[auth.validated_token] = override_validated_token
yield TestClient(app)
base_url = "http://test"

async with AsyncClient(
transport=ASGITransport(app=app), base_url=base_url
) as client:
yield client

app.dependency_overrides.clear()


Expand Down
24 changes: 12 additions & 12 deletions stac_api/runtime/tests/test_transactions.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,59 +53,59 @@ def setup(
self.invalid_stac_collection = invalid_stac_collection
self.invalid_stac_item = invalid_stac_item

def test_post_invalid_collection(self):
async def test_post_invalid_collection(self):
"""
Test the API's response to posting an invalid STAC collection.
Asserts that the response status code is 422 and the detail
is "Validation Error".
"""
response = self.api_client.post(
response = await self.api_client.post(
collections_endpoint, json=self.invalid_stac_collection
)
assert response.json()["detail"] == "Validation Error"
assert response.status_code == 422

def test_post_valid_collection(self):
async def test_post_valid_collection(self):
"""
Test the API's response to posting a valid STAC collection.
Asserts that the response status code is 200.
"""
response = self.api_client.post(
response = await self.api_client.post(
collections_endpoint, json=self.valid_stac_collection
)
# assert response.json() == {}
assert response.status_code == 200

def test_post_invalid_item(self):
async def test_post_invalid_item(self):
"""
Test the API's response to posting an invalid STAC item.
Asserts that the response status code is 422 and the detail
is "Validation Error".
"""
response = self.api_client.post(
response = await self.api_client.post(
items_endpoint.format(self.invalid_stac_item["collection"]),
json=self.invalid_stac_item,
)
assert response.json()["detail"] == "Validation Error"
assert response.status_code == 422

def test_post_valid_item(self):
async def test_post_valid_item(self):
"""
Test the API's response to posting a valid STAC item.
Asserts that the response status code is 200.
"""
response = self.api_client.post(
response = await self.api_client.post(
items_endpoint.format(self.valid_stac_item["collection"]),
json=self.valid_stac_item,
)
# assert response.json() == {}
assert response.status_code == 200

def test_post_invalid_bulk_items(self):
async def test_post_invalid_bulk_items(self):
"""
Test the API's response to posting invalid bulk STAC items.
Expand All @@ -117,12 +117,12 @@ def test_post_invalid_bulk_items(self):
"items": {item_id: self.invalid_stac_item},
"method": "upsert",
}
response = self.api_client.post(
response = await self.api_client.post(
bulk_endpoint.format(collection_id), json=invalid_request
)
assert response.status_code == 422

def test_post_valid_bulk_items(self):
async def test_post_valid_bulk_items(self):
"""
Test the API's response to posting valid bulk STAC items.
Expand All @@ -131,7 +131,7 @@ def test_post_valid_bulk_items(self):
item_id = self.valid_stac_item["id"]
collection_id = self.valid_stac_item["collection"]
valid_request = {"items": {item_id: self.valid_stac_item}, "method": "upsert"}
response = self.api_client.post(
response = await self.api_client.post(
bulk_endpoint.format(collection_id), json=valid_request
)
assert response.status_code == 200

0 comments on commit e53f81b

Please sign in to comment.