From 09438a8941749d8879bbe81b1a909f73ca827fa6 Mon Sep 17 00:00:00 2001 From: Anthony LC Date: Mon, 2 Feb 2026 15:43:00 +0100 Subject: [PATCH] =?UTF-8?q?=F0=9F=9B=82(backend)=20harden=20payload=20prox?= =?UTF-8?q?y=20ai?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Standard can vary depending on the AI service used. To work with Albert API: - a description field is required in the payload for every tools call. - if stream is set to false, stream_options must be omitted from the payload. - the response from Albert sometimes didn't respect the format expected by Blocknote, so we added a system prompt to enforce it. --- src/backend/core/services/ai_services.py | 89 ++++++++++++- .../documents/test_api_documents_ai_proxy.py | 119 +++++++++++++++++- .../core/tests/test_services_ai_services.py | 14 ++- 3 files changed, 208 insertions(+), 14 deletions(-) diff --git a/src/backend/core/services/ai_services.py b/src/backend/core/services/ai_services.py index 72583aa6..cad97dfb 100644 --- a/src/backend/core/services/ai_services.py +++ b/src/backend/core/services/ai_services.py @@ -2,12 +2,11 @@ import json import logging -from typing import Generator +from typing import Any, Dict, Generator from django.conf import settings from django.core.exceptions import ImproperlyConfigured -from openai import OpenAI as OpenAI_Client from openai import OpenAIError from core import enums @@ -15,11 +14,42 @@ from core import enums if settings.LANGFUSE_PUBLIC_KEY: from langfuse.openai import OpenAI else: - OpenAI = OpenAI_Client - + from openai import OpenAI + log = logging.getLogger(__name__) +BLOCKNOTE_TOOL_STRICT_PROMPT = """ +You are editing a BlockNote document via the tool applyDocumentOperations. + +You MUST respond ONLY by calling applyDocumentOperations. +The tool input MUST be valid JSON: +{ "operations": [ ... ] } + +Each operation MUST include "type" and it MUST be one of: +- "update" (requires: id, block) +- "add" (requires: referenceId, position, blocks) +- "delete" (requires: id) + +VALID SHAPES (FOLLOW EXACTLY): + +Update: +{ "type":"update", "id":"", "block":"

...

" } +IMPORTANT: "block" MUST be a STRING containing a SINGLE valid HTML element. + +Add: +{ "type":"add", "referenceId":"", "position":"before|after", "blocks":["

...

"] } +IMPORTANT: "blocks" MUST be an ARRAY OF STRINGS. +Each item MUST be a STRING containing a SINGLE valid HTML element. + +Delete: +{ "type":"delete", "id":"" } + +IDs ALWAYS end with "$". Use ids EXACTLY as provided. + +Return ONLY the JSON tool input. No prose, no markdown. +""" + AI_ACTIONS = { "prompt": ( "Answer the prompt using markdown formatting for structure and emphasis. " @@ -106,11 +136,58 @@ class AIService: system_content = AI_TRANSLATE.format(language=language_display) return self.call_ai_api(system_content, text) + def _normalize_tools(self, tools: list) -> list: + """ + Normalize tool definitions to ensure they have required fields. + """ + normalized = [] + for tool in tools: + if isinstance(tool, dict) and tool.get("type") == "function": + fn = tool.get("function") or {} + if isinstance(fn, dict) and not fn.get("description"): + fn["description"] = f"Tool {fn.get('name', 'unknown')}." + tool["function"] = fn + normalized.append(tool) + return normalized + + def _harden_payload( + self, payload: Dict[str, Any], stream: bool = False + ) -> Dict[str, Any]: + """Harden the AI API payload to enforce compliance and tool usage.""" + payload["stream"] = stream + + # Remove stream_options if stream is False + if not stream and "stream_options" in payload: + payload.pop("stream_options") + + # Tools normalization + if isinstance(payload.get("tools"), list): + payload["tools"] = self._normalize_tools(payload["tools"]) + + # Inject strict system prompt once + msgs = payload.get("messages") + if isinstance(msgs, list): + need = True + if msgs and isinstance(msgs[0], dict) and msgs[0].get("role") == "system": + c = msgs[0].get("content") or "" + if ( + isinstance(c, str) + and "applyDocumentOperations" in c + and "blocks" in c + ): + need = False + if need: + payload["messages"] = [ + {"role": "system", "content": BLOCKNOTE_TOOL_STRICT_PROMPT} + ] + msgs + + return payload + def proxy(self, data: dict, stream: bool = False) -> Generator[str, None, None]: """Proxy AI API requests to the configured AI provider.""" - data["stream"] = stream + payload = self._harden_payload(data, stream=stream) try: - return self.client.chat.completions.create(**data) + return self.client.chat.completions.create(**payload) except OpenAIError as e: raise RuntimeError(f"Failed to proxy AI request: {e}") from e diff --git a/src/backend/core/tests/documents/test_api_documents_ai_proxy.py b/src/backend/core/tests/documents/test_api_documents_ai_proxy.py index f99d01a6..1046cb81 100644 --- a/src/backend/core/tests/documents/test_api_documents_ai_proxy.py +++ b/src/backend/core/tests/documents/test_api_documents_ai_proxy.py @@ -11,6 +11,7 @@ import pytest from rest_framework.test import APIClient from core import factories +from core.services.ai_services import BLOCKNOTE_TOOL_STRICT_PROMPT from core.tests.conftest import TEAM, USER, VIA pytestmark = pytest.mark.django_db @@ -110,7 +111,10 @@ def test_api_documents_ai_proxy_anonymous_success(mock_create): ) mock_create.assert_called_once_with( - messages=[{"role": "user", "content": "Hello"}], + messages=[ + {"role": "system", "content": BLOCKNOTE_TOOL_STRICT_PROMPT}, + {"role": "user", "content": "Hello"}, + ], model="llama", stream=False, ) @@ -228,7 +232,10 @@ def test_api_documents_ai_proxy_authenticated_success(mock_create, reach, role): assert response_data["choices"][0]["message"]["content"] == "Hi there!" mock_create.assert_called_once_with( - messages=[{"role": "user", "content": "Hello"}], + messages=[ + {"role": "system", "content": BLOCKNOTE_TOOL_STRICT_PROMPT}, + {"role": "user", "content": "Hello"}, + ], model="llama", stream=False, ) @@ -315,7 +322,10 @@ def test_api_documents_ai_proxy_success(mock_create, via, role, mock_user_teams) assert response_data["choices"][0]["message"]["content"] == "Success!" mock_create.assert_called_once_with( - messages=[{"role": "user", "content": "Test message"}], + messages=[ + {"role": "system", "content": BLOCKNOTE_TOOL_STRICT_PROMPT}, + {"role": "user", "content": "Test message"}, + ], model="llama", stream=False, ) @@ -412,6 +422,7 @@ def test_api_documents_ai_proxy_invalid_message_format(): } +@override_settings(AI_STREAM=False) @patch("openai.resources.chat.completions.Completions.create") def test_api_documents_ai_proxy_stream_disabled(mock_create): """Stream should be automatically disabled in AI proxy requests.""" @@ -440,12 +451,16 @@ def test_api_documents_ai_proxy_stream_disabled(mock_create): assert response.status_code == 200 # Verify that stream was set to False mock_create.assert_called_once_with( - messages=[{"role": "user", "content": "Hello"}], + messages=[ + {"role": "system", "content": BLOCKNOTE_TOOL_STRICT_PROMPT}, + {"role": "user", "content": "Hello"}, + ], model="llama", stream=False, ) +@override_settings(AI_STREAM=False) @patch("openai.resources.chat.completions.Completions.create") def test_api_documents_ai_proxy_additional_parameters(mock_create): """AI proxy should pass through additional parameters to the AI service.""" @@ -476,7 +491,10 @@ def test_api_documents_ai_proxy_additional_parameters(mock_create): assert response.status_code == 200 # Verify that additional parameters were passed through mock_create.assert_called_once_with( - messages=[{"role": "user", "content": "Hello"}], + messages=[ + {"role": "system", "content": BLOCKNOTE_TOOL_STRICT_PROMPT}, + {"role": "user", "content": "Hello"}, + ], model="llama", temperature=0.7, max_tokens=100, @@ -561,6 +579,7 @@ def test_api_documents_ai_proxy_complex_conversation(mock_create): mock_create.return_value = mock_response complex_messages = [ + {"role": "system", "content": BLOCKNOTE_TOOL_STRICT_PROMPT}, {"role": "system", "content": "You are a helpful programming assistant."}, {"role": "user", "content": "How do I write a for loop in Python?"}, { @@ -688,3 +707,93 @@ def test_api_documents_ai_proxy_ai_feature_disabled(settings): assert response.status_code == 400 assert response.json() == ["AI feature is not enabled."] + + +@override_settings(AI_STREAM=False) +@patch("openai.resources.chat.completions.Completions.create") +def test_api_documents_ai_proxy_harden_payload(mock_create): + """ + AI proxy should harden the payload by default: + - it will remove stream_options if stream is False + - normalize tools by adding descriptions if missing + """ + user = factories.UserFactory() + + client = APIClient() + client.force_login(user) + + document = factories.DocumentFactory(link_reach="public", link_role="editor") + + mock_response = MagicMock() + mock_response.model_dump.return_value = { + "id": "chatcmpl-789", + "object": "chat.completion", + "model": "llama", + "choices": [ + { + "index": 0, + "message": {"role": "assistant", "content": "Success!"}, + "finish_reason": "stop", + } + ], + } + mock_create.return_value = mock_response + + url = f"/api/v1.0/documents/{document.id!s}/ai-proxy/" + response = client.post( + url, + { + "messages": [{"role": "user", "content": "Hello"}], + "model": "llama", + "stream": False, + "stream_options": {"include_usage": True}, + "tools": [ + { + "type": "function", + "function": { + "name": "applyDocumentOperations", + "parameters": { + "type": "object", + "properties": { + "operations": {"type": "array", "items": {"anyOf": []}} + }, + "additionalProperties": False, + "required": ["operations"], + }, + }, + } + ], + }, + format="json", + ) + + assert response.status_code == 200 + response_data = response.json() + assert response_data["id"] == "chatcmpl-789" + assert response_data["choices"][0]["message"]["content"] == "Success!" + + mock_create.assert_called_once_with( + messages=[ + {"role": "system", "content": BLOCKNOTE_TOOL_STRICT_PROMPT}, + {"role": "user", "content": "Hello"}, + ], + model="llama", + stream=False, + tools=[ + { + "type": "function", + "function": { + "name": "applyDocumentOperations", + "parameters": { + "type": "object", + "properties": { + "operations": {"type": "array", "items": {"anyOf": []}} + }, + "additionalProperties": False, + "required": ["operations"], + }, + "description": "Tool applyDocumentOperations.", + }, + } + ], + ) diff --git a/src/backend/core/tests/test_services_ai_services.py b/src/backend/core/tests/test_services_ai_services.py index edfe7e1a..5bc845d0 100644 --- a/src/backend/core/tests/test_services_ai_services.py +++ b/src/backend/core/tests/test_services_ai_services.py @@ -10,7 +10,7 @@ from django.test.utils import override_settings import pytest from openai import OpenAIError -from core.services.ai_services import AIService +from core.services.ai_services import BLOCKNOTE_TOOL_STRICT_PROMPT, AIService pytestmark = pytest.mark.django_db @@ -125,7 +125,11 @@ def test_services_ai_proxy_success(mock_create): } assert response == expected_response mock_create.assert_called_once_with( - messages=[{"role": "user", "content": "hello"}], stream=False + messages=[ + {"role": "system", "content": BLOCKNOTE_TOOL_STRICT_PROMPT}, + {"role": "user", "content": "hello"}, + ], + stream=False, ) @@ -182,5 +186,9 @@ def test_services_ai_proxy_with_stream(mock_create): } assert response == expected_response mock_create.assert_called_once_with( - messages=[{"role": "user", "content": "hello"}], stream=True + messages=[ + {"role": "system", "content": BLOCKNOTE_TOOL_STRICT_PROMPT}, + {"role": "user", "content": "hello"}, + ], + stream=True, )