From 5042f4ca47d441a7212a818462f7cb8eb60e5c52 Mon Sep 17 00:00:00 2001 From: Samuel Paccoud - DINUM Date: Fri, 17 Jan 2025 19:50:03 +0100 Subject: [PATCH] =?UTF-8?q?=E2=99=BB=EF=B8=8F(backend)=20generalize=20use?= =?UTF-8?q?=20of=20SerializerPerActionMixin=20in=20viewsets?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit We improve the serializer and generalize its use for clarity. --- src/backend/core/api/viewsets.py | 33 +++++++++---------- .../documents/test_api_document_accesses.py | 14 ++++---- .../test_api_document_accesses_create.py | 2 +- .../test_api_document_invitations.py | 8 ++--- .../documents/test_api_document_versions.py | 12 +++---- .../templates/test_api_template_accesses.py | 14 ++++---- .../test_api_template_accesses_create.py | 2 +- 7 files changed, 42 insertions(+), 43 deletions(-) diff --git a/src/backend/core/api/viewsets.py b/src/backend/core/api/viewsets.py index 18ff2163..d41a62f9 100644 --- a/src/backend/core/api/viewsets.py +++ b/src/backend/core/api/viewsets.py @@ -117,16 +117,23 @@ class SerializerPerActionMixin: This mixin is useful to avoid to define a serializer class for each action in the `get_serializer_class` method. - """ - serializer_classes: dict[str, type] = {} - default_serializer_class: type = None + Example: + ``` + class MyViewSet(SerializerPerActionMixin, viewsets.GenericViewSet): + serializer_class = MySerializer + list_serializer_class = MyListSerializer + retrieve_serializer_class = MyRetrieveSerializer + ``` + """ def get_serializer_class(self): """ Return the serializer class to use depending on the action. """ - return self.serializer_classes.get(self.action, self.default_serializer_class) + if serializer_class := getattr(self, f"{self.action}_serializer_class", None): + return serializer_class + return super().get_serializer_class() class Pagination(drf.pagination.PageNumberPagination): @@ -314,6 +321,7 @@ class DocumentMetadata(drf.metadata.SimpleMetadata): # pylint: disable=too-many-public-methods class DocumentViewSet( + SerializerPerActionMixin, drf.mixins.CreateModelMixin, drf.mixins.DestroyModelMixin, drf.mixins.ListModelMixin, @@ -424,16 +432,10 @@ class DocumentViewSet( ] queryset = models.Document.objects.all() serializer_class = serializers.DocumentSerializer - - def get_serializer_class(self): - """ - Use ListDocumentSerializer for list actions; otherwise, use DocumentSerializer. - """ - return ( - serializers.ListDocumentSerializer - if self.action == "list" - else self.serializer_class - ) + list_serializer_class = serializers.ListDocumentSerializer + trashbin_serializer_class = serializers.ListDocumentSerializer + children_serializer_class = serializers.ListDocumentSerializer + ai_translate_serializer_class = serializers.AITranslateSerializer def annotate_is_favorite(self, queryset): """ @@ -613,7 +615,6 @@ class DocumentViewSet( @drf.decorators.action( detail=False, methods=["get"], - serializer_class=serializers.ListDocumentSerializer, ) def trashbin(self, request, *args, **kwargs): """ @@ -734,7 +735,6 @@ class DocumentViewSet( detail=True, methods=["get", "post"], ordering=["path"], - serializer_class=serializers.ListDocumentSerializer, url_path="children", ) def children(self, request, *args, **kwargs): @@ -1091,7 +1091,6 @@ class DocumentViewSet( detail=True, methods=["post"], name="Translate a piece of text with AI", - serializer_class=serializers.AITranslateSerializer, url_path="ai-translate", throttle_classes=[utils.AIDocumentRateThrottle, utils.AIUserRateThrottle], ) diff --git a/src/backend/core/tests/documents/test_api_document_accesses.py b/src/backend/core/tests/documents/test_api_document_accesses.py index 5b1ea283..dca6fa79 100644 --- a/src/backend/core/tests/documents/test_api_document_accesses.py +++ b/src/backend/core/tests/documents/test_api_document_accesses.py @@ -76,14 +76,14 @@ def test_api_document_accesses_list_authenticated_related(via, mock_user_teams): user_access = models.DocumentAccess.objects.create( document=document, user=user, - role=random.choice(models.RoleChoices.choices)[0], + role=random.choice(models.RoleChoices.values), ) elif via == TEAM: mock_user_teams.return_value = ["lasuite", "unknown"] user_access = models.DocumentAccess.objects.create( document=document, team="lasuite", - role=random.choice(models.RoleChoices.choices)[0], + role=random.choice(models.RoleChoices.values), ) access1 = factories.TeamDocumentAccessFactory(document=document) @@ -227,7 +227,7 @@ def test_api_document_accesses_update_anonymous(): new_values = { "id": uuid4(), "user": factories.UserFactory().id, - "role": random.choice(models.RoleChoices.choices)[0], + "role": random.choice(models.RoleChoices.values), } api_client = APIClient() @@ -260,7 +260,7 @@ def test_api_document_accesses_update_authenticated_unrelated(): new_values = { "id": uuid4(), "user": factories.UserFactory().id, - "role": random.choice(models.RoleChoices.choices)[0], + "role": random.choice(models.RoleChoices.values), } for field, value in new_values.items(): @@ -302,7 +302,7 @@ def test_api_document_accesses_update_authenticated_reader_or_editor( new_values = { "id": uuid4(), "user": factories.UserFactory().id, - "role": random.choice(models.RoleChoices.choices)[0], + "role": random.choice(models.RoleChoices.values), } for field, value in new_values.items(): @@ -413,7 +413,7 @@ def test_api_document_accesses_update_administrator_from_owner(via, mock_user_te new_values = { "id": uuid4(), "user_id": factories.UserFactory().id, - "role": random.choice(models.RoleChoices.choices)[0], + "role": random.choice(models.RoleChoices.values), } for field, value in new_values.items(): @@ -527,7 +527,7 @@ def test_api_document_accesses_update_owner( new_values = { "id": uuid4(), "user_id": factories.UserFactory().id, - "role": random.choice(models.RoleChoices.choices)[0], + "role": random.choice(models.RoleChoices.values), } for field, value in new_values.items(): diff --git a/src/backend/core/tests/documents/test_api_document_accesses_create.py b/src/backend/core/tests/documents/test_api_document_accesses_create.py index 8be7865f..3d95bd62 100644 --- a/src/backend/core/tests/documents/test_api_document_accesses_create.py +++ b/src/backend/core/tests/documents/test_api_document_accesses_create.py @@ -26,7 +26,7 @@ def test_api_document_accesses_create_anonymous(): { "user_id": str(other_user.id), "document": str(document.id), - "role": random.choice(models.RoleChoices.choices)[0], + "role": random.choice(models.RoleChoices.values), }, format="json", ) diff --git a/src/backend/core/tests/documents/test_api_document_invitations.py b/src/backend/core/tests/documents/test_api_document_invitations.py index b30d0a48..5a5e40f7 100644 --- a/src/backend/core/tests/documents/test_api_document_invitations.py +++ b/src/backend/core/tests/documents/test_api_document_invitations.py @@ -304,7 +304,7 @@ def test_api_document_invitations_create_anonymous(): document = factories.DocumentFactory() invitation_values = { "email": "guest@example.com", - "role": random.choice(models.RoleChoices.choices)[0], + "role": random.choice(models.RoleChoices.values), } response = APIClient().post( @@ -325,7 +325,7 @@ def test_api_document_invitations_create_authenticated_outsider(): document = factories.DocumentFactory() invitation_values = { "email": "guest@example.com", - "role": random.choice(models.RoleChoices.choices)[0], + "role": random.choice(models.RoleChoices.values), } client = APIClient() @@ -550,7 +550,7 @@ def test_api_document_invitations_create_issuer_and_document_override(): "document": str(other_document.id), "issuer": str(factories.UserFactory().id), "email": "guest@example.com", - "role": random.choice(models.RoleChoices.choices)[0], + "role": random.choice(models.RoleChoices.values), } client = APIClient() @@ -611,7 +611,7 @@ def test_api_document_invitations_create_cannot_invite_existing_users(): # Build an invitation to the email of an exising identity in the db invitation_values = { "email": existing_user.email, - "role": random.choice(models.RoleChoices.choices)[0], + "role": random.choice(models.RoleChoices.values), } client = APIClient() diff --git a/src/backend/core/tests/documents/test_api_document_versions.py b/src/backend/core/tests/documents/test_api_document_versions.py index 4546d18c..83b8c7f5 100644 --- a/src/backend/core/tests/documents/test_api_document_versions.py +++ b/src/backend/core/tests/documents/test_api_document_versions.py @@ -75,14 +75,14 @@ def test_api_document_versions_list_authenticated_related_success(via, mock_user models.DocumentAccess.objects.create( document=document, user=user, - role=random.choice(models.RoleChoices.choices)[0], + role=random.choice(models.RoleChoices.values), ) elif via == TEAM: mock_user_teams.return_value = ["lasuite", "unknown"] models.DocumentAccess.objects.create( document=document, team="lasuite", - role=random.choice(models.RoleChoices.choices)[0], + role=random.choice(models.RoleChoices.values), ) # Other versions of documents to which the user has access should not be listed @@ -134,14 +134,14 @@ def test_api_document_versions_list_authenticated_related_pagination( models.DocumentAccess.objects.create( document=document, user=user, - role=random.choice(models.RoleChoices.choices)[0], + role=random.choice(models.RoleChoices.values), ) elif via == TEAM: mock_user_teams.return_value = ["lasuite", "unknown"] models.DocumentAccess.objects.create( document=document, team="lasuite", - role=random.choice(models.RoleChoices.choices)[0], + role=random.choice(models.RoleChoices.values), ) for i in range(4): @@ -210,14 +210,14 @@ def test_api_document_versions_list_authenticated_related_pagination_parent( models.DocumentAccess.objects.create( document=grand_parent, user=user, - role=random.choice(models.RoleChoices.choices)[0], + role=random.choice(models.RoleChoices.values), ) elif via == TEAM: mock_user_teams.return_value = ["lasuite", "unknown"] models.DocumentAccess.objects.create( document=grand_parent, team="lasuite", - role=random.choice(models.RoleChoices.choices)[0], + role=random.choice(models.RoleChoices.values), ) for i in range(4): diff --git a/src/backend/core/tests/templates/test_api_template_accesses.py b/src/backend/core/tests/templates/test_api_template_accesses.py index cd08abfb..86e5f2bd 100644 --- a/src/backend/core/tests/templates/test_api_template_accesses.py +++ b/src/backend/core/tests/templates/test_api_template_accesses.py @@ -73,14 +73,14 @@ def test_api_template_accesses_list_authenticated_related(via, mock_user_teams): user_access = models.TemplateAccess.objects.create( template=template, user=user, - role=random.choice(models.RoleChoices.choices)[0], + role=random.choice(models.RoleChoices.values), ) elif via == TEAM: mock_user_teams.return_value = ["lasuite", "unknown"] user_access = models.TemplateAccess.objects.create( template=template, team="lasuite", - role=random.choice(models.RoleChoices.choices)[0], + role=random.choice(models.RoleChoices.values), ) access1 = factories.TeamTemplateAccessFactory(template=template) @@ -219,7 +219,7 @@ def test_api_template_accesses_update_anonymous(): new_values = { "id": uuid4(), "user": factories.UserFactory().id, - "role": random.choice(models.RoleChoices.choices)[0], + "role": random.choice(models.RoleChoices.values), } api_client = APIClient() @@ -252,7 +252,7 @@ def test_api_template_accesses_update_authenticated_unrelated(): new_values = { "id": uuid4(), "user": factories.UserFactory().id, - "role": random.choice(models.RoleChoices.choices)[0], + "role": random.choice(models.RoleChoices.values), } for field, value in new_values.items(): @@ -294,7 +294,7 @@ def test_api_template_accesses_update_authenticated_editor_or_reader( new_values = { "id": uuid4(), "user": factories.UserFactory().id, - "role": random.choice(models.RoleChoices.choices)[0], + "role": random.choice(models.RoleChoices.values), } for field, value in new_values.items(): @@ -398,7 +398,7 @@ def test_api_template_accesses_update_administrator_from_owner(via, mock_user_te new_values = { "id": uuid4(), "user_id": factories.UserFactory().id, - "role": random.choice(models.RoleChoices.choices)[0], + "role": random.choice(models.RoleChoices.values), } for field, value in new_values.items(): @@ -497,7 +497,7 @@ def test_api_template_accesses_update_owner(via, mock_user_teams): new_values = { "id": uuid4(), "user_id": factories.UserFactory().id, - "role": random.choice(models.RoleChoices.choices)[0], + "role": random.choice(models.RoleChoices.values), } for field, value in new_values.items(): diff --git a/src/backend/core/tests/templates/test_api_template_accesses_create.py b/src/backend/core/tests/templates/test_api_template_accesses_create.py index 9a5072ed..f52a5344 100644 --- a/src/backend/core/tests/templates/test_api_template_accesses_create.py +++ b/src/backend/core/tests/templates/test_api_template_accesses_create.py @@ -23,7 +23,7 @@ def test_api_template_accesses_create_anonymous(): { "user": str(other_user.id), "template": str(template.id), - "role": random.choice(models.RoleChoices.choices)[0], + "role": random.choice(models.RoleChoices.values), }, format="json", )