🛂(backend) harden payload proxy ai

Standard can vary depending on the AI service used.
To work with Albert API:
- a description field is required in the payload
  for every tools call.
- if stream is set to false, stream_options must
  be omitted from the payload.
- the response from Albert sometimes didn't respect
  the format expected by Blocknote, so we added a
  system prompt to enforce it.
This commit is contained in:
Anthony LC
2026-02-02 15:43:00 +01:00
parent 6f0dac4f48
commit 09438a8941
3 changed files with 208 additions and 14 deletions

View File

@@ -2,12 +2,11 @@
import json import json
import logging import logging
from typing import Generator from typing import Any, Dict, Generator
from django.conf import settings from django.conf import settings
from django.core.exceptions import ImproperlyConfigured from django.core.exceptions import ImproperlyConfigured
from openai import OpenAI as OpenAI_Client
from openai import OpenAIError from openai import OpenAIError
from core import enums from core import enums
@@ -15,11 +14,42 @@ from core import enums
if settings.LANGFUSE_PUBLIC_KEY: if settings.LANGFUSE_PUBLIC_KEY:
from langfuse.openai import OpenAI from langfuse.openai import OpenAI
else: else:
OpenAI = OpenAI_Client from openai import OpenAI
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
BLOCKNOTE_TOOL_STRICT_PROMPT = """
You are editing a BlockNote document via the tool applyDocumentOperations.
You MUST respond ONLY by calling applyDocumentOperations.
The tool input MUST be valid JSON:
{ "operations": [ ... ] }
Each operation MUST include "type" and it MUST be one of:
- "update" (requires: id, block)
- "add" (requires: referenceId, position, blocks)
- "delete" (requires: id)
VALID SHAPES (FOLLOW EXACTLY):
Update:
{ "type":"update", "id":"<id$>", "block":"<p>...</p>" }
IMPORTANT: "block" MUST be a STRING containing a SINGLE valid HTML element.
Add:
{ "type":"add", "referenceId":"<id$>", "position":"before|after", "blocks":["<p>...</p>"] }
IMPORTANT: "blocks" MUST be an ARRAY OF STRINGS.
Each item MUST be a STRING containing a SINGLE valid HTML element.
Delete:
{ "type":"delete", "id":"<id$>" }
IDs ALWAYS end with "$". Use ids EXACTLY as provided.
Return ONLY the JSON tool input. No prose, no markdown.
"""
AI_ACTIONS = { AI_ACTIONS = {
"prompt": ( "prompt": (
"Answer the prompt using markdown formatting for structure and emphasis. " "Answer the prompt using markdown formatting for structure and emphasis. "
@@ -106,11 +136,58 @@ class AIService:
system_content = AI_TRANSLATE.format(language=language_display) system_content = AI_TRANSLATE.format(language=language_display)
return self.call_ai_api(system_content, text) return self.call_ai_api(system_content, text)
def _normalize_tools(self, tools: list) -> list:
"""
Normalize tool definitions to ensure they have required fields.
"""
normalized = []
for tool in tools:
if isinstance(tool, dict) and tool.get("type") == "function":
fn = tool.get("function") or {}
if isinstance(fn, dict) and not fn.get("description"):
fn["description"] = f"Tool {fn.get('name', 'unknown')}."
tool["function"] = fn
normalized.append(tool)
return normalized
def _harden_payload(
self, payload: Dict[str, Any], stream: bool = False
) -> Dict[str, Any]:
"""Harden the AI API payload to enforce compliance and tool usage."""
payload["stream"] = stream
# Remove stream_options if stream is False
if not stream and "stream_options" in payload:
payload.pop("stream_options")
# Tools normalization
if isinstance(payload.get("tools"), list):
payload["tools"] = self._normalize_tools(payload["tools"])
# Inject strict system prompt once
msgs = payload.get("messages")
if isinstance(msgs, list):
need = True
if msgs and isinstance(msgs[0], dict) and msgs[0].get("role") == "system":
c = msgs[0].get("content") or ""
if (
isinstance(c, str)
and "applyDocumentOperations" in c
and "blocks" in c
):
need = False
if need:
payload["messages"] = [
{"role": "system", "content": BLOCKNOTE_TOOL_STRICT_PROMPT}
] + msgs
return payload
def proxy(self, data: dict, stream: bool = False) -> Generator[str, None, None]: def proxy(self, data: dict, stream: bool = False) -> Generator[str, None, None]:
"""Proxy AI API requests to the configured AI provider.""" """Proxy AI API requests to the configured AI provider."""
data["stream"] = stream payload = self._harden_payload(data, stream=stream)
try: try:
return self.client.chat.completions.create(**data) return self.client.chat.completions.create(**payload)
except OpenAIError as e: except OpenAIError as e:
raise RuntimeError(f"Failed to proxy AI request: {e}") from e raise RuntimeError(f"Failed to proxy AI request: {e}") from e

View File

@@ -11,6 +11,7 @@ import pytest
from rest_framework.test import APIClient from rest_framework.test import APIClient
from core import factories from core import factories
from core.services.ai_services import BLOCKNOTE_TOOL_STRICT_PROMPT
from core.tests.conftest import TEAM, USER, VIA from core.tests.conftest import TEAM, USER, VIA
pytestmark = pytest.mark.django_db pytestmark = pytest.mark.django_db
@@ -110,7 +111,10 @@ def test_api_documents_ai_proxy_anonymous_success(mock_create):
) )
mock_create.assert_called_once_with( mock_create.assert_called_once_with(
messages=[{"role": "user", "content": "Hello"}], messages=[
{"role": "system", "content": BLOCKNOTE_TOOL_STRICT_PROMPT},
{"role": "user", "content": "Hello"},
],
model="llama", model="llama",
stream=False, stream=False,
) )
@@ -228,7 +232,10 @@ def test_api_documents_ai_proxy_authenticated_success(mock_create, reach, role):
assert response_data["choices"][0]["message"]["content"] == "Hi there!" assert response_data["choices"][0]["message"]["content"] == "Hi there!"
mock_create.assert_called_once_with( mock_create.assert_called_once_with(
messages=[{"role": "user", "content": "Hello"}], messages=[
{"role": "system", "content": BLOCKNOTE_TOOL_STRICT_PROMPT},
{"role": "user", "content": "Hello"},
],
model="llama", model="llama",
stream=False, stream=False,
) )
@@ -315,7 +322,10 @@ def test_api_documents_ai_proxy_success(mock_create, via, role, mock_user_teams)
assert response_data["choices"][0]["message"]["content"] == "Success!" assert response_data["choices"][0]["message"]["content"] == "Success!"
mock_create.assert_called_once_with( mock_create.assert_called_once_with(
messages=[{"role": "user", "content": "Test message"}], messages=[
{"role": "system", "content": BLOCKNOTE_TOOL_STRICT_PROMPT},
{"role": "user", "content": "Test message"},
],
model="llama", model="llama",
stream=False, stream=False,
) )
@@ -412,6 +422,7 @@ def test_api_documents_ai_proxy_invalid_message_format():
} }
@override_settings(AI_STREAM=False)
@patch("openai.resources.chat.completions.Completions.create") @patch("openai.resources.chat.completions.Completions.create")
def test_api_documents_ai_proxy_stream_disabled(mock_create): def test_api_documents_ai_proxy_stream_disabled(mock_create):
"""Stream should be automatically disabled in AI proxy requests.""" """Stream should be automatically disabled in AI proxy requests."""
@@ -440,12 +451,16 @@ def test_api_documents_ai_proxy_stream_disabled(mock_create):
assert response.status_code == 200 assert response.status_code == 200
# Verify that stream was set to False # Verify that stream was set to False
mock_create.assert_called_once_with( mock_create.assert_called_once_with(
messages=[{"role": "user", "content": "Hello"}], messages=[
{"role": "system", "content": BLOCKNOTE_TOOL_STRICT_PROMPT},
{"role": "user", "content": "Hello"},
],
model="llama", model="llama",
stream=False, stream=False,
) )
@override_settings(AI_STREAM=False)
@patch("openai.resources.chat.completions.Completions.create") @patch("openai.resources.chat.completions.Completions.create")
def test_api_documents_ai_proxy_additional_parameters(mock_create): def test_api_documents_ai_proxy_additional_parameters(mock_create):
"""AI proxy should pass through additional parameters to the AI service.""" """AI proxy should pass through additional parameters to the AI service."""
@@ -476,7 +491,10 @@ def test_api_documents_ai_proxy_additional_parameters(mock_create):
assert response.status_code == 200 assert response.status_code == 200
# Verify that additional parameters were passed through # Verify that additional parameters were passed through
mock_create.assert_called_once_with( mock_create.assert_called_once_with(
messages=[{"role": "user", "content": "Hello"}], messages=[
{"role": "system", "content": BLOCKNOTE_TOOL_STRICT_PROMPT},
{"role": "user", "content": "Hello"},
],
model="llama", model="llama",
temperature=0.7, temperature=0.7,
max_tokens=100, max_tokens=100,
@@ -561,6 +579,7 @@ def test_api_documents_ai_proxy_complex_conversation(mock_create):
mock_create.return_value = mock_response mock_create.return_value = mock_response
complex_messages = [ complex_messages = [
{"role": "system", "content": BLOCKNOTE_TOOL_STRICT_PROMPT},
{"role": "system", "content": "You are a helpful programming assistant."}, {"role": "system", "content": "You are a helpful programming assistant."},
{"role": "user", "content": "How do I write a for loop in Python?"}, {"role": "user", "content": "How do I write a for loop in Python?"},
{ {
@@ -688,3 +707,93 @@ def test_api_documents_ai_proxy_ai_feature_disabled(settings):
assert response.status_code == 400 assert response.status_code == 400
assert response.json() == ["AI feature is not enabled."] assert response.json() == ["AI feature is not enabled."]
@override_settings(AI_STREAM=False)
@patch("openai.resources.chat.completions.Completions.create")
def test_api_documents_ai_proxy_harden_payload(mock_create):
"""
AI proxy should harden the payload by default:
- it will remove stream_options if stream is False
- normalize tools by adding descriptions if missing
"""
user = factories.UserFactory()
client = APIClient()
client.force_login(user)
document = factories.DocumentFactory(link_reach="public", link_role="editor")
mock_response = MagicMock()
mock_response.model_dump.return_value = {
"id": "chatcmpl-789",
"object": "chat.completion",
"model": "llama",
"choices": [
{
"index": 0,
"message": {"role": "assistant", "content": "Success!"},
"finish_reason": "stop",
}
],
}
mock_create.return_value = mock_response
url = f"/api/v1.0/documents/{document.id!s}/ai-proxy/"
response = client.post(
url,
{
"messages": [{"role": "user", "content": "Hello"}],
"model": "llama",
"stream": False,
"stream_options": {"include_usage": True},
"tools": [
{
"type": "function",
"function": {
"name": "applyDocumentOperations",
"parameters": {
"type": "object",
"properties": {
"operations": {"type": "array", "items": {"anyOf": []}}
},
"additionalProperties": False,
"required": ["operations"],
},
},
}
],
},
format="json",
)
assert response.status_code == 200
response_data = response.json()
assert response_data["id"] == "chatcmpl-789"
assert response_data["choices"][0]["message"]["content"] == "Success!"
mock_create.assert_called_once_with(
messages=[
{"role": "system", "content": BLOCKNOTE_TOOL_STRICT_PROMPT},
{"role": "user", "content": "Hello"},
],
model="llama",
stream=False,
tools=[
{
"type": "function",
"function": {
"name": "applyDocumentOperations",
"parameters": {
"type": "object",
"properties": {
"operations": {"type": "array", "items": {"anyOf": []}}
},
"additionalProperties": False,
"required": ["operations"],
},
"description": "Tool applyDocumentOperations.",
},
}
],
)

View File

@@ -10,7 +10,7 @@ from django.test.utils import override_settings
import pytest import pytest
from openai import OpenAIError from openai import OpenAIError
from core.services.ai_services import AIService from core.services.ai_services import BLOCKNOTE_TOOL_STRICT_PROMPT, AIService
pytestmark = pytest.mark.django_db pytestmark = pytest.mark.django_db
@@ -125,7 +125,11 @@ def test_services_ai_proxy_success(mock_create):
} }
assert response == expected_response assert response == expected_response
mock_create.assert_called_once_with( mock_create.assert_called_once_with(
messages=[{"role": "user", "content": "hello"}], stream=False messages=[
{"role": "system", "content": BLOCKNOTE_TOOL_STRICT_PROMPT},
{"role": "user", "content": "hello"},
],
stream=False,
) )
@@ -182,5 +186,9 @@ def test_services_ai_proxy_with_stream(mock_create):
} }
assert response == expected_response assert response == expected_response
mock_create.assert_called_once_with( mock_create.assert_called_once_with(
messages=[{"role": "user", "content": "hello"}], stream=True messages=[
{"role": "system", "content": BLOCKNOTE_TOOL_STRICT_PROMPT},
{"role": "user", "content": "hello"},
],
stream=True,
) )