From 7ef67037c39af27c1d089b06b27a6e0273f86dd1 Mon Sep 17 00:00:00 2001 From: Marie PUPO JEAMMET Date: Thu, 14 Mar 2024 19:30:49 +0100 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8(backend)=20convert=20invitations=20to?= =?UTF-8?q?=20accesses?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Convert related invitations to accesses upon creating a new identity. --- src/backend/core/models.py | 28 ++++- .../core/tests/test_models_invitations.py | 100 +++++++++++++++++- src/backend/pyproject.toml | 1 + 3 files changed, 127 insertions(+), 2 deletions(-) diff --git a/src/backend/core/models.py b/src/backend/core/models.py index ddc9983..ffb112f 100644 --- a/src/backend/core/models.py +++ b/src/backend/core/models.py @@ -284,8 +284,34 @@ class Identity(BaseModel): return f"{id_str:s}{main_str:s}" def save(self, *args, **kwargs): - """Ensure users always have one and only one main identity.""" + """ + Saves identity, ensuring users always have exactly one main identity. + Also converts valid invitations to accesses upon creating an identity. + """ + + # If new identity, convert all valid invitations to team accesses. + # Expired invitations are ignored. + if self._state.adding: + active_invitations = Invitation.objects.filter( + email=self.email, + created_at__gte=timezone.now() + - timedelta(seconds=settings.INVITATION_VALIDITY_DURATION), + ).select_related("team") + + if active_invitations.exists(): + TeamAccess.objects.bulk_create( + [ + TeamAccess( + user=self.user, team=invitation.team, role=invitation.role + ) + for invitation in active_invitations + ] + ) + active_invitations.delete() + super().save(*args, **kwargs) + + # Ensure users always have one and only one main identity. if self.is_main is True: self.user.identities.exclude(id=self.id).update(is_main=False) diff --git a/src/backend/core/tests/test_models_invitations.py b/src/backend/core/tests/test_models_invitations.py index 2a15eb4..8f487fe 100644 --- a/src/backend/core/tests/test_models_invitations.py +++ b/src/backend/core/tests/test_models_invitations.py @@ -3,17 +3,23 @@ Unit tests for the Invitation model """ import time +import uuid from django.contrib.auth.models import AnonymousUser from django.core import exceptions import pytest +from faker import Faker +from freezegun import freeze_time -from core import factories +from core import factories, models pytestmark = pytest.mark.django_db +fake = Faker() + + def test_models_invitations_readonly_after_create(): """Existing invitations should be readonly.""" invitation = factories.InvitationFactory() @@ -73,6 +79,98 @@ def test_models_invitations__is_expired(settings): assert expired_invitation.is_expired is True +def test_models_invitation__new_user__convert_invitations_to_accesses(): + """ + Upon creating a new identity, invitations linked to that email + should be converted to accesses and then deleted. + """ + # Two invitations to the same mail but to different teams + invitation_to_team1 = factories.InvitationFactory() + invitation_to_team2 = factories.InvitationFactory(email=invitation_to_team1.email) + + other_invitation = factories.InvitationFactory( + team=invitation_to_team2.team + ) # another person invited to team2 + + new_identity = factories.IdentityFactory( + is_main=True, email=invitation_to_team1.email + ) + + # The invitation regarding + assert models.TeamAccess.objects.filter( + team=invitation_to_team1.team, user=new_identity.user + ).exists() + assert models.TeamAccess.objects.filter( + team=invitation_to_team2.team, user=new_identity.user + ).exists() + assert not models.Invitation.objects.filter( + team=invitation_to_team1.team, email=invitation_to_team1.email + ).exists() # invitation "consumed" + assert not models.Invitation.objects.filter( + team=invitation_to_team2.team, email=invitation_to_team2.email + ).exists() # invitation "consumed" + assert models.Invitation.objects.filter( + team=invitation_to_team2.team, email=other_invitation.email + ).exists() # the other invitation remains + + +def test_models_invitation__new_user__filter_expired_invitations(): + """ + Upon creating a new identity, valid invitations should be converted into accesses + and expired invitations should remain unchanged. + """ + with freeze_time("2020-01-01"): + expired_invitation = factories.InvitationFactory() + user_email = expired_invitation.email + valid_invitation = factories.InvitationFactory(email=user_email) + + new_identity = factories.IdentityFactory(is_main=True, email=user_email) + + # valid invitation should have granted access to the related team + assert models.TeamAccess.objects.filter( + team=valid_invitation.team, user=new_identity.user + ).exists() + assert not models.Invitation.objects.filter( + team=valid_invitation.team, email=user_email + ).exists() + + # expired invitation should not have been consumed + assert not models.TeamAccess.objects.filter( + team=expired_invitation.team, user=new_identity.user + ).exists() + assert models.Invitation.objects.filter( + team=expired_invitation.team, email=user_email + ).exists() + + +@pytest.mark.parametrize("num_invitations, num_queries", [(0, 8), (1, 11), (20, 11)]) +def test_models_invitation__new_user__user_creation_constant_num_queries( + django_assert_num_queries, num_invitations, num_queries +): + """ + The number of queries executed during user creation should not be proportional + to the number of invitations being processed. + """ + user_email = fake.email() + + if num_invitations != 0: + for _ in range(0, num_invitations): + factories.InvitationFactory(email=user_email, team=factories.TeamFactory()) + + user = factories.UserFactory() + + # with no invitation, we skip an "if", resulting in 8 requests + # otherwise, we should have 11 queries with any number of invitations + with django_assert_num_queries(num_queries): + models.Identity.objects.create( + is_main=True, + email=user_email, + user=user, + name="Prudence C.", + sub=uuid.uuid4(), + ) + + # get_abilities diff --git a/src/backend/pyproject.toml b/src/backend/pyproject.toml index ce7890e..7156618 100644 --- a/src/backend/pyproject.toml +++ b/src/backend/pyproject.toml @@ -77,6 +77,7 @@ dev = [ "responses==0.25.0", "ruff==0.3.2", "types-requests==2.31.0.20240311", + "freezegun==1.4.0", ] [tool.setuptools]