Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: CDD-2432 Implement Middleware for Group Context Enforcement #2125

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -249,4 +249,5 @@ fabric.properties


# Static files collected by the Django app
metrics/static/
metrics/static/
identifier.sqlite
4 changes: 4 additions & 0 deletions config.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,3 +66,7 @@
# easy way to inject secrets into serverless lambda functions
if APP_MODE == "INGESTION":
POSTGRES_PASSWORD = get_database_password()

PRIVATE_API_INSTANCE = os.environ.get("PRIVATE_API_INSTANCE", None)
API_PUBLIC_KEY = b"-----BEGIN PUBLIC KEY-----\nMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAwhvqCC+37A+UXgcvDl+7\nnbVjDI3QErdZBkI1VypVBMkKKWHMNLMdHk0bIKL+1aDYTRRsCKBy9ZmSSX1pwQlO\n/3+gRs/MWG27gdRNtf57uLk1+lQI6hBDozuyBR0YayQDIx6VsmpBn3Y8LS13p4pT\nBvirlsdX+jXrbOEaQphn0OdQo0WDoOwwsPCNCKoIMbUOtUCowvjesFXlWkwG1zeM\nzlD1aDDS478PDZdckPjT96ICzqe4O1Ok6fRGnor2UTmuPy0f1tI0F7Ol5DHAD6pZ\nbkhB70aTBuWDGLDR0iLenzyQecmD4aU19r1XC9AHsVbQzxHrP8FveZGlV/nJOBJw\nFwIDAQAB\n-----END PUBLIC KEY-----\n"

6 changes: 6 additions & 0 deletions metrics/api/admin.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
from django.contrib import admin

from metrics.data.models.api_models import APITimeSeries
from metrics.api.models import (
DatasetGroup,
DatasetGroupMapping,
)

admin.site.register(APITimeSeries)
admin.site.register(DatasetGroup)
admin.site.register(DatasetGroupMapping)
43 changes: 43 additions & 0 deletions metrics/api/migrations/0001_initial.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Generated by Django 5.1.5 on 2025-01-30 11:22

import django.db.models.deletion
from django.db import migrations, models


class Migration(migrations.Migration):

initial = True

dependencies = []

operations = [
migrations.CreateModel(
name="DatasetGroup",
fields=[
("group_id", models.BigAutoField(primary_key=True, serialize=False)),
("name", models.CharField(max_length=255, unique=True)),
],
options={
"db_table": "dataset_groups",
},
),
migrations.CreateModel(
name="DatasetGroupMapping",
fields=[
("id", models.BigAutoField(primary_key=True, serialize=False)),
("dataset_name", models.CharField(max_length=255)),
(
"group",
models.ForeignKey(
on_delete=django.db.models.deletion.CASCADE,
related_name="group_mappings",
to="api.datasetgroup",
),
),
],
options={
"db_table": "dataset_group_mappings",
"unique_together": {("dataset_name", "group")},
},
),
]
Empty file.
2 changes: 2 additions & 0 deletions metrics/api/models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from metrics.api.models.dataset_group import DatasetGroup
from metrics.api.models.group_dataset_mapping import DatasetGroupMapping
13 changes: 13 additions & 0 deletions metrics/api/models/dataset_group.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from django.db import models


class DatasetGroup(models.Model):

class Meta:
db_table = "dataset_groups"

group_id = models.BigAutoField(primary_key=True)
name = models.CharField(max_length=255, unique=True)

def __str__(self):
return self.name
15 changes: 15 additions & 0 deletions metrics/api/models/group_dataset_mapping.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from django.db import models


class DatasetGroupMapping(models.Model):

class Meta:
db_table = "dataset_group_mappings"
unique_together = ('dataset_name', 'group')

id = models.BigAutoField(primary_key=True)
dataset_name = models.CharField(max_length=255)
group = models.ForeignKey("DatasetGroup", on_delete=models.CASCADE, related_name='group_mappings')

def __str__(self):
return f"(Dataset Name: {self.dataset_name}), (Group: {self.group.name})"
2 changes: 2 additions & 0 deletions metrics/api/serializers/timeseries.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from rest_framework import serializers

from metrics.data.models.core_models import CoreTimeSeries
from metrics.utils.auth import serializer_permissions


@serializer_permissions(["theme", "age", "stratum"])
class CoreTimeSeriesSerializer(serializers.ModelSerializer):
"""This serializer returns a set of serialized fields from the `CoreTimesSeries` and related models.

Expand Down
22 changes: 13 additions & 9 deletions metrics/api/views/downloads.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
)
from metrics.interfaces.downloads import access
from metrics.interfaces.plots.access import DataNotFoundForAnyPlotError
from metrics.utils.auth import authorised_route

DOWNLOADS_API_TAG = "downloads"

Expand All @@ -43,7 +44,7 @@ class DownloadsView(APIView):
renderer_classes = (JSONOpenAPIRenderer,)

def _get_serializer_class(
self, queryset: CoreTimeSeriesQuerySet | CoreHeadlineQuerySet, metric_group: str
self, queryset: CoreTimeSeriesQuerySet | CoreHeadlineQuerySet, metric_group: str, request
) -> CoreHeadlineSerializer | CoreTimeSeriesSerializer:
"""Returns the appropriate serializer class based on the
provided metric_group.
Expand All @@ -57,7 +58,7 @@ def _get_serializer_class(
return self.headline_serializer_class(queryset, many=True)

if DataSourceFileType[metric_group].is_timeseries:
return self.timeseries_serializer_class(queryset, many=True)
return self.timeseries_serializer_class(queryset, many=True, context={"request": request})

except KeyError:
raise ValueError(DEFAULT_VALUE_ERROR_MESSAGE)
Expand All @@ -67,10 +68,11 @@ def _handle_json(
*,
queryset: CoreTimeSeriesQuerySet | CoreHeadlineQuerySet,
metric_group: str,
request
) -> Response:
# Return the requested data in json format
serializer = self._get_serializer_class(
queryset=queryset, metric_group=metric_group
queryset=queryset, metric_group=metric_group, request=request,
)

response = Response(serializer.data)
Expand All @@ -83,13 +85,14 @@ def _handle_csv(
*,
queryset: CoreTimeSeriesQuerySet | CoreHeadlineQuerySet,
metric_group: str,
request,
) -> io.StringIO:
# Return the requested data in csv format
response = HttpResponse(content_type="text/csv")
response["Content-Disposition"] = 'attachment; filename="mymodel.csv"'

serializer = self._get_serializer_class(
queryset=queryset, metric_group=metric_group
queryset=queryset, metric_group=metric_group, request=request,
)

if DataSourceFileType[metric_group].is_headline:
Expand All @@ -99,8 +102,10 @@ def _handle_csv(

return write_data_to_csv(file=response, core_time_series_queryset=queryset)

@extend_schema(request=DownloadsSerializer, tags=[DOWNLOADS_API_TAG])
@cache_response()

# @extend_schema(request=DownloadsSerializer, tags=[DOWNLOADS_API_TAG])
# @cache_response()
@authorised_route
def post(self, request, *args, **kwargs):
"""This endpoint will return the query output in json/csv format

Expand Down Expand Up @@ -133,7 +138,6 @@ def post(self, request, *args, **kwargs):

file_format: str = request_serializer.data["file_format"]
chart_plot_models = request_serializer.to_models()

try:
queryset: CoreTimeSeriesQuerySet = access.get_downloads_data(
chart_plots=chart_plot_models
Expand All @@ -146,11 +150,11 @@ def post(self, request, *args, **kwargs):
match file_format:
case "json":
return self._handle_json(
queryset=queryset, metric_group=chart_plot_models.metric_group
queryset=queryset, metric_group=chart_plot_models.metric_group, request=request,
)
case "csv":
return self._handle_csv(
queryset=queryset, metric_group=chart_plot_models.metric_group
queryset=queryset, metric_group=chart_plot_models.metric_group, request=request,
)


Expand Down
Empty file added metrics/utils/__init__.py
Empty file.
67 changes: 67 additions & 0 deletions metrics/utils/auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import logging
from typing import List
from functools import wraps

from django.http import JsonResponse
from rest_framework.serializers import Serializer
import jwt

from metrics.api.models import (
DatasetGroup,
DatasetGroupMapping,
)

from config import PRIVATE_API_INSTANCE, API_PUBLIC_KEY

logger = logging.getLogger(__name__)


def authorised_route(func):
@wraps(func)
def wrap(self, request, *args, **kwargs):
if not PRIVATE_API_INSTANCE:
return func(self, request, *args, **kwargs)
try:
token = request.headers.get("Authorization")
token = token.split("Bearer ")[1]
payload = jwt.decode(token, API_PUBLIC_KEY, algorithms=["RS256"])
group_id = payload["group_id"]
query = DatasetGroupMapping.objects.filter(group__name=group_id)
dataset_names = query.values_list('dataset_name', flat=True)
request.dataset_names = list(dataset_names)
print("dataset names ------> ", dataset_names)
return func(self, request, *args, **kwargs)
except jwt.ExpiredSignatureError as err:
return JsonResponse({"error": "Token expired!"})
except jwt.DecodeError as err:
return JsonResponse({"error": "Token decode error!"})
except jwt.InvalidTokenError as err:
return JsonResponse({"error": "Invalid token error!"})
except KeyError as err:
return JsonResponse({"error": "Invalid payload error!"})
except Exception as err:
print(str(err))
return JsonResponse({"error": "Authorisation error!"})

return wrap


def get_allowed_dataset_types(request) -> List[str]:
return getattr(request, "dataset_names", [])


def serializer_permissions(restricted_fields: List[str]):
def decorator(serializer_class):
_init = serializer_class.__init__

@wraps(_init)
def init(self, *args, **kwargs):
super(serializer_class, self).__init__(*args, **kwargs)
request = self.context.get("request", None)
dataset_names: List[str] = getattr(request, "dataset_names", [])
for field in restricted_fields:
if field not in dataset_names:
self.fields.pop(field, None)
serializer_class.__init__ = init
return serializer_class
return decorator
Loading
Loading