diff --git a/src/backend/core/api/filters.py b/src/backend/core/api/filters.py index 38f465a3..42cb79cd 100644 --- a/src/backend/core/api/filters.py +++ b/src/backend/core/api/filters.py @@ -128,3 +128,11 @@ class ListDocumentFilter(DocumentFilter): queryset_method = queryset.filter if bool(value) else queryset.exclude return queryset_method(link_traces__user=user, link_traces__is_masked=True) + + +class UserSearchFilter(django_filters.FilterSet): + """ + Custom filter for searching users. + """ + + q = django_filters.CharFilter(min_length=5, max_length=254) diff --git a/src/backend/core/api/viewsets.py b/src/backend/core/api/viewsets.py index 50ff8baa..9b44be68 100644 --- a/src/backend/core/api/viewsets.py +++ b/src/backend/core/api/viewsets.py @@ -51,7 +51,7 @@ from core.tasks.mail import send_ask_for_access_mail from core.utils import extract_attachments, filter_descendants from . import permissions, serializers, utils -from .filters import DocumentFilter, ListDocumentFilter +from .filters import DocumentFilter, ListDocumentFilter, UserSearchFilter from .throttling import UserListThrottleBurst, UserListThrottleSustained logger = logging.getLogger(__name__) @@ -176,12 +176,18 @@ class UserViewSet( if self.action != "list": return queryset + filterset = UserSearchFilter( + self.request.GET, queryset=queryset, request=self.request + ) + if not filterset.is_valid(): + raise drf.exceptions.ValidationError(filterset.errors) + # Exclude all users already in the given document if document_id := self.request.query_params.get("document_id", ""): queryset = queryset.exclude(documentaccess__document_id=document_id) - if not (query := self.request.query_params.get("q", "")) or len(query) < 5: - return queryset.none() + filter_data = filterset.form.cleaned_data + query = filter_data["q"] # For emails, match emails by Levenstein distance to prevent typing errors if "@" in query: diff --git a/src/backend/core/tests/test_api_users.py b/src/backend/core/tests/test_api_users.py index 179b19ad..2968432a 100644 --- a/src/backend/core/tests/test_api_users.py +++ b/src/backend/core/tests/test_api_users.py @@ -194,18 +194,41 @@ def test_api_users_list_query_short_queries(): factories.UserFactory(email="john.lennon@example.com") response = client.get("/api/v1.0/users/?q=jo") - assert response.status_code == 200 - assert response.json() == [] + assert response.status_code == 400 + assert response.json() == { + "q": ["Ensure this value has at least 5 characters (it has 2)."] + } response = client.get("/api/v1.0/users/?q=john") - assert response.status_code == 200 - assert response.json() == [] + assert response.status_code == 400 + assert response.json() == { + "q": ["Ensure this value has at least 5 characters (it has 4)."] + } response = client.get("/api/v1.0/users/?q=john.") assert response.status_code == 200 assert len(response.json()) == 2 +def test_api_users_list_query_long_queries(): + """ + Queries longer than 255 characters should return an empty result set. + """ + user = factories.UserFactory(email="paul@example.com") + client = APIClient() + client.force_login(user) + + factories.UserFactory(email="john.doe@example.com") + factories.UserFactory(email="john.lennon@example.com") + + query = "a" * 244 + response = client.get(f"/api/v1.0/users/?q={query}@example.com") + assert response.status_code == 400 + assert response.json() == { + "q": ["Ensure this value has at most 254 characters (it has 256)."] + } + + def test_api_users_list_query_inactive(): """Inactive users should not be listed.""" user = factories.UserFactory()