✨(api) search users by email (#16)
* ✨(api) search users by email The front end should be able to search users by email. To that goal, we added a list method to the users viewset thus creating the /users/ endpoint. Results are filtered based on similarity with the query, based on what preexisted for the /contacts/ endpoint. * ✅(api) test list users by email Test search when complete, partial query, accentuated and capital. Also, lower similarity threshold for user search by email as it was too high for some tests to pass. * 💡(api) improve documentation and test comments Improve user viewset documentation and comments describing tests sections Co-authored-by: aleb_the_flash <45729124+lebaudantoine@users.noreply.github.com> Co-authored-by: Anthony LC <anthony.le-courric@mail.numerique.gouv.fr> * 🛂(api) set isAuthenticated as base requirements Instead of checking permissions or adding decorators to every viewset, isAuthenticated is set as base requirement. * 🛂(api) define throttle limits in settings Use of Djando Rest Framework's throttle options, now set globally to avoid duplicate code. * 🩹(api) add email to user serializer email field added to serializer. Tests modified accordingly. I added the email field as "read only" to pass tests, but we need to discuss that point in review. * 🧱(api) move search logic to queryset User viewset "list" method was overridden to allow search by email. This removed the pagination. Instead of manually re-adding pagination at the end of this method, I moved the search/filter logic to get_queryset, to leave DRF handle pagination. * ✅(api) test throttle protection Test that throttle protection succesfully blocks too many requests. * 📝(tests) improve tests comment Fix typos on comments and clarify which setting are tested on test_throttle test (setting import required disabling pylint false positive error) Co-authored-by: aleb_the_flash <45729124+lebaudantoine@users.noreply.github.com> --------- Co-authored-by: aleb_the_flash <45729124+lebaudantoine@users.noreply.github.com> Co-authored-by: Anthony LC <anthony.le-courric@mail.numerique.gouv.fr>
This commit is contained in:
@@ -36,13 +36,14 @@ class UserSerializer(serializers.ModelSerializer):
|
||||
model = models.User
|
||||
fields = [
|
||||
"id",
|
||||
"email",
|
||||
"data",
|
||||
"language",
|
||||
"timezone",
|
||||
"is_device",
|
||||
"is_staff",
|
||||
]
|
||||
read_only_fields = ["id", "data", "is_device", "is_staff"]
|
||||
read_only_fields = ["id", "email", "data", "is_device", "is_staff"]
|
||||
|
||||
def get_data(self, user) -> dict:
|
||||
"""Return contact data for the user."""
|
||||
|
||||
@@ -1,13 +1,6 @@
|
||||
"""API endpoints"""
|
||||
from django.contrib.postgres.search import TrigramSimilarity
|
||||
from django.core.cache import cache
|
||||
from django.db.models import (
|
||||
Func,
|
||||
OuterRef,
|
||||
Q,
|
||||
Subquery,
|
||||
Value,
|
||||
)
|
||||
from django.db.models import Func, OuterRef, Q, Subquery, Value
|
||||
|
||||
from rest_framework import (
|
||||
decorators,
|
||||
@@ -17,11 +10,16 @@ from rest_framework import (
|
||||
response,
|
||||
viewsets,
|
||||
)
|
||||
from rest_framework.throttling import UserRateThrottle
|
||||
|
||||
from core import models
|
||||
|
||||
from . import permissions, serializers
|
||||
|
||||
EMAIL_SIMILARITY_THRESHOLD = 0.01
|
||||
# TrigramSimilarity threshold is lower for searching email than for names,
|
||||
# to improve matching results
|
||||
|
||||
|
||||
class NestedGenericViewSet(viewsets.GenericViewSet):
|
||||
"""
|
||||
@@ -103,6 +101,22 @@ class Pagination(pagination.PageNumberPagination):
|
||||
page_size_query_param = "page_size"
|
||||
|
||||
|
||||
class BurstRateThrottle(UserRateThrottle):
|
||||
"""
|
||||
Throttle rate for minutes. See DRF section in settings for default value.
|
||||
"""
|
||||
|
||||
scope = "burst"
|
||||
|
||||
|
||||
class SustainedRateThrottle(UserRateThrottle):
|
||||
"""
|
||||
Throttle rate for hours. See DRF section in settings for default value.
|
||||
"""
|
||||
|
||||
scope = "sustained"
|
||||
|
||||
|
||||
# pylint: disable=too-many-ancestors
|
||||
class ContactViewSet(
|
||||
mixins.CreateModelMixin,
|
||||
@@ -116,15 +130,13 @@ class ContactViewSet(
|
||||
permission_classes = [permissions.IsOwnedOrPublic]
|
||||
queryset = models.Contact.objects.all()
|
||||
serializer_class = serializers.ContactSerializer
|
||||
throttle_classes = [BurstRateThrottle, SustainedRateThrottle]
|
||||
|
||||
def list(self, request, *args, **kwargs):
|
||||
"""Limit listed users by a query with throttle protection."""
|
||||
user = self.request.user
|
||||
queryset = self.filter_queryset(self.get_queryset())
|
||||
|
||||
if not user.is_authenticated:
|
||||
return queryset.none()
|
||||
|
||||
# Exclude contacts that:
|
||||
queryset = queryset.filter(
|
||||
# - belong to another user (keep public and owned contacts)
|
||||
@@ -150,26 +162,6 @@ class ContactViewSet(
|
||||
.order_by("-similarity")
|
||||
)
|
||||
|
||||
# Throttle protection
|
||||
key_base = f"throttle-contact-list-{user.id!s}"
|
||||
key_minute = f"{key_base:s}-minute"
|
||||
key_hour = f"{key_base:s}-hour"
|
||||
|
||||
try:
|
||||
count_minute = cache.incr(key_minute)
|
||||
except ValueError:
|
||||
cache.set(key_minute, 1, 60)
|
||||
count_minute = 1
|
||||
|
||||
try:
|
||||
count_hour = cache.incr(key_hour)
|
||||
except ValueError:
|
||||
cache.set(key_hour, 1, 3600)
|
||||
count_hour = 1
|
||||
|
||||
if count_minute > 20 or count_hour > 150:
|
||||
raise exceptions.Throttled()
|
||||
|
||||
serializer = self.get_serializer(queryset, many=True)
|
||||
return response.Response(serializer.data)
|
||||
|
||||
@@ -181,21 +173,52 @@ class ContactViewSet(
|
||||
|
||||
|
||||
class UserViewSet(
|
||||
mixins.UpdateModelMixin,
|
||||
viewsets.GenericViewSet,
|
||||
mixins.UpdateModelMixin, viewsets.GenericViewSet, mixins.ListModelMixin
|
||||
):
|
||||
"""User ViewSet"""
|
||||
"""
|
||||
User viewset for all interactions with user infos and teams.
|
||||
|
||||
GET /api/users/&q=query
|
||||
Return a list of users whose email matches the query. Similarity is
|
||||
calculated using trigram similarity, allowing for partial, case
|
||||
insensitive matches and accentuated queries.
|
||||
"""
|
||||
|
||||
permission_classes = [permissions.IsSelf]
|
||||
queryset = models.User.objects.all().select_related("profile_contact")
|
||||
serializer_class = serializers.UserSerializer
|
||||
throttle_classes = [BurstRateThrottle, SustainedRateThrottle]
|
||||
pagination_class = Pagination
|
||||
|
||||
def get_queryset(self):
|
||||
"""Limit listed users by a query. Pagination and throttle protection apply."""
|
||||
queryset = self.queryset
|
||||
|
||||
if self.action == "list":
|
||||
# Exclude inactive contacts
|
||||
queryset = queryset.filter(
|
||||
is_active=True,
|
||||
)
|
||||
|
||||
# Search by case-insensitive and accent-insensitive trigram similarity
|
||||
if query := self.request.GET.get("q", ""):
|
||||
similarity = TrigramSimilarity(
|
||||
Func("email", function="unaccent"),
|
||||
Func(Value(query), function="unaccent"),
|
||||
)
|
||||
queryset = (
|
||||
queryset.annotate(similarity=similarity)
|
||||
.filter(similarity__gte=EMAIL_SIMILARITY_THRESHOLD)
|
||||
.order_by("-similarity")
|
||||
)
|
||||
|
||||
return queryset
|
||||
|
||||
@decorators.action(
|
||||
detail=False,
|
||||
methods=["get"],
|
||||
url_name="me",
|
||||
url_path="me",
|
||||
permission_classes=[permissions.IsAuthenticated],
|
||||
)
|
||||
def get_me(self, request):
|
||||
"""
|
||||
|
||||
@@ -1,11 +1,15 @@
|
||||
"""
|
||||
Test users API endpoints in the People core app.
|
||||
"""
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
from rest_framework.status import HTTP_200_OK
|
||||
from rest_framework.test import APIClient
|
||||
|
||||
from core import factories, models
|
||||
from core.api import serializers
|
||||
from core.api.viewsets import Pagination
|
||||
|
||||
from .utils import OIDCToken
|
||||
|
||||
@@ -17,13 +21,15 @@ def test_api_users_list_anonymous():
|
||||
factories.UserFactory()
|
||||
client = APIClient()
|
||||
response = client.get("/api/v1.0/users/")
|
||||
assert response.status_code == 404
|
||||
assert "Not Found" in response.content.decode("utf-8")
|
||||
assert response.status_code == 401
|
||||
assert "Authentication credentials were not provided." in response.content.decode(
|
||||
"utf-8"
|
||||
)
|
||||
|
||||
|
||||
def test_api_users_list_authenticated():
|
||||
"""
|
||||
Authenticated users should not be able to list users.
|
||||
Authenticated users should be able to list all users.
|
||||
"""
|
||||
identity = factories.IdentityFactory()
|
||||
jwt_token = OIDCToken.for_user(identity.user)
|
||||
@@ -32,8 +38,184 @@ def test_api_users_list_authenticated():
|
||||
response = APIClient().get(
|
||||
"/api/v1.0/users/", HTTP_AUTHORIZATION=f"Bearer {jwt_token}"
|
||||
)
|
||||
assert response.status_code == 404
|
||||
assert "Not Found" in response.content.decode("utf-8")
|
||||
assert response.status_code == 200
|
||||
assert len(response.json()["results"]) == 3
|
||||
|
||||
|
||||
def test_api_users_authenticated_list_by_email():
|
||||
"""
|
||||
Authenticated users should be able to search users with a case-insensitive and
|
||||
partial query on the email.
|
||||
"""
|
||||
user = factories.UserFactory(email="tester@ministry.fr")
|
||||
factories.IdentityFactory(user=user, email=user.email)
|
||||
jwt_token = OIDCToken.for_user(user)
|
||||
|
||||
dave = factories.UserFactory(email="david.bowman@work.com")
|
||||
nicole = factories.UserFactory(email="nicole_foole@work.com")
|
||||
frank = factories.UserFactory(email="frank_poole@work.com")
|
||||
factories.UserFactory(email="heywood_floyd@work.com")
|
||||
|
||||
# Full query should work
|
||||
response = APIClient().get(
|
||||
"/api/v1.0/users/?q=david.bowman@work.com",
|
||||
HTTP_AUTHORIZATION=f"Bearer {jwt_token}",
|
||||
)
|
||||
|
||||
assert response.status_code == HTTP_200_OK
|
||||
user_ids = [user["id"] for user in response.json()["results"]]
|
||||
assert user_ids[0] == str(dave.id)
|
||||
|
||||
# Partial query should work
|
||||
response = APIClient().get(
|
||||
"/api/v1.0/users/?q=fran", HTTP_AUTHORIZATION=f"Bearer {jwt_token}"
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
user_ids = [user["id"] for user in response.json()["results"]]
|
||||
assert user_ids[0] == str(frank.id)
|
||||
|
||||
# Result that matches a trigram twice ranks better than result that matches once
|
||||
response = APIClient().get(
|
||||
"/api/v1.0/users/?q=ole", HTTP_AUTHORIZATION=f"Bearer {jwt_token}"
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
user_ids = [user["id"] for user in response.json()["results"]]
|
||||
# "Nicole Foole" matches twice on "ole"
|
||||
assert user_ids == [str(nicole.id), str(frank.id)]
|
||||
|
||||
# Even with a low similarity threshold, query should match expected emails
|
||||
response = APIClient().get(
|
||||
"/api/v1.0/users/?q=ool", HTTP_AUTHORIZATION=f"Bearer {jwt_token}"
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
user_ids = [user["id"] for user in response.json()["results"]]
|
||||
assert user_ids == [str(nicole.id), str(frank.id)]
|
||||
|
||||
|
||||
def test_api_users_authenticated_list_uppercase_content():
|
||||
"""Upper case content should be found by lower case query."""
|
||||
user = factories.UserFactory(email="tester@ministry.fr")
|
||||
factories.IdentityFactory(user=user, email=user.email)
|
||||
jwt_token = OIDCToken.for_user(user)
|
||||
|
||||
dave = factories.UserFactory(email="DAVID.BOWMAN@INTENSEWORK.COM")
|
||||
|
||||
# Unaccented full address
|
||||
response = APIClient().get(
|
||||
"/api/v1.0/users/?q=david.bowman@intensework.com",
|
||||
HTTP_AUTHORIZATION=f"Bearer {jwt_token}",
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
user_ids = [user["id"] for user in response.json()["results"]]
|
||||
assert user_ids == [str(dave.id)]
|
||||
|
||||
# Partial query
|
||||
response = APIClient().get(
|
||||
"/api/v1.0/users/?q=david", HTTP_AUTHORIZATION=f"Bearer {jwt_token}"
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
user_ids = [user["id"] for user in response.json()["results"]]
|
||||
assert user_ids == [str(dave.id)]
|
||||
|
||||
|
||||
def test_api_users_list_authenticated_capital_query():
|
||||
"""Upper case query should find lower case content."""
|
||||
user = factories.UserFactory(email="tester@ministry.fr")
|
||||
factories.IdentityFactory(user=user, email=user.email)
|
||||
jwt_token = OIDCToken.for_user(user)
|
||||
|
||||
dave = factories.UserFactory(email="david.bowman@work.com")
|
||||
|
||||
# Full uppercase query
|
||||
response = APIClient().get(
|
||||
"/api/v1.0/users/?q=DAVID.BOWMAN@WORK.COM",
|
||||
HTTP_AUTHORIZATION=f"Bearer {jwt_token}",
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
user_ids = [user["id"] for user in response.json()["results"]]
|
||||
assert user_ids == [str(dave.id)]
|
||||
|
||||
# Partial uppercase email
|
||||
response = APIClient().get(
|
||||
"/api/v1.0/users/?q=DAVID", HTTP_AUTHORIZATION=f"Bearer {jwt_token}"
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
user_ids = [user["id"] for user in response.json()["results"]]
|
||||
assert user_ids == [str(dave.id)]
|
||||
|
||||
|
||||
def test_api_contacts_list_authenticated_accented_query():
|
||||
"""Accented content should be found by unaccented query."""
|
||||
user = factories.UserFactory(email="tester@ministry.fr")
|
||||
factories.IdentityFactory(user=user, email=user.email)
|
||||
jwt_token = OIDCToken.for_user(user)
|
||||
|
||||
helene = factories.UserFactory(email="helene.bowman@work.com")
|
||||
|
||||
# Accented full query
|
||||
response = APIClient().get(
|
||||
"/api/v1.0/users/?q=hélène.bowman@work.com",
|
||||
HTTP_AUTHORIZATION=f"Bearer {jwt_token}",
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
user_ids = [user["id"] for user in response.json()["results"]]
|
||||
assert user_ids == [str(helene.id)]
|
||||
|
||||
# Unaccented partial email
|
||||
response = APIClient().get(
|
||||
"/api/v1.0/users/?q=hélène", HTTP_AUTHORIZATION=f"Bearer {jwt_token}"
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
user_ids = [user["id"] for user in response.json()["results"]]
|
||||
assert user_ids == [str(helene.id)]
|
||||
|
||||
|
||||
@mock.patch.object(Pagination, "get_page_size", return_value=3)
|
||||
def test_api_users_list_pagination(
|
||||
_mock_page_size,
|
||||
):
|
||||
"""Pagination should work as expected."""
|
||||
identity = factories.IdentityFactory()
|
||||
user = identity.user
|
||||
jwt_token = OIDCToken.for_user(user)
|
||||
|
||||
factories.UserFactory.create_batch(4)
|
||||
|
||||
# Get page 1
|
||||
response = APIClient().get(
|
||||
"/api/v1.0/users/", HTTP_AUTHORIZATION=f"Bearer {jwt_token}"
|
||||
)
|
||||
|
||||
assert response.status_code == HTTP_200_OK
|
||||
content = response.json()
|
||||
|
||||
assert content["count"] == 5
|
||||
assert len(content["results"]) == 3
|
||||
assert content["next"] == "http://testserver/api/v1.0/users/?page=2"
|
||||
assert content["previous"] is None
|
||||
|
||||
# Get page 2
|
||||
response = APIClient().get(
|
||||
"/api/v1.0/users/?page=2", HTTP_AUTHORIZATION=f"Bearer {jwt_token}"
|
||||
)
|
||||
|
||||
assert response.status_code == HTTP_200_OK
|
||||
content = response.json()
|
||||
|
||||
assert content["count"] == 5
|
||||
assert content["next"] is None
|
||||
assert content["previous"] == "http://testserver/api/v1.0/users/"
|
||||
|
||||
assert len(content["results"]) == 2
|
||||
|
||||
|
||||
def test_api_users_retrieve_me_anonymous():
|
||||
@@ -66,6 +248,7 @@ def test_api_users_retrieve_me_authenticated():
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {
|
||||
"id": str(user.id),
|
||||
"email": str(user.email),
|
||||
"language": user.language,
|
||||
"timezone": str(user.timezone),
|
||||
"is_device": False,
|
||||
@@ -128,8 +311,10 @@ def test_api_users_create_anonymous():
|
||||
"password": "mypassword",
|
||||
},
|
||||
)
|
||||
assert response.status_code == 404
|
||||
assert "Not Found" in response.content.decode("utf-8")
|
||||
assert response.status_code == 401
|
||||
assert "Authentication credentials were not provided." in response.content.decode(
|
||||
"utf-8"
|
||||
)
|
||||
assert models.User.objects.exists() is False
|
||||
|
||||
|
||||
@@ -148,8 +333,8 @@ def test_api_users_create_authenticated():
|
||||
format="json",
|
||||
HTTP_AUTHORIZATION=f"Bearer {jwt_token}",
|
||||
)
|
||||
assert response.status_code == 404
|
||||
assert "Not Found" in response.content.decode("utf-8")
|
||||
assert response.status_code == 405
|
||||
assert response.json() == {"detail": 'Method "POST" not allowed.'}
|
||||
assert models.User.objects.exclude(id=user.id).exists() is False
|
||||
|
||||
|
||||
@@ -322,7 +507,7 @@ def test_api_users_delete_list_anonymous():
|
||||
client = APIClient()
|
||||
response = client.delete("/api/v1.0/users/")
|
||||
|
||||
assert response.status_code == 404
|
||||
assert response.status_code == 401
|
||||
assert models.User.objects.count() == 2
|
||||
|
||||
|
||||
@@ -337,7 +522,7 @@ def test_api_users_delete_list_authenticated():
|
||||
"/api/v1.0/users/", HTTP_AUTHORIZATION=f"Bearer {jwt_token}"
|
||||
)
|
||||
|
||||
assert response.status_code == 404
|
||||
assert response.status_code == 405
|
||||
assert models.User.objects.count() == 3
|
||||
|
||||
|
||||
|
||||
36
src/backend/core/tests/test_throttle.py
Normal file
36
src/backend/core/tests/test_throttle.py
Normal file
@@ -0,0 +1,36 @@
|
||||
"""
|
||||
Test Throttle in People's app.
|
||||
"""
|
||||
import pytest
|
||||
from rest_framework.test import APIClient
|
||||
|
||||
from core import factories
|
||||
|
||||
from people.settings import REST_FRAMEWORK # pylint: disable=E0611
|
||||
|
||||
from .utils import OIDCToken
|
||||
|
||||
pytestmark = pytest.mark.django_db
|
||||
|
||||
|
||||
def test_throttle():
|
||||
"""
|
||||
Throttle protection should block requests if too many.
|
||||
"""
|
||||
identity = factories.IdentityFactory()
|
||||
user = identity.user
|
||||
jwt_token = OIDCToken.for_user(user)
|
||||
|
||||
client = APIClient()
|
||||
endpoint = "/api/v1.0/users/"
|
||||
|
||||
# loop to activate throttle protection
|
||||
throttle_limit = int(
|
||||
REST_FRAMEWORK["DEFAULT_THROTTLE_RATES"]["burst"].replace("/minute", "")
|
||||
)
|
||||
for _ in range(0, throttle_limit):
|
||||
client.get(endpoint, HTTP_AUTHORIZATION=f"Bearer {jwt_token}")
|
||||
|
||||
# this call should err
|
||||
response = client.get(endpoint, HTTP_AUTHORIZATION=f"Bearer {jwt_token}")
|
||||
assert response.status_code == 429 # too many requests
|
||||
@@ -217,11 +217,13 @@ class Base(Configuration):
|
||||
"rest_framework.parsers.JSONParser",
|
||||
"nested_multipart_parser.drf.DrfNestedParser",
|
||||
],
|
||||
"DEFAULT_PERMISSION_CLASSES": ("rest_framework.permissions.IsAuthenticated",),
|
||||
"EXCEPTION_HANDLER": "core.api.exception_handler",
|
||||
"DEFAULT_PAGINATION_CLASS": "rest_framework.pagination.PageNumberPagination",
|
||||
"PAGE_SIZE": 20,
|
||||
"DEFAULT_VERSIONING_CLASS": "rest_framework.versioning.URLPathVersioning",
|
||||
"DEFAULT_SCHEMA_CLASS": "drf_spectacular.openapi.AutoSchema",
|
||||
"DEFAULT_THROTTLE_RATES": {"sustained": "150/hour", "burst": "20/minute"},
|
||||
}
|
||||
|
||||
SPECTACULAR_SETTINGS = {
|
||||
|
||||
Reference in New Issue
Block a user