♻️(backend) refactor resource access viewset
The document viewset was overriding the get_queryset method from its own mixin. This was a sign that the mixin was not optimal anymore. In the next commit I will need to complexify it further so it's time to refactor the mixin.
This commit is contained in:
committed by
Anthony LC
parent
0a5887c162
commit
a7c91f9443
@@ -223,14 +223,10 @@ class UserViewSet(
|
||||
class ResourceAccessViewsetMixin:
|
||||
"""Mixin with methods common to all access viewsets."""
|
||||
|
||||
def get_permissions(self):
|
||||
"""User only needs to be authenticated to list resource accesses"""
|
||||
if self.action == "list":
|
||||
permission_classes = [permissions.IsAuthenticated]
|
||||
else:
|
||||
return super().get_permissions()
|
||||
|
||||
return [permission() for permission in permission_classes]
|
||||
def filter_queryset(self, queryset):
|
||||
"""Override to filter on related resource."""
|
||||
queryset = super().filter_queryset(queryset)
|
||||
return queryset.filter(**{self.resource_field_name: self.kwargs["resource_id"]})
|
||||
|
||||
def get_serializer_context(self):
|
||||
"""Extra context provided to the serializer class."""
|
||||
@@ -238,43 +234,6 @@ class ResourceAccessViewsetMixin:
|
||||
context["resource_id"] = self.kwargs["resource_id"]
|
||||
return context
|
||||
|
||||
def get_queryset(self):
|
||||
"""Return the queryset according to the action."""
|
||||
queryset = super().get_queryset()
|
||||
queryset = queryset.filter(
|
||||
**{self.resource_field_name: self.kwargs["resource_id"]}
|
||||
)
|
||||
|
||||
if self.action == "list":
|
||||
user = self.request.user
|
||||
teams = user.teams
|
||||
user_roles_query = (
|
||||
queryset.filter(
|
||||
db.Q(user=user) | db.Q(team__in=teams),
|
||||
**{self.resource_field_name: self.kwargs["resource_id"]},
|
||||
)
|
||||
.values(self.resource_field_name)
|
||||
.annotate(roles_array=ArrayAgg("role"))
|
||||
.values("roles_array")
|
||||
)
|
||||
|
||||
# Limit to resource access instances related to a resource THAT also has
|
||||
# a resource access
|
||||
# instance for the logged-in user (we don't want to list only the resource
|
||||
# access instances pointing to the logged-in user)
|
||||
queryset = (
|
||||
queryset.filter(
|
||||
db.Q(**{f"{self.resource_field_name}__accesses__user": user})
|
||||
| db.Q(
|
||||
**{f"{self.resource_field_name}__accesses__team__in": teams}
|
||||
),
|
||||
**{self.resource_field_name: self.kwargs["resource_id"]},
|
||||
)
|
||||
.annotate(user_roles=db.Subquery(user_roles_query))
|
||||
.distinct()
|
||||
)
|
||||
return queryset
|
||||
|
||||
def destroy(self, request, *args, **kwargs):
|
||||
"""Forbid deleting the last owner access"""
|
||||
instance = self.get_object()
|
||||
@@ -1551,7 +1510,11 @@ class DocumentViewSet(
|
||||
|
||||
class DocumentAccessViewSet(
|
||||
ResourceAccessViewsetMixin,
|
||||
viewsets.ModelViewSet,
|
||||
drf.mixins.CreateModelMixin,
|
||||
drf.mixins.RetrieveModelMixin,
|
||||
drf.mixins.UpdateModelMixin,
|
||||
drf.mixins.DestroyModelMixin,
|
||||
viewsets.GenericViewSet,
|
||||
):
|
||||
"""
|
||||
API ViewSet for all interactions with document accesses.
|
||||
@@ -1578,31 +1541,35 @@ class DocumentAccessViewSet(
|
||||
"""
|
||||
|
||||
lookup_field = "pk"
|
||||
pagination_class = Pagination
|
||||
permission_classes = [permissions.IsAuthenticated, permissions.AccessPermission]
|
||||
queryset = models.DocumentAccess.objects.select_related("user").all()
|
||||
resource_field_name = "document"
|
||||
serializer_class = serializers.DocumentAccessSerializer
|
||||
is_current_user_owner_or_admin = False
|
||||
|
||||
def get_queryset(self):
|
||||
"""Return the queryset according to the action."""
|
||||
queryset = super().get_queryset()
|
||||
def list(self, request, *args, **kwargs):
|
||||
"""Return accesses for the current document with filters and annotations."""
|
||||
user = self.request.user
|
||||
queryset = self.filter_queryset(self.get_queryset())
|
||||
|
||||
if self.action == "list":
|
||||
try:
|
||||
document = models.Document.objects.get(pk=self.kwargs["resource_id"])
|
||||
except models.Document.DoesNotExist:
|
||||
return queryset.none()
|
||||
try:
|
||||
document = models.Document.objects.get(pk=self.kwargs["resource_id"])
|
||||
except models.Document.DoesNotExist:
|
||||
return drf.response.Response([])
|
||||
|
||||
roles = set(document.get_roles(self.request.user))
|
||||
is_owner_or_admin = bool(roles.intersection(set(models.PRIVILEGED_ROLES)))
|
||||
self.is_current_user_owner_or_admin = is_owner_or_admin
|
||||
if not is_owner_or_admin:
|
||||
# Return only the document owner access
|
||||
queryset = queryset.filter(role__in=models.PRIVILEGED_ROLES)
|
||||
roles = set(document.get_roles(user))
|
||||
if not roles:
|
||||
return drf.response.Response([])
|
||||
|
||||
return queryset
|
||||
is_owner_or_admin = bool(roles.intersection(set(models.PRIVILEGED_ROLES)))
|
||||
self.is_current_user_owner_or_admin = is_owner_or_admin
|
||||
if not is_owner_or_admin:
|
||||
# Return only the document's privileged accesses
|
||||
queryset = queryset.filter(role__in=models.PRIVILEGED_ROLES)
|
||||
|
||||
queryset = queryset.distinct()
|
||||
serializer = self.get_serializer(queryset, many=True)
|
||||
return drf.response.Response(serializer.data)
|
||||
|
||||
def get_serializer_class(self):
|
||||
if self.action == "list" and not self.is_current_user_owner_or_admin:
|
||||
@@ -1720,7 +1687,6 @@ class TemplateAccessViewSet(
|
||||
ResourceAccessViewsetMixin,
|
||||
drf.mixins.CreateModelMixin,
|
||||
drf.mixins.DestroyModelMixin,
|
||||
drf.mixins.ListModelMixin,
|
||||
drf.mixins.RetrieveModelMixin,
|
||||
drf.mixins.UpdateModelMixin,
|
||||
viewsets.GenericViewSet,
|
||||
@@ -1750,12 +1716,28 @@ class TemplateAccessViewSet(
|
||||
"""
|
||||
|
||||
lookup_field = "pk"
|
||||
pagination_class = Pagination
|
||||
permission_classes = [permissions.IsAuthenticated, permissions.AccessPermission]
|
||||
queryset = models.TemplateAccess.objects.select_related("user").all()
|
||||
resource_field_name = "template"
|
||||
serializer_class = serializers.TemplateAccessSerializer
|
||||
|
||||
def list(self, request, *args, **kwargs):
|
||||
"""Restrict templates returned by the list endpoint"""
|
||||
user = self.request.user
|
||||
teams = user.teams
|
||||
queryset = self.filter_queryset(self.get_queryset())
|
||||
|
||||
# Limit to resource access instances related to a resource THAT also has
|
||||
# a resource access instance for the logged-in user (we don't want to list
|
||||
# only the resource access instances pointing to the logged-in user)
|
||||
queryset = queryset.filter(
|
||||
db.Q(template__accesses__user=user)
|
||||
| db.Q(template__accesses__team__in=teams),
|
||||
).distinct()
|
||||
|
||||
serializer = self.get_serializer(queryset, many=True)
|
||||
return drf.response.Response(serializer.data)
|
||||
|
||||
|
||||
class InvitationViewset(
|
||||
drf.mixins.CreateModelMixin,
|
||||
|
||||
@@ -51,12 +51,7 @@ def test_api_document_accesses_list_authenticated_unrelated():
|
||||
f"/api/v1.0/documents/{document.id!s}/accesses/",
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {
|
||||
"count": 0,
|
||||
"next": None,
|
||||
"previous": None,
|
||||
"results": [],
|
||||
}
|
||||
assert response.json() == []
|
||||
|
||||
|
||||
def test_api_document_accesses_list_unexisting_document():
|
||||
@@ -70,12 +65,7 @@ def test_api_document_accesses_list_unexisting_document():
|
||||
|
||||
response = client.get(f"/api/v1.0/documents/{uuid4()!s}/accesses/")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {
|
||||
"count": 0,
|
||||
"next": None,
|
||||
"previous": None,
|
||||
"results": [],
|
||||
}
|
||||
assert response.json() == []
|
||||
|
||||
|
||||
@pytest.mark.parametrize("via", VIA)
|
||||
@@ -129,14 +119,14 @@ def test_api_document_accesses_list_authenticated_related_non_privileged(
|
||||
f"/api/v1.0/documents/{document.id!s}/accesses/",
|
||||
)
|
||||
|
||||
# Return only owners
|
||||
owners_accesses = [
|
||||
# Return only privileged roles
|
||||
privileged_accesses = [
|
||||
access for access in accesses if access.role in models.PRIVILEGED_ROLES
|
||||
]
|
||||
assert response.status_code == 200
|
||||
content = response.json()
|
||||
assert content["count"] == len(owners_accesses)
|
||||
assert sorted(content["results"], key=lambda x: x["id"]) == sorted(
|
||||
assert len(content) == len(privileged_accesses)
|
||||
assert sorted(content, key=lambda x: x["id"]) == sorted(
|
||||
[
|
||||
{
|
||||
"id": str(access.id),
|
||||
@@ -152,12 +142,12 @@ def test_api_document_accesses_list_authenticated_related_non_privileged(
|
||||
"role": access.role,
|
||||
"abilities": access.get_abilities(user),
|
||||
}
|
||||
for access in owners_accesses
|
||||
for access in privileged_accesses
|
||||
],
|
||||
key=lambda x: x["id"],
|
||||
)
|
||||
|
||||
for access in content["results"]:
|
||||
for access in content:
|
||||
assert access["role"] in models.PRIVILEGED_ROLES
|
||||
|
||||
|
||||
@@ -216,8 +206,8 @@ def test_api_document_accesses_list_authenticated_related_privileged_roles(
|
||||
|
||||
assert response.status_code == 200
|
||||
content = response.json()
|
||||
assert len(content["results"]) == 4
|
||||
assert sorted(content["results"], key=lambda x: x["id"]) == sorted(
|
||||
assert len(content) == 4
|
||||
assert sorted(content, key=lambda x: x["id"]) == sorted(
|
||||
[
|
||||
{
|
||||
"id": str(user_access.id),
|
||||
|
||||
@@ -48,12 +48,7 @@ def test_api_template_accesses_list_authenticated_unrelated():
|
||||
f"/api/v1.0/templates/{template.id!s}/accesses/",
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {
|
||||
"count": 0,
|
||||
"next": None,
|
||||
"previous": None,
|
||||
"results": [],
|
||||
}
|
||||
assert response.json() == []
|
||||
|
||||
|
||||
@pytest.mark.parametrize("via", VIA)
|
||||
@@ -96,8 +91,8 @@ def test_api_template_accesses_list_authenticated_related(via, mock_user_teams):
|
||||
|
||||
assert response.status_code == 200
|
||||
content = response.json()
|
||||
assert len(content["results"]) == 3
|
||||
assert sorted(content["results"], key=lambda x: x["id"]) == sorted(
|
||||
assert len(content) == 3
|
||||
assert sorted(content, key=lambda x: x["id"]) == sorted(
|
||||
[
|
||||
{
|
||||
"id": str(user_access.id),
|
||||
|
||||
Reference in New Issue
Block a user