🛂(backend) harden payload proxy ai
Standard can vary depending on the AI service used. To work with Albert API: - a description field is required in the payload for every tools call. - if stream is set to false, stream_options must be omitted from the payload. - the response from Albert sometimes didn't respect the format expected by Blocknote, so we added a system prompt to enforce it.
This commit is contained in:
@@ -2,12 +2,11 @@
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from typing import Generator
|
from typing import Any, Dict, Generator
|
||||||
|
|
||||||
from django.conf import settings
|
from django.conf import settings
|
||||||
from django.core.exceptions import ImproperlyConfigured
|
from django.core.exceptions import ImproperlyConfigured
|
||||||
|
|
||||||
from openai import OpenAI as OpenAI_Client
|
|
||||||
from openai import OpenAIError
|
from openai import OpenAIError
|
||||||
|
|
||||||
from core import enums
|
from core import enums
|
||||||
@@ -15,11 +14,42 @@ from core import enums
|
|||||||
if settings.LANGFUSE_PUBLIC_KEY:
|
if settings.LANGFUSE_PUBLIC_KEY:
|
||||||
from langfuse.openai import OpenAI
|
from langfuse.openai import OpenAI
|
||||||
else:
|
else:
|
||||||
OpenAI = OpenAI_Client
|
from openai import OpenAI
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
BLOCKNOTE_TOOL_STRICT_PROMPT = """
|
||||||
|
You are editing a BlockNote document via the tool applyDocumentOperations.
|
||||||
|
|
||||||
|
You MUST respond ONLY by calling applyDocumentOperations.
|
||||||
|
The tool input MUST be valid JSON:
|
||||||
|
{ "operations": [ ... ] }
|
||||||
|
|
||||||
|
Each operation MUST include "type" and it MUST be one of:
|
||||||
|
- "update" (requires: id, block)
|
||||||
|
- "add" (requires: referenceId, position, blocks)
|
||||||
|
- "delete" (requires: id)
|
||||||
|
|
||||||
|
VALID SHAPES (FOLLOW EXACTLY):
|
||||||
|
|
||||||
|
Update:
|
||||||
|
{ "type":"update", "id":"<id$>", "block":"<p>...</p>" }
|
||||||
|
IMPORTANT: "block" MUST be a STRING containing a SINGLE valid HTML element.
|
||||||
|
|
||||||
|
Add:
|
||||||
|
{ "type":"add", "referenceId":"<id$>", "position":"before|after", "blocks":["<p>...</p>"] }
|
||||||
|
IMPORTANT: "blocks" MUST be an ARRAY OF STRINGS.
|
||||||
|
Each item MUST be a STRING containing a SINGLE valid HTML element.
|
||||||
|
|
||||||
|
Delete:
|
||||||
|
{ "type":"delete", "id":"<id$>" }
|
||||||
|
|
||||||
|
IDs ALWAYS end with "$". Use ids EXACTLY as provided.
|
||||||
|
|
||||||
|
Return ONLY the JSON tool input. No prose, no markdown.
|
||||||
|
"""
|
||||||
|
|
||||||
AI_ACTIONS = {
|
AI_ACTIONS = {
|
||||||
"prompt": (
|
"prompt": (
|
||||||
"Answer the prompt using markdown formatting for structure and emphasis. "
|
"Answer the prompt using markdown formatting for structure and emphasis. "
|
||||||
@@ -106,11 +136,58 @@ class AIService:
|
|||||||
system_content = AI_TRANSLATE.format(language=language_display)
|
system_content = AI_TRANSLATE.format(language=language_display)
|
||||||
return self.call_ai_api(system_content, text)
|
return self.call_ai_api(system_content, text)
|
||||||
|
|
||||||
|
def _normalize_tools(self, tools: list) -> list:
|
||||||
|
"""
|
||||||
|
Normalize tool definitions to ensure they have required fields.
|
||||||
|
"""
|
||||||
|
normalized = []
|
||||||
|
for tool in tools:
|
||||||
|
if isinstance(tool, dict) and tool.get("type") == "function":
|
||||||
|
fn = tool.get("function") or {}
|
||||||
|
if isinstance(fn, dict) and not fn.get("description"):
|
||||||
|
fn["description"] = f"Tool {fn.get('name', 'unknown')}."
|
||||||
|
tool["function"] = fn
|
||||||
|
normalized.append(tool)
|
||||||
|
return normalized
|
||||||
|
|
||||||
|
def _harden_payload(
|
||||||
|
self, payload: Dict[str, Any], stream: bool = False
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Harden the AI API payload to enforce compliance and tool usage."""
|
||||||
|
payload["stream"] = stream
|
||||||
|
|
||||||
|
# Remove stream_options if stream is False
|
||||||
|
if not stream and "stream_options" in payload:
|
||||||
|
payload.pop("stream_options")
|
||||||
|
|
||||||
|
# Tools normalization
|
||||||
|
if isinstance(payload.get("tools"), list):
|
||||||
|
payload["tools"] = self._normalize_tools(payload["tools"])
|
||||||
|
|
||||||
|
# Inject strict system prompt once
|
||||||
|
msgs = payload.get("messages")
|
||||||
|
if isinstance(msgs, list):
|
||||||
|
need = True
|
||||||
|
if msgs and isinstance(msgs[0], dict) and msgs[0].get("role") == "system":
|
||||||
|
c = msgs[0].get("content") or ""
|
||||||
|
if (
|
||||||
|
isinstance(c, str)
|
||||||
|
and "applyDocumentOperations" in c
|
||||||
|
and "blocks" in c
|
||||||
|
):
|
||||||
|
need = False
|
||||||
|
if need:
|
||||||
|
payload["messages"] = [
|
||||||
|
{"role": "system", "content": BLOCKNOTE_TOOL_STRICT_PROMPT}
|
||||||
|
] + msgs
|
||||||
|
|
||||||
|
return payload
|
||||||
|
|
||||||
def proxy(self, data: dict, stream: bool = False) -> Generator[str, None, None]:
|
def proxy(self, data: dict, stream: bool = False) -> Generator[str, None, None]:
|
||||||
"""Proxy AI API requests to the configured AI provider."""
|
"""Proxy AI API requests to the configured AI provider."""
|
||||||
data["stream"] = stream
|
payload = self._harden_payload(data, stream=stream)
|
||||||
try:
|
try:
|
||||||
return self.client.chat.completions.create(**data)
|
return self.client.chat.completions.create(**payload)
|
||||||
except OpenAIError as e:
|
except OpenAIError as e:
|
||||||
raise RuntimeError(f"Failed to proxy AI request: {e}") from e
|
raise RuntimeError(f"Failed to proxy AI request: {e}") from e
|
||||||
|
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ import pytest
|
|||||||
from rest_framework.test import APIClient
|
from rest_framework.test import APIClient
|
||||||
|
|
||||||
from core import factories
|
from core import factories
|
||||||
|
from core.services.ai_services import BLOCKNOTE_TOOL_STRICT_PROMPT
|
||||||
from core.tests.conftest import TEAM, USER, VIA
|
from core.tests.conftest import TEAM, USER, VIA
|
||||||
|
|
||||||
pytestmark = pytest.mark.django_db
|
pytestmark = pytest.mark.django_db
|
||||||
@@ -110,7 +111,10 @@ def test_api_documents_ai_proxy_anonymous_success(mock_create):
|
|||||||
)
|
)
|
||||||
|
|
||||||
mock_create.assert_called_once_with(
|
mock_create.assert_called_once_with(
|
||||||
messages=[{"role": "user", "content": "Hello"}],
|
messages=[
|
||||||
|
{"role": "system", "content": BLOCKNOTE_TOOL_STRICT_PROMPT},
|
||||||
|
{"role": "user", "content": "Hello"},
|
||||||
|
],
|
||||||
model="llama",
|
model="llama",
|
||||||
stream=False,
|
stream=False,
|
||||||
)
|
)
|
||||||
@@ -228,7 +232,10 @@ def test_api_documents_ai_proxy_authenticated_success(mock_create, reach, role):
|
|||||||
assert response_data["choices"][0]["message"]["content"] == "Hi there!"
|
assert response_data["choices"][0]["message"]["content"] == "Hi there!"
|
||||||
|
|
||||||
mock_create.assert_called_once_with(
|
mock_create.assert_called_once_with(
|
||||||
messages=[{"role": "user", "content": "Hello"}],
|
messages=[
|
||||||
|
{"role": "system", "content": BLOCKNOTE_TOOL_STRICT_PROMPT},
|
||||||
|
{"role": "user", "content": "Hello"},
|
||||||
|
],
|
||||||
model="llama",
|
model="llama",
|
||||||
stream=False,
|
stream=False,
|
||||||
)
|
)
|
||||||
@@ -315,7 +322,10 @@ def test_api_documents_ai_proxy_success(mock_create, via, role, mock_user_teams)
|
|||||||
assert response_data["choices"][0]["message"]["content"] == "Success!"
|
assert response_data["choices"][0]["message"]["content"] == "Success!"
|
||||||
|
|
||||||
mock_create.assert_called_once_with(
|
mock_create.assert_called_once_with(
|
||||||
messages=[{"role": "user", "content": "Test message"}],
|
messages=[
|
||||||
|
{"role": "system", "content": BLOCKNOTE_TOOL_STRICT_PROMPT},
|
||||||
|
{"role": "user", "content": "Test message"},
|
||||||
|
],
|
||||||
model="llama",
|
model="llama",
|
||||||
stream=False,
|
stream=False,
|
||||||
)
|
)
|
||||||
@@ -412,6 +422,7 @@ def test_api_documents_ai_proxy_invalid_message_format():
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@override_settings(AI_STREAM=False)
|
||||||
@patch("openai.resources.chat.completions.Completions.create")
|
@patch("openai.resources.chat.completions.Completions.create")
|
||||||
def test_api_documents_ai_proxy_stream_disabled(mock_create):
|
def test_api_documents_ai_proxy_stream_disabled(mock_create):
|
||||||
"""Stream should be automatically disabled in AI proxy requests."""
|
"""Stream should be automatically disabled in AI proxy requests."""
|
||||||
@@ -440,12 +451,16 @@ def test_api_documents_ai_proxy_stream_disabled(mock_create):
|
|||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
# Verify that stream was set to False
|
# Verify that stream was set to False
|
||||||
mock_create.assert_called_once_with(
|
mock_create.assert_called_once_with(
|
||||||
messages=[{"role": "user", "content": "Hello"}],
|
messages=[
|
||||||
|
{"role": "system", "content": BLOCKNOTE_TOOL_STRICT_PROMPT},
|
||||||
|
{"role": "user", "content": "Hello"},
|
||||||
|
],
|
||||||
model="llama",
|
model="llama",
|
||||||
stream=False,
|
stream=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@override_settings(AI_STREAM=False)
|
||||||
@patch("openai.resources.chat.completions.Completions.create")
|
@patch("openai.resources.chat.completions.Completions.create")
|
||||||
def test_api_documents_ai_proxy_additional_parameters(mock_create):
|
def test_api_documents_ai_proxy_additional_parameters(mock_create):
|
||||||
"""AI proxy should pass through additional parameters to the AI service."""
|
"""AI proxy should pass through additional parameters to the AI service."""
|
||||||
@@ -476,7 +491,10 @@ def test_api_documents_ai_proxy_additional_parameters(mock_create):
|
|||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
# Verify that additional parameters were passed through
|
# Verify that additional parameters were passed through
|
||||||
mock_create.assert_called_once_with(
|
mock_create.assert_called_once_with(
|
||||||
messages=[{"role": "user", "content": "Hello"}],
|
messages=[
|
||||||
|
{"role": "system", "content": BLOCKNOTE_TOOL_STRICT_PROMPT},
|
||||||
|
{"role": "user", "content": "Hello"},
|
||||||
|
],
|
||||||
model="llama",
|
model="llama",
|
||||||
temperature=0.7,
|
temperature=0.7,
|
||||||
max_tokens=100,
|
max_tokens=100,
|
||||||
@@ -561,6 +579,7 @@ def test_api_documents_ai_proxy_complex_conversation(mock_create):
|
|||||||
mock_create.return_value = mock_response
|
mock_create.return_value = mock_response
|
||||||
|
|
||||||
complex_messages = [
|
complex_messages = [
|
||||||
|
{"role": "system", "content": BLOCKNOTE_TOOL_STRICT_PROMPT},
|
||||||
{"role": "system", "content": "You are a helpful programming assistant."},
|
{"role": "system", "content": "You are a helpful programming assistant."},
|
||||||
{"role": "user", "content": "How do I write a for loop in Python?"},
|
{"role": "user", "content": "How do I write a for loop in Python?"},
|
||||||
{
|
{
|
||||||
@@ -688,3 +707,93 @@ def test_api_documents_ai_proxy_ai_feature_disabled(settings):
|
|||||||
|
|
||||||
assert response.status_code == 400
|
assert response.status_code == 400
|
||||||
assert response.json() == ["AI feature is not enabled."]
|
assert response.json() == ["AI feature is not enabled."]
|
||||||
|
|
||||||
|
|
||||||
|
@override_settings(AI_STREAM=False)
|
||||||
|
@patch("openai.resources.chat.completions.Completions.create")
|
||||||
|
def test_api_documents_ai_proxy_harden_payload(mock_create):
|
||||||
|
"""
|
||||||
|
AI proxy should harden the payload by default:
|
||||||
|
- it will remove stream_options if stream is False
|
||||||
|
- normalize tools by adding descriptions if missing
|
||||||
|
"""
|
||||||
|
user = factories.UserFactory()
|
||||||
|
|
||||||
|
client = APIClient()
|
||||||
|
client.force_login(user)
|
||||||
|
|
||||||
|
document = factories.DocumentFactory(link_reach="public", link_role="editor")
|
||||||
|
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.model_dump.return_value = {
|
||||||
|
"id": "chatcmpl-789",
|
||||||
|
"object": "chat.completion",
|
||||||
|
"model": "llama",
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"message": {"role": "assistant", "content": "Success!"},
|
||||||
|
"finish_reason": "stop",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
mock_create.return_value = mock_response
|
||||||
|
|
||||||
|
url = f"/api/v1.0/documents/{document.id!s}/ai-proxy/"
|
||||||
|
response = client.post(
|
||||||
|
url,
|
||||||
|
{
|
||||||
|
"messages": [{"role": "user", "content": "Hello"}],
|
||||||
|
"model": "llama",
|
||||||
|
"stream": False,
|
||||||
|
"stream_options": {"include_usage": True},
|
||||||
|
"tools": [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "applyDocumentOperations",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"operations": {"type": "array", "items": {"anyOf": []}}
|
||||||
|
},
|
||||||
|
"additionalProperties": False,
|
||||||
|
"required": ["operations"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
format="json",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
response_data = response.json()
|
||||||
|
assert response_data["id"] == "chatcmpl-789"
|
||||||
|
assert response_data["choices"][0]["message"]["content"] == "Success!"
|
||||||
|
|
||||||
|
mock_create.assert_called_once_with(
|
||||||
|
messages=[
|
||||||
|
{"role": "system", "content": BLOCKNOTE_TOOL_STRICT_PROMPT},
|
||||||
|
{"role": "user", "content": "Hello"},
|
||||||
|
],
|
||||||
|
model="llama",
|
||||||
|
stream=False,
|
||||||
|
tools=[
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "applyDocumentOperations",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"operations": {"type": "array", "items": {"anyOf": []}}
|
||||||
|
},
|
||||||
|
"additionalProperties": False,
|
||||||
|
"required": ["operations"],
|
||||||
|
},
|
||||||
|
"description": "Tool applyDocumentOperations.",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ from django.test.utils import override_settings
|
|||||||
import pytest
|
import pytest
|
||||||
from openai import OpenAIError
|
from openai import OpenAIError
|
||||||
|
|
||||||
from core.services.ai_services import AIService
|
from core.services.ai_services import BLOCKNOTE_TOOL_STRICT_PROMPT, AIService
|
||||||
|
|
||||||
pytestmark = pytest.mark.django_db
|
pytestmark = pytest.mark.django_db
|
||||||
|
|
||||||
@@ -125,7 +125,11 @@ def test_services_ai_proxy_success(mock_create):
|
|||||||
}
|
}
|
||||||
assert response == expected_response
|
assert response == expected_response
|
||||||
mock_create.assert_called_once_with(
|
mock_create.assert_called_once_with(
|
||||||
messages=[{"role": "user", "content": "hello"}], stream=False
|
messages=[
|
||||||
|
{"role": "system", "content": BLOCKNOTE_TOOL_STRICT_PROMPT},
|
||||||
|
{"role": "user", "content": "hello"},
|
||||||
|
],
|
||||||
|
stream=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -182,5 +186,9 @@ def test_services_ai_proxy_with_stream(mock_create):
|
|||||||
}
|
}
|
||||||
assert response == expected_response
|
assert response == expected_response
|
||||||
mock_create.assert_called_once_with(
|
mock_create.assert_called_once_with(
|
||||||
messages=[{"role": "user", "content": "hello"}], stream=True
|
messages=[
|
||||||
|
{"role": "system", "content": BLOCKNOTE_TOOL_STRICT_PROMPT},
|
||||||
|
{"role": "user", "content": "hello"},
|
||||||
|
],
|
||||||
|
stream=True,
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user