diff --git a/src/backend/core/api/serializers.py b/src/backend/core/api/serializers.py index 00eee4a..6f6eeac 100644 --- a/src/backend/core/api/serializers.py +++ b/src/backend/core/api/serializers.py @@ -36,13 +36,14 @@ class UserSerializer(serializers.ModelSerializer): model = models.User fields = [ "id", + "email", "data", "language", "timezone", "is_device", "is_staff", ] - read_only_fields = ["id", "data", "is_device", "is_staff"] + read_only_fields = ["id", "email", "data", "is_device", "is_staff"] def get_data(self, user) -> dict: """Return contact data for the user.""" diff --git a/src/backend/core/api/viewsets.py b/src/backend/core/api/viewsets.py index 4e248f6..56781f1 100644 --- a/src/backend/core/api/viewsets.py +++ b/src/backend/core/api/viewsets.py @@ -1,13 +1,6 @@ """API endpoints""" from django.contrib.postgres.search import TrigramSimilarity -from django.core.cache import cache -from django.db.models import ( - Func, - OuterRef, - Q, - Subquery, - Value, -) +from django.db.models import Func, OuterRef, Q, Subquery, Value from rest_framework import ( decorators, @@ -17,11 +10,16 @@ from rest_framework import ( response, viewsets, ) +from rest_framework.throttling import UserRateThrottle from core import models from . import permissions, serializers +EMAIL_SIMILARITY_THRESHOLD = 0.01 +# TrigramSimilarity threshold is lower for searching email than for names, +# to improve matching results + class NestedGenericViewSet(viewsets.GenericViewSet): """ @@ -103,6 +101,22 @@ class Pagination(pagination.PageNumberPagination): page_size_query_param = "page_size" +class BurstRateThrottle(UserRateThrottle): + """ + Throttle rate for minutes. See DRF section in settings for default value. + """ + + scope = "burst" + + +class SustainedRateThrottle(UserRateThrottle): + """ + Throttle rate for hours. See DRF section in settings for default value. + """ + + scope = "sustained" + + # pylint: disable=too-many-ancestors class ContactViewSet( mixins.CreateModelMixin, @@ -116,15 +130,13 @@ class ContactViewSet( permission_classes = [permissions.IsOwnedOrPublic] queryset = models.Contact.objects.all() serializer_class = serializers.ContactSerializer + throttle_classes = [BurstRateThrottle, SustainedRateThrottle] def list(self, request, *args, **kwargs): """Limit listed users by a query with throttle protection.""" user = self.request.user queryset = self.filter_queryset(self.get_queryset()) - if not user.is_authenticated: - return queryset.none() - # Exclude contacts that: queryset = queryset.filter( # - belong to another user (keep public and owned contacts) @@ -150,26 +162,6 @@ class ContactViewSet( .order_by("-similarity") ) - # Throttle protection - key_base = f"throttle-contact-list-{user.id!s}" - key_minute = f"{key_base:s}-minute" - key_hour = f"{key_base:s}-hour" - - try: - count_minute = cache.incr(key_minute) - except ValueError: - cache.set(key_minute, 1, 60) - count_minute = 1 - - try: - count_hour = cache.incr(key_hour) - except ValueError: - cache.set(key_hour, 1, 3600) - count_hour = 1 - - if count_minute > 20 or count_hour > 150: - raise exceptions.Throttled() - serializer = self.get_serializer(queryset, many=True) return response.Response(serializer.data) @@ -181,21 +173,52 @@ class ContactViewSet( class UserViewSet( - mixins.UpdateModelMixin, - viewsets.GenericViewSet, + mixins.UpdateModelMixin, viewsets.GenericViewSet, mixins.ListModelMixin ): - """User ViewSet""" + """ + User viewset for all interactions with user infos and teams. + + GET /api/users/&q=query + Return a list of users whose email matches the query. Similarity is + calculated using trigram similarity, allowing for partial, case + insensitive matches and accentuated queries. + """ permission_classes = [permissions.IsSelf] queryset = models.User.objects.all().select_related("profile_contact") serializer_class = serializers.UserSerializer + throttle_classes = [BurstRateThrottle, SustainedRateThrottle] + pagination_class = Pagination + + def get_queryset(self): + """Limit listed users by a query. Pagination and throttle protection apply.""" + queryset = self.queryset + + if self.action == "list": + # Exclude inactive contacts + queryset = queryset.filter( + is_active=True, + ) + + # Search by case-insensitive and accent-insensitive trigram similarity + if query := self.request.GET.get("q", ""): + similarity = TrigramSimilarity( + Func("email", function="unaccent"), + Func(Value(query), function="unaccent"), + ) + queryset = ( + queryset.annotate(similarity=similarity) + .filter(similarity__gte=EMAIL_SIMILARITY_THRESHOLD) + .order_by("-similarity") + ) + + return queryset @decorators.action( detail=False, methods=["get"], url_name="me", url_path="me", - permission_classes=[permissions.IsAuthenticated], ) def get_me(self, request): """ diff --git a/src/backend/core/tests/test_api_users.py b/src/backend/core/tests/test_api_users.py index c1cad6f..efa9743 100644 --- a/src/backend/core/tests/test_api_users.py +++ b/src/backend/core/tests/test_api_users.py @@ -1,11 +1,15 @@ """ Test users API endpoints in the People core app. """ +from unittest import mock + import pytest +from rest_framework.status import HTTP_200_OK from rest_framework.test import APIClient from core import factories, models from core.api import serializers +from core.api.viewsets import Pagination from .utils import OIDCToken @@ -17,13 +21,15 @@ def test_api_users_list_anonymous(): factories.UserFactory() client = APIClient() response = client.get("/api/v1.0/users/") - assert response.status_code == 404 - assert "Not Found" in response.content.decode("utf-8") + assert response.status_code == 401 + assert "Authentication credentials were not provided." in response.content.decode( + "utf-8" + ) def test_api_users_list_authenticated(): """ - Authenticated users should not be able to list users. + Authenticated users should be able to list all users. """ identity = factories.IdentityFactory() jwt_token = OIDCToken.for_user(identity.user) @@ -32,8 +38,184 @@ def test_api_users_list_authenticated(): response = APIClient().get( "/api/v1.0/users/", HTTP_AUTHORIZATION=f"Bearer {jwt_token}" ) - assert response.status_code == 404 - assert "Not Found" in response.content.decode("utf-8") + assert response.status_code == 200 + assert len(response.json()["results"]) == 3 + + +def test_api_users_authenticated_list_by_email(): + """ + Authenticated users should be able to search users with a case-insensitive and + partial query on the email. + """ + user = factories.UserFactory(email="tester@ministry.fr") + factories.IdentityFactory(user=user, email=user.email) + jwt_token = OIDCToken.for_user(user) + + dave = factories.UserFactory(email="david.bowman@work.com") + nicole = factories.UserFactory(email="nicole_foole@work.com") + frank = factories.UserFactory(email="frank_poole@work.com") + factories.UserFactory(email="heywood_floyd@work.com") + + # Full query should work + response = APIClient().get( + "/api/v1.0/users/?q=david.bowman@work.com", + HTTP_AUTHORIZATION=f"Bearer {jwt_token}", + ) + + assert response.status_code == HTTP_200_OK + user_ids = [user["id"] for user in response.json()["results"]] + assert user_ids[0] == str(dave.id) + + # Partial query should work + response = APIClient().get( + "/api/v1.0/users/?q=fran", HTTP_AUTHORIZATION=f"Bearer {jwt_token}" + ) + + assert response.status_code == 200 + user_ids = [user["id"] for user in response.json()["results"]] + assert user_ids[0] == str(frank.id) + + # Result that matches a trigram twice ranks better than result that matches once + response = APIClient().get( + "/api/v1.0/users/?q=ole", HTTP_AUTHORIZATION=f"Bearer {jwt_token}" + ) + + assert response.status_code == 200 + user_ids = [user["id"] for user in response.json()["results"]] + # "Nicole Foole" matches twice on "ole" + assert user_ids == [str(nicole.id), str(frank.id)] + + # Even with a low similarity threshold, query should match expected emails + response = APIClient().get( + "/api/v1.0/users/?q=ool", HTTP_AUTHORIZATION=f"Bearer {jwt_token}" + ) + + assert response.status_code == 200 + user_ids = [user["id"] for user in response.json()["results"]] + assert user_ids == [str(nicole.id), str(frank.id)] + + +def test_api_users_authenticated_list_uppercase_content(): + """Upper case content should be found by lower case query.""" + user = factories.UserFactory(email="tester@ministry.fr") + factories.IdentityFactory(user=user, email=user.email) + jwt_token = OIDCToken.for_user(user) + + dave = factories.UserFactory(email="DAVID.BOWMAN@INTENSEWORK.COM") + + # Unaccented full address + response = APIClient().get( + "/api/v1.0/users/?q=david.bowman@intensework.com", + HTTP_AUTHORIZATION=f"Bearer {jwt_token}", + ) + + assert response.status_code == 200 + user_ids = [user["id"] for user in response.json()["results"]] + assert user_ids == [str(dave.id)] + + # Partial query + response = APIClient().get( + "/api/v1.0/users/?q=david", HTTP_AUTHORIZATION=f"Bearer {jwt_token}" + ) + + assert response.status_code == 200 + user_ids = [user["id"] for user in response.json()["results"]] + assert user_ids == [str(dave.id)] + + +def test_api_users_list_authenticated_capital_query(): + """Upper case query should find lower case content.""" + user = factories.UserFactory(email="tester@ministry.fr") + factories.IdentityFactory(user=user, email=user.email) + jwt_token = OIDCToken.for_user(user) + + dave = factories.UserFactory(email="david.bowman@work.com") + + # Full uppercase query + response = APIClient().get( + "/api/v1.0/users/?q=DAVID.BOWMAN@WORK.COM", + HTTP_AUTHORIZATION=f"Bearer {jwt_token}", + ) + + assert response.status_code == 200 + user_ids = [user["id"] for user in response.json()["results"]] + assert user_ids == [str(dave.id)] + + # Partial uppercase email + response = APIClient().get( + "/api/v1.0/users/?q=DAVID", HTTP_AUTHORIZATION=f"Bearer {jwt_token}" + ) + + assert response.status_code == 200 + user_ids = [user["id"] for user in response.json()["results"]] + assert user_ids == [str(dave.id)] + + +def test_api_contacts_list_authenticated_accented_query(): + """Accented content should be found by unaccented query.""" + user = factories.UserFactory(email="tester@ministry.fr") + factories.IdentityFactory(user=user, email=user.email) + jwt_token = OIDCToken.for_user(user) + + helene = factories.UserFactory(email="helene.bowman@work.com") + + # Accented full query + response = APIClient().get( + "/api/v1.0/users/?q=hélène.bowman@work.com", + HTTP_AUTHORIZATION=f"Bearer {jwt_token}", + ) + + assert response.status_code == 200 + user_ids = [user["id"] for user in response.json()["results"]] + assert user_ids == [str(helene.id)] + + # Unaccented partial email + response = APIClient().get( + "/api/v1.0/users/?q=hélène", HTTP_AUTHORIZATION=f"Bearer {jwt_token}" + ) + + assert response.status_code == 200 + user_ids = [user["id"] for user in response.json()["results"]] + assert user_ids == [str(helene.id)] + + +@mock.patch.object(Pagination, "get_page_size", return_value=3) +def test_api_users_list_pagination( + _mock_page_size, +): + """Pagination should work as expected.""" + identity = factories.IdentityFactory() + user = identity.user + jwt_token = OIDCToken.for_user(user) + + factories.UserFactory.create_batch(4) + + # Get page 1 + response = APIClient().get( + "/api/v1.0/users/", HTTP_AUTHORIZATION=f"Bearer {jwt_token}" + ) + + assert response.status_code == HTTP_200_OK + content = response.json() + + assert content["count"] == 5 + assert len(content["results"]) == 3 + assert content["next"] == "http://testserver/api/v1.0/users/?page=2" + assert content["previous"] is None + + # Get page 2 + response = APIClient().get( + "/api/v1.0/users/?page=2", HTTP_AUTHORIZATION=f"Bearer {jwt_token}" + ) + + assert response.status_code == HTTP_200_OK + content = response.json() + + assert content["count"] == 5 + assert content["next"] is None + assert content["previous"] == "http://testserver/api/v1.0/users/" + + assert len(content["results"]) == 2 def test_api_users_retrieve_me_anonymous(): @@ -66,6 +248,7 @@ def test_api_users_retrieve_me_authenticated(): assert response.status_code == 200 assert response.json() == { "id": str(user.id), + "email": str(user.email), "language": user.language, "timezone": str(user.timezone), "is_device": False, @@ -128,8 +311,10 @@ def test_api_users_create_anonymous(): "password": "mypassword", }, ) - assert response.status_code == 404 - assert "Not Found" in response.content.decode("utf-8") + assert response.status_code == 401 + assert "Authentication credentials were not provided." in response.content.decode( + "utf-8" + ) assert models.User.objects.exists() is False @@ -148,8 +333,8 @@ def test_api_users_create_authenticated(): format="json", HTTP_AUTHORIZATION=f"Bearer {jwt_token}", ) - assert response.status_code == 404 - assert "Not Found" in response.content.decode("utf-8") + assert response.status_code == 405 + assert response.json() == {"detail": 'Method "POST" not allowed.'} assert models.User.objects.exclude(id=user.id).exists() is False @@ -322,7 +507,7 @@ def test_api_users_delete_list_anonymous(): client = APIClient() response = client.delete("/api/v1.0/users/") - assert response.status_code == 404 + assert response.status_code == 401 assert models.User.objects.count() == 2 @@ -337,7 +522,7 @@ def test_api_users_delete_list_authenticated(): "/api/v1.0/users/", HTTP_AUTHORIZATION=f"Bearer {jwt_token}" ) - assert response.status_code == 404 + assert response.status_code == 405 assert models.User.objects.count() == 3 diff --git a/src/backend/core/tests/test_throttle.py b/src/backend/core/tests/test_throttle.py new file mode 100644 index 0000000..f3bbf15 --- /dev/null +++ b/src/backend/core/tests/test_throttle.py @@ -0,0 +1,36 @@ +""" +Test Throttle in People's app. +""" +import pytest +from rest_framework.test import APIClient + +from core import factories + +from people.settings import REST_FRAMEWORK # pylint: disable=E0611 + +from .utils import OIDCToken + +pytestmark = pytest.mark.django_db + + +def test_throttle(): + """ + Throttle protection should block requests if too many. + """ + identity = factories.IdentityFactory() + user = identity.user + jwt_token = OIDCToken.for_user(user) + + client = APIClient() + endpoint = "/api/v1.0/users/" + + # loop to activate throttle protection + throttle_limit = int( + REST_FRAMEWORK["DEFAULT_THROTTLE_RATES"]["burst"].replace("/minute", "") + ) + for _ in range(0, throttle_limit): + client.get(endpoint, HTTP_AUTHORIZATION=f"Bearer {jwt_token}") + + # this call should err + response = client.get(endpoint, HTTP_AUTHORIZATION=f"Bearer {jwt_token}") + assert response.status_code == 429 # too many requests diff --git a/src/backend/people/settings.py b/src/backend/people/settings.py index 2a44cdc..d7db127 100755 --- a/src/backend/people/settings.py +++ b/src/backend/people/settings.py @@ -217,11 +217,13 @@ class Base(Configuration): "rest_framework.parsers.JSONParser", "nested_multipart_parser.drf.DrfNestedParser", ], + "DEFAULT_PERMISSION_CLASSES": ("rest_framework.permissions.IsAuthenticated",), "EXCEPTION_HANDLER": "core.api.exception_handler", "DEFAULT_PAGINATION_CLASS": "rest_framework.pagination.PageNumberPagination", "PAGE_SIZE": 20, "DEFAULT_VERSIONING_CLASS": "rest_framework.versioning.URLPathVersioning", "DEFAULT_SCHEMA_CLASS": "drf_spectacular.openapi.AutoSchema", + "DEFAULT_THROTTLE_RATES": {"sustained": "150/hour", "burst": "20/minute"}, } SPECTACULAR_SETTINGS = {