(api) return user id, name and email on /team/<id>/accesses/

Add serializers to return basic user info when listing /team/<id>/accesses/
endpoint. This will allow front-end to retrieve members info without having
to query API for each user.id.
This commit is contained in:
Marie PUPO JEAMMET
2024-02-26 18:03:31 +01:00
committed by aleb_the_flash
parent 70b1b996df
commit 81243cfc9a
7 changed files with 283 additions and 68 deletions

View File

@@ -26,22 +26,66 @@ class ContactSerializer(serializers.ModelSerializer):
return super().update(instance, validated_data) return super().update(instance, validated_data)
class UserSerializer(serializers.ModelSerializer): class DynamicFieldsModelSerializer(serializers.ModelSerializer):
"""
A ModelSerializer that takes an additional `fields` argument that
controls which fields should be displayed.
"""
def __init__(self, *args, **kwargs):
# Don't pass the 'fields' arg up to the superclass
fields = kwargs.pop("fields", None)
# Instantiate the superclass normally
super().__init__(*args, **kwargs)
if fields is not None:
# Drop any fields that are not specified in the `fields` argument.
allowed = set(fields)
existing = set(self.fields)
for field_name in existing - allowed:
self.fields.pop(field_name)
class UserSerializer(DynamicFieldsModelSerializer):
"""Serialize users.""" """Serialize users."""
timezone = TimeZoneSerializerField(use_pytz=False, required=True) timezone = TimeZoneSerializerField(use_pytz=False, required=True)
name = serializers.SerializerMethodField(read_only=True)
email = serializers.SerializerMethodField(read_only=True)
class Meta: class Meta:
model = models.User model = models.User
fields = [ fields = [
"id", "id",
"name",
"email", "email",
"language", "language",
"timezone", "timezone",
"is_device", "is_device",
"is_staff", "is_staff",
] ]
read_only_fields = ["id", "email", "is_device", "is_staff"] read_only_fields = ["id", "name", "email", "is_device", "is_staff"]
def _get_main_identity_attr(self, obj, attribute_name):
"""Return the specified attribute of the main identity."""
try:
return getattr(obj.main_identity[0], attribute_name)
except TypeError:
return getattr(obj.main_identity, attribute_name)
except IndexError:
main_identity = obj.identities.filter(is_main=True).first()
return getattr(obj.main_identity, attribute_name) if main_identity else None
except AttributeError:
return None
def get_name(self, obj):
"""Return main identity's name."""
return self._get_main_identity_attr(obj, "name")
def get_email(self, obj):
"""Return main identity's email."""
return self._get_main_identity_attr(obj, "email")
class TeamAccessSerializer(serializers.ModelSerializer): class TeamAccessSerializer(serializers.ModelSerializer):
@@ -120,11 +164,31 @@ class TeamAccessSerializer(serializers.ModelSerializer):
return attrs return attrs
class TeamAccessReadOnlySerializer(TeamAccessSerializer):
"""Serialize team accesses for list and retrieve actions."""
user = UserSerializer(read_only=True, fields=["id", "name", "email"])
class Meta:
model = models.TeamAccess
fields = [
"id",
"user",
"role",
"abilities",
]
read_only_fields = [
"id",
"user",
"role",
"abilities",
]
class TeamSerializer(serializers.ModelSerializer): class TeamSerializer(serializers.ModelSerializer):
"""Serialize teams.""" """Serialize teams."""
abilities = serializers.SerializerMethodField(read_only=True) abilities = serializers.SerializerMethodField(read_only=True)
accesses = TeamAccessSerializer(many=True, read_only=True)
slug = serializers.SerializerMethodField() slug = serializers.SerializerMethodField()
class Meta: class Meta:

View File

@@ -1,6 +1,6 @@
"""API endpoints""" """API endpoints"""
from django.contrib.postgres.search import TrigramSimilarity from django.contrib.postgres.search import TrigramSimilarity
from django.db.models import Func, Max, OuterRef, Q, Subquery, Value from django.db.models import Func, Max, OuterRef, Prefetch, Q, Subquery, Value
from rest_framework import ( from rest_framework import (
decorators, decorators,
@@ -199,6 +199,12 @@ class UserViewSet(
# Exclude inactive contacts # Exclude inactive contacts
queryset = queryset.filter( queryset = queryset.filter(
is_active=True, is_active=True,
).prefetch_related(
Prefetch(
"identities",
queryset=models.Identity.objects.filter(is_main=True),
to_attr="main_identity",
)
) )
# Search by case-insensitive and accent-insensitive trigram similarity # Search by case-insensitive and accent-insensitive trigram similarity
@@ -227,9 +233,12 @@ class UserViewSet(
""" """
Return information on currently logged user Return information on currently logged user
""" """
context = {"request": request} user = request.user
user.main_identity = models.Identity.objects.filter(
user=user, is_main=True
).first()
return response.Response( return response.Response(
self.serializer_class(request.user, context=context).data self.serializer_class(user, context={"request": request}).data
) )
@@ -305,7 +314,8 @@ class TeamAccessViewSet(
pagination_class = Pagination pagination_class = Pagination
permission_classes = [permissions.AccessPermission] permission_classes = [permissions.AccessPermission]
queryset = models.TeamAccess.objects.all().select_related("user") queryset = models.TeamAccess.objects.all().select_related("user")
serializer_class = serializers.TeamAccessSerializer list_serializer_class = serializers.TeamAccessReadOnlySerializer
detail_serializer_class = serializers.TeamAccessSerializer
def get_permissions(self): def get_permissions(self):
"""User only needs to be authenticated to list team accesses""" """User only needs to be authenticated to list team accesses"""
@@ -322,22 +332,35 @@ class TeamAccessViewSet(
context["team_id"] = self.kwargs["team_id"] context["team_id"] = self.kwargs["team_id"]
return context return context
def get_serializer_class(self):
if self.action in {"list", "retrieve"}:
return self.list_serializer_class
return self.detail_serializer_class
def get_queryset(self): def get_queryset(self):
"""Return the queryset according to the action.""" """Return the queryset according to the action."""
queryset = super().get_queryset() queryset = super().get_queryset()
queryset = queryset.filter(team=self.kwargs["team_id"]) queryset = queryset.filter(team=self.kwargs["team_id"])
if self.action == "list": if self.action in {"list", "retrieve"}:
# Limit to team access instances related to a team THAT also has a team access # Limit to team access instances related to a team THAT also has a team access
# instance for the logged-in user (we don't want to list only the team access # instance for the logged-in user (we don't want to list only the team access
# instances pointing to the logged-in user) # instances pointing to the logged-in user)
user_role_query = models.TeamAccess.objects.filter( user_role_query = models.TeamAccess.objects.filter(
team__accesses__user=self.request.user team__accesses__user=self.request.user
).values("role")[:1] ).values("role")[:1]
queryset = ( queryset = (
queryset.filter( queryset.filter(
team__accesses__user=self.request.user, team__accesses__user=self.request.user,
) )
.prefetch_related(
Prefetch(
"user__identities",
queryset=models.Identity.objects.filter(is_main=True),
to_attr="main_identity",
)
)
.annotate(user_role=Subquery(user_role_query)) .annotate(user_role=Subquery(user_role_query))
.distinct() .distinct()
) )

View File

@@ -45,7 +45,7 @@ def test_api_team_accesses_create_authenticated_unrelated():
client = APIClient() client = APIClient()
client.force_login(user) client.force_login(user)
response = APIClient().post( response = client.post(
f"/api/v1.0/teams/{team.id!s}/accesses/", f"/api/v1.0/teams/{team.id!s}/accesses/",
{ {
"user": str(other_user.id), "user": str(other_user.id),
@@ -155,7 +155,7 @@ def test_api_team_accesses_create_authenticated_owner():
client = APIClient() client = APIClient()
client.force_login(user) client.force_login(user)
response = APIClient().post( response = client.post(
f"/api/v1.0/teams/{team.id!s}/accesses/", f"/api/v1.0/teams/{team.id!s}/accesses/",
{ {
"user": str(other_user.id), "user": str(other_user.id),

View File

@@ -55,12 +55,16 @@ def test_api_team_accesses_list_authenticated_related():
Authenticated users should be able to list team accesses for a team Authenticated users should be able to list team accesses for a team
to which they are related, whatever their role in the team. to which they are related, whatever their role in the team.
""" """
identity = factories.IdentityFactory() identity = factories.IdentityFactory(is_main=True)
user = identity.user user = identity.user
team = factories.TeamFactory() team = factories.TeamFactory()
user_access = models.TeamAccess.objects.create(team=team, user=user) # random role user_access = models.TeamAccess.objects.create(team=team, user=user) # random role
access1, access2 = factories.TeamAccessFactory.create_batch(2, team=team)
# other team members should appear
other_member = factories.UserFactory()
other_member_identity = factories.IdentityFactory(is_main=True, user=other_member)
access1 = factories.TeamAccessFactory.create(team=team, user=other_member)
# Accesses for other teams to which the user is related should not be listed either # Accesses for other teams to which the user is related should not be listed either
other_access = factories.TeamAccessFactory(user=user) other_access = factories.TeamAccessFactory(user=user)
@@ -73,27 +77,111 @@ def test_api_team_accesses_list_authenticated_related():
) )
assert response.status_code == 200 assert response.status_code == 200
assert response.json()["count"] == 3 assert response.json()["count"] == 2
assert sorted(response.json()["results"], key=lambda x: x["id"]) == sorted( assert sorted(response.json()["results"], key=lambda x: x["id"]) == sorted(
[ [
{ {
"id": str(user_access.id), "id": str(user_access.id),
"user": str(user.id), "user": {
"role": user_access.role, "id": str(user_access.user.id),
"email": str(identity.email),
"name": str(identity.name),
},
"role": str(user_access.role),
"abilities": user_access.get_abilities(user), "abilities": user_access.get_abilities(user),
}, },
{ {
"id": str(access1.id), "id": str(access1.id),
"user": str(access1.user.id), "user": {
"role": access1.role, "id": str(access1.user.id),
"email": str(other_member_identity.email),
"name": str(other_member_identity.name),
},
"role": str(access1.role),
"abilities": access1.get_abilities(user), "abilities": access1.get_abilities(user),
}, },
{
"id": str(access2.id),
"user": str(access2.user.id),
"role": access2.role,
"abilities": access2.get_abilities(user),
},
], ],
key=lambda x: x["id"], key=lambda x: x["id"],
) )
def test_api_team_accesses_list_authenticated_main_identity():
"""
Name and email should be returned from main identity only
"""
user = factories.UserFactory()
identity = factories.IdentityFactory(user=user, is_main=True)
factories.IdentityFactory(user=user) # additional non-main identity
team = factories.TeamFactory()
models.TeamAccess.objects.create(team=team, user=user) # random role
# other team members should appear, with correct identity
other_user = factories.UserFactory()
other_main_identity = factories.IdentityFactory(is_main=True, user=other_user)
factories.IdentityFactory(user=other_user)
factories.TeamAccessFactory.create(team=team, user=other_user)
# Accesses for other teams to which the user is related should not be listed either
other_access = factories.TeamAccessFactory(user=user)
factories.TeamAccessFactory(team=other_access.team)
client = APIClient()
client.force_login(user)
response = client.get(
f"/api/v1.0/teams/{team.id!s}/accesses/",
)
assert response.status_code == 200
assert response.json()["count"] == 2
users_info = [
(access["user"]["email"], access["user"]["name"])
for access in response.json()["results"]
]
# user information should be returned from main identity
assert sorted(users_info) == sorted(
[
(str(identity.email), str(identity.name)),
(str(other_main_identity.email), str(other_main_identity.name)),
]
)
def test_api_team_accesses_list_authenticated_constant_numqueries(
django_assert_num_queries,
):
"""
The number of queries should not depend on the amount of fetched accesses.
"""
user = factories.UserFactory()
factories.IdentityFactory(user=user, is_main=True)
team = factories.TeamFactory()
models.TeamAccess.objects.create(team=team, user=user) # random role
client = APIClient()
client.force_login(user)
# Only 4 queries are needed to efficiently fetch team accesses,
# related users and identities :
# - query retrieving logged-in user for user_role annotation
# - count from pagination
# - query prefetching users' main identity
# - distinct from viewset
with django_assert_num_queries(4):
response = client.get(
f"/api/v1.0/teams/{team.id!s}/accesses/",
)
# create 20 new team members
for _ in range(20):
extra_user = factories.IdentityFactory(is_main=True).user
factories.TeamAccessFactory(team=team, user=extra_user)
# num queries should still be 4
with django_assert_num_queries(4):
response = client.get(
f"/api/v1.0/teams/{team.id!s}/accesses/",
)
assert response.status_code == 200
assert response.json()["count"] == 21

View File

@@ -33,26 +33,23 @@ def test_api_team_accesses_retrieve_authenticated_unrelated():
identity = factories.IdentityFactory() identity = factories.IdentityFactory()
user = identity.user user = identity.user
team = factories.TeamFactory() access = factories.TeamAccessFactory(team=factories.TeamFactory())
access = factories.TeamAccessFactory(team=team)
client = APIClient() client = APIClient()
client.force_login(user) client.force_login(user)
response = client.get( response = client.get(
f"/api/v1.0/teams/{team.id!s}/accesses/{access.id!s}/", f"/api/v1.0/teams/{access.team.id!s}/accesses/{access.id!s}/",
) )
assert response.status_code == 403 assert response.status_code == 404
assert response.json() == { assert response.json() == {"detail": "Not found."}
"detail": "You do not have permission to perform this action."
}
# Accesses related to another team should be excluded even if the user is related to it # Accesses related to another team should be excluded even if the user is related to it
for access in [ for other_access in [
factories.TeamAccessFactory(), factories.TeamAccessFactory(),
factories.TeamAccessFactory(user=user), factories.TeamAccessFactory(user=user),
]: ]:
response = client.get( response = client.get(
f"/api/v1.0/teams/{team.id!s}/accesses/{access.id!s}/", f"/api/v1.0/teams/{access.team.id!s}/accesses/{other_access.id!s}/",
) )
assert response.status_code == 404 assert response.status_code == 404
@@ -64,11 +61,11 @@ def test_api_team_accesses_retrieve_authenticated_related():
A user who is related to a team should be allowed to retrieve the A user who is related to a team should be allowed to retrieve the
associated team user accesses. associated team user accesses.
""" """
identity = factories.IdentityFactory() identity = factories.IdentityFactory(is_main=True)
user = identity.user user = identity.user
team = factories.TeamFactory(users=[user]) team = factories.TeamFactory()
access = factories.TeamAccessFactory(team=team) access = factories.TeamAccessFactory(team=team, user=user)
client = APIClient() client = APIClient()
client.force_login(user) client.force_login(user)
@@ -79,7 +76,11 @@ def test_api_team_accesses_retrieve_authenticated_related():
assert response.status_code == 200 assert response.status_code == 200
assert response.json() == { assert response.json() == {
"id": str(access.id), "id": str(access.id),
"user": str(access.user.id), "user": {
"role": access.role, "id": str(access.user.id),
"email": str(identity.email),
"name": str(identity.name),
},
"role": str(access.role),
"abilities": access.get_abilities(user), "abilities": access.get_abilities(user),
} }

View File

@@ -58,24 +58,13 @@ def test_api_teams_retrieve_authenticated_related():
response = client.get( response = client.get(
f"/api/v1.0/teams/{team.id!s}/", f"/api/v1.0/teams/{team.id!s}/",
) )
assert response.status_code == HTTP_200_OK assert response.status_code == HTTP_200_OK
content = response.json() assert sorted(response.json().pop("accesses")) == sorted(
assert sorted(content.pop("accesses"), key=lambda x: x["user"]) == sorted(
[ [
{ str(access1.id),
"id": str(access1.id), str(access2.id),
"user": str(user.id), ]
"role": access1.role,
"abilities": access1.get_abilities(user),
},
{
"id": str(access2.id),
"user": str(access2.user.id),
"role": access2.role,
"abilities": access2.get_abilities(user),
},
],
key=lambda x: x["user"],
) )
assert response.json() == { assert response.json() == {
"id": str(team.id), "id": str(team.id),

View File

@@ -58,10 +58,10 @@ def test_api_users_authenticated_list_by_email():
client = APIClient() client = APIClient()
client.force_login(user) client.force_login(user)
dave = factories.IdentityFactory(email="david.bowman@work.com") dave = factories.IdentityFactory(email="david.bowman@work.com", is_main=True)
nicole = factories.IdentityFactory(email="nicole_foole@work.com") nicole = factories.IdentityFactory(email="nicole_foole@work.com", is_main=True)
frank = factories.IdentityFactory(email="frank_poole@work.com") frank = factories.IdentityFactory(email="frank_poole@work.com", is_main=True)
factories.IdentityFactory(email="heywood_floyd@work.com") factories.IdentityFactory(email="heywood_floyd@work.com", is_main=True)
# Full query should work # Full query should work
response = client.get( response = client.get(
@@ -91,8 +91,26 @@ def test_api_users_authenticated_list_by_email():
response = client.get("/api/v1.0/users/?q=ool") response = client.get("/api/v1.0/users/?q=ool")
assert response.status_code == HTTP_200_OK assert response.status_code == HTTP_200_OK
user_ids = [user["id"] for user in response.json()["results"]] assert response.json()["results"] == [
assert user_ids == [str(nicole.user.id), str(frank.user.id)] {
"id": str(nicole.user.id),
"email": nicole.email,
"name": nicole.name,
"is_device": nicole.user.is_device,
"is_staff": nicole.user.is_staff,
"language": nicole.user.language,
"timezone": str(nicole.user.timezone),
},
{
"id": str(frank.user.id),
"email": frank.email,
"name": frank.name,
"is_device": frank.user.is_device,
"is_staff": frank.user.is_staff,
"language": frank.user.language,
"timezone": str(frank.user.timezone),
},
]
def test_api_users_authenticated_list_multiple_identities_single_user(): def test_api_users_authenticated_list_multiple_identities_single_user():
@@ -132,12 +150,16 @@ def test_api_users_authenticated_list_multiple_identities_multiple_users():
client.force_login(user) client.force_login(user)
dave = factories.UserFactory() dave = factories.UserFactory()
davina = factories.UserFactory() dave_identity = factories.IdentityFactory(
prudence = factories.UserFactory() user=dave, email="david.bowman@work.com", is_main=True
factories.IdentityFactory(user=dave, email="david.bowman@work.com") )
factories.IdentityFactory(user=dave, email="babibou@ehehe.com") factories.IdentityFactory(user=dave, email="babibou@ehehe.com")
factories.IdentityFactory(user=davina, email="davina.bowan@work.com") davina_identity = factories.IdentityFactory(
factories.IdentityFactory(user=prudence, email="prudence.crandall@work.com") user=factories.UserFactory(), email="davina.bowan@work.com"
)
prue_identity = factories.IdentityFactory(
user=factories.UserFactory(), email="prudence.crandall@work.com"
)
# Full query should work # Full query should work
response = client.get( response = client.get(
@@ -146,8 +168,35 @@ def test_api_users_authenticated_list_multiple_identities_multiple_users():
assert response.status_code == HTTP_200_OK assert response.status_code == HTTP_200_OK
assert response.json()["count"] == 3 assert response.json()["count"] == 3
user_ids = [user["id"] for user in response.json()["results"]] assert response.json()["results"] == [
assert user_ids == [str(dave.id), str(davina.id), str(prudence.id)] {
"id": str(dave.id),
"email": dave_identity.email,
"name": dave_identity.name,
"is_device": dave_identity.user.is_device,
"is_staff": dave_identity.user.is_staff,
"language": dave_identity.user.language,
"timezone": str(dave_identity.user.timezone),
},
{
"id": str(davina_identity.user.id),
"email": davina_identity.email,
"name": davina_identity.name,
"is_device": davina_identity.user.is_device,
"is_staff": davina_identity.user.is_staff,
"language": davina_identity.user.language,
"timezone": str(davina_identity.user.timezone),
},
{
"id": str(prue_identity.user.id),
"email": prue_identity.email,
"name": prue_identity.name,
"is_device": prue_identity.user.is_device,
"is_staff": prue_identity.user.is_staff,
"language": prue_identity.user.language,
"timezone": str(prue_identity.user.timezone),
},
]
def test_api_users_authenticated_list_uppercase_content(): def test_api_users_authenticated_list_uppercase_content():
@@ -291,7 +340,7 @@ def test_api_users_retrieve_me_anonymous():
def test_api_users_retrieve_me_authenticated(): def test_api_users_retrieve_me_authenticated():
"""Authenticated users should be able to retrieve their own user via the "/users/me" path.""" """Authenticated users should be able to retrieve their own user via the "/users/me" path."""
identity = factories.IdentityFactory() identity = factories.IdentityFactory(is_main=True)
user = identity.user user = identity.user
client = APIClient() client = APIClient()
@@ -310,7 +359,8 @@ def test_api_users_retrieve_me_authenticated():
assert response.status_code == HTTP_200_OK assert response.status_code == HTTP_200_OK
assert response.json() == { assert response.json() == {
"id": str(user.id), "id": str(user.id),
"email": str(user.email), "name": str(identity.name),
"email": str(identity.email),
"language": user.language, "language": user.language,
"timezone": str(user.timezone), "timezone": str(user.timezone),
"is_device": False, "is_device": False,