From 6f0dac4f485ad765e7436fe08c3fdcc564b9d727 Mon Sep 17 00:00:00 2001 From: Manuel Raynaud Date: Thu, 12 Jun 2025 11:08:23 +0200 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8(back)=20manage=20streaming=20with=20t?= =?UTF-8?q?he=20ai=20service?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit We want to handle both streaming or not when interacting with the AI backend service. --- docs/env.md | 1 + src/backend/core/api/serializers.py | 4 +- src/backend/core/api/viewsets.py | 17 ++- src/backend/core/services/ai_services.py | 36 +++++- .../documents/test_api_documents_ai_proxy.py | 8 +- src/backend/core/tests/test_api_config.py | 2 + .../core/tests/test_services_ai_services.py | 122 ++++++++++++++++-- src/backend/impress/settings.py | 3 + .../e2e/__tests__/app-impress/utils-common.ts | 1 + .../impress/src/core/config/api/useConfig.tsx | 1 + 10 files changed, 173 insertions(+), 22 deletions(-) diff --git a/docs/env.md b/docs/env.md index 1aee0dd8..725cfb07 100644 --- a/docs/env.md +++ b/docs/env.md @@ -14,6 +14,7 @@ These are the environment variables you can set for the `impress-backend` contai | AI_BOT | Information to give to the frontend about the AI bot | { "name": "Docs AI", "color": "#8bc6ff" } | AI_FEATURE_ENABLED | Enable AI options | false | | AI_MODEL | AI Model to use | | +| AI_STREAM | AI Stream to activate | true | | ALLOW_LOGOUT_GET_METHOD | Allow get logout method | true | | API_USERS_LIST_LIMIT | Limit on API users | 5 | | API_USERS_LIST_THROTTLE_RATE_BURST | Throttle rate for api on burst | 30/minute | diff --git a/src/backend/core/api/serializers.py b/src/backend/core/api/serializers.py index aaf72bb8..75eaa5d3 100644 --- a/src/backend/core/api/serializers.py +++ b/src/backend/core/api/serializers.py @@ -838,9 +838,7 @@ class AIProxySerializer(serializers.Serializer): messages = serializers.ListField( required=True, - child=serializers.DictField( - child=serializers.CharField(required=True), - ), + child=serializers.DictField(), allow_empty=False, ) model = serializers.CharField(required=True) diff --git a/src/backend/core/api/viewsets.py b/src/backend/core/api/viewsets.py index 6fde3fa1..9f715139 100644 --- a/src/backend/core/api/viewsets.py +++ b/src/backend/core/api/viewsets.py @@ -1857,8 +1857,20 @@ class DocumentViewSet( serializer = serializers.AIProxySerializer(data=request.data) serializer.is_valid(raise_exception=True) - response = AIService().proxy(request.data) - return drf.response.Response(response, status=drf.status.HTTP_200_OK) + ai_service = AIService() + + if settings.AI_STREAM: + return StreamingHttpResponse( + ai_service.stream(request.data), + content_type="text/event-stream", + status=drf.status.HTTP_200_OK, + ) + + ai_response = ai_service.proxy(request.data) + return drf.response.Response( + ai_response.model_dump(), + status=drf.status.HTTP_200_OK, + ) @drf.decorators.action( detail=True, @@ -2557,6 +2569,7 @@ class ConfigView(drf.views.APIView): array_settings = [ "AI_BOT", "AI_FEATURE_ENABLED", + "AI_STREAM", "API_USERS_SEARCH_QUERY_MIN_LENGTH", "COLLABORATION_WS_URL", "COLLABORATION_WS_NOT_CONNECTED_READY_ONLY", diff --git a/src/backend/core/services/ai_services.py b/src/backend/core/services/ai_services.py index 68e20b9e..72583aa6 100644 --- a/src/backend/core/services/ai_services.py +++ b/src/backend/core/services/ai_services.py @@ -1,16 +1,21 @@ """AI services.""" +import json import logging +from typing import Generator from django.conf import settings from django.core.exceptions import ImproperlyConfigured +from openai import OpenAI as OpenAI_Client +from openai import OpenAIError + from core import enums if settings.LANGFUSE_PUBLIC_KEY: from langfuse.openai import OpenAI else: - from openai import OpenAI + OpenAI = OpenAI_Client log = logging.getLogger(__name__) @@ -101,9 +106,30 @@ class AIService: system_content = AI_TRANSLATE.format(language=language_display) return self.call_ai_api(system_content, text) - def proxy(self, data: dict) -> dict: + def proxy(self, data: dict, stream: bool = False) -> Generator[str, None, None]: """Proxy AI API requests to the configured AI provider.""" - data["stream"] = False + data["stream"] = stream + try: + return self.client.chat.completions.create(**data) + except OpenAIError as e: + raise RuntimeError(f"Failed to proxy AI request: {e}") from e - response = self.client.chat.completions.create(**data) - return response.model_dump() + def stream(self, data: dict) -> Generator[str, None, None]: + """Stream AI API requests to the configured AI provider.""" + + try: + stream = self.proxy(data, stream=True) + for chunk in stream: + try: + chunk_dict = ( + chunk.model_dump() if hasattr(chunk, "model_dump") else chunk + ) + chunk_json = json.dumps(chunk_dict) + yield f"data: {chunk_json}\n\n" + except (AttributeError, TypeError) as e: + log.error("Error serializing chunk: %s, chunk: %s", e, chunk) + continue + except (OpenAIError, RuntimeError, OSError, ValueError) as e: + log.error("Streaming error: %s", e) + + yield "data: [DONE]\n\n" diff --git a/src/backend/core/tests/documents/test_api_documents_ai_proxy.py b/src/backend/core/tests/documents/test_api_documents_ai_proxy.py index 055dcf5b..f99d01a6 100644 --- a/src/backend/core/tests/documents/test_api_documents_ai_proxy.py +++ b/src/backend/core/tests/documents/test_api_documents_ai_proxy.py @@ -20,7 +20,7 @@ pytestmark = pytest.mark.django_db def ai_settings(settings): """Fixture to set AI settings.""" settings.AI_MODEL = "llama" - settings.AI_BASE_URL = "http://example.com" + settings.AI_BASE_URL = "http://localhost-ai:12345/" settings.AI_API_KEY = "test-key" settings.AI_FEATURE_ENABLED = True @@ -61,7 +61,7 @@ def test_api_documents_ai_proxy_anonymous_forbidden(reach, role): } -@override_settings(AI_ALLOW_REACH_FROM="public") +@override_settings(AI_ALLOW_REACH_FROM="public", AI_STREAM=False) @patch("openai.resources.chat.completions.Completions.create") def test_api_documents_ai_proxy_anonymous_success(mock_create): """ @@ -183,6 +183,7 @@ def test_api_documents_ai_proxy_authenticated_forbidden(reach, role): ("public", "editor"), ], ) +@override_settings(AI_STREAM=False) @patch("openai.resources.chat.completions.Completions.create") def test_api_documents_ai_proxy_authenticated_success(mock_create, reach, role): """ @@ -265,6 +266,7 @@ def test_api_documents_ai_proxy_reader(via, mock_user_teams): @pytest.mark.parametrize("role", ["editor", "administrator", "owner"]) @pytest.mark.parametrize("via", VIA) +@override_settings(AI_STREAM=False) @patch("openai.resources.chat.completions.Completions.create") def test_api_documents_ai_proxy_success(mock_create, via, role, mock_user_teams): """Users with sufficient permissions should be able to request AI proxy.""" @@ -483,6 +485,7 @@ def test_api_documents_ai_proxy_additional_parameters(mock_create): ) +@override_settings(AI_STREAM=False) @override_settings(AI_DOCUMENT_RATE_THROTTLE_RATES={"minute": 3, "hour": 6, "day": 10}) @patch("openai.resources.chat.completions.Completions.create") def test_api_documents_ai_proxy_throttling_document(mock_create): @@ -528,6 +531,7 @@ def test_api_documents_ai_proxy_throttling_document(mock_create): } +@override_settings(AI_STREAM=False) @patch("openai.resources.chat.completions.Completions.create") def test_api_documents_ai_proxy_complex_conversation(mock_create): """AI proxy should handle complex conversations with multiple messages.""" diff --git a/src/backend/core/tests/test_api_config.py b/src/backend/core/tests/test_api_config.py index 41b5493f..e98e5848 100644 --- a/src/backend/core/tests/test_api_config.py +++ b/src/backend/core/tests/test_api_config.py @@ -21,6 +21,7 @@ pytestmark = pytest.mark.django_db @override_settings( AI_BOT={"name": "Test Bot", "color": "#000000"}, AI_FEATURE_ENABLED=False, + AI_STREAM=True, API_USERS_SEARCH_QUERY_MIN_LENGTH=6, COLLABORATION_WS_URL="http://testcollab/", COLLABORATION_WS_NOT_CONNECTED_READY_ONLY=True, @@ -47,6 +48,7 @@ def test_api_config(is_authenticated): assert response.json() == { "AI_BOT": {"name": "Test Bot", "color": "#000000"}, "AI_FEATURE_ENABLED": False, + "AI_STREAM": True, "API_USERS_SEARCH_QUERY_MIN_LENGTH": 6, "COLLABORATION_WS_URL": "http://testcollab/", "COLLABORATION_WS_NOT_CONNECTED_READY_ONLY": True, diff --git a/src/backend/core/tests/test_services_ai_services.py b/src/backend/core/tests/test_services_ai_services.py index ffa5c170..edfe7e1a 100644 --- a/src/backend/core/tests/test_services_ai_services.py +++ b/src/backend/core/tests/test_services_ai_services.py @@ -15,6 +15,15 @@ from core.services.ai_services import AIService pytestmark = pytest.mark.django_db +@pytest.fixture(autouse=True) +def ai_settings(settings): + """Fixture to set AI settings.""" + settings.AI_MODEL = "llama" + settings.AI_BASE_URL = "http://example.com" + settings.AI_API_KEY = "test-key" + settings.AI_FEATURE_ENABLED = True + + @pytest.mark.parametrize( "setting_name, setting_value", [ @@ -23,22 +32,22 @@ pytestmark = pytest.mark.django_db ("AI_MODEL", None), ], ) -def test_api_ai_setting_missing(setting_name, setting_value): +def test_services_ai_setting_missing(setting_name, setting_value, settings): """Setting should be set""" + setattr(settings, setting_name, setting_value) - with override_settings(**{setting_name: setting_value}): - with pytest.raises( - ImproperlyConfigured, - match="AI configuration not set", - ): - AIService() + with pytest.raises( + ImproperlyConfigured, + match="AI configuration not set", + ): + AIService() @override_settings( AI_BASE_URL="http://example.com", AI_API_KEY="test-key", AI_MODEL="test-model" ) @patch("openai.resources.chat.completions.Completions.create") -def test_api_ai__client_error(mock_create): +def test_services_ai_client_error(mock_create): """Fail when the client raises an error""" mock_create.side_effect = OpenAIError("Mocked client error") @@ -50,11 +59,24 @@ def test_api_ai__client_error(mock_create): AIService().transform("hello", "prompt") +@patch("openai.resources.chat.completions.Completions.create") +def test_services_ai_proxy_client_error(mock_create): + """Fail when the client raises an error""" + + mock_create.side_effect = OpenAIError("Mocked client error") + + with pytest.raises( + RuntimeError, + match="Failed to proxy AI request: Mocked client error", + ): + AIService().proxy({"messages": [{"role": "user", "content": "hello"}]}) + + @override_settings( AI_BASE_URL="http://example.com", AI_API_KEY="test-key", AI_MODEL="test-model" ) @patch("openai.resources.chat.completions.Completions.create") -def test_api_ai__client_invalid_response(mock_create): +def test_services_ai_client_invalid_response(mock_create): """Fail when the client response is invalid""" mock_create.return_value = MagicMock( @@ -68,11 +90,50 @@ def test_api_ai__client_invalid_response(mock_create): AIService().transform("hello", "prompt") +@patch("openai.resources.chat.completions.Completions.create") +def test_services_ai_proxy_success(mock_create): + """The AI request should work as expect when called with valid arguments.""" + + mock_create.return_value = { + "id": "chatcmpl-test", + "object": "chat.completion", + "created": 1234567890, + "model": "test-model", + "choices": [ + { + "index": 0, + "message": {"role": "assistant", "content": "Salut"}, + "finish_reason": "stop", + } + ], + } + + response = AIService().proxy({"messages": [{"role": "user", "content": "hello"}]}) + + expected_response = { + "id": "chatcmpl-test", + "object": "chat.completion", + "created": 1234567890, + "model": "test-model", + "choices": [ + { + "index": 0, + "message": {"role": "assistant", "content": "Salut"}, + "finish_reason": "stop", + } + ], + } + assert response == expected_response + mock_create.assert_called_once_with( + messages=[{"role": "user", "content": "hello"}], stream=False + ) + + @override_settings( AI_BASE_URL="http://example.com", AI_API_KEY="test-key", AI_MODEL="test-model" ) @patch("openai.resources.chat.completions.Completions.create") -def test_api_ai__success(mock_create): +def test_services_ai_success(mock_create): """The AI request should work as expect when called with valid arguments.""" mock_create.return_value = MagicMock( @@ -82,3 +143,44 @@ def test_api_ai__success(mock_create): response = AIService().transform("hello", "prompt") assert response == {"answer": "Salut"} + + +@patch("openai.resources.chat.completions.Completions.create") +def test_services_ai_proxy_with_stream(mock_create): + """The AI request should work as expect when called with valid arguments.""" + + mock_create.return_value = { + "id": "chatcmpl-test", + "object": "chat.completion", + "created": 1234567890, + "model": "test-model", + "choices": [ + { + "index": 0, + "message": {"role": "assistant", "content": "Salut"}, + "finish_reason": "stop", + } + ], + } + + response = AIService().proxy( + {"messages": [{"role": "user", "content": "hello"}]}, stream=True + ) + + expected_response = { + "id": "chatcmpl-test", + "object": "chat.completion", + "created": 1234567890, + "model": "test-model", + "choices": [ + { + "index": 0, + "message": {"role": "assistant", "content": "Salut"}, + "finish_reason": "stop", + } + ], + } + assert response == expected_response + mock_create.assert_called_once_with( + messages=[{"role": "user", "content": "hello"}], stream=True + ) diff --git a/src/backend/impress/settings.py b/src/backend/impress/settings.py index 31758292..de4c6212 100755 --- a/src/backend/impress/settings.py +++ b/src/backend/impress/settings.py @@ -713,6 +713,9 @@ class Base(Configuration): default=False, environ_name="AI_FEATURE_ENABLED", environ_prefix=None ) AI_MODEL = values.Value(None, environ_name="AI_MODEL", environ_prefix=None) + AI_STREAM = values.BooleanValue( + default=True, environ_name="AI_STREAM", environ_prefix=None + ) AI_USER_RATE_THROTTLE_RATES = { "minute": 3, "hour": 50, diff --git a/src/frontend/apps/e2e/__tests__/app-impress/utils-common.ts b/src/frontend/apps/e2e/__tests__/app-impress/utils-common.ts index 78a97982..e2637bac 100644 --- a/src/frontend/apps/e2e/__tests__/app-impress/utils-common.ts +++ b/src/frontend/apps/e2e/__tests__/app-impress/utils-common.ts @@ -14,6 +14,7 @@ export const CONFIG = { color: '#8bc6ff', }, AI_FEATURE_ENABLED: true, + AI_STREAM: true, API_USERS_SEARCH_QUERY_MIN_LENGTH: 3, CRISP_WEBSITE_ID: null, COLLABORATION_WS_URL: 'ws://localhost:4444/collaboration/ws/', diff --git a/src/frontend/apps/impress/src/core/config/api/useConfig.tsx b/src/frontend/apps/impress/src/core/config/api/useConfig.tsx index 03f74209..68340e23 100644 --- a/src/frontend/apps/impress/src/core/config/api/useConfig.tsx +++ b/src/frontend/apps/impress/src/core/config/api/useConfig.tsx @@ -29,6 +29,7 @@ interface ThemeCustomization { export interface ConfigResponse { AI_BOT: { name: string; color: string }; AI_FEATURE_ENABLED?: boolean; + AI_STREAM: boolean; API_USERS_SEARCH_QUERY_MIN_LENGTH?: number; COLLABORATION_WS_URL?: string; COLLABORATION_WS_NOT_CONNECTED_READY_ONLY?: boolean;