✨(backend) use pydantic AI to manage vercel data stream protocol
The frontend application is using Vercel AI SDK and it's data stream protocol. We decided to use the pydantic AI library to use it's vercel ai adapter. It will make the payload validation, use AsyncIterator and deal with vercel specification.
This commit is contained in:
committed by
Anthony LC
parent
050b106a8f
commit
8ce216f6e8
@@ -1,55 +1,38 @@
|
||||
"""AI services."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Dict, Generator
|
||||
import os
|
||||
import queue
|
||||
import threading
|
||||
from collections.abc import AsyncIterator, Iterator
|
||||
from typing import Any, Dict, Union
|
||||
|
||||
from django.conf import settings
|
||||
from django.core.exceptions import ImproperlyConfigured
|
||||
|
||||
from openai import OpenAIError
|
||||
from langfuse import get_client
|
||||
from langfuse.openai import OpenAI as OpenAI_Langfuse
|
||||
from pydantic_ai import Agent, DeferredToolRequests
|
||||
from pydantic_ai.models.openai import OpenAIChatModel
|
||||
from pydantic_ai.providers.openai import OpenAIProvider
|
||||
from pydantic_ai.tools import ToolDefinition
|
||||
from pydantic_ai.toolsets.external import ExternalToolset
|
||||
from pydantic_ai.ui import SSE_CONTENT_TYPE
|
||||
from pydantic_ai.ui.vercel_ai import VercelAIAdapter
|
||||
from pydantic_ai.ui.vercel_ai.request_types import TextUIPart, UIMessage
|
||||
from rest_framework.request import Request
|
||||
|
||||
from core import enums
|
||||
|
||||
if settings.LANGFUSE_PUBLIC_KEY:
|
||||
from langfuse.openai import OpenAI
|
||||
OpenAI = OpenAI_Langfuse
|
||||
else:
|
||||
from openai import OpenAI
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
BLOCKNOTE_TOOL_STRICT_PROMPT = """
|
||||
You are editing a BlockNote document via the tool applyDocumentOperations.
|
||||
|
||||
You MUST respond ONLY by calling applyDocumentOperations.
|
||||
The tool input MUST be valid JSON:
|
||||
{ "operations": [ ... ] }
|
||||
|
||||
Each operation MUST include "type" and it MUST be one of:
|
||||
- "update" (requires: id, block)
|
||||
- "add" (requires: referenceId, position, blocks)
|
||||
- "delete" (requires: id)
|
||||
|
||||
VALID SHAPES (FOLLOW EXACTLY):
|
||||
|
||||
Update:
|
||||
{ "type":"update", "id":"<id$>", "block":"<p>...</p>" }
|
||||
IMPORTANT: "block" MUST be a STRING containing a SINGLE valid HTML element.
|
||||
|
||||
Add:
|
||||
{ "type":"add", "referenceId":"<id$>", "position":"before|after", "blocks":["<p>...</p>"] }
|
||||
IMPORTANT: "blocks" MUST be an ARRAY OF STRINGS.
|
||||
Each item MUST be a STRING containing a SINGLE valid HTML element.
|
||||
|
||||
Delete:
|
||||
{ "type":"delete", "id":"<id$>" }
|
||||
|
||||
IDs ALWAYS end with "$". Use ids EXACTLY as provided.
|
||||
|
||||
Return ONLY the JSON tool input. No prose, no markdown.
|
||||
"""
|
||||
|
||||
AI_ACTIONS = {
|
||||
"prompt": (
|
||||
"Answer the prompt using markdown formatting for structure and emphasis. "
|
||||
@@ -95,6 +78,40 @@ AI_TRANSLATE = (
|
||||
)
|
||||
|
||||
|
||||
def convert_async_generator_to_sync(async_gen: AsyncIterator[str]) -> Iterator[str]:
|
||||
"""Convert an async generator to a sync generator."""
|
||||
q: queue.Queue[str | object] = queue.Queue()
|
||||
sentinel = object()
|
||||
exc_sentinel = object()
|
||||
|
||||
async def run_async_gen():
|
||||
try:
|
||||
async for async_item in async_gen:
|
||||
q.put(async_item)
|
||||
except Exception as exc: # pylint: disable=broad-except #noqa: BLE001
|
||||
q.put((exc_sentinel, exc))
|
||||
finally:
|
||||
q.put(sentinel)
|
||||
|
||||
def start_async_loop():
|
||||
asyncio.run(run_async_gen())
|
||||
|
||||
thread = threading.Thread(target=start_async_loop, daemon=True)
|
||||
thread.start()
|
||||
|
||||
try:
|
||||
while True:
|
||||
item = q.get()
|
||||
if item is sentinel:
|
||||
break
|
||||
if isinstance(item, tuple) and item[0] is exc_sentinel:
|
||||
# re-raise the exception in the sync context
|
||||
raise item[1]
|
||||
yield item
|
||||
finally:
|
||||
thread.join()
|
||||
|
||||
|
||||
class AIService:
|
||||
"""Service class for AI-related operations."""
|
||||
|
||||
@@ -136,77 +153,171 @@ class AIService:
|
||||
system_content = AI_TRANSLATE.format(language=language_display)
|
||||
return self.call_ai_api(system_content, text)
|
||||
|
||||
def _normalize_tools(self, tools: list) -> list:
|
||||
@staticmethod
|
||||
def inject_document_state_messages(
|
||||
messages: list[UIMessage],
|
||||
) -> list[UIMessage]:
|
||||
"""Inject document state context before user messages.
|
||||
|
||||
Port of BlockNote's injectDocumentStateMessages.
|
||||
For each user message carrying documentState metadata, an assistant
|
||||
message describing the current document/selection state is prepended
|
||||
so the LLM sees it as context.
|
||||
"""
|
||||
Normalize tool definitions to ensure they have required fields.
|
||||
"""
|
||||
normalized = []
|
||||
for tool in tools:
|
||||
if isinstance(tool, dict) and tool.get("type") == "function":
|
||||
fn = tool.get("function") or {}
|
||||
if isinstance(fn, dict) and not fn.get("description"):
|
||||
fn["description"] = f"Tool {fn.get('name', 'unknown')}."
|
||||
tool["function"] = fn
|
||||
normalized.append(tool)
|
||||
return normalized
|
||||
result: list[UIMessage] = []
|
||||
for message in messages:
|
||||
if (
|
||||
message.role == "user"
|
||||
and isinstance(message.metadata, dict)
|
||||
and "documentState" in message.metadata
|
||||
):
|
||||
doc_state = message.metadata["documentState"]
|
||||
selection = doc_state.get("selection")
|
||||
blocks = doc_state.get("blocks")
|
||||
|
||||
def _harden_payload(
|
||||
self, payload: Dict[str, Any], stream: bool = False
|
||||
) -> Dict[str, Any]:
|
||||
"""Harden the AI API payload to enforce compliance and tool usage."""
|
||||
payload["stream"] = stream
|
||||
|
||||
# Remove stream_options if stream is False
|
||||
if not stream and "stream_options" in payload:
|
||||
payload.pop("stream_options")
|
||||
|
||||
# Tools normalization
|
||||
if isinstance(payload.get("tools"), list):
|
||||
payload["tools"] = self._normalize_tools(payload["tools"])
|
||||
|
||||
# Inject strict system prompt once
|
||||
msgs = payload.get("messages")
|
||||
if isinstance(msgs, list):
|
||||
need = True
|
||||
if msgs and isinstance(msgs[0], dict) and msgs[0].get("role") == "system":
|
||||
c = msgs[0].get("content") or ""
|
||||
if (
|
||||
isinstance(c, str)
|
||||
and "applyDocumentOperations" in c
|
||||
and "blocks" in c
|
||||
):
|
||||
need = False
|
||||
if need:
|
||||
payload["messages"] = [
|
||||
{"role": "system", "content": BLOCKNOTE_TOOL_STRICT_PROMPT}
|
||||
] + msgs
|
||||
|
||||
return payload
|
||||
|
||||
def proxy(self, data: dict, stream: bool = False) -> Generator[str, None, None]:
|
||||
"""Proxy AI API requests to the configured AI provider."""
|
||||
payload = self._harden_payload(data, stream=stream)
|
||||
try:
|
||||
return self.client.chat.completions.create(**payload)
|
||||
except OpenAIError as e:
|
||||
raise RuntimeError(f"Failed to proxy AI request: {e}") from e
|
||||
|
||||
def stream(self, data: dict) -> Generator[str, None, None]:
|
||||
"""Stream AI API requests to the configured AI provider."""
|
||||
|
||||
try:
|
||||
stream = self.proxy(data, stream=True)
|
||||
for chunk in stream:
|
||||
try:
|
||||
chunk_dict = (
|
||||
chunk.model_dump() if hasattr(chunk, "model_dump") else chunk
|
||||
if selection:
|
||||
parts = [
|
||||
TextUIPart(
|
||||
text=(
|
||||
"This is the latest state of the selection "
|
||||
"(ignore previous selections, you MUST issue "
|
||||
"operations against this latest version of "
|
||||
"the selection):"
|
||||
),
|
||||
),
|
||||
TextUIPart(
|
||||
text=json.dumps(doc_state.get("selectedBlocks")),
|
||||
),
|
||||
TextUIPart(
|
||||
text=(
|
||||
"This is the latest state of the entire "
|
||||
"document (INCLUDING the selected text), you "
|
||||
"can use this to find the selected text to "
|
||||
"understand the context (but you MUST NOT "
|
||||
"issue operations against this document, you "
|
||||
"MUST issue operations against the selection):"
|
||||
),
|
||||
),
|
||||
TextUIPart(text=json.dumps(blocks)),
|
||||
]
|
||||
else:
|
||||
text = (
|
||||
"There is no active selection. This is the latest "
|
||||
"state of the document (ignore previous documents, "
|
||||
"you MUST issue operations against this latest "
|
||||
"version of the document). The cursor is BETWEEN "
|
||||
"two blocks as indicated by cursor: true."
|
||||
)
|
||||
chunk_json = json.dumps(chunk_dict)
|
||||
yield f"data: {chunk_json}\n\n"
|
||||
except (AttributeError, TypeError) as e:
|
||||
log.error("Error serializing chunk: %s, chunk: %s", e, chunk)
|
||||
continue
|
||||
except (OpenAIError, RuntimeError, OSError, ValueError) as e:
|
||||
log.error("Streaming error: %s", e)
|
||||
if doc_state.get("isEmptyDocument"):
|
||||
text += (
|
||||
"Because the document is empty, YOU MUST first "
|
||||
"update the empty block before adding new blocks."
|
||||
)
|
||||
else:
|
||||
text += (
|
||||
"Prefer updating existing blocks over removing "
|
||||
"and adding (but this also depends on the "
|
||||
"user's question)."
|
||||
)
|
||||
parts = [
|
||||
TextUIPart(text=text),
|
||||
TextUIPart(text=json.dumps(blocks)),
|
||||
]
|
||||
|
||||
yield "data: [DONE]\n\n"
|
||||
result.append(
|
||||
UIMessage(
|
||||
role="assistant",
|
||||
id=f"assistant-document-state-{message.id}",
|
||||
parts=parts,
|
||||
)
|
||||
)
|
||||
|
||||
result.append(message)
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def tool_definitions_to_toolset(
|
||||
tool_definitions: Dict[str, Any],
|
||||
) -> ExternalToolset:
|
||||
"""Convert serialized tool definitions to a pydantic-ai ExternalToolset.
|
||||
|
||||
Port of BlockNote's toolDefinitionsToToolSet.
|
||||
Builds ToolDefinition objects from the JSON-Schema-based definitions
|
||||
sent by the frontend and wraps them in an ExternalToolset so that
|
||||
pydantic-ai advertises them to the LLM without trying to execute them
|
||||
server-side (execution is deferred to the frontend).
|
||||
"""
|
||||
tool_defs = [
|
||||
ToolDefinition(
|
||||
name=name,
|
||||
description=defn.get("description", ""),
|
||||
parameters_json_schema=defn.get("inputSchema", {}),
|
||||
kind="external",
|
||||
metadata={
|
||||
"output_schema": defn.get("outputSchema"),
|
||||
},
|
||||
)
|
||||
for name, defn in tool_definitions.items()
|
||||
]
|
||||
return ExternalToolset(tool_defs)
|
||||
|
||||
def _build_async_stream(self, request: Request) -> AsyncIterator[str]:
|
||||
"""Build the async stream from the AI provider."""
|
||||
instrument_enabled = settings.LANGFUSE_PUBLIC_KEY is not None
|
||||
|
||||
if instrument_enabled:
|
||||
langfuse = get_client()
|
||||
langfuse.auth_check()
|
||||
Agent.instrument_all()
|
||||
|
||||
model = OpenAIChatModel(
|
||||
settings.AI_MODEL,
|
||||
provider=OpenAIProvider(
|
||||
base_url=settings.AI_BASE_URL, api_key=settings.AI_API_KEY
|
||||
),
|
||||
)
|
||||
agent = Agent(model, instrument=instrument_enabled)
|
||||
|
||||
accept = request.META.get("HTTP_ACCEPT", SSE_CONTENT_TYPE)
|
||||
|
||||
run_input = VercelAIAdapter.build_run_input(request.raw_body)
|
||||
|
||||
# Inject document state context into the conversation
|
||||
run_input.messages = self.inject_document_state_messages(run_input.messages)
|
||||
|
||||
# Build an ExternalToolset from frontend-supplied tool definitions
|
||||
raw_tool_defs = (
|
||||
run_input.model_extra.get("toolDefinitions")
|
||||
if run_input.model_extra
|
||||
else None
|
||||
)
|
||||
toolset = (
|
||||
self.tool_definitions_to_toolset(raw_tool_defs) if raw_tool_defs else None
|
||||
)
|
||||
|
||||
adapter = VercelAIAdapter(
|
||||
agent=agent,
|
||||
run_input=run_input,
|
||||
accept=accept,
|
||||
sdk_version=settings.AI_VERCEL_SDK_VERSION,
|
||||
)
|
||||
|
||||
event_stream = adapter.run_stream(
|
||||
output_type=[str, DeferredToolRequests] if toolset else None,
|
||||
toolsets=[toolset] if toolset else None,
|
||||
)
|
||||
|
||||
return adapter.encode_stream(event_stream)
|
||||
|
||||
def stream(self, request: Request) -> Union[AsyncIterator[str], Iterator[str]]:
|
||||
"""Stream AI API requests to the configured AI provider.
|
||||
|
||||
Returns an async iterator when running in async mode (ASGI)
|
||||
or a sync iterator when running in sync mode (WSGI).
|
||||
"""
|
||||
async_stream = self._build_async_stream(request)
|
||||
|
||||
if os.environ.get("PYTHON_SERVER_MODE", "sync") == "async":
|
||||
return async_stream
|
||||
|
||||
return convert_async_generator_to_sync(async_stream)
|
||||
|
||||
Reference in New Issue
Block a user