From 81243cfc9a7e3754027f011a3d34aad3307aff66 Mon Sep 17 00:00:00 2001 From: Marie PUPO JEAMMET Date: Mon, 26 Feb 2024 18:03:31 +0100 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8(api)=20return=20user=20id,=20name=20a?= =?UTF-8?q?nd=20email=20on=20/team//accesses/?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add serializers to return basic user info when listing /team//accesses/ endpoint. This will allow front-end to retrieve members info without having to query API for each user.id. --- src/backend/core/api/serializers.py | 70 ++++++++++- src/backend/core/api/viewsets.py | 33 ++++- .../test_api_team_accesses_create.py | 4 +- .../test_api_team_accesses_list.py | 114 ++++++++++++++++-- .../test_api_team_accesses_retrieve.py | 29 ++--- .../teams/test_core_api_teams_retrieve.py | 21 +--- src/backend/core/tests/test_api_users.py | 80 +++++++++--- 7 files changed, 283 insertions(+), 68 deletions(-) diff --git a/src/backend/core/api/serializers.py b/src/backend/core/api/serializers.py index e15273e..5e19082 100644 --- a/src/backend/core/api/serializers.py +++ b/src/backend/core/api/serializers.py @@ -26,22 +26,66 @@ class ContactSerializer(serializers.ModelSerializer): return super().update(instance, validated_data) -class UserSerializer(serializers.ModelSerializer): +class DynamicFieldsModelSerializer(serializers.ModelSerializer): + """ + A ModelSerializer that takes an additional `fields` argument that + controls which fields should be displayed. + """ + + def __init__(self, *args, **kwargs): + # Don't pass the 'fields' arg up to the superclass + fields = kwargs.pop("fields", None) + + # Instantiate the superclass normally + super().__init__(*args, **kwargs) + + if fields is not None: + # Drop any fields that are not specified in the `fields` argument. + allowed = set(fields) + existing = set(self.fields) + for field_name in existing - allowed: + self.fields.pop(field_name) + + +class UserSerializer(DynamicFieldsModelSerializer): """Serialize users.""" timezone = TimeZoneSerializerField(use_pytz=False, required=True) + name = serializers.SerializerMethodField(read_only=True) + email = serializers.SerializerMethodField(read_only=True) class Meta: model = models.User fields = [ "id", + "name", "email", "language", "timezone", "is_device", "is_staff", ] - read_only_fields = ["id", "email", "is_device", "is_staff"] + read_only_fields = ["id", "name", "email", "is_device", "is_staff"] + + def _get_main_identity_attr(self, obj, attribute_name): + """Return the specified attribute of the main identity.""" + try: + return getattr(obj.main_identity[0], attribute_name) + except TypeError: + return getattr(obj.main_identity, attribute_name) + except IndexError: + main_identity = obj.identities.filter(is_main=True).first() + return getattr(obj.main_identity, attribute_name) if main_identity else None + except AttributeError: + return None + + def get_name(self, obj): + """Return main identity's name.""" + return self._get_main_identity_attr(obj, "name") + + def get_email(self, obj): + """Return main identity's email.""" + return self._get_main_identity_attr(obj, "email") class TeamAccessSerializer(serializers.ModelSerializer): @@ -120,11 +164,31 @@ class TeamAccessSerializer(serializers.ModelSerializer): return attrs +class TeamAccessReadOnlySerializer(TeamAccessSerializer): + """Serialize team accesses for list and retrieve actions.""" + + user = UserSerializer(read_only=True, fields=["id", "name", "email"]) + + class Meta: + model = models.TeamAccess + fields = [ + "id", + "user", + "role", + "abilities", + ] + read_only_fields = [ + "id", + "user", + "role", + "abilities", + ] + + class TeamSerializer(serializers.ModelSerializer): """Serialize teams.""" abilities = serializers.SerializerMethodField(read_only=True) - accesses = TeamAccessSerializer(many=True, read_only=True) slug = serializers.SerializerMethodField() class Meta: diff --git a/src/backend/core/api/viewsets.py b/src/backend/core/api/viewsets.py index f17fb86..9d32a74 100644 --- a/src/backend/core/api/viewsets.py +++ b/src/backend/core/api/viewsets.py @@ -1,6 +1,6 @@ """API endpoints""" from django.contrib.postgres.search import TrigramSimilarity -from django.db.models import Func, Max, OuterRef, Q, Subquery, Value +from django.db.models import Func, Max, OuterRef, Prefetch, Q, Subquery, Value from rest_framework import ( decorators, @@ -199,6 +199,12 @@ class UserViewSet( # Exclude inactive contacts queryset = queryset.filter( is_active=True, + ).prefetch_related( + Prefetch( + "identities", + queryset=models.Identity.objects.filter(is_main=True), + to_attr="main_identity", + ) ) # Search by case-insensitive and accent-insensitive trigram similarity @@ -227,9 +233,12 @@ class UserViewSet( """ Return information on currently logged user """ - context = {"request": request} + user = request.user + user.main_identity = models.Identity.objects.filter( + user=user, is_main=True + ).first() return response.Response( - self.serializer_class(request.user, context=context).data + self.serializer_class(user, context={"request": request}).data ) @@ -305,7 +314,8 @@ class TeamAccessViewSet( pagination_class = Pagination permission_classes = [permissions.AccessPermission] queryset = models.TeamAccess.objects.all().select_related("user") - serializer_class = serializers.TeamAccessSerializer + list_serializer_class = serializers.TeamAccessReadOnlySerializer + detail_serializer_class = serializers.TeamAccessSerializer def get_permissions(self): """User only needs to be authenticated to list team accesses""" @@ -322,22 +332,35 @@ class TeamAccessViewSet( context["team_id"] = self.kwargs["team_id"] return context + def get_serializer_class(self): + if self.action in {"list", "retrieve"}: + return self.list_serializer_class + return self.detail_serializer_class + def get_queryset(self): """Return the queryset according to the action.""" queryset = super().get_queryset() queryset = queryset.filter(team=self.kwargs["team_id"]) - if self.action == "list": + if self.action in {"list", "retrieve"}: # Limit to team access instances related to a team THAT also has a team access # instance for the logged-in user (we don't want to list only the team access # instances pointing to the logged-in user) user_role_query = models.TeamAccess.objects.filter( team__accesses__user=self.request.user ).values("role")[:1] + queryset = ( queryset.filter( team__accesses__user=self.request.user, ) + .prefetch_related( + Prefetch( + "user__identities", + queryset=models.Identity.objects.filter(is_main=True), + to_attr="main_identity", + ) + ) .annotate(user_role=Subquery(user_role_query)) .distinct() ) diff --git a/src/backend/core/tests/team_accesses/test_api_team_accesses_create.py b/src/backend/core/tests/team_accesses/test_api_team_accesses_create.py index 5375b61..5343cf8 100644 --- a/src/backend/core/tests/team_accesses/test_api_team_accesses_create.py +++ b/src/backend/core/tests/team_accesses/test_api_team_accesses_create.py @@ -45,7 +45,7 @@ def test_api_team_accesses_create_authenticated_unrelated(): client = APIClient() client.force_login(user) - response = APIClient().post( + response = client.post( f"/api/v1.0/teams/{team.id!s}/accesses/", { "user": str(other_user.id), @@ -155,7 +155,7 @@ def test_api_team_accesses_create_authenticated_owner(): client = APIClient() client.force_login(user) - response = APIClient().post( + response = client.post( f"/api/v1.0/teams/{team.id!s}/accesses/", { "user": str(other_user.id), diff --git a/src/backend/core/tests/team_accesses/test_api_team_accesses_list.py b/src/backend/core/tests/team_accesses/test_api_team_accesses_list.py index e69cf07..0ffea9b 100644 --- a/src/backend/core/tests/team_accesses/test_api_team_accesses_list.py +++ b/src/backend/core/tests/team_accesses/test_api_team_accesses_list.py @@ -55,12 +55,16 @@ def test_api_team_accesses_list_authenticated_related(): Authenticated users should be able to list team accesses for a team to which they are related, whatever their role in the team. """ - identity = factories.IdentityFactory() + identity = factories.IdentityFactory(is_main=True) user = identity.user team = factories.TeamFactory() user_access = models.TeamAccess.objects.create(team=team, user=user) # random role - access1, access2 = factories.TeamAccessFactory.create_batch(2, team=team) + + # other team members should appear + other_member = factories.UserFactory() + other_member_identity = factories.IdentityFactory(is_main=True, user=other_member) + access1 = factories.TeamAccessFactory.create(team=team, user=other_member) # Accesses for other teams to which the user is related should not be listed either other_access = factories.TeamAccessFactory(user=user) @@ -73,27 +77,111 @@ def test_api_team_accesses_list_authenticated_related(): ) assert response.status_code == 200 - assert response.json()["count"] == 3 + assert response.json()["count"] == 2 assert sorted(response.json()["results"], key=lambda x: x["id"]) == sorted( [ { "id": str(user_access.id), - "user": str(user.id), - "role": user_access.role, + "user": { + "id": str(user_access.user.id), + "email": str(identity.email), + "name": str(identity.name), + }, + "role": str(user_access.role), "abilities": user_access.get_abilities(user), }, { "id": str(access1.id), - "user": str(access1.user.id), - "role": access1.role, + "user": { + "id": str(access1.user.id), + "email": str(other_member_identity.email), + "name": str(other_member_identity.name), + }, + "role": str(access1.role), "abilities": access1.get_abilities(user), }, - { - "id": str(access2.id), - "user": str(access2.user.id), - "role": access2.role, - "abilities": access2.get_abilities(user), - }, ], key=lambda x: x["id"], ) + + +def test_api_team_accesses_list_authenticated_main_identity(): + """ + Name and email should be returned from main identity only + """ + user = factories.UserFactory() + identity = factories.IdentityFactory(user=user, is_main=True) + factories.IdentityFactory(user=user) # additional non-main identity + + team = factories.TeamFactory() + models.TeamAccess.objects.create(team=team, user=user) # random role + + # other team members should appear, with correct identity + other_user = factories.UserFactory() + other_main_identity = factories.IdentityFactory(is_main=True, user=other_user) + factories.IdentityFactory(user=other_user) + factories.TeamAccessFactory.create(team=team, user=other_user) + + # Accesses for other teams to which the user is related should not be listed either + other_access = factories.TeamAccessFactory(user=user) + factories.TeamAccessFactory(team=other_access.team) + + client = APIClient() + client.force_login(user) + response = client.get( + f"/api/v1.0/teams/{team.id!s}/accesses/", + ) + + assert response.status_code == 200 + assert response.json()["count"] == 2 + users_info = [ + (access["user"]["email"], access["user"]["name"]) + for access in response.json()["results"] + ] + # user information should be returned from main identity + assert sorted(users_info) == sorted( + [ + (str(identity.email), str(identity.name)), + (str(other_main_identity.email), str(other_main_identity.name)), + ] + ) + + +def test_api_team_accesses_list_authenticated_constant_numqueries( + django_assert_num_queries, +): + """ + The number of queries should not depend on the amount of fetched accesses. + """ + user = factories.UserFactory() + factories.IdentityFactory(user=user, is_main=True) + + team = factories.TeamFactory() + models.TeamAccess.objects.create(team=team, user=user) # random role + + client = APIClient() + client.force_login(user) + # Only 4 queries are needed to efficiently fetch team accesses, + # related users and identities : + # - query retrieving logged-in user for user_role annotation + # - count from pagination + # - query prefetching users' main identity + # - distinct from viewset + with django_assert_num_queries(4): + response = client.get( + f"/api/v1.0/teams/{team.id!s}/accesses/", + ) + + # create 20 new team members + for _ in range(20): + extra_user = factories.IdentityFactory(is_main=True).user + factories.TeamAccessFactory(team=team, user=extra_user) + + # num queries should still be 4 + with django_assert_num_queries(4): + response = client.get( + f"/api/v1.0/teams/{team.id!s}/accesses/", + ) + + assert response.status_code == 200 + assert response.json()["count"] == 21 diff --git a/src/backend/core/tests/team_accesses/test_api_team_accesses_retrieve.py b/src/backend/core/tests/team_accesses/test_api_team_accesses_retrieve.py index 7c4aad1..21f628d 100644 --- a/src/backend/core/tests/team_accesses/test_api_team_accesses_retrieve.py +++ b/src/backend/core/tests/team_accesses/test_api_team_accesses_retrieve.py @@ -33,26 +33,23 @@ def test_api_team_accesses_retrieve_authenticated_unrelated(): identity = factories.IdentityFactory() user = identity.user - team = factories.TeamFactory() - access = factories.TeamAccessFactory(team=team) + access = factories.TeamAccessFactory(team=factories.TeamFactory()) client = APIClient() client.force_login(user) response = client.get( - f"/api/v1.0/teams/{team.id!s}/accesses/{access.id!s}/", + f"/api/v1.0/teams/{access.team.id!s}/accesses/{access.id!s}/", ) - assert response.status_code == 403 - assert response.json() == { - "detail": "You do not have permission to perform this action." - } + assert response.status_code == 404 + assert response.json() == {"detail": "Not found."} # Accesses related to another team should be excluded even if the user is related to it - for access in [ + for other_access in [ factories.TeamAccessFactory(), factories.TeamAccessFactory(user=user), ]: response = client.get( - f"/api/v1.0/teams/{team.id!s}/accesses/{access.id!s}/", + f"/api/v1.0/teams/{access.team.id!s}/accesses/{other_access.id!s}/", ) assert response.status_code == 404 @@ -64,11 +61,11 @@ def test_api_team_accesses_retrieve_authenticated_related(): A user who is related to a team should be allowed to retrieve the associated team user accesses. """ - identity = factories.IdentityFactory() + identity = factories.IdentityFactory(is_main=True) user = identity.user - team = factories.TeamFactory(users=[user]) - access = factories.TeamAccessFactory(team=team) + team = factories.TeamFactory() + access = factories.TeamAccessFactory(team=team, user=user) client = APIClient() client.force_login(user) @@ -79,7 +76,11 @@ def test_api_team_accesses_retrieve_authenticated_related(): assert response.status_code == 200 assert response.json() == { "id": str(access.id), - "user": str(access.user.id), - "role": access.role, + "user": { + "id": str(access.user.id), + "email": str(identity.email), + "name": str(identity.name), + }, + "role": str(access.role), "abilities": access.get_abilities(user), } diff --git a/src/backend/core/tests/teams/test_core_api_teams_retrieve.py b/src/backend/core/tests/teams/test_core_api_teams_retrieve.py index 6dab09d..3fa6243 100644 --- a/src/backend/core/tests/teams/test_core_api_teams_retrieve.py +++ b/src/backend/core/tests/teams/test_core_api_teams_retrieve.py @@ -58,24 +58,13 @@ def test_api_teams_retrieve_authenticated_related(): response = client.get( f"/api/v1.0/teams/{team.id!s}/", ) + assert response.status_code == HTTP_200_OK - content = response.json() - assert sorted(content.pop("accesses"), key=lambda x: x["user"]) == sorted( + assert sorted(response.json().pop("accesses")) == sorted( [ - { - "id": str(access1.id), - "user": str(user.id), - "role": access1.role, - "abilities": access1.get_abilities(user), - }, - { - "id": str(access2.id), - "user": str(access2.user.id), - "role": access2.role, - "abilities": access2.get_abilities(user), - }, - ], - key=lambda x: x["user"], + str(access1.id), + str(access2.id), + ] ) assert response.json() == { "id": str(team.id), diff --git a/src/backend/core/tests/test_api_users.py b/src/backend/core/tests/test_api_users.py index 1c48c85..21b48fe 100644 --- a/src/backend/core/tests/test_api_users.py +++ b/src/backend/core/tests/test_api_users.py @@ -58,10 +58,10 @@ def test_api_users_authenticated_list_by_email(): client = APIClient() client.force_login(user) - dave = factories.IdentityFactory(email="david.bowman@work.com") - nicole = factories.IdentityFactory(email="nicole_foole@work.com") - frank = factories.IdentityFactory(email="frank_poole@work.com") - factories.IdentityFactory(email="heywood_floyd@work.com") + dave = factories.IdentityFactory(email="david.bowman@work.com", is_main=True) + nicole = factories.IdentityFactory(email="nicole_foole@work.com", is_main=True) + frank = factories.IdentityFactory(email="frank_poole@work.com", is_main=True) + factories.IdentityFactory(email="heywood_floyd@work.com", is_main=True) # Full query should work response = client.get( @@ -91,8 +91,26 @@ def test_api_users_authenticated_list_by_email(): response = client.get("/api/v1.0/users/?q=ool") assert response.status_code == HTTP_200_OK - user_ids = [user["id"] for user in response.json()["results"]] - assert user_ids == [str(nicole.user.id), str(frank.user.id)] + assert response.json()["results"] == [ + { + "id": str(nicole.user.id), + "email": nicole.email, + "name": nicole.name, + "is_device": nicole.user.is_device, + "is_staff": nicole.user.is_staff, + "language": nicole.user.language, + "timezone": str(nicole.user.timezone), + }, + { + "id": str(frank.user.id), + "email": frank.email, + "name": frank.name, + "is_device": frank.user.is_device, + "is_staff": frank.user.is_staff, + "language": frank.user.language, + "timezone": str(frank.user.timezone), + }, + ] def test_api_users_authenticated_list_multiple_identities_single_user(): @@ -132,12 +150,16 @@ def test_api_users_authenticated_list_multiple_identities_multiple_users(): client.force_login(user) dave = factories.UserFactory() - davina = factories.UserFactory() - prudence = factories.UserFactory() - factories.IdentityFactory(user=dave, email="david.bowman@work.com") + dave_identity = factories.IdentityFactory( + user=dave, email="david.bowman@work.com", is_main=True + ) factories.IdentityFactory(user=dave, email="babibou@ehehe.com") - factories.IdentityFactory(user=davina, email="davina.bowan@work.com") - factories.IdentityFactory(user=prudence, email="prudence.crandall@work.com") + davina_identity = factories.IdentityFactory( + user=factories.UserFactory(), email="davina.bowan@work.com" + ) + prue_identity = factories.IdentityFactory( + user=factories.UserFactory(), email="prudence.crandall@work.com" + ) # Full query should work response = client.get( @@ -146,8 +168,35 @@ def test_api_users_authenticated_list_multiple_identities_multiple_users(): assert response.status_code == HTTP_200_OK assert response.json()["count"] == 3 - user_ids = [user["id"] for user in response.json()["results"]] - assert user_ids == [str(dave.id), str(davina.id), str(prudence.id)] + assert response.json()["results"] == [ + { + "id": str(dave.id), + "email": dave_identity.email, + "name": dave_identity.name, + "is_device": dave_identity.user.is_device, + "is_staff": dave_identity.user.is_staff, + "language": dave_identity.user.language, + "timezone": str(dave_identity.user.timezone), + }, + { + "id": str(davina_identity.user.id), + "email": davina_identity.email, + "name": davina_identity.name, + "is_device": davina_identity.user.is_device, + "is_staff": davina_identity.user.is_staff, + "language": davina_identity.user.language, + "timezone": str(davina_identity.user.timezone), + }, + { + "id": str(prue_identity.user.id), + "email": prue_identity.email, + "name": prue_identity.name, + "is_device": prue_identity.user.is_device, + "is_staff": prue_identity.user.is_staff, + "language": prue_identity.user.language, + "timezone": str(prue_identity.user.timezone), + }, + ] def test_api_users_authenticated_list_uppercase_content(): @@ -291,7 +340,7 @@ def test_api_users_retrieve_me_anonymous(): def test_api_users_retrieve_me_authenticated(): """Authenticated users should be able to retrieve their own user via the "/users/me" path.""" - identity = factories.IdentityFactory() + identity = factories.IdentityFactory(is_main=True) user = identity.user client = APIClient() @@ -310,7 +359,8 @@ def test_api_users_retrieve_me_authenticated(): assert response.status_code == HTTP_200_OK assert response.json() == { "id": str(user.id), - "email": str(user.email), + "name": str(identity.name), + "email": str(identity.email), "language": user.language, "timezone": str(user.timezone), "is_device": False,