♻️(backend) refactor resource access viewset

The document viewset was overriding the get_queryset method from its
own mixin. This was a sign that the mixin was not optimal anymore.
In the next commit I will need to complexify it further so it's time
to refactor the mixin.
This commit is contained in:
Samuel Paccoud - DINUM
2025-04-12 09:11:33 +02:00
committed by Anthony LC
parent 0a5887c162
commit a7c91f9443
3 changed files with 59 additions and 92 deletions

View File

@@ -223,14 +223,10 @@ class UserViewSet(
class ResourceAccessViewsetMixin: class ResourceAccessViewsetMixin:
"""Mixin with methods common to all access viewsets.""" """Mixin with methods common to all access viewsets."""
def get_permissions(self): def filter_queryset(self, queryset):
"""User only needs to be authenticated to list resource accesses""" """Override to filter on related resource."""
if self.action == "list": queryset = super().filter_queryset(queryset)
permission_classes = [permissions.IsAuthenticated] return queryset.filter(**{self.resource_field_name: self.kwargs["resource_id"]})
else:
return super().get_permissions()
return [permission() for permission in permission_classes]
def get_serializer_context(self): def get_serializer_context(self):
"""Extra context provided to the serializer class.""" """Extra context provided to the serializer class."""
@@ -238,43 +234,6 @@ class ResourceAccessViewsetMixin:
context["resource_id"] = self.kwargs["resource_id"] context["resource_id"] = self.kwargs["resource_id"]
return context return context
def get_queryset(self):
"""Return the queryset according to the action."""
queryset = super().get_queryset()
queryset = queryset.filter(
**{self.resource_field_name: self.kwargs["resource_id"]}
)
if self.action == "list":
user = self.request.user
teams = user.teams
user_roles_query = (
queryset.filter(
db.Q(user=user) | db.Q(team__in=teams),
**{self.resource_field_name: self.kwargs["resource_id"]},
)
.values(self.resource_field_name)
.annotate(roles_array=ArrayAgg("role"))
.values("roles_array")
)
# Limit to resource access instances related to a resource THAT also has
# a resource access
# instance for the logged-in user (we don't want to list only the resource
# access instances pointing to the logged-in user)
queryset = (
queryset.filter(
db.Q(**{f"{self.resource_field_name}__accesses__user": user})
| db.Q(
**{f"{self.resource_field_name}__accesses__team__in": teams}
),
**{self.resource_field_name: self.kwargs["resource_id"]},
)
.annotate(user_roles=db.Subquery(user_roles_query))
.distinct()
)
return queryset
def destroy(self, request, *args, **kwargs): def destroy(self, request, *args, **kwargs):
"""Forbid deleting the last owner access""" """Forbid deleting the last owner access"""
instance = self.get_object() instance = self.get_object()
@@ -1551,7 +1510,11 @@ class DocumentViewSet(
class DocumentAccessViewSet( class DocumentAccessViewSet(
ResourceAccessViewsetMixin, ResourceAccessViewsetMixin,
viewsets.ModelViewSet, drf.mixins.CreateModelMixin,
drf.mixins.RetrieveModelMixin,
drf.mixins.UpdateModelMixin,
drf.mixins.DestroyModelMixin,
viewsets.GenericViewSet,
): ):
""" """
API ViewSet for all interactions with document accesses. API ViewSet for all interactions with document accesses.
@@ -1578,31 +1541,35 @@ class DocumentAccessViewSet(
""" """
lookup_field = "pk" lookup_field = "pk"
pagination_class = Pagination
permission_classes = [permissions.IsAuthenticated, permissions.AccessPermission] permission_classes = [permissions.IsAuthenticated, permissions.AccessPermission]
queryset = models.DocumentAccess.objects.select_related("user").all() queryset = models.DocumentAccess.objects.select_related("user").all()
resource_field_name = "document" resource_field_name = "document"
serializer_class = serializers.DocumentAccessSerializer serializer_class = serializers.DocumentAccessSerializer
is_current_user_owner_or_admin = False is_current_user_owner_or_admin = False
def get_queryset(self): def list(self, request, *args, **kwargs):
"""Return the queryset according to the action.""" """Return accesses for the current document with filters and annotations."""
queryset = super().get_queryset() user = self.request.user
queryset = self.filter_queryset(self.get_queryset())
if self.action == "list": try:
try: document = models.Document.objects.get(pk=self.kwargs["resource_id"])
document = models.Document.objects.get(pk=self.kwargs["resource_id"]) except models.Document.DoesNotExist:
except models.Document.DoesNotExist: return drf.response.Response([])
return queryset.none()
roles = set(document.get_roles(self.request.user)) roles = set(document.get_roles(user))
is_owner_or_admin = bool(roles.intersection(set(models.PRIVILEGED_ROLES))) if not roles:
self.is_current_user_owner_or_admin = is_owner_or_admin return drf.response.Response([])
if not is_owner_or_admin:
# Return only the document owner access
queryset = queryset.filter(role__in=models.PRIVILEGED_ROLES)
return queryset is_owner_or_admin = bool(roles.intersection(set(models.PRIVILEGED_ROLES)))
self.is_current_user_owner_or_admin = is_owner_or_admin
if not is_owner_or_admin:
# Return only the document's privileged accesses
queryset = queryset.filter(role__in=models.PRIVILEGED_ROLES)
queryset = queryset.distinct()
serializer = self.get_serializer(queryset, many=True)
return drf.response.Response(serializer.data)
def get_serializer_class(self): def get_serializer_class(self):
if self.action == "list" and not self.is_current_user_owner_or_admin: if self.action == "list" and not self.is_current_user_owner_or_admin:
@@ -1720,7 +1687,6 @@ class TemplateAccessViewSet(
ResourceAccessViewsetMixin, ResourceAccessViewsetMixin,
drf.mixins.CreateModelMixin, drf.mixins.CreateModelMixin,
drf.mixins.DestroyModelMixin, drf.mixins.DestroyModelMixin,
drf.mixins.ListModelMixin,
drf.mixins.RetrieveModelMixin, drf.mixins.RetrieveModelMixin,
drf.mixins.UpdateModelMixin, drf.mixins.UpdateModelMixin,
viewsets.GenericViewSet, viewsets.GenericViewSet,
@@ -1750,12 +1716,28 @@ class TemplateAccessViewSet(
""" """
lookup_field = "pk" lookup_field = "pk"
pagination_class = Pagination
permission_classes = [permissions.IsAuthenticated, permissions.AccessPermission] permission_classes = [permissions.IsAuthenticated, permissions.AccessPermission]
queryset = models.TemplateAccess.objects.select_related("user").all() queryset = models.TemplateAccess.objects.select_related("user").all()
resource_field_name = "template" resource_field_name = "template"
serializer_class = serializers.TemplateAccessSerializer serializer_class = serializers.TemplateAccessSerializer
def list(self, request, *args, **kwargs):
"""Restrict templates returned by the list endpoint"""
user = self.request.user
teams = user.teams
queryset = self.filter_queryset(self.get_queryset())
# Limit to resource access instances related to a resource THAT also has
# a resource access instance for the logged-in user (we don't want to list
# only the resource access instances pointing to the logged-in user)
queryset = queryset.filter(
db.Q(template__accesses__user=user)
| db.Q(template__accesses__team__in=teams),
).distinct()
serializer = self.get_serializer(queryset, many=True)
return drf.response.Response(serializer.data)
class InvitationViewset( class InvitationViewset(
drf.mixins.CreateModelMixin, drf.mixins.CreateModelMixin,

View File

@@ -51,12 +51,7 @@ def test_api_document_accesses_list_authenticated_unrelated():
f"/api/v1.0/documents/{document.id!s}/accesses/", f"/api/v1.0/documents/{document.id!s}/accesses/",
) )
assert response.status_code == 200 assert response.status_code == 200
assert response.json() == { assert response.json() == []
"count": 0,
"next": None,
"previous": None,
"results": [],
}
def test_api_document_accesses_list_unexisting_document(): def test_api_document_accesses_list_unexisting_document():
@@ -70,12 +65,7 @@ def test_api_document_accesses_list_unexisting_document():
response = client.get(f"/api/v1.0/documents/{uuid4()!s}/accesses/") response = client.get(f"/api/v1.0/documents/{uuid4()!s}/accesses/")
assert response.status_code == 200 assert response.status_code == 200
assert response.json() == { assert response.json() == []
"count": 0,
"next": None,
"previous": None,
"results": [],
}
@pytest.mark.parametrize("via", VIA) @pytest.mark.parametrize("via", VIA)
@@ -129,14 +119,14 @@ def test_api_document_accesses_list_authenticated_related_non_privileged(
f"/api/v1.0/documents/{document.id!s}/accesses/", f"/api/v1.0/documents/{document.id!s}/accesses/",
) )
# Return only owners # Return only privileged roles
owners_accesses = [ privileged_accesses = [
access for access in accesses if access.role in models.PRIVILEGED_ROLES access for access in accesses if access.role in models.PRIVILEGED_ROLES
] ]
assert response.status_code == 200 assert response.status_code == 200
content = response.json() content = response.json()
assert content["count"] == len(owners_accesses) assert len(content) == len(privileged_accesses)
assert sorted(content["results"], key=lambda x: x["id"]) == sorted( assert sorted(content, key=lambda x: x["id"]) == sorted(
[ [
{ {
"id": str(access.id), "id": str(access.id),
@@ -152,12 +142,12 @@ def test_api_document_accesses_list_authenticated_related_non_privileged(
"role": access.role, "role": access.role,
"abilities": access.get_abilities(user), "abilities": access.get_abilities(user),
} }
for access in owners_accesses for access in privileged_accesses
], ],
key=lambda x: x["id"], key=lambda x: x["id"],
) )
for access in content["results"]: for access in content:
assert access["role"] in models.PRIVILEGED_ROLES assert access["role"] in models.PRIVILEGED_ROLES
@@ -216,8 +206,8 @@ def test_api_document_accesses_list_authenticated_related_privileged_roles(
assert response.status_code == 200 assert response.status_code == 200
content = response.json() content = response.json()
assert len(content["results"]) == 4 assert len(content) == 4
assert sorted(content["results"], key=lambda x: x["id"]) == sorted( assert sorted(content, key=lambda x: x["id"]) == sorted(
[ [
{ {
"id": str(user_access.id), "id": str(user_access.id),

View File

@@ -48,12 +48,7 @@ def test_api_template_accesses_list_authenticated_unrelated():
f"/api/v1.0/templates/{template.id!s}/accesses/", f"/api/v1.0/templates/{template.id!s}/accesses/",
) )
assert response.status_code == 200 assert response.status_code == 200
assert response.json() == { assert response.json() == []
"count": 0,
"next": None,
"previous": None,
"results": [],
}
@pytest.mark.parametrize("via", VIA) @pytest.mark.parametrize("via", VIA)
@@ -96,8 +91,8 @@ def test_api_template_accesses_list_authenticated_related(via, mock_user_teams):
assert response.status_code == 200 assert response.status_code == 200
content = response.json() content = response.json()
assert len(content["results"]) == 3 assert len(content) == 3
assert sorted(content["results"], key=lambda x: x["id"]) == sorted( assert sorted(content, key=lambda x: x["id"]) == sorted(
[ [
{ {
"id": str(user_access.id), "id": str(user_access.id),