(backend) use pydantic AI to manage vercel data stream protocol

The frontend application is using Vercel AI SDK and it's data stream
protocol. We decided to use the pydantic AI library to use it's vercel
ai adapter. It will make the payload validation, use AsyncIterator and
deal with vercel specification.
This commit is contained in:
Manuel Raynaud
2026-02-17 11:11:13 +01:00
committed by Anthony LC
parent 050b106a8f
commit 8ce216f6e8
9 changed files with 850 additions and 812 deletions

View File

@@ -15,6 +15,7 @@ These are the environment variables you can set for the `impress-backend` contai
| AI_FEATURE_ENABLED | Enable AI options | false | | AI_FEATURE_ENABLED | Enable AI options | false |
| AI_MODEL | AI Model to use | | | AI_MODEL | AI Model to use | |
| AI_STREAM | AI Stream to activate | true | | AI_STREAM | AI Stream to activate | true |
| AI_VERCEL_SDK_VERSION | The vercel AI SDK version used | 6 |
| ALLOW_LOGOUT_GET_METHOD | Allow get logout method | true | | ALLOW_LOGOUT_GET_METHOD | Allow get logout method | true |
| API_USERS_LIST_LIMIT | Limit on API users | 5 | | API_USERS_LIST_LIMIT | Limit on API users | 5 |
| API_USERS_LIST_THROTTLE_RATE_BURST | Throttle rate for api on burst | 30/minute | | API_USERS_LIST_THROTTLE_RATE_BURST | Throttle rate for api on burst | 30/minute |
@@ -166,4 +167,3 @@ Packages with licences incompatible with the MIT licence:
In `.env.development`, `PUBLISH_AS_MIT` is set to `false`, allowing developers to test Docs with all its features. In `.env.development`, `PUBLISH_AS_MIT` is set to `false`, allowing developers to test Docs with all its features.
⚠️ If you run Docs in production with `PUBLISH_AS_MIT` set to `false` make sure you fulfill your BlockNote licensing or [subscription](https://www.blocknotejs.org/about#partner-with-us) obligations. ⚠️ If you run Docs in production with `PUBLISH_AS_MIT` set to `false` make sure you fulfill your BlockNote licensing or [subscription](https://www.blocknotejs.org/about#partner-with-us) obligations.

View File

@@ -833,39 +833,6 @@ class AITranslateSerializer(serializers.Serializer):
return value return value
class AIProxySerializer(serializers.Serializer):
"""Serializer for AI proxy requests."""
messages = serializers.ListField(
required=True,
child=serializers.DictField(),
allow_empty=False,
)
model = serializers.CharField(required=True)
def validate_messages(self, messages):
"""Validate messages structure."""
# Ensure each message has the required fields
for message in messages:
if (
not isinstance(message, dict)
or "role" not in message
or "content" not in message
):
raise serializers.ValidationError(
"Each message must have 'role' and 'content' fields"
)
return messages
def validate_model(self, value):
"""Validate model value is the same than settings.AI_MODEL"""
if value != settings.AI_MODEL:
raise serializers.ValidationError(f"{value} is not a valid model")
return value
class MoveDocumentSerializer(serializers.Serializer): class MoveDocumentSerializer(serializers.Serializer):
""" """
Serializer for validating input data to move a document within the tree structure. Serializer for validating input data to move a document within the tree structure.

View File

@@ -38,6 +38,7 @@ from csp.decorators import csp_update
from lasuite.malware_detection import malware_detection from lasuite.malware_detection import malware_detection
from lasuite.oidc_login.decorators import refresh_oidc_access_token from lasuite.oidc_login.decorators import refresh_oidc_access_token
from lasuite.tools.email import get_domain_from_email from lasuite.tools.email import get_domain_from_email
from pydantic import ValidationError as PydanticValidationError
from rest_framework import filters, status, viewsets from rest_framework import filters, status, viewsets
from rest_framework import response as drf_response from rest_framework import response as drf_response
from rest_framework.permissions import AllowAny from rest_framework.permissions import AllowAny
@@ -1854,22 +1855,24 @@ class DocumentViewSet(
if not settings.AI_FEATURE_ENABLED: if not settings.AI_FEATURE_ENABLED:
raise ValidationError("AI feature is not enabled.") raise ValidationError("AI feature is not enabled.")
serializer = serializers.AIProxySerializer(data=request.data)
serializer.is_valid(raise_exception=True)
ai_service = AIService() ai_service = AIService()
if settings.AI_STREAM: try:
return StreamingHttpResponse( stream = ai_service.stream(request)
ai_service.stream(request.data), except PydanticValidationError as err:
content_type="text/event-stream", logger.info("pydantic validation error: %s", err)
status=drf.status.HTTP_200_OK, return drf.response.Response(
{"detail": "Invalid submitted payload"},
status=drf.status.HTTP_400_BAD_REQUEST,
) )
ai_response = ai_service.proxy(request.data) return StreamingHttpResponse(
return drf.response.Response( stream,
ai_response.model_dump(), content_type="text/event-stream",
status=drf.status.HTTP_200_OK, headers={
"x-vercel-ai-data-stream": "v1", # This header is used for Vercel AI streaming,
"X-Accel-Buffering": "no", # Prevent nginx buffering
},
) )
@drf.decorators.action( @drf.decorators.action(

View File

@@ -19,3 +19,21 @@ class ForceSessionMiddleware:
response = self.get_response(request) response = self.get_response(request)
return response return response
class SaveRawBodyMiddleware:
"""
Save the raw request body to use it later.
"""
def __init__(self, get_response):
"""Initialize the middleware."""
self.get_response = get_response
def __call__(self, request):
"""Save the raw request body in the request to use it later."""
if request.path.endswith(("/ai-proxy/", "/ai-proxy")):
request.raw_body = request.body
response = self.get_response(request)
return response

View File

@@ -1,55 +1,38 @@
"""AI services.""" """AI services."""
import asyncio
import json import json
import logging import logging
from typing import Any, Dict, Generator import os
import queue
import threading
from collections.abc import AsyncIterator, Iterator
from typing import Any, Dict, Union
from django.conf import settings from django.conf import settings
from django.core.exceptions import ImproperlyConfigured from django.core.exceptions import ImproperlyConfigured
from openai import OpenAIError from langfuse import get_client
from langfuse.openai import OpenAI as OpenAI_Langfuse
from pydantic_ai import Agent, DeferredToolRequests
from pydantic_ai.models.openai import OpenAIChatModel
from pydantic_ai.providers.openai import OpenAIProvider
from pydantic_ai.tools import ToolDefinition
from pydantic_ai.toolsets.external import ExternalToolset
from pydantic_ai.ui import SSE_CONTENT_TYPE
from pydantic_ai.ui.vercel_ai import VercelAIAdapter
from pydantic_ai.ui.vercel_ai.request_types import TextUIPart, UIMessage
from rest_framework.request import Request
from core import enums from core import enums
if settings.LANGFUSE_PUBLIC_KEY: if settings.LANGFUSE_PUBLIC_KEY:
from langfuse.openai import OpenAI OpenAI = OpenAI_Langfuse
else: else:
from openai import OpenAI from openai import OpenAI
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
BLOCKNOTE_TOOL_STRICT_PROMPT = """
You are editing a BlockNote document via the tool applyDocumentOperations.
You MUST respond ONLY by calling applyDocumentOperations.
The tool input MUST be valid JSON:
{ "operations": [ ... ] }
Each operation MUST include "type" and it MUST be one of:
- "update" (requires: id, block)
- "add" (requires: referenceId, position, blocks)
- "delete" (requires: id)
VALID SHAPES (FOLLOW EXACTLY):
Update:
{ "type":"update", "id":"<id$>", "block":"<p>...</p>" }
IMPORTANT: "block" MUST be a STRING containing a SINGLE valid HTML element.
Add:
{ "type":"add", "referenceId":"<id$>", "position":"before|after", "blocks":["<p>...</p>"] }
IMPORTANT: "blocks" MUST be an ARRAY OF STRINGS.
Each item MUST be a STRING containing a SINGLE valid HTML element.
Delete:
{ "type":"delete", "id":"<id$>" }
IDs ALWAYS end with "$". Use ids EXACTLY as provided.
Return ONLY the JSON tool input. No prose, no markdown.
"""
AI_ACTIONS = { AI_ACTIONS = {
"prompt": ( "prompt": (
"Answer the prompt using markdown formatting for structure and emphasis. " "Answer the prompt using markdown formatting for structure and emphasis. "
@@ -95,6 +78,40 @@ AI_TRANSLATE = (
) )
def convert_async_generator_to_sync(async_gen: AsyncIterator[str]) -> Iterator[str]:
"""Convert an async generator to a sync generator."""
q: queue.Queue[str | object] = queue.Queue()
sentinel = object()
exc_sentinel = object()
async def run_async_gen():
try:
async for async_item in async_gen:
q.put(async_item)
except Exception as exc: # pylint: disable=broad-except #noqa: BLE001
q.put((exc_sentinel, exc))
finally:
q.put(sentinel)
def start_async_loop():
asyncio.run(run_async_gen())
thread = threading.Thread(target=start_async_loop, daemon=True)
thread.start()
try:
while True:
item = q.get()
if item is sentinel:
break
if isinstance(item, tuple) and item[0] is exc_sentinel:
# re-raise the exception in the sync context
raise item[1]
yield item
finally:
thread.join()
class AIService: class AIService:
"""Service class for AI-related operations.""" """Service class for AI-related operations."""
@@ -136,77 +153,171 @@ class AIService:
system_content = AI_TRANSLATE.format(language=language_display) system_content = AI_TRANSLATE.format(language=language_display)
return self.call_ai_api(system_content, text) return self.call_ai_api(system_content, text)
def _normalize_tools(self, tools: list) -> list: @staticmethod
def inject_document_state_messages(
messages: list[UIMessage],
) -> list[UIMessage]:
"""Inject document state context before user messages.
Port of BlockNote's injectDocumentStateMessages.
For each user message carrying documentState metadata, an assistant
message describing the current document/selection state is prepended
so the LLM sees it as context.
""" """
Normalize tool definitions to ensure they have required fields. result: list[UIMessage] = []
""" for message in messages:
normalized = []
for tool in tools:
if isinstance(tool, dict) and tool.get("type") == "function":
fn = tool.get("function") or {}
if isinstance(fn, dict) and not fn.get("description"):
fn["description"] = f"Tool {fn.get('name', 'unknown')}."
tool["function"] = fn
normalized.append(tool)
return normalized
def _harden_payload(
self, payload: Dict[str, Any], stream: bool = False
) -> Dict[str, Any]:
"""Harden the AI API payload to enforce compliance and tool usage."""
payload["stream"] = stream
# Remove stream_options if stream is False
if not stream and "stream_options" in payload:
payload.pop("stream_options")
# Tools normalization
if isinstance(payload.get("tools"), list):
payload["tools"] = self._normalize_tools(payload["tools"])
# Inject strict system prompt once
msgs = payload.get("messages")
if isinstance(msgs, list):
need = True
if msgs and isinstance(msgs[0], dict) and msgs[0].get("role") == "system":
c = msgs[0].get("content") or ""
if ( if (
isinstance(c, str) message.role == "user"
and "applyDocumentOperations" in c and isinstance(message.metadata, dict)
and "blocks" in c and "documentState" in message.metadata
): ):
need = False doc_state = message.metadata["documentState"]
if need: selection = doc_state.get("selection")
payload["messages"] = [ blocks = doc_state.get("blocks")
{"role": "system", "content": BLOCKNOTE_TOOL_STRICT_PROMPT}
] + msgs
return payload if selection:
parts = [
def proxy(self, data: dict, stream: bool = False) -> Generator[str, None, None]: TextUIPart(
"""Proxy AI API requests to the configured AI provider.""" text=(
payload = self._harden_payload(data, stream=stream) "This is the latest state of the selection "
try: "(ignore previous selections, you MUST issue "
return self.client.chat.completions.create(**payload) "operations against this latest version of "
except OpenAIError as e: "the selection):"
raise RuntimeError(f"Failed to proxy AI request: {e}") from e ),
),
def stream(self, data: dict) -> Generator[str, None, None]: TextUIPart(
"""Stream AI API requests to the configured AI provider.""" text=json.dumps(doc_state.get("selectedBlocks")),
),
try: TextUIPart(
stream = self.proxy(data, stream=True) text=(
for chunk in stream: "This is the latest state of the entire "
try: "document (INCLUDING the selected text), you "
chunk_dict = ( "can use this to find the selected text to "
chunk.model_dump() if hasattr(chunk, "model_dump") else chunk "understand the context (but you MUST NOT "
"issue operations against this document, you "
"MUST issue operations against the selection):"
),
),
TextUIPart(text=json.dumps(blocks)),
]
else:
text = (
"There is no active selection. This is the latest "
"state of the document (ignore previous documents, "
"you MUST issue operations against this latest "
"version of the document). The cursor is BETWEEN "
"two blocks as indicated by cursor: true."
) )
chunk_json = json.dumps(chunk_dict) if doc_state.get("isEmptyDocument"):
yield f"data: {chunk_json}\n\n" text += (
except (AttributeError, TypeError) as e: "Because the document is empty, YOU MUST first "
log.error("Error serializing chunk: %s, chunk: %s", e, chunk) "update the empty block before adding new blocks."
continue )
except (OpenAIError, RuntimeError, OSError, ValueError) as e: else:
log.error("Streaming error: %s", e) text += (
"Prefer updating existing blocks over removing "
"and adding (but this also depends on the "
"user's question)."
)
parts = [
TextUIPart(text=text),
TextUIPart(text=json.dumps(blocks)),
]
yield "data: [DONE]\n\n" result.append(
UIMessage(
role="assistant",
id=f"assistant-document-state-{message.id}",
parts=parts,
)
)
result.append(message)
return result
@staticmethod
def tool_definitions_to_toolset(
tool_definitions: Dict[str, Any],
) -> ExternalToolset:
"""Convert serialized tool definitions to a pydantic-ai ExternalToolset.
Port of BlockNote's toolDefinitionsToToolSet.
Builds ToolDefinition objects from the JSON-Schema-based definitions
sent by the frontend and wraps them in an ExternalToolset so that
pydantic-ai advertises them to the LLM without trying to execute them
server-side (execution is deferred to the frontend).
"""
tool_defs = [
ToolDefinition(
name=name,
description=defn.get("description", ""),
parameters_json_schema=defn.get("inputSchema", {}),
kind="external",
metadata={
"output_schema": defn.get("outputSchema"),
},
)
for name, defn in tool_definitions.items()
]
return ExternalToolset(tool_defs)
def _build_async_stream(self, request: Request) -> AsyncIterator[str]:
"""Build the async stream from the AI provider."""
instrument_enabled = settings.LANGFUSE_PUBLIC_KEY is not None
if instrument_enabled:
langfuse = get_client()
langfuse.auth_check()
Agent.instrument_all()
model = OpenAIChatModel(
settings.AI_MODEL,
provider=OpenAIProvider(
base_url=settings.AI_BASE_URL, api_key=settings.AI_API_KEY
),
)
agent = Agent(model, instrument=instrument_enabled)
accept = request.META.get("HTTP_ACCEPT", SSE_CONTENT_TYPE)
run_input = VercelAIAdapter.build_run_input(request.raw_body)
# Inject document state context into the conversation
run_input.messages = self.inject_document_state_messages(run_input.messages)
# Build an ExternalToolset from frontend-supplied tool definitions
raw_tool_defs = (
run_input.model_extra.get("toolDefinitions")
if run_input.model_extra
else None
)
toolset = (
self.tool_definitions_to_toolset(raw_tool_defs) if raw_tool_defs else None
)
adapter = VercelAIAdapter(
agent=agent,
run_input=run_input,
accept=accept,
sdk_version=settings.AI_VERCEL_SDK_VERSION,
)
event_stream = adapter.run_stream(
output_type=[str, DeferredToolRequests] if toolset else None,
toolsets=[toolset] if toolset else None,
)
return adapter.encode_stream(event_stream)
def stream(self, request: Request) -> Union[AsyncIterator[str], Iterator[str]]:
"""Stream AI API requests to the configured AI provider.
Returns an async iterator when running in async mode (ASGI)
or a sync iterator when running in sync mode (WSGI).
"""
async_stream = self._build_async_stream(request)
if os.environ.get("PYTHON_SERVER_MODE", "sync") == "async":
return async_stream
return convert_async_generator_to_sync(async_stream)

View File

@@ -3,7 +3,7 @@ Test AI proxy API endpoint for users in impress's core app.
""" """
import random import random
from unittest.mock import MagicMock, patch from unittest.mock import patch
from django.test import override_settings from django.test import override_settings
@@ -11,7 +11,6 @@ import pytest
from rest_framework.test import APIClient from rest_framework.test import APIClient
from core import factories from core import factories
from core.services.ai_services import BLOCKNOTE_TOOL_STRICT_PROMPT
from core.tests.conftest import TEAM, USER, VIA from core.tests.conftest import TEAM, USER, VIA
pytestmark = pytest.mark.django_db pytestmark = pytest.mark.django_db
@@ -24,6 +23,8 @@ def ai_settings(settings):
settings.AI_BASE_URL = "http://localhost-ai:12345/" settings.AI_BASE_URL = "http://localhost-ai:12345/"
settings.AI_API_KEY = "test-key" settings.AI_API_KEY = "test-key"
settings.AI_FEATURE_ENABLED = True settings.AI_FEATURE_ENABLED = True
settings.LANGFUSE_PUBLIC_KEY = None
settings.AI_VERCEL_SDK_VERSION = 6
@override_settings( @override_settings(
@@ -51,7 +52,6 @@ def test_api_documents_ai_proxy_anonymous_forbidden(reach, role):
url, url,
{ {
"messages": [{"role": "user", "content": "Hello"}], "messages": [{"role": "user", "content": "Hello"}],
"model": "llama",
}, },
format="json", format="json",
) )
@@ -62,85 +62,48 @@ def test_api_documents_ai_proxy_anonymous_forbidden(reach, role):
} }
@override_settings(AI_ALLOW_REACH_FROM="public", AI_STREAM=False) @override_settings(AI_ALLOW_REACH_FROM="public")
@patch("openai.resources.chat.completions.Completions.create") @patch("core.services.ai_services.AIService.stream")
def test_api_documents_ai_proxy_anonymous_success(mock_create): def test_api_documents_ai_proxy_anonymous_success(mock_stream):
""" """
Anonymous users should be able to request AI proxy to a document Anonymous users should be able to request AI proxy to a document
if the link reach and role permit it. if the link reach and role permit it.
""" """
document = factories.DocumentFactory(link_reach="public", link_role="editor") document = factories.DocumentFactory(link_reach="public", link_role="editor")
mock_response = MagicMock() mock_stream.return_value = iter(["data: chunk1\n", "data: chunk2\n"])
mock_response.model_dump.return_value = {
"id": "chatcmpl-123",
"object": "chat.completion",
"created": 1677652288,
"model": "llama",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": "Hello! How can I help you?",
},
"finish_reason": "stop",
}
],
"usage": {"prompt_tokens": 9, "completion_tokens": 12, "total_tokens": 21},
}
mock_create.return_value = mock_response
url = f"/api/v1.0/documents/{document.id!s}/ai-proxy/" url = f"/api/v1.0/documents/{document.id!s}/ai-proxy/"
response = APIClient().post( response = APIClient().post(
url, url,
{ b"{}",
"messages": [{"role": "user", "content": "Hello"}], content_type="application/json",
"model": "llama",
},
format="json",
) )
assert response.status_code == 200 assert response.status_code == 200
response_data = response.json() assert response["Content-Type"] == "text/event-stream"
assert response_data["id"] == "chatcmpl-123" assert response["x-vercel-ai-data-stream"] == "v1"
assert response_data["model"] == "llama" assert response["X-Accel-Buffering"] == "no"
assert len(response_data["choices"]) == 1
assert (
response_data["choices"][0]["message"]["content"]
== "Hello! How can I help you?"
)
mock_create.assert_called_once_with( content = b"".join(response.streaming_content).decode()
messages=[ assert "chunk1" in content
{"role": "system", "content": BLOCKNOTE_TOOL_STRICT_PROMPT}, assert "chunk2" in content
{"role": "user", "content": "Hello"}, mock_stream.assert_called_once()
],
model="llama",
stream=False,
)
@override_settings(AI_ALLOW_REACH_FROM=random.choice(["authenticated", "restricted"])) @override_settings(AI_ALLOW_REACH_FROM=random.choice(["authenticated", "restricted"]))
@patch("openai.resources.chat.completions.Completions.create") def test_api_documents_ai_proxy_anonymous_limited_by_setting():
def test_api_documents_ai_proxy_anonymous_limited_by_setting(mock_create):
""" """
Anonymous users should not be able to request AI proxy to a document Anonymous users should not be able to request AI proxy to a document
if AI_ALLOW_REACH_FROM setting restricts it. if AI_ALLOW_REACH_FROM setting restricts it.
""" """
document = factories.DocumentFactory(link_reach="public", link_role="editor") document = factories.DocumentFactory(link_reach="public", link_role="editor")
mock_response = MagicMock()
mock_response.model_dump.return_value = {"content": "Hello!"}
mock_create.return_value = mock_response
url = f"/api/v1.0/documents/{document.id!s}/ai-proxy/" url = f"/api/v1.0/documents/{document.id!s}/ai-proxy/"
response = APIClient().post( response = APIClient().post(
url, url,
{ b"{}",
"messages": [{"role": "user", "content": "Hello"}], content_type="application/json",
"model": "llama",
},
format="json",
) )
assert response.status_code == 401 assert response.status_code == 401
@@ -170,11 +133,8 @@ def test_api_documents_ai_proxy_authenticated_forbidden(reach, role):
url = f"/api/v1.0/documents/{document.id!s}/ai-proxy/" url = f"/api/v1.0/documents/{document.id!s}/ai-proxy/"
response = client.post( response = client.post(
url, url,
{ b"{}",
"messages": [{"role": "user", "content": "Hello"}], content_type="application/json",
"model": "llama",
},
format="json",
) )
assert response.status_code == 403 assert response.status_code == 403
@@ -187,9 +147,8 @@ def test_api_documents_ai_proxy_authenticated_forbidden(reach, role):
("public", "editor"), ("public", "editor"),
], ],
) )
@override_settings(AI_STREAM=False) @patch("core.services.ai_services.AIService.stream")
@patch("openai.resources.chat.completions.Completions.create") def test_api_documents_ai_proxy_authenticated_success(mock_stream, reach, role):
def test_api_documents_ai_proxy_authenticated_success(mock_create, reach, role):
""" """
Authenticated users should be able to request AI proxy to a document Authenticated users should be able to request AI proxy to a document
if the link reach and role permit it. if the link reach and role permit it.
@@ -201,44 +160,18 @@ def test_api_documents_ai_proxy_authenticated_success(mock_create, reach, role):
document = factories.DocumentFactory(link_reach=reach, link_role=role) document = factories.DocumentFactory(link_reach=reach, link_role=role)
mock_response = MagicMock() mock_stream.return_value = iter(["data: response\n"])
mock_response.model_dump.return_value = {
"id": "chatcmpl-456",
"object": "chat.completion",
"model": "llama",
"choices": [
{
"index": 0,
"message": {"role": "assistant", "content": "Hi there!"},
"finish_reason": "stop",
}
],
}
mock_create.return_value = mock_response
url = f"/api/v1.0/documents/{document.id!s}/ai-proxy/" url = f"/api/v1.0/documents/{document.id!s}/ai-proxy/"
response = client.post( response = client.post(
url, url,
{ b"{}",
"messages": [{"role": "user", "content": "Hello"}], content_type="application/json",
"model": "llama",
},
format="json",
) )
assert response.status_code == 200 assert response.status_code == 200
response_data = response.json() assert response["Content-Type"] == "text/event-stream"
assert response_data["id"] == "chatcmpl-456" mock_stream.assert_called_once()
assert response_data["choices"][0]["message"]["content"] == "Hi there!"
mock_create.assert_called_once_with(
messages=[
{"role": "system", "content": BLOCKNOTE_TOOL_STRICT_PROMPT},
{"role": "user", "content": "Hello"},
],
model="llama",
stream=False,
)
@pytest.mark.parametrize("via", VIA) @pytest.mark.parametrize("via", VIA)
@@ -261,11 +194,8 @@ def test_api_documents_ai_proxy_reader(via, mock_user_teams):
url = f"/api/v1.0/documents/{document.id!s}/ai-proxy/" url = f"/api/v1.0/documents/{document.id!s}/ai-proxy/"
response = client.post( response = client.post(
url, url,
{ b"{}",
"messages": [{"role": "user", "content": "Hello"}], content_type="application/json",
"model": "llama",
},
format="json",
) )
assert response.status_code == 403 assert response.status_code == 403
@@ -273,9 +203,8 @@ def test_api_documents_ai_proxy_reader(via, mock_user_teams):
@pytest.mark.parametrize("role", ["editor", "administrator", "owner"]) @pytest.mark.parametrize("role", ["editor", "administrator", "owner"])
@pytest.mark.parametrize("via", VIA) @pytest.mark.parametrize("via", VIA)
@override_settings(AI_STREAM=False) @patch("core.services.ai_services.AIService.stream")
@patch("openai.resources.chat.completions.Completions.create") def test_api_documents_ai_proxy_success(mock_stream, via, role, mock_user_teams):
def test_api_documents_ai_proxy_success(mock_create, via, role, mock_user_teams):
"""Users with sufficient permissions should be able to request AI proxy.""" """Users with sufficient permissions should be able to request AI proxy."""
user = factories.UserFactory() user = factories.UserFactory()
@@ -291,402 +220,27 @@ def test_api_documents_ai_proxy_success(mock_create, via, role, mock_user_teams)
document=document, team="lasuite", role=role document=document, team="lasuite", role=role
) )
mock_response = MagicMock() mock_stream.return_value = iter(["data: success\n"])
mock_response.model_dump.return_value = {
"id": "chatcmpl-789",
"object": "chat.completion",
"model": "llama",
"choices": [
{
"index": 0,
"message": {"role": "assistant", "content": "Success!"},
"finish_reason": "stop",
}
],
}
mock_create.return_value = mock_response
url = f"/api/v1.0/documents/{document.id!s}/ai-proxy/" url = f"/api/v1.0/documents/{document.id!s}/ai-proxy/"
response = client.post( response = client.post(
url, url,
{ b"{}",
"messages": [{"role": "user", "content": "Test message"}], content_type="application/json",
"model": "llama",
},
format="json",
) )
assert response.status_code == 200 assert response.status_code == 200
response_data = response.json() assert response["Content-Type"] == "text/event-stream"
assert response_data["id"] == "chatcmpl-789" assert response["x-vercel-ai-data-stream"] == "v1"
assert response_data["choices"][0]["message"]["content"] == "Success!" assert response["X-Accel-Buffering"] == "no"
mock_create.assert_called_once_with( content = b"".join(response.streaming_content).decode()
messages=[ assert "success" in content
{"role": "system", "content": BLOCKNOTE_TOOL_STRICT_PROMPT}, mock_stream.assert_called_once()
{"role": "user", "content": "Test message"},
],
model="llama",
stream=False,
)
def test_api_documents_ai_proxy_empty_messages():
"""The messages should not be empty when requesting AI proxy."""
user = factories.UserFactory()
client = APIClient()
client.force_login(user)
document = factories.DocumentFactory(link_reach="public", link_role="editor")
url = f"/api/v1.0/documents/{document.id!s}/ai-proxy/"
response = client.post(url, {"messages": [], "model": "llama"}, format="json")
assert response.status_code == 400
assert response.json() == {"messages": ["This list may not be empty."]}
def test_api_documents_ai_proxy_missing_model():
"""The model should be required when requesting AI proxy."""
user = factories.UserFactory()
client = APIClient()
client.force_login(user)
document = factories.DocumentFactory(link_reach="public", link_role="editor")
url = f"/api/v1.0/documents/{document.id!s}/ai-proxy/"
response = client.post(
url, {"messages": [{"role": "user", "content": "Hello"}]}, format="json"
)
assert response.status_code == 400
assert response.json() == {"model": ["This field is required."]}
def test_api_documents_ai_proxy_invalid_message_format():
"""Messages should have the correct format when requesting AI proxy."""
user = factories.UserFactory()
client = APIClient()
client.force_login(user)
document = factories.DocumentFactory(link_reach="public", link_role="editor")
url = f"/api/v1.0/documents/{document.id!s}/ai-proxy/"
# Test with invalid message format (missing role)
response = client.post(
url,
{
"messages": [{"content": "Hello"}],
"model": "llama",
},
format="json",
)
assert response.status_code == 400
assert response.json() == {
"messages": ["Each message must have 'role' and 'content' fields"]
}
# Test with invalid message format (missing content)
response = client.post(
url,
{
"messages": [{"role": "user"}],
"model": "llama",
},
format="json",
)
assert response.status_code == 400
assert response.json() == {
"messages": ["Each message must have 'role' and 'content' fields"]
}
# Test with non-dict message
response = client.post(
url,
{
"messages": ["invalid"],
"model": "llama",
},
format="json",
)
assert response.status_code == 400
assert response.json() == {
"messages": {"0": ['Expected a dictionary of items but got type "str".']}
}
@override_settings(AI_STREAM=False)
@patch("openai.resources.chat.completions.Completions.create")
def test_api_documents_ai_proxy_stream_disabled(mock_create):
"""Stream should be automatically disabled in AI proxy requests."""
user = factories.UserFactory()
client = APIClient()
client.force_login(user)
document = factories.DocumentFactory(link_reach="public", link_role="editor")
mock_response = MagicMock()
mock_response.model_dump.return_value = {"content": "Success!"}
mock_create.return_value = mock_response
url = f"/api/v1.0/documents/{document.id!s}/ai-proxy/"
response = client.post(
url,
{
"messages": [{"role": "user", "content": "Hello"}],
"model": "llama",
"stream": True, # This should be overridden to False
},
format="json",
)
assert response.status_code == 200
# Verify that stream was set to False
mock_create.assert_called_once_with(
messages=[
{"role": "system", "content": BLOCKNOTE_TOOL_STRICT_PROMPT},
{"role": "user", "content": "Hello"},
],
model="llama",
stream=False,
)
@override_settings(AI_STREAM=False)
@patch("openai.resources.chat.completions.Completions.create")
def test_api_documents_ai_proxy_additional_parameters(mock_create):
"""AI proxy should pass through additional parameters to the AI service."""
user = factories.UserFactory()
client = APIClient()
client.force_login(user)
document = factories.DocumentFactory(link_reach="public", link_role="editor")
mock_response = MagicMock()
mock_response.model_dump.return_value = {"content": "Success!"}
mock_create.return_value = mock_response
url = f"/api/v1.0/documents/{document.id!s}/ai-proxy/"
response = client.post(
url,
{
"messages": [{"role": "user", "content": "Hello"}],
"model": "llama",
"temperature": 0.7,
"max_tokens": 100,
"top_p": 0.9,
},
format="json",
)
assert response.status_code == 200
# Verify that additional parameters were passed through
mock_create.assert_called_once_with(
messages=[
{"role": "system", "content": BLOCKNOTE_TOOL_STRICT_PROMPT},
{"role": "user", "content": "Hello"},
],
model="llama",
temperature=0.7,
max_tokens=100,
top_p=0.9,
stream=False,
)
@override_settings(AI_STREAM=False)
@override_settings(AI_DOCUMENT_RATE_THROTTLE_RATES={"minute": 3, "hour": 6, "day": 10})
@patch("openai.resources.chat.completions.Completions.create")
def test_api_documents_ai_proxy_throttling_document(mock_create):
"""
Throttling per document should be triggered on the AI transform endpoint.
For full throttle class test see: `test_api_utils_ai_document_rate_throttles`
"""
client = APIClient()
document = factories.DocumentFactory(link_reach="public", link_role="editor")
mock_response = MagicMock()
mock_response.model_dump.return_value = {"content": "Success!"}
mock_create.return_value = mock_response
url = f"/api/v1.0/documents/{document.id!s}/ai-proxy/"
for _ in range(3):
user = factories.UserFactory()
client.force_login(user)
response = client.post(
url,
{
"messages": [{"role": "user", "content": "Test message"}],
"model": "llama",
},
format="json",
)
assert response.status_code == 200
assert response.json() == {"content": "Success!"}
user = factories.UserFactory()
client.force_login(user)
response = client.post(
url,
{
"messages": [{"role": "user", "content": "Test message"}],
"model": "llama",
},
)
assert response.status_code == 429
assert response.json() == {
"detail": "Request was throttled. Expected available in 60 seconds."
}
@override_settings(AI_STREAM=False)
@patch("openai.resources.chat.completions.Completions.create")
def test_api_documents_ai_proxy_complex_conversation(mock_create):
"""AI proxy should handle complex conversations with multiple messages."""
user = factories.UserFactory()
client = APIClient()
client.force_login(user)
document = factories.DocumentFactory(link_reach="public", link_role="editor")
mock_response = MagicMock()
mock_response.model_dump.return_value = {
"id": "chatcmpl-complex",
"object": "chat.completion",
"model": "llama",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": "I understand your question about Python.",
},
"finish_reason": "stop",
}
],
}
mock_create.return_value = mock_response
complex_messages = [
{"role": "system", "content": BLOCKNOTE_TOOL_STRICT_PROMPT},
{"role": "system", "content": "You are a helpful programming assistant."},
{"role": "user", "content": "How do I write a for loop in Python?"},
{
"role": "assistant",
"content": "You can write a for loop using: for item in iterable:",
},
{"role": "user", "content": "Can you give me a concrete example?"},
]
url = f"/api/v1.0/documents/{document.id!s}/ai-proxy/"
response = client.post(
url,
{
"messages": complex_messages,
"model": "llama",
},
format="json",
)
assert response.status_code == 200
response_data = response.json()
assert response_data["id"] == "chatcmpl-complex"
assert (
response_data["choices"][0]["message"]["content"]
== "I understand your question about Python."
)
mock_create.assert_called_once_with(
messages=complex_messages,
model="llama",
stream=False,
)
@override_settings(AI_USER_RATE_THROTTLE_RATES={"minute": 3, "hour": 6, "day": 10})
@patch("openai.resources.chat.completions.Completions.create")
def test_api_documents_ai_proxy_throttling_user(mock_create):
"""
Throttling per user should be triggered on the AI proxy endpoint.
For full throttle class test see: `test_api_utils_ai_user_rate_throttles`
"""
user = factories.UserFactory()
client = APIClient()
client.force_login(user)
mock_response = MagicMock()
mock_response.model_dump.return_value = {"content": "Success!"}
mock_create.return_value = mock_response
for _ in range(3):
document = factories.DocumentFactory(link_reach="public", link_role="editor")
url = f"/api/v1.0/documents/{document.id!s}/ai-proxy/"
response = client.post(
url,
{
"messages": [{"role": "user", "content": "Hello"}],
"model": "llama",
},
format="json",
)
assert response.status_code == 200
document = factories.DocumentFactory(link_reach="public", link_role="editor")
url = f"/api/v1.0/documents/{document.id!s}/ai-proxy/"
response = client.post(
url,
{
"messages": [{"role": "user", "content": "Hello"}],
"model": "llama",
},
format="json",
)
assert response.status_code == 429
assert response.json() == {
"detail": "Request was throttled. Expected available in 60 seconds."
}
@override_settings(AI_USER_RATE_THROTTLE_RATES={"minute": 10, "hour": 6, "day": 10})
def test_api_documents_ai_proxy_different_models():
"""AI proxy should work with different AI models."""
user = factories.UserFactory()
client = APIClient()
client.force_login(user)
document = factories.DocumentFactory(link_reach="public", link_role="editor")
models_to_test = ["gpt-3.5-turbo", "gpt-4", "claude-3", "llama-2"]
for model_name in models_to_test:
response = client.post(
f"/api/v1.0/documents/{document.id!s}/ai-proxy/",
{
"messages": [{"role": "user", "content": "Hello"}],
"model": model_name,
},
format="json",
)
assert response.status_code == 400
assert response.json() == {"model": [f"{model_name} is not a valid model"]}
def test_api_documents_ai_proxy_ai_feature_disabled(settings): def test_api_documents_ai_proxy_ai_feature_disabled(settings):
"""When the settings AI_FEATURE_ENABLED is set to False, the endpoint is not reachable.""" """When AI_FEATURE_ENABLED is False, the endpoint returns 400."""
settings.AI_FEATURE_ENABLED = False settings.AI_FEATURE_ENABLED = False
user = factories.UserFactory() user = factories.UserFactory()
@@ -698,25 +252,91 @@ def test_api_documents_ai_proxy_ai_feature_disabled(settings):
response = client.post( response = client.post(
f"/api/v1.0/documents/{document.id!s}/ai-proxy/", f"/api/v1.0/documents/{document.id!s}/ai-proxy/",
{ b"{}",
"messages": [{"role": "user", "content": "Hello"}], content_type="application/json",
"model": "llama",
},
format="json",
) )
assert response.status_code == 400 assert response.status_code == 400
assert response.json() == ["AI feature is not enabled."] assert response.json() == ["AI feature is not enabled."]
@override_settings(AI_STREAM=False) @override_settings(AI_DOCUMENT_RATE_THROTTLE_RATES={"minute": 3, "hour": 6, "day": 10})
@patch("openai.resources.chat.completions.Completions.create") @patch("core.services.ai_services.AIService.stream")
def test_api_documents_ai_proxy_harden_payload(mock_create): def test_api_documents_ai_proxy_throttling_document(mock_stream):
""" """
AI proxy should harden the payload by default: Throttling per document should be triggered on the AI proxy endpoint.
- it will remove stream_options if stream is False For full throttle class test see: `test_api_utils_ai_document_rate_throttles`
- normalize tools by adding descriptions if missing
""" """
client = APIClient()
document = factories.DocumentFactory(link_reach="public", link_role="editor")
mock_stream.return_value = iter(["data: ok\n"])
url = f"/api/v1.0/documents/{document.id!s}/ai-proxy/"
for _ in range(3):
mock_stream.return_value = iter(["data: ok\n"])
user = factories.UserFactory()
client.force_login(user)
response = client.post(
url,
b"{}",
content_type="application/json",
)
assert response.status_code == 200
user = factories.UserFactory()
client.force_login(user)
response = client.post(
url,
b"{}",
content_type="application/json",
)
assert response.status_code == 429
assert response.json() == {
"detail": "Request was throttled. Expected available in 60 seconds."
}
@override_settings(AI_USER_RATE_THROTTLE_RATES={"minute": 3, "hour": 6, "day": 10})
@patch("core.services.ai_services.AIService.stream")
def test_api_documents_ai_proxy_throttling_user(mock_stream):
"""
Throttling per user should be triggered on the AI proxy endpoint.
For full throttle class test see: `test_api_utils_ai_user_rate_throttles`
"""
user = factories.UserFactory()
client = APIClient()
client.force_login(user)
for _ in range(3):
mock_stream.return_value = iter(["data: ok\n"])
document = factories.DocumentFactory(link_reach="public", link_role="editor")
url = f"/api/v1.0/documents/{document.id!s}/ai-proxy/"
response = client.post(
url,
b"{}",
content_type="application/json",
)
assert response.status_code == 200
document = factories.DocumentFactory(link_reach="public", link_role="editor")
url = f"/api/v1.0/documents/{document.id!s}/ai-proxy/"
response = client.post(
url,
b"{}",
content_type="application/json",
)
assert response.status_code == 429
assert response.json() == {
"detail": "Request was throttled. Expected available in 60 seconds."
}
@patch("core.services.ai_services.AIService.stream")
def test_api_documents_ai_proxy_returns_streaming_response(mock_stream):
"""AI proxy should return a StreamingHttpResponse with correct headers."""
user = factories.UserFactory() user = factories.UserFactory()
client = APIClient() client = APIClient()
@@ -724,76 +344,39 @@ def test_api_documents_ai_proxy_harden_payload(mock_create):
document = factories.DocumentFactory(link_reach="public", link_role="editor") document = factories.DocumentFactory(link_reach="public", link_role="editor")
mock_response = MagicMock() mock_stream.return_value = iter(["data: part1\n", "data: part2\n", "data: part3\n"])
mock_response.model_dump.return_value = {
"id": "chatcmpl-789",
"object": "chat.completion",
"model": "llama",
"choices": [
{
"index": 0,
"message": {"role": "assistant", "content": "Success!"},
"finish_reason": "stop",
}
],
}
mock_create.return_value = mock_response
url = f"/api/v1.0/documents/{document.id!s}/ai-proxy/" url = f"/api/v1.0/documents/{document.id!s}/ai-proxy/"
response = client.post( response = client.post(
url, url,
{ b"{}",
"messages": [{"role": "user", "content": "Hello"}], content_type="application/json",
"model": "llama",
"stream": False,
"stream_options": {"include_usage": True},
"tools": [
{
"type": "function",
"function": {
"name": "applyDocumentOperations",
"parameters": {
"type": "object",
"properties": {
"operations": {"type": "array", "items": {"anyOf": []}}
},
"additionalProperties": False,
"required": ["operations"],
},
},
}
],
},
format="json",
) )
assert response.status_code == 200 assert response.status_code == 200
response_data = response.json() assert response["Content-Type"] == "text/event-stream"
assert response_data["id"] == "chatcmpl-789" assert response["x-vercel-ai-data-stream"] == "v1"
assert response_data["choices"][0]["message"]["content"] == "Success!" assert response["X-Accel-Buffering"] == "no"
mock_create.assert_called_once_with( chunks = list(response.streaming_content)
messages=[ assert len(chunks) == 3
{"role": "system", "content": BLOCKNOTE_TOOL_STRICT_PROMPT},
{"role": "user", "content": "Hello"},
], def test_api_documents_ai_proxy_invalid_payload():
model="llama", """AI Proxy should return a 400 if the payload is invalid."""
stream=False,
tools=[ user = factories.UserFactory()
{
"type": "function", document = factories.DocumentFactory(users=[(user, "owner")])
"function": {
"name": "applyDocumentOperations", client = APIClient()
"parameters": { client.force_login(user)
"type": "object",
"properties": { response = client.post(
"operations": {"type": "array", "items": {"anyOf": []}} f"/api/v1.0/documents/{document.id!s}/ai-proxy/",
}, b'{"foo": "bar", "trigger": "submit-message"}',
"additionalProperties": False, content_type="application/json",
"required": ["operations"],
},
"description": "Tool applyDocumentOperations.",
},
}
],
) )
assert response.status_code == 400
assert response.json() == {"detail": "Invalid submitted payload"}

View File

@@ -1,7 +1,9 @@
""" """
Test ai API endpoints in the impress core app. Test AI services in the impress core app.
""" """
# pylint: disable=protected-access
from collections.abc import AsyncIterator
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
from django.core.exceptions import ImproperlyConfigured from django.core.exceptions import ImproperlyConfigured
@@ -9,8 +11,12 @@ from django.test.utils import override_settings
import pytest import pytest
from openai import OpenAIError from openai import OpenAIError
from pydantic_ai.ui.vercel_ai.request_types import TextUIPart, UIMessage
from core.services.ai_services import BLOCKNOTE_TOOL_STRICT_PROMPT, AIService from core.services.ai_services import (
AIService,
convert_async_generator_to_sync,
)
pytestmark = pytest.mark.django_db pytestmark = pytest.mark.django_db
@@ -22,6 +28,11 @@ def ai_settings(settings):
settings.AI_BASE_URL = "http://example.com" settings.AI_BASE_URL = "http://example.com"
settings.AI_API_KEY = "test-key" settings.AI_API_KEY = "test-key"
settings.AI_FEATURE_ENABLED = True settings.AI_FEATURE_ENABLED = True
settings.LANGFUSE_PUBLIC_KEY = None
settings.AI_VERCEL_SDK_VERSION = 6
# -- AIService.__init__ --
@pytest.mark.parametrize( @pytest.mark.parametrize(
@@ -43,6 +54,9 @@ def test_services_ai_setting_missing(setting_name, setting_value, settings):
AIService() AIService()
# -- AIService.transform --
@override_settings( @override_settings(
AI_BASE_URL="http://example.com", AI_API_KEY="test-key", AI_MODEL="test-model" AI_BASE_URL="http://example.com", AI_API_KEY="test-key", AI_MODEL="test-model"
) )
@@ -59,19 +73,6 @@ def test_services_ai_client_error(mock_create):
AIService().transform("hello", "prompt") AIService().transform("hello", "prompt")
@patch("openai.resources.chat.completions.Completions.create")
def test_services_ai_proxy_client_error(mock_create):
"""Fail when the client raises an error"""
mock_create.side_effect = OpenAIError("Mocked client error")
with pytest.raises(
RuntimeError,
match="Failed to proxy AI request: Mocked client error",
):
AIService().proxy({"messages": [{"role": "user", "content": "hello"}]})
@override_settings( @override_settings(
AI_BASE_URL="http://example.com", AI_API_KEY="test-key", AI_MODEL="test-model" AI_BASE_URL="http://example.com", AI_API_KEY="test-key", AI_MODEL="test-model"
) )
@@ -90,49 +91,6 @@ def test_services_ai_client_invalid_response(mock_create):
AIService().transform("hello", "prompt") AIService().transform("hello", "prompt")
@patch("openai.resources.chat.completions.Completions.create")
def test_services_ai_proxy_success(mock_create):
"""The AI request should work as expect when called with valid arguments."""
mock_create.return_value = {
"id": "chatcmpl-test",
"object": "chat.completion",
"created": 1234567890,
"model": "test-model",
"choices": [
{
"index": 0,
"message": {"role": "assistant", "content": "Salut"},
"finish_reason": "stop",
}
],
}
response = AIService().proxy({"messages": [{"role": "user", "content": "hello"}]})
expected_response = {
"id": "chatcmpl-test",
"object": "chat.completion",
"created": 1234567890,
"model": "test-model",
"choices": [
{
"index": 0,
"message": {"role": "assistant", "content": "Salut"},
"finish_reason": "stop",
}
],
}
assert response == expected_response
mock_create.assert_called_once_with(
messages=[
{"role": "system", "content": BLOCKNOTE_TOOL_STRICT_PROMPT},
{"role": "user", "content": "hello"},
],
stream=False,
)
@override_settings( @override_settings(
AI_BASE_URL="http://example.com", AI_API_KEY="test-key", AI_MODEL="test-model" AI_BASE_URL="http://example.com", AI_API_KEY="test-key", AI_MODEL="test-model"
) )
@@ -149,46 +107,438 @@ def test_services_ai_success(mock_create):
assert response == {"answer": "Salut"} assert response == {"answer": "Salut"}
# -- AIService.translate --
@patch("openai.resources.chat.completions.Completions.create") @patch("openai.resources.chat.completions.Completions.create")
def test_services_ai_proxy_with_stream(mock_create): def test_services_ai_translate_success(mock_create):
"""The AI request should work as expect when called with valid arguments.""" """Translate should call the AI API with the correct language prompt."""
mock_create.return_value = { mock_create.return_value = MagicMock(
"id": "chatcmpl-test", choices=[MagicMock(message=MagicMock(content="Bonjour"))]
"object": "chat.completion",
"created": 1234567890,
"model": "test-model",
"choices": [
{
"index": 0,
"message": {"role": "assistant", "content": "Salut"},
"finish_reason": "stop",
}
],
}
response = AIService().proxy(
{"messages": [{"role": "user", "content": "hello"}]}, stream=True
) )
expected_response = { response = AIService().translate("<p>Hello</p>", "fr")
"id": "chatcmpl-test",
"object": "chat.completion", assert response == {"answer": "Bonjour"}
"created": 1234567890, call_args = mock_create.call_args
"model": "test-model", system_content = call_args[1]["messages"][0]["content"]
"choices": [ assert "French" in system_content or "fr" in system_content
{
"index": 0,
"message": {"role": "assistant", "content": "Salut"}, @patch("openai.resources.chat.completions.Completions.create")
"finish_reason": "stop", def test_services_ai_translate_unknown_language(mock_create):
} """Translate with an unknown language code should use the code as-is."""
],
} mock_create.return_value = MagicMock(
assert response == expected_response choices=[MagicMock(message=MagicMock(content="Translated"))]
mock_create.assert_called_once_with(
messages=[
{"role": "system", "content": BLOCKNOTE_TOOL_STRICT_PROMPT},
{"role": "user", "content": "hello"},
],
stream=True,
) )
response = AIService().translate("<p>Hello</p>", "xx-unknown")
assert response == {"answer": "Translated"}
call_args = mock_create.call_args
system_content = call_args[1]["messages"][0]["content"]
assert "xx-unknown" in system_content
# -- convert_async_generator_to_sync --
def test_convert_async_generator_to_sync_basic():
"""Should convert an async generator yielding items to a sync iterator."""
async def async_gen():
for item in ["hello", "world", "!"]:
yield item
result = list(convert_async_generator_to_sync(async_gen()))
assert result == ["hello", "world", "!"]
def test_convert_async_generator_to_sync_empty():
"""Should handle an empty async generator."""
async def async_gen():
return
yield
result = list(convert_async_generator_to_sync(async_gen()))
assert not result
def test_convert_async_generator_to_sync_exception():
"""Should propagate exceptions from the async generator."""
async def async_gen():
yield "first"
raise ValueError("async error")
sync_iter = convert_async_generator_to_sync(async_gen())
assert next(sync_iter) == "first"
with pytest.raises(ValueError, match="async error"):
next(sync_iter)
# -- AIService.inject_document_state_messages --
def test_inject_document_state_messages_no_metadata():
"""Messages without documentState metadata should pass through unchanged."""
messages = [
UIMessage(role="user", id="msg-1", parts=[TextUIPart(text="Hello")]),
]
result = AIService.inject_document_state_messages(messages)
assert len(result) == 1
assert result[0].id == "msg-1"
def test_inject_document_state_messages_with_selection():
"""A user message with documentState and selection should get an
assistant context message prepended."""
messages = [
UIMessage(
role="user",
id="msg-1",
parts=[TextUIPart(text="Fix this")],
metadata={
"documentState": {
"selection": {"start": 0, "end": 5},
"selectedBlocks": [{"type": "paragraph", "content": "Hello"}],
"blocks": [
{"type": "paragraph", "content": "Hello"},
{"type": "paragraph", "content": "World"},
],
}
},
),
]
result = AIService.inject_document_state_messages(messages)
assert len(result) == 2
# First message should be the injected assistant context
assert result[0].role == "assistant"
assert result[0].id == "assistant-document-state-msg-1"
assert len(result[0].parts) == 4
assert "selection" in result[0].parts[0].text.lower()
# Second message should be the original user message
assert result[1].id == "msg-1"
def test_inject_document_state_messages_without_selection():
"""A user message with documentState but no selection should describe
the full document context."""
messages = [
UIMessage(
role="user",
id="msg-1",
parts=[TextUIPart(text="Summarize")],
metadata={
"documentState": {
"selection": None,
"blocks": [
{"type": "paragraph", "content": "Hello"},
],
"isEmptyDocument": False,
}
},
),
]
result = AIService.inject_document_state_messages(messages)
assert len(result) == 2
assistant_msg = result[0]
assert assistant_msg.role == "assistant"
assert len(assistant_msg.parts) == 2
assert "no active selection" in assistant_msg.parts[0].text.lower()
assert "prefer updating" in assistant_msg.parts[0].text.lower()
def test_inject_document_state_messages_empty_document():
"""When the document is empty, the injected message should instruct
updating the empty block first."""
messages = [
UIMessage(
role="user",
id="msg-1",
parts=[TextUIPart(text="Write something")],
metadata={
"documentState": {
"selection": None,
"blocks": [{"type": "paragraph", "content": ""}],
"isEmptyDocument": True,
}
},
),
]
result = AIService.inject_document_state_messages(messages)
assert len(result) == 2
assistant_msg = result[0]
assert "update the empty block" in assistant_msg.parts[0].text.lower()
def test_inject_document_state_messages_mixed():
"""Only user messages with documentState get assistant context;
other messages pass through unchanged."""
messages = [
UIMessage(
role="assistant",
id="msg-0",
parts=[TextUIPart(text="Previous response")],
),
UIMessage(
role="user",
id="msg-1",
parts=[TextUIPart(text="Hello")],
),
UIMessage(
role="user",
id="msg-2",
parts=[TextUIPart(text="Fix this")],
metadata={
"documentState": {
"selection": {"start": 0, "end": 5},
"selectedBlocks": [{"type": "paragraph", "content": "Hello"}],
"blocks": [{"type": "paragraph", "content": "Hello"}],
}
},
),
]
result = AIService.inject_document_state_messages(messages)
# 3 original + 1 injected assistant message before msg-2
assert len(result) == 4
assert result[0].id == "msg-0"
assert result[1].id == "msg-1"
assert result[2].role == "assistant"
assert result[2].id == "assistant-document-state-msg-2"
assert result[3].id == "msg-2"
# -- AIService.tool_definitions_to_toolset --
def test_tool_definitions_to_toolset():
"""Should convert frontend tool definitions to an ExternalToolset."""
tool_definitions = {
"applyOperations": {
"description": "Apply operations to the document",
"inputSchema": {
"type": "object",
"properties": {
"operations": {"type": "array"},
},
},
"outputSchema": {"type": "object"},
},
"insertBlocks": {
"description": "Insert blocks",
"inputSchema": {"type": "object"},
},
}
toolset = AIService.tool_definitions_to_toolset(tool_definitions)
# The ExternalToolset wraps ToolDefinition objects
assert toolset is not None
# Access internal tool definitions
tool_defs = toolset.tool_defs
assert len(tool_defs) == 2
names = {td.name for td in tool_defs}
assert names == {"applyOperations", "insertBlocks"}
for td in tool_defs:
assert td.kind == "external"
if td.name == "applyOperations":
assert td.description == "Apply operations to the document"
assert td.metadata == {"output_schema": {"type": "object"}}
def test_tool_definitions_to_toolset_missing_fields():
"""Should handle tool definitions with missing optional fields."""
tool_definitions = {
"myTool": {},
}
toolset = AIService.tool_definitions_to_toolset(tool_definitions)
tool_defs = toolset.tool_defs
assert len(tool_defs) == 1
assert tool_defs[0].name == "myTool"
assert tool_defs[0].description == ""
assert tool_defs[0].parameters_json_schema == {}
assert tool_defs[0].metadata == {"output_schema": None}
# -- AIService.stream --
@patch.object(AIService, "_build_async_stream")
def test_services_ai_stream_sync_mode(mock_build, monkeypatch):
"""In sync mode, stream() should return a sync iterator."""
async def mock_async_gen():
yield "chunk1"
yield "chunk2"
mock_build.return_value = mock_async_gen()
monkeypatch.setenv("PYTHON_SERVER_MODE", "sync")
service = AIService()
request = MagicMock()
result = service.stream(request)
# Should be a regular (sync) iterator, not async
assert not isinstance(result, AsyncIterator)
assert list(result) == ["chunk1", "chunk2"]
mock_build.assert_called_once_with(request)
@patch.object(AIService, "_build_async_stream")
def test_services_ai_stream_async_mode(mock_build, monkeypatch):
"""In async mode, stream() should return the async iterator directly."""
async def mock_async_gen():
yield "chunk1"
yield "chunk2"
mock_async_iter = mock_async_gen()
mock_build.return_value = mock_async_iter
monkeypatch.setenv("PYTHON_SERVER_MODE", "async")
service = AIService()
request = MagicMock()
result = service.stream(request)
assert result is mock_async_iter
mock_build.assert_called_once_with(request)
@patch.object(AIService, "_build_async_stream")
def test_services_ai_stream_defaults_to_sync(mock_build, monkeypatch):
"""When PYTHON_SERVER_MODE is not set, stream() should default to sync."""
async def mock_async_gen():
yield "data"
mock_build.return_value = mock_async_gen()
monkeypatch.delenv("PYTHON_SERVER_MODE", raising=False)
service = AIService()
request = MagicMock()
result = service.stream(request)
# Default should be sync mode
assert not isinstance(result, AsyncIterator)
assert list(result) == ["data"]
# -- AIService._build_async_stream --
@patch("core.services.ai_services.VercelAIAdapter")
def test_services_ai_build_async_stream(mock_adapter_cls):
"""_build_async_stream should build the pydantic-ai streaming pipeline."""
async def mock_encode():
yield "event-data"
mock_run_input = MagicMock()
mock_run_input.model_extra = None
mock_run_input.messages = []
mock_adapter_cls.build_run_input.return_value = mock_run_input
mock_adapter_instance = MagicMock()
mock_adapter_instance.run_stream.return_value = MagicMock()
mock_adapter_instance.encode_stream.return_value = mock_encode()
mock_adapter_cls.return_value = mock_adapter_instance
service = AIService()
request = MagicMock()
request.META = {"HTTP_ACCEPT": "text/event-stream"}
request.raw_body = b'{"messages": []}'
result = service._build_async_stream(request)
assert isinstance(result, AsyncIterator)
mock_adapter_cls.build_run_input.assert_called_once_with(b'{"messages": []}')
mock_adapter_instance.run_stream.assert_called_once()
mock_adapter_instance.encode_stream.assert_called_once()
@patch("core.services.ai_services.VercelAIAdapter")
def test_services_ai_build_async_stream_with_tool_definitions(mock_adapter_cls):
"""_build_async_stream should build an ExternalToolset when
toolDefinitions are present in the request."""
async def mock_encode():
yield "event-data"
mock_run_input = MagicMock()
mock_run_input.model_extra = {
"toolDefinitions": {
"myTool": {
"description": "A tool",
"inputSchema": {"type": "object"},
}
}
}
mock_run_input.messages = []
mock_adapter_cls.build_run_input.return_value = mock_run_input
mock_adapter_instance = MagicMock()
mock_adapter_instance.run_stream.return_value = MagicMock()
mock_adapter_instance.encode_stream.return_value = mock_encode()
mock_adapter_cls.return_value = mock_adapter_instance
service = AIService()
request = MagicMock()
request.META = {}
request.raw_body = b"{}"
service._build_async_stream(request)
# run_stream should have been called with a toolset
call_kwargs = mock_adapter_instance.run_stream.call_args[1]
assert call_kwargs["toolsets"] is not None
assert len(call_kwargs["toolsets"]) == 1
@patch("core.services.ai_services.Agent")
@patch("core.services.ai_services.VercelAIAdapter")
def test_services_ai_build_async_stream_langfuse_enabled(
mock_adapter_cls, mock_agent_cls, settings
):
"""When LANGFUSE_PUBLIC_KEY is set, instrument should be enabled."""
settings.LANGFUSE_PUBLIC_KEY = "pk-test-123"
async def mock_encode():
yield "data"
mock_run_input = MagicMock()
mock_run_input.model_extra = None
mock_run_input.messages = []
mock_adapter_cls.build_run_input.return_value = mock_run_input
mock_adapter_instance = MagicMock()
mock_adapter_instance.run_stream.return_value = MagicMock()
mock_adapter_instance.encode_stream.return_value = mock_encode()
mock_adapter_cls.return_value = mock_adapter_instance
service = AIService()
request = MagicMock()
request.META = {}
request.raw_body = b"{}"
service._build_async_stream(request)
mock_agent_cls.instrument_all.assert_called_once()
# Agent should be created with instrument=True
mock_agent_cls.assert_called_once()
assert mock_agent_cls.call_args[1]["instrument"] is True

View File

@@ -326,6 +326,7 @@ class Base(Configuration):
"django.middleware.csrf.CsrfViewMiddleware", "django.middleware.csrf.CsrfViewMiddleware",
"django.contrib.auth.middleware.AuthenticationMiddleware", "django.contrib.auth.middleware.AuthenticationMiddleware",
"core.middleware.ForceSessionMiddleware", "core.middleware.ForceSessionMiddleware",
"core.middleware.SaveRawBodyMiddleware",
"django.contrib.messages.middleware.MessageMiddleware", "django.contrib.messages.middleware.MessageMiddleware",
"dockerflow.django.middleware.DockerflowMiddleware", "dockerflow.django.middleware.DockerflowMiddleware",
"csp.middleware.CSPMiddleware", "csp.middleware.CSPMiddleware",
@@ -716,6 +717,9 @@ class Base(Configuration):
AI_STREAM = values.BooleanValue( AI_STREAM = values.BooleanValue(
default=True, environ_name="AI_STREAM", environ_prefix=None default=True, environ_name="AI_STREAM", environ_prefix=None
) )
AI_VERCEL_SDK_VERSION = values.IntegerValue(
6, environ_name="AI_VERCEL_SDK_VERSION", environ_prefix=None
)
AI_USER_RATE_THROTTLE_RATES = { AI_USER_RATE_THROTTLE_RATES = {
"minute": 3, "minute": 3,
"hour": 50, "hour": 50,

View File

@@ -53,9 +53,11 @@ dependencies = [
"markdown==3.10", "markdown==3.10",
"mozilla-django-oidc==5.0.2", "mozilla-django-oidc==5.0.2",
"nested-multipart-parser==1.6.0", "nested-multipart-parser==1.6.0",
"openai==2.14.0", "openai==2.21.0",
"psycopg[binary]==3.3.2", "psycopg[binary]==3.3.2",
"pycrdt==0.12.44", "pycrdt==0.12.44",
"pydantic==2.12.5",
"pydantic-ai-slim[openai,logfire,web]==1.58.0",
"PyJWT==2.10.1", "PyJWT==2.10.1",
"python-magic==0.4.27", "python-magic==0.4.27",
"redis<6.0.0", "redis<6.0.0",