(back) manage streaming with the ai service

We want to handle both streaming or not when interacting with the AI
backend service.
This commit is contained in:
Manuel Raynaud
2025-06-12 11:08:23 +02:00
committed by Anthony LC
parent 9d6fe5da8f
commit 6f0dac4f48
10 changed files with 173 additions and 22 deletions

View File

@@ -14,6 +14,7 @@ These are the environment variables you can set for the `impress-backend` contai
| AI_BOT | Information to give to the frontend about the AI bot | { "name": "Docs AI", "color": "#8bc6ff" } | AI_BOT | Information to give to the frontend about the AI bot | { "name": "Docs AI", "color": "#8bc6ff" }
| AI_FEATURE_ENABLED | Enable AI options | false | | AI_FEATURE_ENABLED | Enable AI options | false |
| AI_MODEL | AI Model to use | | | AI_MODEL | AI Model to use | |
| AI_STREAM | AI Stream to activate | true |
| ALLOW_LOGOUT_GET_METHOD | Allow get logout method | true | | ALLOW_LOGOUT_GET_METHOD | Allow get logout method | true |
| API_USERS_LIST_LIMIT | Limit on API users | 5 | | API_USERS_LIST_LIMIT | Limit on API users | 5 |
| API_USERS_LIST_THROTTLE_RATE_BURST | Throttle rate for api on burst | 30/minute | | API_USERS_LIST_THROTTLE_RATE_BURST | Throttle rate for api on burst | 30/minute |

View File

@@ -838,9 +838,7 @@ class AIProxySerializer(serializers.Serializer):
messages = serializers.ListField( messages = serializers.ListField(
required=True, required=True,
child=serializers.DictField( child=serializers.DictField(),
child=serializers.CharField(required=True),
),
allow_empty=False, allow_empty=False,
) )
model = serializers.CharField(required=True) model = serializers.CharField(required=True)

View File

@@ -1857,8 +1857,20 @@ class DocumentViewSet(
serializer = serializers.AIProxySerializer(data=request.data) serializer = serializers.AIProxySerializer(data=request.data)
serializer.is_valid(raise_exception=True) serializer.is_valid(raise_exception=True)
response = AIService().proxy(request.data) ai_service = AIService()
return drf.response.Response(response, status=drf.status.HTTP_200_OK)
if settings.AI_STREAM:
return StreamingHttpResponse(
ai_service.stream(request.data),
content_type="text/event-stream",
status=drf.status.HTTP_200_OK,
)
ai_response = ai_service.proxy(request.data)
return drf.response.Response(
ai_response.model_dump(),
status=drf.status.HTTP_200_OK,
)
@drf.decorators.action( @drf.decorators.action(
detail=True, detail=True,
@@ -2557,6 +2569,7 @@ class ConfigView(drf.views.APIView):
array_settings = [ array_settings = [
"AI_BOT", "AI_BOT",
"AI_FEATURE_ENABLED", "AI_FEATURE_ENABLED",
"AI_STREAM",
"API_USERS_SEARCH_QUERY_MIN_LENGTH", "API_USERS_SEARCH_QUERY_MIN_LENGTH",
"COLLABORATION_WS_URL", "COLLABORATION_WS_URL",
"COLLABORATION_WS_NOT_CONNECTED_READY_ONLY", "COLLABORATION_WS_NOT_CONNECTED_READY_ONLY",

View File

@@ -1,16 +1,21 @@
"""AI services.""" """AI services."""
import json
import logging import logging
from typing import Generator
from django.conf import settings from django.conf import settings
from django.core.exceptions import ImproperlyConfigured from django.core.exceptions import ImproperlyConfigured
from openai import OpenAI as OpenAI_Client
from openai import OpenAIError
from core import enums from core import enums
if settings.LANGFUSE_PUBLIC_KEY: if settings.LANGFUSE_PUBLIC_KEY:
from langfuse.openai import OpenAI from langfuse.openai import OpenAI
else: else:
from openai import OpenAI OpenAI = OpenAI_Client
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@@ -101,9 +106,30 @@ class AIService:
system_content = AI_TRANSLATE.format(language=language_display) system_content = AI_TRANSLATE.format(language=language_display)
return self.call_ai_api(system_content, text) return self.call_ai_api(system_content, text)
def proxy(self, data: dict) -> dict: def proxy(self, data: dict, stream: bool = False) -> Generator[str, None, None]:
"""Proxy AI API requests to the configured AI provider.""" """Proxy AI API requests to the configured AI provider."""
data["stream"] = False data["stream"] = stream
try:
return self.client.chat.completions.create(**data)
except OpenAIError as e:
raise RuntimeError(f"Failed to proxy AI request: {e}") from e
response = self.client.chat.completions.create(**data) def stream(self, data: dict) -> Generator[str, None, None]:
return response.model_dump() """Stream AI API requests to the configured AI provider."""
try:
stream = self.proxy(data, stream=True)
for chunk in stream:
try:
chunk_dict = (
chunk.model_dump() if hasattr(chunk, "model_dump") else chunk
)
chunk_json = json.dumps(chunk_dict)
yield f"data: {chunk_json}\n\n"
except (AttributeError, TypeError) as e:
log.error("Error serializing chunk: %s, chunk: %s", e, chunk)
continue
except (OpenAIError, RuntimeError, OSError, ValueError) as e:
log.error("Streaming error: %s", e)
yield "data: [DONE]\n\n"

View File

@@ -20,7 +20,7 @@ pytestmark = pytest.mark.django_db
def ai_settings(settings): def ai_settings(settings):
"""Fixture to set AI settings.""" """Fixture to set AI settings."""
settings.AI_MODEL = "llama" settings.AI_MODEL = "llama"
settings.AI_BASE_URL = "http://example.com" settings.AI_BASE_URL = "http://localhost-ai:12345/"
settings.AI_API_KEY = "test-key" settings.AI_API_KEY = "test-key"
settings.AI_FEATURE_ENABLED = True settings.AI_FEATURE_ENABLED = True
@@ -61,7 +61,7 @@ def test_api_documents_ai_proxy_anonymous_forbidden(reach, role):
} }
@override_settings(AI_ALLOW_REACH_FROM="public") @override_settings(AI_ALLOW_REACH_FROM="public", AI_STREAM=False)
@patch("openai.resources.chat.completions.Completions.create") @patch("openai.resources.chat.completions.Completions.create")
def test_api_documents_ai_proxy_anonymous_success(mock_create): def test_api_documents_ai_proxy_anonymous_success(mock_create):
""" """
@@ -183,6 +183,7 @@ def test_api_documents_ai_proxy_authenticated_forbidden(reach, role):
("public", "editor"), ("public", "editor"),
], ],
) )
@override_settings(AI_STREAM=False)
@patch("openai.resources.chat.completions.Completions.create") @patch("openai.resources.chat.completions.Completions.create")
def test_api_documents_ai_proxy_authenticated_success(mock_create, reach, role): def test_api_documents_ai_proxy_authenticated_success(mock_create, reach, role):
""" """
@@ -265,6 +266,7 @@ def test_api_documents_ai_proxy_reader(via, mock_user_teams):
@pytest.mark.parametrize("role", ["editor", "administrator", "owner"]) @pytest.mark.parametrize("role", ["editor", "administrator", "owner"])
@pytest.mark.parametrize("via", VIA) @pytest.mark.parametrize("via", VIA)
@override_settings(AI_STREAM=False)
@patch("openai.resources.chat.completions.Completions.create") @patch("openai.resources.chat.completions.Completions.create")
def test_api_documents_ai_proxy_success(mock_create, via, role, mock_user_teams): def test_api_documents_ai_proxy_success(mock_create, via, role, mock_user_teams):
"""Users with sufficient permissions should be able to request AI proxy.""" """Users with sufficient permissions should be able to request AI proxy."""
@@ -483,6 +485,7 @@ def test_api_documents_ai_proxy_additional_parameters(mock_create):
) )
@override_settings(AI_STREAM=False)
@override_settings(AI_DOCUMENT_RATE_THROTTLE_RATES={"minute": 3, "hour": 6, "day": 10}) @override_settings(AI_DOCUMENT_RATE_THROTTLE_RATES={"minute": 3, "hour": 6, "day": 10})
@patch("openai.resources.chat.completions.Completions.create") @patch("openai.resources.chat.completions.Completions.create")
def test_api_documents_ai_proxy_throttling_document(mock_create): def test_api_documents_ai_proxy_throttling_document(mock_create):
@@ -528,6 +531,7 @@ def test_api_documents_ai_proxy_throttling_document(mock_create):
} }
@override_settings(AI_STREAM=False)
@patch("openai.resources.chat.completions.Completions.create") @patch("openai.resources.chat.completions.Completions.create")
def test_api_documents_ai_proxy_complex_conversation(mock_create): def test_api_documents_ai_proxy_complex_conversation(mock_create):
"""AI proxy should handle complex conversations with multiple messages.""" """AI proxy should handle complex conversations with multiple messages."""

View File

@@ -21,6 +21,7 @@ pytestmark = pytest.mark.django_db
@override_settings( @override_settings(
AI_BOT={"name": "Test Bot", "color": "#000000"}, AI_BOT={"name": "Test Bot", "color": "#000000"},
AI_FEATURE_ENABLED=False, AI_FEATURE_ENABLED=False,
AI_STREAM=True,
API_USERS_SEARCH_QUERY_MIN_LENGTH=6, API_USERS_SEARCH_QUERY_MIN_LENGTH=6,
COLLABORATION_WS_URL="http://testcollab/", COLLABORATION_WS_URL="http://testcollab/",
COLLABORATION_WS_NOT_CONNECTED_READY_ONLY=True, COLLABORATION_WS_NOT_CONNECTED_READY_ONLY=True,
@@ -47,6 +48,7 @@ def test_api_config(is_authenticated):
assert response.json() == { assert response.json() == {
"AI_BOT": {"name": "Test Bot", "color": "#000000"}, "AI_BOT": {"name": "Test Bot", "color": "#000000"},
"AI_FEATURE_ENABLED": False, "AI_FEATURE_ENABLED": False,
"AI_STREAM": True,
"API_USERS_SEARCH_QUERY_MIN_LENGTH": 6, "API_USERS_SEARCH_QUERY_MIN_LENGTH": 6,
"COLLABORATION_WS_URL": "http://testcollab/", "COLLABORATION_WS_URL": "http://testcollab/",
"COLLABORATION_WS_NOT_CONNECTED_READY_ONLY": True, "COLLABORATION_WS_NOT_CONNECTED_READY_ONLY": True,

View File

@@ -15,6 +15,15 @@ from core.services.ai_services import AIService
pytestmark = pytest.mark.django_db pytestmark = pytest.mark.django_db
@pytest.fixture(autouse=True)
def ai_settings(settings):
"""Fixture to set AI settings."""
settings.AI_MODEL = "llama"
settings.AI_BASE_URL = "http://example.com"
settings.AI_API_KEY = "test-key"
settings.AI_FEATURE_ENABLED = True
@pytest.mark.parametrize( @pytest.mark.parametrize(
"setting_name, setting_value", "setting_name, setting_value",
[ [
@@ -23,22 +32,22 @@ pytestmark = pytest.mark.django_db
("AI_MODEL", None), ("AI_MODEL", None),
], ],
) )
def test_api_ai_setting_missing(setting_name, setting_value): def test_services_ai_setting_missing(setting_name, setting_value, settings):
"""Setting should be set""" """Setting should be set"""
setattr(settings, setting_name, setting_value)
with override_settings(**{setting_name: setting_value}): with pytest.raises(
with pytest.raises( ImproperlyConfigured,
ImproperlyConfigured, match="AI configuration not set",
match="AI configuration not set", ):
): AIService()
AIService()
@override_settings( @override_settings(
AI_BASE_URL="http://example.com", AI_API_KEY="test-key", AI_MODEL="test-model" AI_BASE_URL="http://example.com", AI_API_KEY="test-key", AI_MODEL="test-model"
) )
@patch("openai.resources.chat.completions.Completions.create") @patch("openai.resources.chat.completions.Completions.create")
def test_api_ai__client_error(mock_create): def test_services_ai_client_error(mock_create):
"""Fail when the client raises an error""" """Fail when the client raises an error"""
mock_create.side_effect = OpenAIError("Mocked client error") mock_create.side_effect = OpenAIError("Mocked client error")
@@ -50,11 +59,24 @@ def test_api_ai__client_error(mock_create):
AIService().transform("hello", "prompt") AIService().transform("hello", "prompt")
@patch("openai.resources.chat.completions.Completions.create")
def test_services_ai_proxy_client_error(mock_create):
"""Fail when the client raises an error"""
mock_create.side_effect = OpenAIError("Mocked client error")
with pytest.raises(
RuntimeError,
match="Failed to proxy AI request: Mocked client error",
):
AIService().proxy({"messages": [{"role": "user", "content": "hello"}]})
@override_settings( @override_settings(
AI_BASE_URL="http://example.com", AI_API_KEY="test-key", AI_MODEL="test-model" AI_BASE_URL="http://example.com", AI_API_KEY="test-key", AI_MODEL="test-model"
) )
@patch("openai.resources.chat.completions.Completions.create") @patch("openai.resources.chat.completions.Completions.create")
def test_api_ai__client_invalid_response(mock_create): def test_services_ai_client_invalid_response(mock_create):
"""Fail when the client response is invalid""" """Fail when the client response is invalid"""
mock_create.return_value = MagicMock( mock_create.return_value = MagicMock(
@@ -68,11 +90,50 @@ def test_api_ai__client_invalid_response(mock_create):
AIService().transform("hello", "prompt") AIService().transform("hello", "prompt")
@patch("openai.resources.chat.completions.Completions.create")
def test_services_ai_proxy_success(mock_create):
"""The AI request should work as expect when called with valid arguments."""
mock_create.return_value = {
"id": "chatcmpl-test",
"object": "chat.completion",
"created": 1234567890,
"model": "test-model",
"choices": [
{
"index": 0,
"message": {"role": "assistant", "content": "Salut"},
"finish_reason": "stop",
}
],
}
response = AIService().proxy({"messages": [{"role": "user", "content": "hello"}]})
expected_response = {
"id": "chatcmpl-test",
"object": "chat.completion",
"created": 1234567890,
"model": "test-model",
"choices": [
{
"index": 0,
"message": {"role": "assistant", "content": "Salut"},
"finish_reason": "stop",
}
],
}
assert response == expected_response
mock_create.assert_called_once_with(
messages=[{"role": "user", "content": "hello"}], stream=False
)
@override_settings( @override_settings(
AI_BASE_URL="http://example.com", AI_API_KEY="test-key", AI_MODEL="test-model" AI_BASE_URL="http://example.com", AI_API_KEY="test-key", AI_MODEL="test-model"
) )
@patch("openai.resources.chat.completions.Completions.create") @patch("openai.resources.chat.completions.Completions.create")
def test_api_ai__success(mock_create): def test_services_ai_success(mock_create):
"""The AI request should work as expect when called with valid arguments.""" """The AI request should work as expect when called with valid arguments."""
mock_create.return_value = MagicMock( mock_create.return_value = MagicMock(
@@ -82,3 +143,44 @@ def test_api_ai__success(mock_create):
response = AIService().transform("hello", "prompt") response = AIService().transform("hello", "prompt")
assert response == {"answer": "Salut"} assert response == {"answer": "Salut"}
@patch("openai.resources.chat.completions.Completions.create")
def test_services_ai_proxy_with_stream(mock_create):
"""The AI request should work as expect when called with valid arguments."""
mock_create.return_value = {
"id": "chatcmpl-test",
"object": "chat.completion",
"created": 1234567890,
"model": "test-model",
"choices": [
{
"index": 0,
"message": {"role": "assistant", "content": "Salut"},
"finish_reason": "stop",
}
],
}
response = AIService().proxy(
{"messages": [{"role": "user", "content": "hello"}]}, stream=True
)
expected_response = {
"id": "chatcmpl-test",
"object": "chat.completion",
"created": 1234567890,
"model": "test-model",
"choices": [
{
"index": 0,
"message": {"role": "assistant", "content": "Salut"},
"finish_reason": "stop",
}
],
}
assert response == expected_response
mock_create.assert_called_once_with(
messages=[{"role": "user", "content": "hello"}], stream=True
)

View File

@@ -713,6 +713,9 @@ class Base(Configuration):
default=False, environ_name="AI_FEATURE_ENABLED", environ_prefix=None default=False, environ_name="AI_FEATURE_ENABLED", environ_prefix=None
) )
AI_MODEL = values.Value(None, environ_name="AI_MODEL", environ_prefix=None) AI_MODEL = values.Value(None, environ_name="AI_MODEL", environ_prefix=None)
AI_STREAM = values.BooleanValue(
default=True, environ_name="AI_STREAM", environ_prefix=None
)
AI_USER_RATE_THROTTLE_RATES = { AI_USER_RATE_THROTTLE_RATES = {
"minute": 3, "minute": 3,
"hour": 50, "hour": 50,

View File

@@ -14,6 +14,7 @@ export const CONFIG = {
color: '#8bc6ff', color: '#8bc6ff',
}, },
AI_FEATURE_ENABLED: true, AI_FEATURE_ENABLED: true,
AI_STREAM: true,
API_USERS_SEARCH_QUERY_MIN_LENGTH: 3, API_USERS_SEARCH_QUERY_MIN_LENGTH: 3,
CRISP_WEBSITE_ID: null, CRISP_WEBSITE_ID: null,
COLLABORATION_WS_URL: 'ws://localhost:4444/collaboration/ws/', COLLABORATION_WS_URL: 'ws://localhost:4444/collaboration/ws/',

View File

@@ -29,6 +29,7 @@ interface ThemeCustomization {
export interface ConfigResponse { export interface ConfigResponse {
AI_BOT: { name: string; color: string }; AI_BOT: { name: string; color: string };
AI_FEATURE_ENABLED?: boolean; AI_FEATURE_ENABLED?: boolean;
AI_STREAM: boolean;
API_USERS_SEARCH_QUERY_MIN_LENGTH?: number; API_USERS_SEARCH_QUERY_MIN_LENGTH?: number;
COLLABORATION_WS_URL?: string; COLLABORATION_WS_URL?: string;
COLLABORATION_WS_NOT_CONNECTED_READY_ONLY?: boolean; COLLABORATION_WS_NOT_CONNECTED_READY_ONLY?: boolean;