✨(back) manage streaming with the ai service
We want to handle both streaming or not when interacting with the AI backend service.
This commit is contained in:
committed by
Anthony LC
parent
9d6fe5da8f
commit
6f0dac4f48
@@ -14,6 +14,7 @@ These are the environment variables you can set for the `impress-backend` contai
|
||||
| AI_BOT | Information to give to the frontend about the AI bot | { "name": "Docs AI", "color": "#8bc6ff" }
|
||||
| AI_FEATURE_ENABLED | Enable AI options | false |
|
||||
| AI_MODEL | AI Model to use | |
|
||||
| AI_STREAM | AI Stream to activate | true |
|
||||
| ALLOW_LOGOUT_GET_METHOD | Allow get logout method | true |
|
||||
| API_USERS_LIST_LIMIT | Limit on API users | 5 |
|
||||
| API_USERS_LIST_THROTTLE_RATE_BURST | Throttle rate for api on burst | 30/minute |
|
||||
|
||||
@@ -838,9 +838,7 @@ class AIProxySerializer(serializers.Serializer):
|
||||
|
||||
messages = serializers.ListField(
|
||||
required=True,
|
||||
child=serializers.DictField(
|
||||
child=serializers.CharField(required=True),
|
||||
),
|
||||
child=serializers.DictField(),
|
||||
allow_empty=False,
|
||||
)
|
||||
model = serializers.CharField(required=True)
|
||||
|
||||
@@ -1857,8 +1857,20 @@ class DocumentViewSet(
|
||||
serializer = serializers.AIProxySerializer(data=request.data)
|
||||
serializer.is_valid(raise_exception=True)
|
||||
|
||||
response = AIService().proxy(request.data)
|
||||
return drf.response.Response(response, status=drf.status.HTTP_200_OK)
|
||||
ai_service = AIService()
|
||||
|
||||
if settings.AI_STREAM:
|
||||
return StreamingHttpResponse(
|
||||
ai_service.stream(request.data),
|
||||
content_type="text/event-stream",
|
||||
status=drf.status.HTTP_200_OK,
|
||||
)
|
||||
|
||||
ai_response = ai_service.proxy(request.data)
|
||||
return drf.response.Response(
|
||||
ai_response.model_dump(),
|
||||
status=drf.status.HTTP_200_OK,
|
||||
)
|
||||
|
||||
@drf.decorators.action(
|
||||
detail=True,
|
||||
@@ -2557,6 +2569,7 @@ class ConfigView(drf.views.APIView):
|
||||
array_settings = [
|
||||
"AI_BOT",
|
||||
"AI_FEATURE_ENABLED",
|
||||
"AI_STREAM",
|
||||
"API_USERS_SEARCH_QUERY_MIN_LENGTH",
|
||||
"COLLABORATION_WS_URL",
|
||||
"COLLABORATION_WS_NOT_CONNECTED_READY_ONLY",
|
||||
|
||||
@@ -1,16 +1,21 @@
|
||||
"""AI services."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Generator
|
||||
|
||||
from django.conf import settings
|
||||
from django.core.exceptions import ImproperlyConfigured
|
||||
|
||||
from openai import OpenAI as OpenAI_Client
|
||||
from openai import OpenAIError
|
||||
|
||||
from core import enums
|
||||
|
||||
if settings.LANGFUSE_PUBLIC_KEY:
|
||||
from langfuse.openai import OpenAI
|
||||
else:
|
||||
from openai import OpenAI
|
||||
OpenAI = OpenAI_Client
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
@@ -101,9 +106,30 @@ class AIService:
|
||||
system_content = AI_TRANSLATE.format(language=language_display)
|
||||
return self.call_ai_api(system_content, text)
|
||||
|
||||
def proxy(self, data: dict) -> dict:
|
||||
def proxy(self, data: dict, stream: bool = False) -> Generator[str, None, None]:
|
||||
"""Proxy AI API requests to the configured AI provider."""
|
||||
data["stream"] = False
|
||||
data["stream"] = stream
|
||||
try:
|
||||
return self.client.chat.completions.create(**data)
|
||||
except OpenAIError as e:
|
||||
raise RuntimeError(f"Failed to proxy AI request: {e}") from e
|
||||
|
||||
response = self.client.chat.completions.create(**data)
|
||||
return response.model_dump()
|
||||
def stream(self, data: dict) -> Generator[str, None, None]:
|
||||
"""Stream AI API requests to the configured AI provider."""
|
||||
|
||||
try:
|
||||
stream = self.proxy(data, stream=True)
|
||||
for chunk in stream:
|
||||
try:
|
||||
chunk_dict = (
|
||||
chunk.model_dump() if hasattr(chunk, "model_dump") else chunk
|
||||
)
|
||||
chunk_json = json.dumps(chunk_dict)
|
||||
yield f"data: {chunk_json}\n\n"
|
||||
except (AttributeError, TypeError) as e:
|
||||
log.error("Error serializing chunk: %s, chunk: %s", e, chunk)
|
||||
continue
|
||||
except (OpenAIError, RuntimeError, OSError, ValueError) as e:
|
||||
log.error("Streaming error: %s", e)
|
||||
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
@@ -20,7 +20,7 @@ pytestmark = pytest.mark.django_db
|
||||
def ai_settings(settings):
|
||||
"""Fixture to set AI settings."""
|
||||
settings.AI_MODEL = "llama"
|
||||
settings.AI_BASE_URL = "http://example.com"
|
||||
settings.AI_BASE_URL = "http://localhost-ai:12345/"
|
||||
settings.AI_API_KEY = "test-key"
|
||||
settings.AI_FEATURE_ENABLED = True
|
||||
|
||||
@@ -61,7 +61,7 @@ def test_api_documents_ai_proxy_anonymous_forbidden(reach, role):
|
||||
}
|
||||
|
||||
|
||||
@override_settings(AI_ALLOW_REACH_FROM="public")
|
||||
@override_settings(AI_ALLOW_REACH_FROM="public", AI_STREAM=False)
|
||||
@patch("openai.resources.chat.completions.Completions.create")
|
||||
def test_api_documents_ai_proxy_anonymous_success(mock_create):
|
||||
"""
|
||||
@@ -183,6 +183,7 @@ def test_api_documents_ai_proxy_authenticated_forbidden(reach, role):
|
||||
("public", "editor"),
|
||||
],
|
||||
)
|
||||
@override_settings(AI_STREAM=False)
|
||||
@patch("openai.resources.chat.completions.Completions.create")
|
||||
def test_api_documents_ai_proxy_authenticated_success(mock_create, reach, role):
|
||||
"""
|
||||
@@ -265,6 +266,7 @@ def test_api_documents_ai_proxy_reader(via, mock_user_teams):
|
||||
|
||||
@pytest.mark.parametrize("role", ["editor", "administrator", "owner"])
|
||||
@pytest.mark.parametrize("via", VIA)
|
||||
@override_settings(AI_STREAM=False)
|
||||
@patch("openai.resources.chat.completions.Completions.create")
|
||||
def test_api_documents_ai_proxy_success(mock_create, via, role, mock_user_teams):
|
||||
"""Users with sufficient permissions should be able to request AI proxy."""
|
||||
@@ -483,6 +485,7 @@ def test_api_documents_ai_proxy_additional_parameters(mock_create):
|
||||
)
|
||||
|
||||
|
||||
@override_settings(AI_STREAM=False)
|
||||
@override_settings(AI_DOCUMENT_RATE_THROTTLE_RATES={"minute": 3, "hour": 6, "day": 10})
|
||||
@patch("openai.resources.chat.completions.Completions.create")
|
||||
def test_api_documents_ai_proxy_throttling_document(mock_create):
|
||||
@@ -528,6 +531,7 @@ def test_api_documents_ai_proxy_throttling_document(mock_create):
|
||||
}
|
||||
|
||||
|
||||
@override_settings(AI_STREAM=False)
|
||||
@patch("openai.resources.chat.completions.Completions.create")
|
||||
def test_api_documents_ai_proxy_complex_conversation(mock_create):
|
||||
"""AI proxy should handle complex conversations with multiple messages."""
|
||||
|
||||
@@ -21,6 +21,7 @@ pytestmark = pytest.mark.django_db
|
||||
@override_settings(
|
||||
AI_BOT={"name": "Test Bot", "color": "#000000"},
|
||||
AI_FEATURE_ENABLED=False,
|
||||
AI_STREAM=True,
|
||||
API_USERS_SEARCH_QUERY_MIN_LENGTH=6,
|
||||
COLLABORATION_WS_URL="http://testcollab/",
|
||||
COLLABORATION_WS_NOT_CONNECTED_READY_ONLY=True,
|
||||
@@ -47,6 +48,7 @@ def test_api_config(is_authenticated):
|
||||
assert response.json() == {
|
||||
"AI_BOT": {"name": "Test Bot", "color": "#000000"},
|
||||
"AI_FEATURE_ENABLED": False,
|
||||
"AI_STREAM": True,
|
||||
"API_USERS_SEARCH_QUERY_MIN_LENGTH": 6,
|
||||
"COLLABORATION_WS_URL": "http://testcollab/",
|
||||
"COLLABORATION_WS_NOT_CONNECTED_READY_ONLY": True,
|
||||
|
||||
@@ -15,6 +15,15 @@ from core.services.ai_services import AIService
|
||||
pytestmark = pytest.mark.django_db
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def ai_settings(settings):
|
||||
"""Fixture to set AI settings."""
|
||||
settings.AI_MODEL = "llama"
|
||||
settings.AI_BASE_URL = "http://example.com"
|
||||
settings.AI_API_KEY = "test-key"
|
||||
settings.AI_FEATURE_ENABLED = True
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"setting_name, setting_value",
|
||||
[
|
||||
@@ -23,22 +32,22 @@ pytestmark = pytest.mark.django_db
|
||||
("AI_MODEL", None),
|
||||
],
|
||||
)
|
||||
def test_api_ai_setting_missing(setting_name, setting_value):
|
||||
def test_services_ai_setting_missing(setting_name, setting_value, settings):
|
||||
"""Setting should be set"""
|
||||
setattr(settings, setting_name, setting_value)
|
||||
|
||||
with override_settings(**{setting_name: setting_value}):
|
||||
with pytest.raises(
|
||||
ImproperlyConfigured,
|
||||
match="AI configuration not set",
|
||||
):
|
||||
AIService()
|
||||
with pytest.raises(
|
||||
ImproperlyConfigured,
|
||||
match="AI configuration not set",
|
||||
):
|
||||
AIService()
|
||||
|
||||
|
||||
@override_settings(
|
||||
AI_BASE_URL="http://example.com", AI_API_KEY="test-key", AI_MODEL="test-model"
|
||||
)
|
||||
@patch("openai.resources.chat.completions.Completions.create")
|
||||
def test_api_ai__client_error(mock_create):
|
||||
def test_services_ai_client_error(mock_create):
|
||||
"""Fail when the client raises an error"""
|
||||
|
||||
mock_create.side_effect = OpenAIError("Mocked client error")
|
||||
@@ -50,11 +59,24 @@ def test_api_ai__client_error(mock_create):
|
||||
AIService().transform("hello", "prompt")
|
||||
|
||||
|
||||
@patch("openai.resources.chat.completions.Completions.create")
|
||||
def test_services_ai_proxy_client_error(mock_create):
|
||||
"""Fail when the client raises an error"""
|
||||
|
||||
mock_create.side_effect = OpenAIError("Mocked client error")
|
||||
|
||||
with pytest.raises(
|
||||
RuntimeError,
|
||||
match="Failed to proxy AI request: Mocked client error",
|
||||
):
|
||||
AIService().proxy({"messages": [{"role": "user", "content": "hello"}]})
|
||||
|
||||
|
||||
@override_settings(
|
||||
AI_BASE_URL="http://example.com", AI_API_KEY="test-key", AI_MODEL="test-model"
|
||||
)
|
||||
@patch("openai.resources.chat.completions.Completions.create")
|
||||
def test_api_ai__client_invalid_response(mock_create):
|
||||
def test_services_ai_client_invalid_response(mock_create):
|
||||
"""Fail when the client response is invalid"""
|
||||
|
||||
mock_create.return_value = MagicMock(
|
||||
@@ -68,11 +90,50 @@ def test_api_ai__client_invalid_response(mock_create):
|
||||
AIService().transform("hello", "prompt")
|
||||
|
||||
|
||||
@patch("openai.resources.chat.completions.Completions.create")
|
||||
def test_services_ai_proxy_success(mock_create):
|
||||
"""The AI request should work as expect when called with valid arguments."""
|
||||
|
||||
mock_create.return_value = {
|
||||
"id": "chatcmpl-test",
|
||||
"object": "chat.completion",
|
||||
"created": 1234567890,
|
||||
"model": "test-model",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {"role": "assistant", "content": "Salut"},
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
response = AIService().proxy({"messages": [{"role": "user", "content": "hello"}]})
|
||||
|
||||
expected_response = {
|
||||
"id": "chatcmpl-test",
|
||||
"object": "chat.completion",
|
||||
"created": 1234567890,
|
||||
"model": "test-model",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {"role": "assistant", "content": "Salut"},
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
],
|
||||
}
|
||||
assert response == expected_response
|
||||
mock_create.assert_called_once_with(
|
||||
messages=[{"role": "user", "content": "hello"}], stream=False
|
||||
)
|
||||
|
||||
|
||||
@override_settings(
|
||||
AI_BASE_URL="http://example.com", AI_API_KEY="test-key", AI_MODEL="test-model"
|
||||
)
|
||||
@patch("openai.resources.chat.completions.Completions.create")
|
||||
def test_api_ai__success(mock_create):
|
||||
def test_services_ai_success(mock_create):
|
||||
"""The AI request should work as expect when called with valid arguments."""
|
||||
|
||||
mock_create.return_value = MagicMock(
|
||||
@@ -82,3 +143,44 @@ def test_api_ai__success(mock_create):
|
||||
response = AIService().transform("hello", "prompt")
|
||||
|
||||
assert response == {"answer": "Salut"}
|
||||
|
||||
|
||||
@patch("openai.resources.chat.completions.Completions.create")
|
||||
def test_services_ai_proxy_with_stream(mock_create):
|
||||
"""The AI request should work as expect when called with valid arguments."""
|
||||
|
||||
mock_create.return_value = {
|
||||
"id": "chatcmpl-test",
|
||||
"object": "chat.completion",
|
||||
"created": 1234567890,
|
||||
"model": "test-model",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {"role": "assistant", "content": "Salut"},
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
response = AIService().proxy(
|
||||
{"messages": [{"role": "user", "content": "hello"}]}, stream=True
|
||||
)
|
||||
|
||||
expected_response = {
|
||||
"id": "chatcmpl-test",
|
||||
"object": "chat.completion",
|
||||
"created": 1234567890,
|
||||
"model": "test-model",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {"role": "assistant", "content": "Salut"},
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
],
|
||||
}
|
||||
assert response == expected_response
|
||||
mock_create.assert_called_once_with(
|
||||
messages=[{"role": "user", "content": "hello"}], stream=True
|
||||
)
|
||||
|
||||
@@ -713,6 +713,9 @@ class Base(Configuration):
|
||||
default=False, environ_name="AI_FEATURE_ENABLED", environ_prefix=None
|
||||
)
|
||||
AI_MODEL = values.Value(None, environ_name="AI_MODEL", environ_prefix=None)
|
||||
AI_STREAM = values.BooleanValue(
|
||||
default=True, environ_name="AI_STREAM", environ_prefix=None
|
||||
)
|
||||
AI_USER_RATE_THROTTLE_RATES = {
|
||||
"minute": 3,
|
||||
"hour": 50,
|
||||
|
||||
@@ -14,6 +14,7 @@ export const CONFIG = {
|
||||
color: '#8bc6ff',
|
||||
},
|
||||
AI_FEATURE_ENABLED: true,
|
||||
AI_STREAM: true,
|
||||
API_USERS_SEARCH_QUERY_MIN_LENGTH: 3,
|
||||
CRISP_WEBSITE_ID: null,
|
||||
COLLABORATION_WS_URL: 'ws://localhost:4444/collaboration/ws/',
|
||||
|
||||
@@ -29,6 +29,7 @@ interface ThemeCustomization {
|
||||
export interface ConfigResponse {
|
||||
AI_BOT: { name: string; color: string };
|
||||
AI_FEATURE_ENABLED?: boolean;
|
||||
AI_STREAM: boolean;
|
||||
API_USERS_SEARCH_QUERY_MIN_LENGTH?: number;
|
||||
COLLABORATION_WS_URL?: string;
|
||||
COLLABORATION_WS_NOT_CONNECTED_READY_ONLY?: boolean;
|
||||
|
||||
Reference in New Issue
Block a user