diff --git a/docs/env.md b/docs/env.md index cd2aafe3..d983e44b 100644 --- a/docs/env.md +++ b/docs/env.md @@ -98,3 +98,5 @@ These are the environmental variables you can set for the impress-backend contai | DJANGO_CSRF_TRUSTED_ORIGINS | CSRF trusted origins | [] | | REDIS_URL | cache url | redis://redis:6379/1 | | CACHES_DEFAULT_TIMEOUT | cache default timeout | 30 | +| MALWARE_DETECTION_BACKEND | The malware detection backend use from the django-lasuite package | lasuite.malware_detection.backends.dummy.DummyBackend | +| MALWARE_DETECTION_PARAMETERS | A dict containing all the parameters to initiate the malware detection backend | {"callback_path": "core.malware_detection.malware_detection_callback",} | \ No newline at end of file diff --git a/src/backend/core/enums.py b/src/backend/core/enums.py index 78cac40a..46e62b2c 100644 --- a/src/backend/core/enums.py +++ b/src/backend/core/enums.py @@ -3,6 +3,7 @@ Core application enums declaration """ import re +from enum import StrEnum from django.conf import global_settings, settings from django.db import models @@ -38,3 +39,10 @@ class MoveNodePositionChoices(models.TextChoices): LAST_SIBLING = "last-sibling", _("Last sibling") LEFT = "left", _("Left") RIGHT = "right", _("Right") + + +class DocumentAttachmentStatus(StrEnum): + """Defines the possible statuses for an attachment.""" + + PROCESSING = "processing" + READY = "ready" diff --git a/src/backend/core/malware_detection.py b/src/backend/core/malware_detection.py new file mode 100644 index 00000000..e67617d8 --- /dev/null +++ b/src/backend/core/malware_detection.py @@ -0,0 +1,51 @@ +"""Malware detection callbacks""" + +import logging + +from django.core.files.storage import default_storage + +from lasuite.malware_detection.enums import ReportStatus + +from core.enums import DocumentAttachmentStatus +from core.models import Document + +logger = logging.getLogger(__name__) + + +def malware_detection_callback(file_path, status, error_info, **kwargs): + """Malware detection callback""" + + if status == ReportStatus.SAFE: + logger.info("File %s is safe", file_path) + # Get existing metadata + s3_client = default_storage.connection.meta.client + bucket_name = default_storage.bucket_name + head_resp = s3_client.head_object(Bucket=bucket_name, Key=file_path) + metadata = head_resp.get("Metadata", {}) + metadata.update({"status": DocumentAttachmentStatus.READY}) + # Update status in metadata + s3_client.copy_object( + Bucket=bucket_name, + CopySource={"Bucket": bucket_name, "Key": file_path}, + Key=file_path, + ContentType=head_resp.get("ContentType"), + Metadata=metadata, + MetadataDirective="REPLACE", + ) + return + + document_id = kwargs.get("document_id") + logger.error( + "File %s for document %s is infected with malware. Error info: %s", + file_path, + document_id, + error_info, + ) + + # Remove the file from the document and change the status to unsafe + document = Document.objects.get(pk=document_id) + document.attachments.remove(file_path) + document.save(update_fields=["attachments"]) + + # Delete the file from the storage + default_storage.delete(file_path) diff --git a/src/backend/core/tests/test_malware_detection.py b/src/backend/core/tests/test_malware_detection.py new file mode 100644 index 00000000..cf2822f0 --- /dev/null +++ b/src/backend/core/tests/test_malware_detection.py @@ -0,0 +1,76 @@ +"""Test malware detection callback.""" + +import random + +from django.core.files.base import ContentFile +from django.core.files.storage import default_storage + +import pytest +from lasuite.malware_detection.enums import ReportStatus + +from core.enums import DocumentAttachmentStatus +from core.factories import DocumentFactory +from core.malware_detection import malware_detection_callback + +pytestmark = pytest.mark.django_db + + +@pytest.fixture +def safe_file(): + """Create a safe file.""" + file_path = "test.txt" + default_storage.save(file_path, ContentFile("test")) + yield file_path + default_storage.delete(file_path) + + +@pytest.fixture +def unsafe_file(): + """Create an unsafe file.""" + file_path = "unsafe.txt" + default_storage.save(file_path, ContentFile("test")) + yield file_path + + +def test_malware_detection_callback_safe_status(safe_file): + """Test malware detection callback with safe status.""" + + document = DocumentFactory(attachments=[safe_file]) + + malware_detection_callback( + safe_file, + ReportStatus.SAFE, + error_info={}, + document_id=document.id, + ) + + document.refresh_from_db() + + assert safe_file in document.attachments + assert default_storage.exists(safe_file) + + s3_client = default_storage.connection.meta.client + bucket_name = default_storage.bucket_name + head_resp = s3_client.head_object(Bucket=bucket_name, Key=safe_file) + metadata = head_resp.get("Metadata", {}) + assert metadata["status"] == DocumentAttachmentStatus.READY + + +def test_malware_detection_callback_unsafe_status(unsafe_file): + """Test malware detection callback with unsafe status.""" + + document = DocumentFactory(attachments=[unsafe_file]) + + malware_detection_callback( + unsafe_file, + random.choice( + [status.value for status in ReportStatus if status != ReportStatus.SAFE] + ), + error_info={"error": "test", "error_code": 4001}, + document_id=document.id, + ) + + document.refresh_from_db() + + assert unsafe_file not in document.attachments + assert not default_storage.exists(unsafe_file) diff --git a/src/backend/impress/settings.py b/src/backend/impress/settings.py index e4932081..c831c3ba 100755 --- a/src/backend/impress/settings.py +++ b/src/backend/impress/settings.py @@ -317,6 +317,7 @@ class Base(Configuration): "django.contrib.staticfiles", # OIDC third party "mozilla_django_oidc", + "lasuite.malware_detection", ] # Cache @@ -680,6 +681,21 @@ class Base(Configuration): }, } + MALWARE_DETECTION = { + "BACKEND": values.Value( + "lasuite.malware_detection.backends.dummy.DummyBackend", + environ_name="MALWARE_DETECTION_BACKEND", + environ_prefix=None, + ), + "PARAMETERS": values.DictValue( + default={ + "callback_path": "core.malware_detection.malware_detection_callback", + }, + environ_name="MALWARE_DETECTION_PARAMETERS", + environ_prefix=None, + ), + } + API_USERS_LIST_LIMIT = values.PositiveIntegerValue( default=5, environ_name="API_USERS_LIST_LIMIT", diff --git a/src/backend/pyproject.toml b/src/backend/pyproject.toml index c0496b1f..4190e2f3 100644 --- a/src/backend/pyproject.toml +++ b/src/backend/pyproject.toml @@ -33,7 +33,7 @@ dependencies = [ "django-cors-headers==4.7.0", "django-countries==7.6.1", "django-filter==25.1", - "django-lasuite==0.0.7", + "django-lasuite[all]==0.0.8", "django-parler==2.3", "django-redis==5.4.0", "django-storages[s3]==1.14.6",