From a7c91f94430015897237aa799c04f46961ef0707 Mon Sep 17 00:00:00 2001 From: Samuel Paccoud - DINUM Date: Sat, 12 Apr 2025 09:11:33 +0200 Subject: [PATCH] =?UTF-8?q?=E2=99=BB=EF=B8=8F(backend)=20refactor=20resour?= =?UTF-8?q?ce=20access=20viewset?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The document viewset was overriding the get_queryset method from its own mixin. This was a sign that the mixin was not optimal anymore. In the next commit I will need to complexify it further so it's time to refactor the mixin. --- src/backend/core/api/viewsets.py | 110 ++++++++---------- .../documents/test_api_document_accesses.py | 30 ++--- .../templates/test_api_template_accesses.py | 11 +- 3 files changed, 59 insertions(+), 92 deletions(-) diff --git a/src/backend/core/api/viewsets.py b/src/backend/core/api/viewsets.py index 03c1752c..4e4b5661 100644 --- a/src/backend/core/api/viewsets.py +++ b/src/backend/core/api/viewsets.py @@ -223,14 +223,10 @@ class UserViewSet( class ResourceAccessViewsetMixin: """Mixin with methods common to all access viewsets.""" - def get_permissions(self): - """User only needs to be authenticated to list resource accesses""" - if self.action == "list": - permission_classes = [permissions.IsAuthenticated] - else: - return super().get_permissions() - - return [permission() for permission in permission_classes] + def filter_queryset(self, queryset): + """Override to filter on related resource.""" + queryset = super().filter_queryset(queryset) + return queryset.filter(**{self.resource_field_name: self.kwargs["resource_id"]}) def get_serializer_context(self): """Extra context provided to the serializer class.""" @@ -238,43 +234,6 @@ class ResourceAccessViewsetMixin: context["resource_id"] = self.kwargs["resource_id"] return context - def get_queryset(self): - """Return the queryset according to the action.""" - queryset = super().get_queryset() - queryset = queryset.filter( - **{self.resource_field_name: self.kwargs["resource_id"]} - ) - - if self.action == "list": - user = self.request.user - teams = user.teams - user_roles_query = ( - queryset.filter( - db.Q(user=user) | db.Q(team__in=teams), - **{self.resource_field_name: self.kwargs["resource_id"]}, - ) - .values(self.resource_field_name) - .annotate(roles_array=ArrayAgg("role")) - .values("roles_array") - ) - - # Limit to resource access instances related to a resource THAT also has - # a resource access - # instance for the logged-in user (we don't want to list only the resource - # access instances pointing to the logged-in user) - queryset = ( - queryset.filter( - db.Q(**{f"{self.resource_field_name}__accesses__user": user}) - | db.Q( - **{f"{self.resource_field_name}__accesses__team__in": teams} - ), - **{self.resource_field_name: self.kwargs["resource_id"]}, - ) - .annotate(user_roles=db.Subquery(user_roles_query)) - .distinct() - ) - return queryset - def destroy(self, request, *args, **kwargs): """Forbid deleting the last owner access""" instance = self.get_object() @@ -1551,7 +1510,11 @@ class DocumentViewSet( class DocumentAccessViewSet( ResourceAccessViewsetMixin, - viewsets.ModelViewSet, + drf.mixins.CreateModelMixin, + drf.mixins.RetrieveModelMixin, + drf.mixins.UpdateModelMixin, + drf.mixins.DestroyModelMixin, + viewsets.GenericViewSet, ): """ API ViewSet for all interactions with document accesses. @@ -1578,31 +1541,35 @@ class DocumentAccessViewSet( """ lookup_field = "pk" - pagination_class = Pagination permission_classes = [permissions.IsAuthenticated, permissions.AccessPermission] queryset = models.DocumentAccess.objects.select_related("user").all() resource_field_name = "document" serializer_class = serializers.DocumentAccessSerializer is_current_user_owner_or_admin = False - def get_queryset(self): - """Return the queryset according to the action.""" - queryset = super().get_queryset() + def list(self, request, *args, **kwargs): + """Return accesses for the current document with filters and annotations.""" + user = self.request.user + queryset = self.filter_queryset(self.get_queryset()) - if self.action == "list": - try: - document = models.Document.objects.get(pk=self.kwargs["resource_id"]) - except models.Document.DoesNotExist: - return queryset.none() + try: + document = models.Document.objects.get(pk=self.kwargs["resource_id"]) + except models.Document.DoesNotExist: + return drf.response.Response([]) - roles = set(document.get_roles(self.request.user)) - is_owner_or_admin = bool(roles.intersection(set(models.PRIVILEGED_ROLES))) - self.is_current_user_owner_or_admin = is_owner_or_admin - if not is_owner_or_admin: - # Return only the document owner access - queryset = queryset.filter(role__in=models.PRIVILEGED_ROLES) + roles = set(document.get_roles(user)) + if not roles: + return drf.response.Response([]) - return queryset + is_owner_or_admin = bool(roles.intersection(set(models.PRIVILEGED_ROLES))) + self.is_current_user_owner_or_admin = is_owner_or_admin + if not is_owner_or_admin: + # Return only the document's privileged accesses + queryset = queryset.filter(role__in=models.PRIVILEGED_ROLES) + + queryset = queryset.distinct() + serializer = self.get_serializer(queryset, many=True) + return drf.response.Response(serializer.data) def get_serializer_class(self): if self.action == "list" and not self.is_current_user_owner_or_admin: @@ -1720,7 +1687,6 @@ class TemplateAccessViewSet( ResourceAccessViewsetMixin, drf.mixins.CreateModelMixin, drf.mixins.DestroyModelMixin, - drf.mixins.ListModelMixin, drf.mixins.RetrieveModelMixin, drf.mixins.UpdateModelMixin, viewsets.GenericViewSet, @@ -1750,12 +1716,28 @@ class TemplateAccessViewSet( """ lookup_field = "pk" - pagination_class = Pagination permission_classes = [permissions.IsAuthenticated, permissions.AccessPermission] queryset = models.TemplateAccess.objects.select_related("user").all() resource_field_name = "template" serializer_class = serializers.TemplateAccessSerializer + def list(self, request, *args, **kwargs): + """Restrict templates returned by the list endpoint""" + user = self.request.user + teams = user.teams + queryset = self.filter_queryset(self.get_queryset()) + + # Limit to resource access instances related to a resource THAT also has + # a resource access instance for the logged-in user (we don't want to list + # only the resource access instances pointing to the logged-in user) + queryset = queryset.filter( + db.Q(template__accesses__user=user) + | db.Q(template__accesses__team__in=teams), + ).distinct() + + serializer = self.get_serializer(queryset, many=True) + return drf.response.Response(serializer.data) + class InvitationViewset( drf.mixins.CreateModelMixin, diff --git a/src/backend/core/tests/documents/test_api_document_accesses.py b/src/backend/core/tests/documents/test_api_document_accesses.py index bf5ef182..946ffcf0 100644 --- a/src/backend/core/tests/documents/test_api_document_accesses.py +++ b/src/backend/core/tests/documents/test_api_document_accesses.py @@ -51,12 +51,7 @@ def test_api_document_accesses_list_authenticated_unrelated(): f"/api/v1.0/documents/{document.id!s}/accesses/", ) assert response.status_code == 200 - assert response.json() == { - "count": 0, - "next": None, - "previous": None, - "results": [], - } + assert response.json() == [] def test_api_document_accesses_list_unexisting_document(): @@ -70,12 +65,7 @@ def test_api_document_accesses_list_unexisting_document(): response = client.get(f"/api/v1.0/documents/{uuid4()!s}/accesses/") assert response.status_code == 200 - assert response.json() == { - "count": 0, - "next": None, - "previous": None, - "results": [], - } + assert response.json() == [] @pytest.mark.parametrize("via", VIA) @@ -129,14 +119,14 @@ def test_api_document_accesses_list_authenticated_related_non_privileged( f"/api/v1.0/documents/{document.id!s}/accesses/", ) - # Return only owners - owners_accesses = [ + # Return only privileged roles + privileged_accesses = [ access for access in accesses if access.role in models.PRIVILEGED_ROLES ] assert response.status_code == 200 content = response.json() - assert content["count"] == len(owners_accesses) - assert sorted(content["results"], key=lambda x: x["id"]) == sorted( + assert len(content) == len(privileged_accesses) + assert sorted(content, key=lambda x: x["id"]) == sorted( [ { "id": str(access.id), @@ -152,12 +142,12 @@ def test_api_document_accesses_list_authenticated_related_non_privileged( "role": access.role, "abilities": access.get_abilities(user), } - for access in owners_accesses + for access in privileged_accesses ], key=lambda x: x["id"], ) - for access in content["results"]: + for access in content: assert access["role"] in models.PRIVILEGED_ROLES @@ -216,8 +206,8 @@ def test_api_document_accesses_list_authenticated_related_privileged_roles( assert response.status_code == 200 content = response.json() - assert len(content["results"]) == 4 - assert sorted(content["results"], key=lambda x: x["id"]) == sorted( + assert len(content) == 4 + assert sorted(content, key=lambda x: x["id"]) == sorted( [ { "id": str(user_access.id), diff --git a/src/backend/core/tests/templates/test_api_template_accesses.py b/src/backend/core/tests/templates/test_api_template_accesses.py index 86e5f2bd..6d110776 100644 --- a/src/backend/core/tests/templates/test_api_template_accesses.py +++ b/src/backend/core/tests/templates/test_api_template_accesses.py @@ -48,12 +48,7 @@ def test_api_template_accesses_list_authenticated_unrelated(): f"/api/v1.0/templates/{template.id!s}/accesses/", ) assert response.status_code == 200 - assert response.json() == { - "count": 0, - "next": None, - "previous": None, - "results": [], - } + assert response.json() == [] @pytest.mark.parametrize("via", VIA) @@ -96,8 +91,8 @@ def test_api_template_accesses_list_authenticated_related(via, mock_user_teams): assert response.status_code == 200 content = response.json() - assert len(content["results"]) == 3 - assert sorted(content["results"], key=lambda x: x["id"]) == sorted( + assert len(content) == 3 + assert sorted(content, key=lambda x: x["id"]) == sorted( [ { "id": str(user_access.id),