🛂(backend) harden payload proxy ai
Standard can vary depending on the AI service used. To work with Albert API: - a description field is required in the payload for every tools call. - if stream is set to false, stream_options must be omitted from the payload. - the response from Albert sometimes didn't respect the format expected by Blocknote, so we added a system prompt to enforce it.
This commit is contained in:
@@ -2,12 +2,11 @@
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Generator
|
||||
from typing import Any, Dict, Generator
|
||||
|
||||
from django.conf import settings
|
||||
from django.core.exceptions import ImproperlyConfigured
|
||||
|
||||
from openai import OpenAI as OpenAI_Client
|
||||
from openai import OpenAIError
|
||||
|
||||
from core import enums
|
||||
@@ -15,11 +14,42 @@ from core import enums
|
||||
if settings.LANGFUSE_PUBLIC_KEY:
|
||||
from langfuse.openai import OpenAI
|
||||
else:
|
||||
OpenAI = OpenAI_Client
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
BLOCKNOTE_TOOL_STRICT_PROMPT = """
|
||||
You are editing a BlockNote document via the tool applyDocumentOperations.
|
||||
|
||||
You MUST respond ONLY by calling applyDocumentOperations.
|
||||
The tool input MUST be valid JSON:
|
||||
{ "operations": [ ... ] }
|
||||
|
||||
Each operation MUST include "type" and it MUST be one of:
|
||||
- "update" (requires: id, block)
|
||||
- "add" (requires: referenceId, position, blocks)
|
||||
- "delete" (requires: id)
|
||||
|
||||
VALID SHAPES (FOLLOW EXACTLY):
|
||||
|
||||
Update:
|
||||
{ "type":"update", "id":"<id$>", "block":"<p>...</p>" }
|
||||
IMPORTANT: "block" MUST be a STRING containing a SINGLE valid HTML element.
|
||||
|
||||
Add:
|
||||
{ "type":"add", "referenceId":"<id$>", "position":"before|after", "blocks":["<p>...</p>"] }
|
||||
IMPORTANT: "blocks" MUST be an ARRAY OF STRINGS.
|
||||
Each item MUST be a STRING containing a SINGLE valid HTML element.
|
||||
|
||||
Delete:
|
||||
{ "type":"delete", "id":"<id$>" }
|
||||
|
||||
IDs ALWAYS end with "$". Use ids EXACTLY as provided.
|
||||
|
||||
Return ONLY the JSON tool input. No prose, no markdown.
|
||||
"""
|
||||
|
||||
AI_ACTIONS = {
|
||||
"prompt": (
|
||||
"Answer the prompt using markdown formatting for structure and emphasis. "
|
||||
@@ -106,11 +136,58 @@ class AIService:
|
||||
system_content = AI_TRANSLATE.format(language=language_display)
|
||||
return self.call_ai_api(system_content, text)
|
||||
|
||||
def _normalize_tools(self, tools: list) -> list:
|
||||
"""
|
||||
Normalize tool definitions to ensure they have required fields.
|
||||
"""
|
||||
normalized = []
|
||||
for tool in tools:
|
||||
if isinstance(tool, dict) and tool.get("type") == "function":
|
||||
fn = tool.get("function") or {}
|
||||
if isinstance(fn, dict) and not fn.get("description"):
|
||||
fn["description"] = f"Tool {fn.get('name', 'unknown')}."
|
||||
tool["function"] = fn
|
||||
normalized.append(tool)
|
||||
return normalized
|
||||
|
||||
def _harden_payload(
|
||||
self, payload: Dict[str, Any], stream: bool = False
|
||||
) -> Dict[str, Any]:
|
||||
"""Harden the AI API payload to enforce compliance and tool usage."""
|
||||
payload["stream"] = stream
|
||||
|
||||
# Remove stream_options if stream is False
|
||||
if not stream and "stream_options" in payload:
|
||||
payload.pop("stream_options")
|
||||
|
||||
# Tools normalization
|
||||
if isinstance(payload.get("tools"), list):
|
||||
payload["tools"] = self._normalize_tools(payload["tools"])
|
||||
|
||||
# Inject strict system prompt once
|
||||
msgs = payload.get("messages")
|
||||
if isinstance(msgs, list):
|
||||
need = True
|
||||
if msgs and isinstance(msgs[0], dict) and msgs[0].get("role") == "system":
|
||||
c = msgs[0].get("content") or ""
|
||||
if (
|
||||
isinstance(c, str)
|
||||
and "applyDocumentOperations" in c
|
||||
and "blocks" in c
|
||||
):
|
||||
need = False
|
||||
if need:
|
||||
payload["messages"] = [
|
||||
{"role": "system", "content": BLOCKNOTE_TOOL_STRICT_PROMPT}
|
||||
] + msgs
|
||||
|
||||
return payload
|
||||
|
||||
def proxy(self, data: dict, stream: bool = False) -> Generator[str, None, None]:
|
||||
"""Proxy AI API requests to the configured AI provider."""
|
||||
data["stream"] = stream
|
||||
payload = self._harden_payload(data, stream=stream)
|
||||
try:
|
||||
return self.client.chat.completions.create(**data)
|
||||
return self.client.chat.completions.create(**payload)
|
||||
except OpenAIError as e:
|
||||
raise RuntimeError(f"Failed to proxy AI request: {e}") from e
|
||||
|
||||
|
||||
@@ -11,6 +11,7 @@ import pytest
|
||||
from rest_framework.test import APIClient
|
||||
|
||||
from core import factories
|
||||
from core.services.ai_services import BLOCKNOTE_TOOL_STRICT_PROMPT
|
||||
from core.tests.conftest import TEAM, USER, VIA
|
||||
|
||||
pytestmark = pytest.mark.django_db
|
||||
@@ -110,7 +111,10 @@ def test_api_documents_ai_proxy_anonymous_success(mock_create):
|
||||
)
|
||||
|
||||
mock_create.assert_called_once_with(
|
||||
messages=[{"role": "user", "content": "Hello"}],
|
||||
messages=[
|
||||
{"role": "system", "content": BLOCKNOTE_TOOL_STRICT_PROMPT},
|
||||
{"role": "user", "content": "Hello"},
|
||||
],
|
||||
model="llama",
|
||||
stream=False,
|
||||
)
|
||||
@@ -228,7 +232,10 @@ def test_api_documents_ai_proxy_authenticated_success(mock_create, reach, role):
|
||||
assert response_data["choices"][0]["message"]["content"] == "Hi there!"
|
||||
|
||||
mock_create.assert_called_once_with(
|
||||
messages=[{"role": "user", "content": "Hello"}],
|
||||
messages=[
|
||||
{"role": "system", "content": BLOCKNOTE_TOOL_STRICT_PROMPT},
|
||||
{"role": "user", "content": "Hello"},
|
||||
],
|
||||
model="llama",
|
||||
stream=False,
|
||||
)
|
||||
@@ -315,7 +322,10 @@ def test_api_documents_ai_proxy_success(mock_create, via, role, mock_user_teams)
|
||||
assert response_data["choices"][0]["message"]["content"] == "Success!"
|
||||
|
||||
mock_create.assert_called_once_with(
|
||||
messages=[{"role": "user", "content": "Test message"}],
|
||||
messages=[
|
||||
{"role": "system", "content": BLOCKNOTE_TOOL_STRICT_PROMPT},
|
||||
{"role": "user", "content": "Test message"},
|
||||
],
|
||||
model="llama",
|
||||
stream=False,
|
||||
)
|
||||
@@ -412,6 +422,7 @@ def test_api_documents_ai_proxy_invalid_message_format():
|
||||
}
|
||||
|
||||
|
||||
@override_settings(AI_STREAM=False)
|
||||
@patch("openai.resources.chat.completions.Completions.create")
|
||||
def test_api_documents_ai_proxy_stream_disabled(mock_create):
|
||||
"""Stream should be automatically disabled in AI proxy requests."""
|
||||
@@ -440,12 +451,16 @@ def test_api_documents_ai_proxy_stream_disabled(mock_create):
|
||||
assert response.status_code == 200
|
||||
# Verify that stream was set to False
|
||||
mock_create.assert_called_once_with(
|
||||
messages=[{"role": "user", "content": "Hello"}],
|
||||
messages=[
|
||||
{"role": "system", "content": BLOCKNOTE_TOOL_STRICT_PROMPT},
|
||||
{"role": "user", "content": "Hello"},
|
||||
],
|
||||
model="llama",
|
||||
stream=False,
|
||||
)
|
||||
|
||||
|
||||
@override_settings(AI_STREAM=False)
|
||||
@patch("openai.resources.chat.completions.Completions.create")
|
||||
def test_api_documents_ai_proxy_additional_parameters(mock_create):
|
||||
"""AI proxy should pass through additional parameters to the AI service."""
|
||||
@@ -476,7 +491,10 @@ def test_api_documents_ai_proxy_additional_parameters(mock_create):
|
||||
assert response.status_code == 200
|
||||
# Verify that additional parameters were passed through
|
||||
mock_create.assert_called_once_with(
|
||||
messages=[{"role": "user", "content": "Hello"}],
|
||||
messages=[
|
||||
{"role": "system", "content": BLOCKNOTE_TOOL_STRICT_PROMPT},
|
||||
{"role": "user", "content": "Hello"},
|
||||
],
|
||||
model="llama",
|
||||
temperature=0.7,
|
||||
max_tokens=100,
|
||||
@@ -561,6 +579,7 @@ def test_api_documents_ai_proxy_complex_conversation(mock_create):
|
||||
mock_create.return_value = mock_response
|
||||
|
||||
complex_messages = [
|
||||
{"role": "system", "content": BLOCKNOTE_TOOL_STRICT_PROMPT},
|
||||
{"role": "system", "content": "You are a helpful programming assistant."},
|
||||
{"role": "user", "content": "How do I write a for loop in Python?"},
|
||||
{
|
||||
@@ -688,3 +707,93 @@ def test_api_documents_ai_proxy_ai_feature_disabled(settings):
|
||||
|
||||
assert response.status_code == 400
|
||||
assert response.json() == ["AI feature is not enabled."]
|
||||
|
||||
|
||||
@override_settings(AI_STREAM=False)
|
||||
@patch("openai.resources.chat.completions.Completions.create")
|
||||
def test_api_documents_ai_proxy_harden_payload(mock_create):
|
||||
"""
|
||||
AI proxy should harden the payload by default:
|
||||
- it will remove stream_options if stream is False
|
||||
- normalize tools by adding descriptions if missing
|
||||
"""
|
||||
user = factories.UserFactory()
|
||||
|
||||
client = APIClient()
|
||||
client.force_login(user)
|
||||
|
||||
document = factories.DocumentFactory(link_reach="public", link_role="editor")
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.model_dump.return_value = {
|
||||
"id": "chatcmpl-789",
|
||||
"object": "chat.completion",
|
||||
"model": "llama",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {"role": "assistant", "content": "Success!"},
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
],
|
||||
}
|
||||
mock_create.return_value = mock_response
|
||||
|
||||
url = f"/api/v1.0/documents/{document.id!s}/ai-proxy/"
|
||||
response = client.post(
|
||||
url,
|
||||
{
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
"model": "llama",
|
||||
"stream": False,
|
||||
"stream_options": {"include_usage": True},
|
||||
"tools": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "applyDocumentOperations",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"operations": {"type": "array", "items": {"anyOf": []}}
|
||||
},
|
||||
"additionalProperties": False,
|
||||
"required": ["operations"],
|
||||
},
|
||||
},
|
||||
}
|
||||
],
|
||||
},
|
||||
format="json",
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
response_data = response.json()
|
||||
assert response_data["id"] == "chatcmpl-789"
|
||||
assert response_data["choices"][0]["message"]["content"] == "Success!"
|
||||
|
||||
mock_create.assert_called_once_with(
|
||||
messages=[
|
||||
{"role": "system", "content": BLOCKNOTE_TOOL_STRICT_PROMPT},
|
||||
{"role": "user", "content": "Hello"},
|
||||
],
|
||||
model="llama",
|
||||
stream=False,
|
||||
tools=[
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "applyDocumentOperations",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"operations": {"type": "array", "items": {"anyOf": []}}
|
||||
},
|
||||
"additionalProperties": False,
|
||||
"required": ["operations"],
|
||||
},
|
||||
"description": "Tool applyDocumentOperations.",
|
||||
},
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
@@ -10,7 +10,7 @@ from django.test.utils import override_settings
|
||||
import pytest
|
||||
from openai import OpenAIError
|
||||
|
||||
from core.services.ai_services import AIService
|
||||
from core.services.ai_services import BLOCKNOTE_TOOL_STRICT_PROMPT, AIService
|
||||
|
||||
pytestmark = pytest.mark.django_db
|
||||
|
||||
@@ -125,7 +125,11 @@ def test_services_ai_proxy_success(mock_create):
|
||||
}
|
||||
assert response == expected_response
|
||||
mock_create.assert_called_once_with(
|
||||
messages=[{"role": "user", "content": "hello"}], stream=False
|
||||
messages=[
|
||||
{"role": "system", "content": BLOCKNOTE_TOOL_STRICT_PROMPT},
|
||||
{"role": "user", "content": "hello"},
|
||||
],
|
||||
stream=False,
|
||||
)
|
||||
|
||||
|
||||
@@ -182,5 +186,9 @@ def test_services_ai_proxy_with_stream(mock_create):
|
||||
}
|
||||
assert response == expected_response
|
||||
mock_create.assert_called_once_with(
|
||||
messages=[{"role": "user", "content": "hello"}], stream=True
|
||||
messages=[
|
||||
{"role": "system", "content": BLOCKNOTE_TOOL_STRICT_PROMPT},
|
||||
{"role": "user", "content": "hello"},
|
||||
],
|
||||
stream=True,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user