diff --git a/src/backend/core/api/viewsets.py b/src/backend/core/api/viewsets.py index ee0c594e..90267961 100644 --- a/src/backend/core/api/viewsets.py +++ b/src/backend/core/api/viewsets.py @@ -13,6 +13,7 @@ from django.contrib.postgres.search import TrigramSimilarity from django.core.cache import cache from django.core.exceptions import ValidationError from django.core.files.storage import default_storage +from django.core.validators import URLValidator from django.db import connection, transaction from django.db import models as db from django.db.models.expressions import RawSQL @@ -1441,6 +1442,15 @@ class DocumentViewSet( url = unquote(url) + url_validator = URLValidator(schemes=["http", "https"]) + try: + url_validator(url) + except drf.exceptions.ValidationError as e: + return drf.response.Response( + {"detail": str(e)}, + status=drf.status.HTTP_400_BAD_REQUEST, + ) + try: response = requests.get( url, diff --git a/src/backend/core/tests/documents/test_api_documents_cors_proxy.py b/src/backend/core/tests/documents/test_api_documents_cors_proxy.py index 5d8b02c2..3356e93c 100644 --- a/src/backend/core/tests/documents/test_api_documents_cors_proxy.py +++ b/src/backend/core/tests/documents/test_api_documents_cors_proxy.py @@ -149,3 +149,24 @@ def test_api_docs_cors_proxy_unsupported_media_type(): f"/api/v1.0/documents/{document.id!s}/cors-proxy/?url={url_to_fetch}" ) assert response.status_code == 415 + + +@pytest.mark.parametrize( + "url_to_fetch", + [ + "ftp://external-url.com/assets/index.html", + "ftps://external-url.com/assets/index.html", + "invalid-url.com", + "ssh://external-url.com/assets/index.html", + ], +) +def test_api_docs_cors_proxy_invalid_url(url_to_fetch): + """Test the CORS proxy API for documents with an invalid URL.""" + document = factories.DocumentFactory(link_reach="public") + + client = APIClient() + response = client.get( + f"/api/v1.0/documents/{document.id!s}/cors-proxy/?url={url_to_fetch}" + ) + assert response.status_code == 400 + assert response.json() == ["Enter a valid URL."]