(api) search users by email (#16)

* (api) search users by email

The front end should be able to search users by email.
To that goal, we added a list method to the users viewset
thus creating the /users/ endpoint.
Results are filtered based on similarity with the query,
based on what preexisted for the /contacts/ endpoint.

* (api) test list users by email

Test search when complete, partial query,
accentuated and capital.
Also, lower similarity threshold for user search by email
as it was too high for some tests to pass.

* 💡(api) improve documentation and test comments

Improve user viewset documentation
and comments describing tests sections

Co-authored-by: aleb_the_flash <45729124+lebaudantoine@users.noreply.github.com>
Co-authored-by: Anthony LC <anthony.le-courric@mail.numerique.gouv.fr>

* 🛂(api) set isAuthenticated as base requirements

Instead of checking permissions or adding decorators
to every viewset, isAuthenticated is set as base requirement.

* 🛂(api) define throttle limits in settings

Use of Djando Rest Framework's throttle options, now set globally
to avoid duplicate code.

* 🩹(api) add email to user serializer

email field added to serializer. Tests modified accordingly.
I added the email field as "read only" to pass tests, but we need to discuss
that point in review.

* 🧱(api) move search logic to queryset

User viewset "list" method was overridden to allow search by email.
This removed the pagination. Instead of manually re-adding pagination at
the end of this method, I moved the search/filter logic to get_queryset,
to leave DRF handle pagination.

* (api) test throttle protection

Test that throttle protection succesfully blocks too many requests.

* 📝(tests) improve tests comment

Fix typos on comments and clarify which setting are tested on test_throttle test
(setting import required disabling pylint false positive error)

Co-authored-by: aleb_the_flash <45729124+lebaudantoine@users.noreply.github.com>

---------

Co-authored-by: aleb_the_flash <45729124+lebaudantoine@users.noreply.github.com>
Co-authored-by: Anthony LC <anthony.le-courric@mail.numerique.gouv.fr>
This commit is contained in:
Marie
2024-01-29 10:14:17 +01:00
committed by GitHub
parent 8f2f47d3b1
commit 269ba42204
5 changed files with 294 additions and 47 deletions

View File

@@ -36,13 +36,14 @@ class UserSerializer(serializers.ModelSerializer):
model = models.User
fields = [
"id",
"email",
"data",
"language",
"timezone",
"is_device",
"is_staff",
]
read_only_fields = ["id", "data", "is_device", "is_staff"]
read_only_fields = ["id", "email", "data", "is_device", "is_staff"]
def get_data(self, user) -> dict:
"""Return contact data for the user."""

View File

@@ -1,13 +1,6 @@
"""API endpoints"""
from django.contrib.postgres.search import TrigramSimilarity
from django.core.cache import cache
from django.db.models import (
Func,
OuterRef,
Q,
Subquery,
Value,
)
from django.db.models import Func, OuterRef, Q, Subquery, Value
from rest_framework import (
decorators,
@@ -17,11 +10,16 @@ from rest_framework import (
response,
viewsets,
)
from rest_framework.throttling import UserRateThrottle
from core import models
from . import permissions, serializers
EMAIL_SIMILARITY_THRESHOLD = 0.01
# TrigramSimilarity threshold is lower for searching email than for names,
# to improve matching results
class NestedGenericViewSet(viewsets.GenericViewSet):
"""
@@ -103,6 +101,22 @@ class Pagination(pagination.PageNumberPagination):
page_size_query_param = "page_size"
class BurstRateThrottle(UserRateThrottle):
"""
Throttle rate for minutes. See DRF section in settings for default value.
"""
scope = "burst"
class SustainedRateThrottle(UserRateThrottle):
"""
Throttle rate for hours. See DRF section in settings for default value.
"""
scope = "sustained"
# pylint: disable=too-many-ancestors
class ContactViewSet(
mixins.CreateModelMixin,
@@ -116,15 +130,13 @@ class ContactViewSet(
permission_classes = [permissions.IsOwnedOrPublic]
queryset = models.Contact.objects.all()
serializer_class = serializers.ContactSerializer
throttle_classes = [BurstRateThrottle, SustainedRateThrottle]
def list(self, request, *args, **kwargs):
"""Limit listed users by a query with throttle protection."""
user = self.request.user
queryset = self.filter_queryset(self.get_queryset())
if not user.is_authenticated:
return queryset.none()
# Exclude contacts that:
queryset = queryset.filter(
# - belong to another user (keep public and owned contacts)
@@ -150,26 +162,6 @@ class ContactViewSet(
.order_by("-similarity")
)
# Throttle protection
key_base = f"throttle-contact-list-{user.id!s}"
key_minute = f"{key_base:s}-minute"
key_hour = f"{key_base:s}-hour"
try:
count_minute = cache.incr(key_minute)
except ValueError:
cache.set(key_minute, 1, 60)
count_minute = 1
try:
count_hour = cache.incr(key_hour)
except ValueError:
cache.set(key_hour, 1, 3600)
count_hour = 1
if count_minute > 20 or count_hour > 150:
raise exceptions.Throttled()
serializer = self.get_serializer(queryset, many=True)
return response.Response(serializer.data)
@@ -181,21 +173,52 @@ class ContactViewSet(
class UserViewSet(
mixins.UpdateModelMixin,
viewsets.GenericViewSet,
mixins.UpdateModelMixin, viewsets.GenericViewSet, mixins.ListModelMixin
):
"""User ViewSet"""
"""
User viewset for all interactions with user infos and teams.
GET /api/users/&q=query
Return a list of users whose email matches the query. Similarity is
calculated using trigram similarity, allowing for partial, case
insensitive matches and accentuated queries.
"""
permission_classes = [permissions.IsSelf]
queryset = models.User.objects.all().select_related("profile_contact")
serializer_class = serializers.UserSerializer
throttle_classes = [BurstRateThrottle, SustainedRateThrottle]
pagination_class = Pagination
def get_queryset(self):
"""Limit listed users by a query. Pagination and throttle protection apply."""
queryset = self.queryset
if self.action == "list":
# Exclude inactive contacts
queryset = queryset.filter(
is_active=True,
)
# Search by case-insensitive and accent-insensitive trigram similarity
if query := self.request.GET.get("q", ""):
similarity = TrigramSimilarity(
Func("email", function="unaccent"),
Func(Value(query), function="unaccent"),
)
queryset = (
queryset.annotate(similarity=similarity)
.filter(similarity__gte=EMAIL_SIMILARITY_THRESHOLD)
.order_by("-similarity")
)
return queryset
@decorators.action(
detail=False,
methods=["get"],
url_name="me",
url_path="me",
permission_classes=[permissions.IsAuthenticated],
)
def get_me(self, request):
"""

View File

@@ -1,11 +1,15 @@
"""
Test users API endpoints in the People core app.
"""
from unittest import mock
import pytest
from rest_framework.status import HTTP_200_OK
from rest_framework.test import APIClient
from core import factories, models
from core.api import serializers
from core.api.viewsets import Pagination
from .utils import OIDCToken
@@ -17,13 +21,15 @@ def test_api_users_list_anonymous():
factories.UserFactory()
client = APIClient()
response = client.get("/api/v1.0/users/")
assert response.status_code == 404
assert "Not Found" in response.content.decode("utf-8")
assert response.status_code == 401
assert "Authentication credentials were not provided." in response.content.decode(
"utf-8"
)
def test_api_users_list_authenticated():
"""
Authenticated users should not be able to list users.
Authenticated users should be able to list all users.
"""
identity = factories.IdentityFactory()
jwt_token = OIDCToken.for_user(identity.user)
@@ -32,8 +38,184 @@ def test_api_users_list_authenticated():
response = APIClient().get(
"/api/v1.0/users/", HTTP_AUTHORIZATION=f"Bearer {jwt_token}"
)
assert response.status_code == 404
assert "Not Found" in response.content.decode("utf-8")
assert response.status_code == 200
assert len(response.json()["results"]) == 3
def test_api_users_authenticated_list_by_email():
"""
Authenticated users should be able to search users with a case-insensitive and
partial query on the email.
"""
user = factories.UserFactory(email="tester@ministry.fr")
factories.IdentityFactory(user=user, email=user.email)
jwt_token = OIDCToken.for_user(user)
dave = factories.UserFactory(email="david.bowman@work.com")
nicole = factories.UserFactory(email="nicole_foole@work.com")
frank = factories.UserFactory(email="frank_poole@work.com")
factories.UserFactory(email="heywood_floyd@work.com")
# Full query should work
response = APIClient().get(
"/api/v1.0/users/?q=david.bowman@work.com",
HTTP_AUTHORIZATION=f"Bearer {jwt_token}",
)
assert response.status_code == HTTP_200_OK
user_ids = [user["id"] for user in response.json()["results"]]
assert user_ids[0] == str(dave.id)
# Partial query should work
response = APIClient().get(
"/api/v1.0/users/?q=fran", HTTP_AUTHORIZATION=f"Bearer {jwt_token}"
)
assert response.status_code == 200
user_ids = [user["id"] for user in response.json()["results"]]
assert user_ids[0] == str(frank.id)
# Result that matches a trigram twice ranks better than result that matches once
response = APIClient().get(
"/api/v1.0/users/?q=ole", HTTP_AUTHORIZATION=f"Bearer {jwt_token}"
)
assert response.status_code == 200
user_ids = [user["id"] for user in response.json()["results"]]
# "Nicole Foole" matches twice on "ole"
assert user_ids == [str(nicole.id), str(frank.id)]
# Even with a low similarity threshold, query should match expected emails
response = APIClient().get(
"/api/v1.0/users/?q=ool", HTTP_AUTHORIZATION=f"Bearer {jwt_token}"
)
assert response.status_code == 200
user_ids = [user["id"] for user in response.json()["results"]]
assert user_ids == [str(nicole.id), str(frank.id)]
def test_api_users_authenticated_list_uppercase_content():
"""Upper case content should be found by lower case query."""
user = factories.UserFactory(email="tester@ministry.fr")
factories.IdentityFactory(user=user, email=user.email)
jwt_token = OIDCToken.for_user(user)
dave = factories.UserFactory(email="DAVID.BOWMAN@INTENSEWORK.COM")
# Unaccented full address
response = APIClient().get(
"/api/v1.0/users/?q=david.bowman@intensework.com",
HTTP_AUTHORIZATION=f"Bearer {jwt_token}",
)
assert response.status_code == 200
user_ids = [user["id"] for user in response.json()["results"]]
assert user_ids == [str(dave.id)]
# Partial query
response = APIClient().get(
"/api/v1.0/users/?q=david", HTTP_AUTHORIZATION=f"Bearer {jwt_token}"
)
assert response.status_code == 200
user_ids = [user["id"] for user in response.json()["results"]]
assert user_ids == [str(dave.id)]
def test_api_users_list_authenticated_capital_query():
"""Upper case query should find lower case content."""
user = factories.UserFactory(email="tester@ministry.fr")
factories.IdentityFactory(user=user, email=user.email)
jwt_token = OIDCToken.for_user(user)
dave = factories.UserFactory(email="david.bowman@work.com")
# Full uppercase query
response = APIClient().get(
"/api/v1.0/users/?q=DAVID.BOWMAN@WORK.COM",
HTTP_AUTHORIZATION=f"Bearer {jwt_token}",
)
assert response.status_code == 200
user_ids = [user["id"] for user in response.json()["results"]]
assert user_ids == [str(dave.id)]
# Partial uppercase email
response = APIClient().get(
"/api/v1.0/users/?q=DAVID", HTTP_AUTHORIZATION=f"Bearer {jwt_token}"
)
assert response.status_code == 200
user_ids = [user["id"] for user in response.json()["results"]]
assert user_ids == [str(dave.id)]
def test_api_contacts_list_authenticated_accented_query():
"""Accented content should be found by unaccented query."""
user = factories.UserFactory(email="tester@ministry.fr")
factories.IdentityFactory(user=user, email=user.email)
jwt_token = OIDCToken.for_user(user)
helene = factories.UserFactory(email="helene.bowman@work.com")
# Accented full query
response = APIClient().get(
"/api/v1.0/users/?q=hélène.bowman@work.com",
HTTP_AUTHORIZATION=f"Bearer {jwt_token}",
)
assert response.status_code == 200
user_ids = [user["id"] for user in response.json()["results"]]
assert user_ids == [str(helene.id)]
# Unaccented partial email
response = APIClient().get(
"/api/v1.0/users/?q=hélène", HTTP_AUTHORIZATION=f"Bearer {jwt_token}"
)
assert response.status_code == 200
user_ids = [user["id"] for user in response.json()["results"]]
assert user_ids == [str(helene.id)]
@mock.patch.object(Pagination, "get_page_size", return_value=3)
def test_api_users_list_pagination(
_mock_page_size,
):
"""Pagination should work as expected."""
identity = factories.IdentityFactory()
user = identity.user
jwt_token = OIDCToken.for_user(user)
factories.UserFactory.create_batch(4)
# Get page 1
response = APIClient().get(
"/api/v1.0/users/", HTTP_AUTHORIZATION=f"Bearer {jwt_token}"
)
assert response.status_code == HTTP_200_OK
content = response.json()
assert content["count"] == 5
assert len(content["results"]) == 3
assert content["next"] == "http://testserver/api/v1.0/users/?page=2"
assert content["previous"] is None
# Get page 2
response = APIClient().get(
"/api/v1.0/users/?page=2", HTTP_AUTHORIZATION=f"Bearer {jwt_token}"
)
assert response.status_code == HTTP_200_OK
content = response.json()
assert content["count"] == 5
assert content["next"] is None
assert content["previous"] == "http://testserver/api/v1.0/users/"
assert len(content["results"]) == 2
def test_api_users_retrieve_me_anonymous():
@@ -66,6 +248,7 @@ def test_api_users_retrieve_me_authenticated():
assert response.status_code == 200
assert response.json() == {
"id": str(user.id),
"email": str(user.email),
"language": user.language,
"timezone": str(user.timezone),
"is_device": False,
@@ -128,8 +311,10 @@ def test_api_users_create_anonymous():
"password": "mypassword",
},
)
assert response.status_code == 404
assert "Not Found" in response.content.decode("utf-8")
assert response.status_code == 401
assert "Authentication credentials were not provided." in response.content.decode(
"utf-8"
)
assert models.User.objects.exists() is False
@@ -148,8 +333,8 @@ def test_api_users_create_authenticated():
format="json",
HTTP_AUTHORIZATION=f"Bearer {jwt_token}",
)
assert response.status_code == 404
assert "Not Found" in response.content.decode("utf-8")
assert response.status_code == 405
assert response.json() == {"detail": 'Method "POST" not allowed.'}
assert models.User.objects.exclude(id=user.id).exists() is False
@@ -322,7 +507,7 @@ def test_api_users_delete_list_anonymous():
client = APIClient()
response = client.delete("/api/v1.0/users/")
assert response.status_code == 404
assert response.status_code == 401
assert models.User.objects.count() == 2
@@ -337,7 +522,7 @@ def test_api_users_delete_list_authenticated():
"/api/v1.0/users/", HTTP_AUTHORIZATION=f"Bearer {jwt_token}"
)
assert response.status_code == 404
assert response.status_code == 405
assert models.User.objects.count() == 3

View File

@@ -0,0 +1,36 @@
"""
Test Throttle in People's app.
"""
import pytest
from rest_framework.test import APIClient
from core import factories
from people.settings import REST_FRAMEWORK # pylint: disable=E0611
from .utils import OIDCToken
pytestmark = pytest.mark.django_db
def test_throttle():
"""
Throttle protection should block requests if too many.
"""
identity = factories.IdentityFactory()
user = identity.user
jwt_token = OIDCToken.for_user(user)
client = APIClient()
endpoint = "/api/v1.0/users/"
# loop to activate throttle protection
throttle_limit = int(
REST_FRAMEWORK["DEFAULT_THROTTLE_RATES"]["burst"].replace("/minute", "")
)
for _ in range(0, throttle_limit):
client.get(endpoint, HTTP_AUTHORIZATION=f"Bearer {jwt_token}")
# this call should err
response = client.get(endpoint, HTTP_AUTHORIZATION=f"Bearer {jwt_token}")
assert response.status_code == 429 # too many requests

View File

@@ -217,11 +217,13 @@ class Base(Configuration):
"rest_framework.parsers.JSONParser",
"nested_multipart_parser.drf.DrfNestedParser",
],
"DEFAULT_PERMISSION_CLASSES": ("rest_framework.permissions.IsAuthenticated",),
"EXCEPTION_HANDLER": "core.api.exception_handler",
"DEFAULT_PAGINATION_CLASS": "rest_framework.pagination.PageNumberPagination",
"PAGE_SIZE": 20,
"DEFAULT_VERSIONING_CLASS": "rest_framework.versioning.URLPathVersioning",
"DEFAULT_SCHEMA_CLASS": "drf_spectacular.openapi.AutoSchema",
"DEFAULT_THROTTLE_RATES": {"sustained": "150/hour", "burst": "20/minute"},
}
SPECTACULAR_SETTINGS = {