✨(backend) use pydantic AI to manage vercel data stream protocol
The frontend application is using Vercel AI SDK and it's data stream protocol. We decided to use the pydantic AI library to use it's vercel ai adapter. It will make the payload validation, use AsyncIterator and deal with vercel specification.
This commit is contained in:
committed by
Anthony LC
parent
050b106a8f
commit
8ce216f6e8
@@ -15,6 +15,7 @@ These are the environment variables you can set for the `impress-backend` contai
|
|||||||
| AI_FEATURE_ENABLED | Enable AI options | false |
|
| AI_FEATURE_ENABLED | Enable AI options | false |
|
||||||
| AI_MODEL | AI Model to use | |
|
| AI_MODEL | AI Model to use | |
|
||||||
| AI_STREAM | AI Stream to activate | true |
|
| AI_STREAM | AI Stream to activate | true |
|
||||||
|
| AI_VERCEL_SDK_VERSION | The vercel AI SDK version used | 6 |
|
||||||
| ALLOW_LOGOUT_GET_METHOD | Allow get logout method | true |
|
| ALLOW_LOGOUT_GET_METHOD | Allow get logout method | true |
|
||||||
| API_USERS_LIST_LIMIT | Limit on API users | 5 |
|
| API_USERS_LIST_LIMIT | Limit on API users | 5 |
|
||||||
| API_USERS_LIST_THROTTLE_RATE_BURST | Throttle rate for api on burst | 30/minute |
|
| API_USERS_LIST_THROTTLE_RATE_BURST | Throttle rate for api on burst | 30/minute |
|
||||||
@@ -166,4 +167,3 @@ Packages with licences incompatible with the MIT licence:
|
|||||||
In `.env.development`, `PUBLISH_AS_MIT` is set to `false`, allowing developers to test Docs with all its features.
|
In `.env.development`, `PUBLISH_AS_MIT` is set to `false`, allowing developers to test Docs with all its features.
|
||||||
|
|
||||||
⚠️ If you run Docs in production with `PUBLISH_AS_MIT` set to `false` make sure you fulfill your BlockNote licensing or [subscription](https://www.blocknotejs.org/about#partner-with-us) obligations.
|
⚠️ If you run Docs in production with `PUBLISH_AS_MIT` set to `false` make sure you fulfill your BlockNote licensing or [subscription](https://www.blocknotejs.org/about#partner-with-us) obligations.
|
||||||
|
|
||||||
|
|||||||
@@ -833,39 +833,6 @@ class AITranslateSerializer(serializers.Serializer):
|
|||||||
return value
|
return value
|
||||||
|
|
||||||
|
|
||||||
class AIProxySerializer(serializers.Serializer):
|
|
||||||
"""Serializer for AI proxy requests."""
|
|
||||||
|
|
||||||
messages = serializers.ListField(
|
|
||||||
required=True,
|
|
||||||
child=serializers.DictField(),
|
|
||||||
allow_empty=False,
|
|
||||||
)
|
|
||||||
model = serializers.CharField(required=True)
|
|
||||||
|
|
||||||
def validate_messages(self, messages):
|
|
||||||
"""Validate messages structure."""
|
|
||||||
# Ensure each message has the required fields
|
|
||||||
for message in messages:
|
|
||||||
if (
|
|
||||||
not isinstance(message, dict)
|
|
||||||
or "role" not in message
|
|
||||||
or "content" not in message
|
|
||||||
):
|
|
||||||
raise serializers.ValidationError(
|
|
||||||
"Each message must have 'role' and 'content' fields"
|
|
||||||
)
|
|
||||||
|
|
||||||
return messages
|
|
||||||
|
|
||||||
def validate_model(self, value):
|
|
||||||
"""Validate model value is the same than settings.AI_MODEL"""
|
|
||||||
if value != settings.AI_MODEL:
|
|
||||||
raise serializers.ValidationError(f"{value} is not a valid model")
|
|
||||||
|
|
||||||
return value
|
|
||||||
|
|
||||||
|
|
||||||
class MoveDocumentSerializer(serializers.Serializer):
|
class MoveDocumentSerializer(serializers.Serializer):
|
||||||
"""
|
"""
|
||||||
Serializer for validating input data to move a document within the tree structure.
|
Serializer for validating input data to move a document within the tree structure.
|
||||||
|
|||||||
@@ -38,6 +38,7 @@ from csp.decorators import csp_update
|
|||||||
from lasuite.malware_detection import malware_detection
|
from lasuite.malware_detection import malware_detection
|
||||||
from lasuite.oidc_login.decorators import refresh_oidc_access_token
|
from lasuite.oidc_login.decorators import refresh_oidc_access_token
|
||||||
from lasuite.tools.email import get_domain_from_email
|
from lasuite.tools.email import get_domain_from_email
|
||||||
|
from pydantic import ValidationError as PydanticValidationError
|
||||||
from rest_framework import filters, status, viewsets
|
from rest_framework import filters, status, viewsets
|
||||||
from rest_framework import response as drf_response
|
from rest_framework import response as drf_response
|
||||||
from rest_framework.permissions import AllowAny
|
from rest_framework.permissions import AllowAny
|
||||||
@@ -1854,22 +1855,24 @@ class DocumentViewSet(
|
|||||||
if not settings.AI_FEATURE_ENABLED:
|
if not settings.AI_FEATURE_ENABLED:
|
||||||
raise ValidationError("AI feature is not enabled.")
|
raise ValidationError("AI feature is not enabled.")
|
||||||
|
|
||||||
serializer = serializers.AIProxySerializer(data=request.data)
|
|
||||||
serializer.is_valid(raise_exception=True)
|
|
||||||
|
|
||||||
ai_service = AIService()
|
ai_service = AIService()
|
||||||
|
|
||||||
if settings.AI_STREAM:
|
try:
|
||||||
return StreamingHttpResponse(
|
stream = ai_service.stream(request)
|
||||||
ai_service.stream(request.data),
|
except PydanticValidationError as err:
|
||||||
content_type="text/event-stream",
|
logger.info("pydantic validation error: %s", err)
|
||||||
status=drf.status.HTTP_200_OK,
|
return drf.response.Response(
|
||||||
|
{"detail": "Invalid submitted payload"},
|
||||||
|
status=drf.status.HTTP_400_BAD_REQUEST,
|
||||||
)
|
)
|
||||||
|
|
||||||
ai_response = ai_service.proxy(request.data)
|
return StreamingHttpResponse(
|
||||||
return drf.response.Response(
|
stream,
|
||||||
ai_response.model_dump(),
|
content_type="text/event-stream",
|
||||||
status=drf.status.HTTP_200_OK,
|
headers={
|
||||||
|
"x-vercel-ai-data-stream": "v1", # This header is used for Vercel AI streaming,
|
||||||
|
"X-Accel-Buffering": "no", # Prevent nginx buffering
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
@drf.decorators.action(
|
@drf.decorators.action(
|
||||||
|
|||||||
@@ -19,3 +19,21 @@ class ForceSessionMiddleware:
|
|||||||
|
|
||||||
response = self.get_response(request)
|
response = self.get_response(request)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
class SaveRawBodyMiddleware:
|
||||||
|
"""
|
||||||
|
Save the raw request body to use it later.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, get_response):
|
||||||
|
"""Initialize the middleware."""
|
||||||
|
self.get_response = get_response
|
||||||
|
|
||||||
|
def __call__(self, request):
|
||||||
|
"""Save the raw request body in the request to use it later."""
|
||||||
|
if request.path.endswith(("/ai-proxy/", "/ai-proxy")):
|
||||||
|
request.raw_body = request.body
|
||||||
|
|
||||||
|
response = self.get_response(request)
|
||||||
|
return response
|
||||||
|
|||||||
@@ -1,55 +1,38 @@
|
|||||||
"""AI services."""
|
"""AI services."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Dict, Generator
|
import os
|
||||||
|
import queue
|
||||||
|
import threading
|
||||||
|
from collections.abc import AsyncIterator, Iterator
|
||||||
|
from typing import Any, Dict, Union
|
||||||
|
|
||||||
from django.conf import settings
|
from django.conf import settings
|
||||||
from django.core.exceptions import ImproperlyConfigured
|
from django.core.exceptions import ImproperlyConfigured
|
||||||
|
|
||||||
from openai import OpenAIError
|
from langfuse import get_client
|
||||||
|
from langfuse.openai import OpenAI as OpenAI_Langfuse
|
||||||
|
from pydantic_ai import Agent, DeferredToolRequests
|
||||||
|
from pydantic_ai.models.openai import OpenAIChatModel
|
||||||
|
from pydantic_ai.providers.openai import OpenAIProvider
|
||||||
|
from pydantic_ai.tools import ToolDefinition
|
||||||
|
from pydantic_ai.toolsets.external import ExternalToolset
|
||||||
|
from pydantic_ai.ui import SSE_CONTENT_TYPE
|
||||||
|
from pydantic_ai.ui.vercel_ai import VercelAIAdapter
|
||||||
|
from pydantic_ai.ui.vercel_ai.request_types import TextUIPart, UIMessage
|
||||||
|
from rest_framework.request import Request
|
||||||
|
|
||||||
from core import enums
|
from core import enums
|
||||||
|
|
||||||
if settings.LANGFUSE_PUBLIC_KEY:
|
if settings.LANGFUSE_PUBLIC_KEY:
|
||||||
from langfuse.openai import OpenAI
|
OpenAI = OpenAI_Langfuse
|
||||||
else:
|
else:
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
BLOCKNOTE_TOOL_STRICT_PROMPT = """
|
|
||||||
You are editing a BlockNote document via the tool applyDocumentOperations.
|
|
||||||
|
|
||||||
You MUST respond ONLY by calling applyDocumentOperations.
|
|
||||||
The tool input MUST be valid JSON:
|
|
||||||
{ "operations": [ ... ] }
|
|
||||||
|
|
||||||
Each operation MUST include "type" and it MUST be one of:
|
|
||||||
- "update" (requires: id, block)
|
|
||||||
- "add" (requires: referenceId, position, blocks)
|
|
||||||
- "delete" (requires: id)
|
|
||||||
|
|
||||||
VALID SHAPES (FOLLOW EXACTLY):
|
|
||||||
|
|
||||||
Update:
|
|
||||||
{ "type":"update", "id":"<id$>", "block":"<p>...</p>" }
|
|
||||||
IMPORTANT: "block" MUST be a STRING containing a SINGLE valid HTML element.
|
|
||||||
|
|
||||||
Add:
|
|
||||||
{ "type":"add", "referenceId":"<id$>", "position":"before|after", "blocks":["<p>...</p>"] }
|
|
||||||
IMPORTANT: "blocks" MUST be an ARRAY OF STRINGS.
|
|
||||||
Each item MUST be a STRING containing a SINGLE valid HTML element.
|
|
||||||
|
|
||||||
Delete:
|
|
||||||
{ "type":"delete", "id":"<id$>" }
|
|
||||||
|
|
||||||
IDs ALWAYS end with "$". Use ids EXACTLY as provided.
|
|
||||||
|
|
||||||
Return ONLY the JSON tool input. No prose, no markdown.
|
|
||||||
"""
|
|
||||||
|
|
||||||
AI_ACTIONS = {
|
AI_ACTIONS = {
|
||||||
"prompt": (
|
"prompt": (
|
||||||
"Answer the prompt using markdown formatting for structure and emphasis. "
|
"Answer the prompt using markdown formatting for structure and emphasis. "
|
||||||
@@ -95,6 +78,40 @@ AI_TRANSLATE = (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def convert_async_generator_to_sync(async_gen: AsyncIterator[str]) -> Iterator[str]:
|
||||||
|
"""Convert an async generator to a sync generator."""
|
||||||
|
q: queue.Queue[str | object] = queue.Queue()
|
||||||
|
sentinel = object()
|
||||||
|
exc_sentinel = object()
|
||||||
|
|
||||||
|
async def run_async_gen():
|
||||||
|
try:
|
||||||
|
async for async_item in async_gen:
|
||||||
|
q.put(async_item)
|
||||||
|
except Exception as exc: # pylint: disable=broad-except #noqa: BLE001
|
||||||
|
q.put((exc_sentinel, exc))
|
||||||
|
finally:
|
||||||
|
q.put(sentinel)
|
||||||
|
|
||||||
|
def start_async_loop():
|
||||||
|
asyncio.run(run_async_gen())
|
||||||
|
|
||||||
|
thread = threading.Thread(target=start_async_loop, daemon=True)
|
||||||
|
thread.start()
|
||||||
|
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
item = q.get()
|
||||||
|
if item is sentinel:
|
||||||
|
break
|
||||||
|
if isinstance(item, tuple) and item[0] is exc_sentinel:
|
||||||
|
# re-raise the exception in the sync context
|
||||||
|
raise item[1]
|
||||||
|
yield item
|
||||||
|
finally:
|
||||||
|
thread.join()
|
||||||
|
|
||||||
|
|
||||||
class AIService:
|
class AIService:
|
||||||
"""Service class for AI-related operations."""
|
"""Service class for AI-related operations."""
|
||||||
|
|
||||||
@@ -136,77 +153,171 @@ class AIService:
|
|||||||
system_content = AI_TRANSLATE.format(language=language_display)
|
system_content = AI_TRANSLATE.format(language=language_display)
|
||||||
return self.call_ai_api(system_content, text)
|
return self.call_ai_api(system_content, text)
|
||||||
|
|
||||||
def _normalize_tools(self, tools: list) -> list:
|
@staticmethod
|
||||||
|
def inject_document_state_messages(
|
||||||
|
messages: list[UIMessage],
|
||||||
|
) -> list[UIMessage]:
|
||||||
|
"""Inject document state context before user messages.
|
||||||
|
|
||||||
|
Port of BlockNote's injectDocumentStateMessages.
|
||||||
|
For each user message carrying documentState metadata, an assistant
|
||||||
|
message describing the current document/selection state is prepended
|
||||||
|
so the LLM sees it as context.
|
||||||
"""
|
"""
|
||||||
Normalize tool definitions to ensure they have required fields.
|
result: list[UIMessage] = []
|
||||||
"""
|
for message in messages:
|
||||||
normalized = []
|
|
||||||
for tool in tools:
|
|
||||||
if isinstance(tool, dict) and tool.get("type") == "function":
|
|
||||||
fn = tool.get("function") or {}
|
|
||||||
if isinstance(fn, dict) and not fn.get("description"):
|
|
||||||
fn["description"] = f"Tool {fn.get('name', 'unknown')}."
|
|
||||||
tool["function"] = fn
|
|
||||||
normalized.append(tool)
|
|
||||||
return normalized
|
|
||||||
|
|
||||||
def _harden_payload(
|
|
||||||
self, payload: Dict[str, Any], stream: bool = False
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
"""Harden the AI API payload to enforce compliance and tool usage."""
|
|
||||||
payload["stream"] = stream
|
|
||||||
|
|
||||||
# Remove stream_options if stream is False
|
|
||||||
if not stream and "stream_options" in payload:
|
|
||||||
payload.pop("stream_options")
|
|
||||||
|
|
||||||
# Tools normalization
|
|
||||||
if isinstance(payload.get("tools"), list):
|
|
||||||
payload["tools"] = self._normalize_tools(payload["tools"])
|
|
||||||
|
|
||||||
# Inject strict system prompt once
|
|
||||||
msgs = payload.get("messages")
|
|
||||||
if isinstance(msgs, list):
|
|
||||||
need = True
|
|
||||||
if msgs and isinstance(msgs[0], dict) and msgs[0].get("role") == "system":
|
|
||||||
c = msgs[0].get("content") or ""
|
|
||||||
if (
|
if (
|
||||||
isinstance(c, str)
|
message.role == "user"
|
||||||
and "applyDocumentOperations" in c
|
and isinstance(message.metadata, dict)
|
||||||
and "blocks" in c
|
and "documentState" in message.metadata
|
||||||
):
|
):
|
||||||
need = False
|
doc_state = message.metadata["documentState"]
|
||||||
if need:
|
selection = doc_state.get("selection")
|
||||||
payload["messages"] = [
|
blocks = doc_state.get("blocks")
|
||||||
{"role": "system", "content": BLOCKNOTE_TOOL_STRICT_PROMPT}
|
|
||||||
] + msgs
|
|
||||||
|
|
||||||
return payload
|
if selection:
|
||||||
|
parts = [
|
||||||
def proxy(self, data: dict, stream: bool = False) -> Generator[str, None, None]:
|
TextUIPart(
|
||||||
"""Proxy AI API requests to the configured AI provider."""
|
text=(
|
||||||
payload = self._harden_payload(data, stream=stream)
|
"This is the latest state of the selection "
|
||||||
try:
|
"(ignore previous selections, you MUST issue "
|
||||||
return self.client.chat.completions.create(**payload)
|
"operations against this latest version of "
|
||||||
except OpenAIError as e:
|
"the selection):"
|
||||||
raise RuntimeError(f"Failed to proxy AI request: {e}") from e
|
),
|
||||||
|
),
|
||||||
def stream(self, data: dict) -> Generator[str, None, None]:
|
TextUIPart(
|
||||||
"""Stream AI API requests to the configured AI provider."""
|
text=json.dumps(doc_state.get("selectedBlocks")),
|
||||||
|
),
|
||||||
try:
|
TextUIPart(
|
||||||
stream = self.proxy(data, stream=True)
|
text=(
|
||||||
for chunk in stream:
|
"This is the latest state of the entire "
|
||||||
try:
|
"document (INCLUDING the selected text), you "
|
||||||
chunk_dict = (
|
"can use this to find the selected text to "
|
||||||
chunk.model_dump() if hasattr(chunk, "model_dump") else chunk
|
"understand the context (but you MUST NOT "
|
||||||
|
"issue operations against this document, you "
|
||||||
|
"MUST issue operations against the selection):"
|
||||||
|
),
|
||||||
|
),
|
||||||
|
TextUIPart(text=json.dumps(blocks)),
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
text = (
|
||||||
|
"There is no active selection. This is the latest "
|
||||||
|
"state of the document (ignore previous documents, "
|
||||||
|
"you MUST issue operations against this latest "
|
||||||
|
"version of the document). The cursor is BETWEEN "
|
||||||
|
"two blocks as indicated by cursor: true."
|
||||||
)
|
)
|
||||||
chunk_json = json.dumps(chunk_dict)
|
if doc_state.get("isEmptyDocument"):
|
||||||
yield f"data: {chunk_json}\n\n"
|
text += (
|
||||||
except (AttributeError, TypeError) as e:
|
"Because the document is empty, YOU MUST first "
|
||||||
log.error("Error serializing chunk: %s, chunk: %s", e, chunk)
|
"update the empty block before adding new blocks."
|
||||||
continue
|
)
|
||||||
except (OpenAIError, RuntimeError, OSError, ValueError) as e:
|
else:
|
||||||
log.error("Streaming error: %s", e)
|
text += (
|
||||||
|
"Prefer updating existing blocks over removing "
|
||||||
|
"and adding (but this also depends on the "
|
||||||
|
"user's question)."
|
||||||
|
)
|
||||||
|
parts = [
|
||||||
|
TextUIPart(text=text),
|
||||||
|
TextUIPart(text=json.dumps(blocks)),
|
||||||
|
]
|
||||||
|
|
||||||
yield "data: [DONE]\n\n"
|
result.append(
|
||||||
|
UIMessage(
|
||||||
|
role="assistant",
|
||||||
|
id=f"assistant-document-state-{message.id}",
|
||||||
|
parts=parts,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
result.append(message)
|
||||||
|
return result
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def tool_definitions_to_toolset(
|
||||||
|
tool_definitions: Dict[str, Any],
|
||||||
|
) -> ExternalToolset:
|
||||||
|
"""Convert serialized tool definitions to a pydantic-ai ExternalToolset.
|
||||||
|
|
||||||
|
Port of BlockNote's toolDefinitionsToToolSet.
|
||||||
|
Builds ToolDefinition objects from the JSON-Schema-based definitions
|
||||||
|
sent by the frontend and wraps them in an ExternalToolset so that
|
||||||
|
pydantic-ai advertises them to the LLM without trying to execute them
|
||||||
|
server-side (execution is deferred to the frontend).
|
||||||
|
"""
|
||||||
|
tool_defs = [
|
||||||
|
ToolDefinition(
|
||||||
|
name=name,
|
||||||
|
description=defn.get("description", ""),
|
||||||
|
parameters_json_schema=defn.get("inputSchema", {}),
|
||||||
|
kind="external",
|
||||||
|
metadata={
|
||||||
|
"output_schema": defn.get("outputSchema"),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
for name, defn in tool_definitions.items()
|
||||||
|
]
|
||||||
|
return ExternalToolset(tool_defs)
|
||||||
|
|
||||||
|
def _build_async_stream(self, request: Request) -> AsyncIterator[str]:
|
||||||
|
"""Build the async stream from the AI provider."""
|
||||||
|
instrument_enabled = settings.LANGFUSE_PUBLIC_KEY is not None
|
||||||
|
|
||||||
|
if instrument_enabled:
|
||||||
|
langfuse = get_client()
|
||||||
|
langfuse.auth_check()
|
||||||
|
Agent.instrument_all()
|
||||||
|
|
||||||
|
model = OpenAIChatModel(
|
||||||
|
settings.AI_MODEL,
|
||||||
|
provider=OpenAIProvider(
|
||||||
|
base_url=settings.AI_BASE_URL, api_key=settings.AI_API_KEY
|
||||||
|
),
|
||||||
|
)
|
||||||
|
agent = Agent(model, instrument=instrument_enabled)
|
||||||
|
|
||||||
|
accept = request.META.get("HTTP_ACCEPT", SSE_CONTENT_TYPE)
|
||||||
|
|
||||||
|
run_input = VercelAIAdapter.build_run_input(request.raw_body)
|
||||||
|
|
||||||
|
# Inject document state context into the conversation
|
||||||
|
run_input.messages = self.inject_document_state_messages(run_input.messages)
|
||||||
|
|
||||||
|
# Build an ExternalToolset from frontend-supplied tool definitions
|
||||||
|
raw_tool_defs = (
|
||||||
|
run_input.model_extra.get("toolDefinitions")
|
||||||
|
if run_input.model_extra
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
toolset = (
|
||||||
|
self.tool_definitions_to_toolset(raw_tool_defs) if raw_tool_defs else None
|
||||||
|
)
|
||||||
|
|
||||||
|
adapter = VercelAIAdapter(
|
||||||
|
agent=agent,
|
||||||
|
run_input=run_input,
|
||||||
|
accept=accept,
|
||||||
|
sdk_version=settings.AI_VERCEL_SDK_VERSION,
|
||||||
|
)
|
||||||
|
|
||||||
|
event_stream = adapter.run_stream(
|
||||||
|
output_type=[str, DeferredToolRequests] if toolset else None,
|
||||||
|
toolsets=[toolset] if toolset else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
return adapter.encode_stream(event_stream)
|
||||||
|
|
||||||
|
def stream(self, request: Request) -> Union[AsyncIterator[str], Iterator[str]]:
|
||||||
|
"""Stream AI API requests to the configured AI provider.
|
||||||
|
|
||||||
|
Returns an async iterator when running in async mode (ASGI)
|
||||||
|
or a sync iterator when running in sync mode (WSGI).
|
||||||
|
"""
|
||||||
|
async_stream = self._build_async_stream(request)
|
||||||
|
|
||||||
|
if os.environ.get("PYTHON_SERVER_MODE", "sync") == "async":
|
||||||
|
return async_stream
|
||||||
|
|
||||||
|
return convert_async_generator_to_sync(async_stream)
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ Test AI proxy API endpoint for users in impress's core app.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import random
|
import random
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
from django.test import override_settings
|
from django.test import override_settings
|
||||||
|
|
||||||
@@ -11,7 +11,6 @@ import pytest
|
|||||||
from rest_framework.test import APIClient
|
from rest_framework.test import APIClient
|
||||||
|
|
||||||
from core import factories
|
from core import factories
|
||||||
from core.services.ai_services import BLOCKNOTE_TOOL_STRICT_PROMPT
|
|
||||||
from core.tests.conftest import TEAM, USER, VIA
|
from core.tests.conftest import TEAM, USER, VIA
|
||||||
|
|
||||||
pytestmark = pytest.mark.django_db
|
pytestmark = pytest.mark.django_db
|
||||||
@@ -24,6 +23,8 @@ def ai_settings(settings):
|
|||||||
settings.AI_BASE_URL = "http://localhost-ai:12345/"
|
settings.AI_BASE_URL = "http://localhost-ai:12345/"
|
||||||
settings.AI_API_KEY = "test-key"
|
settings.AI_API_KEY = "test-key"
|
||||||
settings.AI_FEATURE_ENABLED = True
|
settings.AI_FEATURE_ENABLED = True
|
||||||
|
settings.LANGFUSE_PUBLIC_KEY = None
|
||||||
|
settings.AI_VERCEL_SDK_VERSION = 6
|
||||||
|
|
||||||
|
|
||||||
@override_settings(
|
@override_settings(
|
||||||
@@ -51,7 +52,6 @@ def test_api_documents_ai_proxy_anonymous_forbidden(reach, role):
|
|||||||
url,
|
url,
|
||||||
{
|
{
|
||||||
"messages": [{"role": "user", "content": "Hello"}],
|
"messages": [{"role": "user", "content": "Hello"}],
|
||||||
"model": "llama",
|
|
||||||
},
|
},
|
||||||
format="json",
|
format="json",
|
||||||
)
|
)
|
||||||
@@ -62,85 +62,48 @@ def test_api_documents_ai_proxy_anonymous_forbidden(reach, role):
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@override_settings(AI_ALLOW_REACH_FROM="public", AI_STREAM=False)
|
@override_settings(AI_ALLOW_REACH_FROM="public")
|
||||||
@patch("openai.resources.chat.completions.Completions.create")
|
@patch("core.services.ai_services.AIService.stream")
|
||||||
def test_api_documents_ai_proxy_anonymous_success(mock_create):
|
def test_api_documents_ai_proxy_anonymous_success(mock_stream):
|
||||||
"""
|
"""
|
||||||
Anonymous users should be able to request AI proxy to a document
|
Anonymous users should be able to request AI proxy to a document
|
||||||
if the link reach and role permit it.
|
if the link reach and role permit it.
|
||||||
"""
|
"""
|
||||||
document = factories.DocumentFactory(link_reach="public", link_role="editor")
|
document = factories.DocumentFactory(link_reach="public", link_role="editor")
|
||||||
|
|
||||||
mock_response = MagicMock()
|
mock_stream.return_value = iter(["data: chunk1\n", "data: chunk2\n"])
|
||||||
mock_response.model_dump.return_value = {
|
|
||||||
"id": "chatcmpl-123",
|
|
||||||
"object": "chat.completion",
|
|
||||||
"created": 1677652288,
|
|
||||||
"model": "llama",
|
|
||||||
"choices": [
|
|
||||||
{
|
|
||||||
"index": 0,
|
|
||||||
"message": {
|
|
||||||
"role": "assistant",
|
|
||||||
"content": "Hello! How can I help you?",
|
|
||||||
},
|
|
||||||
"finish_reason": "stop",
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"usage": {"prompt_tokens": 9, "completion_tokens": 12, "total_tokens": 21},
|
|
||||||
}
|
|
||||||
mock_create.return_value = mock_response
|
|
||||||
url = f"/api/v1.0/documents/{document.id!s}/ai-proxy/"
|
url = f"/api/v1.0/documents/{document.id!s}/ai-proxy/"
|
||||||
response = APIClient().post(
|
response = APIClient().post(
|
||||||
url,
|
url,
|
||||||
{
|
b"{}",
|
||||||
"messages": [{"role": "user", "content": "Hello"}],
|
content_type="application/json",
|
||||||
"model": "llama",
|
|
||||||
},
|
|
||||||
format="json",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
response_data = response.json()
|
assert response["Content-Type"] == "text/event-stream"
|
||||||
assert response_data["id"] == "chatcmpl-123"
|
assert response["x-vercel-ai-data-stream"] == "v1"
|
||||||
assert response_data["model"] == "llama"
|
assert response["X-Accel-Buffering"] == "no"
|
||||||
assert len(response_data["choices"]) == 1
|
|
||||||
assert (
|
|
||||||
response_data["choices"][0]["message"]["content"]
|
|
||||||
== "Hello! How can I help you?"
|
|
||||||
)
|
|
||||||
|
|
||||||
mock_create.assert_called_once_with(
|
content = b"".join(response.streaming_content).decode()
|
||||||
messages=[
|
assert "chunk1" in content
|
||||||
{"role": "system", "content": BLOCKNOTE_TOOL_STRICT_PROMPT},
|
assert "chunk2" in content
|
||||||
{"role": "user", "content": "Hello"},
|
mock_stream.assert_called_once()
|
||||||
],
|
|
||||||
model="llama",
|
|
||||||
stream=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@override_settings(AI_ALLOW_REACH_FROM=random.choice(["authenticated", "restricted"]))
|
@override_settings(AI_ALLOW_REACH_FROM=random.choice(["authenticated", "restricted"]))
|
||||||
@patch("openai.resources.chat.completions.Completions.create")
|
def test_api_documents_ai_proxy_anonymous_limited_by_setting():
|
||||||
def test_api_documents_ai_proxy_anonymous_limited_by_setting(mock_create):
|
|
||||||
"""
|
"""
|
||||||
Anonymous users should not be able to request AI proxy to a document
|
Anonymous users should not be able to request AI proxy to a document
|
||||||
if AI_ALLOW_REACH_FROM setting restricts it.
|
if AI_ALLOW_REACH_FROM setting restricts it.
|
||||||
"""
|
"""
|
||||||
document = factories.DocumentFactory(link_reach="public", link_role="editor")
|
document = factories.DocumentFactory(link_reach="public", link_role="editor")
|
||||||
|
|
||||||
mock_response = MagicMock()
|
|
||||||
mock_response.model_dump.return_value = {"content": "Hello!"}
|
|
||||||
mock_create.return_value = mock_response
|
|
||||||
|
|
||||||
url = f"/api/v1.0/documents/{document.id!s}/ai-proxy/"
|
url = f"/api/v1.0/documents/{document.id!s}/ai-proxy/"
|
||||||
response = APIClient().post(
|
response = APIClient().post(
|
||||||
url,
|
url,
|
||||||
{
|
b"{}",
|
||||||
"messages": [{"role": "user", "content": "Hello"}],
|
content_type="application/json",
|
||||||
"model": "llama",
|
|
||||||
},
|
|
||||||
format="json",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response.status_code == 401
|
assert response.status_code == 401
|
||||||
@@ -170,11 +133,8 @@ def test_api_documents_ai_proxy_authenticated_forbidden(reach, role):
|
|||||||
url = f"/api/v1.0/documents/{document.id!s}/ai-proxy/"
|
url = f"/api/v1.0/documents/{document.id!s}/ai-proxy/"
|
||||||
response = client.post(
|
response = client.post(
|
||||||
url,
|
url,
|
||||||
{
|
b"{}",
|
||||||
"messages": [{"role": "user", "content": "Hello"}],
|
content_type="application/json",
|
||||||
"model": "llama",
|
|
||||||
},
|
|
||||||
format="json",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response.status_code == 403
|
assert response.status_code == 403
|
||||||
@@ -187,9 +147,8 @@ def test_api_documents_ai_proxy_authenticated_forbidden(reach, role):
|
|||||||
("public", "editor"),
|
("public", "editor"),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@override_settings(AI_STREAM=False)
|
@patch("core.services.ai_services.AIService.stream")
|
||||||
@patch("openai.resources.chat.completions.Completions.create")
|
def test_api_documents_ai_proxy_authenticated_success(mock_stream, reach, role):
|
||||||
def test_api_documents_ai_proxy_authenticated_success(mock_create, reach, role):
|
|
||||||
"""
|
"""
|
||||||
Authenticated users should be able to request AI proxy to a document
|
Authenticated users should be able to request AI proxy to a document
|
||||||
if the link reach and role permit it.
|
if the link reach and role permit it.
|
||||||
@@ -201,44 +160,18 @@ def test_api_documents_ai_proxy_authenticated_success(mock_create, reach, role):
|
|||||||
|
|
||||||
document = factories.DocumentFactory(link_reach=reach, link_role=role)
|
document = factories.DocumentFactory(link_reach=reach, link_role=role)
|
||||||
|
|
||||||
mock_response = MagicMock()
|
mock_stream.return_value = iter(["data: response\n"])
|
||||||
mock_response.model_dump.return_value = {
|
|
||||||
"id": "chatcmpl-456",
|
|
||||||
"object": "chat.completion",
|
|
||||||
"model": "llama",
|
|
||||||
"choices": [
|
|
||||||
{
|
|
||||||
"index": 0,
|
|
||||||
"message": {"role": "assistant", "content": "Hi there!"},
|
|
||||||
"finish_reason": "stop",
|
|
||||||
}
|
|
||||||
],
|
|
||||||
}
|
|
||||||
mock_create.return_value = mock_response
|
|
||||||
|
|
||||||
url = f"/api/v1.0/documents/{document.id!s}/ai-proxy/"
|
url = f"/api/v1.0/documents/{document.id!s}/ai-proxy/"
|
||||||
response = client.post(
|
response = client.post(
|
||||||
url,
|
url,
|
||||||
{
|
b"{}",
|
||||||
"messages": [{"role": "user", "content": "Hello"}],
|
content_type="application/json",
|
||||||
"model": "llama",
|
|
||||||
},
|
|
||||||
format="json",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
response_data = response.json()
|
assert response["Content-Type"] == "text/event-stream"
|
||||||
assert response_data["id"] == "chatcmpl-456"
|
mock_stream.assert_called_once()
|
||||||
assert response_data["choices"][0]["message"]["content"] == "Hi there!"
|
|
||||||
|
|
||||||
mock_create.assert_called_once_with(
|
|
||||||
messages=[
|
|
||||||
{"role": "system", "content": BLOCKNOTE_TOOL_STRICT_PROMPT},
|
|
||||||
{"role": "user", "content": "Hello"},
|
|
||||||
],
|
|
||||||
model="llama",
|
|
||||||
stream=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("via", VIA)
|
@pytest.mark.parametrize("via", VIA)
|
||||||
@@ -261,11 +194,8 @@ def test_api_documents_ai_proxy_reader(via, mock_user_teams):
|
|||||||
url = f"/api/v1.0/documents/{document.id!s}/ai-proxy/"
|
url = f"/api/v1.0/documents/{document.id!s}/ai-proxy/"
|
||||||
response = client.post(
|
response = client.post(
|
||||||
url,
|
url,
|
||||||
{
|
b"{}",
|
||||||
"messages": [{"role": "user", "content": "Hello"}],
|
content_type="application/json",
|
||||||
"model": "llama",
|
|
||||||
},
|
|
||||||
format="json",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response.status_code == 403
|
assert response.status_code == 403
|
||||||
@@ -273,9 +203,8 @@ def test_api_documents_ai_proxy_reader(via, mock_user_teams):
|
|||||||
|
|
||||||
@pytest.mark.parametrize("role", ["editor", "administrator", "owner"])
|
@pytest.mark.parametrize("role", ["editor", "administrator", "owner"])
|
||||||
@pytest.mark.parametrize("via", VIA)
|
@pytest.mark.parametrize("via", VIA)
|
||||||
@override_settings(AI_STREAM=False)
|
@patch("core.services.ai_services.AIService.stream")
|
||||||
@patch("openai.resources.chat.completions.Completions.create")
|
def test_api_documents_ai_proxy_success(mock_stream, via, role, mock_user_teams):
|
||||||
def test_api_documents_ai_proxy_success(mock_create, via, role, mock_user_teams):
|
|
||||||
"""Users with sufficient permissions should be able to request AI proxy."""
|
"""Users with sufficient permissions should be able to request AI proxy."""
|
||||||
user = factories.UserFactory()
|
user = factories.UserFactory()
|
||||||
|
|
||||||
@@ -291,402 +220,27 @@ def test_api_documents_ai_proxy_success(mock_create, via, role, mock_user_teams)
|
|||||||
document=document, team="lasuite", role=role
|
document=document, team="lasuite", role=role
|
||||||
)
|
)
|
||||||
|
|
||||||
mock_response = MagicMock()
|
mock_stream.return_value = iter(["data: success\n"])
|
||||||
mock_response.model_dump.return_value = {
|
|
||||||
"id": "chatcmpl-789",
|
|
||||||
"object": "chat.completion",
|
|
||||||
"model": "llama",
|
|
||||||
"choices": [
|
|
||||||
{
|
|
||||||
"index": 0,
|
|
||||||
"message": {"role": "assistant", "content": "Success!"},
|
|
||||||
"finish_reason": "stop",
|
|
||||||
}
|
|
||||||
],
|
|
||||||
}
|
|
||||||
mock_create.return_value = mock_response
|
|
||||||
|
|
||||||
url = f"/api/v1.0/documents/{document.id!s}/ai-proxy/"
|
url = f"/api/v1.0/documents/{document.id!s}/ai-proxy/"
|
||||||
response = client.post(
|
response = client.post(
|
||||||
url,
|
url,
|
||||||
{
|
b"{}",
|
||||||
"messages": [{"role": "user", "content": "Test message"}],
|
content_type="application/json",
|
||||||
"model": "llama",
|
|
||||||
},
|
|
||||||
format="json",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
response_data = response.json()
|
assert response["Content-Type"] == "text/event-stream"
|
||||||
assert response_data["id"] == "chatcmpl-789"
|
assert response["x-vercel-ai-data-stream"] == "v1"
|
||||||
assert response_data["choices"][0]["message"]["content"] == "Success!"
|
assert response["X-Accel-Buffering"] == "no"
|
||||||
|
|
||||||
mock_create.assert_called_once_with(
|
content = b"".join(response.streaming_content).decode()
|
||||||
messages=[
|
assert "success" in content
|
||||||
{"role": "system", "content": BLOCKNOTE_TOOL_STRICT_PROMPT},
|
mock_stream.assert_called_once()
|
||||||
{"role": "user", "content": "Test message"},
|
|
||||||
],
|
|
||||||
model="llama",
|
|
||||||
stream=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_api_documents_ai_proxy_empty_messages():
|
|
||||||
"""The messages should not be empty when requesting AI proxy."""
|
|
||||||
user = factories.UserFactory()
|
|
||||||
|
|
||||||
client = APIClient()
|
|
||||||
client.force_login(user)
|
|
||||||
|
|
||||||
document = factories.DocumentFactory(link_reach="public", link_role="editor")
|
|
||||||
|
|
||||||
url = f"/api/v1.0/documents/{document.id!s}/ai-proxy/"
|
|
||||||
response = client.post(url, {"messages": [], "model": "llama"}, format="json")
|
|
||||||
|
|
||||||
assert response.status_code == 400
|
|
||||||
assert response.json() == {"messages": ["This list may not be empty."]}
|
|
||||||
|
|
||||||
|
|
||||||
def test_api_documents_ai_proxy_missing_model():
|
|
||||||
"""The model should be required when requesting AI proxy."""
|
|
||||||
user = factories.UserFactory()
|
|
||||||
|
|
||||||
client = APIClient()
|
|
||||||
client.force_login(user)
|
|
||||||
|
|
||||||
document = factories.DocumentFactory(link_reach="public", link_role="editor")
|
|
||||||
|
|
||||||
url = f"/api/v1.0/documents/{document.id!s}/ai-proxy/"
|
|
||||||
response = client.post(
|
|
||||||
url, {"messages": [{"role": "user", "content": "Hello"}]}, format="json"
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.status_code == 400
|
|
||||||
assert response.json() == {"model": ["This field is required."]}
|
|
||||||
|
|
||||||
|
|
||||||
def test_api_documents_ai_proxy_invalid_message_format():
|
|
||||||
"""Messages should have the correct format when requesting AI proxy."""
|
|
||||||
user = factories.UserFactory()
|
|
||||||
|
|
||||||
client = APIClient()
|
|
||||||
client.force_login(user)
|
|
||||||
|
|
||||||
document = factories.DocumentFactory(link_reach="public", link_role="editor")
|
|
||||||
|
|
||||||
url = f"/api/v1.0/documents/{document.id!s}/ai-proxy/"
|
|
||||||
|
|
||||||
# Test with invalid message format (missing role)
|
|
||||||
response = client.post(
|
|
||||||
url,
|
|
||||||
{
|
|
||||||
"messages": [{"content": "Hello"}],
|
|
||||||
"model": "llama",
|
|
||||||
},
|
|
||||||
format="json",
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.status_code == 400
|
|
||||||
assert response.json() == {
|
|
||||||
"messages": ["Each message must have 'role' and 'content' fields"]
|
|
||||||
}
|
|
||||||
|
|
||||||
# Test with invalid message format (missing content)
|
|
||||||
response = client.post(
|
|
||||||
url,
|
|
||||||
{
|
|
||||||
"messages": [{"role": "user"}],
|
|
||||||
"model": "llama",
|
|
||||||
},
|
|
||||||
format="json",
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.status_code == 400
|
|
||||||
assert response.json() == {
|
|
||||||
"messages": ["Each message must have 'role' and 'content' fields"]
|
|
||||||
}
|
|
||||||
|
|
||||||
# Test with non-dict message
|
|
||||||
response = client.post(
|
|
||||||
url,
|
|
||||||
{
|
|
||||||
"messages": ["invalid"],
|
|
||||||
"model": "llama",
|
|
||||||
},
|
|
||||||
format="json",
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.status_code == 400
|
|
||||||
assert response.json() == {
|
|
||||||
"messages": {"0": ['Expected a dictionary of items but got type "str".']}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@override_settings(AI_STREAM=False)
|
|
||||||
@patch("openai.resources.chat.completions.Completions.create")
|
|
||||||
def test_api_documents_ai_proxy_stream_disabled(mock_create):
|
|
||||||
"""Stream should be automatically disabled in AI proxy requests."""
|
|
||||||
user = factories.UserFactory()
|
|
||||||
|
|
||||||
client = APIClient()
|
|
||||||
client.force_login(user)
|
|
||||||
|
|
||||||
document = factories.DocumentFactory(link_reach="public", link_role="editor")
|
|
||||||
|
|
||||||
mock_response = MagicMock()
|
|
||||||
mock_response.model_dump.return_value = {"content": "Success!"}
|
|
||||||
mock_create.return_value = mock_response
|
|
||||||
|
|
||||||
url = f"/api/v1.0/documents/{document.id!s}/ai-proxy/"
|
|
||||||
response = client.post(
|
|
||||||
url,
|
|
||||||
{
|
|
||||||
"messages": [{"role": "user", "content": "Hello"}],
|
|
||||||
"model": "llama",
|
|
||||||
"stream": True, # This should be overridden to False
|
|
||||||
},
|
|
||||||
format="json",
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.status_code == 200
|
|
||||||
# Verify that stream was set to False
|
|
||||||
mock_create.assert_called_once_with(
|
|
||||||
messages=[
|
|
||||||
{"role": "system", "content": BLOCKNOTE_TOOL_STRICT_PROMPT},
|
|
||||||
{"role": "user", "content": "Hello"},
|
|
||||||
],
|
|
||||||
model="llama",
|
|
||||||
stream=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@override_settings(AI_STREAM=False)
|
|
||||||
@patch("openai.resources.chat.completions.Completions.create")
|
|
||||||
def test_api_documents_ai_proxy_additional_parameters(mock_create):
|
|
||||||
"""AI proxy should pass through additional parameters to the AI service."""
|
|
||||||
user = factories.UserFactory()
|
|
||||||
|
|
||||||
client = APIClient()
|
|
||||||
client.force_login(user)
|
|
||||||
|
|
||||||
document = factories.DocumentFactory(link_reach="public", link_role="editor")
|
|
||||||
|
|
||||||
mock_response = MagicMock()
|
|
||||||
mock_response.model_dump.return_value = {"content": "Success!"}
|
|
||||||
mock_create.return_value = mock_response
|
|
||||||
|
|
||||||
url = f"/api/v1.0/documents/{document.id!s}/ai-proxy/"
|
|
||||||
response = client.post(
|
|
||||||
url,
|
|
||||||
{
|
|
||||||
"messages": [{"role": "user", "content": "Hello"}],
|
|
||||||
"model": "llama",
|
|
||||||
"temperature": 0.7,
|
|
||||||
"max_tokens": 100,
|
|
||||||
"top_p": 0.9,
|
|
||||||
},
|
|
||||||
format="json",
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.status_code == 200
|
|
||||||
# Verify that additional parameters were passed through
|
|
||||||
mock_create.assert_called_once_with(
|
|
||||||
messages=[
|
|
||||||
{"role": "system", "content": BLOCKNOTE_TOOL_STRICT_PROMPT},
|
|
||||||
{"role": "user", "content": "Hello"},
|
|
||||||
],
|
|
||||||
model="llama",
|
|
||||||
temperature=0.7,
|
|
||||||
max_tokens=100,
|
|
||||||
top_p=0.9,
|
|
||||||
stream=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@override_settings(AI_STREAM=False)
|
|
||||||
@override_settings(AI_DOCUMENT_RATE_THROTTLE_RATES={"minute": 3, "hour": 6, "day": 10})
|
|
||||||
@patch("openai.resources.chat.completions.Completions.create")
|
|
||||||
def test_api_documents_ai_proxy_throttling_document(mock_create):
|
|
||||||
"""
|
|
||||||
Throttling per document should be triggered on the AI transform endpoint.
|
|
||||||
For full throttle class test see: `test_api_utils_ai_document_rate_throttles`
|
|
||||||
"""
|
|
||||||
client = APIClient()
|
|
||||||
document = factories.DocumentFactory(link_reach="public", link_role="editor")
|
|
||||||
|
|
||||||
mock_response = MagicMock()
|
|
||||||
mock_response.model_dump.return_value = {"content": "Success!"}
|
|
||||||
mock_create.return_value = mock_response
|
|
||||||
|
|
||||||
url = f"/api/v1.0/documents/{document.id!s}/ai-proxy/"
|
|
||||||
for _ in range(3):
|
|
||||||
user = factories.UserFactory()
|
|
||||||
client.force_login(user)
|
|
||||||
response = client.post(
|
|
||||||
url,
|
|
||||||
{
|
|
||||||
"messages": [{"role": "user", "content": "Test message"}],
|
|
||||||
"model": "llama",
|
|
||||||
},
|
|
||||||
format="json",
|
|
||||||
)
|
|
||||||
assert response.status_code == 200
|
|
||||||
assert response.json() == {"content": "Success!"}
|
|
||||||
|
|
||||||
user = factories.UserFactory()
|
|
||||||
client.force_login(user)
|
|
||||||
response = client.post(
|
|
||||||
url,
|
|
||||||
{
|
|
||||||
"messages": [{"role": "user", "content": "Test message"}],
|
|
||||||
"model": "llama",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.status_code == 429
|
|
||||||
assert response.json() == {
|
|
||||||
"detail": "Request was throttled. Expected available in 60 seconds."
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@override_settings(AI_STREAM=False)
|
|
||||||
@patch("openai.resources.chat.completions.Completions.create")
|
|
||||||
def test_api_documents_ai_proxy_complex_conversation(mock_create):
|
|
||||||
"""AI proxy should handle complex conversations with multiple messages."""
|
|
||||||
user = factories.UserFactory()
|
|
||||||
|
|
||||||
client = APIClient()
|
|
||||||
client.force_login(user)
|
|
||||||
|
|
||||||
document = factories.DocumentFactory(link_reach="public", link_role="editor")
|
|
||||||
|
|
||||||
mock_response = MagicMock()
|
|
||||||
mock_response.model_dump.return_value = {
|
|
||||||
"id": "chatcmpl-complex",
|
|
||||||
"object": "chat.completion",
|
|
||||||
"model": "llama",
|
|
||||||
"choices": [
|
|
||||||
{
|
|
||||||
"index": 0,
|
|
||||||
"message": {
|
|
||||||
"role": "assistant",
|
|
||||||
"content": "I understand your question about Python.",
|
|
||||||
},
|
|
||||||
"finish_reason": "stop",
|
|
||||||
}
|
|
||||||
],
|
|
||||||
}
|
|
||||||
mock_create.return_value = mock_response
|
|
||||||
|
|
||||||
complex_messages = [
|
|
||||||
{"role": "system", "content": BLOCKNOTE_TOOL_STRICT_PROMPT},
|
|
||||||
{"role": "system", "content": "You are a helpful programming assistant."},
|
|
||||||
{"role": "user", "content": "How do I write a for loop in Python?"},
|
|
||||||
{
|
|
||||||
"role": "assistant",
|
|
||||||
"content": "You can write a for loop using: for item in iterable:",
|
|
||||||
},
|
|
||||||
{"role": "user", "content": "Can you give me a concrete example?"},
|
|
||||||
]
|
|
||||||
|
|
||||||
url = f"/api/v1.0/documents/{document.id!s}/ai-proxy/"
|
|
||||||
response = client.post(
|
|
||||||
url,
|
|
||||||
{
|
|
||||||
"messages": complex_messages,
|
|
||||||
"model": "llama",
|
|
||||||
},
|
|
||||||
format="json",
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.status_code == 200
|
|
||||||
response_data = response.json()
|
|
||||||
assert response_data["id"] == "chatcmpl-complex"
|
|
||||||
assert (
|
|
||||||
response_data["choices"][0]["message"]["content"]
|
|
||||||
== "I understand your question about Python."
|
|
||||||
)
|
|
||||||
|
|
||||||
mock_create.assert_called_once_with(
|
|
||||||
messages=complex_messages,
|
|
||||||
model="llama",
|
|
||||||
stream=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@override_settings(AI_USER_RATE_THROTTLE_RATES={"minute": 3, "hour": 6, "day": 10})
|
|
||||||
@patch("openai.resources.chat.completions.Completions.create")
|
|
||||||
def test_api_documents_ai_proxy_throttling_user(mock_create):
|
|
||||||
"""
|
|
||||||
Throttling per user should be triggered on the AI proxy endpoint.
|
|
||||||
For full throttle class test see: `test_api_utils_ai_user_rate_throttles`
|
|
||||||
"""
|
|
||||||
user = factories.UserFactory()
|
|
||||||
client = APIClient()
|
|
||||||
client.force_login(user)
|
|
||||||
|
|
||||||
mock_response = MagicMock()
|
|
||||||
mock_response.model_dump.return_value = {"content": "Success!"}
|
|
||||||
mock_create.return_value = mock_response
|
|
||||||
|
|
||||||
for _ in range(3):
|
|
||||||
document = factories.DocumentFactory(link_reach="public", link_role="editor")
|
|
||||||
url = f"/api/v1.0/documents/{document.id!s}/ai-proxy/"
|
|
||||||
response = client.post(
|
|
||||||
url,
|
|
||||||
{
|
|
||||||
"messages": [{"role": "user", "content": "Hello"}],
|
|
||||||
"model": "llama",
|
|
||||||
},
|
|
||||||
format="json",
|
|
||||||
)
|
|
||||||
assert response.status_code == 200
|
|
||||||
|
|
||||||
document = factories.DocumentFactory(link_reach="public", link_role="editor")
|
|
||||||
url = f"/api/v1.0/documents/{document.id!s}/ai-proxy/"
|
|
||||||
response = client.post(
|
|
||||||
url,
|
|
||||||
{
|
|
||||||
"messages": [{"role": "user", "content": "Hello"}],
|
|
||||||
"model": "llama",
|
|
||||||
},
|
|
||||||
format="json",
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.status_code == 429
|
|
||||||
assert response.json() == {
|
|
||||||
"detail": "Request was throttled. Expected available in 60 seconds."
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@override_settings(AI_USER_RATE_THROTTLE_RATES={"minute": 10, "hour": 6, "day": 10})
|
|
||||||
def test_api_documents_ai_proxy_different_models():
|
|
||||||
"""AI proxy should work with different AI models."""
|
|
||||||
user = factories.UserFactory()
|
|
||||||
|
|
||||||
client = APIClient()
|
|
||||||
client.force_login(user)
|
|
||||||
|
|
||||||
document = factories.DocumentFactory(link_reach="public", link_role="editor")
|
|
||||||
|
|
||||||
models_to_test = ["gpt-3.5-turbo", "gpt-4", "claude-3", "llama-2"]
|
|
||||||
|
|
||||||
for model_name in models_to_test:
|
|
||||||
response = client.post(
|
|
||||||
f"/api/v1.0/documents/{document.id!s}/ai-proxy/",
|
|
||||||
{
|
|
||||||
"messages": [{"role": "user", "content": "Hello"}],
|
|
||||||
"model": model_name,
|
|
||||||
},
|
|
||||||
format="json",
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.status_code == 400
|
|
||||||
assert response.json() == {"model": [f"{model_name} is not a valid model"]}
|
|
||||||
|
|
||||||
|
|
||||||
def test_api_documents_ai_proxy_ai_feature_disabled(settings):
|
def test_api_documents_ai_proxy_ai_feature_disabled(settings):
|
||||||
"""When the settings AI_FEATURE_ENABLED is set to False, the endpoint is not reachable."""
|
"""When AI_FEATURE_ENABLED is False, the endpoint returns 400."""
|
||||||
settings.AI_FEATURE_ENABLED = False
|
settings.AI_FEATURE_ENABLED = False
|
||||||
|
|
||||||
user = factories.UserFactory()
|
user = factories.UserFactory()
|
||||||
@@ -698,25 +252,91 @@ def test_api_documents_ai_proxy_ai_feature_disabled(settings):
|
|||||||
|
|
||||||
response = client.post(
|
response = client.post(
|
||||||
f"/api/v1.0/documents/{document.id!s}/ai-proxy/",
|
f"/api/v1.0/documents/{document.id!s}/ai-proxy/",
|
||||||
{
|
b"{}",
|
||||||
"messages": [{"role": "user", "content": "Hello"}],
|
content_type="application/json",
|
||||||
"model": "llama",
|
|
||||||
},
|
|
||||||
format="json",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response.status_code == 400
|
assert response.status_code == 400
|
||||||
assert response.json() == ["AI feature is not enabled."]
|
assert response.json() == ["AI feature is not enabled."]
|
||||||
|
|
||||||
|
|
||||||
@override_settings(AI_STREAM=False)
|
@override_settings(AI_DOCUMENT_RATE_THROTTLE_RATES={"minute": 3, "hour": 6, "day": 10})
|
||||||
@patch("openai.resources.chat.completions.Completions.create")
|
@patch("core.services.ai_services.AIService.stream")
|
||||||
def test_api_documents_ai_proxy_harden_payload(mock_create):
|
def test_api_documents_ai_proxy_throttling_document(mock_stream):
|
||||||
"""
|
"""
|
||||||
AI proxy should harden the payload by default:
|
Throttling per document should be triggered on the AI proxy endpoint.
|
||||||
- it will remove stream_options if stream is False
|
For full throttle class test see: `test_api_utils_ai_document_rate_throttles`
|
||||||
- normalize tools by adding descriptions if missing
|
|
||||||
"""
|
"""
|
||||||
|
client = APIClient()
|
||||||
|
document = factories.DocumentFactory(link_reach="public", link_role="editor")
|
||||||
|
|
||||||
|
mock_stream.return_value = iter(["data: ok\n"])
|
||||||
|
|
||||||
|
url = f"/api/v1.0/documents/{document.id!s}/ai-proxy/"
|
||||||
|
for _ in range(3):
|
||||||
|
mock_stream.return_value = iter(["data: ok\n"])
|
||||||
|
user = factories.UserFactory()
|
||||||
|
client.force_login(user)
|
||||||
|
response = client.post(
|
||||||
|
url,
|
||||||
|
b"{}",
|
||||||
|
content_type="application/json",
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
user = factories.UserFactory()
|
||||||
|
client.force_login(user)
|
||||||
|
response = client.post(
|
||||||
|
url,
|
||||||
|
b"{}",
|
||||||
|
content_type="application/json",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 429
|
||||||
|
assert response.json() == {
|
||||||
|
"detail": "Request was throttled. Expected available in 60 seconds."
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@override_settings(AI_USER_RATE_THROTTLE_RATES={"minute": 3, "hour": 6, "day": 10})
|
||||||
|
@patch("core.services.ai_services.AIService.stream")
|
||||||
|
def test_api_documents_ai_proxy_throttling_user(mock_stream):
|
||||||
|
"""
|
||||||
|
Throttling per user should be triggered on the AI proxy endpoint.
|
||||||
|
For full throttle class test see: `test_api_utils_ai_user_rate_throttles`
|
||||||
|
"""
|
||||||
|
user = factories.UserFactory()
|
||||||
|
client = APIClient()
|
||||||
|
client.force_login(user)
|
||||||
|
|
||||||
|
for _ in range(3):
|
||||||
|
mock_stream.return_value = iter(["data: ok\n"])
|
||||||
|
document = factories.DocumentFactory(link_reach="public", link_role="editor")
|
||||||
|
url = f"/api/v1.0/documents/{document.id!s}/ai-proxy/"
|
||||||
|
response = client.post(
|
||||||
|
url,
|
||||||
|
b"{}",
|
||||||
|
content_type="application/json",
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
document = factories.DocumentFactory(link_reach="public", link_role="editor")
|
||||||
|
url = f"/api/v1.0/documents/{document.id!s}/ai-proxy/"
|
||||||
|
response = client.post(
|
||||||
|
url,
|
||||||
|
b"{}",
|
||||||
|
content_type="application/json",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 429
|
||||||
|
assert response.json() == {
|
||||||
|
"detail": "Request was throttled. Expected available in 60 seconds."
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@patch("core.services.ai_services.AIService.stream")
|
||||||
|
def test_api_documents_ai_proxy_returns_streaming_response(mock_stream):
|
||||||
|
"""AI proxy should return a StreamingHttpResponse with correct headers."""
|
||||||
user = factories.UserFactory()
|
user = factories.UserFactory()
|
||||||
|
|
||||||
client = APIClient()
|
client = APIClient()
|
||||||
@@ -724,76 +344,39 @@ def test_api_documents_ai_proxy_harden_payload(mock_create):
|
|||||||
|
|
||||||
document = factories.DocumentFactory(link_reach="public", link_role="editor")
|
document = factories.DocumentFactory(link_reach="public", link_role="editor")
|
||||||
|
|
||||||
mock_response = MagicMock()
|
mock_stream.return_value = iter(["data: part1\n", "data: part2\n", "data: part3\n"])
|
||||||
mock_response.model_dump.return_value = {
|
|
||||||
"id": "chatcmpl-789",
|
|
||||||
"object": "chat.completion",
|
|
||||||
"model": "llama",
|
|
||||||
"choices": [
|
|
||||||
{
|
|
||||||
"index": 0,
|
|
||||||
"message": {"role": "assistant", "content": "Success!"},
|
|
||||||
"finish_reason": "stop",
|
|
||||||
}
|
|
||||||
],
|
|
||||||
}
|
|
||||||
mock_create.return_value = mock_response
|
|
||||||
|
|
||||||
url = f"/api/v1.0/documents/{document.id!s}/ai-proxy/"
|
url = f"/api/v1.0/documents/{document.id!s}/ai-proxy/"
|
||||||
response = client.post(
|
response = client.post(
|
||||||
url,
|
url,
|
||||||
{
|
b"{}",
|
||||||
"messages": [{"role": "user", "content": "Hello"}],
|
content_type="application/json",
|
||||||
"model": "llama",
|
|
||||||
"stream": False,
|
|
||||||
"stream_options": {"include_usage": True},
|
|
||||||
"tools": [
|
|
||||||
{
|
|
||||||
"type": "function",
|
|
||||||
"function": {
|
|
||||||
"name": "applyDocumentOperations",
|
|
||||||
"parameters": {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"operations": {"type": "array", "items": {"anyOf": []}}
|
|
||||||
},
|
|
||||||
"additionalProperties": False,
|
|
||||||
"required": ["operations"],
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
],
|
|
||||||
},
|
|
||||||
format="json",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
response_data = response.json()
|
assert response["Content-Type"] == "text/event-stream"
|
||||||
assert response_data["id"] == "chatcmpl-789"
|
assert response["x-vercel-ai-data-stream"] == "v1"
|
||||||
assert response_data["choices"][0]["message"]["content"] == "Success!"
|
assert response["X-Accel-Buffering"] == "no"
|
||||||
|
|
||||||
mock_create.assert_called_once_with(
|
chunks = list(response.streaming_content)
|
||||||
messages=[
|
assert len(chunks) == 3
|
||||||
{"role": "system", "content": BLOCKNOTE_TOOL_STRICT_PROMPT},
|
|
||||||
{"role": "user", "content": "Hello"},
|
|
||||||
],
|
def test_api_documents_ai_proxy_invalid_payload():
|
||||||
model="llama",
|
"""AI Proxy should return a 400 if the payload is invalid."""
|
||||||
stream=False,
|
|
||||||
tools=[
|
user = factories.UserFactory()
|
||||||
{
|
|
||||||
"type": "function",
|
document = factories.DocumentFactory(users=[(user, "owner")])
|
||||||
"function": {
|
|
||||||
"name": "applyDocumentOperations",
|
client = APIClient()
|
||||||
"parameters": {
|
client.force_login(user)
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
response = client.post(
|
||||||
"operations": {"type": "array", "items": {"anyOf": []}}
|
f"/api/v1.0/documents/{document.id!s}/ai-proxy/",
|
||||||
},
|
b'{"foo": "bar", "trigger": "submit-message"}',
|
||||||
"additionalProperties": False,
|
content_type="application/json",
|
||||||
"required": ["operations"],
|
|
||||||
},
|
|
||||||
"description": "Tool applyDocumentOperations.",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
],
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 400
|
||||||
|
assert response.json() == {"detail": "Invalid submitted payload"}
|
||||||
|
|||||||
@@ -1,7 +1,9 @@
|
|||||||
"""
|
"""
|
||||||
Test ai API endpoints in the impress core app.
|
Test AI services in the impress core app.
|
||||||
"""
|
"""
|
||||||
|
# pylint: disable=protected-access
|
||||||
|
|
||||||
|
from collections.abc import AsyncIterator
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
from django.core.exceptions import ImproperlyConfigured
|
from django.core.exceptions import ImproperlyConfigured
|
||||||
@@ -9,8 +11,12 @@ from django.test.utils import override_settings
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from openai import OpenAIError
|
from openai import OpenAIError
|
||||||
|
from pydantic_ai.ui.vercel_ai.request_types import TextUIPart, UIMessage
|
||||||
|
|
||||||
from core.services.ai_services import BLOCKNOTE_TOOL_STRICT_PROMPT, AIService
|
from core.services.ai_services import (
|
||||||
|
AIService,
|
||||||
|
convert_async_generator_to_sync,
|
||||||
|
)
|
||||||
|
|
||||||
pytestmark = pytest.mark.django_db
|
pytestmark = pytest.mark.django_db
|
||||||
|
|
||||||
@@ -22,6 +28,11 @@ def ai_settings(settings):
|
|||||||
settings.AI_BASE_URL = "http://example.com"
|
settings.AI_BASE_URL = "http://example.com"
|
||||||
settings.AI_API_KEY = "test-key"
|
settings.AI_API_KEY = "test-key"
|
||||||
settings.AI_FEATURE_ENABLED = True
|
settings.AI_FEATURE_ENABLED = True
|
||||||
|
settings.LANGFUSE_PUBLIC_KEY = None
|
||||||
|
settings.AI_VERCEL_SDK_VERSION = 6
|
||||||
|
|
||||||
|
|
||||||
|
# -- AIService.__init__ --
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@@ -43,6 +54,9 @@ def test_services_ai_setting_missing(setting_name, setting_value, settings):
|
|||||||
AIService()
|
AIService()
|
||||||
|
|
||||||
|
|
||||||
|
# -- AIService.transform --
|
||||||
|
|
||||||
|
|
||||||
@override_settings(
|
@override_settings(
|
||||||
AI_BASE_URL="http://example.com", AI_API_KEY="test-key", AI_MODEL="test-model"
|
AI_BASE_URL="http://example.com", AI_API_KEY="test-key", AI_MODEL="test-model"
|
||||||
)
|
)
|
||||||
@@ -59,19 +73,6 @@ def test_services_ai_client_error(mock_create):
|
|||||||
AIService().transform("hello", "prompt")
|
AIService().transform("hello", "prompt")
|
||||||
|
|
||||||
|
|
||||||
@patch("openai.resources.chat.completions.Completions.create")
|
|
||||||
def test_services_ai_proxy_client_error(mock_create):
|
|
||||||
"""Fail when the client raises an error"""
|
|
||||||
|
|
||||||
mock_create.side_effect = OpenAIError("Mocked client error")
|
|
||||||
|
|
||||||
with pytest.raises(
|
|
||||||
RuntimeError,
|
|
||||||
match="Failed to proxy AI request: Mocked client error",
|
|
||||||
):
|
|
||||||
AIService().proxy({"messages": [{"role": "user", "content": "hello"}]})
|
|
||||||
|
|
||||||
|
|
||||||
@override_settings(
|
@override_settings(
|
||||||
AI_BASE_URL="http://example.com", AI_API_KEY="test-key", AI_MODEL="test-model"
|
AI_BASE_URL="http://example.com", AI_API_KEY="test-key", AI_MODEL="test-model"
|
||||||
)
|
)
|
||||||
@@ -90,49 +91,6 @@ def test_services_ai_client_invalid_response(mock_create):
|
|||||||
AIService().transform("hello", "prompt")
|
AIService().transform("hello", "prompt")
|
||||||
|
|
||||||
|
|
||||||
@patch("openai.resources.chat.completions.Completions.create")
|
|
||||||
def test_services_ai_proxy_success(mock_create):
|
|
||||||
"""The AI request should work as expect when called with valid arguments."""
|
|
||||||
|
|
||||||
mock_create.return_value = {
|
|
||||||
"id": "chatcmpl-test",
|
|
||||||
"object": "chat.completion",
|
|
||||||
"created": 1234567890,
|
|
||||||
"model": "test-model",
|
|
||||||
"choices": [
|
|
||||||
{
|
|
||||||
"index": 0,
|
|
||||||
"message": {"role": "assistant", "content": "Salut"},
|
|
||||||
"finish_reason": "stop",
|
|
||||||
}
|
|
||||||
],
|
|
||||||
}
|
|
||||||
|
|
||||||
response = AIService().proxy({"messages": [{"role": "user", "content": "hello"}]})
|
|
||||||
|
|
||||||
expected_response = {
|
|
||||||
"id": "chatcmpl-test",
|
|
||||||
"object": "chat.completion",
|
|
||||||
"created": 1234567890,
|
|
||||||
"model": "test-model",
|
|
||||||
"choices": [
|
|
||||||
{
|
|
||||||
"index": 0,
|
|
||||||
"message": {"role": "assistant", "content": "Salut"},
|
|
||||||
"finish_reason": "stop",
|
|
||||||
}
|
|
||||||
],
|
|
||||||
}
|
|
||||||
assert response == expected_response
|
|
||||||
mock_create.assert_called_once_with(
|
|
||||||
messages=[
|
|
||||||
{"role": "system", "content": BLOCKNOTE_TOOL_STRICT_PROMPT},
|
|
||||||
{"role": "user", "content": "hello"},
|
|
||||||
],
|
|
||||||
stream=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@override_settings(
|
@override_settings(
|
||||||
AI_BASE_URL="http://example.com", AI_API_KEY="test-key", AI_MODEL="test-model"
|
AI_BASE_URL="http://example.com", AI_API_KEY="test-key", AI_MODEL="test-model"
|
||||||
)
|
)
|
||||||
@@ -149,46 +107,438 @@ def test_services_ai_success(mock_create):
|
|||||||
assert response == {"answer": "Salut"}
|
assert response == {"answer": "Salut"}
|
||||||
|
|
||||||
|
|
||||||
|
# -- AIService.translate --
|
||||||
|
|
||||||
|
|
||||||
@patch("openai.resources.chat.completions.Completions.create")
|
@patch("openai.resources.chat.completions.Completions.create")
|
||||||
def test_services_ai_proxy_with_stream(mock_create):
|
def test_services_ai_translate_success(mock_create):
|
||||||
"""The AI request should work as expect when called with valid arguments."""
|
"""Translate should call the AI API with the correct language prompt."""
|
||||||
|
|
||||||
mock_create.return_value = {
|
mock_create.return_value = MagicMock(
|
||||||
"id": "chatcmpl-test",
|
choices=[MagicMock(message=MagicMock(content="Bonjour"))]
|
||||||
"object": "chat.completion",
|
|
||||||
"created": 1234567890,
|
|
||||||
"model": "test-model",
|
|
||||||
"choices": [
|
|
||||||
{
|
|
||||||
"index": 0,
|
|
||||||
"message": {"role": "assistant", "content": "Salut"},
|
|
||||||
"finish_reason": "stop",
|
|
||||||
}
|
|
||||||
],
|
|
||||||
}
|
|
||||||
|
|
||||||
response = AIService().proxy(
|
|
||||||
{"messages": [{"role": "user", "content": "hello"}]}, stream=True
|
|
||||||
)
|
)
|
||||||
|
|
||||||
expected_response = {
|
response = AIService().translate("<p>Hello</p>", "fr")
|
||||||
"id": "chatcmpl-test",
|
|
||||||
"object": "chat.completion",
|
assert response == {"answer": "Bonjour"}
|
||||||
"created": 1234567890,
|
call_args = mock_create.call_args
|
||||||
"model": "test-model",
|
system_content = call_args[1]["messages"][0]["content"]
|
||||||
"choices": [
|
assert "French" in system_content or "fr" in system_content
|
||||||
{
|
|
||||||
"index": 0,
|
|
||||||
"message": {"role": "assistant", "content": "Salut"},
|
@patch("openai.resources.chat.completions.Completions.create")
|
||||||
"finish_reason": "stop",
|
def test_services_ai_translate_unknown_language(mock_create):
|
||||||
}
|
"""Translate with an unknown language code should use the code as-is."""
|
||||||
],
|
|
||||||
}
|
mock_create.return_value = MagicMock(
|
||||||
assert response == expected_response
|
choices=[MagicMock(message=MagicMock(content="Translated"))]
|
||||||
mock_create.assert_called_once_with(
|
|
||||||
messages=[
|
|
||||||
{"role": "system", "content": BLOCKNOTE_TOOL_STRICT_PROMPT},
|
|
||||||
{"role": "user", "content": "hello"},
|
|
||||||
],
|
|
||||||
stream=True,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
response = AIService().translate("<p>Hello</p>", "xx-unknown")
|
||||||
|
|
||||||
|
assert response == {"answer": "Translated"}
|
||||||
|
call_args = mock_create.call_args
|
||||||
|
system_content = call_args[1]["messages"][0]["content"]
|
||||||
|
assert "xx-unknown" in system_content
|
||||||
|
|
||||||
|
|
||||||
|
# -- convert_async_generator_to_sync --
|
||||||
|
|
||||||
|
|
||||||
|
def test_convert_async_generator_to_sync_basic():
|
||||||
|
"""Should convert an async generator yielding items to a sync iterator."""
|
||||||
|
|
||||||
|
async def async_gen():
|
||||||
|
for item in ["hello", "world", "!"]:
|
||||||
|
yield item
|
||||||
|
|
||||||
|
result = list(convert_async_generator_to_sync(async_gen()))
|
||||||
|
assert result == ["hello", "world", "!"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_convert_async_generator_to_sync_empty():
|
||||||
|
"""Should handle an empty async generator."""
|
||||||
|
|
||||||
|
async def async_gen():
|
||||||
|
return
|
||||||
|
yield
|
||||||
|
|
||||||
|
result = list(convert_async_generator_to_sync(async_gen()))
|
||||||
|
assert not result
|
||||||
|
|
||||||
|
|
||||||
|
def test_convert_async_generator_to_sync_exception():
|
||||||
|
"""Should propagate exceptions from the async generator."""
|
||||||
|
|
||||||
|
async def async_gen():
|
||||||
|
yield "first"
|
||||||
|
raise ValueError("async error")
|
||||||
|
|
||||||
|
sync_iter = convert_async_generator_to_sync(async_gen())
|
||||||
|
assert next(sync_iter) == "first"
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="async error"):
|
||||||
|
next(sync_iter)
|
||||||
|
|
||||||
|
|
||||||
|
# -- AIService.inject_document_state_messages --
|
||||||
|
|
||||||
|
|
||||||
|
def test_inject_document_state_messages_no_metadata():
|
||||||
|
"""Messages without documentState metadata should pass through unchanged."""
|
||||||
|
messages = [
|
||||||
|
UIMessage(role="user", id="msg-1", parts=[TextUIPart(text="Hello")]),
|
||||||
|
]
|
||||||
|
|
||||||
|
result = AIService.inject_document_state_messages(messages)
|
||||||
|
|
||||||
|
assert len(result) == 1
|
||||||
|
assert result[0].id == "msg-1"
|
||||||
|
|
||||||
|
|
||||||
|
def test_inject_document_state_messages_with_selection():
|
||||||
|
"""A user message with documentState and selection should get an
|
||||||
|
assistant context message prepended."""
|
||||||
|
messages = [
|
||||||
|
UIMessage(
|
||||||
|
role="user",
|
||||||
|
id="msg-1",
|
||||||
|
parts=[TextUIPart(text="Fix this")],
|
||||||
|
metadata={
|
||||||
|
"documentState": {
|
||||||
|
"selection": {"start": 0, "end": 5},
|
||||||
|
"selectedBlocks": [{"type": "paragraph", "content": "Hello"}],
|
||||||
|
"blocks": [
|
||||||
|
{"type": "paragraph", "content": "Hello"},
|
||||||
|
{"type": "paragraph", "content": "World"},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
},
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
result = AIService.inject_document_state_messages(messages)
|
||||||
|
|
||||||
|
assert len(result) == 2
|
||||||
|
# First message should be the injected assistant context
|
||||||
|
assert result[0].role == "assistant"
|
||||||
|
assert result[0].id == "assistant-document-state-msg-1"
|
||||||
|
assert len(result[0].parts) == 4
|
||||||
|
assert "selection" in result[0].parts[0].text.lower()
|
||||||
|
# Second message should be the original user message
|
||||||
|
assert result[1].id == "msg-1"
|
||||||
|
|
||||||
|
|
||||||
|
def test_inject_document_state_messages_without_selection():
|
||||||
|
"""A user message with documentState but no selection should describe
|
||||||
|
the full document context."""
|
||||||
|
messages = [
|
||||||
|
UIMessage(
|
||||||
|
role="user",
|
||||||
|
id="msg-1",
|
||||||
|
parts=[TextUIPart(text="Summarize")],
|
||||||
|
metadata={
|
||||||
|
"documentState": {
|
||||||
|
"selection": None,
|
||||||
|
"blocks": [
|
||||||
|
{"type": "paragraph", "content": "Hello"},
|
||||||
|
],
|
||||||
|
"isEmptyDocument": False,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
result = AIService.inject_document_state_messages(messages)
|
||||||
|
|
||||||
|
assert len(result) == 2
|
||||||
|
assistant_msg = result[0]
|
||||||
|
assert assistant_msg.role == "assistant"
|
||||||
|
assert len(assistant_msg.parts) == 2
|
||||||
|
assert "no active selection" in assistant_msg.parts[0].text.lower()
|
||||||
|
assert "prefer updating" in assistant_msg.parts[0].text.lower()
|
||||||
|
|
||||||
|
|
||||||
|
def test_inject_document_state_messages_empty_document():
|
||||||
|
"""When the document is empty, the injected message should instruct
|
||||||
|
updating the empty block first."""
|
||||||
|
messages = [
|
||||||
|
UIMessage(
|
||||||
|
role="user",
|
||||||
|
id="msg-1",
|
||||||
|
parts=[TextUIPart(text="Write something")],
|
||||||
|
metadata={
|
||||||
|
"documentState": {
|
||||||
|
"selection": None,
|
||||||
|
"blocks": [{"type": "paragraph", "content": ""}],
|
||||||
|
"isEmptyDocument": True,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
result = AIService.inject_document_state_messages(messages)
|
||||||
|
|
||||||
|
assert len(result) == 2
|
||||||
|
assistant_msg = result[0]
|
||||||
|
assert "update the empty block" in assistant_msg.parts[0].text.lower()
|
||||||
|
|
||||||
|
|
||||||
|
def test_inject_document_state_messages_mixed():
|
||||||
|
"""Only user messages with documentState get assistant context;
|
||||||
|
other messages pass through unchanged."""
|
||||||
|
messages = [
|
||||||
|
UIMessage(
|
||||||
|
role="assistant",
|
||||||
|
id="msg-0",
|
||||||
|
parts=[TextUIPart(text="Previous response")],
|
||||||
|
),
|
||||||
|
UIMessage(
|
||||||
|
role="user",
|
||||||
|
id="msg-1",
|
||||||
|
parts=[TextUIPart(text="Hello")],
|
||||||
|
),
|
||||||
|
UIMessage(
|
||||||
|
role="user",
|
||||||
|
id="msg-2",
|
||||||
|
parts=[TextUIPart(text="Fix this")],
|
||||||
|
metadata={
|
||||||
|
"documentState": {
|
||||||
|
"selection": {"start": 0, "end": 5},
|
||||||
|
"selectedBlocks": [{"type": "paragraph", "content": "Hello"}],
|
||||||
|
"blocks": [{"type": "paragraph", "content": "Hello"}],
|
||||||
|
}
|
||||||
|
},
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
result = AIService.inject_document_state_messages(messages)
|
||||||
|
|
||||||
|
# 3 original + 1 injected assistant message before msg-2
|
||||||
|
assert len(result) == 4
|
||||||
|
assert result[0].id == "msg-0"
|
||||||
|
assert result[1].id == "msg-1"
|
||||||
|
assert result[2].role == "assistant"
|
||||||
|
assert result[2].id == "assistant-document-state-msg-2"
|
||||||
|
assert result[3].id == "msg-2"
|
||||||
|
|
||||||
|
|
||||||
|
# -- AIService.tool_definitions_to_toolset --
|
||||||
|
|
||||||
|
|
||||||
|
def test_tool_definitions_to_toolset():
|
||||||
|
"""Should convert frontend tool definitions to an ExternalToolset."""
|
||||||
|
tool_definitions = {
|
||||||
|
"applyOperations": {
|
||||||
|
"description": "Apply operations to the document",
|
||||||
|
"inputSchema": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"operations": {"type": "array"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"outputSchema": {"type": "object"},
|
||||||
|
},
|
||||||
|
"insertBlocks": {
|
||||||
|
"description": "Insert blocks",
|
||||||
|
"inputSchema": {"type": "object"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
toolset = AIService.tool_definitions_to_toolset(tool_definitions)
|
||||||
|
|
||||||
|
# The ExternalToolset wraps ToolDefinition objects
|
||||||
|
assert toolset is not None
|
||||||
|
# Access internal tool definitions
|
||||||
|
tool_defs = toolset.tool_defs
|
||||||
|
assert len(tool_defs) == 2
|
||||||
|
|
||||||
|
names = {td.name for td in tool_defs}
|
||||||
|
assert names == {"applyOperations", "insertBlocks"}
|
||||||
|
|
||||||
|
for td in tool_defs:
|
||||||
|
assert td.kind == "external"
|
||||||
|
if td.name == "applyOperations":
|
||||||
|
assert td.description == "Apply operations to the document"
|
||||||
|
assert td.metadata == {"output_schema": {"type": "object"}}
|
||||||
|
|
||||||
|
|
||||||
|
def test_tool_definitions_to_toolset_missing_fields():
|
||||||
|
"""Should handle tool definitions with missing optional fields."""
|
||||||
|
tool_definitions = {
|
||||||
|
"myTool": {},
|
||||||
|
}
|
||||||
|
|
||||||
|
toolset = AIService.tool_definitions_to_toolset(tool_definitions)
|
||||||
|
|
||||||
|
tool_defs = toolset.tool_defs
|
||||||
|
assert len(tool_defs) == 1
|
||||||
|
assert tool_defs[0].name == "myTool"
|
||||||
|
assert tool_defs[0].description == ""
|
||||||
|
assert tool_defs[0].parameters_json_schema == {}
|
||||||
|
assert tool_defs[0].metadata == {"output_schema": None}
|
||||||
|
|
||||||
|
|
||||||
|
# -- AIService.stream --
|
||||||
|
|
||||||
|
|
||||||
|
@patch.object(AIService, "_build_async_stream")
|
||||||
|
def test_services_ai_stream_sync_mode(mock_build, monkeypatch):
|
||||||
|
"""In sync mode, stream() should return a sync iterator."""
|
||||||
|
|
||||||
|
async def mock_async_gen():
|
||||||
|
yield "chunk1"
|
||||||
|
yield "chunk2"
|
||||||
|
|
||||||
|
mock_build.return_value = mock_async_gen()
|
||||||
|
monkeypatch.setenv("PYTHON_SERVER_MODE", "sync")
|
||||||
|
|
||||||
|
service = AIService()
|
||||||
|
request = MagicMock()
|
||||||
|
result = service.stream(request)
|
||||||
|
|
||||||
|
# Should be a regular (sync) iterator, not async
|
||||||
|
assert not isinstance(result, AsyncIterator)
|
||||||
|
assert list(result) == ["chunk1", "chunk2"]
|
||||||
|
mock_build.assert_called_once_with(request)
|
||||||
|
|
||||||
|
|
||||||
|
@patch.object(AIService, "_build_async_stream")
|
||||||
|
def test_services_ai_stream_async_mode(mock_build, monkeypatch):
|
||||||
|
"""In async mode, stream() should return the async iterator directly."""
|
||||||
|
|
||||||
|
async def mock_async_gen():
|
||||||
|
yield "chunk1"
|
||||||
|
yield "chunk2"
|
||||||
|
|
||||||
|
mock_async_iter = mock_async_gen()
|
||||||
|
mock_build.return_value = mock_async_iter
|
||||||
|
monkeypatch.setenv("PYTHON_SERVER_MODE", "async")
|
||||||
|
|
||||||
|
service = AIService()
|
||||||
|
request = MagicMock()
|
||||||
|
result = service.stream(request)
|
||||||
|
|
||||||
|
assert result is mock_async_iter
|
||||||
|
mock_build.assert_called_once_with(request)
|
||||||
|
|
||||||
|
|
||||||
|
@patch.object(AIService, "_build_async_stream")
|
||||||
|
def test_services_ai_stream_defaults_to_sync(mock_build, monkeypatch):
|
||||||
|
"""When PYTHON_SERVER_MODE is not set, stream() should default to sync."""
|
||||||
|
|
||||||
|
async def mock_async_gen():
|
||||||
|
yield "data"
|
||||||
|
|
||||||
|
mock_build.return_value = mock_async_gen()
|
||||||
|
monkeypatch.delenv("PYTHON_SERVER_MODE", raising=False)
|
||||||
|
|
||||||
|
service = AIService()
|
||||||
|
request = MagicMock()
|
||||||
|
result = service.stream(request)
|
||||||
|
|
||||||
|
# Default should be sync mode
|
||||||
|
assert not isinstance(result, AsyncIterator)
|
||||||
|
assert list(result) == ["data"]
|
||||||
|
|
||||||
|
|
||||||
|
# -- AIService._build_async_stream --
|
||||||
|
|
||||||
|
|
||||||
|
@patch("core.services.ai_services.VercelAIAdapter")
|
||||||
|
def test_services_ai_build_async_stream(mock_adapter_cls):
|
||||||
|
"""_build_async_stream should build the pydantic-ai streaming pipeline."""
|
||||||
|
|
||||||
|
async def mock_encode():
|
||||||
|
yield "event-data"
|
||||||
|
|
||||||
|
mock_run_input = MagicMock()
|
||||||
|
mock_run_input.model_extra = None
|
||||||
|
mock_run_input.messages = []
|
||||||
|
mock_adapter_cls.build_run_input.return_value = mock_run_input
|
||||||
|
|
||||||
|
mock_adapter_instance = MagicMock()
|
||||||
|
mock_adapter_instance.run_stream.return_value = MagicMock()
|
||||||
|
mock_adapter_instance.encode_stream.return_value = mock_encode()
|
||||||
|
mock_adapter_cls.return_value = mock_adapter_instance
|
||||||
|
|
||||||
|
service = AIService()
|
||||||
|
request = MagicMock()
|
||||||
|
request.META = {"HTTP_ACCEPT": "text/event-stream"}
|
||||||
|
request.raw_body = b'{"messages": []}'
|
||||||
|
|
||||||
|
result = service._build_async_stream(request)
|
||||||
|
assert isinstance(result, AsyncIterator)
|
||||||
|
mock_adapter_cls.build_run_input.assert_called_once_with(b'{"messages": []}')
|
||||||
|
mock_adapter_instance.run_stream.assert_called_once()
|
||||||
|
mock_adapter_instance.encode_stream.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
@patch("core.services.ai_services.VercelAIAdapter")
|
||||||
|
def test_services_ai_build_async_stream_with_tool_definitions(mock_adapter_cls):
|
||||||
|
"""_build_async_stream should build an ExternalToolset when
|
||||||
|
toolDefinitions are present in the request."""
|
||||||
|
|
||||||
|
async def mock_encode():
|
||||||
|
yield "event-data"
|
||||||
|
|
||||||
|
mock_run_input = MagicMock()
|
||||||
|
mock_run_input.model_extra = {
|
||||||
|
"toolDefinitions": {
|
||||||
|
"myTool": {
|
||||||
|
"description": "A tool",
|
||||||
|
"inputSchema": {"type": "object"},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
mock_run_input.messages = []
|
||||||
|
mock_adapter_cls.build_run_input.return_value = mock_run_input
|
||||||
|
|
||||||
|
mock_adapter_instance = MagicMock()
|
||||||
|
mock_adapter_instance.run_stream.return_value = MagicMock()
|
||||||
|
mock_adapter_instance.encode_stream.return_value = mock_encode()
|
||||||
|
mock_adapter_cls.return_value = mock_adapter_instance
|
||||||
|
|
||||||
|
service = AIService()
|
||||||
|
request = MagicMock()
|
||||||
|
request.META = {}
|
||||||
|
request.raw_body = b"{}"
|
||||||
|
|
||||||
|
service._build_async_stream(request)
|
||||||
|
# run_stream should have been called with a toolset
|
||||||
|
call_kwargs = mock_adapter_instance.run_stream.call_args[1]
|
||||||
|
assert call_kwargs["toolsets"] is not None
|
||||||
|
assert len(call_kwargs["toolsets"]) == 1
|
||||||
|
|
||||||
|
|
||||||
|
@patch("core.services.ai_services.Agent")
|
||||||
|
@patch("core.services.ai_services.VercelAIAdapter")
|
||||||
|
def test_services_ai_build_async_stream_langfuse_enabled(
|
||||||
|
mock_adapter_cls, mock_agent_cls, settings
|
||||||
|
):
|
||||||
|
"""When LANGFUSE_PUBLIC_KEY is set, instrument should be enabled."""
|
||||||
|
settings.LANGFUSE_PUBLIC_KEY = "pk-test-123"
|
||||||
|
|
||||||
|
async def mock_encode():
|
||||||
|
yield "data"
|
||||||
|
|
||||||
|
mock_run_input = MagicMock()
|
||||||
|
mock_run_input.model_extra = None
|
||||||
|
mock_run_input.messages = []
|
||||||
|
mock_adapter_cls.build_run_input.return_value = mock_run_input
|
||||||
|
|
||||||
|
mock_adapter_instance = MagicMock()
|
||||||
|
mock_adapter_instance.run_stream.return_value = MagicMock()
|
||||||
|
mock_adapter_instance.encode_stream.return_value = mock_encode()
|
||||||
|
mock_adapter_cls.return_value = mock_adapter_instance
|
||||||
|
|
||||||
|
service = AIService()
|
||||||
|
request = MagicMock()
|
||||||
|
request.META = {}
|
||||||
|
request.raw_body = b"{}"
|
||||||
|
|
||||||
|
service._build_async_stream(request)
|
||||||
|
mock_agent_cls.instrument_all.assert_called_once()
|
||||||
|
# Agent should be created with instrument=True
|
||||||
|
mock_agent_cls.assert_called_once()
|
||||||
|
assert mock_agent_cls.call_args[1]["instrument"] is True
|
||||||
|
|||||||
@@ -326,6 +326,7 @@ class Base(Configuration):
|
|||||||
"django.middleware.csrf.CsrfViewMiddleware",
|
"django.middleware.csrf.CsrfViewMiddleware",
|
||||||
"django.contrib.auth.middleware.AuthenticationMiddleware",
|
"django.contrib.auth.middleware.AuthenticationMiddleware",
|
||||||
"core.middleware.ForceSessionMiddleware",
|
"core.middleware.ForceSessionMiddleware",
|
||||||
|
"core.middleware.SaveRawBodyMiddleware",
|
||||||
"django.contrib.messages.middleware.MessageMiddleware",
|
"django.contrib.messages.middleware.MessageMiddleware",
|
||||||
"dockerflow.django.middleware.DockerflowMiddleware",
|
"dockerflow.django.middleware.DockerflowMiddleware",
|
||||||
"csp.middleware.CSPMiddleware",
|
"csp.middleware.CSPMiddleware",
|
||||||
@@ -716,6 +717,9 @@ class Base(Configuration):
|
|||||||
AI_STREAM = values.BooleanValue(
|
AI_STREAM = values.BooleanValue(
|
||||||
default=True, environ_name="AI_STREAM", environ_prefix=None
|
default=True, environ_name="AI_STREAM", environ_prefix=None
|
||||||
)
|
)
|
||||||
|
AI_VERCEL_SDK_VERSION = values.IntegerValue(
|
||||||
|
6, environ_name="AI_VERCEL_SDK_VERSION", environ_prefix=None
|
||||||
|
)
|
||||||
AI_USER_RATE_THROTTLE_RATES = {
|
AI_USER_RATE_THROTTLE_RATES = {
|
||||||
"minute": 3,
|
"minute": 3,
|
||||||
"hour": 50,
|
"hour": 50,
|
||||||
|
|||||||
@@ -53,9 +53,11 @@ dependencies = [
|
|||||||
"markdown==3.10",
|
"markdown==3.10",
|
||||||
"mozilla-django-oidc==5.0.2",
|
"mozilla-django-oidc==5.0.2",
|
||||||
"nested-multipart-parser==1.6.0",
|
"nested-multipart-parser==1.6.0",
|
||||||
"openai==2.14.0",
|
"openai==2.21.0",
|
||||||
"psycopg[binary]==3.3.2",
|
"psycopg[binary]==3.3.2",
|
||||||
"pycrdt==0.12.44",
|
"pycrdt==0.12.44",
|
||||||
|
"pydantic==2.12.5",
|
||||||
|
"pydantic-ai-slim[openai,logfire,web]==1.58.0",
|
||||||
"PyJWT==2.10.1",
|
"PyJWT==2.10.1",
|
||||||
"python-magic==0.4.27",
|
"python-magic==0.4.27",
|
||||||
"redis<6.0.0",
|
"redis<6.0.0",
|
||||||
|
|||||||
Reference in New Issue
Block a user