♻️(backend) generalize use of SerializerPerActionMixin in viewsets

We improve the serializer and generalize its use for clarity.
This commit is contained in:
Samuel Paccoud - DINUM
2025-01-17 19:50:03 +01:00
committed by Anthony LC
parent 8c247c8777
commit 5042f4ca47
7 changed files with 42 additions and 43 deletions

View File

@@ -117,16 +117,23 @@ class SerializerPerActionMixin:
This mixin is useful to avoid to define a serializer class for each action in the This mixin is useful to avoid to define a serializer class for each action in the
`get_serializer_class` method. `get_serializer_class` method.
"""
serializer_classes: dict[str, type] = {} Example:
default_serializer_class: type = None ```
class MyViewSet(SerializerPerActionMixin, viewsets.GenericViewSet):
serializer_class = MySerializer
list_serializer_class = MyListSerializer
retrieve_serializer_class = MyRetrieveSerializer
```
"""
def get_serializer_class(self): def get_serializer_class(self):
""" """
Return the serializer class to use depending on the action. Return the serializer class to use depending on the action.
""" """
return self.serializer_classes.get(self.action, self.default_serializer_class) if serializer_class := getattr(self, f"{self.action}_serializer_class", None):
return serializer_class
return super().get_serializer_class()
class Pagination(drf.pagination.PageNumberPagination): class Pagination(drf.pagination.PageNumberPagination):
@@ -314,6 +321,7 @@ class DocumentMetadata(drf.metadata.SimpleMetadata):
# pylint: disable=too-many-public-methods # pylint: disable=too-many-public-methods
class DocumentViewSet( class DocumentViewSet(
SerializerPerActionMixin,
drf.mixins.CreateModelMixin, drf.mixins.CreateModelMixin,
drf.mixins.DestroyModelMixin, drf.mixins.DestroyModelMixin,
drf.mixins.ListModelMixin, drf.mixins.ListModelMixin,
@@ -424,16 +432,10 @@ class DocumentViewSet(
] ]
queryset = models.Document.objects.all() queryset = models.Document.objects.all()
serializer_class = serializers.DocumentSerializer serializer_class = serializers.DocumentSerializer
list_serializer_class = serializers.ListDocumentSerializer
def get_serializer_class(self): trashbin_serializer_class = serializers.ListDocumentSerializer
""" children_serializer_class = serializers.ListDocumentSerializer
Use ListDocumentSerializer for list actions; otherwise, use DocumentSerializer. ai_translate_serializer_class = serializers.AITranslateSerializer
"""
return (
serializers.ListDocumentSerializer
if self.action == "list"
else self.serializer_class
)
def annotate_is_favorite(self, queryset): def annotate_is_favorite(self, queryset):
""" """
@@ -613,7 +615,6 @@ class DocumentViewSet(
@drf.decorators.action( @drf.decorators.action(
detail=False, detail=False,
methods=["get"], methods=["get"],
serializer_class=serializers.ListDocumentSerializer,
) )
def trashbin(self, request, *args, **kwargs): def trashbin(self, request, *args, **kwargs):
""" """
@@ -734,7 +735,6 @@ class DocumentViewSet(
detail=True, detail=True,
methods=["get", "post"], methods=["get", "post"],
ordering=["path"], ordering=["path"],
serializer_class=serializers.ListDocumentSerializer,
url_path="children", url_path="children",
) )
def children(self, request, *args, **kwargs): def children(self, request, *args, **kwargs):
@@ -1091,7 +1091,6 @@ class DocumentViewSet(
detail=True, detail=True,
methods=["post"], methods=["post"],
name="Translate a piece of text with AI", name="Translate a piece of text with AI",
serializer_class=serializers.AITranslateSerializer,
url_path="ai-translate", url_path="ai-translate",
throttle_classes=[utils.AIDocumentRateThrottle, utils.AIUserRateThrottle], throttle_classes=[utils.AIDocumentRateThrottle, utils.AIUserRateThrottle],
) )

View File

@@ -76,14 +76,14 @@ def test_api_document_accesses_list_authenticated_related(via, mock_user_teams):
user_access = models.DocumentAccess.objects.create( user_access = models.DocumentAccess.objects.create(
document=document, document=document,
user=user, user=user,
role=random.choice(models.RoleChoices.choices)[0], role=random.choice(models.RoleChoices.values),
) )
elif via == TEAM: elif via == TEAM:
mock_user_teams.return_value = ["lasuite", "unknown"] mock_user_teams.return_value = ["lasuite", "unknown"]
user_access = models.DocumentAccess.objects.create( user_access = models.DocumentAccess.objects.create(
document=document, document=document,
team="lasuite", team="lasuite",
role=random.choice(models.RoleChoices.choices)[0], role=random.choice(models.RoleChoices.values),
) )
access1 = factories.TeamDocumentAccessFactory(document=document) access1 = factories.TeamDocumentAccessFactory(document=document)
@@ -227,7 +227,7 @@ def test_api_document_accesses_update_anonymous():
new_values = { new_values = {
"id": uuid4(), "id": uuid4(),
"user": factories.UserFactory().id, "user": factories.UserFactory().id,
"role": random.choice(models.RoleChoices.choices)[0], "role": random.choice(models.RoleChoices.values),
} }
api_client = APIClient() api_client = APIClient()
@@ -260,7 +260,7 @@ def test_api_document_accesses_update_authenticated_unrelated():
new_values = { new_values = {
"id": uuid4(), "id": uuid4(),
"user": factories.UserFactory().id, "user": factories.UserFactory().id,
"role": random.choice(models.RoleChoices.choices)[0], "role": random.choice(models.RoleChoices.values),
} }
for field, value in new_values.items(): for field, value in new_values.items():
@@ -302,7 +302,7 @@ def test_api_document_accesses_update_authenticated_reader_or_editor(
new_values = { new_values = {
"id": uuid4(), "id": uuid4(),
"user": factories.UserFactory().id, "user": factories.UserFactory().id,
"role": random.choice(models.RoleChoices.choices)[0], "role": random.choice(models.RoleChoices.values),
} }
for field, value in new_values.items(): for field, value in new_values.items():
@@ -413,7 +413,7 @@ def test_api_document_accesses_update_administrator_from_owner(via, mock_user_te
new_values = { new_values = {
"id": uuid4(), "id": uuid4(),
"user_id": factories.UserFactory().id, "user_id": factories.UserFactory().id,
"role": random.choice(models.RoleChoices.choices)[0], "role": random.choice(models.RoleChoices.values),
} }
for field, value in new_values.items(): for field, value in new_values.items():
@@ -527,7 +527,7 @@ def test_api_document_accesses_update_owner(
new_values = { new_values = {
"id": uuid4(), "id": uuid4(),
"user_id": factories.UserFactory().id, "user_id": factories.UserFactory().id,
"role": random.choice(models.RoleChoices.choices)[0], "role": random.choice(models.RoleChoices.values),
} }
for field, value in new_values.items(): for field, value in new_values.items():

View File

@@ -26,7 +26,7 @@ def test_api_document_accesses_create_anonymous():
{ {
"user_id": str(other_user.id), "user_id": str(other_user.id),
"document": str(document.id), "document": str(document.id),
"role": random.choice(models.RoleChoices.choices)[0], "role": random.choice(models.RoleChoices.values),
}, },
format="json", format="json",
) )

View File

@@ -304,7 +304,7 @@ def test_api_document_invitations_create_anonymous():
document = factories.DocumentFactory() document = factories.DocumentFactory()
invitation_values = { invitation_values = {
"email": "guest@example.com", "email": "guest@example.com",
"role": random.choice(models.RoleChoices.choices)[0], "role": random.choice(models.RoleChoices.values),
} }
response = APIClient().post( response = APIClient().post(
@@ -325,7 +325,7 @@ def test_api_document_invitations_create_authenticated_outsider():
document = factories.DocumentFactory() document = factories.DocumentFactory()
invitation_values = { invitation_values = {
"email": "guest@example.com", "email": "guest@example.com",
"role": random.choice(models.RoleChoices.choices)[0], "role": random.choice(models.RoleChoices.values),
} }
client = APIClient() client = APIClient()
@@ -550,7 +550,7 @@ def test_api_document_invitations_create_issuer_and_document_override():
"document": str(other_document.id), "document": str(other_document.id),
"issuer": str(factories.UserFactory().id), "issuer": str(factories.UserFactory().id),
"email": "guest@example.com", "email": "guest@example.com",
"role": random.choice(models.RoleChoices.choices)[0], "role": random.choice(models.RoleChoices.values),
} }
client = APIClient() client = APIClient()
@@ -611,7 +611,7 @@ def test_api_document_invitations_create_cannot_invite_existing_users():
# Build an invitation to the email of an exising identity in the db # Build an invitation to the email of an exising identity in the db
invitation_values = { invitation_values = {
"email": existing_user.email, "email": existing_user.email,
"role": random.choice(models.RoleChoices.choices)[0], "role": random.choice(models.RoleChoices.values),
} }
client = APIClient() client = APIClient()

View File

@@ -75,14 +75,14 @@ def test_api_document_versions_list_authenticated_related_success(via, mock_user
models.DocumentAccess.objects.create( models.DocumentAccess.objects.create(
document=document, document=document,
user=user, user=user,
role=random.choice(models.RoleChoices.choices)[0], role=random.choice(models.RoleChoices.values),
) )
elif via == TEAM: elif via == TEAM:
mock_user_teams.return_value = ["lasuite", "unknown"] mock_user_teams.return_value = ["lasuite", "unknown"]
models.DocumentAccess.objects.create( models.DocumentAccess.objects.create(
document=document, document=document,
team="lasuite", team="lasuite",
role=random.choice(models.RoleChoices.choices)[0], role=random.choice(models.RoleChoices.values),
) )
# Other versions of documents to which the user has access should not be listed # Other versions of documents to which the user has access should not be listed
@@ -134,14 +134,14 @@ def test_api_document_versions_list_authenticated_related_pagination(
models.DocumentAccess.objects.create( models.DocumentAccess.objects.create(
document=document, document=document,
user=user, user=user,
role=random.choice(models.RoleChoices.choices)[0], role=random.choice(models.RoleChoices.values),
) )
elif via == TEAM: elif via == TEAM:
mock_user_teams.return_value = ["lasuite", "unknown"] mock_user_teams.return_value = ["lasuite", "unknown"]
models.DocumentAccess.objects.create( models.DocumentAccess.objects.create(
document=document, document=document,
team="lasuite", team="lasuite",
role=random.choice(models.RoleChoices.choices)[0], role=random.choice(models.RoleChoices.values),
) )
for i in range(4): for i in range(4):
@@ -210,14 +210,14 @@ def test_api_document_versions_list_authenticated_related_pagination_parent(
models.DocumentAccess.objects.create( models.DocumentAccess.objects.create(
document=grand_parent, document=grand_parent,
user=user, user=user,
role=random.choice(models.RoleChoices.choices)[0], role=random.choice(models.RoleChoices.values),
) )
elif via == TEAM: elif via == TEAM:
mock_user_teams.return_value = ["lasuite", "unknown"] mock_user_teams.return_value = ["lasuite", "unknown"]
models.DocumentAccess.objects.create( models.DocumentAccess.objects.create(
document=grand_parent, document=grand_parent,
team="lasuite", team="lasuite",
role=random.choice(models.RoleChoices.choices)[0], role=random.choice(models.RoleChoices.values),
) )
for i in range(4): for i in range(4):

View File

@@ -73,14 +73,14 @@ def test_api_template_accesses_list_authenticated_related(via, mock_user_teams):
user_access = models.TemplateAccess.objects.create( user_access = models.TemplateAccess.objects.create(
template=template, template=template,
user=user, user=user,
role=random.choice(models.RoleChoices.choices)[0], role=random.choice(models.RoleChoices.values),
) )
elif via == TEAM: elif via == TEAM:
mock_user_teams.return_value = ["lasuite", "unknown"] mock_user_teams.return_value = ["lasuite", "unknown"]
user_access = models.TemplateAccess.objects.create( user_access = models.TemplateAccess.objects.create(
template=template, template=template,
team="lasuite", team="lasuite",
role=random.choice(models.RoleChoices.choices)[0], role=random.choice(models.RoleChoices.values),
) )
access1 = factories.TeamTemplateAccessFactory(template=template) access1 = factories.TeamTemplateAccessFactory(template=template)
@@ -219,7 +219,7 @@ def test_api_template_accesses_update_anonymous():
new_values = { new_values = {
"id": uuid4(), "id": uuid4(),
"user": factories.UserFactory().id, "user": factories.UserFactory().id,
"role": random.choice(models.RoleChoices.choices)[0], "role": random.choice(models.RoleChoices.values),
} }
api_client = APIClient() api_client = APIClient()
@@ -252,7 +252,7 @@ def test_api_template_accesses_update_authenticated_unrelated():
new_values = { new_values = {
"id": uuid4(), "id": uuid4(),
"user": factories.UserFactory().id, "user": factories.UserFactory().id,
"role": random.choice(models.RoleChoices.choices)[0], "role": random.choice(models.RoleChoices.values),
} }
for field, value in new_values.items(): for field, value in new_values.items():
@@ -294,7 +294,7 @@ def test_api_template_accesses_update_authenticated_editor_or_reader(
new_values = { new_values = {
"id": uuid4(), "id": uuid4(),
"user": factories.UserFactory().id, "user": factories.UserFactory().id,
"role": random.choice(models.RoleChoices.choices)[0], "role": random.choice(models.RoleChoices.values),
} }
for field, value in new_values.items(): for field, value in new_values.items():
@@ -398,7 +398,7 @@ def test_api_template_accesses_update_administrator_from_owner(via, mock_user_te
new_values = { new_values = {
"id": uuid4(), "id": uuid4(),
"user_id": factories.UserFactory().id, "user_id": factories.UserFactory().id,
"role": random.choice(models.RoleChoices.choices)[0], "role": random.choice(models.RoleChoices.values),
} }
for field, value in new_values.items(): for field, value in new_values.items():
@@ -497,7 +497,7 @@ def test_api_template_accesses_update_owner(via, mock_user_teams):
new_values = { new_values = {
"id": uuid4(), "id": uuid4(),
"user_id": factories.UserFactory().id, "user_id": factories.UserFactory().id,
"role": random.choice(models.RoleChoices.choices)[0], "role": random.choice(models.RoleChoices.values),
} }
for field, value in new_values.items(): for field, value in new_values.items():

View File

@@ -23,7 +23,7 @@ def test_api_template_accesses_create_anonymous():
{ {
"user": str(other_user.id), "user": str(other_user.id),
"template": str(template.id), "template": str(template.id),
"role": random.choice(models.RoleChoices.choices)[0], "role": random.choice(models.RoleChoices.values),
}, },
format="json", format="json",
) )