diff --git a/onadata/apps/api/tests/viewsets/test_entity_list_viewset.py b/onadata/apps/api/tests/viewsets/test_entity_list_viewset.py index 26b70961cd..53801d5b2a 100644 --- a/onadata/apps/api/tests/viewsets/test_entity_list_viewset.py +++ b/onadata/apps/api/tests/viewsets/test_entity_list_viewset.py @@ -162,6 +162,22 @@ def test_pagination(self): self.assertEqual(response.status_code, 200) self.assertEqual(len(response.data), 1) + def test_filtering_by_project(self): + """Filter by project id works""" + self._project_create() + project_2 = Project.objects.create( + name="Other project", + created_by=self.user, + organization=self.user, + ) + EntityList.objects.create(name="dataset_1", project=self.project) + EntityList.objects.create(name="dataset_2", project=project_2) + request = self.factory.get("/", data={"project": project_2.pk}, **self.extra) + response = self.view(request) + self.assertEqual(response.status_code, 200) + self.assertEqual(len(response.data), 1) + self.assertEqual(response.data[0]["name"], "dataset_2") + @override_settings(TIME_ZONE="UTC") class GetSingleEntityListTestCase(TestAbstractViewSet): diff --git a/onadata/apps/api/viewsets/entity_list_viewset.py b/onadata/apps/api/viewsets/entity_list_viewset.py index 47d042a93e..0a10c965a1 100644 --- a/onadata/apps/api/viewsets/entity_list_viewset.py +++ b/onadata/apps/api/viewsets/entity_list_viewset.py @@ -6,6 +6,7 @@ from onadata.apps.api.tools import get_baseviewset_class from onadata.apps.logger.models import Entity, EntityList +from onadata.libs.filters import EntityListProjectFilter from onadata.libs.mixins.cache_control_mixin import CacheControlMixin from onadata.libs.mixins.etags_mixin import ETagsMixin from onadata.libs.mixins.anon_user_public_entity_lists_mixin import ( @@ -35,6 +36,7 @@ class EntityListViewSet( serializer_class = EntityListSerializer permission_classes = (AllowAny,) pagination_class = StandardPageNumberPagination + filter_backends = (EntityListProjectFilter,) def get_queryset(self): if self.action == "retrieve": diff --git a/onadata/libs/filters.py b/onadata/libs/filters.py index e210ad658f..8238a97f13 100644 --- a/onadata/libs/filters.py +++ b/onadata/libs/filters.py @@ -344,17 +344,19 @@ def _xform_filter(self, request, view, keyword): if dataview: int_or_parse_error( dataview, - "Invalid value for dataview ID. It must be a positive integer." + "Invalid value for dataview ID. It must be a positive integer.", ) self.dataview = get_object_or_404(DataView, pk=dataview) # filter with fitlered dataset query dataview_kwargs = self._add_instance_prefix_to_dataview_filter_kwargs( - get_filter_kwargs(self.dataview.query)) + get_filter_kwargs(self.dataview.query) + ) xform_qs = XForm.objects.filter(pk=self.dataview.xform.pk) elif merged_xform: int_or_parse_error( merged_xform, - "Invalid value for Merged Dataset ID. It must be a positive integer.") + "Invalid value for Merged Dataset ID. It must be a positive integer.", + ) self.merged_xform = get_object_or_404(MergedXForm, pk=merged_xform) xform_qs = self.merged_xform.xforms.all() elif xform: @@ -378,10 +380,7 @@ def _xform_filter(self, request, view, keyword): xforms = xform_qs.filter(shared_data=True) else: xforms = super().filter_queryset(request, xform_qs, view) | public_forms - return { - **{f"{keyword}__in": xforms}, - **dataview_kwargs - } + return {**{f"{keyword}__in": xforms}, **dataview_kwargs} def _xform_filter_queryset(self, request, queryset, view, keyword): kwarg = self._xform_filter(request, view, keyword) @@ -748,3 +747,18 @@ def filter_queryset(self, request, queryset, view): return queryset.filter(shared=True) return queryset + + +# pylint: disable=too-few-public-methods +class EntityListProjectFilter(filters.BaseFilterBackend): + """EntityList `project` filter.""" + + # pylint: disable=unused-argument + def filter_queryset(self, request, queryset, view): + """Filter by project id""" + project_id = request.query_params.get("project") + + if project_id: + return queryset.filter(project__pk=project_id) + + return super().filter_queryset(request, queryset, view)