diff --git a/CHANGELOG.md b/CHANGELOG.md index 4300efba..51a2d78c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,10 @@ and this project adheres to ## [Unreleased] +### Changed + +- 🔒️(backend) configure throttle on every viewsets #1343 + ## [3.6.0] - 2025-09-04 ### Added diff --git a/env.d/development/common.e2e b/env.d/development/common.e2e index 4fd5efd5..15434a68 100644 --- a/env.d/development/common.e2e +++ b/env.d/development/common.e2e @@ -3,3 +3,7 @@ BURST_THROTTLE_RATES="200/minute" COLLABORATION_API_URL=http://y-provider:4444/collaboration/api/ SUSTAINED_THROTTLE_RATES="200/hour" Y_PROVIDER_API_BASE_URL=http://y-provider:4444/api/ + +# Throttle +API_DOCUMENT_THROTTLE_RATE=1000/min +API_CONFIG_THROTTLE_RATE=1000/min \ No newline at end of file diff --git a/src/backend/core/api/throttling.py b/src/backend/core/api/throttling.py new file mode 100644 index 00000000..adfb6d57 --- /dev/null +++ b/src/backend/core/api/throttling.py @@ -0,0 +1,21 @@ +"""Throttling modules for the API.""" + +from rest_framework.throttling import UserRateThrottle +from sentry_sdk import capture_message + + +def sentry_monitoring_throttle_failure(message): + """Log when a failure occurs to detect rate limiting issues.""" + capture_message(message, "warning") + + +class UserListThrottleBurst(UserRateThrottle): + """Throttle for the user list endpoint.""" + + scope = "user_list_burst" + + +class UserListThrottleSustained(UserRateThrottle): + """Throttle for the user list endpoint.""" + + scope = "user_list_sustained" diff --git a/src/backend/core/api/viewsets.py b/src/backend/core/api/viewsets.py index f98cbfc6..19841185 100644 --- a/src/backend/core/api/viewsets.py +++ b/src/backend/core/api/viewsets.py @@ -33,7 +33,6 @@ from lasuite.malware_detection import malware_detection from rest_framework import filters, status, viewsets from rest_framework import response as drf_response from rest_framework.permissions import AllowAny -from rest_framework.throttling import UserRateThrottle from core import authentication, choices, enums, models from core.services.ai_services import AIService @@ -43,6 +42,7 @@ from core.utils import extract_attachments, filter_descendants from . import permissions, serializers, utils from .filters import DocumentFilter, ListDocumentFilter +from .throttling import UserListThrottleBurst, UserListThrottleSustained logger = logging.getLogger(__name__) @@ -136,18 +136,6 @@ class Pagination(drf.pagination.PageNumberPagination): page_size_query_param = "page_size" -class UserListThrottleBurst(UserRateThrottle): - """Throttle for the user list endpoint.""" - - scope = "user_list_burst" - - -class UserListThrottleSustained(UserRateThrottle): - """Throttle for the user list endpoint.""" - - scope = "user_list_sustained" - - class UserViewSet( drf.mixins.UpdateModelMixin, viewsets.GenericViewSet, drf.mixins.ListModelMixin ): @@ -360,6 +348,7 @@ class DocumentViewSet( permission_classes = [ permissions.DocumentPermission, ] + throttle_scope = "document" queryset = models.Document.objects.select_related("creator").all() serializer_class = serializers.DocumentSerializer ai_translate_serializer_class = serializers.AITranslateSerializer @@ -1555,6 +1544,7 @@ class DocumentAccessViewSet( "document__depth", ) resource_field_name = "document" + throttle_scope = "document_access" @cached_property def document(self): @@ -1714,6 +1704,7 @@ class TemplateViewSet( permissions.IsAuthenticatedOrSafe, permissions.ResourceWithAccessPermission, ] + throttle_scope = "template" ordering = ["-created_at"] ordering_fields = ["created_at", "updated_at", "title"] serializer_class = serializers.TemplateSerializer @@ -1804,6 +1795,7 @@ class TemplateAccessViewSet( lookup_field = "pk" permission_classes = [permissions.ResourceAccessPermission] + throttle_scope = "template_access" queryset = models.TemplateAccess.objects.select_related("user").all() resource_field_name = "template" serializer_class = serializers.TemplateAccessSerializer @@ -1886,6 +1878,7 @@ class InvitationViewset( permissions.CanCreateInvitationPermission, permissions.ResourceWithAccessPermission, ] + throttle_scope = "invitation" queryset = ( models.Invitation.objects.all() .select_related("document") @@ -1964,6 +1957,7 @@ class DocumentAskForAccessViewSet( permissions.IsAuthenticated, permissions.ResourceWithAccessPermission, ] + throttle_scope = "document_ask_for_access" queryset = models.DocumentAskForAccess.objects.all() serializer_class = serializers.DocumentAskForAccessSerializer _document = None @@ -2036,6 +2030,7 @@ class ConfigView(drf.views.APIView): """API ViewSet for sharing some public settings.""" permission_classes = [AllowAny] + throttle_scope = "config" def get(self, request): """ diff --git a/src/backend/core/tests/documents/test_api_document_accesses.py b/src/backend/core/tests/documents/test_api_document_accesses.py index eb6fa92b..280b2bc9 100644 --- a/src/backend/core/tests/documents/test_api_document_accesses.py +++ b/src/backend/core/tests/documents/test_api_document_accesses.py @@ -4,6 +4,7 @@ Test document accesses API endpoints for users in impress's core app. # pylint: disable=too-many-lines import random +from unittest import mock from uuid import uuid4 import pytest @@ -1344,3 +1345,24 @@ def test_api_document_accesses_delete_owners_last_owner_child_team( assert response.status_code == 204 assert models.DocumentAccess.objects.count() == 1 + + +def test_api_document_accesses_throttling(settings): + """Test api document accesses throttling.""" + settings.REST_FRAMEWORK["DEFAULT_THROTTLE_RATES"]["document_access"] = "2/minute" + user = factories.UserFactory() + document = factories.DocumentFactory() + factories.UserDocumentAccessFactory( + document=document, user=user, role="administrator" + ) + client = APIClient() + client.force_login(user) + for _i in range(2): + response = client.get(f"/api/v1.0/documents/{document.id!s}/accesses/") + assert response.status_code == 200 + with mock.patch("core.api.throttling.capture_message") as mock_capture_message: + response = client.get(f"/api/v1.0/documents/{document.id!s}/accesses/") + assert response.status_code == 429 + mock_capture_message.assert_called_once_with( + "Rate limit exceeded for scope document_access", "warning" + ) diff --git a/src/backend/core/tests/documents/test_api_document_invitations.py b/src/backend/core/tests/documents/test_api_document_invitations.py index 16090b7d..33d99d8a 100644 --- a/src/backend/core/tests/documents/test_api_document_invitations.py +++ b/src/backend/core/tests/documents/test_api_document_invitations.py @@ -824,3 +824,29 @@ def test_api_document_invitations_delete_readers_or_editors(via, role, mock_user response.json()["detail"] == "You do not have permission to perform this action." ) + + +def test_api_document_invitations_throttling(settings): + """Test api document ask for access throttling.""" + current_rate = settings.REST_FRAMEWORK["DEFAULT_THROTTLE_RATES"]["invitation"] + settings.REST_FRAMEWORK["DEFAULT_THROTTLE_RATES"]["invitation"] = "2/minute" + user = factories.UserFactory() + document = factories.DocumentFactory() + + factories.UserDocumentAccessFactory(document=document, user=user, role="owner") + + factories.InvitationFactory(document=document, issuer=user) + + client = APIClient() + client.force_login(user) + + for _i in range(2): + response = client.get(f"/api/v1.0/documents/{document.id}/invitations/") + assert response.status_code == 200 + with mock.patch("core.api.throttling.capture_message") as mock_capture_message: + response = client.get(f"/api/v1.0/documents/{document.id}/invitations/") + assert response.status_code == 429 + mock_capture_message.assert_called_once_with( + "Rate limit exceeded for scope invitation", "warning" + ) + settings.REST_FRAMEWORK["DEFAULT_THROTTLE_RATES"]["invitation"] = current_rate diff --git a/src/backend/core/tests/documents/test_api_documents_ask_for_access.py b/src/backend/core/tests/documents/test_api_documents_ask_for_access.py index 23bb8d50..4e09dc5d 100644 --- a/src/backend/core/tests/documents/test_api_documents_ask_for_access.py +++ b/src/backend/core/tests/documents/test_api_documents_ask_for_access.py @@ -1,6 +1,7 @@ """Test API for document ask for access.""" import uuid +from unittest import mock from django.core import mail @@ -768,3 +769,35 @@ def test_api_documents_ask_for_access_accept_authenticated_non_root_document(rol f"/api/v1.0/documents/{child.id}/ask-for-access/{document_ask_for_access.id}/accept/" ) assert response.status_code == 404 + + +def test_api_document_ask_for_access_throttling(settings): + """Test api document ask for access throttling.""" + current_rate = settings.REST_FRAMEWORK["DEFAULT_THROTTLE_RATES"][ + "document_ask_for_access" + ] + settings.REST_FRAMEWORK["DEFAULT_THROTTLE_RATES"]["document_ask_for_access"] = ( + "2/minute" + ) + document = DocumentFactory() + DocumentAskForAccessFactory.create_batch( + 3, document=document, role=RoleChoices.READER + ) + + user = UserFactory() + + client = APIClient() + client.force_login(user) + + for _i in range(2): + response = client.get(f"/api/v1.0/documents/{document.id}/ask-for-access/") + assert response.status_code == 200 + with mock.patch("core.api.throttling.capture_message") as mock_capture_message: + response = client.get(f"/api/v1.0/documents/{document.id}/ask-for-access/") + assert response.status_code == 429 + mock_capture_message.assert_called_once_with( + "Rate limit exceeded for scope document_ask_for_access", "warning" + ) + settings.REST_FRAMEWORK["DEFAULT_THROTTLE_RATES"]["document_ask_for_access"] = ( + current_rate + ) diff --git a/src/backend/core/tests/documents/test_api_documents_list.py b/src/backend/core/tests/documents/test_api_documents_list.py index cfaa3e0a..1fe23594 100644 --- a/src/backend/core/tests/documents/test_api_documents_list.py +++ b/src/backend/core/tests/documents/test_api_documents_list.py @@ -427,3 +427,20 @@ def test_api_documents_list_favorites_no_extra_queries(django_assert_num_queries assert result["is_favorite"] is True else: assert result["is_favorite"] is False + + +def test_api_documents_list_throttling(settings): + """Test api documents throttling.""" + current_rate = settings.REST_FRAMEWORK["DEFAULT_THROTTLE_RATES"]["document"] + settings.REST_FRAMEWORK["DEFAULT_THROTTLE_RATES"]["document"] = "2/minute" + client = APIClient() + for _i in range(2): + response = client.get("/api/v1.0/documents/") + assert response.status_code == 200 + with mock.patch("core.api.throttling.capture_message") as mock_capture_message: + response = client.get("/api/v1.0/documents/") + assert response.status_code == 429 + mock_capture_message.assert_called_once_with( + "Rate limit exceeded for scope document", "warning" + ) + settings.REST_FRAMEWORK["DEFAULT_THROTTLE_RATES"]["document"] = current_rate diff --git a/src/backend/core/tests/templates/test_api_template_accesses.py b/src/backend/core/tests/templates/test_api_template_accesses.py index 6d110776..01bb6a12 100644 --- a/src/backend/core/tests/templates/test_api_template_accesses.py +++ b/src/backend/core/tests/templates/test_api_template_accesses.py @@ -3,6 +3,7 @@ Test template accesses API endpoints for users in impress's core app. """ import random +from unittest import mock from uuid import uuid4 import pytest @@ -773,3 +774,26 @@ def test_api_template_accesses_delete_owners_last_owner(via, mock_user_teams): assert response.status_code == 403 assert models.TemplateAccess.objects.count() == 2 + + +def test_api_template_accesses_throttling(settings): + """Test api template accesses throttling.""" + current_rate = settings.REST_FRAMEWORK["DEFAULT_THROTTLE_RATES"]["template_access"] + settings.REST_FRAMEWORK["DEFAULT_THROTTLE_RATES"]["template_access"] = "2/minute" + template = factories.TemplateFactory() + user = factories.UserFactory() + factories.UserTemplateAccessFactory( + template=template, user=user, role="administrator" + ) + client = APIClient() + client.force_login(user) + for _i in range(2): + response = client.get(f"/api/v1.0/templates/{template.id!s}/accesses/") + assert response.status_code == 200 + with mock.patch("core.api.throttling.capture_message") as mock_capture_message: + response = client.get(f"/api/v1.0/templates/{template.id!s}/accesses/") + assert response.status_code == 429 + mock_capture_message.assert_called_once_with( + "Rate limit exceeded for scope template_access", "warning" + ) + settings.REST_FRAMEWORK["DEFAULT_THROTTLE_RATES"]["template_access"] = current_rate diff --git a/src/backend/core/tests/templates/test_api_templates_list.py b/src/backend/core/tests/templates/test_api_templates_list.py index 11df4fa9..f4980354 100644 --- a/src/backend/core/tests/templates/test_api_templates_list.py +++ b/src/backend/core/tests/templates/test_api_templates_list.py @@ -218,3 +218,20 @@ def test_api_templates_list_order_param(): assert response_template_ids == templates_ids, ( "created_at values are not sorted from oldest to newest" ) + + +def test_api_template_throttling(settings): + """Test api template throttling.""" + current_rate = settings.REST_FRAMEWORK["DEFAULT_THROTTLE_RATES"]["template"] + settings.REST_FRAMEWORK["DEFAULT_THROTTLE_RATES"]["template"] = "2/minute" + client = APIClient() + for _i in range(2): + response = client.get("/api/v1.0/templates/") + assert response.status_code == 200 + with mock.patch("core.api.throttling.capture_message") as mock_capture_message: + response = client.get("/api/v1.0/templates/") + assert response.status_code == 429 + mock_capture_message.assert_called_once_with( + "Rate limit exceeded for scope template", "warning" + ) + settings.REST_FRAMEWORK["DEFAULT_THROTTLE_RATES"]["template"] = current_rate diff --git a/src/backend/core/tests/test_api_config.py b/src/backend/core/tests/test_api_config.py index cdc4a9cb..cac6bc07 100644 --- a/src/backend/core/tests/test_api_config.py +++ b/src/backend/core/tests/test_api_config.py @@ -3,6 +3,7 @@ Test config API endpoints in the Impress core app. """ import json +from unittest.mock import patch from django.test import override_settings @@ -174,3 +175,20 @@ def test_api_config_with_original_theme_customization(is_authenticated, settings theme_customization = json.load(f) assert content["theme_customization"] == theme_customization + + +def test_api_config_throttling(settings): + """Test api config throttling.""" + current_rate = settings.REST_FRAMEWORK["DEFAULT_THROTTLE_RATES"]["config"] + settings.REST_FRAMEWORK["DEFAULT_THROTTLE_RATES"]["config"] = "2/minute" + client = APIClient() + for _i in range(2): + response = client.get("/api/v1.0/config/") + assert response.status_code == 200 + with patch("core.api.throttling.capture_message") as mock_capture_message: + response = client.get("/api/v1.0/config/") + assert response.status_code == 429 + mock_capture_message.assert_called_once_with( + "Rate limit exceeded for scope config", "warning" + ) + settings.REST_FRAMEWORK["DEFAULT_THROTTLE_RATES"]["config"] = current_rate diff --git a/src/backend/impress/settings.py b/src/backend/impress/settings.py index 730574e3..9059dd29 100755 --- a/src/backend/impress/settings.py +++ b/src/backend/impress/settings.py @@ -356,6 +356,9 @@ class Base(Configuration): "PAGE_SIZE": 20, "DEFAULT_VERSIONING_CLASS": "rest_framework.versioning.URLPathVersioning", "DEFAULT_SCHEMA_CLASS": "drf_spectacular.openapi.AutoSchema", + "DEFAULT_THROTTLE_CLASSES": [ + "lasuite.drf.throttling.MonitoredScopedRateThrottle", + ], "DEFAULT_THROTTLE_RATES": { "user_list_sustained": values.Value( default="180/hour", @@ -367,8 +370,46 @@ class Base(Configuration): environ_name="API_USERS_LIST_THROTTLE_RATE_BURST", environ_prefix=None, ), + "document": values.Value( + default="80/minute", + environ_name="API_DOCUMENT_THROTTLE_RATE", + environ_prefix=None, + ), + "document_access": values.Value( + default="50/minute", + environ_name="API_DOCUMENT_ACCESS_THROTTLE_RATE", + environ_prefix=None, + ), + "template": values.Value( + default="30/minute", + environ_name="API_TEMPLATE_THROTTLE_RATE", + environ_prefix=None, + ), + "template_access": values.Value( + default="30/minute", + environ_name="API_TEMPLATE_ACCESS_THROTTLE_RATE", + environ_prefix=None, + ), + "invitation": values.Value( + default="60/minute", + environ_name="API_INVITATION_THROTTLE_RATE", + environ_prefix=None, + ), + "document_ask_for_access": values.Value( + default="30/minute", + environ_name="API_DOCUMENT_ASK_FOR_ACCESS_THROTTLE_RATE", + environ_prefix=None, + ), + "config": values.Value( + default="30/minute", + environ_name="API_CONFIG_THROTTLE_RATE", + environ_prefix=None, + ), }, } + MONITORED_THROTTLE_FAILURE_CALLBACK = ( + "core.api.throttling.sentry_monitoring_throttle_failure" + ) SPECTACULAR_SETTINGS = { "TITLE": "Impress API",