✨(backend) add ai_proxy
Add AI proxy to handle AI related requests to the AI service.
This commit is contained in:
@@ -833,6 +833,41 @@ class AITranslateSerializer(serializers.Serializer):
|
||||
return value
|
||||
|
||||
|
||||
class AIProxySerializer(serializers.Serializer):
|
||||
"""Serializer for AI proxy requests."""
|
||||
|
||||
messages = serializers.ListField(
|
||||
required=True,
|
||||
child=serializers.DictField(
|
||||
child=serializers.CharField(required=True),
|
||||
),
|
||||
allow_empty=False,
|
||||
)
|
||||
model = serializers.CharField(required=True)
|
||||
|
||||
def validate_messages(self, messages):
|
||||
"""Validate messages structure."""
|
||||
# Ensure each message has the required fields
|
||||
for message in messages:
|
||||
if (
|
||||
not isinstance(message, dict)
|
||||
or "role" not in message
|
||||
or "content" not in message
|
||||
):
|
||||
raise serializers.ValidationError(
|
||||
"Each message must have 'role' and 'content' fields"
|
||||
)
|
||||
|
||||
return messages
|
||||
|
||||
def validate_model(self, value):
|
||||
"""Validate model value is the same than settings.AI_MODEL"""
|
||||
if value != settings.AI_MODEL:
|
||||
raise serializers.ValidationError(f"{value} is not a valid model")
|
||||
|
||||
return value
|
||||
|
||||
|
||||
class MoveDocumentSerializer(serializers.Serializer):
|
||||
"""
|
||||
Serializer for validating input data to move a document within the tree structure.
|
||||
|
||||
@@ -475,6 +475,9 @@ class DocumentViewSet(
|
||||
Returns: JSON response with the translated text.
|
||||
Throttled by: AIDocumentRateThrottle, AIUserRateThrottle.
|
||||
|
||||
12. **AI Proxy**: Proxy an AI request to an external AI service.
|
||||
Example: POST /api/v1.0/documents/<resource_id>/ai-proxy
|
||||
|
||||
### Ordering: created_at, updated_at, is_favorite, title
|
||||
|
||||
Example:
|
||||
@@ -1832,6 +1835,31 @@ class DocumentViewSet(
|
||||
|
||||
return drf.response.Response(body, status=drf.status.HTTP_200_OK)
|
||||
|
||||
@drf.decorators.action(
|
||||
detail=True,
|
||||
methods=["post"],
|
||||
name="Proxy AI requests to the AI provider",
|
||||
url_path="ai-proxy",
|
||||
throttle_classes=[utils.AIDocumentRateThrottle, utils.AIUserRateThrottle],
|
||||
)
|
||||
def ai_proxy(self, request, *args, **kwargs):
|
||||
"""
|
||||
POST /api/v1.0/documents/<resource_id>/ai-proxy
|
||||
Proxy AI requests to the configured AI provider.
|
||||
This endpoint forwards requests to the AI provider and returns the complete response.
|
||||
"""
|
||||
# Check permissions first
|
||||
self.get_object()
|
||||
|
||||
if not settings.AI_FEATURE_ENABLED:
|
||||
raise ValidationError("AI feature is not enabled.")
|
||||
|
||||
serializer = serializers.AIProxySerializer(data=request.data)
|
||||
serializer.is_valid(raise_exception=True)
|
||||
|
||||
response = AIService().proxy(request.data)
|
||||
return drf.response.Response(response, status=drf.status.HTTP_200_OK)
|
||||
|
||||
@drf.decorators.action(
|
||||
detail=True,
|
||||
methods=["post"],
|
||||
|
||||
@@ -1282,6 +1282,7 @@ class Document(MP_Node, BaseModel):
|
||||
return {
|
||||
"accesses_manage": is_owner_or_admin,
|
||||
"accesses_view": has_access_role,
|
||||
"ai_proxy": ai_access,
|
||||
"ai_transform": ai_access,
|
||||
"ai_translate": ai_access,
|
||||
"attachment_upload": can_update,
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
"""AI services."""
|
||||
|
||||
import logging
|
||||
|
||||
from django.conf import settings
|
||||
from django.core.exceptions import ImproperlyConfigured
|
||||
|
||||
@@ -9,6 +11,8 @@ if settings.LANGFUSE_PUBLIC_KEY:
|
||||
from langfuse.openai import OpenAI
|
||||
else:
|
||||
from openai import OpenAI
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
AI_ACTIONS = {
|
||||
@@ -96,3 +100,10 @@ class AIService:
|
||||
language_display = enums.ALL_LANGUAGES.get(language, language)
|
||||
system_content = AI_TRANSLATE.format(language=language_display)
|
||||
return self.call_ai_api(system_content, text)
|
||||
|
||||
def proxy(self, data: dict) -> dict:
|
||||
"""Proxy AI API requests to the configured AI provider."""
|
||||
data["stream"] = False
|
||||
|
||||
response = self.client.chat.completions.create(**data)
|
||||
return response.model_dump()
|
||||
|
||||
686
src/backend/core/tests/documents/test_api_documents_ai_proxy.py
Normal file
686
src/backend/core/tests/documents/test_api_documents_ai_proxy.py
Normal file
@@ -0,0 +1,686 @@
|
||||
"""
|
||||
Test AI proxy API endpoint for users in impress's core app.
|
||||
"""
|
||||
|
||||
import random
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from django.test import override_settings
|
||||
|
||||
import pytest
|
||||
from rest_framework.test import APIClient
|
||||
|
||||
from core import factories
|
||||
from core.tests.conftest import TEAM, USER, VIA
|
||||
|
||||
pytestmark = pytest.mark.django_db
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def ai_settings(settings):
|
||||
"""Fixture to set AI settings."""
|
||||
settings.AI_MODEL = "llama"
|
||||
settings.AI_BASE_URL = "http://example.com"
|
||||
settings.AI_API_KEY = "test-key"
|
||||
settings.AI_FEATURE_ENABLED = True
|
||||
|
||||
|
||||
@override_settings(
|
||||
AI_ALLOW_REACH_FROM=random.choice(["public", "authenticated", "restricted"])
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"reach, role",
|
||||
[
|
||||
("restricted", "reader"),
|
||||
("restricted", "editor"),
|
||||
("authenticated", "reader"),
|
||||
("authenticated", "editor"),
|
||||
("public", "reader"),
|
||||
],
|
||||
)
|
||||
def test_api_documents_ai_proxy_anonymous_forbidden(reach, role):
|
||||
"""
|
||||
Anonymous users should not be able to request AI proxy if the link reach
|
||||
and role don't allow it.
|
||||
"""
|
||||
document = factories.DocumentFactory(link_reach=reach, link_role=role)
|
||||
|
||||
url = f"/api/v1.0/documents/{document.id!s}/ai-proxy/"
|
||||
response = APIClient().post(
|
||||
url,
|
||||
{
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
"model": "llama",
|
||||
},
|
||||
format="json",
|
||||
)
|
||||
|
||||
assert response.status_code == 401
|
||||
assert response.json() == {
|
||||
"detail": "Authentication credentials were not provided."
|
||||
}
|
||||
|
||||
|
||||
@override_settings(AI_ALLOW_REACH_FROM="public")
|
||||
@patch("openai.resources.chat.completions.Completions.create")
|
||||
def test_api_documents_ai_proxy_anonymous_success(mock_create):
|
||||
"""
|
||||
Anonymous users should be able to request AI proxy to a document
|
||||
if the link reach and role permit it.
|
||||
"""
|
||||
document = factories.DocumentFactory(link_reach="public", link_role="editor")
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.model_dump.return_value = {
|
||||
"id": "chatcmpl-123",
|
||||
"object": "chat.completion",
|
||||
"created": 1677652288,
|
||||
"model": "llama",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "Hello! How can I help you?",
|
||||
},
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
],
|
||||
"usage": {"prompt_tokens": 9, "completion_tokens": 12, "total_tokens": 21},
|
||||
}
|
||||
mock_create.return_value = mock_response
|
||||
url = f"/api/v1.0/documents/{document.id!s}/ai-proxy/"
|
||||
response = APIClient().post(
|
||||
url,
|
||||
{
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
"model": "llama",
|
||||
},
|
||||
format="json",
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
response_data = response.json()
|
||||
assert response_data["id"] == "chatcmpl-123"
|
||||
assert response_data["model"] == "llama"
|
||||
assert len(response_data["choices"]) == 1
|
||||
assert (
|
||||
response_data["choices"][0]["message"]["content"]
|
||||
== "Hello! How can I help you?"
|
||||
)
|
||||
|
||||
mock_create.assert_called_once_with(
|
||||
messages=[{"role": "user", "content": "Hello"}],
|
||||
model="llama",
|
||||
stream=False,
|
||||
)
|
||||
|
||||
|
||||
@override_settings(AI_ALLOW_REACH_FROM=random.choice(["authenticated", "restricted"]))
|
||||
@patch("openai.resources.chat.completions.Completions.create")
|
||||
def test_api_documents_ai_proxy_anonymous_limited_by_setting(mock_create):
|
||||
"""
|
||||
Anonymous users should not be able to request AI proxy to a document
|
||||
if AI_ALLOW_REACH_FROM setting restricts it.
|
||||
"""
|
||||
document = factories.DocumentFactory(link_reach="public", link_role="editor")
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.model_dump.return_value = {"content": "Hello!"}
|
||||
mock_create.return_value = mock_response
|
||||
|
||||
url = f"/api/v1.0/documents/{document.id!s}/ai-proxy/"
|
||||
response = APIClient().post(
|
||||
url,
|
||||
{
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
"model": "llama",
|
||||
},
|
||||
format="json",
|
||||
)
|
||||
|
||||
assert response.status_code == 401
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"reach, role",
|
||||
[
|
||||
("restricted", "reader"),
|
||||
("restricted", "editor"),
|
||||
("authenticated", "reader"),
|
||||
("public", "reader"),
|
||||
],
|
||||
)
|
||||
def test_api_documents_ai_proxy_authenticated_forbidden(reach, role):
|
||||
"""
|
||||
Users who are not related to a document can't request AI proxy if the
|
||||
link reach and role don't allow it.
|
||||
"""
|
||||
user = factories.UserFactory()
|
||||
|
||||
client = APIClient()
|
||||
client.force_login(user)
|
||||
|
||||
document = factories.DocumentFactory(link_reach=reach, link_role=role)
|
||||
|
||||
url = f"/api/v1.0/documents/{document.id!s}/ai-proxy/"
|
||||
response = client.post(
|
||||
url,
|
||||
{
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
"model": "llama",
|
||||
},
|
||||
format="json",
|
||||
)
|
||||
|
||||
assert response.status_code == 403
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"reach, role",
|
||||
[
|
||||
("authenticated", "editor"),
|
||||
("public", "editor"),
|
||||
],
|
||||
)
|
||||
@patch("openai.resources.chat.completions.Completions.create")
|
||||
def test_api_documents_ai_proxy_authenticated_success(mock_create, reach, role):
|
||||
"""
|
||||
Authenticated users should be able to request AI proxy to a document
|
||||
if the link reach and role permit it.
|
||||
"""
|
||||
user = factories.UserFactory()
|
||||
|
||||
client = APIClient()
|
||||
client.force_login(user)
|
||||
|
||||
document = factories.DocumentFactory(link_reach=reach, link_role=role)
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.model_dump.return_value = {
|
||||
"id": "chatcmpl-456",
|
||||
"object": "chat.completion",
|
||||
"model": "llama",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {"role": "assistant", "content": "Hi there!"},
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
],
|
||||
}
|
||||
mock_create.return_value = mock_response
|
||||
|
||||
url = f"/api/v1.0/documents/{document.id!s}/ai-proxy/"
|
||||
response = client.post(
|
||||
url,
|
||||
{
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
"model": "llama",
|
||||
},
|
||||
format="json",
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
response_data = response.json()
|
||||
assert response_data["id"] == "chatcmpl-456"
|
||||
assert response_data["choices"][0]["message"]["content"] == "Hi there!"
|
||||
|
||||
mock_create.assert_called_once_with(
|
||||
messages=[{"role": "user", "content": "Hello"}],
|
||||
model="llama",
|
||||
stream=False,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("via", VIA)
|
||||
def test_api_documents_ai_proxy_reader(via, mock_user_teams):
|
||||
"""Users with reader access should not be able to request AI proxy."""
|
||||
user = factories.UserFactory()
|
||||
|
||||
client = APIClient()
|
||||
client.force_login(user)
|
||||
|
||||
document = factories.DocumentFactory(link_reach="restricted")
|
||||
if via == USER:
|
||||
factories.UserDocumentAccessFactory(document=document, user=user, role="reader")
|
||||
elif via == TEAM:
|
||||
mock_user_teams.return_value = ["lasuite", "unknown"]
|
||||
factories.TeamDocumentAccessFactory(
|
||||
document=document, team="lasuite", role="reader"
|
||||
)
|
||||
|
||||
url = f"/api/v1.0/documents/{document.id!s}/ai-proxy/"
|
||||
response = client.post(
|
||||
url,
|
||||
{
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
"model": "llama",
|
||||
},
|
||||
format="json",
|
||||
)
|
||||
|
||||
assert response.status_code == 403
|
||||
|
||||
|
||||
@pytest.mark.parametrize("role", ["editor", "administrator", "owner"])
|
||||
@pytest.mark.parametrize("via", VIA)
|
||||
@patch("openai.resources.chat.completions.Completions.create")
|
||||
def test_api_documents_ai_proxy_success(mock_create, via, role, mock_user_teams):
|
||||
"""Users with sufficient permissions should be able to request AI proxy."""
|
||||
user = factories.UserFactory()
|
||||
|
||||
client = APIClient()
|
||||
client.force_login(user)
|
||||
|
||||
document = factories.DocumentFactory(link_reach="restricted")
|
||||
if via == USER:
|
||||
factories.UserDocumentAccessFactory(document=document, user=user, role=role)
|
||||
elif via == TEAM:
|
||||
mock_user_teams.return_value = ["lasuite", "unknown"]
|
||||
factories.TeamDocumentAccessFactory(
|
||||
document=document, team="lasuite", role=role
|
||||
)
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.model_dump.return_value = {
|
||||
"id": "chatcmpl-789",
|
||||
"object": "chat.completion",
|
||||
"model": "llama",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {"role": "assistant", "content": "Success!"},
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
],
|
||||
}
|
||||
mock_create.return_value = mock_response
|
||||
|
||||
url = f"/api/v1.0/documents/{document.id!s}/ai-proxy/"
|
||||
response = client.post(
|
||||
url,
|
||||
{
|
||||
"messages": [{"role": "user", "content": "Test message"}],
|
||||
"model": "llama",
|
||||
},
|
||||
format="json",
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
response_data = response.json()
|
||||
assert response_data["id"] == "chatcmpl-789"
|
||||
assert response_data["choices"][0]["message"]["content"] == "Success!"
|
||||
|
||||
mock_create.assert_called_once_with(
|
||||
messages=[{"role": "user", "content": "Test message"}],
|
||||
model="llama",
|
||||
stream=False,
|
||||
)
|
||||
|
||||
|
||||
def test_api_documents_ai_proxy_empty_messages():
|
||||
"""The messages should not be empty when requesting AI proxy."""
|
||||
user = factories.UserFactory()
|
||||
|
||||
client = APIClient()
|
||||
client.force_login(user)
|
||||
|
||||
document = factories.DocumentFactory(link_reach="public", link_role="editor")
|
||||
|
||||
url = f"/api/v1.0/documents/{document.id!s}/ai-proxy/"
|
||||
response = client.post(url, {"messages": [], "model": "llama"}, format="json")
|
||||
|
||||
assert response.status_code == 400
|
||||
assert response.json() == {"messages": ["This list may not be empty."]}
|
||||
|
||||
|
||||
def test_api_documents_ai_proxy_missing_model():
|
||||
"""The model should be required when requesting AI proxy."""
|
||||
user = factories.UserFactory()
|
||||
|
||||
client = APIClient()
|
||||
client.force_login(user)
|
||||
|
||||
document = factories.DocumentFactory(link_reach="public", link_role="editor")
|
||||
|
||||
url = f"/api/v1.0/documents/{document.id!s}/ai-proxy/"
|
||||
response = client.post(
|
||||
url, {"messages": [{"role": "user", "content": "Hello"}]}, format="json"
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
assert response.json() == {"model": ["This field is required."]}
|
||||
|
||||
|
||||
def test_api_documents_ai_proxy_invalid_message_format():
|
||||
"""Messages should have the correct format when requesting AI proxy."""
|
||||
user = factories.UserFactory()
|
||||
|
||||
client = APIClient()
|
||||
client.force_login(user)
|
||||
|
||||
document = factories.DocumentFactory(link_reach="public", link_role="editor")
|
||||
|
||||
url = f"/api/v1.0/documents/{document.id!s}/ai-proxy/"
|
||||
|
||||
# Test with invalid message format (missing role)
|
||||
response = client.post(
|
||||
url,
|
||||
{
|
||||
"messages": [{"content": "Hello"}],
|
||||
"model": "llama",
|
||||
},
|
||||
format="json",
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
assert response.json() == {
|
||||
"messages": ["Each message must have 'role' and 'content' fields"]
|
||||
}
|
||||
|
||||
# Test with invalid message format (missing content)
|
||||
response = client.post(
|
||||
url,
|
||||
{
|
||||
"messages": [{"role": "user"}],
|
||||
"model": "llama",
|
||||
},
|
||||
format="json",
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
assert response.json() == {
|
||||
"messages": ["Each message must have 'role' and 'content' fields"]
|
||||
}
|
||||
|
||||
# Test with non-dict message
|
||||
response = client.post(
|
||||
url,
|
||||
{
|
||||
"messages": ["invalid"],
|
||||
"model": "llama",
|
||||
},
|
||||
format="json",
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
assert response.json() == {
|
||||
"messages": {"0": ['Expected a dictionary of items but got type "str".']}
|
||||
}
|
||||
|
||||
|
||||
@patch("openai.resources.chat.completions.Completions.create")
|
||||
def test_api_documents_ai_proxy_stream_disabled(mock_create):
|
||||
"""Stream should be automatically disabled in AI proxy requests."""
|
||||
user = factories.UserFactory()
|
||||
|
||||
client = APIClient()
|
||||
client.force_login(user)
|
||||
|
||||
document = factories.DocumentFactory(link_reach="public", link_role="editor")
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.model_dump.return_value = {"content": "Success!"}
|
||||
mock_create.return_value = mock_response
|
||||
|
||||
url = f"/api/v1.0/documents/{document.id!s}/ai-proxy/"
|
||||
response = client.post(
|
||||
url,
|
||||
{
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
"model": "llama",
|
||||
"stream": True, # This should be overridden to False
|
||||
},
|
||||
format="json",
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
# Verify that stream was set to False
|
||||
mock_create.assert_called_once_with(
|
||||
messages=[{"role": "user", "content": "Hello"}],
|
||||
model="llama",
|
||||
stream=False,
|
||||
)
|
||||
|
||||
|
||||
@patch("openai.resources.chat.completions.Completions.create")
|
||||
def test_api_documents_ai_proxy_additional_parameters(mock_create):
|
||||
"""AI proxy should pass through additional parameters to the AI service."""
|
||||
user = factories.UserFactory()
|
||||
|
||||
client = APIClient()
|
||||
client.force_login(user)
|
||||
|
||||
document = factories.DocumentFactory(link_reach="public", link_role="editor")
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.model_dump.return_value = {"content": "Success!"}
|
||||
mock_create.return_value = mock_response
|
||||
|
||||
url = f"/api/v1.0/documents/{document.id!s}/ai-proxy/"
|
||||
response = client.post(
|
||||
url,
|
||||
{
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
"model": "llama",
|
||||
"temperature": 0.7,
|
||||
"max_tokens": 100,
|
||||
"top_p": 0.9,
|
||||
},
|
||||
format="json",
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
# Verify that additional parameters were passed through
|
||||
mock_create.assert_called_once_with(
|
||||
messages=[{"role": "user", "content": "Hello"}],
|
||||
model="llama",
|
||||
temperature=0.7,
|
||||
max_tokens=100,
|
||||
top_p=0.9,
|
||||
stream=False,
|
||||
)
|
||||
|
||||
|
||||
@override_settings(AI_DOCUMENT_RATE_THROTTLE_RATES={"minute": 3, "hour": 6, "day": 10})
|
||||
@patch("openai.resources.chat.completions.Completions.create")
|
||||
def test_api_documents_ai_proxy_throttling_document(mock_create):
|
||||
"""
|
||||
Throttling per document should be triggered on the AI transform endpoint.
|
||||
For full throttle class test see: `test_api_utils_ai_document_rate_throttles`
|
||||
"""
|
||||
client = APIClient()
|
||||
document = factories.DocumentFactory(link_reach="public", link_role="editor")
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.model_dump.return_value = {"content": "Success!"}
|
||||
mock_create.return_value = mock_response
|
||||
|
||||
url = f"/api/v1.0/documents/{document.id!s}/ai-proxy/"
|
||||
for _ in range(3):
|
||||
user = factories.UserFactory()
|
||||
client.force_login(user)
|
||||
response = client.post(
|
||||
url,
|
||||
{
|
||||
"messages": [{"role": "user", "content": "Test message"}],
|
||||
"model": "llama",
|
||||
},
|
||||
format="json",
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"content": "Success!"}
|
||||
|
||||
user = factories.UserFactory()
|
||||
client.force_login(user)
|
||||
response = client.post(
|
||||
url,
|
||||
{
|
||||
"messages": [{"role": "user", "content": "Test message"}],
|
||||
"model": "llama",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 429
|
||||
assert response.json() == {
|
||||
"detail": "Request was throttled. Expected available in 60 seconds."
|
||||
}
|
||||
|
||||
|
||||
@patch("openai.resources.chat.completions.Completions.create")
|
||||
def test_api_documents_ai_proxy_complex_conversation(mock_create):
|
||||
"""AI proxy should handle complex conversations with multiple messages."""
|
||||
user = factories.UserFactory()
|
||||
|
||||
client = APIClient()
|
||||
client.force_login(user)
|
||||
|
||||
document = factories.DocumentFactory(link_reach="public", link_role="editor")
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.model_dump.return_value = {
|
||||
"id": "chatcmpl-complex",
|
||||
"object": "chat.completion",
|
||||
"model": "llama",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "I understand your question about Python.",
|
||||
},
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
],
|
||||
}
|
||||
mock_create.return_value = mock_response
|
||||
|
||||
complex_messages = [
|
||||
{"role": "system", "content": "You are a helpful programming assistant."},
|
||||
{"role": "user", "content": "How do I write a for loop in Python?"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "You can write a for loop using: for item in iterable:",
|
||||
},
|
||||
{"role": "user", "content": "Can you give me a concrete example?"},
|
||||
]
|
||||
|
||||
url = f"/api/v1.0/documents/{document.id!s}/ai-proxy/"
|
||||
response = client.post(
|
||||
url,
|
||||
{
|
||||
"messages": complex_messages,
|
||||
"model": "llama",
|
||||
},
|
||||
format="json",
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
response_data = response.json()
|
||||
assert response_data["id"] == "chatcmpl-complex"
|
||||
assert (
|
||||
response_data["choices"][0]["message"]["content"]
|
||||
== "I understand your question about Python."
|
||||
)
|
||||
|
||||
mock_create.assert_called_once_with(
|
||||
messages=complex_messages,
|
||||
model="llama",
|
||||
stream=False,
|
||||
)
|
||||
|
||||
|
||||
@override_settings(AI_USER_RATE_THROTTLE_RATES={"minute": 3, "hour": 6, "day": 10})
|
||||
@patch("openai.resources.chat.completions.Completions.create")
|
||||
def test_api_documents_ai_proxy_throttling_user(mock_create):
|
||||
"""
|
||||
Throttling per user should be triggered on the AI proxy endpoint.
|
||||
For full throttle class test see: `test_api_utils_ai_user_rate_throttles`
|
||||
"""
|
||||
user = factories.UserFactory()
|
||||
client = APIClient()
|
||||
client.force_login(user)
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.model_dump.return_value = {"content": "Success!"}
|
||||
mock_create.return_value = mock_response
|
||||
|
||||
for _ in range(3):
|
||||
document = factories.DocumentFactory(link_reach="public", link_role="editor")
|
||||
url = f"/api/v1.0/documents/{document.id!s}/ai-proxy/"
|
||||
response = client.post(
|
||||
url,
|
||||
{
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
"model": "llama",
|
||||
},
|
||||
format="json",
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
document = factories.DocumentFactory(link_reach="public", link_role="editor")
|
||||
url = f"/api/v1.0/documents/{document.id!s}/ai-proxy/"
|
||||
response = client.post(
|
||||
url,
|
||||
{
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
"model": "llama",
|
||||
},
|
||||
format="json",
|
||||
)
|
||||
|
||||
assert response.status_code == 429
|
||||
assert response.json() == {
|
||||
"detail": "Request was throttled. Expected available in 60 seconds."
|
||||
}
|
||||
|
||||
|
||||
@override_settings(AI_USER_RATE_THROTTLE_RATES={"minute": 10, "hour": 6, "day": 10})
|
||||
def test_api_documents_ai_proxy_different_models():
|
||||
"""AI proxy should work with different AI models."""
|
||||
user = factories.UserFactory()
|
||||
|
||||
client = APIClient()
|
||||
client.force_login(user)
|
||||
|
||||
document = factories.DocumentFactory(link_reach="public", link_role="editor")
|
||||
|
||||
models_to_test = ["gpt-3.5-turbo", "gpt-4", "claude-3", "llama-2"]
|
||||
|
||||
for model_name in models_to_test:
|
||||
response = client.post(
|
||||
f"/api/v1.0/documents/{document.id!s}/ai-proxy/",
|
||||
{
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
"model": model_name,
|
||||
},
|
||||
format="json",
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
assert response.json() == {"model": [f"{model_name} is not a valid model"]}
|
||||
|
||||
|
||||
def test_api_documents_ai_proxy_ai_feature_disabled(settings):
|
||||
"""When the settings AI_FEATURE_ENABLED is set to False, the endpoint is not reachable."""
|
||||
settings.AI_FEATURE_ENABLED = False
|
||||
|
||||
user = factories.UserFactory()
|
||||
|
||||
client = APIClient()
|
||||
client.force_login(user)
|
||||
|
||||
document = factories.DocumentFactory(link_reach="public", link_role="editor")
|
||||
|
||||
response = client.post(
|
||||
f"/api/v1.0/documents/{document.id!s}/ai-proxy/",
|
||||
{
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
"model": "llama",
|
||||
},
|
||||
format="json",
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
assert response.json() == ["AI feature is not enabled."]
|
||||
@@ -29,6 +29,7 @@ def test_api_documents_retrieve_anonymous_public_standalone():
|
||||
"abilities": {
|
||||
"accesses_manage": False,
|
||||
"accesses_view": False,
|
||||
"ai_proxy": False,
|
||||
"ai_transform": False,
|
||||
"ai_translate": False,
|
||||
"attachment_upload": document.link_role == "editor",
|
||||
@@ -107,6 +108,7 @@ def test_api_documents_retrieve_anonymous_public_parent():
|
||||
"abilities": {
|
||||
"accesses_manage": False,
|
||||
"accesses_view": False,
|
||||
"ai_proxy": False,
|
||||
"ai_transform": False,
|
||||
"ai_translate": False,
|
||||
"attachment_upload": grand_parent.link_role == "editor",
|
||||
@@ -215,6 +217,7 @@ def test_api_documents_retrieve_authenticated_unrelated_public_or_authenticated(
|
||||
"abilities": {
|
||||
"accesses_manage": False,
|
||||
"accesses_view": False,
|
||||
"ai_proxy": document.link_role == "editor",
|
||||
"ai_transform": document.link_role == "editor",
|
||||
"ai_translate": document.link_role == "editor",
|
||||
"attachment_upload": document.link_role == "editor",
|
||||
@@ -300,6 +303,7 @@ def test_api_documents_retrieve_authenticated_public_or_authenticated_parent(rea
|
||||
"abilities": {
|
||||
"accesses_manage": False,
|
||||
"accesses_view": False,
|
||||
"ai_proxy": grand_parent.link_role == "editor",
|
||||
"ai_transform": grand_parent.link_role == "editor",
|
||||
"ai_translate": grand_parent.link_role == "editor",
|
||||
"attachment_upload": grand_parent.link_role == "editor",
|
||||
@@ -498,6 +502,7 @@ def test_api_documents_retrieve_authenticated_related_parent():
|
||||
"abilities": {
|
||||
"accesses_manage": access.role in ["administrator", "owner"],
|
||||
"accesses_view": True,
|
||||
"ai_proxy": access.role not in ["reader", "commenter"],
|
||||
"ai_transform": access.role not in ["reader", "commenter"],
|
||||
"ai_translate": access.role not in ["reader", "commenter"],
|
||||
"attachment_upload": access.role not in ["reader", "commenter"],
|
||||
|
||||
@@ -72,6 +72,7 @@ def test_api_documents_trashbin_format():
|
||||
"abilities": {
|
||||
"accesses_manage": False,
|
||||
"accesses_view": False,
|
||||
"ai_proxy": False,
|
||||
"ai_transform": False,
|
||||
"ai_translate": False,
|
||||
"attachment_upload": False,
|
||||
|
||||
@@ -155,6 +155,7 @@ def test_models_documents_get_abilities_forbidden(
|
||||
expected_abilities = {
|
||||
"accesses_manage": False,
|
||||
"accesses_view": False,
|
||||
"ai_proxy": False,
|
||||
"ai_transform": False,
|
||||
"ai_translate": False,
|
||||
"attachment_upload": False,
|
||||
@@ -220,6 +221,7 @@ def test_models_documents_get_abilities_reader(
|
||||
expected_abilities = {
|
||||
"accesses_manage": False,
|
||||
"accesses_view": False,
|
||||
"ai_proxy": False,
|
||||
"ai_transform": False,
|
||||
"ai_translate": False,
|
||||
"attachment_upload": False,
|
||||
@@ -290,6 +292,7 @@ def test_models_documents_get_abilities_commenter(
|
||||
expected_abilities = {
|
||||
"accesses_manage": False,
|
||||
"accesses_view": False,
|
||||
"ai_proxy": False,
|
||||
"ai_transform": False,
|
||||
"ai_translate": False,
|
||||
"attachment_upload": False,
|
||||
@@ -357,6 +360,7 @@ def test_models_documents_get_abilities_editor(
|
||||
expected_abilities = {
|
||||
"accesses_manage": False,
|
||||
"accesses_view": False,
|
||||
"ai_proxy": is_authenticated,
|
||||
"ai_transform": is_authenticated,
|
||||
"ai_translate": is_authenticated,
|
||||
"attachment_upload": True,
|
||||
@@ -413,6 +417,7 @@ def test_models_documents_get_abilities_owner(django_assert_num_queries):
|
||||
expected_abilities = {
|
||||
"accesses_manage": True,
|
||||
"accesses_view": True,
|
||||
"ai_proxy": True,
|
||||
"ai_transform": True,
|
||||
"ai_translate": True,
|
||||
"attachment_upload": True,
|
||||
@@ -455,6 +460,7 @@ def test_models_documents_get_abilities_owner(django_assert_num_queries):
|
||||
assert document.get_abilities(user) == {
|
||||
"accesses_manage": False,
|
||||
"accesses_view": False,
|
||||
"ai_proxy": False,
|
||||
"ai_transform": False,
|
||||
"ai_translate": False,
|
||||
"attachment_upload": False,
|
||||
@@ -501,6 +507,7 @@ def test_models_documents_get_abilities_administrator(django_assert_num_queries)
|
||||
expected_abilities = {
|
||||
"accesses_manage": True,
|
||||
"accesses_view": True,
|
||||
"ai_proxy": True,
|
||||
"ai_transform": True,
|
||||
"ai_translate": True,
|
||||
"attachment_upload": True,
|
||||
@@ -557,6 +564,7 @@ def test_models_documents_get_abilities_editor_user(django_assert_num_queries):
|
||||
expected_abilities = {
|
||||
"accesses_manage": False,
|
||||
"accesses_view": True,
|
||||
"ai_proxy": True,
|
||||
"ai_transform": True,
|
||||
"ai_translate": True,
|
||||
"attachment_upload": True,
|
||||
@@ -620,6 +628,7 @@ def test_models_documents_get_abilities_reader_user(
|
||||
"accesses_view": True,
|
||||
# If you get your editor rights from the link role and not your access role
|
||||
# You should not access AI if it's restricted to users with specific access
|
||||
"ai_proxy": access_from_link and ai_access_setting != "restricted",
|
||||
"ai_transform": access_from_link and ai_access_setting != "restricted",
|
||||
"ai_translate": access_from_link and ai_access_setting != "restricted",
|
||||
"attachment_upload": access_from_link,
|
||||
@@ -686,6 +695,7 @@ def test_models_documents_get_abilities_commenter_user(
|
||||
"accesses_view": True,
|
||||
# If you get your editor rights from the link role and not your access role
|
||||
# You should not access AI if it's restricted to users with specific access
|
||||
"ai_proxy": access_from_link and ai_access_setting != "restricted",
|
||||
"ai_transform": access_from_link and ai_access_setting != "restricted",
|
||||
"ai_translate": access_from_link and ai_access_setting != "restricted",
|
||||
"attachment_upload": access_from_link,
|
||||
@@ -747,6 +757,7 @@ def test_models_documents_get_abilities_preset_role(django_assert_num_queries):
|
||||
assert abilities == {
|
||||
"accesses_manage": False,
|
||||
"accesses_view": True,
|
||||
"ai_proxy": False,
|
||||
"ai_transform": False,
|
||||
"ai_translate": False,
|
||||
"attachment_upload": False,
|
||||
@@ -878,6 +889,7 @@ def test_models_document_get_abilities_ai_access_authenticated(is_authenticated,
|
||||
document = factories.DocumentFactory(link_reach=reach, link_role="editor")
|
||||
|
||||
abilities = document.get_abilities(user)
|
||||
assert abilities["ai_proxy"] is True
|
||||
assert abilities["ai_transform"] is True
|
||||
assert abilities["ai_translate"] is True
|
||||
|
||||
@@ -897,6 +909,7 @@ def test_models_document_get_abilities_ai_access_public(is_authenticated, reach)
|
||||
document = factories.DocumentFactory(link_reach=reach, link_role="editor")
|
||||
|
||||
abilities = document.get_abilities(user)
|
||||
assert abilities["ai_proxy"] == is_authenticated
|
||||
assert abilities["ai_transform"] == is_authenticated
|
||||
assert abilities["ai_translate"] == is_authenticated
|
||||
|
||||
|
||||
@@ -74,6 +74,7 @@ export interface Doc {
|
||||
abilities: {
|
||||
accesses_manage: boolean;
|
||||
accesses_view: boolean;
|
||||
ai_proxy: boolean;
|
||||
ai_transform: boolean;
|
||||
ai_translate: boolean;
|
||||
attachment_upload: boolean;
|
||||
|
||||
@@ -182,6 +182,7 @@ export class ApiPlugin implements WorkboxPlugin {
|
||||
abilities: {
|
||||
accesses_manage: true,
|
||||
accesses_view: true,
|
||||
ai_proxy: true,
|
||||
ai_transform: true,
|
||||
ai_translate: true,
|
||||
attachment_upload: true,
|
||||
|
||||
@@ -29,6 +29,7 @@ interface Doc {
|
||||
abilities: {
|
||||
accesses_manage: boolean;
|
||||
accesses_view: boolean;
|
||||
ai_proxy: boolean;
|
||||
ai_transform: boolean;
|
||||
ai_translate: boolean;
|
||||
attachment_upload: boolean;
|
||||
|
||||
Reference in New Issue
Block a user