diff --git a/docs/env.md b/docs/env.md index 725cfb07..241bf6ae 100644 --- a/docs/env.md +++ b/docs/env.md @@ -15,6 +15,7 @@ These are the environment variables you can set for the `impress-backend` contai | AI_FEATURE_ENABLED | Enable AI options | false | | AI_MODEL | AI Model to use | | | AI_STREAM | AI Stream to activate | true | +| AI_VERCEL_SDK_VERSION | The vercel AI SDK version used | 6 | | ALLOW_LOGOUT_GET_METHOD | Allow get logout method | true | | API_USERS_LIST_LIMIT | Limit on API users | 5 | | API_USERS_LIST_THROTTLE_RATE_BURST | Throttle rate for api on burst | 30/minute | @@ -166,4 +167,3 @@ Packages with licences incompatible with the MIT licence: In `.env.development`, `PUBLISH_AS_MIT` is set to `false`, allowing developers to test Docs with all its features. ⚠️ If you run Docs in production with `PUBLISH_AS_MIT` set to `false` make sure you fulfill your BlockNote licensing or [subscription](https://www.blocknotejs.org/about#partner-with-us) obligations. - diff --git a/src/backend/core/api/serializers.py b/src/backend/core/api/serializers.py index 75eaa5d3..a7f3d401 100644 --- a/src/backend/core/api/serializers.py +++ b/src/backend/core/api/serializers.py @@ -833,39 +833,6 @@ class AITranslateSerializer(serializers.Serializer): return value -class AIProxySerializer(serializers.Serializer): - """Serializer for AI proxy requests.""" - - messages = serializers.ListField( - required=True, - child=serializers.DictField(), - allow_empty=False, - ) - model = serializers.CharField(required=True) - - def validate_messages(self, messages): - """Validate messages structure.""" - # Ensure each message has the required fields - for message in messages: - if ( - not isinstance(message, dict) - or "role" not in message - or "content" not in message - ): - raise serializers.ValidationError( - "Each message must have 'role' and 'content' fields" - ) - - return messages - - def validate_model(self, value): - """Validate model value is the same than settings.AI_MODEL""" - if value != settings.AI_MODEL: - raise serializers.ValidationError(f"{value} is not a valid model") - - return value - - class MoveDocumentSerializer(serializers.Serializer): """ Serializer for validating input data to move a document within the tree structure. diff --git a/src/backend/core/api/viewsets.py b/src/backend/core/api/viewsets.py index 9f715139..26b174ed 100644 --- a/src/backend/core/api/viewsets.py +++ b/src/backend/core/api/viewsets.py @@ -38,6 +38,7 @@ from csp.decorators import csp_update from lasuite.malware_detection import malware_detection from lasuite.oidc_login.decorators import refresh_oidc_access_token from lasuite.tools.email import get_domain_from_email +from pydantic import ValidationError as PydanticValidationError from rest_framework import filters, status, viewsets from rest_framework import response as drf_response from rest_framework.permissions import AllowAny @@ -1854,22 +1855,24 @@ class DocumentViewSet( if not settings.AI_FEATURE_ENABLED: raise ValidationError("AI feature is not enabled.") - serializer = serializers.AIProxySerializer(data=request.data) - serializer.is_valid(raise_exception=True) - ai_service = AIService() - if settings.AI_STREAM: - return StreamingHttpResponse( - ai_service.stream(request.data), - content_type="text/event-stream", - status=drf.status.HTTP_200_OK, + try: + stream = ai_service.stream(request) + except PydanticValidationError as err: + logger.info("pydantic validation error: %s", err) + return drf.response.Response( + {"detail": "Invalid submitted payload"}, + status=drf.status.HTTP_400_BAD_REQUEST, ) - ai_response = ai_service.proxy(request.data) - return drf.response.Response( - ai_response.model_dump(), - status=drf.status.HTTP_200_OK, + return StreamingHttpResponse( + stream, + content_type="text/event-stream", + headers={ + "x-vercel-ai-data-stream": "v1", # This header is used for Vercel AI streaming, + "X-Accel-Buffering": "no", # Prevent nginx buffering + }, ) @drf.decorators.action( diff --git a/src/backend/core/middleware.py b/src/backend/core/middleware.py index afb5a100..410a461d 100644 --- a/src/backend/core/middleware.py +++ b/src/backend/core/middleware.py @@ -19,3 +19,21 @@ class ForceSessionMiddleware: response = self.get_response(request) return response + + +class SaveRawBodyMiddleware: + """ + Save the raw request body to use it later. + """ + + def __init__(self, get_response): + """Initialize the middleware.""" + self.get_response = get_response + + def __call__(self, request): + """Save the raw request body in the request to use it later.""" + if request.path.endswith(("/ai-proxy/", "/ai-proxy")): + request.raw_body = request.body + + response = self.get_response(request) + return response diff --git a/src/backend/core/services/ai_services.py b/src/backend/core/services/ai_services.py index cad97dfb..d325decc 100644 --- a/src/backend/core/services/ai_services.py +++ b/src/backend/core/services/ai_services.py @@ -1,55 +1,38 @@ """AI services.""" +import asyncio import json import logging -from typing import Any, Dict, Generator +import os +import queue +import threading +from collections.abc import AsyncIterator, Iterator +from typing import Any, Dict, Union from django.conf import settings from django.core.exceptions import ImproperlyConfigured -from openai import OpenAIError +from langfuse import get_client +from langfuse.openai import OpenAI as OpenAI_Langfuse +from pydantic_ai import Agent, DeferredToolRequests +from pydantic_ai.models.openai import OpenAIChatModel +from pydantic_ai.providers.openai import OpenAIProvider +from pydantic_ai.tools import ToolDefinition +from pydantic_ai.toolsets.external import ExternalToolset +from pydantic_ai.ui import SSE_CONTENT_TYPE +from pydantic_ai.ui.vercel_ai import VercelAIAdapter +from pydantic_ai.ui.vercel_ai.request_types import TextUIPart, UIMessage +from rest_framework.request import Request from core import enums if settings.LANGFUSE_PUBLIC_KEY: - from langfuse.openai import OpenAI + OpenAI = OpenAI_Langfuse else: from openai import OpenAI log = logging.getLogger(__name__) - -BLOCKNOTE_TOOL_STRICT_PROMPT = """ -You are editing a BlockNote document via the tool applyDocumentOperations. - -You MUST respond ONLY by calling applyDocumentOperations. -The tool input MUST be valid JSON: -{ "operations": [ ... ] } - -Each operation MUST include "type" and it MUST be one of: -- "update" (requires: id, block) -- "add" (requires: referenceId, position, blocks) -- "delete" (requires: id) - -VALID SHAPES (FOLLOW EXACTLY): - -Update: -{ "type":"update", "id":"", "block":"

...

" } -IMPORTANT: "block" MUST be a STRING containing a SINGLE valid HTML element. - -Add: -{ "type":"add", "referenceId":"", "position":"before|after", "blocks":["

...

"] } -IMPORTANT: "blocks" MUST be an ARRAY OF STRINGS. -Each item MUST be a STRING containing a SINGLE valid HTML element. - -Delete: -{ "type":"delete", "id":"" } - -IDs ALWAYS end with "$". Use ids EXACTLY as provided. - -Return ONLY the JSON tool input. No prose, no markdown. -""" - AI_ACTIONS = { "prompt": ( "Answer the prompt using markdown formatting for structure and emphasis. " @@ -95,6 +78,40 @@ AI_TRANSLATE = ( ) +def convert_async_generator_to_sync(async_gen: AsyncIterator[str]) -> Iterator[str]: + """Convert an async generator to a sync generator.""" + q: queue.Queue[str | object] = queue.Queue() + sentinel = object() + exc_sentinel = object() + + async def run_async_gen(): + try: + async for async_item in async_gen: + q.put(async_item) + except Exception as exc: # pylint: disable=broad-except #noqa: BLE001 + q.put((exc_sentinel, exc)) + finally: + q.put(sentinel) + + def start_async_loop(): + asyncio.run(run_async_gen()) + + thread = threading.Thread(target=start_async_loop, daemon=True) + thread.start() + + try: + while True: + item = q.get() + if item is sentinel: + break + if isinstance(item, tuple) and item[0] is exc_sentinel: + # re-raise the exception in the sync context + raise item[1] + yield item + finally: + thread.join() + + class AIService: """Service class for AI-related operations.""" @@ -136,77 +153,171 @@ class AIService: system_content = AI_TRANSLATE.format(language=language_display) return self.call_ai_api(system_content, text) - def _normalize_tools(self, tools: list) -> list: + @staticmethod + def inject_document_state_messages( + messages: list[UIMessage], + ) -> list[UIMessage]: + """Inject document state context before user messages. + + Port of BlockNote's injectDocumentStateMessages. + For each user message carrying documentState metadata, an assistant + message describing the current document/selection state is prepended + so the LLM sees it as context. """ - Normalize tool definitions to ensure they have required fields. - """ - normalized = [] - for tool in tools: - if isinstance(tool, dict) and tool.get("type") == "function": - fn = tool.get("function") or {} - if isinstance(fn, dict) and not fn.get("description"): - fn["description"] = f"Tool {fn.get('name', 'unknown')}." - tool["function"] = fn - normalized.append(tool) - return normalized + result: list[UIMessage] = [] + for message in messages: + if ( + message.role == "user" + and isinstance(message.metadata, dict) + and "documentState" in message.metadata + ): + doc_state = message.metadata["documentState"] + selection = doc_state.get("selection") + blocks = doc_state.get("blocks") - def _harden_payload( - self, payload: Dict[str, Any], stream: bool = False - ) -> Dict[str, Any]: - """Harden the AI API payload to enforce compliance and tool usage.""" - payload["stream"] = stream - - # Remove stream_options if stream is False - if not stream and "stream_options" in payload: - payload.pop("stream_options") - - # Tools normalization - if isinstance(payload.get("tools"), list): - payload["tools"] = self._normalize_tools(payload["tools"]) - - # Inject strict system prompt once - msgs = payload.get("messages") - if isinstance(msgs, list): - need = True - if msgs and isinstance(msgs[0], dict) and msgs[0].get("role") == "system": - c = msgs[0].get("content") or "" - if ( - isinstance(c, str) - and "applyDocumentOperations" in c - and "blocks" in c - ): - need = False - if need: - payload["messages"] = [ - {"role": "system", "content": BLOCKNOTE_TOOL_STRICT_PROMPT} - ] + msgs - - return payload - - def proxy(self, data: dict, stream: bool = False) -> Generator[str, None, None]: - """Proxy AI API requests to the configured AI provider.""" - payload = self._harden_payload(data, stream=stream) - try: - return self.client.chat.completions.create(**payload) - except OpenAIError as e: - raise RuntimeError(f"Failed to proxy AI request: {e}") from e - - def stream(self, data: dict) -> Generator[str, None, None]: - """Stream AI API requests to the configured AI provider.""" - - try: - stream = self.proxy(data, stream=True) - for chunk in stream: - try: - chunk_dict = ( - chunk.model_dump() if hasattr(chunk, "model_dump") else chunk + if selection: + parts = [ + TextUIPart( + text=( + "This is the latest state of the selection " + "(ignore previous selections, you MUST issue " + "operations against this latest version of " + "the selection):" + ), + ), + TextUIPart( + text=json.dumps(doc_state.get("selectedBlocks")), + ), + TextUIPart( + text=( + "This is the latest state of the entire " + "document (INCLUDING the selected text), you " + "can use this to find the selected text to " + "understand the context (but you MUST NOT " + "issue operations against this document, you " + "MUST issue operations against the selection):" + ), + ), + TextUIPart(text=json.dumps(blocks)), + ] + else: + text = ( + "There is no active selection. This is the latest " + "state of the document (ignore previous documents, " + "you MUST issue operations against this latest " + "version of the document). The cursor is BETWEEN " + "two blocks as indicated by cursor: true." ) - chunk_json = json.dumps(chunk_dict) - yield f"data: {chunk_json}\n\n" - except (AttributeError, TypeError) as e: - log.error("Error serializing chunk: %s, chunk: %s", e, chunk) - continue - except (OpenAIError, RuntimeError, OSError, ValueError) as e: - log.error("Streaming error: %s", e) + if doc_state.get("isEmptyDocument"): + text += ( + "Because the document is empty, YOU MUST first " + "update the empty block before adding new blocks." + ) + else: + text += ( + "Prefer updating existing blocks over removing " + "and adding (but this also depends on the " + "user's question)." + ) + parts = [ + TextUIPart(text=text), + TextUIPart(text=json.dumps(blocks)), + ] - yield "data: [DONE]\n\n" + result.append( + UIMessage( + role="assistant", + id=f"assistant-document-state-{message.id}", + parts=parts, + ) + ) + + result.append(message) + return result + + @staticmethod + def tool_definitions_to_toolset( + tool_definitions: Dict[str, Any], + ) -> ExternalToolset: + """Convert serialized tool definitions to a pydantic-ai ExternalToolset. + + Port of BlockNote's toolDefinitionsToToolSet. + Builds ToolDefinition objects from the JSON-Schema-based definitions + sent by the frontend and wraps them in an ExternalToolset so that + pydantic-ai advertises them to the LLM without trying to execute them + server-side (execution is deferred to the frontend). + """ + tool_defs = [ + ToolDefinition( + name=name, + description=defn.get("description", ""), + parameters_json_schema=defn.get("inputSchema", {}), + kind="external", + metadata={ + "output_schema": defn.get("outputSchema"), + }, + ) + for name, defn in tool_definitions.items() + ] + return ExternalToolset(tool_defs) + + def _build_async_stream(self, request: Request) -> AsyncIterator[str]: + """Build the async stream from the AI provider.""" + instrument_enabled = settings.LANGFUSE_PUBLIC_KEY is not None + + if instrument_enabled: + langfuse = get_client() + langfuse.auth_check() + Agent.instrument_all() + + model = OpenAIChatModel( + settings.AI_MODEL, + provider=OpenAIProvider( + base_url=settings.AI_BASE_URL, api_key=settings.AI_API_KEY + ), + ) + agent = Agent(model, instrument=instrument_enabled) + + accept = request.META.get("HTTP_ACCEPT", SSE_CONTENT_TYPE) + + run_input = VercelAIAdapter.build_run_input(request.raw_body) + + # Inject document state context into the conversation + run_input.messages = self.inject_document_state_messages(run_input.messages) + + # Build an ExternalToolset from frontend-supplied tool definitions + raw_tool_defs = ( + run_input.model_extra.get("toolDefinitions") + if run_input.model_extra + else None + ) + toolset = ( + self.tool_definitions_to_toolset(raw_tool_defs) if raw_tool_defs else None + ) + + adapter = VercelAIAdapter( + agent=agent, + run_input=run_input, + accept=accept, + sdk_version=settings.AI_VERCEL_SDK_VERSION, + ) + + event_stream = adapter.run_stream( + output_type=[str, DeferredToolRequests] if toolset else None, + toolsets=[toolset] if toolset else None, + ) + + return adapter.encode_stream(event_stream) + + def stream(self, request: Request) -> Union[AsyncIterator[str], Iterator[str]]: + """Stream AI API requests to the configured AI provider. + + Returns an async iterator when running in async mode (ASGI) + or a sync iterator when running in sync mode (WSGI). + """ + async_stream = self._build_async_stream(request) + + if os.environ.get("PYTHON_SERVER_MODE", "sync") == "async": + return async_stream + + return convert_async_generator_to_sync(async_stream) diff --git a/src/backend/core/tests/documents/test_api_documents_ai_proxy.py b/src/backend/core/tests/documents/test_api_documents_ai_proxy.py index 1046cb81..5229bdc3 100644 --- a/src/backend/core/tests/documents/test_api_documents_ai_proxy.py +++ b/src/backend/core/tests/documents/test_api_documents_ai_proxy.py @@ -3,7 +3,7 @@ Test AI proxy API endpoint for users in impress's core app. """ import random -from unittest.mock import MagicMock, patch +from unittest.mock import patch from django.test import override_settings @@ -11,7 +11,6 @@ import pytest from rest_framework.test import APIClient from core import factories -from core.services.ai_services import BLOCKNOTE_TOOL_STRICT_PROMPT from core.tests.conftest import TEAM, USER, VIA pytestmark = pytest.mark.django_db @@ -24,6 +23,8 @@ def ai_settings(settings): settings.AI_BASE_URL = "http://localhost-ai:12345/" settings.AI_API_KEY = "test-key" settings.AI_FEATURE_ENABLED = True + settings.LANGFUSE_PUBLIC_KEY = None + settings.AI_VERCEL_SDK_VERSION = 6 @override_settings( @@ -51,7 +52,6 @@ def test_api_documents_ai_proxy_anonymous_forbidden(reach, role): url, { "messages": [{"role": "user", "content": "Hello"}], - "model": "llama", }, format="json", ) @@ -62,85 +62,48 @@ def test_api_documents_ai_proxy_anonymous_forbidden(reach, role): } -@override_settings(AI_ALLOW_REACH_FROM="public", AI_STREAM=False) -@patch("openai.resources.chat.completions.Completions.create") -def test_api_documents_ai_proxy_anonymous_success(mock_create): +@override_settings(AI_ALLOW_REACH_FROM="public") +@patch("core.services.ai_services.AIService.stream") +def test_api_documents_ai_proxy_anonymous_success(mock_stream): """ Anonymous users should be able to request AI proxy to a document if the link reach and role permit it. """ document = factories.DocumentFactory(link_reach="public", link_role="editor") - mock_response = MagicMock() - mock_response.model_dump.return_value = { - "id": "chatcmpl-123", - "object": "chat.completion", - "created": 1677652288, - "model": "llama", - "choices": [ - { - "index": 0, - "message": { - "role": "assistant", - "content": "Hello! How can I help you?", - }, - "finish_reason": "stop", - } - ], - "usage": {"prompt_tokens": 9, "completion_tokens": 12, "total_tokens": 21}, - } - mock_create.return_value = mock_response + mock_stream.return_value = iter(["data: chunk1\n", "data: chunk2\n"]) + url = f"/api/v1.0/documents/{document.id!s}/ai-proxy/" response = APIClient().post( url, - { - "messages": [{"role": "user", "content": "Hello"}], - "model": "llama", - }, - format="json", + b"{}", + content_type="application/json", ) assert response.status_code == 200 - response_data = response.json() - assert response_data["id"] == "chatcmpl-123" - assert response_data["model"] == "llama" - assert len(response_data["choices"]) == 1 - assert ( - response_data["choices"][0]["message"]["content"] - == "Hello! How can I help you?" - ) + assert response["Content-Type"] == "text/event-stream" + assert response["x-vercel-ai-data-stream"] == "v1" + assert response["X-Accel-Buffering"] == "no" - mock_create.assert_called_once_with( - messages=[ - {"role": "system", "content": BLOCKNOTE_TOOL_STRICT_PROMPT}, - {"role": "user", "content": "Hello"}, - ], - model="llama", - stream=False, - ) + content = b"".join(response.streaming_content).decode() + assert "chunk1" in content + assert "chunk2" in content + mock_stream.assert_called_once() @override_settings(AI_ALLOW_REACH_FROM=random.choice(["authenticated", "restricted"])) -@patch("openai.resources.chat.completions.Completions.create") -def test_api_documents_ai_proxy_anonymous_limited_by_setting(mock_create): +def test_api_documents_ai_proxy_anonymous_limited_by_setting(): """ Anonymous users should not be able to request AI proxy to a document if AI_ALLOW_REACH_FROM setting restricts it. """ document = factories.DocumentFactory(link_reach="public", link_role="editor") - mock_response = MagicMock() - mock_response.model_dump.return_value = {"content": "Hello!"} - mock_create.return_value = mock_response - url = f"/api/v1.0/documents/{document.id!s}/ai-proxy/" response = APIClient().post( url, - { - "messages": [{"role": "user", "content": "Hello"}], - "model": "llama", - }, - format="json", + b"{}", + content_type="application/json", ) assert response.status_code == 401 @@ -170,11 +133,8 @@ def test_api_documents_ai_proxy_authenticated_forbidden(reach, role): url = f"/api/v1.0/documents/{document.id!s}/ai-proxy/" response = client.post( url, - { - "messages": [{"role": "user", "content": "Hello"}], - "model": "llama", - }, - format="json", + b"{}", + content_type="application/json", ) assert response.status_code == 403 @@ -187,9 +147,8 @@ def test_api_documents_ai_proxy_authenticated_forbidden(reach, role): ("public", "editor"), ], ) -@override_settings(AI_STREAM=False) -@patch("openai.resources.chat.completions.Completions.create") -def test_api_documents_ai_proxy_authenticated_success(mock_create, reach, role): +@patch("core.services.ai_services.AIService.stream") +def test_api_documents_ai_proxy_authenticated_success(mock_stream, reach, role): """ Authenticated users should be able to request AI proxy to a document if the link reach and role permit it. @@ -201,44 +160,18 @@ def test_api_documents_ai_proxy_authenticated_success(mock_create, reach, role): document = factories.DocumentFactory(link_reach=reach, link_role=role) - mock_response = MagicMock() - mock_response.model_dump.return_value = { - "id": "chatcmpl-456", - "object": "chat.completion", - "model": "llama", - "choices": [ - { - "index": 0, - "message": {"role": "assistant", "content": "Hi there!"}, - "finish_reason": "stop", - } - ], - } - mock_create.return_value = mock_response + mock_stream.return_value = iter(["data: response\n"]) url = f"/api/v1.0/documents/{document.id!s}/ai-proxy/" response = client.post( url, - { - "messages": [{"role": "user", "content": "Hello"}], - "model": "llama", - }, - format="json", + b"{}", + content_type="application/json", ) assert response.status_code == 200 - response_data = response.json() - assert response_data["id"] == "chatcmpl-456" - assert response_data["choices"][0]["message"]["content"] == "Hi there!" - - mock_create.assert_called_once_with( - messages=[ - {"role": "system", "content": BLOCKNOTE_TOOL_STRICT_PROMPT}, - {"role": "user", "content": "Hello"}, - ], - model="llama", - stream=False, - ) + assert response["Content-Type"] == "text/event-stream" + mock_stream.assert_called_once() @pytest.mark.parametrize("via", VIA) @@ -261,11 +194,8 @@ def test_api_documents_ai_proxy_reader(via, mock_user_teams): url = f"/api/v1.0/documents/{document.id!s}/ai-proxy/" response = client.post( url, - { - "messages": [{"role": "user", "content": "Hello"}], - "model": "llama", - }, - format="json", + b"{}", + content_type="application/json", ) assert response.status_code == 403 @@ -273,9 +203,8 @@ def test_api_documents_ai_proxy_reader(via, mock_user_teams): @pytest.mark.parametrize("role", ["editor", "administrator", "owner"]) @pytest.mark.parametrize("via", VIA) -@override_settings(AI_STREAM=False) -@patch("openai.resources.chat.completions.Completions.create") -def test_api_documents_ai_proxy_success(mock_create, via, role, mock_user_teams): +@patch("core.services.ai_services.AIService.stream") +def test_api_documents_ai_proxy_success(mock_stream, via, role, mock_user_teams): """Users with sufficient permissions should be able to request AI proxy.""" user = factories.UserFactory() @@ -291,402 +220,27 @@ def test_api_documents_ai_proxy_success(mock_create, via, role, mock_user_teams) document=document, team="lasuite", role=role ) - mock_response = MagicMock() - mock_response.model_dump.return_value = { - "id": "chatcmpl-789", - "object": "chat.completion", - "model": "llama", - "choices": [ - { - "index": 0, - "message": {"role": "assistant", "content": "Success!"}, - "finish_reason": "stop", - } - ], - } - mock_create.return_value = mock_response + mock_stream.return_value = iter(["data: success\n"]) url = f"/api/v1.0/documents/{document.id!s}/ai-proxy/" response = client.post( url, - { - "messages": [{"role": "user", "content": "Test message"}], - "model": "llama", - }, - format="json", + b"{}", + content_type="application/json", ) assert response.status_code == 200 - response_data = response.json() - assert response_data["id"] == "chatcmpl-789" - assert response_data["choices"][0]["message"]["content"] == "Success!" + assert response["Content-Type"] == "text/event-stream" + assert response["x-vercel-ai-data-stream"] == "v1" + assert response["X-Accel-Buffering"] == "no" - mock_create.assert_called_once_with( - messages=[ - {"role": "system", "content": BLOCKNOTE_TOOL_STRICT_PROMPT}, - {"role": "user", "content": "Test message"}, - ], - model="llama", - stream=False, - ) - - -def test_api_documents_ai_proxy_empty_messages(): - """The messages should not be empty when requesting AI proxy.""" - user = factories.UserFactory() - - client = APIClient() - client.force_login(user) - - document = factories.DocumentFactory(link_reach="public", link_role="editor") - - url = f"/api/v1.0/documents/{document.id!s}/ai-proxy/" - response = client.post(url, {"messages": [], "model": "llama"}, format="json") - - assert response.status_code == 400 - assert response.json() == {"messages": ["This list may not be empty."]} - - -def test_api_documents_ai_proxy_missing_model(): - """The model should be required when requesting AI proxy.""" - user = factories.UserFactory() - - client = APIClient() - client.force_login(user) - - document = factories.DocumentFactory(link_reach="public", link_role="editor") - - url = f"/api/v1.0/documents/{document.id!s}/ai-proxy/" - response = client.post( - url, {"messages": [{"role": "user", "content": "Hello"}]}, format="json" - ) - - assert response.status_code == 400 - assert response.json() == {"model": ["This field is required."]} - - -def test_api_documents_ai_proxy_invalid_message_format(): - """Messages should have the correct format when requesting AI proxy.""" - user = factories.UserFactory() - - client = APIClient() - client.force_login(user) - - document = factories.DocumentFactory(link_reach="public", link_role="editor") - - url = f"/api/v1.0/documents/{document.id!s}/ai-proxy/" - - # Test with invalid message format (missing role) - response = client.post( - url, - { - "messages": [{"content": "Hello"}], - "model": "llama", - }, - format="json", - ) - - assert response.status_code == 400 - assert response.json() == { - "messages": ["Each message must have 'role' and 'content' fields"] - } - - # Test with invalid message format (missing content) - response = client.post( - url, - { - "messages": [{"role": "user"}], - "model": "llama", - }, - format="json", - ) - - assert response.status_code == 400 - assert response.json() == { - "messages": ["Each message must have 'role' and 'content' fields"] - } - - # Test with non-dict message - response = client.post( - url, - { - "messages": ["invalid"], - "model": "llama", - }, - format="json", - ) - - assert response.status_code == 400 - assert response.json() == { - "messages": {"0": ['Expected a dictionary of items but got type "str".']} - } - - -@override_settings(AI_STREAM=False) -@patch("openai.resources.chat.completions.Completions.create") -def test_api_documents_ai_proxy_stream_disabled(mock_create): - """Stream should be automatically disabled in AI proxy requests.""" - user = factories.UserFactory() - - client = APIClient() - client.force_login(user) - - document = factories.DocumentFactory(link_reach="public", link_role="editor") - - mock_response = MagicMock() - mock_response.model_dump.return_value = {"content": "Success!"} - mock_create.return_value = mock_response - - url = f"/api/v1.0/documents/{document.id!s}/ai-proxy/" - response = client.post( - url, - { - "messages": [{"role": "user", "content": "Hello"}], - "model": "llama", - "stream": True, # This should be overridden to False - }, - format="json", - ) - - assert response.status_code == 200 - # Verify that stream was set to False - mock_create.assert_called_once_with( - messages=[ - {"role": "system", "content": BLOCKNOTE_TOOL_STRICT_PROMPT}, - {"role": "user", "content": "Hello"}, - ], - model="llama", - stream=False, - ) - - -@override_settings(AI_STREAM=False) -@patch("openai.resources.chat.completions.Completions.create") -def test_api_documents_ai_proxy_additional_parameters(mock_create): - """AI proxy should pass through additional parameters to the AI service.""" - user = factories.UserFactory() - - client = APIClient() - client.force_login(user) - - document = factories.DocumentFactory(link_reach="public", link_role="editor") - - mock_response = MagicMock() - mock_response.model_dump.return_value = {"content": "Success!"} - mock_create.return_value = mock_response - - url = f"/api/v1.0/documents/{document.id!s}/ai-proxy/" - response = client.post( - url, - { - "messages": [{"role": "user", "content": "Hello"}], - "model": "llama", - "temperature": 0.7, - "max_tokens": 100, - "top_p": 0.9, - }, - format="json", - ) - - assert response.status_code == 200 - # Verify that additional parameters were passed through - mock_create.assert_called_once_with( - messages=[ - {"role": "system", "content": BLOCKNOTE_TOOL_STRICT_PROMPT}, - {"role": "user", "content": "Hello"}, - ], - model="llama", - temperature=0.7, - max_tokens=100, - top_p=0.9, - stream=False, - ) - - -@override_settings(AI_STREAM=False) -@override_settings(AI_DOCUMENT_RATE_THROTTLE_RATES={"minute": 3, "hour": 6, "day": 10}) -@patch("openai.resources.chat.completions.Completions.create") -def test_api_documents_ai_proxy_throttling_document(mock_create): - """ - Throttling per document should be triggered on the AI transform endpoint. - For full throttle class test see: `test_api_utils_ai_document_rate_throttles` - """ - client = APIClient() - document = factories.DocumentFactory(link_reach="public", link_role="editor") - - mock_response = MagicMock() - mock_response.model_dump.return_value = {"content": "Success!"} - mock_create.return_value = mock_response - - url = f"/api/v1.0/documents/{document.id!s}/ai-proxy/" - for _ in range(3): - user = factories.UserFactory() - client.force_login(user) - response = client.post( - url, - { - "messages": [{"role": "user", "content": "Test message"}], - "model": "llama", - }, - format="json", - ) - assert response.status_code == 200 - assert response.json() == {"content": "Success!"} - - user = factories.UserFactory() - client.force_login(user) - response = client.post( - url, - { - "messages": [{"role": "user", "content": "Test message"}], - "model": "llama", - }, - ) - - assert response.status_code == 429 - assert response.json() == { - "detail": "Request was throttled. Expected available in 60 seconds." - } - - -@override_settings(AI_STREAM=False) -@patch("openai.resources.chat.completions.Completions.create") -def test_api_documents_ai_proxy_complex_conversation(mock_create): - """AI proxy should handle complex conversations with multiple messages.""" - user = factories.UserFactory() - - client = APIClient() - client.force_login(user) - - document = factories.DocumentFactory(link_reach="public", link_role="editor") - - mock_response = MagicMock() - mock_response.model_dump.return_value = { - "id": "chatcmpl-complex", - "object": "chat.completion", - "model": "llama", - "choices": [ - { - "index": 0, - "message": { - "role": "assistant", - "content": "I understand your question about Python.", - }, - "finish_reason": "stop", - } - ], - } - mock_create.return_value = mock_response - - complex_messages = [ - {"role": "system", "content": BLOCKNOTE_TOOL_STRICT_PROMPT}, - {"role": "system", "content": "You are a helpful programming assistant."}, - {"role": "user", "content": "How do I write a for loop in Python?"}, - { - "role": "assistant", - "content": "You can write a for loop using: for item in iterable:", - }, - {"role": "user", "content": "Can you give me a concrete example?"}, - ] - - url = f"/api/v1.0/documents/{document.id!s}/ai-proxy/" - response = client.post( - url, - { - "messages": complex_messages, - "model": "llama", - }, - format="json", - ) - - assert response.status_code == 200 - response_data = response.json() - assert response_data["id"] == "chatcmpl-complex" - assert ( - response_data["choices"][0]["message"]["content"] - == "I understand your question about Python." - ) - - mock_create.assert_called_once_with( - messages=complex_messages, - model="llama", - stream=False, - ) - - -@override_settings(AI_USER_RATE_THROTTLE_RATES={"minute": 3, "hour": 6, "day": 10}) -@patch("openai.resources.chat.completions.Completions.create") -def test_api_documents_ai_proxy_throttling_user(mock_create): - """ - Throttling per user should be triggered on the AI proxy endpoint. - For full throttle class test see: `test_api_utils_ai_user_rate_throttles` - """ - user = factories.UserFactory() - client = APIClient() - client.force_login(user) - - mock_response = MagicMock() - mock_response.model_dump.return_value = {"content": "Success!"} - mock_create.return_value = mock_response - - for _ in range(3): - document = factories.DocumentFactory(link_reach="public", link_role="editor") - url = f"/api/v1.0/documents/{document.id!s}/ai-proxy/" - response = client.post( - url, - { - "messages": [{"role": "user", "content": "Hello"}], - "model": "llama", - }, - format="json", - ) - assert response.status_code == 200 - - document = factories.DocumentFactory(link_reach="public", link_role="editor") - url = f"/api/v1.0/documents/{document.id!s}/ai-proxy/" - response = client.post( - url, - { - "messages": [{"role": "user", "content": "Hello"}], - "model": "llama", - }, - format="json", - ) - - assert response.status_code == 429 - assert response.json() == { - "detail": "Request was throttled. Expected available in 60 seconds." - } - - -@override_settings(AI_USER_RATE_THROTTLE_RATES={"minute": 10, "hour": 6, "day": 10}) -def test_api_documents_ai_proxy_different_models(): - """AI proxy should work with different AI models.""" - user = factories.UserFactory() - - client = APIClient() - client.force_login(user) - - document = factories.DocumentFactory(link_reach="public", link_role="editor") - - models_to_test = ["gpt-3.5-turbo", "gpt-4", "claude-3", "llama-2"] - - for model_name in models_to_test: - response = client.post( - f"/api/v1.0/documents/{document.id!s}/ai-proxy/", - { - "messages": [{"role": "user", "content": "Hello"}], - "model": model_name, - }, - format="json", - ) - - assert response.status_code == 400 - assert response.json() == {"model": [f"{model_name} is not a valid model"]} + content = b"".join(response.streaming_content).decode() + assert "success" in content + mock_stream.assert_called_once() def test_api_documents_ai_proxy_ai_feature_disabled(settings): - """When the settings AI_FEATURE_ENABLED is set to False, the endpoint is not reachable.""" + """When AI_FEATURE_ENABLED is False, the endpoint returns 400.""" settings.AI_FEATURE_ENABLED = False user = factories.UserFactory() @@ -698,25 +252,91 @@ def test_api_documents_ai_proxy_ai_feature_disabled(settings): response = client.post( f"/api/v1.0/documents/{document.id!s}/ai-proxy/", - { - "messages": [{"role": "user", "content": "Hello"}], - "model": "llama", - }, - format="json", + b"{}", + content_type="application/json", ) assert response.status_code == 400 assert response.json() == ["AI feature is not enabled."] -@override_settings(AI_STREAM=False) -@patch("openai.resources.chat.completions.Completions.create") -def test_api_documents_ai_proxy_harden_payload(mock_create): +@override_settings(AI_DOCUMENT_RATE_THROTTLE_RATES={"minute": 3, "hour": 6, "day": 10}) +@patch("core.services.ai_services.AIService.stream") +def test_api_documents_ai_proxy_throttling_document(mock_stream): """ - AI proxy should harden the payload by default: - - it will remove stream_options if stream is False - - normalize tools by adding descriptions if missing + Throttling per document should be triggered on the AI proxy endpoint. + For full throttle class test see: `test_api_utils_ai_document_rate_throttles` """ + client = APIClient() + document = factories.DocumentFactory(link_reach="public", link_role="editor") + + mock_stream.return_value = iter(["data: ok\n"]) + + url = f"/api/v1.0/documents/{document.id!s}/ai-proxy/" + for _ in range(3): + mock_stream.return_value = iter(["data: ok\n"]) + user = factories.UserFactory() + client.force_login(user) + response = client.post( + url, + b"{}", + content_type="application/json", + ) + assert response.status_code == 200 + + user = factories.UserFactory() + client.force_login(user) + response = client.post( + url, + b"{}", + content_type="application/json", + ) + + assert response.status_code == 429 + assert response.json() == { + "detail": "Request was throttled. Expected available in 60 seconds." + } + + +@override_settings(AI_USER_RATE_THROTTLE_RATES={"minute": 3, "hour": 6, "day": 10}) +@patch("core.services.ai_services.AIService.stream") +def test_api_documents_ai_proxy_throttling_user(mock_stream): + """ + Throttling per user should be triggered on the AI proxy endpoint. + For full throttle class test see: `test_api_utils_ai_user_rate_throttles` + """ + user = factories.UserFactory() + client = APIClient() + client.force_login(user) + + for _ in range(3): + mock_stream.return_value = iter(["data: ok\n"]) + document = factories.DocumentFactory(link_reach="public", link_role="editor") + url = f"/api/v1.0/documents/{document.id!s}/ai-proxy/" + response = client.post( + url, + b"{}", + content_type="application/json", + ) + assert response.status_code == 200 + + document = factories.DocumentFactory(link_reach="public", link_role="editor") + url = f"/api/v1.0/documents/{document.id!s}/ai-proxy/" + response = client.post( + url, + b"{}", + content_type="application/json", + ) + + assert response.status_code == 429 + assert response.json() == { + "detail": "Request was throttled. Expected available in 60 seconds." + } + + +@patch("core.services.ai_services.AIService.stream") +def test_api_documents_ai_proxy_returns_streaming_response(mock_stream): + """AI proxy should return a StreamingHttpResponse with correct headers.""" user = factories.UserFactory() client = APIClient() @@ -724,76 +344,39 @@ def test_api_documents_ai_proxy_harden_payload(mock_create): document = factories.DocumentFactory(link_reach="public", link_role="editor") - mock_response = MagicMock() - mock_response.model_dump.return_value = { - "id": "chatcmpl-789", - "object": "chat.completion", - "model": "llama", - "choices": [ - { - "index": 0, - "message": {"role": "assistant", "content": "Success!"}, - "finish_reason": "stop", - } - ], - } - mock_create.return_value = mock_response + mock_stream.return_value = iter(["data: part1\n", "data: part2\n", "data: part3\n"]) url = f"/api/v1.0/documents/{document.id!s}/ai-proxy/" response = client.post( url, - { - "messages": [{"role": "user", "content": "Hello"}], - "model": "llama", - "stream": False, - "stream_options": {"include_usage": True}, - "tools": [ - { - "type": "function", - "function": { - "name": "applyDocumentOperations", - "parameters": { - "type": "object", - "properties": { - "operations": {"type": "array", "items": {"anyOf": []}} - }, - "additionalProperties": False, - "required": ["operations"], - }, - }, - } - ], - }, - format="json", + b"{}", + content_type="application/json", ) assert response.status_code == 200 - response_data = response.json() - assert response_data["id"] == "chatcmpl-789" - assert response_data["choices"][0]["message"]["content"] == "Success!" + assert response["Content-Type"] == "text/event-stream" + assert response["x-vercel-ai-data-stream"] == "v1" + assert response["X-Accel-Buffering"] == "no" - mock_create.assert_called_once_with( - messages=[ - {"role": "system", "content": BLOCKNOTE_TOOL_STRICT_PROMPT}, - {"role": "user", "content": "Hello"}, - ], - model="llama", - stream=False, - tools=[ - { - "type": "function", - "function": { - "name": "applyDocumentOperations", - "parameters": { - "type": "object", - "properties": { - "operations": {"type": "array", "items": {"anyOf": []}} - }, - "additionalProperties": False, - "required": ["operations"], - }, - "description": "Tool applyDocumentOperations.", - }, - } - ], + chunks = list(response.streaming_content) + assert len(chunks) == 3 + + +def test_api_documents_ai_proxy_invalid_payload(): + """AI Proxy should return a 400 if the payload is invalid.""" + + user = factories.UserFactory() + + document = factories.DocumentFactory(users=[(user, "owner")]) + + client = APIClient() + client.force_login(user) + + response = client.post( + f"/api/v1.0/documents/{document.id!s}/ai-proxy/", + b'{"foo": "bar", "trigger": "submit-message"}', + content_type="application/json", ) + + assert response.status_code == 400 + assert response.json() == {"detail": "Invalid submitted payload"} diff --git a/src/backend/core/tests/test_services_ai_services.py b/src/backend/core/tests/test_services_ai_services.py index 5bc845d0..d1966f5a 100644 --- a/src/backend/core/tests/test_services_ai_services.py +++ b/src/backend/core/tests/test_services_ai_services.py @@ -1,7 +1,9 @@ """ -Test ai API endpoints in the impress core app. +Test AI services in the impress core app. """ +# pylint: disable=protected-access +from collections.abc import AsyncIterator from unittest.mock import MagicMock, patch from django.core.exceptions import ImproperlyConfigured @@ -9,8 +11,12 @@ from django.test.utils import override_settings import pytest from openai import OpenAIError +from pydantic_ai.ui.vercel_ai.request_types import TextUIPart, UIMessage -from core.services.ai_services import BLOCKNOTE_TOOL_STRICT_PROMPT, AIService +from core.services.ai_services import ( + AIService, + convert_async_generator_to_sync, +) pytestmark = pytest.mark.django_db @@ -22,6 +28,11 @@ def ai_settings(settings): settings.AI_BASE_URL = "http://example.com" settings.AI_API_KEY = "test-key" settings.AI_FEATURE_ENABLED = True + settings.LANGFUSE_PUBLIC_KEY = None + settings.AI_VERCEL_SDK_VERSION = 6 + + +# -- AIService.__init__ -- @pytest.mark.parametrize( @@ -43,6 +54,9 @@ def test_services_ai_setting_missing(setting_name, setting_value, settings): AIService() +# -- AIService.transform -- + + @override_settings( AI_BASE_URL="http://example.com", AI_API_KEY="test-key", AI_MODEL="test-model" ) @@ -59,19 +73,6 @@ def test_services_ai_client_error(mock_create): AIService().transform("hello", "prompt") -@patch("openai.resources.chat.completions.Completions.create") -def test_services_ai_proxy_client_error(mock_create): - """Fail when the client raises an error""" - - mock_create.side_effect = OpenAIError("Mocked client error") - - with pytest.raises( - RuntimeError, - match="Failed to proxy AI request: Mocked client error", - ): - AIService().proxy({"messages": [{"role": "user", "content": "hello"}]}) - - @override_settings( AI_BASE_URL="http://example.com", AI_API_KEY="test-key", AI_MODEL="test-model" ) @@ -90,49 +91,6 @@ def test_services_ai_client_invalid_response(mock_create): AIService().transform("hello", "prompt") -@patch("openai.resources.chat.completions.Completions.create") -def test_services_ai_proxy_success(mock_create): - """The AI request should work as expect when called with valid arguments.""" - - mock_create.return_value = { - "id": "chatcmpl-test", - "object": "chat.completion", - "created": 1234567890, - "model": "test-model", - "choices": [ - { - "index": 0, - "message": {"role": "assistant", "content": "Salut"}, - "finish_reason": "stop", - } - ], - } - - response = AIService().proxy({"messages": [{"role": "user", "content": "hello"}]}) - - expected_response = { - "id": "chatcmpl-test", - "object": "chat.completion", - "created": 1234567890, - "model": "test-model", - "choices": [ - { - "index": 0, - "message": {"role": "assistant", "content": "Salut"}, - "finish_reason": "stop", - } - ], - } - assert response == expected_response - mock_create.assert_called_once_with( - messages=[ - {"role": "system", "content": BLOCKNOTE_TOOL_STRICT_PROMPT}, - {"role": "user", "content": "hello"}, - ], - stream=False, - ) - - @override_settings( AI_BASE_URL="http://example.com", AI_API_KEY="test-key", AI_MODEL="test-model" ) @@ -149,46 +107,438 @@ def test_services_ai_success(mock_create): assert response == {"answer": "Salut"} +# -- AIService.translate -- + + @patch("openai.resources.chat.completions.Completions.create") -def test_services_ai_proxy_with_stream(mock_create): - """The AI request should work as expect when called with valid arguments.""" +def test_services_ai_translate_success(mock_create): + """Translate should call the AI API with the correct language prompt.""" - mock_create.return_value = { - "id": "chatcmpl-test", - "object": "chat.completion", - "created": 1234567890, - "model": "test-model", - "choices": [ - { - "index": 0, - "message": {"role": "assistant", "content": "Salut"}, - "finish_reason": "stop", - } - ], - } - - response = AIService().proxy( - {"messages": [{"role": "user", "content": "hello"}]}, stream=True + mock_create.return_value = MagicMock( + choices=[MagicMock(message=MagicMock(content="Bonjour"))] ) - expected_response = { - "id": "chatcmpl-test", - "object": "chat.completion", - "created": 1234567890, - "model": "test-model", - "choices": [ - { - "index": 0, - "message": {"role": "assistant", "content": "Salut"}, - "finish_reason": "stop", - } - ], - } - assert response == expected_response - mock_create.assert_called_once_with( - messages=[ - {"role": "system", "content": BLOCKNOTE_TOOL_STRICT_PROMPT}, - {"role": "user", "content": "hello"}, - ], - stream=True, + response = AIService().translate("

Hello

", "fr") + + assert response == {"answer": "Bonjour"} + call_args = mock_create.call_args + system_content = call_args[1]["messages"][0]["content"] + assert "French" in system_content or "fr" in system_content + + +@patch("openai.resources.chat.completions.Completions.create") +def test_services_ai_translate_unknown_language(mock_create): + """Translate with an unknown language code should use the code as-is.""" + + mock_create.return_value = MagicMock( + choices=[MagicMock(message=MagicMock(content="Translated"))] ) + + response = AIService().translate("

Hello

", "xx-unknown") + + assert response == {"answer": "Translated"} + call_args = mock_create.call_args + system_content = call_args[1]["messages"][0]["content"] + assert "xx-unknown" in system_content + + +# -- convert_async_generator_to_sync -- + + +def test_convert_async_generator_to_sync_basic(): + """Should convert an async generator yielding items to a sync iterator.""" + + async def async_gen(): + for item in ["hello", "world", "!"]: + yield item + + result = list(convert_async_generator_to_sync(async_gen())) + assert result == ["hello", "world", "!"] + + +def test_convert_async_generator_to_sync_empty(): + """Should handle an empty async generator.""" + + async def async_gen(): + return + yield + + result = list(convert_async_generator_to_sync(async_gen())) + assert not result + + +def test_convert_async_generator_to_sync_exception(): + """Should propagate exceptions from the async generator.""" + + async def async_gen(): + yield "first" + raise ValueError("async error") + + sync_iter = convert_async_generator_to_sync(async_gen()) + assert next(sync_iter) == "first" + + with pytest.raises(ValueError, match="async error"): + next(sync_iter) + + +# -- AIService.inject_document_state_messages -- + + +def test_inject_document_state_messages_no_metadata(): + """Messages without documentState metadata should pass through unchanged.""" + messages = [ + UIMessage(role="user", id="msg-1", parts=[TextUIPart(text="Hello")]), + ] + + result = AIService.inject_document_state_messages(messages) + + assert len(result) == 1 + assert result[0].id == "msg-1" + + +def test_inject_document_state_messages_with_selection(): + """A user message with documentState and selection should get an + assistant context message prepended.""" + messages = [ + UIMessage( + role="user", + id="msg-1", + parts=[TextUIPart(text="Fix this")], + metadata={ + "documentState": { + "selection": {"start": 0, "end": 5}, + "selectedBlocks": [{"type": "paragraph", "content": "Hello"}], + "blocks": [ + {"type": "paragraph", "content": "Hello"}, + {"type": "paragraph", "content": "World"}, + ], + } + }, + ), + ] + + result = AIService.inject_document_state_messages(messages) + + assert len(result) == 2 + # First message should be the injected assistant context + assert result[0].role == "assistant" + assert result[0].id == "assistant-document-state-msg-1" + assert len(result[0].parts) == 4 + assert "selection" in result[0].parts[0].text.lower() + # Second message should be the original user message + assert result[1].id == "msg-1" + + +def test_inject_document_state_messages_without_selection(): + """A user message with documentState but no selection should describe + the full document context.""" + messages = [ + UIMessage( + role="user", + id="msg-1", + parts=[TextUIPart(text="Summarize")], + metadata={ + "documentState": { + "selection": None, + "blocks": [ + {"type": "paragraph", "content": "Hello"}, + ], + "isEmptyDocument": False, + } + }, + ), + ] + + result = AIService.inject_document_state_messages(messages) + + assert len(result) == 2 + assistant_msg = result[0] + assert assistant_msg.role == "assistant" + assert len(assistant_msg.parts) == 2 + assert "no active selection" in assistant_msg.parts[0].text.lower() + assert "prefer updating" in assistant_msg.parts[0].text.lower() + + +def test_inject_document_state_messages_empty_document(): + """When the document is empty, the injected message should instruct + updating the empty block first.""" + messages = [ + UIMessage( + role="user", + id="msg-1", + parts=[TextUIPart(text="Write something")], + metadata={ + "documentState": { + "selection": None, + "blocks": [{"type": "paragraph", "content": ""}], + "isEmptyDocument": True, + } + }, + ), + ] + + result = AIService.inject_document_state_messages(messages) + + assert len(result) == 2 + assistant_msg = result[0] + assert "update the empty block" in assistant_msg.parts[0].text.lower() + + +def test_inject_document_state_messages_mixed(): + """Only user messages with documentState get assistant context; + other messages pass through unchanged.""" + messages = [ + UIMessage( + role="assistant", + id="msg-0", + parts=[TextUIPart(text="Previous response")], + ), + UIMessage( + role="user", + id="msg-1", + parts=[TextUIPart(text="Hello")], + ), + UIMessage( + role="user", + id="msg-2", + parts=[TextUIPart(text="Fix this")], + metadata={ + "documentState": { + "selection": {"start": 0, "end": 5}, + "selectedBlocks": [{"type": "paragraph", "content": "Hello"}], + "blocks": [{"type": "paragraph", "content": "Hello"}], + } + }, + ), + ] + + result = AIService.inject_document_state_messages(messages) + + # 3 original + 1 injected assistant message before msg-2 + assert len(result) == 4 + assert result[0].id == "msg-0" + assert result[1].id == "msg-1" + assert result[2].role == "assistant" + assert result[2].id == "assistant-document-state-msg-2" + assert result[3].id == "msg-2" + + +# -- AIService.tool_definitions_to_toolset -- + + +def test_tool_definitions_to_toolset(): + """Should convert frontend tool definitions to an ExternalToolset.""" + tool_definitions = { + "applyOperations": { + "description": "Apply operations to the document", + "inputSchema": { + "type": "object", + "properties": { + "operations": {"type": "array"}, + }, + }, + "outputSchema": {"type": "object"}, + }, + "insertBlocks": { + "description": "Insert blocks", + "inputSchema": {"type": "object"}, + }, + } + + toolset = AIService.tool_definitions_to_toolset(tool_definitions) + + # The ExternalToolset wraps ToolDefinition objects + assert toolset is not None + # Access internal tool definitions + tool_defs = toolset.tool_defs + assert len(tool_defs) == 2 + + names = {td.name for td in tool_defs} + assert names == {"applyOperations", "insertBlocks"} + + for td in tool_defs: + assert td.kind == "external" + if td.name == "applyOperations": + assert td.description == "Apply operations to the document" + assert td.metadata == {"output_schema": {"type": "object"}} + + +def test_tool_definitions_to_toolset_missing_fields(): + """Should handle tool definitions with missing optional fields.""" + tool_definitions = { + "myTool": {}, + } + + toolset = AIService.tool_definitions_to_toolset(tool_definitions) + + tool_defs = toolset.tool_defs + assert len(tool_defs) == 1 + assert tool_defs[0].name == "myTool" + assert tool_defs[0].description == "" + assert tool_defs[0].parameters_json_schema == {} + assert tool_defs[0].metadata == {"output_schema": None} + + +# -- AIService.stream -- + + +@patch.object(AIService, "_build_async_stream") +def test_services_ai_stream_sync_mode(mock_build, monkeypatch): + """In sync mode, stream() should return a sync iterator.""" + + async def mock_async_gen(): + yield "chunk1" + yield "chunk2" + + mock_build.return_value = mock_async_gen() + monkeypatch.setenv("PYTHON_SERVER_MODE", "sync") + + service = AIService() + request = MagicMock() + result = service.stream(request) + + # Should be a regular (sync) iterator, not async + assert not isinstance(result, AsyncIterator) + assert list(result) == ["chunk1", "chunk2"] + mock_build.assert_called_once_with(request) + + +@patch.object(AIService, "_build_async_stream") +def test_services_ai_stream_async_mode(mock_build, monkeypatch): + """In async mode, stream() should return the async iterator directly.""" + + async def mock_async_gen(): + yield "chunk1" + yield "chunk2" + + mock_async_iter = mock_async_gen() + mock_build.return_value = mock_async_iter + monkeypatch.setenv("PYTHON_SERVER_MODE", "async") + + service = AIService() + request = MagicMock() + result = service.stream(request) + + assert result is mock_async_iter + mock_build.assert_called_once_with(request) + + +@patch.object(AIService, "_build_async_stream") +def test_services_ai_stream_defaults_to_sync(mock_build, monkeypatch): + """When PYTHON_SERVER_MODE is not set, stream() should default to sync.""" + + async def mock_async_gen(): + yield "data" + + mock_build.return_value = mock_async_gen() + monkeypatch.delenv("PYTHON_SERVER_MODE", raising=False) + + service = AIService() + request = MagicMock() + result = service.stream(request) + + # Default should be sync mode + assert not isinstance(result, AsyncIterator) + assert list(result) == ["data"] + + +# -- AIService._build_async_stream -- + + +@patch("core.services.ai_services.VercelAIAdapter") +def test_services_ai_build_async_stream(mock_adapter_cls): + """_build_async_stream should build the pydantic-ai streaming pipeline.""" + + async def mock_encode(): + yield "event-data" + + mock_run_input = MagicMock() + mock_run_input.model_extra = None + mock_run_input.messages = [] + mock_adapter_cls.build_run_input.return_value = mock_run_input + + mock_adapter_instance = MagicMock() + mock_adapter_instance.run_stream.return_value = MagicMock() + mock_adapter_instance.encode_stream.return_value = mock_encode() + mock_adapter_cls.return_value = mock_adapter_instance + + service = AIService() + request = MagicMock() + request.META = {"HTTP_ACCEPT": "text/event-stream"} + request.raw_body = b'{"messages": []}' + + result = service._build_async_stream(request) + assert isinstance(result, AsyncIterator) + mock_adapter_cls.build_run_input.assert_called_once_with(b'{"messages": []}') + mock_adapter_instance.run_stream.assert_called_once() + mock_adapter_instance.encode_stream.assert_called_once() + + +@patch("core.services.ai_services.VercelAIAdapter") +def test_services_ai_build_async_stream_with_tool_definitions(mock_adapter_cls): + """_build_async_stream should build an ExternalToolset when + toolDefinitions are present in the request.""" + + async def mock_encode(): + yield "event-data" + + mock_run_input = MagicMock() + mock_run_input.model_extra = { + "toolDefinitions": { + "myTool": { + "description": "A tool", + "inputSchema": {"type": "object"}, + } + } + } + mock_run_input.messages = [] + mock_adapter_cls.build_run_input.return_value = mock_run_input + + mock_adapter_instance = MagicMock() + mock_adapter_instance.run_stream.return_value = MagicMock() + mock_adapter_instance.encode_stream.return_value = mock_encode() + mock_adapter_cls.return_value = mock_adapter_instance + + service = AIService() + request = MagicMock() + request.META = {} + request.raw_body = b"{}" + + service._build_async_stream(request) + # run_stream should have been called with a toolset + call_kwargs = mock_adapter_instance.run_stream.call_args[1] + assert call_kwargs["toolsets"] is not None + assert len(call_kwargs["toolsets"]) == 1 + + +@patch("core.services.ai_services.Agent") +@patch("core.services.ai_services.VercelAIAdapter") +def test_services_ai_build_async_stream_langfuse_enabled( + mock_adapter_cls, mock_agent_cls, settings +): + """When LANGFUSE_PUBLIC_KEY is set, instrument should be enabled.""" + settings.LANGFUSE_PUBLIC_KEY = "pk-test-123" + + async def mock_encode(): + yield "data" + + mock_run_input = MagicMock() + mock_run_input.model_extra = None + mock_run_input.messages = [] + mock_adapter_cls.build_run_input.return_value = mock_run_input + + mock_adapter_instance = MagicMock() + mock_adapter_instance.run_stream.return_value = MagicMock() + mock_adapter_instance.encode_stream.return_value = mock_encode() + mock_adapter_cls.return_value = mock_adapter_instance + + service = AIService() + request = MagicMock() + request.META = {} + request.raw_body = b"{}" + + service._build_async_stream(request) + mock_agent_cls.instrument_all.assert_called_once() + # Agent should be created with instrument=True + mock_agent_cls.assert_called_once() + assert mock_agent_cls.call_args[1]["instrument"] is True diff --git a/src/backend/impress/settings.py b/src/backend/impress/settings.py index de4c6212..45c2bf9b 100755 --- a/src/backend/impress/settings.py +++ b/src/backend/impress/settings.py @@ -326,6 +326,7 @@ class Base(Configuration): "django.middleware.csrf.CsrfViewMiddleware", "django.contrib.auth.middleware.AuthenticationMiddleware", "core.middleware.ForceSessionMiddleware", + "core.middleware.SaveRawBodyMiddleware", "django.contrib.messages.middleware.MessageMiddleware", "dockerflow.django.middleware.DockerflowMiddleware", "csp.middleware.CSPMiddleware", @@ -716,6 +717,9 @@ class Base(Configuration): AI_STREAM = values.BooleanValue( default=True, environ_name="AI_STREAM", environ_prefix=None ) + AI_VERCEL_SDK_VERSION = values.IntegerValue( + 6, environ_name="AI_VERCEL_SDK_VERSION", environ_prefix=None + ) AI_USER_RATE_THROTTLE_RATES = { "minute": 3, "hour": 50, diff --git a/src/backend/pyproject.toml b/src/backend/pyproject.toml index 33eaf4ed..ca3d7f71 100644 --- a/src/backend/pyproject.toml +++ b/src/backend/pyproject.toml @@ -53,9 +53,11 @@ dependencies = [ "markdown==3.10", "mozilla-django-oidc==5.0.2", "nested-multipart-parser==1.6.0", - "openai==2.14.0", + "openai==2.21.0", "psycopg[binary]==3.3.2", "pycrdt==0.12.44", + "pydantic==2.12.5", + "pydantic-ai-slim[openai,logfire,web]==1.58.0", "PyJWT==2.10.1", "python-magic==0.4.27", "redis<6.0.0",