♻️(backend) refactor resource access viewset

The document viewset was overriding the get_queryset method from its
own mixin. This was a sign that the mixin was not optimal anymore.
In the next commit I will need to complexify it further so it's time
to refactor the mixin.
This commit is contained in:
Samuel Paccoud - DINUM
2025-04-12 09:11:33 +02:00
committed by Anthony LC
parent 0a5887c162
commit a7c91f9443
3 changed files with 59 additions and 92 deletions

View File

@@ -223,14 +223,10 @@ class UserViewSet(
class ResourceAccessViewsetMixin:
"""Mixin with methods common to all access viewsets."""
def get_permissions(self):
"""User only needs to be authenticated to list resource accesses"""
if self.action == "list":
permission_classes = [permissions.IsAuthenticated]
else:
return super().get_permissions()
return [permission() for permission in permission_classes]
def filter_queryset(self, queryset):
"""Override to filter on related resource."""
queryset = super().filter_queryset(queryset)
return queryset.filter(**{self.resource_field_name: self.kwargs["resource_id"]})
def get_serializer_context(self):
"""Extra context provided to the serializer class."""
@@ -238,43 +234,6 @@ class ResourceAccessViewsetMixin:
context["resource_id"] = self.kwargs["resource_id"]
return context
def get_queryset(self):
"""Return the queryset according to the action."""
queryset = super().get_queryset()
queryset = queryset.filter(
**{self.resource_field_name: self.kwargs["resource_id"]}
)
if self.action == "list":
user = self.request.user
teams = user.teams
user_roles_query = (
queryset.filter(
db.Q(user=user) | db.Q(team__in=teams),
**{self.resource_field_name: self.kwargs["resource_id"]},
)
.values(self.resource_field_name)
.annotate(roles_array=ArrayAgg("role"))
.values("roles_array")
)
# Limit to resource access instances related to a resource THAT also has
# a resource access
# instance for the logged-in user (we don't want to list only the resource
# access instances pointing to the logged-in user)
queryset = (
queryset.filter(
db.Q(**{f"{self.resource_field_name}__accesses__user": user})
| db.Q(
**{f"{self.resource_field_name}__accesses__team__in": teams}
),
**{self.resource_field_name: self.kwargs["resource_id"]},
)
.annotate(user_roles=db.Subquery(user_roles_query))
.distinct()
)
return queryset
def destroy(self, request, *args, **kwargs):
"""Forbid deleting the last owner access"""
instance = self.get_object()
@@ -1551,7 +1510,11 @@ class DocumentViewSet(
class DocumentAccessViewSet(
ResourceAccessViewsetMixin,
viewsets.ModelViewSet,
drf.mixins.CreateModelMixin,
drf.mixins.RetrieveModelMixin,
drf.mixins.UpdateModelMixin,
drf.mixins.DestroyModelMixin,
viewsets.GenericViewSet,
):
"""
API ViewSet for all interactions with document accesses.
@@ -1578,31 +1541,35 @@ class DocumentAccessViewSet(
"""
lookup_field = "pk"
pagination_class = Pagination
permission_classes = [permissions.IsAuthenticated, permissions.AccessPermission]
queryset = models.DocumentAccess.objects.select_related("user").all()
resource_field_name = "document"
serializer_class = serializers.DocumentAccessSerializer
is_current_user_owner_or_admin = False
def get_queryset(self):
"""Return the queryset according to the action."""
queryset = super().get_queryset()
def list(self, request, *args, **kwargs):
"""Return accesses for the current document with filters and annotations."""
user = self.request.user
queryset = self.filter_queryset(self.get_queryset())
if self.action == "list":
try:
document = models.Document.objects.get(pk=self.kwargs["resource_id"])
except models.Document.DoesNotExist:
return queryset.none()
try:
document = models.Document.objects.get(pk=self.kwargs["resource_id"])
except models.Document.DoesNotExist:
return drf.response.Response([])
roles = set(document.get_roles(self.request.user))
is_owner_or_admin = bool(roles.intersection(set(models.PRIVILEGED_ROLES)))
self.is_current_user_owner_or_admin = is_owner_or_admin
if not is_owner_or_admin:
# Return only the document owner access
queryset = queryset.filter(role__in=models.PRIVILEGED_ROLES)
roles = set(document.get_roles(user))
if not roles:
return drf.response.Response([])
return queryset
is_owner_or_admin = bool(roles.intersection(set(models.PRIVILEGED_ROLES)))
self.is_current_user_owner_or_admin = is_owner_or_admin
if not is_owner_or_admin:
# Return only the document's privileged accesses
queryset = queryset.filter(role__in=models.PRIVILEGED_ROLES)
queryset = queryset.distinct()
serializer = self.get_serializer(queryset, many=True)
return drf.response.Response(serializer.data)
def get_serializer_class(self):
if self.action == "list" and not self.is_current_user_owner_or_admin:
@@ -1720,7 +1687,6 @@ class TemplateAccessViewSet(
ResourceAccessViewsetMixin,
drf.mixins.CreateModelMixin,
drf.mixins.DestroyModelMixin,
drf.mixins.ListModelMixin,
drf.mixins.RetrieveModelMixin,
drf.mixins.UpdateModelMixin,
viewsets.GenericViewSet,
@@ -1750,12 +1716,28 @@ class TemplateAccessViewSet(
"""
lookup_field = "pk"
pagination_class = Pagination
permission_classes = [permissions.IsAuthenticated, permissions.AccessPermission]
queryset = models.TemplateAccess.objects.select_related("user").all()
resource_field_name = "template"
serializer_class = serializers.TemplateAccessSerializer
def list(self, request, *args, **kwargs):
"""Restrict templates returned by the list endpoint"""
user = self.request.user
teams = user.teams
queryset = self.filter_queryset(self.get_queryset())
# Limit to resource access instances related to a resource THAT also has
# a resource access instance for the logged-in user (we don't want to list
# only the resource access instances pointing to the logged-in user)
queryset = queryset.filter(
db.Q(template__accesses__user=user)
| db.Q(template__accesses__team__in=teams),
).distinct()
serializer = self.get_serializer(queryset, many=True)
return drf.response.Response(serializer.data)
class InvitationViewset(
drf.mixins.CreateModelMixin,

View File

@@ -51,12 +51,7 @@ def test_api_document_accesses_list_authenticated_unrelated():
f"/api/v1.0/documents/{document.id!s}/accesses/",
)
assert response.status_code == 200
assert response.json() == {
"count": 0,
"next": None,
"previous": None,
"results": [],
}
assert response.json() == []
def test_api_document_accesses_list_unexisting_document():
@@ -70,12 +65,7 @@ def test_api_document_accesses_list_unexisting_document():
response = client.get(f"/api/v1.0/documents/{uuid4()!s}/accesses/")
assert response.status_code == 200
assert response.json() == {
"count": 0,
"next": None,
"previous": None,
"results": [],
}
assert response.json() == []
@pytest.mark.parametrize("via", VIA)
@@ -129,14 +119,14 @@ def test_api_document_accesses_list_authenticated_related_non_privileged(
f"/api/v1.0/documents/{document.id!s}/accesses/",
)
# Return only owners
owners_accesses = [
# Return only privileged roles
privileged_accesses = [
access for access in accesses if access.role in models.PRIVILEGED_ROLES
]
assert response.status_code == 200
content = response.json()
assert content["count"] == len(owners_accesses)
assert sorted(content["results"], key=lambda x: x["id"]) == sorted(
assert len(content) == len(privileged_accesses)
assert sorted(content, key=lambda x: x["id"]) == sorted(
[
{
"id": str(access.id),
@@ -152,12 +142,12 @@ def test_api_document_accesses_list_authenticated_related_non_privileged(
"role": access.role,
"abilities": access.get_abilities(user),
}
for access in owners_accesses
for access in privileged_accesses
],
key=lambda x: x["id"],
)
for access in content["results"]:
for access in content:
assert access["role"] in models.PRIVILEGED_ROLES
@@ -216,8 +206,8 @@ def test_api_document_accesses_list_authenticated_related_privileged_roles(
assert response.status_code == 200
content = response.json()
assert len(content["results"]) == 4
assert sorted(content["results"], key=lambda x: x["id"]) == sorted(
assert len(content) == 4
assert sorted(content, key=lambda x: x["id"]) == sorted(
[
{
"id": str(user_access.id),

View File

@@ -48,12 +48,7 @@ def test_api_template_accesses_list_authenticated_unrelated():
f"/api/v1.0/templates/{template.id!s}/accesses/",
)
assert response.status_code == 200
assert response.json() == {
"count": 0,
"next": None,
"previous": None,
"results": [],
}
assert response.json() == []
@pytest.mark.parametrize("via", VIA)
@@ -96,8 +91,8 @@ def test_api_template_accesses_list_authenticated_related(via, mock_user_teams):
assert response.status_code == 200
content = response.json()
assert len(content["results"]) == 3
assert sorted(content["results"], key=lambda x: x["id"]) == sorted(
assert len(content) == 3
assert sorted(content, key=lambda x: x["id"]) == sorted(
[
{
"id": str(user_access.id),