Skip to content

Commit

Permalink
CDD-2432 Implement Middleware - auth decorator
Browse files Browse the repository at this point in the history
  • Loading branch information
mxrman committed Feb 25, 2025
1 parent 173cf19 commit 59c3a17
Show file tree
Hide file tree
Showing 11 changed files with 272 additions and 23 deletions.
3 changes: 3 additions & 0 deletions config.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,3 +66,6 @@
# easy way to inject secrets into serverless lambda functions
if APP_MODE == "INGESTION":
POSTGRES_PASSWORD = get_database_password()

# Enables RBAC group permissions
AUTH_ENABLED = os.environ.get("AUTH_ENABLED")
Empty file.
51 changes: 51 additions & 0 deletions metrics/api/decorators/auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from enum import Enum
from functools import wraps

from django.db import Error
from django.http import JsonResponse

from config import AUTH_ENABLED
from metrics.data.models.rbac_models import RBACGroupPermission

RBAC_AUTH_X_HEADER = "X-GroupId"


class ErrorCode(Enum):
INVALID_GROUP_ID = 1115


def authorised_route(func):
@wraps(func)
def wrap(self, request, *args, **kwargs):
if not AUTH_ENABLED:
return func(self, request, *args, **kwargs)
try:
if RBAC_AUTH_X_HEADER in request.headers:
group_id = request.headers.get(RBAC_AUTH_X_HEADER)
if group_id == "":
raise InvalidGroupIdError
_set_rbac_group_permissions(request, group_id)
except InvalidGroupIdError:
return JsonResponse(
{"error": "Access Denied", "code": ErrorCode.INVALID_GROUP_ID.value},
status=403,
)
return func(self, request, *args, **kwargs)

return wrap


def _set_rbac_group_permissions(request, group_id: str) -> None:
try:
group_permissions = RBACGroupPermission.objects.get_group(name=group_id)
if group_permissions:
request.group_permissions = list(group_permissions.permissions.all())
else:
raise InvalidGroupIdError
except Error:
"""Catch all for database related errors"""
raise InvalidGroupIdError


class InvalidGroupIdError(Exception):
"""Custom exception for invalid RBAC group ID"""
2 changes: 2 additions & 0 deletions metrics/api/views/downloads.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from rest_framework.views import APIView

from caching.private_api.decorators import cache_response
from metrics.api.decorators.auth import authorised_route
from metrics.api.serializers import (
BulkDownloadsSerializer,
CoreHeadlineSerializer,
Expand Down Expand Up @@ -99,6 +100,7 @@ def _handle_csv(

return write_data_to_csv(file=response, core_time_series_queryset=queryset)

@authorised_route
@extend_schema(request=DownloadsSerializer, tags=[DOWNLOADS_API_TAG])
@cache_response()
def post(self, request, *args, **kwargs):
Expand Down
4 changes: 4 additions & 0 deletions metrics/data/managers/rbac_models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,7 @@
RBACPermissionQuerySet,
RBACPermissionManager,
)
from metrics.data.managers.rbac_models.rbac_group_permissions import (
RBACGroupPermissionQuerySet,
RBACGroupPermissionManager,
)
43 changes: 43 additions & 0 deletions metrics/data/managers/rbac_models/rbac_group_permissions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from django.db import models


class RBACGroupPermissionQuerySet(models.QuerySet):
"""Custom queryset for the `RBACGroupPermission` model."""

def get_group(self, name: str) -> "RBACGroupPermission":
"""
Retrieves a single `RBACGroupPermission` instance based on the given name.
Since the `name` field has a unique constraint, this method returns at most one group.
Args:
name (str): The name of the group permission to retrieve.
Returns:
RBACGroupPermission | None: The matching group permission instance if found, otherwise None.
"""
return self.filter(name=name).first()


class RBACGroupPermissionManager(models.Manager):
"""Custom manager for the `RBACGroupPermission` model."""

def get_queryset(self) -> RBACGroupPermissionQuerySet:
"""
Returns the custom queryset for RBACGroupPermission.
This allows access to custom queryset methods like `get_group()`.
"""
return RBACGroupPermissionQuerySet(self.model, using=self._db)

def get_group(self, name: str) -> "RBACGroupPermission":
"""
Retrieves a single `RBACGroupPermission` instance by name using the queryset method.
Args:
name (str): The name of the group permission to retrieve.
Returns:
RBACGroupPermission | None: The matching group permission instance if found, otherwise None.
"""
return self.get_queryset().get_group(name)
6 changes: 6 additions & 0 deletions metrics/data/models/rbac_models/rbac_group_permissions.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
from django.db import models

from metrics.data.managers.rbac_models.rbac_group_permissions import (
RBACGroupPermissionManager,
)


class RBACGroupPermission(models.Model):

Expand All @@ -12,5 +16,7 @@ class Meta:
"RBACPermission", related_name="rbac_group_permissions"
)

objects = RBACGroupPermissionManager()

def __str__(self):
return self.name
23 changes: 0 additions & 23 deletions tests/factories/metrics/rbac_models/rbac_group_permissions.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import contextlib
import datetime
from typing import List

import factory
Expand All @@ -9,9 +7,6 @@
)


from django.utils import timezone


class RBACPermissionGroupFactory(factory.django.DjangoModelFactory):

class Meta:
Expand All @@ -30,21 +25,3 @@ def create_record(
for permission in permissions:
group.permissions.add(permission)
return group

@classmethod
def _make_datetime_timezone_aware(
cls, datetime_obj: str | datetime.datetime | None
) -> datetime.datetime:

if datetime_obj is None:
return datetime_obj

with contextlib.suppress(TypeError):
# If it is already a datetime object then suppress the resulting TypeError
datetime_obj = datetime.datetime.strptime(datetime_obj, "%Y-%m-%d")

try:
return timezone.make_aware(value=datetime_obj)
except ValueError:
# The object is already timezone aware
return datetime_obj
Empty file.
119 changes: 119 additions & 0 deletions tests/integration/metrics/api/decorators/test_auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
import pytest
from http import HTTPStatus
from unittest import mock
from django.urls import path
from rest_framework.test import APIClient
from rest_framework.views import APIView
from django.test import override_settings
from metrics.api.decorators.auth import authorised_route, RBAC_AUTH_X_HEADER
from django.http import JsonResponse

from tests.factories.metrics.rbac_models.rbac_group_permissions import (
RBACPermissionGroupFactory,
)
from tests.factories.metrics.rbac_models.rbac_permission import RBACPermissionFactory


MODULE_PATH = "metrics.api.decorators.auth"


class MockDownloadView(APIView):
@authorised_route
def post(self, request, *args, **kwargs):
permissions = getattr(request, "group_permissions", None)
if permissions:
permissions_data = [p.name for p in permissions]
else:
permissions_data = []
return JsonResponse(
{"message": "Success", "permissions": permissions_data},
status=HTTPStatus.OK,
)


urlpatterns = [
path("api/mock-downloads/", MockDownloadView.as_view(), name="mock-downloads"),
]


class TestAuthorisedRoute:
"""
Tests for the `authorised_route` decorator.
"""

@pytest.mark.django_db
@override_settings(ROOT_URLCONF=__name__)
def test_request_succeeds_when_auth_is_disabled(self):
"""
Given authentication is disabled
When a request is made to an authorised route
Then the response is successful
"""
# Given
client = APIClient()

with mock.patch(f"{MODULE_PATH}.AUTH_ENABLED", False):
# When
response = client.post("/api/mock-downloads/", format="json")

# Then
assert response.status_code == HTTPStatus.OK
assert response.json() == {"message": "Success", "permissions": []}

@pytest.mark.django_db
@override_settings(ROOT_URLCONF=__name__)
def test_request_succeeds_with_valid_group_id(self):
"""
Given authentication is enabled
And a valid `X-GroupId` header is provided
When a request is made to an authorised route
Then the response is successful
"""
# Given
client = APIClient()
headers = {f"HTTP_{RBAC_AUTH_X_HEADER}": "medical"}
all_infectious = RBACPermissionFactory.create_record(
name="all_infectious_respiratory_data",
theme_name="infectious_disease",
sub_theme_name="respiratory",
)
_ = RBACPermissionGroupFactory.create_record(
name="medical",
permissions=[all_infectious],
)

with mock.patch(f"{MODULE_PATH}.AUTH_ENABLED", True):
# When
response = client.post("/api/mock-downloads/", format="json", **headers)

# Then
assert response.status_code == HTTPStatus.OK
assert response.json() == {
"message": "Success",
"permissions": [all_infectious.name],
}
# mock_set_rbac.assert_called_once_with(mock.ANY, "medical")

@pytest.mark.django_db
@override_settings(ROOT_URLCONF=__name__)
def test_request_fails_with_invalid_group_id(self):
"""
Given authentication is enabled
And an invalid `X-Group-id` header is provided
When a request is made to an authorised route
Then the response contains an error message
"""
# Given
client = APIClient()
headers = {f"HTTP_{RBAC_AUTH_X_HEADER}": "invalid"}

with mock.patch(f"{MODULE_PATH}.AUTH_ENABLED", True):
# When
response = client.post("/api/mock-downloads/", format="json", **headers)

# Then
assert response.status_code == HTTPStatus.FORBIDDEN
assert response.json() == {
"error": "Access Denied",
"code": 1115,
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import pytest

from metrics.data.models.rbac_models import RBACGroupPermission
from tests.factories.metrics.rbac_models.rbac_group_permissions import (
RBACPermissionGroupFactory,
)


class TestRBACGroupPermissionFactory:
group_permissions = {
"name": "admin_group",
}

@pytest.mark.django_db
def test_create_record_creates_valid_group_permission(self):
"""
Given valid input parameters,
When `create_record` is called,
Then an `RBACGroupPermission` instance is created with the correct attributes.
"""
# Given
group_permission = RBACPermissionGroupFactory.create_record(
**self.group_permissions
)

# When
assert RBACGroupPermission.objects.filter(id=group_permission.id).exists()

# Then
assert group_permission.name == "admin_group"

@pytest.mark.django_db
def test_create_duplicate_group_permission_raises_error(self):
"""
Given an existing `RBACGroupPermission` record with a specific name,
When another group with the same name is created,
Then an integrity error should be raised.
"""
# Given
RBACPermissionGroupFactory.create_record(name="admin_group")

# When / Then
with pytest.raises(Exception): # Replace with the actual exception if needed
RBACPermissionGroupFactory.create_record(name="admin_group")

0 comments on commit 59c3a17

Please sign in to comment.