diff --git a/src/backend/core/api/serializers.py b/src/backend/core/api/serializers.py index dc1ba541..2a21b670 100644 --- a/src/backend/core/api/serializers.py +++ b/src/backend/core/api/serializers.py @@ -12,7 +12,7 @@ from django.utils.translation import gettext_lazy as _ import magic from rest_framework import exceptions, serializers -from core import enums, models, utils +from core import choices, enums, models, utils from core.services.ai_services import AI_ACTIONS from core.services.converter_services import ( ConversionError, @@ -97,7 +97,7 @@ class BaseAccessSerializer(serializers.ModelSerializer): if not self.Meta.model.objects.filter( # pylint: disable=no-member Q(user=user) | Q(team__in=user.teams), - role__in=models.PRIVILEGED_ROLES, + role__in=choices.PRIVILEGED_ROLES, **{self.Meta.resource_field_name: resource_id}, # pylint: disable=no-member ).exists(): raise exceptions.PermissionDenied( diff --git a/src/backend/core/api/viewsets.py b/src/backend/core/api/viewsets.py index 8f8b5816..efc06fe7 100644 --- a/src/backend/core/api/viewsets.py +++ b/src/backend/core/api/viewsets.py @@ -32,7 +32,7 @@ from rest_framework import response as drf_response from rest_framework.permissions import AllowAny from rest_framework.throttling import UserRateThrottle -from core import authentication, enums, models +from core import authentication, choices, enums, models from core.services.ai_services import AIService from core.services.collaboration_services import CollaborationService from core.tasks.mail import send_ask_for_access_mail @@ -1539,12 +1539,12 @@ class DocumentAccessViewSet( document__in=ancestors.filter(depth__gte=highest_readable.depth) ) - is_privileged = bool(roles.intersection(set(models.PRIVILEGED_ROLES))) + is_privileged = bool(roles.intersection(set(choices.PRIVILEGED_ROLES))) if is_privileged: serializer_class = serializers.DocumentAccessSerializer else: # Return only the document's privileged accesses - queryset = queryset.filter(role__in=models.PRIVILEGED_ROLES) + queryset = queryset.filter(role__in=choices.PRIVILEGED_ROLES) serializer_class = serializers.DocumentAccessLightSerializer queryset = queryset.distinct() @@ -1786,11 +1786,11 @@ class InvitationViewset( queryset.filter( db.Q( document__accesses__user=user, - document__accesses__role__in=models.PRIVILEGED_ROLES, + document__accesses__role__in=choices.PRIVILEGED_ROLES, ) | db.Q( document__accesses__team__in=teams, - document__accesses__role__in=models.PRIVILEGED_ROLES, + document__accesses__role__in=choices.PRIVILEGED_ROLES, ), ) # Abilities are computed based on logged-in user's role and diff --git a/src/backend/core/choices.py b/src/backend/core/choices.py new file mode 100644 index 00000000..5bfd61c3 --- /dev/null +++ b/src/backend/core/choices.py @@ -0,0 +1,123 @@ +"""Declare and configure choices for Docs' core application.""" + +from django.db.models import TextChoices +from django.utils.translation import gettext_lazy as _ + + +class PriorityTextChoices(TextChoices): + """ + This class inherits from Django's TextChoices and provides a method to get the priority + of a given value based on its position in the class. + """ + + @classmethod + def get_priority(cls, value): + """Returns the priority of the given value based on its order in the class.""" + members = list(cls.__members__.values()) + return members.index(value) + 1 if value in members else 0 + + @classmethod + def max(cls, *roles): + """ + Return the highest-priority role among the given roles, using get_priority(). + If no valid roles are provided, returns None. + """ + + valid_roles = [role for role in roles if cls.get_priority(role) is not None] + if not valid_roles: + return None + return max(valid_roles, key=cls.get_priority) + + +class LinkRoleChoices(PriorityTextChoices): + """Defines the possible roles a link can offer on a document.""" + + READER = "reader", _("Reader") # Can read + EDITOR = "editor", _("Editor") # Can read and edit + + +class RoleChoices(PriorityTextChoices): + """Defines the possible roles a user can have in a resource.""" + + READER = "reader", _("Reader") # Can read + EDITOR = "editor", _("Editor") # Can read and edit + ADMIN = "administrator", _("Administrator") # Can read, edit, delete and share + OWNER = "owner", _("Owner") + + +PRIVILEGED_ROLES = [RoleChoices.ADMIN, RoleChoices.OWNER] + + +class LinkReachChoices(PriorityTextChoices): + """Defines types of access for links""" + + RESTRICTED = ( + "restricted", + _("Restricted"), + ) # Only users with a specific access can read/edit the document + AUTHENTICATED = ( + "authenticated", + _("Authenticated"), + ) # Any authenticated user can access the document + PUBLIC = "public", _("Public") # Even anonymous users can access the document + + @classmethod + def get_select_options(cls, ancestors_links): + """ + Determines the valid select options for link reach and link role depending on the + list of ancestors' link reach/role. + Args: + ancestors_links: List of dictionaries, each with 'link_reach' and 'link_role' keys + representing the reach and role of ancestors links. + Returns: + Dictionary mapping possible reach levels to their corresponding possible roles. + """ + # If no ancestors, return all options + if not ancestors_links: + return { + reach: LinkRoleChoices.values if reach != cls.RESTRICTED else None + for reach in cls.values + } + + # Initialize result with all possible reaches and role options as sets + result = { + reach: set(LinkRoleChoices.values) if reach != cls.RESTRICTED else None + for reach in cls.values + } + + # Group roles by reach level + reach_roles = defaultdict(set) + for link in ancestors_links: + reach_roles[link["link_reach"]].add(link["link_role"]) + + # Rule 1: public/editor → override everything + if LinkRoleChoices.EDITOR in reach_roles.get(cls.PUBLIC, set()): + return {cls.PUBLIC: [LinkRoleChoices.EDITOR]} + + # Rule 2: authenticated/editor + if LinkRoleChoices.EDITOR in reach_roles.get(cls.AUTHENTICATED, set()): + result[cls.AUTHENTICATED].discard(LinkRoleChoices.READER) + result.pop(cls.RESTRICTED, None) + + # Rule 3: public/reader + if LinkRoleChoices.READER in reach_roles.get(cls.PUBLIC, set()): + result.pop(cls.AUTHENTICATED, None) + result.pop(cls.RESTRICTED, None) + + # Rule 4: authenticated/reader + if LinkRoleChoices.READER in reach_roles.get(cls.AUTHENTICATED, set()): + result.pop(cls.RESTRICTED, None) + + # Clean up: remove empty entries and convert sets to ordered lists + cleaned = {} + for reach in cls.values: + if reach in result: + if result[reach]: + cleaned[reach] = [ + r for r in LinkRoleChoices.values if r in result[reach] + ] + else: + # Could be [] or None (for RESTRICTED reach) + cleaned[reach] = result[reach] + + return cleaned diff --git a/src/backend/core/models.py b/src/backend/core/models.py index e49a8d97..b01f5eab 100644 --- a/src/backend/core/models.py +++ b/src/backend/core/models.py @@ -33,6 +33,8 @@ from rest_framework.exceptions import ValidationError from timezone_field import TimeZoneField from treebeard.mp_tree import MP_Node, MP_NodeManager, MP_NodeQuerySet +from .choices import PRIVILEGED_ROLES, LinkReachChoices, LinkRoleChoices, RoleChoices + logger = getLogger(__name__) @@ -50,100 +52,6 @@ def get_trashbin_cutoff(): return timezone.now() - timedelta(days=settings.TRASHBIN_CUTOFF_DAYS) -class LinkRoleChoices(models.TextChoices): - """Defines the possible roles a link can offer on a document.""" - - READER = "reader", _("Reader") # Can read - EDITOR = "editor", _("Editor") # Can read and edit - - -class RoleChoices(models.TextChoices): - """Defines the possible roles a user can have in a resource.""" - - READER = "reader", _("Reader") # Can read - EDITOR = "editor", _("Editor") # Can read and edit - ADMIN = "administrator", _("Administrator") # Can read, edit, delete and share - OWNER = "owner", _("Owner") - - -PRIVILEGED_ROLES = [RoleChoices.ADMIN, RoleChoices.OWNER] - - -class LinkReachChoices(models.TextChoices): - """Defines types of access for links""" - - RESTRICTED = ( - "restricted", - _("Restricted"), - ) # Only users with a specific access can read/edit the document - AUTHENTICATED = ( - "authenticated", - _("Authenticated"), - ) # Any authenticated user can access the document - PUBLIC = "public", _("Public") # Even anonymous users can access the document - - @classmethod - def get_select_options(cls, ancestors_links): - """ - Determines the valid select options for link reach and link role depending on the - list of ancestors' link reach/role. - Args: - ancestors_links: List of dictionaries, each with 'link_reach' and 'link_role' keys - representing the reach and role of ancestors links. - Returns: - Dictionary mapping possible reach levels to their corresponding possible roles. - """ - # If no ancestors, return all options - if not ancestors_links: - return { - reach: LinkRoleChoices.values if reach != cls.RESTRICTED else None - for reach in cls.values - } - - # Initialize result with all possible reaches and role options as sets - result = { - reach: set(LinkRoleChoices.values) if reach != cls.RESTRICTED else None - for reach in cls.values - } - - # Group roles by reach level - reach_roles = defaultdict(set) - for link in ancestors_links: - reach_roles[link["link_reach"]].add(link["link_role"]) - - # Rule 1: public/editor → override everything - if LinkRoleChoices.EDITOR in reach_roles.get(cls.PUBLIC, set()): - return {cls.PUBLIC: [LinkRoleChoices.EDITOR]} - - # Rule 2: authenticated/editor - if LinkRoleChoices.EDITOR in reach_roles.get(cls.AUTHENTICATED, set()): - result[cls.AUTHENTICATED].discard(LinkRoleChoices.READER) - result.pop(cls.RESTRICTED, None) - - # Rule 3: public/reader - if LinkRoleChoices.READER in reach_roles.get(cls.PUBLIC, set()): - result.pop(cls.AUTHENTICATED, None) - result.pop(cls.RESTRICTED, None) - - # Rule 4: authenticated/reader - if LinkRoleChoices.READER in reach_roles.get(cls.AUTHENTICATED, set()): - result.pop(cls.RESTRICTED, None) - - # Clean up: remove empty entries and convert sets to ordered lists - cleaned = {} - for reach in cls.values: - if reach in result: - if result[reach]: - cleaned[reach] = [ - r for r in LinkRoleChoices.values if r in result[reach] - ] - else: - # Could be [] or None (for RESTRICTED reach) - cleaned[reach] = result[reach] - - return cleaned - - class DuplicateEmailError(Exception): """Raised when an email is already associated with a pre-existing user.""" diff --git a/src/backend/core/tests/documents/test_api_document_accesses.py b/src/backend/core/tests/documents/test_api_document_accesses.py index e30a6c36..bc6dcb51 100644 --- a/src/backend/core/tests/documents/test_api_document_accesses.py +++ b/src/backend/core/tests/documents/test_api_document_accesses.py @@ -8,7 +8,7 @@ from uuid import uuid4 import pytest from rest_framework.test import APIClient -from core import factories, models +from core import choices, factories, models from core.api import serializers from core.tests.conftest import TEAM, USER, VIA from core.tests.test_services_collaboration_services import ( # pylint: disable=unused-import @@ -70,7 +70,8 @@ def test_api_document_accesses_list_unexisting_document(): @pytest.mark.parametrize("via", VIA) @pytest.mark.parametrize( - "role", [role for role in models.RoleChoices if role not in models.PRIVILEGED_ROLES] + "role", + [role for role in choices.RoleChoices if role not in choices.PRIVILEGED_ROLES], ) def test_api_document_accesses_list_authenticated_related_non_privileged( via, role, mock_user_teams @@ -131,7 +132,7 @@ def test_api_document_accesses_list_authenticated_related_non_privileged( # Make sure only privileged roles are returned accesses = [grand_parent_access, parent_access, document_access, access1, access2] privileged_accesses = [ - acc for acc in accesses if acc.role in models.PRIVILEGED_ROLES + acc for acc in accesses if acc.role in choices.PRIVILEGED_ROLES ] assert len(content) == len(privileged_accesses) @@ -159,7 +160,7 @@ def test_api_document_accesses_list_authenticated_related_non_privileged( @pytest.mark.parametrize("via", VIA) @pytest.mark.parametrize( - "role", [role for role in models.RoleChoices if role in models.PRIVILEGED_ROLES] + "role", [role for role in choices.RoleChoices if role in choices.PRIVILEGED_ROLES] ) def test_api_document_accesses_list_authenticated_related_privileged( via, role, mock_user_teams @@ -335,7 +336,7 @@ def test_api_document_accesses_retrieve_authenticated_related( f"/api/v1.0/documents/{document.id!s}/accesses/{access.id!s}/", ) - if not role in models.PRIVILEGED_ROLES: + if not role in choices.PRIVILEGED_ROLES: assert response.status_code == 403 else: access_user = serializers.UserSerializer(instance=access.user).data