diff --git a/src/backend/core/services/ai_services.py b/src/backend/core/services/ai_services.py index d325decc..89e322ab 100644 --- a/src/backend/core/services/ai_services.py +++ b/src/backend/core/services/ai_services.py @@ -21,7 +21,7 @@ from pydantic_ai.tools import ToolDefinition from pydantic_ai.toolsets.external import ExternalToolset from pydantic_ai.ui import SSE_CONTENT_TYPE from pydantic_ai.ui.vercel_ai import VercelAIAdapter -from pydantic_ai.ui.vercel_ai.request_types import TextUIPart, UIMessage +from pydantic_ai.ui.vercel_ai.request_types import RequestData, TextUIPart, UIMessage from rest_framework.request import Request from core import enums @@ -33,6 +33,37 @@ else: log = logging.getLogger(__name__) +BLOCKNOTE_TOOL_STRICT_PROMPT = """ +You are editing a BlockNote document via the tool applyDocumentOperations. + +You MUST respond ONLY by calling applyDocumentOperations. +The tool input MUST be valid JSON: +{ "operations": [ ... ] } + +Each operation MUST include "type" and it MUST be one of: +- "update" (requires: id, block) +- "add" (requires: referenceId, position, blocks) +- "delete" (requires: id) + +VALID SHAPES (FOLLOW EXACTLY): + +Update: +{ "type":"update", "id":"", "block":"

...

" } +IMPORTANT: "block" MUST be a STRING containing a SINGLE valid HTML element. + +Add: +{ "type":"add", "referenceId":"", "position":"before|after", "blocks":["

...

"] } +IMPORTANT: "blocks" MUST be an ARRAY OF STRINGS. +Each item MUST be a STRING containing a SINGLE valid HTML element. + +Delete: +{ "type":"delete", "id":"" } + +IDs ALWAYS end with "$". Use ids EXACTLY as provided. + +Return ONLY the JSON tool input. No prose, no markdown. +""" + AI_ACTIONS = { "prompt": ( "Answer the prompt using markdown formatting for structure and emphasis. " @@ -261,6 +292,29 @@ class AIService: ] return ExternalToolset(tool_defs) + def _harden_messages( + self, run_input: RequestData, tool_definitions: Dict[str, Any] + ): + """ + Harden messages if applyDocumentOperations tool is used. + We would like the system_prompt property in the Agent initialization + but for UI adapter, like vercel, the agent is ignoring it + see https://github.com/pydantic/pydantic-ai/issues/3315 + + We have to inject it in the run_input.messages if needed. + """ + for name, _defn in tool_definitions.items(): + if name == "applyDocumentOperations": + run_input.messages.insert( + 0, + UIMessage( + id="system-force-tool-usage", + role="system", + parts=[TextUIPart(text=BLOCKNOTE_TOOL_STRICT_PROMPT)], + ), + ) + return + def _build_async_stream(self, request: Request) -> AsyncIterator[str]: """Build the async stream from the AI provider.""" instrument_enabled = settings.LANGFUSE_PUBLIC_KEY is not None @@ -295,6 +349,9 @@ class AIService: self.tool_definitions_to_toolset(raw_tool_defs) if raw_tool_defs else None ) + if raw_tool_defs: + self._harden_messages(run_input, raw_tool_defs) + adapter = VercelAIAdapter( agent=agent, run_input=run_input, diff --git a/src/backend/core/tests/test_services_ai_services.py b/src/backend/core/tests/test_services_ai_services.py index d1966f5a..8412aca5 100644 --- a/src/backend/core/tests/test_services_ai_services.py +++ b/src/backend/core/tests/test_services_ai_services.py @@ -14,6 +14,7 @@ from openai import OpenAIError from pydantic_ai.ui.vercel_ai.request_types import TextUIPart, UIMessage from core.services.ai_services import ( + BLOCKNOTE_TOOL_STRICT_PROMPT, AIService, convert_async_generator_to_sync, ) @@ -511,6 +512,49 @@ def test_services_ai_build_async_stream_with_tool_definitions(mock_adapter_cls): assert len(call_kwargs["toolsets"]) == 1 +@patch("core.services.ai_services.VercelAIAdapter") +def test_services_ai_build_async_stream_with_tool_definitions_required_system_prompt( + mock_adapter_cls, +): + """The presence of the applyDocumentOperations tool must force the addition + of a system prompt""" + + async def mock_encode(): + yield "event-data" + + mock_run_input = MagicMock() + mock_run_input.model_extra = { + "toolDefinitions": { + "applyDocumentOperations": { + "description": "A tool", + "inputSchema": {"type": "object"}, + } + } + } + mock_run_input.messages = [] + mock_adapter_cls.build_run_input.return_value = mock_run_input + + mock_adapter_instance = MagicMock() + mock_adapter_instance.run_stream.return_value = MagicMock() + mock_adapter_instance.encode_stream.return_value = mock_encode() + mock_adapter_cls.return_value = mock_adapter_instance + + service = AIService() + request = MagicMock() + request.META = {} + request.raw_body = b"{}" + + service._build_async_stream(request) + # run_stream should have been called with a toolset + call_kwargs = mock_adapter_instance.run_stream.call_args[1] + assert call_kwargs["toolsets"] is not None + assert len(call_kwargs["toolsets"]) == 1 + assert len(mock_run_input.messages) == 1 + assert mock_run_input.messages[0].id == "system-force-tool-usage" + assert mock_run_input.messages[0].role == "system" + assert mock_run_input.messages[0].parts[0].text == BLOCKNOTE_TOOL_STRICT_PROMPT + + @patch("core.services.ai_services.Agent") @patch("core.services.ai_services.VercelAIAdapter") def test_services_ai_build_async_stream_langfuse_enabled(