diff --git a/api/catalog/api/views/health_views.py b/api/catalog/api/views/health_views.py index 314199a52..e83d2959d 100644 --- a/api/catalog/api/views/health_views.py +++ b/api/catalog/api/views/health_views.py @@ -1,4 +1,5 @@ from django.conf import settings +from django.db import connection from rest_framework import status from rest_framework.exceptions import APIException from rest_framework.request import Request @@ -21,19 +22,33 @@ class HealthCheck(APIView): swagger_schema = None - def _check_es(self) -> Response | None: - """Check ES cluster health and raise an exception if ES is not healthy.""" + @staticmethod + def _check_db() -> None: + """ + Check that the database is available. + Returns nothing if everything is OK, throws error otherwise. + """ + connection.ensure_connection() + + @staticmethod + def _check_es() -> None: + """ + Check Elasticsearch cluster health. + + Raises an exception if ES is not healthy. + """ es_health = settings.ES.cluster.health(timeout="5s") if es_health["timed_out"]: raise ElasticsearchHealthcheckException("es_timed_out") - if (status := es_health["status"]) != "green": - raise ElasticsearchHealthcheckException(f"es_status_{status}") + if (es_status := es_health["status"]) != "green": + raise ElasticsearchHealthcheckException(f"es_status_{es_status}") def get(self, request: Request): if "check_es" in request.query_params: self._check_es() + self._check_db() return Response({"status": "200 OK"}, status=200) diff --git a/api/test/unit/views/health_views_test.py b/api/test/unit/views/health_views_test.py index 2ddf3b3d5..b2de4a81f 100644 --- a/api/test/unit/views/health_views_test.py +++ b/api/test/unit/views/health_views_test.py @@ -1,3 +1,5 @@ +from unittest import mock + import pook import pytest @@ -16,11 +18,21 @@ def mock_health_response(status="green", timed_out=False): ) +@pytest.mark.django_db def test_health_check_plain(api_client): res = api_client.get("/healthcheck/") assert res.status_code == 200 +def test_health_check_calls__check_db(api_client): + with mock.patch( + "catalog.api.views.health_views.HealthCheck._check_db" + ) as mock_check_db: + res = api_client.get("/healthcheck/") + assert res.status_code == 200 + mock_check_db.assert_called_once() + + def test_health_check_es_timed_out(api_client): mock_health_response(timed_out=True) pook.on() @@ -42,6 +54,7 @@ def test_health_check_es_status_bad(status, api_client): assert res.json()["detail"] == f"es_status_{status}" +@pytest.mark.django_db def test_health_check_es_all_good(api_client): mock_health_response(status="green") pook.on()