From 8ce216f6e879ab9ec90984a22ea107d07f6808ee Mon Sep 17 00:00:00 2001 From: Manuel Raynaud Date: Tue, 17 Feb 2026 11:11:13 +0100 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8(backend)=20use=20pydantic=20AI=20to?= =?UTF-8?q?=20manage=20vercel=20data=20stream=20protocol?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The frontend application is using Vercel AI SDK and it's data stream protocol. We decided to use the pydantic AI library to use it's vercel ai adapter. It will make the payload validation, use AsyncIterator and deal with vercel specification. --- docs/env.md | 2 +- src/backend/core/api/serializers.py | 33 - src/backend/core/api/viewsets.py | 27 +- src/backend/core/middleware.py | 18 + src/backend/core/services/ai_services.py | 321 +++++--- .../documents/test_api_documents_ai_proxy.py | 711 ++++-------------- .../core/tests/test_services_ai_services.py | 542 ++++++++++--- src/backend/impress/settings.py | 4 + src/backend/pyproject.toml | 4 +- 9 files changed, 850 insertions(+), 812 deletions(-) diff --git a/docs/env.md b/docs/env.md index 725cfb07..241bf6ae 100644 --- a/docs/env.md +++ b/docs/env.md @@ -15,6 +15,7 @@ These are the environment variables you can set for the `impress-backend` contai | AI_FEATURE_ENABLED | Enable AI options | false | | AI_MODEL | AI Model to use | | | AI_STREAM | AI Stream to activate | true | +| AI_VERCEL_SDK_VERSION | The vercel AI SDK version used | 6 | | ALLOW_LOGOUT_GET_METHOD | Allow get logout method | true | | API_USERS_LIST_LIMIT | Limit on API users | 5 | | API_USERS_LIST_THROTTLE_RATE_BURST | Throttle rate for api on burst | 30/minute | @@ -166,4 +167,3 @@ Packages with licences incompatible with the MIT licence: In `.env.development`, `PUBLISH_AS_MIT` is set to `false`, allowing developers to test Docs with all its features. ⚠️ If you run Docs in production with `PUBLISH_AS_MIT` set to `false` make sure you fulfill your BlockNote licensing or [subscription](https://www.blocknotejs.org/about#partner-with-us) obligations. - diff --git a/src/backend/core/api/serializers.py b/src/backend/core/api/serializers.py index 75eaa5d3..a7f3d401 100644 --- a/src/backend/core/api/serializers.py +++ b/src/backend/core/api/serializers.py @@ -833,39 +833,6 @@ class AITranslateSerializer(serializers.Serializer): return value -class AIProxySerializer(serializers.Serializer): - """Serializer for AI proxy requests.""" - - messages = serializers.ListField( - required=True, - child=serializers.DictField(), - allow_empty=False, - ) - model = serializers.CharField(required=True) - - def validate_messages(self, messages): - """Validate messages structure.""" - # Ensure each message has the required fields - for message in messages: - if ( - not isinstance(message, dict) - or "role" not in message - or "content" not in message - ): - raise serializers.ValidationError( - "Each message must have 'role' and 'content' fields" - ) - - return messages - - def validate_model(self, value): - """Validate model value is the same than settings.AI_MODEL""" - if value != settings.AI_MODEL: - raise serializers.ValidationError(f"{value} is not a valid model") - - return value - - class MoveDocumentSerializer(serializers.Serializer): """ Serializer for validating input data to move a document within the tree structure. diff --git a/src/backend/core/api/viewsets.py b/src/backend/core/api/viewsets.py index 9f715139..26b174ed 100644 --- a/src/backend/core/api/viewsets.py +++ b/src/backend/core/api/viewsets.py @@ -38,6 +38,7 @@ from csp.decorators import csp_update from lasuite.malware_detection import malware_detection from lasuite.oidc_login.decorators import refresh_oidc_access_token from lasuite.tools.email import get_domain_from_email +from pydantic import ValidationError as PydanticValidationError from rest_framework import filters, status, viewsets from rest_framework import response as drf_response from rest_framework.permissions import AllowAny @@ -1854,22 +1855,24 @@ class DocumentViewSet( if not settings.AI_FEATURE_ENABLED: raise ValidationError("AI feature is not enabled.") - serializer = serializers.AIProxySerializer(data=request.data) - serializer.is_valid(raise_exception=True) - ai_service = AIService() - if settings.AI_STREAM: - return StreamingHttpResponse( - ai_service.stream(request.data), - content_type="text/event-stream", - status=drf.status.HTTP_200_OK, + try: + stream = ai_service.stream(request) + except PydanticValidationError as err: + logger.info("pydantic validation error: %s", err) + return drf.response.Response( + {"detail": "Invalid submitted payload"}, + status=drf.status.HTTP_400_BAD_REQUEST, ) - ai_response = ai_service.proxy(request.data) - return drf.response.Response( - ai_response.model_dump(), - status=drf.status.HTTP_200_OK, + return StreamingHttpResponse( + stream, + content_type="text/event-stream", + headers={ + "x-vercel-ai-data-stream": "v1", # This header is used for Vercel AI streaming, + "X-Accel-Buffering": "no", # Prevent nginx buffering + }, ) @drf.decorators.action( diff --git a/src/backend/core/middleware.py b/src/backend/core/middleware.py index afb5a100..410a461d 100644 --- a/src/backend/core/middleware.py +++ b/src/backend/core/middleware.py @@ -19,3 +19,21 @@ class ForceSessionMiddleware: response = self.get_response(request) return response + + +class SaveRawBodyMiddleware: + """ + Save the raw request body to use it later. + """ + + def __init__(self, get_response): + """Initialize the middleware.""" + self.get_response = get_response + + def __call__(self, request): + """Save the raw request body in the request to use it later.""" + if request.path.endswith(("/ai-proxy/", "/ai-proxy")): + request.raw_body = request.body + + response = self.get_response(request) + return response diff --git a/src/backend/core/services/ai_services.py b/src/backend/core/services/ai_services.py index cad97dfb..d325decc 100644 --- a/src/backend/core/services/ai_services.py +++ b/src/backend/core/services/ai_services.py @@ -1,55 +1,38 @@ """AI services.""" +import asyncio import json import logging -from typing import Any, Dict, Generator +import os +import queue +import threading +from collections.abc import AsyncIterator, Iterator +from typing import Any, Dict, Union from django.conf import settings from django.core.exceptions import ImproperlyConfigured -from openai import OpenAIError +from langfuse import get_client +from langfuse.openai import OpenAI as OpenAI_Langfuse +from pydantic_ai import Agent, DeferredToolRequests +from pydantic_ai.models.openai import OpenAIChatModel +from pydantic_ai.providers.openai import OpenAIProvider +from pydantic_ai.tools import ToolDefinition +from pydantic_ai.toolsets.external import ExternalToolset +from pydantic_ai.ui import SSE_CONTENT_TYPE +from pydantic_ai.ui.vercel_ai import VercelAIAdapter +from pydantic_ai.ui.vercel_ai.request_types import TextUIPart, UIMessage +from rest_framework.request import Request from core import enums if settings.LANGFUSE_PUBLIC_KEY: - from langfuse.openai import OpenAI + OpenAI = OpenAI_Langfuse else: from openai import OpenAI log = logging.getLogger(__name__) - -BLOCKNOTE_TOOL_STRICT_PROMPT = """ -You are editing a BlockNote document via the tool applyDocumentOperations. - -You MUST respond ONLY by calling applyDocumentOperations. -The tool input MUST be valid JSON: -{ "operations": [ ... ] } - -Each operation MUST include "type" and it MUST be one of: -- "update" (requires: id, block) -- "add" (requires: referenceId, position, blocks) -- "delete" (requires: id) - -VALID SHAPES (FOLLOW EXACTLY): - -Update: -{ "type":"update", "id":"", "block":"

...

" } -IMPORTANT: "block" MUST be a STRING containing a SINGLE valid HTML element. - -Add: -{ "type":"add", "referenceId":"", "position":"before|after", "blocks":["

...

"] } -IMPORTANT: "blocks" MUST be an ARRAY OF STRINGS. -Each item MUST be a STRING containing a SINGLE valid HTML element. - -Delete: -{ "type":"delete", "id":"" } - -IDs ALWAYS end with "$". Use ids EXACTLY as provided. - -Return ONLY the JSON tool input. No prose, no markdown. -""" - AI_ACTIONS = { "prompt": ( "Answer the prompt using markdown formatting for structure and emphasis. " @@ -95,6 +78,40 @@ AI_TRANSLATE = ( ) +def convert_async_generator_to_sync(async_gen: AsyncIterator[str]) -> Iterator[str]: + """Convert an async generator to a sync generator.""" + q: queue.Queue[str | object] = queue.Queue() + sentinel = object() + exc_sentinel = object() + + async def run_async_gen(): + try: + async for async_item in async_gen: + q.put(async_item) + except Exception as exc: # pylint: disable=broad-except #noqa: BLE001 + q.put((exc_sentinel, exc)) + finally: + q.put(sentinel) + + def start_async_loop(): + asyncio.run(run_async_gen()) + + thread = threading.Thread(target=start_async_loop, daemon=True) + thread.start() + + try: + while True: + item = q.get() + if item is sentinel: + break + if isinstance(item, tuple) and item[0] is exc_sentinel: + # re-raise the exception in the sync context + raise item[1] + yield item + finally: + thread.join() + + class AIService: """Service class for AI-related operations.""" @@ -136,77 +153,171 @@ class AIService: system_content = AI_TRANSLATE.format(language=language_display) return self.call_ai_api(system_content, text) - def _normalize_tools(self, tools: list) -> list: + @staticmethod + def inject_document_state_messages( + messages: list[UIMessage], + ) -> list[UIMessage]: + """Inject document state context before user messages. + + Port of BlockNote's injectDocumentStateMessages. + For each user message carrying documentState metadata, an assistant + message describing the current document/selection state is prepended + so the LLM sees it as context. """ - Normalize tool definitions to ensure they have required fields. - """ - normalized = [] - for tool in tools: - if isinstance(tool, dict) and tool.get("type") == "function": - fn = tool.get("function") or {} - if isinstance(fn, dict) and not fn.get("description"): - fn["description"] = f"Tool {fn.get('name', 'unknown')}." - tool["function"] = fn - normalized.append(tool) - return normalized + result: list[UIMessage] = [] + for message in messages: + if ( + message.role == "user" + and isinstance(message.metadata, dict) + and "documentState" in message.metadata + ): + doc_state = message.metadata["documentState"] + selection = doc_state.get("selection") + blocks = doc_state.get("blocks") - def _harden_payload( - self, payload: Dict[str, Any], stream: bool = False - ) -> Dict[str, Any]: - """Harden the AI API payload to enforce compliance and tool usage.""" - payload["stream"] = stream - - # Remove stream_options if stream is False - if not stream and "stream_options" in payload: - payload.pop("stream_options") - - # Tools normalization - if isinstance(payload.get("tools"), list): - payload["tools"] = self._normalize_tools(payload["tools"]) - - # Inject strict system prompt once - msgs = payload.get("messages") - if isinstance(msgs, list): - need = True - if msgs and isinstance(msgs[0], dict) and msgs[0].get("role") == "system": - c = msgs[0].get("content") or "" - if ( - isinstance(c, str) - and "applyDocumentOperations" in c - and "blocks" in c - ): - need = False - if need: - payload["messages"] = [ - {"role": "system", "content": BLOCKNOTE_TOOL_STRICT_PROMPT} - ] + msgs - - return payload - - def proxy(self, data: dict, stream: bool = False) -> Generator[str, None, None]: - """Proxy AI API requests to the configured AI provider.""" - payload = self._harden_payload(data, stream=stream) - try: - return self.client.chat.completions.create(**payload) - except OpenAIError as e: - raise RuntimeError(f"Failed to proxy AI request: {e}") from e - - def stream(self, data: dict) -> Generator[str, None, None]: - """Stream AI API requests to the configured AI provider.""" - - try: - stream = self.proxy(data, stream=True) - for chunk in stream: - try: - chunk_dict = ( - chunk.model_dump() if hasattr(chunk, "model_dump") else chunk + if selection: + parts = [ + TextUIPart( + text=( + "This is the latest state of the selection " + "(ignore previous selections, you MUST issue " + "operations against this latest version of " + "the selection):" + ), + ), + TextUIPart( + text=json.dumps(doc_state.get("selectedBlocks")), + ), + TextUIPart( + text=( + "This is the latest state of the entire " + "document (INCLUDING the selected text), you " + "can use this to find the selected text to " + "understand the context (but you MUST NOT " + "issue operations against this document, you " + "MUST issue operations against the selection):" + ), + ), + TextUIPart(text=json.dumps(blocks)), + ] + else: + text = ( + "There is no active selection. This is the latest " + "state of the document (ignore previous documents, " + "you MUST issue operations against this latest " + "version of the document). The cursor is BETWEEN " + "two blocks as indicated by cursor: true." ) - chunk_json = json.dumps(chunk_dict) - yield f"data: {chunk_json}\n\n" - except (AttributeError, TypeError) as e: - log.error("Error serializing chunk: %s, chunk: %s", e, chunk) - continue - except (OpenAIError, RuntimeError, OSError, ValueError) as e: - log.error("Streaming error: %s", e) + if doc_state.get("isEmptyDocument"): + text += ( + "Because the document is empty, YOU MUST first " + "update the empty block before adding new blocks." + ) + else: + text += ( + "Prefer updating existing blocks over removing " + "and adding (but this also depends on the " + "user's question)." + ) + parts = [ + TextUIPart(text=text), + TextUIPart(text=json.dumps(blocks)), + ] - yield "data: [DONE]\n\n" + result.append( + UIMessage( + role="assistant", + id=f"assistant-document-state-{message.id}", + parts=parts, + ) + ) + + result.append(message) + return result + + @staticmethod + def tool_definitions_to_toolset( + tool_definitions: Dict[str, Any], + ) -> ExternalToolset: + """Convert serialized tool definitions to a pydantic-ai ExternalToolset. + + Port of BlockNote's toolDefinitionsToToolSet. + Builds ToolDefinition objects from the JSON-Schema-based definitions + sent by the frontend and wraps them in an ExternalToolset so that + pydantic-ai advertises them to the LLM without trying to execute them + server-side (execution is deferred to the frontend). + """ + tool_defs = [ + ToolDefinition( + name=name, + description=defn.get("description", ""), + parameters_json_schema=defn.get("inputSchema", {}), + kind="external", + metadata={ + "output_schema": defn.get("outputSchema"), + }, + ) + for name, defn in tool_definitions.items() + ] + return ExternalToolset(tool_defs) + + def _build_async_stream(self, request: Request) -> AsyncIterator[str]: + """Build the async stream from the AI provider.""" + instrument_enabled = settings.LANGFUSE_PUBLIC_KEY is not None + + if instrument_enabled: + langfuse = get_client() + langfuse.auth_check() + Agent.instrument_all() + + model = OpenAIChatModel( + settings.AI_MODEL, + provider=OpenAIProvider( + base_url=settings.AI_BASE_URL, api_key=settings.AI_API_KEY + ), + ) + agent = Agent(model, instrument=instrument_enabled) + + accept = request.META.get("HTTP_ACCEPT", SSE_CONTENT_TYPE) + + run_input = VercelAIAdapter.build_run_input(request.raw_body) + + # Inject document state context into the conversation + run_input.messages = self.inject_document_state_messages(run_input.messages) + + # Build an ExternalToolset from frontend-supplied tool definitions + raw_tool_defs = ( + run_input.model_extra.get("toolDefinitions") + if run_input.model_extra + else None + ) + toolset = ( + self.tool_definitions_to_toolset(raw_tool_defs) if raw_tool_defs else None + ) + + adapter = VercelAIAdapter( + agent=agent, + run_input=run_input, + accept=accept, + sdk_version=settings.AI_VERCEL_SDK_VERSION, + ) + + event_stream = adapter.run_stream( + output_type=[str, DeferredToolRequests] if toolset else None, + toolsets=[toolset] if toolset else None, + ) + + return adapter.encode_stream(event_stream) + + def stream(self, request: Request) -> Union[AsyncIterator[str], Iterator[str]]: + """Stream AI API requests to the configured AI provider. + + Returns an async iterator when running in async mode (ASGI) + or a sync iterator when running in sync mode (WSGI). + """ + async_stream = self._build_async_stream(request) + + if os.environ.get("PYTHON_SERVER_MODE", "sync") == "async": + return async_stream + + return convert_async_generator_to_sync(async_stream) diff --git a/src/backend/core/tests/documents/test_api_documents_ai_proxy.py b/src/backend/core/tests/documents/test_api_documents_ai_proxy.py index 1046cb81..5229bdc3 100644 --- a/src/backend/core/tests/documents/test_api_documents_ai_proxy.py +++ b/src/backend/core/tests/documents/test_api_documents_ai_proxy.py @@ -3,7 +3,7 @@ Test AI proxy API endpoint for users in impress's core app. """ import random -from unittest.mock import MagicMock, patch +from unittest.mock import patch from django.test import override_settings @@ -11,7 +11,6 @@ import pytest from rest_framework.test import APIClient from core import factories -from core.services.ai_services import BLOCKNOTE_TOOL_STRICT_PROMPT from core.tests.conftest import TEAM, USER, VIA pytestmark = pytest.mark.django_db @@ -24,6 +23,8 @@ def ai_settings(settings): settings.AI_BASE_URL = "http://localhost-ai:12345/" settings.AI_API_KEY = "test-key" settings.AI_FEATURE_ENABLED = True + settings.LANGFUSE_PUBLIC_KEY = None + settings.AI_VERCEL_SDK_VERSION = 6 @override_settings( @@ -51,7 +52,6 @@ def test_api_documents_ai_proxy_anonymous_forbidden(reach, role): url, { "messages": [{"role": "user", "content": "Hello"}], - "model": "llama", }, format="json", ) @@ -62,85 +62,48 @@ def test_api_documents_ai_proxy_anonymous_forbidden(reach, role): } -@override_settings(AI_ALLOW_REACH_FROM="public", AI_STREAM=False) -@patch("openai.resources.chat.completions.Completions.create") -def test_api_documents_ai_proxy_anonymous_success(mock_create): +@override_settings(AI_ALLOW_REACH_FROM="public") +@patch("core.services.ai_services.AIService.stream") +def test_api_documents_ai_proxy_anonymous_success(mock_stream): """ Anonymous users should be able to request AI proxy to a document if the link reach and role permit it. """ document = factories.DocumentFactory(link_reach="public", link_role="editor") - mock_response = MagicMock() - mock_response.model_dump.return_value = { - "id": "chatcmpl-123", - "object": "chat.completion", - "created": 1677652288, - "model": "llama", - "choices": [ - { - "index": 0, - "message": { - "role": "assistant", - "content": "Hello! How can I help you?", - }, - "finish_reason": "stop", - } - ], - "usage": {"prompt_tokens": 9, "completion_tokens": 12, "total_tokens": 21}, - } - mock_create.return_value = mock_response + mock_stream.return_value = iter(["data: chunk1\n", "data: chunk2\n"]) + url = f"/api/v1.0/documents/{document.id!s}/ai-proxy/" response = APIClient().post( url, - { - "messages": [{"role": "user", "content": "Hello"}], - "model": "llama", - }, - format="json", + b"{}", + content_type="application/json", ) assert response.status_code == 200 - response_data = response.json() - assert response_data["id"] == "chatcmpl-123" - assert response_data["model"] == "llama" - assert len(response_data["choices"]) == 1 - assert ( - response_data["choices"][0]["message"]["content"] - == "Hello! How can I help you?" - ) + assert response["Content-Type"] == "text/event-stream" + assert response["x-vercel-ai-data-stream"] == "v1" + assert response["X-Accel-Buffering"] == "no" - mock_create.assert_called_once_with( - messages=[ - {"role": "system", "content": BLOCKNOTE_TOOL_STRICT_PROMPT}, - {"role": "user", "content": "Hello"}, - ], - model="llama", - stream=False, - ) + content = b"".join(response.streaming_content).decode() + assert "chunk1" in content + assert "chunk2" in content + mock_stream.assert_called_once() @override_settings(AI_ALLOW_REACH_FROM=random.choice(["authenticated", "restricted"])) -@patch("openai.resources.chat.completions.Completions.create") -def test_api_documents_ai_proxy_anonymous_limited_by_setting(mock_create): +def test_api_documents_ai_proxy_anonymous_limited_by_setting(): """ Anonymous users should not be able to request AI proxy to a document if AI_ALLOW_REACH_FROM setting restricts it. """ document = factories.DocumentFactory(link_reach="public", link_role="editor") - mock_response = MagicMock() - mock_response.model_dump.return_value = {"content": "Hello!"} - mock_create.return_value = mock_response - url = f"/api/v1.0/documents/{document.id!s}/ai-proxy/" response = APIClient().post( url, - { - "messages": [{"role": "user", "content": "Hello"}], - "model": "llama", - }, - format="json", + b"{}", + content_type="application/json", ) assert response.status_code == 401 @@ -170,11 +133,8 @@ def test_api_documents_ai_proxy_authenticated_forbidden(reach, role): url = f"/api/v1.0/documents/{document.id!s}/ai-proxy/" response = client.post( url, - { - "messages": [{"role": "user", "content": "Hello"}], - "model": "llama", - }, - format="json", + b"{}", + content_type="application/json", ) assert response.status_code == 403 @@ -187,9 +147,8 @@ def test_api_documents_ai_proxy_authenticated_forbidden(reach, role): ("public", "editor"), ], ) -@override_settings(AI_STREAM=False) -@patch("openai.resources.chat.completions.Completions.create") -def test_api_documents_ai_proxy_authenticated_success(mock_create, reach, role): +@patch("core.services.ai_services.AIService.stream") +def test_api_documents_ai_proxy_authenticated_success(mock_stream, reach, role): """ Authenticated users should be able to request AI proxy to a document if the link reach and role permit it. @@ -201,44 +160,18 @@ def test_api_documents_ai_proxy_authenticated_success(mock_create, reach, role): document = factories.DocumentFactory(link_reach=reach, link_role=role) - mock_response = MagicMock() - mock_response.model_dump.return_value = { - "id": "chatcmpl-456", - "object": "chat.completion", - "model": "llama", - "choices": [ - { - "index": 0, - "message": {"role": "assistant", "content": "Hi there!"}, - "finish_reason": "stop", - } - ], - } - mock_create.return_value = mock_response + mock_stream.return_value = iter(["data: response\n"]) url = f"/api/v1.0/documents/{document.id!s}/ai-proxy/" response = client.post( url, - { - "messages": [{"role": "user", "content": "Hello"}], - "model": "llama", - }, - format="json", + b"{}", + content_type="application/json", ) assert response.status_code == 200 - response_data = response.json() - assert response_data["id"] == "chatcmpl-456" - assert response_data["choices"][0]["message"]["content"] == "Hi there!" - - mock_create.assert_called_once_with( - messages=[ - {"role": "system", "content": BLOCKNOTE_TOOL_STRICT_PROMPT}, - {"role": "user", "content": "Hello"}, - ], - model="llama", - stream=False, - ) + assert response["Content-Type"] == "text/event-stream" + mock_stream.assert_called_once() @pytest.mark.parametrize("via", VIA) @@ -261,11 +194,8 @@ def test_api_documents_ai_proxy_reader(via, mock_user_teams): url = f"/api/v1.0/documents/{document.id!s}/ai-proxy/" response = client.post( url, - { - "messages": [{"role": "user", "content": "Hello"}], - "model": "llama", - }, - format="json", + b"{}", + content_type="application/json", ) assert response.status_code == 403 @@ -273,9 +203,8 @@ def test_api_documents_ai_proxy_reader(via, mock_user_teams): @pytest.mark.parametrize("role", ["editor", "administrator", "owner"]) @pytest.mark.parametrize("via", VIA) -@override_settings(AI_STREAM=False) -@patch("openai.resources.chat.completions.Completions.create") -def test_api_documents_ai_proxy_success(mock_create, via, role, mock_user_teams): +@patch("core.services.ai_services.AIService.stream") +def test_api_documents_ai_proxy_success(mock_stream, via, role, mock_user_teams): """Users with sufficient permissions should be able to request AI proxy.""" user = factories.UserFactory() @@ -291,402 +220,27 @@ def test_api_documents_ai_proxy_success(mock_create, via, role, mock_user_teams) document=document, team="lasuite", role=role ) - mock_response = MagicMock() - mock_response.model_dump.return_value = { - "id": "chatcmpl-789", - "object": "chat.completion", - "model": "llama", - "choices": [ - { - "index": 0, - "message": {"role": "assistant", "content": "Success!"}, - "finish_reason": "stop", - } - ], - } - mock_create.return_value = mock_response + mock_stream.return_value = iter(["data: success\n"]) url = f"/api/v1.0/documents/{document.id!s}/ai-proxy/" response = client.post( url, - { - "messages": [{"role": "user", "content": "Test message"}], - "model": "llama", - }, - format="json", + b"{}", + content_type="application/json", ) assert response.status_code == 200 - response_data = response.json() - assert response_data["id"] == "chatcmpl-789" - assert response_data["choices"][0]["message"]["content"] == "Success!" + assert response["Content-Type"] == "text/event-stream" + assert response["x-vercel-ai-data-stream"] == "v1" + assert response["X-Accel-Buffering"] == "no" - mock_create.assert_called_once_with( - messages=[ - {"role": "system", "content": BLOCKNOTE_TOOL_STRICT_PROMPT}, - {"role": "user", "content": "Test message"}, - ], - model="llama", - stream=False, - ) - - -def test_api_documents_ai_proxy_empty_messages(): - """The messages should not be empty when requesting AI proxy.""" - user = factories.UserFactory() - - client = APIClient() - client.force_login(user) - - document = factories.DocumentFactory(link_reach="public", link_role="editor") - - url = f"/api/v1.0/documents/{document.id!s}/ai-proxy/" - response = client.post(url, {"messages": [], "model": "llama"}, format="json") - - assert response.status_code == 400 - assert response.json() == {"messages": ["This list may not be empty."]} - - -def test_api_documents_ai_proxy_missing_model(): - """The model should be required when requesting AI proxy.""" - user = factories.UserFactory() - - client = APIClient() - client.force_login(user) - - document = factories.DocumentFactory(link_reach="public", link_role="editor") - - url = f"/api/v1.0/documents/{document.id!s}/ai-proxy/" - response = client.post( - url, {"messages": [{"role": "user", "content": "Hello"}]}, format="json" - ) - - assert response.status_code == 400 - assert response.json() == {"model": ["This field is required."]} - - -def test_api_documents_ai_proxy_invalid_message_format(): - """Messages should have the correct format when requesting AI proxy.""" - user = factories.UserFactory() - - client = APIClient() - client.force_login(user) - - document = factories.DocumentFactory(link_reach="public", link_role="editor") - - url = f"/api/v1.0/documents/{document.id!s}/ai-proxy/" - - # Test with invalid message format (missing role) - response = client.post( - url, - { - "messages": [{"content": "Hello"}], - "model": "llama", - }, - format="json", - ) - - assert response.status_code == 400 - assert response.json() == { - "messages": ["Each message must have 'role' and 'content' fields"] - } - - # Test with invalid message format (missing content) - response = client.post( - url, - { - "messages": [{"role": "user"}], - "model": "llama", - }, - format="json", - ) - - assert response.status_code == 400 - assert response.json() == { - "messages": ["Each message must have 'role' and 'content' fields"] - } - - # Test with non-dict message - response = client.post( - url, - { - "messages": ["invalid"], - "model": "llama", - }, - format="json", - ) - - assert response.status_code == 400 - assert response.json() == { - "messages": {"0": ['Expected a dictionary of items but got type "str".']} - } - - -@override_settings(AI_STREAM=False) -@patch("openai.resources.chat.completions.Completions.create") -def test_api_documents_ai_proxy_stream_disabled(mock_create): - """Stream should be automatically disabled in AI proxy requests.""" - user = factories.UserFactory() - - client = APIClient() - client.force_login(user) - - document = factories.DocumentFactory(link_reach="public", link_role="editor") - - mock_response = MagicMock() - mock_response.model_dump.return_value = {"content": "Success!"} - mock_create.return_value = mock_response - - url = f"/api/v1.0/documents/{document.id!s}/ai-proxy/" - response = client.post( - url, - { - "messages": [{"role": "user", "content": "Hello"}], - "model": "llama", - "stream": True, # This should be overridden to False - }, - format="json", - ) - - assert response.status_code == 200 - # Verify that stream was set to False - mock_create.assert_called_once_with( - messages=[ - {"role": "system", "content": BLOCKNOTE_TOOL_STRICT_PROMPT}, - {"role": "user", "content": "Hello"}, - ], - model="llama", - stream=False, - ) - - -@override_settings(AI_STREAM=False) -@patch("openai.resources.chat.completions.Completions.create") -def test_api_documents_ai_proxy_additional_parameters(mock_create): - """AI proxy should pass through additional parameters to the AI service.""" - user = factories.UserFactory() - - client = APIClient() - client.force_login(user) - - document = factories.DocumentFactory(link_reach="public", link_role="editor") - - mock_response = MagicMock() - mock_response.model_dump.return_value = {"content": "Success!"} - mock_create.return_value = mock_response - - url = f"/api/v1.0/documents/{document.id!s}/ai-proxy/" - response = client.post( - url, - { - "messages": [{"role": "user", "content": "Hello"}], - "model": "llama", - "temperature": 0.7, - "max_tokens": 100, - "top_p": 0.9, - }, - format="json", - ) - - assert response.status_code == 200 - # Verify that additional parameters were passed through - mock_create.assert_called_once_with( - messages=[ - {"role": "system", "content": BLOCKNOTE_TOOL_STRICT_PROMPT}, - {"role": "user", "content": "Hello"}, - ], - model="llama", - temperature=0.7, - max_tokens=100, - top_p=0.9, - stream=False, - ) - - -@override_settings(AI_STREAM=False) -@override_settings(AI_DOCUMENT_RATE_THROTTLE_RATES={"minute": 3, "hour": 6, "day": 10}) -@patch("openai.resources.chat.completions.Completions.create") -def test_api_documents_ai_proxy_throttling_document(mock_create): - """ - Throttling per document should be triggered on the AI transform endpoint. - For full throttle class test see: `test_api_utils_ai_document_rate_throttles` - """ - client = APIClient() - document = factories.DocumentFactory(link_reach="public", link_role="editor") - - mock_response = MagicMock() - mock_response.model_dump.return_value = {"content": "Success!"} - mock_create.return_value = mock_response - - url = f"/api/v1.0/documents/{document.id!s}/ai-proxy/" - for _ in range(3): - user = factories.UserFactory() - client.force_login(user) - response = client.post( - url, - { - "messages": [{"role": "user", "content": "Test message"}], - "model": "llama", - }, - format="json", - ) - assert response.status_code == 200 - assert response.json() == {"content": "Success!"} - - user = factories.UserFactory() - client.force_login(user) - response = client.post( - url, - { - "messages": [{"role": "user", "content": "Test message"}], - "model": "llama", - }, - ) - - assert response.status_code == 429 - assert response.json() == { - "detail": "Request was throttled. Expected available in 60 seconds." - } - - -@override_settings(AI_STREAM=False) -@patch("openai.resources.chat.completions.Completions.create") -def test_api_documents_ai_proxy_complex_conversation(mock_create): - """AI proxy should handle complex conversations with multiple messages.""" - user = factories.UserFactory() - - client = APIClient() - client.force_login(user) - - document = factories.DocumentFactory(link_reach="public", link_role="editor") - - mock_response = MagicMock() - mock_response.model_dump.return_value = { - "id": "chatcmpl-complex", - "object": "chat.completion", - "model": "llama", - "choices": [ - { - "index": 0, - "message": { - "role": "assistant", - "content": "I understand your question about Python.", - }, - "finish_reason": "stop", - } - ], - } - mock_create.return_value = mock_response - - complex_messages = [ - {"role": "system", "content": BLOCKNOTE_TOOL_STRICT_PROMPT}, - {"role": "system", "content": "You are a helpful programming assistant."}, - {"role": "user", "content": "How do I write a for loop in Python?"}, - { - "role": "assistant", - "content": "You can write a for loop using: for item in iterable:", - }, - {"role": "user", "content": "Can you give me a concrete example?"}, - ] - - url = f"/api/v1.0/documents/{document.id!s}/ai-proxy/" - response = client.post( - url, - { - "messages": complex_messages, - "model": "llama", - }, - format="json", - ) - - assert response.status_code == 200 - response_data = response.json() - assert response_data["id"] == "chatcmpl-complex" - assert ( - response_data["choices"][0]["message"]["content"] - == "I understand your question about Python." - ) - - mock_create.assert_called_once_with( - messages=complex_messages, - model="llama", - stream=False, - ) - - -@override_settings(AI_USER_RATE_THROTTLE_RATES={"minute": 3, "hour": 6, "day": 10}) -@patch("openai.resources.chat.completions.Completions.create") -def test_api_documents_ai_proxy_throttling_user(mock_create): - """ - Throttling per user should be triggered on the AI proxy endpoint. - For full throttle class test see: `test_api_utils_ai_user_rate_throttles` - """ - user = factories.UserFactory() - client = APIClient() - client.force_login(user) - - mock_response = MagicMock() - mock_response.model_dump.return_value = {"content": "Success!"} - mock_create.return_value = mock_response - - for _ in range(3): - document = factories.DocumentFactory(link_reach="public", link_role="editor") - url = f"/api/v1.0/documents/{document.id!s}/ai-proxy/" - response = client.post( - url, - { - "messages": [{"role": "user", "content": "Hello"}], - "model": "llama", - }, - format="json", - ) - assert response.status_code == 200 - - document = factories.DocumentFactory(link_reach="public", link_role="editor") - url = f"/api/v1.0/documents/{document.id!s}/ai-proxy/" - response = client.post( - url, - { - "messages": [{"role": "user", "content": "Hello"}], - "model": "llama", - }, - format="json", - ) - - assert response.status_code == 429 - assert response.json() == { - "detail": "Request was throttled. Expected available in 60 seconds." - } - - -@override_settings(AI_USER_RATE_THROTTLE_RATES={"minute": 10, "hour": 6, "day": 10}) -def test_api_documents_ai_proxy_different_models(): - """AI proxy should work with different AI models.""" - user = factories.UserFactory() - - client = APIClient() - client.force_login(user) - - document = factories.DocumentFactory(link_reach="public", link_role="editor") - - models_to_test = ["gpt-3.5-turbo", "gpt-4", "claude-3", "llama-2"] - - for model_name in models_to_test: - response = client.post( - f"/api/v1.0/documents/{document.id!s}/ai-proxy/", - { - "messages": [{"role": "user", "content": "Hello"}], - "model": model_name, - }, - format="json", - ) - - assert response.status_code == 400 - assert response.json() == {"model": [f"{model_name} is not a valid model"]} + content = b"".join(response.streaming_content).decode() + assert "success" in content + mock_stream.assert_called_once() def test_api_documents_ai_proxy_ai_feature_disabled(settings): - """When the settings AI_FEATURE_ENABLED is set to False, the endpoint is not reachable.""" + """When AI_FEATURE_ENABLED is False, the endpoint returns 400.""" settings.AI_FEATURE_ENABLED = False user = factories.UserFactory() @@ -698,25 +252,91 @@ def test_api_documents_ai_proxy_ai_feature_disabled(settings): response = client.post( f"/api/v1.0/documents/{document.id!s}/ai-proxy/", - { - "messages": [{"role": "user", "content": "Hello"}], - "model": "llama", - }, - format="json", + b"{}", + content_type="application/json", ) assert response.status_code == 400 assert response.json() == ["AI feature is not enabled."] -@override_settings(AI_STREAM=False) -@patch("openai.resources.chat.completions.Completions.create") -def test_api_documents_ai_proxy_harden_payload(mock_create): +@override_settings(AI_DOCUMENT_RATE_THROTTLE_RATES={"minute": 3, "hour": 6, "day": 10}) +@patch("core.services.ai_services.AIService.stream") +def test_api_documents_ai_proxy_throttling_document(mock_stream): """ - AI proxy should harden the payload by default: - - it will remove stream_options if stream is False - - normalize tools by adding descriptions if missing + Throttling per document should be triggered on the AI proxy endpoint. + For full throttle class test see: `test_api_utils_ai_document_rate_throttles` """ + client = APIClient() + document = factories.DocumentFactory(link_reach="public", link_role="editor") + + mock_stream.return_value = iter(["data: ok\n"]) + + url = f"/api/v1.0/documents/{document.id!s}/ai-proxy/" + for _ in range(3): + mock_stream.return_value = iter(["data: ok\n"]) + user = factories.UserFactory() + client.force_login(user) + response = client.post( + url, + b"{}", + content_type="application/json", + ) + assert response.status_code == 200 + + user = factories.UserFactory() + client.force_login(user) + response = client.post( + url, + b"{}", + content_type="application/json", + ) + + assert response.status_code == 429 + assert response.json() == { + "detail": "Request was throttled. Expected available in 60 seconds." + } + + +@override_settings(AI_USER_RATE_THROTTLE_RATES={"minute": 3, "hour": 6, "day": 10}) +@patch("core.services.ai_services.AIService.stream") +def test_api_documents_ai_proxy_throttling_user(mock_stream): + """ + Throttling per user should be triggered on the AI proxy endpoint. + For full throttle class test see: `test_api_utils_ai_user_rate_throttles` + """ + user = factories.UserFactory() + client = APIClient() + client.force_login(user) + + for _ in range(3): + mock_stream.return_value = iter(["data: ok\n"]) + document = factories.DocumentFactory(link_reach="public", link_role="editor") + url = f"/api/v1.0/documents/{document.id!s}/ai-proxy/" + response = client.post( + url, + b"{}", + content_type="application/json", + ) + assert response.status_code == 200 + + document = factories.DocumentFactory(link_reach="public", link_role="editor") + url = f"/api/v1.0/documents/{document.id!s}/ai-proxy/" + response = client.post( + url, + b"{}", + content_type="application/json", + ) + + assert response.status_code == 429 + assert response.json() == { + "detail": "Request was throttled. Expected available in 60 seconds." + } + + +@patch("core.services.ai_services.AIService.stream") +def test_api_documents_ai_proxy_returns_streaming_response(mock_stream): + """AI proxy should return a StreamingHttpResponse with correct headers.""" user = factories.UserFactory() client = APIClient() @@ -724,76 +344,39 @@ def test_api_documents_ai_proxy_harden_payload(mock_create): document = factories.DocumentFactory(link_reach="public", link_role="editor") - mock_response = MagicMock() - mock_response.model_dump.return_value = { - "id": "chatcmpl-789", - "object": "chat.completion", - "model": "llama", - "choices": [ - { - "index": 0, - "message": {"role": "assistant", "content": "Success!"}, - "finish_reason": "stop", - } - ], - } - mock_create.return_value = mock_response + mock_stream.return_value = iter(["data: part1\n", "data: part2\n", "data: part3\n"]) url = f"/api/v1.0/documents/{document.id!s}/ai-proxy/" response = client.post( url, - { - "messages": [{"role": "user", "content": "Hello"}], - "model": "llama", - "stream": False, - "stream_options": {"include_usage": True}, - "tools": [ - { - "type": "function", - "function": { - "name": "applyDocumentOperations", - "parameters": { - "type": "object", - "properties": { - "operations": {"type": "array", "items": {"anyOf": []}} - }, - "additionalProperties": False, - "required": ["operations"], - }, - }, - } - ], - }, - format="json", + b"{}", + content_type="application/json", ) assert response.status_code == 200 - response_data = response.json() - assert response_data["id"] == "chatcmpl-789" - assert response_data["choices"][0]["message"]["content"] == "Success!" + assert response["Content-Type"] == "text/event-stream" + assert response["x-vercel-ai-data-stream"] == "v1" + assert response["X-Accel-Buffering"] == "no" - mock_create.assert_called_once_with( - messages=[ - {"role": "system", "content": BLOCKNOTE_TOOL_STRICT_PROMPT}, - {"role": "user", "content": "Hello"}, - ], - model="llama", - stream=False, - tools=[ - { - "type": "function", - "function": { - "name": "applyDocumentOperations", - "parameters": { - "type": "object", - "properties": { - "operations": {"type": "array", "items": {"anyOf": []}} - }, - "additionalProperties": False, - "required": ["operations"], - }, - "description": "Tool applyDocumentOperations.", - }, - } - ], + chunks = list(response.streaming_content) + assert len(chunks) == 3 + + +def test_api_documents_ai_proxy_invalid_payload(): + """AI Proxy should return a 400 if the payload is invalid.""" + + user = factories.UserFactory() + + document = factories.DocumentFactory(users=[(user, "owner")]) + + client = APIClient() + client.force_login(user) + + response = client.post( + f"/api/v1.0/documents/{document.id!s}/ai-proxy/", + b'{"foo": "bar", "trigger": "submit-message"}', + content_type="application/json", ) + + assert response.status_code == 400 + assert response.json() == {"detail": "Invalid submitted payload"} diff --git a/src/backend/core/tests/test_services_ai_services.py b/src/backend/core/tests/test_services_ai_services.py index 5bc845d0..d1966f5a 100644 --- a/src/backend/core/tests/test_services_ai_services.py +++ b/src/backend/core/tests/test_services_ai_services.py @@ -1,7 +1,9 @@ """ -Test ai API endpoints in the impress core app. +Test AI services in the impress core app. """ +# pylint: disable=protected-access +from collections.abc import AsyncIterator from unittest.mock import MagicMock, patch from django.core.exceptions import ImproperlyConfigured @@ -9,8 +11,12 @@ from django.test.utils import override_settings import pytest from openai import OpenAIError +from pydantic_ai.ui.vercel_ai.request_types import TextUIPart, UIMessage -from core.services.ai_services import BLOCKNOTE_TOOL_STRICT_PROMPT, AIService +from core.services.ai_services import ( + AIService, + convert_async_generator_to_sync, +) pytestmark = pytest.mark.django_db @@ -22,6 +28,11 @@ def ai_settings(settings): settings.AI_BASE_URL = "http://example.com" settings.AI_API_KEY = "test-key" settings.AI_FEATURE_ENABLED = True + settings.LANGFUSE_PUBLIC_KEY = None + settings.AI_VERCEL_SDK_VERSION = 6 + + +# -- AIService.__init__ -- @pytest.mark.parametrize( @@ -43,6 +54,9 @@ def test_services_ai_setting_missing(setting_name, setting_value, settings): AIService() +# -- AIService.transform -- + + @override_settings( AI_BASE_URL="http://example.com", AI_API_KEY="test-key", AI_MODEL="test-model" ) @@ -59,19 +73,6 @@ def test_services_ai_client_error(mock_create): AIService().transform("hello", "prompt") -@patch("openai.resources.chat.completions.Completions.create") -def test_services_ai_proxy_client_error(mock_create): - """Fail when the client raises an error""" - - mock_create.side_effect = OpenAIError("Mocked client error") - - with pytest.raises( - RuntimeError, - match="Failed to proxy AI request: Mocked client error", - ): - AIService().proxy({"messages": [{"role": "user", "content": "hello"}]}) - - @override_settings( AI_BASE_URL="http://example.com", AI_API_KEY="test-key", AI_MODEL="test-model" ) @@ -90,49 +91,6 @@ def test_services_ai_client_invalid_response(mock_create): AIService().transform("hello", "prompt") -@patch("openai.resources.chat.completions.Completions.create") -def test_services_ai_proxy_success(mock_create): - """The AI request should work as expect when called with valid arguments.""" - - mock_create.return_value = { - "id": "chatcmpl-test", - "object": "chat.completion", - "created": 1234567890, - "model": "test-model", - "choices": [ - { - "index": 0, - "message": {"role": "assistant", "content": "Salut"}, - "finish_reason": "stop", - } - ], - } - - response = AIService().proxy({"messages": [{"role": "user", "content": "hello"}]}) - - expected_response = { - "id": "chatcmpl-test", - "object": "chat.completion", - "created": 1234567890, - "model": "test-model", - "choices": [ - { - "index": 0, - "message": {"role": "assistant", "content": "Salut"}, - "finish_reason": "stop", - } - ], - } - assert response == expected_response - mock_create.assert_called_once_with( - messages=[ - {"role": "system", "content": BLOCKNOTE_TOOL_STRICT_PROMPT}, - {"role": "user", "content": "hello"}, - ], - stream=False, - ) - - @override_settings( AI_BASE_URL="http://example.com", AI_API_KEY="test-key", AI_MODEL="test-model" ) @@ -149,46 +107,438 @@ def test_services_ai_success(mock_create): assert response == {"answer": "Salut"} +# -- AIService.translate -- + + @patch("openai.resources.chat.completions.Completions.create") -def test_services_ai_proxy_with_stream(mock_create): - """The AI request should work as expect when called with valid arguments.""" +def test_services_ai_translate_success(mock_create): + """Translate should call the AI API with the correct language prompt.""" - mock_create.return_value = { - "id": "chatcmpl-test", - "object": "chat.completion", - "created": 1234567890, - "model": "test-model", - "choices": [ - { - "index": 0, - "message": {"role": "assistant", "content": "Salut"}, - "finish_reason": "stop", - } - ], - } - - response = AIService().proxy( - {"messages": [{"role": "user", "content": "hello"}]}, stream=True + mock_create.return_value = MagicMock( + choices=[MagicMock(message=MagicMock(content="Bonjour"))] ) - expected_response = { - "id": "chatcmpl-test", - "object": "chat.completion", - "created": 1234567890, - "model": "test-model", - "choices": [ - { - "index": 0, - "message": {"role": "assistant", "content": "Salut"}, - "finish_reason": "stop", - } - ], - } - assert response == expected_response - mock_create.assert_called_once_with( - messages=[ - {"role": "system", "content": BLOCKNOTE_TOOL_STRICT_PROMPT}, - {"role": "user", "content": "hello"}, - ], - stream=True, + response = AIService().translate("

Hello

", "fr") + + assert response == {"answer": "Bonjour"} + call_args = mock_create.call_args + system_content = call_args[1]["messages"][0]["content"] + assert "French" in system_content or "fr" in system_content + + +@patch("openai.resources.chat.completions.Completions.create") +def test_services_ai_translate_unknown_language(mock_create): + """Translate with an unknown language code should use the code as-is.""" + + mock_create.return_value = MagicMock( + choices=[MagicMock(message=MagicMock(content="Translated"))] ) + + response = AIService().translate("

Hello

", "xx-unknown") + + assert response == {"answer": "Translated"} + call_args = mock_create.call_args + system_content = call_args[1]["messages"][0]["content"] + assert "xx-unknown" in system_content + + +# -- convert_async_generator_to_sync -- + + +def test_convert_async_generator_to_sync_basic(): + """Should convert an async generator yielding items to a sync iterator.""" + + async def async_gen(): + for item in ["hello", "world", "!"]: + yield item + + result = list(convert_async_generator_to_sync(async_gen())) + assert result == ["hello", "world", "!"] + + +def test_convert_async_generator_to_sync_empty(): + """Should handle an empty async generator.""" + + async def async_gen(): + return + yield + + result = list(convert_async_generator_to_sync(async_gen())) + assert not result + + +def test_convert_async_generator_to_sync_exception(): + """Should propagate exceptions from the async generator.""" + + async def async_gen(): + yield "first" + raise ValueError("async error") + + sync_iter = convert_async_generator_to_sync(async_gen()) + assert next(sync_iter) == "first" + + with pytest.raises(ValueError, match="async error"): + next(sync_iter) + + +# -- AIService.inject_document_state_messages -- + + +def test_inject_document_state_messages_no_metadata(): + """Messages without documentState metadata should pass through unchanged.""" + messages = [ + UIMessage(role="user", id="msg-1", parts=[TextUIPart(text="Hello")]), + ] + + result = AIService.inject_document_state_messages(messages) + + assert len(result) == 1 + assert result[0].id == "msg-1" + + +def test_inject_document_state_messages_with_selection(): + """A user message with documentState and selection should get an + assistant context message prepended.""" + messages = [ + UIMessage( + role="user", + id="msg-1", + parts=[TextUIPart(text="Fix this")], + metadata={ + "documentState": { + "selection": {"start": 0, "end": 5}, + "selectedBlocks": [{"type": "paragraph", "content": "Hello"}], + "blocks": [ + {"type": "paragraph", "content": "Hello"}, + {"type": "paragraph", "content": "World"}, + ], + } + }, + ), + ] + + result = AIService.inject_document_state_messages(messages) + + assert len(result) == 2 + # First message should be the injected assistant context + assert result[0].role == "assistant" + assert result[0].id == "assistant-document-state-msg-1" + assert len(result[0].parts) == 4 + assert "selection" in result[0].parts[0].text.lower() + # Second message should be the original user message + assert result[1].id == "msg-1" + + +def test_inject_document_state_messages_without_selection(): + """A user message with documentState but no selection should describe + the full document context.""" + messages = [ + UIMessage( + role="user", + id="msg-1", + parts=[TextUIPart(text="Summarize")], + metadata={ + "documentState": { + "selection": None, + "blocks": [ + {"type": "paragraph", "content": "Hello"}, + ], + "isEmptyDocument": False, + } + }, + ), + ] + + result = AIService.inject_document_state_messages(messages) + + assert len(result) == 2 + assistant_msg = result[0] + assert assistant_msg.role == "assistant" + assert len(assistant_msg.parts) == 2 + assert "no active selection" in assistant_msg.parts[0].text.lower() + assert "prefer updating" in assistant_msg.parts[0].text.lower() + + +def test_inject_document_state_messages_empty_document(): + """When the document is empty, the injected message should instruct + updating the empty block first.""" + messages = [ + UIMessage( + role="user", + id="msg-1", + parts=[TextUIPart(text="Write something")], + metadata={ + "documentState": { + "selection": None, + "blocks": [{"type": "paragraph", "content": ""}], + "isEmptyDocument": True, + } + }, + ), + ] + + result = AIService.inject_document_state_messages(messages) + + assert len(result) == 2 + assistant_msg = result[0] + assert "update the empty block" in assistant_msg.parts[0].text.lower() + + +def test_inject_document_state_messages_mixed(): + """Only user messages with documentState get assistant context; + other messages pass through unchanged.""" + messages = [ + UIMessage( + role="assistant", + id="msg-0", + parts=[TextUIPart(text="Previous response")], + ), + UIMessage( + role="user", + id="msg-1", + parts=[TextUIPart(text="Hello")], + ), + UIMessage( + role="user", + id="msg-2", + parts=[TextUIPart(text="Fix this")], + metadata={ + "documentState": { + "selection": {"start": 0, "end": 5}, + "selectedBlocks": [{"type": "paragraph", "content": "Hello"}], + "blocks": [{"type": "paragraph", "content": "Hello"}], + } + }, + ), + ] + + result = AIService.inject_document_state_messages(messages) + + # 3 original + 1 injected assistant message before msg-2 + assert len(result) == 4 + assert result[0].id == "msg-0" + assert result[1].id == "msg-1" + assert result[2].role == "assistant" + assert result[2].id == "assistant-document-state-msg-2" + assert result[3].id == "msg-2" + + +# -- AIService.tool_definitions_to_toolset -- + + +def test_tool_definitions_to_toolset(): + """Should convert frontend tool definitions to an ExternalToolset.""" + tool_definitions = { + "applyOperations": { + "description": "Apply operations to the document", + "inputSchema": { + "type": "object", + "properties": { + "operations": {"type": "array"}, + }, + }, + "outputSchema": {"type": "object"}, + }, + "insertBlocks": { + "description": "Insert blocks", + "inputSchema": {"type": "object"}, + }, + } + + toolset = AIService.tool_definitions_to_toolset(tool_definitions) + + # The ExternalToolset wraps ToolDefinition objects + assert toolset is not None + # Access internal tool definitions + tool_defs = toolset.tool_defs + assert len(tool_defs) == 2 + + names = {td.name for td in tool_defs} + assert names == {"applyOperations", "insertBlocks"} + + for td in tool_defs: + assert td.kind == "external" + if td.name == "applyOperations": + assert td.description == "Apply operations to the document" + assert td.metadata == {"output_schema": {"type": "object"}} + + +def test_tool_definitions_to_toolset_missing_fields(): + """Should handle tool definitions with missing optional fields.""" + tool_definitions = { + "myTool": {}, + } + + toolset = AIService.tool_definitions_to_toolset(tool_definitions) + + tool_defs = toolset.tool_defs + assert len(tool_defs) == 1 + assert tool_defs[0].name == "myTool" + assert tool_defs[0].description == "" + assert tool_defs[0].parameters_json_schema == {} + assert tool_defs[0].metadata == {"output_schema": None} + + +# -- AIService.stream -- + + +@patch.object(AIService, "_build_async_stream") +def test_services_ai_stream_sync_mode(mock_build, monkeypatch): + """In sync mode, stream() should return a sync iterator.""" + + async def mock_async_gen(): + yield "chunk1" + yield "chunk2" + + mock_build.return_value = mock_async_gen() + monkeypatch.setenv("PYTHON_SERVER_MODE", "sync") + + service = AIService() + request = MagicMock() + result = service.stream(request) + + # Should be a regular (sync) iterator, not async + assert not isinstance(result, AsyncIterator) + assert list(result) == ["chunk1", "chunk2"] + mock_build.assert_called_once_with(request) + + +@patch.object(AIService, "_build_async_stream") +def test_services_ai_stream_async_mode(mock_build, monkeypatch): + """In async mode, stream() should return the async iterator directly.""" + + async def mock_async_gen(): + yield "chunk1" + yield "chunk2" + + mock_async_iter = mock_async_gen() + mock_build.return_value = mock_async_iter + monkeypatch.setenv("PYTHON_SERVER_MODE", "async") + + service = AIService() + request = MagicMock() + result = service.stream(request) + + assert result is mock_async_iter + mock_build.assert_called_once_with(request) + + +@patch.object(AIService, "_build_async_stream") +def test_services_ai_stream_defaults_to_sync(mock_build, monkeypatch): + """When PYTHON_SERVER_MODE is not set, stream() should default to sync.""" + + async def mock_async_gen(): + yield "data" + + mock_build.return_value = mock_async_gen() + monkeypatch.delenv("PYTHON_SERVER_MODE", raising=False) + + service = AIService() + request = MagicMock() + result = service.stream(request) + + # Default should be sync mode + assert not isinstance(result, AsyncIterator) + assert list(result) == ["data"] + + +# -- AIService._build_async_stream -- + + +@patch("core.services.ai_services.VercelAIAdapter") +def test_services_ai_build_async_stream(mock_adapter_cls): + """_build_async_stream should build the pydantic-ai streaming pipeline.""" + + async def mock_encode(): + yield "event-data" + + mock_run_input = MagicMock() + mock_run_input.model_extra = None + mock_run_input.messages = [] + mock_adapter_cls.build_run_input.return_value = mock_run_input + + mock_adapter_instance = MagicMock() + mock_adapter_instance.run_stream.return_value = MagicMock() + mock_adapter_instance.encode_stream.return_value = mock_encode() + mock_adapter_cls.return_value = mock_adapter_instance + + service = AIService() + request = MagicMock() + request.META = {"HTTP_ACCEPT": "text/event-stream"} + request.raw_body = b'{"messages": []}' + + result = service._build_async_stream(request) + assert isinstance(result, AsyncIterator) + mock_adapter_cls.build_run_input.assert_called_once_with(b'{"messages": []}') + mock_adapter_instance.run_stream.assert_called_once() + mock_adapter_instance.encode_stream.assert_called_once() + + +@patch("core.services.ai_services.VercelAIAdapter") +def test_services_ai_build_async_stream_with_tool_definitions(mock_adapter_cls): + """_build_async_stream should build an ExternalToolset when + toolDefinitions are present in the request.""" + + async def mock_encode(): + yield "event-data" + + mock_run_input = MagicMock() + mock_run_input.model_extra = { + "toolDefinitions": { + "myTool": { + "description": "A tool", + "inputSchema": {"type": "object"}, + } + } + } + mock_run_input.messages = [] + mock_adapter_cls.build_run_input.return_value = mock_run_input + + mock_adapter_instance = MagicMock() + mock_adapter_instance.run_stream.return_value = MagicMock() + mock_adapter_instance.encode_stream.return_value = mock_encode() + mock_adapter_cls.return_value = mock_adapter_instance + + service = AIService() + request = MagicMock() + request.META = {} + request.raw_body = b"{}" + + service._build_async_stream(request) + # run_stream should have been called with a toolset + call_kwargs = mock_adapter_instance.run_stream.call_args[1] + assert call_kwargs["toolsets"] is not None + assert len(call_kwargs["toolsets"]) == 1 + + +@patch("core.services.ai_services.Agent") +@patch("core.services.ai_services.VercelAIAdapter") +def test_services_ai_build_async_stream_langfuse_enabled( + mock_adapter_cls, mock_agent_cls, settings +): + """When LANGFUSE_PUBLIC_KEY is set, instrument should be enabled.""" + settings.LANGFUSE_PUBLIC_KEY = "pk-test-123" + + async def mock_encode(): + yield "data" + + mock_run_input = MagicMock() + mock_run_input.model_extra = None + mock_run_input.messages = [] + mock_adapter_cls.build_run_input.return_value = mock_run_input + + mock_adapter_instance = MagicMock() + mock_adapter_instance.run_stream.return_value = MagicMock() + mock_adapter_instance.encode_stream.return_value = mock_encode() + mock_adapter_cls.return_value = mock_adapter_instance + + service = AIService() + request = MagicMock() + request.META = {} + request.raw_body = b"{}" + + service._build_async_stream(request) + mock_agent_cls.instrument_all.assert_called_once() + # Agent should be created with instrument=True + mock_agent_cls.assert_called_once() + assert mock_agent_cls.call_args[1]["instrument"] is True diff --git a/src/backend/impress/settings.py b/src/backend/impress/settings.py index de4c6212..45c2bf9b 100755 --- a/src/backend/impress/settings.py +++ b/src/backend/impress/settings.py @@ -326,6 +326,7 @@ class Base(Configuration): "django.middleware.csrf.CsrfViewMiddleware", "django.contrib.auth.middleware.AuthenticationMiddleware", "core.middleware.ForceSessionMiddleware", + "core.middleware.SaveRawBodyMiddleware", "django.contrib.messages.middleware.MessageMiddleware", "dockerflow.django.middleware.DockerflowMiddleware", "csp.middleware.CSPMiddleware", @@ -716,6 +717,9 @@ class Base(Configuration): AI_STREAM = values.BooleanValue( default=True, environ_name="AI_STREAM", environ_prefix=None ) + AI_VERCEL_SDK_VERSION = values.IntegerValue( + 6, environ_name="AI_VERCEL_SDK_VERSION", environ_prefix=None + ) AI_USER_RATE_THROTTLE_RATES = { "minute": 3, "hour": 50, diff --git a/src/backend/pyproject.toml b/src/backend/pyproject.toml index 33eaf4ed..ca3d7f71 100644 --- a/src/backend/pyproject.toml +++ b/src/backend/pyproject.toml @@ -53,9 +53,11 @@ dependencies = [ "markdown==3.10", "mozilla-django-oidc==5.0.2", "nested-multipart-parser==1.6.0", - "openai==2.14.0", + "openai==2.21.0", "psycopg[binary]==3.3.2", "pycrdt==0.12.44", + "pydantic==2.12.5", + "pydantic-ai-slim[openai,logfire,web]==1.58.0", "PyJWT==2.10.1", "python-magic==0.4.27", "redis<6.0.0",