♻️(backend) refactor resource access viewset
The document viewset was overriding the get_queryset method from its own mixin. This was a sign that the mixin was not optimal anymore. In the next commit I will need to complexify it further so it's time to refactor the mixin.
This commit is contained in:
committed by
Anthony LC
parent
0a5887c162
commit
a7c91f9443
@@ -223,14 +223,10 @@ class UserViewSet(
|
|||||||
class ResourceAccessViewsetMixin:
|
class ResourceAccessViewsetMixin:
|
||||||
"""Mixin with methods common to all access viewsets."""
|
"""Mixin with methods common to all access viewsets."""
|
||||||
|
|
||||||
def get_permissions(self):
|
def filter_queryset(self, queryset):
|
||||||
"""User only needs to be authenticated to list resource accesses"""
|
"""Override to filter on related resource."""
|
||||||
if self.action == "list":
|
queryset = super().filter_queryset(queryset)
|
||||||
permission_classes = [permissions.IsAuthenticated]
|
return queryset.filter(**{self.resource_field_name: self.kwargs["resource_id"]})
|
||||||
else:
|
|
||||||
return super().get_permissions()
|
|
||||||
|
|
||||||
return [permission() for permission in permission_classes]
|
|
||||||
|
|
||||||
def get_serializer_context(self):
|
def get_serializer_context(self):
|
||||||
"""Extra context provided to the serializer class."""
|
"""Extra context provided to the serializer class."""
|
||||||
@@ -238,43 +234,6 @@ class ResourceAccessViewsetMixin:
|
|||||||
context["resource_id"] = self.kwargs["resource_id"]
|
context["resource_id"] = self.kwargs["resource_id"]
|
||||||
return context
|
return context
|
||||||
|
|
||||||
def get_queryset(self):
|
|
||||||
"""Return the queryset according to the action."""
|
|
||||||
queryset = super().get_queryset()
|
|
||||||
queryset = queryset.filter(
|
|
||||||
**{self.resource_field_name: self.kwargs["resource_id"]}
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.action == "list":
|
|
||||||
user = self.request.user
|
|
||||||
teams = user.teams
|
|
||||||
user_roles_query = (
|
|
||||||
queryset.filter(
|
|
||||||
db.Q(user=user) | db.Q(team__in=teams),
|
|
||||||
**{self.resource_field_name: self.kwargs["resource_id"]},
|
|
||||||
)
|
|
||||||
.values(self.resource_field_name)
|
|
||||||
.annotate(roles_array=ArrayAgg("role"))
|
|
||||||
.values("roles_array")
|
|
||||||
)
|
|
||||||
|
|
||||||
# Limit to resource access instances related to a resource THAT also has
|
|
||||||
# a resource access
|
|
||||||
# instance for the logged-in user (we don't want to list only the resource
|
|
||||||
# access instances pointing to the logged-in user)
|
|
||||||
queryset = (
|
|
||||||
queryset.filter(
|
|
||||||
db.Q(**{f"{self.resource_field_name}__accesses__user": user})
|
|
||||||
| db.Q(
|
|
||||||
**{f"{self.resource_field_name}__accesses__team__in": teams}
|
|
||||||
),
|
|
||||||
**{self.resource_field_name: self.kwargs["resource_id"]},
|
|
||||||
)
|
|
||||||
.annotate(user_roles=db.Subquery(user_roles_query))
|
|
||||||
.distinct()
|
|
||||||
)
|
|
||||||
return queryset
|
|
||||||
|
|
||||||
def destroy(self, request, *args, **kwargs):
|
def destroy(self, request, *args, **kwargs):
|
||||||
"""Forbid deleting the last owner access"""
|
"""Forbid deleting the last owner access"""
|
||||||
instance = self.get_object()
|
instance = self.get_object()
|
||||||
@@ -1551,7 +1510,11 @@ class DocumentViewSet(
|
|||||||
|
|
||||||
class DocumentAccessViewSet(
|
class DocumentAccessViewSet(
|
||||||
ResourceAccessViewsetMixin,
|
ResourceAccessViewsetMixin,
|
||||||
viewsets.ModelViewSet,
|
drf.mixins.CreateModelMixin,
|
||||||
|
drf.mixins.RetrieveModelMixin,
|
||||||
|
drf.mixins.UpdateModelMixin,
|
||||||
|
drf.mixins.DestroyModelMixin,
|
||||||
|
viewsets.GenericViewSet,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
API ViewSet for all interactions with document accesses.
|
API ViewSet for all interactions with document accesses.
|
||||||
@@ -1578,31 +1541,35 @@ class DocumentAccessViewSet(
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
lookup_field = "pk"
|
lookup_field = "pk"
|
||||||
pagination_class = Pagination
|
|
||||||
permission_classes = [permissions.IsAuthenticated, permissions.AccessPermission]
|
permission_classes = [permissions.IsAuthenticated, permissions.AccessPermission]
|
||||||
queryset = models.DocumentAccess.objects.select_related("user").all()
|
queryset = models.DocumentAccess.objects.select_related("user").all()
|
||||||
resource_field_name = "document"
|
resource_field_name = "document"
|
||||||
serializer_class = serializers.DocumentAccessSerializer
|
serializer_class = serializers.DocumentAccessSerializer
|
||||||
is_current_user_owner_or_admin = False
|
is_current_user_owner_or_admin = False
|
||||||
|
|
||||||
def get_queryset(self):
|
def list(self, request, *args, **kwargs):
|
||||||
"""Return the queryset according to the action."""
|
"""Return accesses for the current document with filters and annotations."""
|
||||||
queryset = super().get_queryset()
|
user = self.request.user
|
||||||
|
queryset = self.filter_queryset(self.get_queryset())
|
||||||
|
|
||||||
if self.action == "list":
|
try:
|
||||||
try:
|
document = models.Document.objects.get(pk=self.kwargs["resource_id"])
|
||||||
document = models.Document.objects.get(pk=self.kwargs["resource_id"])
|
except models.Document.DoesNotExist:
|
||||||
except models.Document.DoesNotExist:
|
return drf.response.Response([])
|
||||||
return queryset.none()
|
|
||||||
|
|
||||||
roles = set(document.get_roles(self.request.user))
|
roles = set(document.get_roles(user))
|
||||||
is_owner_or_admin = bool(roles.intersection(set(models.PRIVILEGED_ROLES)))
|
if not roles:
|
||||||
self.is_current_user_owner_or_admin = is_owner_or_admin
|
return drf.response.Response([])
|
||||||
if not is_owner_or_admin:
|
|
||||||
# Return only the document owner access
|
|
||||||
queryset = queryset.filter(role__in=models.PRIVILEGED_ROLES)
|
|
||||||
|
|
||||||
return queryset
|
is_owner_or_admin = bool(roles.intersection(set(models.PRIVILEGED_ROLES)))
|
||||||
|
self.is_current_user_owner_or_admin = is_owner_or_admin
|
||||||
|
if not is_owner_or_admin:
|
||||||
|
# Return only the document's privileged accesses
|
||||||
|
queryset = queryset.filter(role__in=models.PRIVILEGED_ROLES)
|
||||||
|
|
||||||
|
queryset = queryset.distinct()
|
||||||
|
serializer = self.get_serializer(queryset, many=True)
|
||||||
|
return drf.response.Response(serializer.data)
|
||||||
|
|
||||||
def get_serializer_class(self):
|
def get_serializer_class(self):
|
||||||
if self.action == "list" and not self.is_current_user_owner_or_admin:
|
if self.action == "list" and not self.is_current_user_owner_or_admin:
|
||||||
@@ -1720,7 +1687,6 @@ class TemplateAccessViewSet(
|
|||||||
ResourceAccessViewsetMixin,
|
ResourceAccessViewsetMixin,
|
||||||
drf.mixins.CreateModelMixin,
|
drf.mixins.CreateModelMixin,
|
||||||
drf.mixins.DestroyModelMixin,
|
drf.mixins.DestroyModelMixin,
|
||||||
drf.mixins.ListModelMixin,
|
|
||||||
drf.mixins.RetrieveModelMixin,
|
drf.mixins.RetrieveModelMixin,
|
||||||
drf.mixins.UpdateModelMixin,
|
drf.mixins.UpdateModelMixin,
|
||||||
viewsets.GenericViewSet,
|
viewsets.GenericViewSet,
|
||||||
@@ -1750,12 +1716,28 @@ class TemplateAccessViewSet(
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
lookup_field = "pk"
|
lookup_field = "pk"
|
||||||
pagination_class = Pagination
|
|
||||||
permission_classes = [permissions.IsAuthenticated, permissions.AccessPermission]
|
permission_classes = [permissions.IsAuthenticated, permissions.AccessPermission]
|
||||||
queryset = models.TemplateAccess.objects.select_related("user").all()
|
queryset = models.TemplateAccess.objects.select_related("user").all()
|
||||||
resource_field_name = "template"
|
resource_field_name = "template"
|
||||||
serializer_class = serializers.TemplateAccessSerializer
|
serializer_class = serializers.TemplateAccessSerializer
|
||||||
|
|
||||||
|
def list(self, request, *args, **kwargs):
|
||||||
|
"""Restrict templates returned by the list endpoint"""
|
||||||
|
user = self.request.user
|
||||||
|
teams = user.teams
|
||||||
|
queryset = self.filter_queryset(self.get_queryset())
|
||||||
|
|
||||||
|
# Limit to resource access instances related to a resource THAT also has
|
||||||
|
# a resource access instance for the logged-in user (we don't want to list
|
||||||
|
# only the resource access instances pointing to the logged-in user)
|
||||||
|
queryset = queryset.filter(
|
||||||
|
db.Q(template__accesses__user=user)
|
||||||
|
| db.Q(template__accesses__team__in=teams),
|
||||||
|
).distinct()
|
||||||
|
|
||||||
|
serializer = self.get_serializer(queryset, many=True)
|
||||||
|
return drf.response.Response(serializer.data)
|
||||||
|
|
||||||
|
|
||||||
class InvitationViewset(
|
class InvitationViewset(
|
||||||
drf.mixins.CreateModelMixin,
|
drf.mixins.CreateModelMixin,
|
||||||
|
|||||||
@@ -51,12 +51,7 @@ def test_api_document_accesses_list_authenticated_unrelated():
|
|||||||
f"/api/v1.0/documents/{document.id!s}/accesses/",
|
f"/api/v1.0/documents/{document.id!s}/accesses/",
|
||||||
)
|
)
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert response.json() == {
|
assert response.json() == []
|
||||||
"count": 0,
|
|
||||||
"next": None,
|
|
||||||
"previous": None,
|
|
||||||
"results": [],
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def test_api_document_accesses_list_unexisting_document():
|
def test_api_document_accesses_list_unexisting_document():
|
||||||
@@ -70,12 +65,7 @@ def test_api_document_accesses_list_unexisting_document():
|
|||||||
|
|
||||||
response = client.get(f"/api/v1.0/documents/{uuid4()!s}/accesses/")
|
response = client.get(f"/api/v1.0/documents/{uuid4()!s}/accesses/")
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert response.json() == {
|
assert response.json() == []
|
||||||
"count": 0,
|
|
||||||
"next": None,
|
|
||||||
"previous": None,
|
|
||||||
"results": [],
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("via", VIA)
|
@pytest.mark.parametrize("via", VIA)
|
||||||
@@ -129,14 +119,14 @@ def test_api_document_accesses_list_authenticated_related_non_privileged(
|
|||||||
f"/api/v1.0/documents/{document.id!s}/accesses/",
|
f"/api/v1.0/documents/{document.id!s}/accesses/",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Return only owners
|
# Return only privileged roles
|
||||||
owners_accesses = [
|
privileged_accesses = [
|
||||||
access for access in accesses if access.role in models.PRIVILEGED_ROLES
|
access for access in accesses if access.role in models.PRIVILEGED_ROLES
|
||||||
]
|
]
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
content = response.json()
|
content = response.json()
|
||||||
assert content["count"] == len(owners_accesses)
|
assert len(content) == len(privileged_accesses)
|
||||||
assert sorted(content["results"], key=lambda x: x["id"]) == sorted(
|
assert sorted(content, key=lambda x: x["id"]) == sorted(
|
||||||
[
|
[
|
||||||
{
|
{
|
||||||
"id": str(access.id),
|
"id": str(access.id),
|
||||||
@@ -152,12 +142,12 @@ def test_api_document_accesses_list_authenticated_related_non_privileged(
|
|||||||
"role": access.role,
|
"role": access.role,
|
||||||
"abilities": access.get_abilities(user),
|
"abilities": access.get_abilities(user),
|
||||||
}
|
}
|
||||||
for access in owners_accesses
|
for access in privileged_accesses
|
||||||
],
|
],
|
||||||
key=lambda x: x["id"],
|
key=lambda x: x["id"],
|
||||||
)
|
)
|
||||||
|
|
||||||
for access in content["results"]:
|
for access in content:
|
||||||
assert access["role"] in models.PRIVILEGED_ROLES
|
assert access["role"] in models.PRIVILEGED_ROLES
|
||||||
|
|
||||||
|
|
||||||
@@ -216,8 +206,8 @@ def test_api_document_accesses_list_authenticated_related_privileged_roles(
|
|||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
content = response.json()
|
content = response.json()
|
||||||
assert len(content["results"]) == 4
|
assert len(content) == 4
|
||||||
assert sorted(content["results"], key=lambda x: x["id"]) == sorted(
|
assert sorted(content, key=lambda x: x["id"]) == sorted(
|
||||||
[
|
[
|
||||||
{
|
{
|
||||||
"id": str(user_access.id),
|
"id": str(user_access.id),
|
||||||
|
|||||||
@@ -48,12 +48,7 @@ def test_api_template_accesses_list_authenticated_unrelated():
|
|||||||
f"/api/v1.0/templates/{template.id!s}/accesses/",
|
f"/api/v1.0/templates/{template.id!s}/accesses/",
|
||||||
)
|
)
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert response.json() == {
|
assert response.json() == []
|
||||||
"count": 0,
|
|
||||||
"next": None,
|
|
||||||
"previous": None,
|
|
||||||
"results": [],
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("via", VIA)
|
@pytest.mark.parametrize("via", VIA)
|
||||||
@@ -96,8 +91,8 @@ def test_api_template_accesses_list_authenticated_related(via, mock_user_teams):
|
|||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
content = response.json()
|
content = response.json()
|
||||||
assert len(content["results"]) == 3
|
assert len(content) == 3
|
||||||
assert sorted(content["results"], key=lambda x: x["id"]) == sorted(
|
assert sorted(content, key=lambda x: x["id"]) == sorted(
|
||||||
[
|
[
|
||||||
{
|
{
|
||||||
"id": str(user_access.id),
|
"id": str(user_access.id),
|
||||||
|
|||||||
Reference in New Issue
Block a user