diff --git a/src/backend/core/api/viewsets.py b/src/backend/core/api/viewsets.py index 8b2ed7a6..b028b06d 100644 --- a/src/backend/core/api/viewsets.py +++ b/src/backend/core/api/viewsets.py @@ -1078,6 +1078,11 @@ class DocumentViewSet( {"id": str(duplicated_document.id)}, status=status.HTTP_201_CREATED ) + # TODO + # @drf.decorators.action(detail=False, methods=["get"]) + # def search(self, request, *args, **kwargs): + # index.search() + @drf.decorators.action(detail=True, methods=["get"], url_path="versions") def versions_list(self, request, *args, **kwargs): """ diff --git a/src/backend/core/management/commands/index.py b/src/backend/core/management/commands/index.py index 85aeec90..b7eab950 100644 --- a/src/backend/core/management/commands/index.py +++ b/src/backend/core/management/commands/index.py @@ -25,4 +25,4 @@ class Command(BaseCommand): FindDocumentIndexer().index() duration = time.perf_counter() - start - logger.info(f"Search index regenerated in {duration:.2f} seconds.") + logger.info("Search index regenerated in %.2f seconds.", duration) diff --git a/src/backend/core/models.py b/src/backend/core/models.py index 424d6b89..cf6e7c06 100644 --- a/src/backend/core/models.py +++ b/src/backend/core/models.py @@ -20,7 +20,9 @@ from django.core.files.base import ContentFile from django.core.files.storage import default_storage from django.core.mail import send_mail from django.db import models, transaction +from django.db.models import signals from django.db.models.functions import Left, Length +from django.dispatch import receiver from django.template.loader import render_to_string from django.utils import timezone from django.utils.functional import cached_property @@ -40,6 +42,7 @@ from .choices import ( get_equivalent_link_definition, ) from .validators import sub_validator +from .tasks.find import trigger_document_indexer logger = getLogger(__name__) @@ -952,6 +955,16 @@ class Document(MP_Node, BaseModel): ) +@receiver(signals.post_save, sender=Document) +def document_post_save(sender, instance, **kwargs): + """ + Asynchronous call to the document indexer at the end of the transaction. + Note : Within the transaction we can have an empty content and a serialization + error. + """ + trigger_document_indexer(instance, on_commit=True) + + class LinkTrace(BaseModel): """ Relation model to trace accesses to a document via a link by a logged-in user. @@ -1182,6 +1195,15 @@ class DocumentAccess(BaseAccess): } +@receiver(signals.post_save, sender=DocumentAccess) +def document_access_post_save(sender, instance, created, **kwargs): + """ + Asynchronous call to the document indexer at the end of the transaction. + """ + if not created: + trigger_document_indexer(instance.document, on_commit=True) + + class DocumentAskForAccess(BaseModel): """Relation model to ask for access to a document.""" diff --git a/src/backend/core/services/search_indexers.py b/src/backend/core/services/search_indexers.py index 742be87e..dda2b585 100644 --- a/src/backend/core/services/search_indexers.py +++ b/src/backend/core/services/search_indexers.py @@ -5,6 +5,7 @@ from abc import ABC, abstractmethod from collections import defaultdict from django.conf import settings +from django.contrib.auth.models import AnonymousUser import requests @@ -18,11 +19,13 @@ def get_batch_accesses_by_users_and_teams(paths): Get accesses related to a list of document paths, grouped by users and teams, including all ancestor paths. """ - print("paths: ", paths) - ancestor_map = utils.get_ancestor_to_descendants_map(paths, steplen=models.Document.steplen) + # print("paths: ", paths) + ancestor_map = utils.get_ancestor_to_descendants_map( + paths, steplen=models.Document.steplen + ) ancestor_paths = list(ancestor_map.keys()) - print("ancestor map: ", ancestor_map) - print("ancestor paths: ", ancestor_paths) + # print("ancestor map: ", ancestor_map) + # print("ancestor paths: ", ancestor_paths) access_qs = models.DocumentAccess.objects.filter( document__path__in=ancestor_paths @@ -44,6 +47,22 @@ def get_batch_accesses_by_users_and_teams(paths): return dict(access_by_document_path) +def get_visited_document_ids_of(user): + if isinstance(user, AnonymousUser): + return [] + + # TODO : exclude links when user already have a specific access to the doc + qs = models.LinkTrace.objects.filter( + user=user + ).exclude( + document__accesses__user=user, + ) + + return list({ + str(id) for id in qs.values_list("document_id", flat=True) + }) + + class BaseDocumentIndexer(ABC): """ Base class for document indexers. @@ -84,6 +103,7 @@ class BaseDocumentIndexer(ABC): serialized_batch = [ self.serialize_document(document, accesses_by_document_path) for document in documents_batch + if document.content ] self.push(serialized_batch) @@ -103,6 +123,38 @@ class BaseDocumentIndexer(ABC): Must be implemented by subclasses. """ + def search(self, text, user, token): + """ + Search for documents in Find app. + """ + visited_ids = get_visited_document_ids_of(user) + + response = self.search_query(data={ + "q": text, + "visited": visited_ids, + "services": ["docs"], + }, token=token) + + print(response) + + return self.format_response(response) + + @abstractmethod + def search_query(self, data, token) -> dict: + """ + Retreive documents from the Find app API. + + Must be implemented by subclasses. + """ + + @abstractmethod + def format_response(self, data: dict): + """ + Convert the JSON response from Find app as document queryset. + + Must be implemented by subclasses. + """ + class FindDocumentIndexer(BaseDocumentIndexer): """ @@ -121,10 +173,12 @@ class FindDocumentIndexer(BaseDocumentIndexer): dict: A JSON-serializable dictionary. """ doc_path = document.path - text_content = utils.base64_yjs_to_text(document.content) + doc_content = document.content + text_content = utils.base64_yjs_to_text(doc_content) if doc_content else "" + return { "id": str(document.id), - "title": document.title, + "title": document.title or "", "content": text_content, "depth": document.depth, "path": document.path, @@ -138,6 +192,46 @@ class FindDocumentIndexer(BaseDocumentIndexer): "is_active": not bool(document.ancestors_deleted_at), } + def search_query(self, data, token) -> requests.Response: + """ + Retrieve documents from the Find app API. + + Args: + data (dict): search data + token (str): OICD token + + Returns: + dict: A JSON-serializable dictionary. + """ + url = getattr(settings, "SEARCH_INDEXER_QUERY_URL", None) + + if not url: + raise RuntimeError( + "SEARCH_INDEXER_QUERY_URL must be set in Django settings before indexing." + ) + + try: + response = requests.post( + url, + json=data, + headers={"Authorization": f"Bearer {token}"}, + timeout=10, + ) + response.raise_for_status() + return response.json() + except requests.exceptions.HTTPError as e: + logger.error("HTTPError: %s", e) + logger.error("Response content: %s", response.text) # type: ignore + raise + + def format_response(self, data: dict): + """ + Retrieve documents ids from Find app response and return a queryset. + """ + return models.Document.objects.filter(pk__in=[ + d['_id'] for d in data + ]) + def push(self, data): """ Push a batch of documents to the Find backend. @@ -156,6 +250,7 @@ class FindDocumentIndexer(BaseDocumentIndexer): raise RuntimeError( "SEARCH_INDEXER_SECRET must be set in Django settings before indexing." ) + try: response = requests.post( url, diff --git a/src/backend/core/tasks/find.py b/src/backend/core/tasks/find.py new file mode 100644 index 00000000..858c83ee --- /dev/null +++ b/src/backend/core/tasks/find.py @@ -0,0 +1,96 @@ +"""Trigger document indexation using celery task.""" + +from logging import getLogger + +from django.conf import settings +from django.core.cache import cache +from django.db import transaction + +from core import models +from core.services.search_indexers import ( + FindDocumentIndexer, + get_batch_accesses_by_users_and_teams, +) + +from impress.celery_app import app + +logger = getLogger(__file__) + + +def document_indexer_debounce_key(document_id): + """Returns debounce cache key""" + return f"doc-indexer-debounce-{document_id}" + + +def incr_counter(key): + """Increase or reset counter""" + try: + return cache.incr(key) + except ValueError: + cache.set(key, 1) + return 1 + + +def decr_counter(key): + """Decrease or reset counter""" + try: + return cache.decr(key) + except ValueError: + cache.set(key, 0) + return 0 + + +@app.task +def document_indexer_task(document_id): + """Send indexation query for a document using celery task.""" + key = document_indexer_debounce_key(document_id) + + # check if the counter : if still up, skip the task. only the last one + # within the countdown delay will do the query. + if decr_counter(key) > 0: + logger.info("Skip document %s indexation", document_id) + return + + doc = models.Document.objects.get(pk=document_id) + indexer = FindDocumentIndexer() + accesses = get_batch_accesses_by_users_and_teams((doc.path,)) + + data = indexer.serialize_document(document=doc, accesses=accesses) + + logger.info("Start document %s indexation", document_id) + indexer.push(data) + + +def trigger_document_indexer(document, on_commit=False): + """ + Trigger indexation task with debounce a delay set by the SEARCH_INDEXER_COUNTDOWN setting. + + Args: + document (Document): The document instance. + on_commit (bool): Wait for the end of the transaction before starting the task + (some fields may be in wrong state within the transaction) + """ + + if document.deleted_at or document.ancestors_deleted_at: + pass + + if on_commit: + + def _aux(): + trigger_document_indexer(document, on_commit=False) + + transaction.on_commit(_aux) + else: + key = document_indexer_debounce_key(document.pk) + countdown = getattr(settings, "SEARCH_INDEXER_COUNTDOWN", 1) + + logger.info( + "Add task for document %s indexation in %.2f seconds", + document.pk, countdown + ) + + # Each time this method is called during the countdown, we increment the + # counter and each task decrease it, so the index be run only once. + incr_counter(key) + + document_indexer_task.apply_async(args=[document.pk], countdown=countdown) diff --git a/src/backend/core/tests/test_models_documents.py b/src/backend/core/tests/test_models_documents.py index 91b8abf9..cbe01f83 100644 --- a/src/backend/core/tests/test_models_documents.py +++ b/src/backend/core/tests/test_models_documents.py @@ -5,6 +5,7 @@ Unit tests for the Document model import random import smtplib +import time from logging import Logger from unittest import mock @@ -13,12 +14,15 @@ from django.core import mail from django.core.cache import cache from django.core.exceptions import ValidationError from django.core.files.storage import default_storage +from django.db import transaction from django.test.utils import override_settings from django.utils import timezone import pytest from core import factories, models +from core.services.search_indexers import FindDocumentIndexer +from core.tasks.find import document_indexer_debounce_key pytestmark = pytest.mark.django_db @@ -1617,3 +1621,125 @@ def test_models_documents_compute_ancestors_links_paths_mapping_structure( {"link_reach": sibling.link_reach, "link_role": sibling.link_role}, ], } + + +@mock.patch.object(FindDocumentIndexer, "push") +@pytest.mark.django_db(transaction=True) +def test_models_documents_post_save_indexer(mock_push, settings): + """Test indexation task on document creation""" + settings.SEARCH_INDEXER_COUNTDOWN = 0 + + user = factories.UserFactory() + + with transaction.atomic(): + doc1, doc2, doc3 = factories.DocumentFactory.create_batch(3) + + factories.UserDocumentAccessFactory(document=doc1, user=user) + factories.UserDocumentAccessFactory(document=doc2, user=user) + factories.UserDocumentAccessFactory(document=doc3, user=user) + + time.sleep(0.1) # waits for the end of the tasks + + accesses = { + str(doc1.path): {"users": [user.sub]}, + str(doc2.path): {"users": [user.sub]}, + str(doc3.path): {"users": [user.sub]}, + } + + data = [call.args[0] for call in mock_push.call_args_list] + + indexer = FindDocumentIndexer() + + def sortkey(d): + return d["id"] + + assert sorted(data, key=sortkey) == sorted( + [ + indexer.serialize_document(doc1, accesses), + indexer.serialize_document(doc2, accesses), + indexer.serialize_document(doc3, accesses), + ], + key=sortkey, + ) + + # The debounce counters should be reset + assert cache.get(document_indexer_debounce_key(doc1.pk)) == 0 + assert cache.get(document_indexer_debounce_key(doc2.pk)) == 0 + assert cache.get(document_indexer_debounce_key(doc3.pk)) == 0 + + +@pytest.mark.django_db(transaction=True) +def test_models_documents_post_save_indexer_debounce(settings): + """Test indexation task skipping on document update""" + settings.SEARCH_INDEXER_COUNTDOWN = 0 + + indexer = FindDocumentIndexer() + user = factories.UserFactory() + + with mock.patch.object(FindDocumentIndexer, "push"): + with transaction.atomic(): + doc = factories.DocumentFactory() + factories.UserDocumentAccessFactory(document=doc, user=user) + + accesses = { + str(doc.path): {"users": [user.sub]}, + } + + time.sleep(0.1) # waits for the end of the tasks + + with mock.patch.object(FindDocumentIndexer, "push") as mock_push: + # Simulate 1 waiting task + cache.set(document_indexer_debounce_key(doc.pk), 1) + + # save doc to trigger the indexer, but nothing should be done since + # the counter is over 0 + with transaction.atomic(): + doc.save() + + time.sleep(0.1) + + assert [call.args[0] for call in mock_push.call_args_list] == [] + + with mock.patch.object(FindDocumentIndexer, "push") as mock_push: + # No waiting task + cache.set(document_indexer_debounce_key(doc.pk), 0) + + with transaction.atomic(): + doc = models.Document.objects.get(pk=doc.pk) + doc.save() + + time.sleep(0.1) + + assert [call.args[0] for call in mock_push.call_args_list] == [ + indexer.serialize_document(doc, accesses), + ] + + +@pytest.mark.django_db(transaction=True) +def test_models_documents_access_post_save_indexer(settings): + """Test indexation task on DocumentAccess update""" + settings.SEARCH_INDEXER_COUNTDOWN = 0 + + indexer = FindDocumentIndexer() + user = factories.UserFactory() + + with mock.patch.object(FindDocumentIndexer, "push"): + with transaction.atomic(): + doc = factories.DocumentFactory() + doc_access = factories.UserDocumentAccessFactory(document=doc, user=user) + + accesses = { + str(doc.path): {"users": [user.sub]}, + } + + indexer = FindDocumentIndexer() + + with mock.patch.object(FindDocumentIndexer, "push") as mock_push: + with transaction.atomic(): + doc_access.save() + + time.sleep(0.1) + + assert [call.args[0] for call in mock_push.call_args_list] == [ + indexer.serialize_document(doc, accesses), + ] diff --git a/src/backend/core/tests/test_services_search_indexers.py b/src/backend/core/tests/test_services_search_indexers.py index 2ad37c38..0d74d845 100644 --- a/src/backend/core/tests/test_services_search_indexers.py +++ b/src/backend/core/tests/test_services_search_indexers.py @@ -107,7 +107,7 @@ def test_services_search_indexers_serialize_document_returns_expected_json(): "created_at": document.created_at.isoformat(), "updated_at": document.updated_at.isoformat(), "reach": document.link_reach, - 'size': 13, + "size": 13, "is_active": True, } @@ -126,6 +126,17 @@ def test_services_search_indexers_serialize_document_deleted(): assert result["is_active"] is False +def test_services_search_indexers_serialize_document_empty(): + """Empty documents returns empty content in the serialized json.""" + document = factories.DocumentFactory(content="", title=None) + + indexer = FindDocumentIndexer() + result = indexer.serialize_document(document, {}) + + assert result["content"] == "" + assert result["title"] == "" + + @patch.object(FindDocumentIndexer, "push") def test_services_search_indexers_batches_pass_only_batch_accesses(mock_push, settings): """ @@ -168,7 +179,9 @@ def test_services_search_indexers_batches_pass_only_batch_accesses(mock_push, se def test_services_search_indexers_ancestors_link_reach(mock_push): """Document accesses and reach should take into account ancestors link reaches.""" great_grand_parent = factories.DocumentFactory(link_reach="restricted") - grand_parent = factories.DocumentFactory(parent=great_grand_parent, link_reach="authenticated") + grand_parent = factories.DocumentFactory( + parent=great_grand_parent, link_reach="authenticated" + ) parent = factories.DocumentFactory(parent=grand_parent, link_reach="public") document = factories.DocumentFactory(parent=parent, link_reach="restricted") @@ -199,7 +212,11 @@ def test_services_search_indexers_ancestors_users(mock_push): assert len(results) == 3 assert results[str(grand_parent.id)]["users"] == [str(user_gp.sub)] assert set(results[str(parent.id)]["users"]) == {str(user_gp.sub), str(user_p.sub)} - assert set(results[str(document.id)]["users"]) == {str(user_gp.sub), str(user_p.sub), str(user_d.sub)} + assert set(results[str(document.id)]["users"]) == { + str(user_gp.sub), + str(user_p.sub), + str(user_d.sub), + } @patch.object(FindDocumentIndexer, "push") diff --git a/src/backend/core/tests/test_utils.py b/src/backend/core/tests/test_utils.py index 66285afe..42d588c5 100644 --- a/src/backend/core/tests/test_utils.py +++ b/src/backend/core/tests/test_utils.py @@ -79,24 +79,24 @@ def test_utils_extract_attachments(): def test_utils_get_ancestor_to_descendants_map_single_path(): """Test ancestor mapping of a single path.""" - paths = ['000100020005'] + paths = ["000100020005"] result = utils.get_ancestor_to_descendants_map(paths, steplen=4) assert result == { - '0001': {'000100020005'}, - '00010002': {'000100020005'}, - '000100020005': {'000100020005'}, + "0001": {"000100020005"}, + "00010002": {"000100020005"}, + "000100020005": {"000100020005"}, } def test_utils_get_ancestor_to_descendants_map_multiple_paths(): """Test ancestor mapping of multiple paths with shared prefixes.""" - paths = ['000100020005', '00010003'] + paths = ["000100020005", "00010003"] result = utils.get_ancestor_to_descendants_map(paths, steplen=4) assert result == { - '0001': {'000100020005', '00010003'}, - '00010002': {'000100020005'}, - '000100020005': {'000100020005'}, - '00010003': {'00010003'}, + "0001": {"000100020005", "00010003"}, + "00010002": {"000100020005"}, + "000100020005": {"000100020005"}, + "00010003": {"00010003"}, } diff --git a/src/backend/core/utils.py b/src/backend/core/utils.py index 746a0dfe..357ede03 100644 --- a/src/backend/core/utils.py +++ b/src/backend/core/utils.py @@ -1,8 +1,8 @@ """Utils for the core app.""" import base64 -from collections import defaultdict import re +from collections import defaultdict import pycrdt from bs4 import BeautifulSoup