✨(backend) create ai endpoint
We created 2 new action endpoints on the document
to perform AI operations:
- POST /api/v1.0/documents/{uuid}/ai-transform
- POST /api/v1.0/documents/{uuid}/ai-translate
This commit is contained in:
committed by
Samuel Paccoud
parent
e8d95facdf
commit
aff3b43c9d
@@ -39,3 +39,8 @@ LOGOUT_REDIRECT_URL=http://localhost:3000
|
|||||||
|
|
||||||
OIDC_REDIRECT_ALLOWED_HOSTS=["http://localhost:8083", "http://localhost:3000"]
|
OIDC_REDIRECT_ALLOWED_HOSTS=["http://localhost:8083", "http://localhost:3000"]
|
||||||
OIDC_AUTH_REQUEST_EXTRA_PARAMS={"acr_values": "eidas1"}
|
OIDC_AUTH_REQUEST_EXTRA_PARAMS={"acr_values": "eidas1"}
|
||||||
|
|
||||||
|
# AI
|
||||||
|
AI_BASE_URL=https://openaiendpoint.com
|
||||||
|
AI_API_KEY=password
|
||||||
|
AI_MODEL=llama
|
||||||
|
|||||||
@@ -9,7 +9,8 @@ from django.utils.translation import gettext_lazy as _
|
|||||||
import magic
|
import magic
|
||||||
from rest_framework import exceptions, serializers
|
from rest_framework import exceptions, serializers
|
||||||
|
|
||||||
from core import models
|
from core import enums, models
|
||||||
|
from core.services.ai_services import AI_ACTIONS
|
||||||
|
|
||||||
|
|
||||||
class UserSerializer(serializers.ModelSerializer):
|
class UserSerializer(serializers.ModelSerializer):
|
||||||
@@ -378,3 +379,33 @@ class VersionFilterSerializer(serializers.Serializer):
|
|||||||
page_size = serializers.IntegerField(
|
page_size = serializers.IntegerField(
|
||||||
required=False, min_value=1, max_value=50, default=20
|
required=False, min_value=1, max_value=50, default=20
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class AITransformSerializer(serializers.Serializer):
|
||||||
|
"""Serializer for AI transform requests."""
|
||||||
|
|
||||||
|
action = serializers.ChoiceField(choices=AI_ACTIONS, required=True)
|
||||||
|
text = serializers.CharField(required=True)
|
||||||
|
|
||||||
|
def validate_text(self, value):
|
||||||
|
"""Ensure the text field is not empty."""
|
||||||
|
|
||||||
|
if len(value.strip()) == 0:
|
||||||
|
raise serializers.ValidationError("Text field cannot be empty.")
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
class AITranslateSerializer(serializers.Serializer):
|
||||||
|
"""Serializer for AI translate requests."""
|
||||||
|
|
||||||
|
language = serializers.ChoiceField(
|
||||||
|
choices=tuple(enums.ALL_LANGUAGES.items()), required=True
|
||||||
|
)
|
||||||
|
text = serializers.CharField(required=True)
|
||||||
|
|
||||||
|
def validate_text(self, value):
|
||||||
|
"""Ensure the text field is not empty."""
|
||||||
|
|
||||||
|
if len(value.strip()) == 0:
|
||||||
|
raise serializers.ValidationError("Text field cannot be empty.")
|
||||||
|
return value
|
||||||
|
|||||||
@@ -1,8 +1,14 @@
|
|||||||
"""Util to generate S3 authorization headers for object storage access control"""
|
"""Util to generate S3 authorization headers for object storage access control"""
|
||||||
|
|
||||||
|
import time
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
|
from django.conf import settings
|
||||||
|
from django.core.cache import cache
|
||||||
from django.core.files.storage import default_storage
|
from django.core.files.storage import default_storage
|
||||||
|
|
||||||
import botocore
|
import botocore
|
||||||
|
from rest_framework.throttling import BaseThrottle
|
||||||
|
|
||||||
|
|
||||||
def generate_s3_authorization_headers(key):
|
def generate_s3_authorization_headers(key):
|
||||||
@@ -31,3 +37,93 @@ def generate_s3_authorization_headers(key):
|
|||||||
auth.add_auth(request)
|
auth.add_auth(request)
|
||||||
|
|
||||||
return request
|
return request
|
||||||
|
|
||||||
|
|
||||||
|
class AIBaseRateThrottle(BaseThrottle, ABC):
|
||||||
|
"""Base throttle class for AI-related rate limiting with backoff."""
|
||||||
|
|
||||||
|
def __init__(self, rates):
|
||||||
|
"""Initialize instance attributes with configurable rates."""
|
||||||
|
super().__init__()
|
||||||
|
self.rates = rates
|
||||||
|
self.cache_key = None
|
||||||
|
self.recent_requests_minute = 0
|
||||||
|
self.recent_requests_hour = 0
|
||||||
|
self.recent_requests_day = 0
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_cache_key(self, request, view):
|
||||||
|
"""Abstract method to generate cache key for throttling."""
|
||||||
|
|
||||||
|
def allow_request(self, request, view):
|
||||||
|
"""Check if the request is allowed based on rate limits."""
|
||||||
|
self.cache_key = self.get_cache_key(request, view)
|
||||||
|
if not self.cache_key:
|
||||||
|
return True # Allow if no cache key is generated
|
||||||
|
|
||||||
|
now = time.time()
|
||||||
|
history = cache.get(self.cache_key, [])
|
||||||
|
# Keep requests within the last 24 hours
|
||||||
|
history = [req for req in history if req > now - 86400]
|
||||||
|
|
||||||
|
# Calculate recent requests
|
||||||
|
self.recent_requests_minute = len([req for req in history if req > now - 60])
|
||||||
|
self.recent_requests_hour = len([req for req in history if req > now - 3600])
|
||||||
|
self.recent_requests_day = len(history)
|
||||||
|
|
||||||
|
# Check rate limits
|
||||||
|
if self.recent_requests_minute >= self.rates["minute"]:
|
||||||
|
return False
|
||||||
|
if self.recent_requests_hour >= self.rates["hour"]:
|
||||||
|
return False
|
||||||
|
if self.recent_requests_day >= self.rates["day"]:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Log the request
|
||||||
|
history.append(now)
|
||||||
|
cache.set(self.cache_key, history, timeout=86400)
|
||||||
|
return True
|
||||||
|
|
||||||
|
def wait(self):
|
||||||
|
"""Implement a backoff strategy by increasing wait time based on limits hit."""
|
||||||
|
if self.recent_requests_day >= self.rates["day"]:
|
||||||
|
return 86400
|
||||||
|
if self.recent_requests_hour >= self.rates["hour"]:
|
||||||
|
return 3600
|
||||||
|
if self.recent_requests_minute >= self.rates["minute"]:
|
||||||
|
return 60
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class AIDocumentRateThrottle(AIBaseRateThrottle):
|
||||||
|
"""Throttle for limiting AI requests per document with backoff."""
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(settings.AI_DOCUMENT_RATE_THROTTLE_RATES)
|
||||||
|
|
||||||
|
def get_cache_key(self, request, view):
|
||||||
|
"""Include document ID in the cache key."""
|
||||||
|
document_id = view.kwargs["pk"]
|
||||||
|
return f"document_{document_id}_throttle_ai"
|
||||||
|
|
||||||
|
|
||||||
|
class AIUserRateThrottle(AIBaseRateThrottle):
|
||||||
|
"""Throttle that limits requests per user or IP with backoff and rate limits."""
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(settings.AI_USER_RATE_THROTTLE_RATES)
|
||||||
|
|
||||||
|
def get_cache_key(self, request, view=None):
|
||||||
|
"""Generate a cache key based on the user ID or IP for anonymous users."""
|
||||||
|
if request.user.is_authenticated:
|
||||||
|
return f"user_{request.user.id!s}_throttle_ai"
|
||||||
|
return f"anonymous_{self.get_ident(request)}_throttle_ai"
|
||||||
|
|
||||||
|
def get_ident(self, request):
|
||||||
|
"""Return the request IP address."""
|
||||||
|
x_forwarded_for = request.META.get("HTTP_X_FORWARDED_FOR")
|
||||||
|
return (
|
||||||
|
x_forwarded_for.split(",")[0]
|
||||||
|
if x_forwarded_for
|
||||||
|
else request.META.get("REMOTE_ADDR")
|
||||||
|
)
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ from rest_framework import (
|
|||||||
decorators,
|
decorators,
|
||||||
exceptions,
|
exceptions,
|
||||||
filters,
|
filters,
|
||||||
|
metadata,
|
||||||
mixins,
|
mixins,
|
||||||
pagination,
|
pagination,
|
||||||
status,
|
status,
|
||||||
@@ -30,7 +31,8 @@ from rest_framework import (
|
|||||||
response as drf_response,
|
response as drf_response,
|
||||||
)
|
)
|
||||||
|
|
||||||
from core import models
|
from core import enums, models
|
||||||
|
from core.services.ai_services import AIService
|
||||||
|
|
||||||
from . import permissions, serializers, utils
|
from . import permissions, serializers, utils
|
||||||
|
|
||||||
@@ -302,6 +304,23 @@ class ResourceAccessViewsetMixin:
|
|||||||
serializer.save()
|
serializer.save()
|
||||||
|
|
||||||
|
|
||||||
|
class DocumentMetadata(metadata.SimpleMetadata):
|
||||||
|
"""Custom metadata class to add information"""
|
||||||
|
|
||||||
|
def determine_metadata(self, request, view):
|
||||||
|
"""Add language choices only for the list endpoint."""
|
||||||
|
simple_metadata = super().determine_metadata(request, view)
|
||||||
|
|
||||||
|
if request.path.endswith("/documents/"):
|
||||||
|
simple_metadata["actions"]["POST"]["language"] = {
|
||||||
|
"choices": [
|
||||||
|
{"value": code, "display_name": name}
|
||||||
|
for code, name in enums.ALL_LANGUAGES.items()
|
||||||
|
]
|
||||||
|
}
|
||||||
|
return simple_metadata
|
||||||
|
|
||||||
|
|
||||||
class DocumentViewSet(
|
class DocumentViewSet(
|
||||||
ResourceViewsetMixin,
|
ResourceViewsetMixin,
|
||||||
mixins.CreateModelMixin,
|
mixins.CreateModelMixin,
|
||||||
@@ -319,6 +338,7 @@ class DocumentViewSet(
|
|||||||
resource_field_name = "document"
|
resource_field_name = "document"
|
||||||
queryset = models.Document.objects.all()
|
queryset = models.Document.objects.all()
|
||||||
ordering = ["-updated_at"]
|
ordering = ["-updated_at"]
|
||||||
|
metadata_class = DocumentMetadata
|
||||||
|
|
||||||
def list(self, request, *args, **kwargs):
|
def list(self, request, *args, **kwargs):
|
||||||
"""Restrict resources returned by the list endpoint"""
|
"""Restrict resources returned by the list endpoint"""
|
||||||
@@ -455,10 +475,7 @@ class DocumentViewSet(
|
|||||||
serializer = serializers.LinkDocumentSerializer(
|
serializer = serializers.LinkDocumentSerializer(
|
||||||
document, data=request.data, partial=True
|
document, data=request.data, partial=True
|
||||||
)
|
)
|
||||||
if not serializer.is_valid():
|
serializer.is_valid(raise_exception=True)
|
||||||
return drf_response.Response(
|
|
||||||
serializer.errors, status=status.HTTP_400_BAD_REQUEST
|
|
||||||
)
|
|
||||||
|
|
||||||
serializer.save()
|
serializer.save()
|
||||||
return drf_response.Response(serializer.data, status=status.HTTP_200_OK)
|
return drf_response.Response(serializer.data, status=status.HTTP_200_OK)
|
||||||
@@ -471,10 +488,7 @@ class DocumentViewSet(
|
|||||||
|
|
||||||
# Validate metadata in payload
|
# Validate metadata in payload
|
||||||
serializer = serializers.FileUploadSerializer(data=request.data)
|
serializer = serializers.FileUploadSerializer(data=request.data)
|
||||||
if not serializer.is_valid():
|
serializer.is_valid(raise_exception=True)
|
||||||
return drf_response.Response(
|
|
||||||
serializer.errors, status=status.HTTP_400_BAD_REQUEST
|
|
||||||
)
|
|
||||||
|
|
||||||
# Generate a generic yet unique filename to store the image in object storage
|
# Generate a generic yet unique filename to store the image in object storage
|
||||||
file_id = uuid.uuid4()
|
file_id = uuid.uuid4()
|
||||||
@@ -482,13 +496,13 @@ class DocumentViewSet(
|
|||||||
key = f"{document.key_base}/{ATTACHMENTS_FOLDER:s}/{file_id!s}.{extension:s}"
|
key = f"{document.key_base}/{ATTACHMENTS_FOLDER:s}/{file_id!s}.{extension:s}"
|
||||||
|
|
||||||
# Prepare metadata for storage
|
# Prepare metadata for storage
|
||||||
metadata = {"Metadata": {"owner": str(request.user.id)}}
|
extra_args = {"Metadata": {"owner": str(request.user.id)}}
|
||||||
if serializer.validated_data["is_unsafe"]:
|
if serializer.validated_data["is_unsafe"]:
|
||||||
metadata["Metadata"]["is_unsafe"] = "true"
|
extra_args["Metadata"]["is_unsafe"] = "true"
|
||||||
|
|
||||||
file = serializer.validated_data["file"]
|
file = serializer.validated_data["file"]
|
||||||
default_storage.connection.meta.client.upload_fileobj(
|
default_storage.connection.meta.client.upload_fileobj(
|
||||||
file, default_storage.bucket_name, key, ExtraArgs=metadata
|
file, default_storage.bucket_name, key, ExtraArgs=extra_args
|
||||||
)
|
)
|
||||||
|
|
||||||
return drf_response.Response(
|
return drf_response.Response(
|
||||||
@@ -537,6 +551,63 @@ class DocumentViewSet(
|
|||||||
request = utils.generate_s3_authorization_headers(f"{pk:s}/{attachment_key:s}")
|
request = utils.generate_s3_authorization_headers(f"{pk:s}/{attachment_key:s}")
|
||||||
return drf_response.Response("authorized", headers=request.headers, status=200)
|
return drf_response.Response("authorized", headers=request.headers, status=200)
|
||||||
|
|
||||||
|
@decorators.action(
|
||||||
|
detail=True,
|
||||||
|
methods=["post"],
|
||||||
|
name="Apply a transformation action on a piece of text with AI",
|
||||||
|
url_path="ai-transform",
|
||||||
|
throttle_classes=[utils.AIDocumentRateThrottle, utils.AIUserRateThrottle],
|
||||||
|
)
|
||||||
|
def ai_transform(self, request, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
POST /api/v1.0/documents/<resource_id>/ai-transform
|
||||||
|
with expected data:
|
||||||
|
- text: str
|
||||||
|
- action: str [prompt, correct, rephrase, summarize]
|
||||||
|
Return JSON response with the processed text.
|
||||||
|
"""
|
||||||
|
# Check permissions first
|
||||||
|
self.get_object()
|
||||||
|
|
||||||
|
serializer = serializers.AITransformSerializer(data=request.data)
|
||||||
|
serializer.is_valid(raise_exception=True)
|
||||||
|
|
||||||
|
text = serializer.validated_data["text"]
|
||||||
|
action = serializer.validated_data["action"]
|
||||||
|
|
||||||
|
response = AIService().transform(text, action)
|
||||||
|
|
||||||
|
return drf_response.Response(response, status=status.HTTP_200_OK)
|
||||||
|
|
||||||
|
@decorators.action(
|
||||||
|
detail=True,
|
||||||
|
methods=["post"],
|
||||||
|
name="Translate a piece of text with AI",
|
||||||
|
serializer_class=serializers.AITranslateSerializer,
|
||||||
|
url_path="ai-translate",
|
||||||
|
throttle_classes=[utils.AIDocumentRateThrottle, utils.AIUserRateThrottle],
|
||||||
|
)
|
||||||
|
def ai_translate(self, request, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
POST /api/v1.0/documents/<resource_id>/ai-translate
|
||||||
|
with expected data:
|
||||||
|
- text: str
|
||||||
|
- language: str [settings.LANGUAGES]
|
||||||
|
Return JSON response with the translated text.
|
||||||
|
"""
|
||||||
|
# Check permissions first
|
||||||
|
self.get_object()
|
||||||
|
|
||||||
|
serializer = self.get_serializer(data=request.data)
|
||||||
|
serializer.is_valid(raise_exception=True)
|
||||||
|
|
||||||
|
text = serializer.validated_data["text"]
|
||||||
|
language = serializer.validated_data["language"]
|
||||||
|
|
||||||
|
response = AIService().translate(text, language)
|
||||||
|
|
||||||
|
return drf_response.Response(response, status=status.HTTP_200_OK)
|
||||||
|
|
||||||
|
|
||||||
class DocumentAccessViewSet(
|
class DocumentAccessViewSet(
|
||||||
ResourceAccessViewsetMixin,
|
ResourceAccessViewsetMixin,
|
||||||
|
|||||||
@@ -2,15 +2,11 @@
|
|||||||
Core application enums declaration
|
Core application enums declaration
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from django.conf import global_settings, settings
|
from django.conf import global_settings
|
||||||
from django.utils.translation import gettext_lazy as _
|
from django.utils.translation import gettext_lazy as _
|
||||||
|
|
||||||
# Django sets `LANGUAGES` by default with all supported languages. We can use it for
|
# In Django's code base, `LANGUAGES` is set by default with all supported languages.
|
||||||
# the choice of languages which should not be limited to the few languages active in
|
# We can use it for the choice of languages which should not be limited to the few languages
|
||||||
# the app.
|
# active in the app.
|
||||||
# pylint: disable=no-member
|
# pylint: disable=no-member
|
||||||
ALL_LANGUAGES = getattr(
|
ALL_LANGUAGES = {language: _(name) for language, name in global_settings.LANGUAGES}
|
||||||
settings,
|
|
||||||
"ALL_LANGUAGES",
|
|
||||||
[(language, _(name)) for language, name in global_settings.LANGUAGES],
|
|
||||||
)
|
|
||||||
|
|||||||
@@ -508,6 +508,8 @@ class Document(BaseModel):
|
|||||||
can_get = bool(roles)
|
can_get = bool(roles)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
"ai_transform": is_owner_or_admin or is_editor,
|
||||||
|
"ai_translate": is_owner_or_admin or is_editor,
|
||||||
"attachment_upload": is_owner_or_admin or is_editor,
|
"attachment_upload": is_owner_or_admin or is_editor,
|
||||||
"destroy": RoleChoices.OWNER in roles,
|
"destroy": RoleChoices.OWNER in roles,
|
||||||
"link_configuration": is_owner_or_admin,
|
"link_configuration": is_owner_or_admin,
|
||||||
|
|||||||
0
src/backend/core/services/__init__.py
Normal file
0
src/backend/core/services/__init__.py
Normal file
89
src/backend/core/services/ai_services.py
Normal file
89
src/backend/core/services/ai_services.py
Normal file
@@ -0,0 +1,89 @@
|
|||||||
|
"""AI services."""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import re
|
||||||
|
|
||||||
|
from django.conf import settings
|
||||||
|
from django.core.exceptions import ImproperlyConfigured
|
||||||
|
|
||||||
|
from openai import OpenAI
|
||||||
|
|
||||||
|
from core import enums
|
||||||
|
|
||||||
|
AI_ACTIONS = {
|
||||||
|
"prompt": (
|
||||||
|
"Answer the prompt in markdown format. Return JSON: "
|
||||||
|
'{"answer": "Your markdown answer"}. '
|
||||||
|
"Do not provide any other information."
|
||||||
|
),
|
||||||
|
"correct": (
|
||||||
|
"Correct grammar and spelling of the markdown text, "
|
||||||
|
"preserving language and markdown formatting. "
|
||||||
|
'Return JSON: {"answer": "your corrected markdown text"}. '
|
||||||
|
"Do not provide any other information."
|
||||||
|
),
|
||||||
|
"rephrase": (
|
||||||
|
"Rephrase the given markdown text, "
|
||||||
|
"preserving language and markdown formatting. "
|
||||||
|
'Return JSON: {"answer": "your rephrased markdown text"}. '
|
||||||
|
"Do not provide any other information."
|
||||||
|
),
|
||||||
|
"summarize": (
|
||||||
|
"Summarize the markdown text, preserving language and markdown formatting. "
|
||||||
|
'Return JSON: {"answer": "your markdown summary"}. '
|
||||||
|
"Do not provide any other information."
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
AI_TRANSLATE = (
|
||||||
|
"Translate the markdown text to {language:s}, preserving markdown formatting. "
|
||||||
|
'Return JSON: {{"answer": "your translated markdown text in {language:s}"}}. '
|
||||||
|
"Do not provide any other information."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class AIService:
|
||||||
|
"""Service class for AI-related operations."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
"""Ensure that the AI configuration is set properly."""
|
||||||
|
if (
|
||||||
|
settings.AI_BASE_URL is None
|
||||||
|
or settings.AI_API_KEY is None
|
||||||
|
or settings.AI_MODEL is None
|
||||||
|
):
|
||||||
|
raise ImproperlyConfigured("AI configuration not set")
|
||||||
|
self.client = OpenAI(base_url=settings.AI_BASE_URL, api_key=settings.AI_API_KEY)
|
||||||
|
|
||||||
|
def call_ai_api(self, system_content, text):
|
||||||
|
"""Helper method to call the OpenAI API and process the response."""
|
||||||
|
response = self.client.chat.completions.create(
|
||||||
|
model=settings.AI_MODEL,
|
||||||
|
response_format={"type": "json_object"},
|
||||||
|
messages=[
|
||||||
|
{"role": "system", "content": system_content},
|
||||||
|
{"role": "user", "content": json.dumps({"markdown_input": text})},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
content = response.choices[0].message.content
|
||||||
|
sanitized_content = re.sub(r"(?<!\\)\n", "\\\\n", content)
|
||||||
|
sanitized_content = re.sub(r"(?<!\\)\t", "\\\\t", sanitized_content)
|
||||||
|
|
||||||
|
json_response = json.loads(sanitized_content)
|
||||||
|
|
||||||
|
if "answer" not in json_response:
|
||||||
|
raise RuntimeError("AI response does not contain an answer")
|
||||||
|
|
||||||
|
return json_response
|
||||||
|
|
||||||
|
def transform(self, text, action):
|
||||||
|
"""Transform text based on specified action."""
|
||||||
|
system_content = AI_ACTIONS[action]
|
||||||
|
return self.call_ai_api(system_content, text)
|
||||||
|
|
||||||
|
def translate(self, text, language):
|
||||||
|
"""Translate text to a specified language."""
|
||||||
|
language_display = enums.ALL_LANGUAGES.get(language, language)
|
||||||
|
system_content = AI_TRANSLATE.format(language=language_display)
|
||||||
|
return self.call_ai_api(system_content, text)
|
||||||
@@ -0,0 +1,346 @@
|
|||||||
|
"""
|
||||||
|
Test AI transform API endpoint for users in impress's core app.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
from django.core.cache import cache
|
||||||
|
from django.test import override_settings
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from rest_framework.test import APIClient
|
||||||
|
|
||||||
|
from core import factories
|
||||||
|
from core.tests.conftest import TEAM, USER, VIA
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.django_db
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def clear_cache():
|
||||||
|
"""Fixture to clear the cache before each test."""
|
||||||
|
cache.clear()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def ai_settings():
|
||||||
|
"""Fixture to set AI settings."""
|
||||||
|
with override_settings(
|
||||||
|
AI_BASE_URL="http://example.com", AI_API_KEY="test-key", AI_MODEL="llama"
|
||||||
|
):
|
||||||
|
yield
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"reach, role",
|
||||||
|
[
|
||||||
|
("restricted", "reader"),
|
||||||
|
("restricted", "editor"),
|
||||||
|
("authenticated", "reader"),
|
||||||
|
("authenticated", "editor"),
|
||||||
|
("public", "reader"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_api_documents_ai_transform_anonymous_forbidden(reach, role):
|
||||||
|
"""
|
||||||
|
Anonymous users should not be able to request AI transform if the link reach
|
||||||
|
and role don't allow it.
|
||||||
|
"""
|
||||||
|
document = factories.DocumentFactory(link_reach=reach, link_role=role)
|
||||||
|
|
||||||
|
url = f"/api/v1.0/documents/{document.id!s}/ai-transform/"
|
||||||
|
response = APIClient().post(url, {"text": "hello", "action": "prompt"})
|
||||||
|
|
||||||
|
assert response.status_code == 401
|
||||||
|
assert response.json() == {
|
||||||
|
"detail": "Authentication credentials were not provided."
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.usefixtures("ai_settings")
|
||||||
|
@patch("openai.resources.chat.completions.Completions.create")
|
||||||
|
def test_api_documents_ai_transform_anonymous_success(mock_create):
|
||||||
|
"""
|
||||||
|
Anonymous users should be able to request AI transform to a document
|
||||||
|
if the link reach and role permit it.
|
||||||
|
"""
|
||||||
|
document = factories.DocumentFactory(link_reach="public", link_role="editor")
|
||||||
|
|
||||||
|
answer = '{"answer": "Salut"}'
|
||||||
|
mock_create.return_value = MagicMock(
|
||||||
|
choices=[MagicMock(message=MagicMock(content=answer))]
|
||||||
|
)
|
||||||
|
|
||||||
|
url = f"/api/v1.0/documents/{document.id!s}/ai-transform/"
|
||||||
|
response = APIClient().post(url, {"text": "Hello", "action": "summarize"})
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json() == {"answer": "Salut"}
|
||||||
|
mock_create.assert_called_once_with(
|
||||||
|
model="llama",
|
||||||
|
response_format={"type": "json_object"},
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": (
|
||||||
|
"Summarize the markdown text, preserving language and markdown formatting. "
|
||||||
|
'Return JSON: {"answer": "your markdown summary"}. Do not provide any other '
|
||||||
|
"information."
|
||||||
|
),
|
||||||
|
},
|
||||||
|
{"role": "user", "content": '{"markdown_input": "Hello"}'},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"reach, role",
|
||||||
|
[
|
||||||
|
("restricted", "reader"),
|
||||||
|
("restricted", "editor"),
|
||||||
|
("authenticated", "reader"),
|
||||||
|
("public", "reader"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_api_documents_ai_transform_authenticated_forbidden(reach, role):
|
||||||
|
"""
|
||||||
|
Users who are not related to a document can't request AI transform if the
|
||||||
|
link reach and role don't allow it.
|
||||||
|
"""
|
||||||
|
user = factories.UserFactory()
|
||||||
|
|
||||||
|
client = APIClient()
|
||||||
|
client.force_login(user)
|
||||||
|
|
||||||
|
document = factories.DocumentFactory(link_reach=reach, link_role=role)
|
||||||
|
|
||||||
|
url = f"/api/v1.0/documents/{document.id!s}/ai-transform/"
|
||||||
|
response = client.post(url, {"text": "Hello", "action": "prompt"})
|
||||||
|
|
||||||
|
assert response.status_code == 403
|
||||||
|
assert response.json() == {
|
||||||
|
"detail": "You do not have permission to perform this action."
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"reach, role",
|
||||||
|
[
|
||||||
|
("authenticated", "editor"),
|
||||||
|
("public", "editor"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@pytest.mark.usefixtures("ai_settings")
|
||||||
|
@patch("openai.resources.chat.completions.Completions.create")
|
||||||
|
def test_api_documents_ai_transform_authenticated_success(mock_create, reach, role):
|
||||||
|
"""
|
||||||
|
Autenticated who are not related to a document should be able to request AI transform
|
||||||
|
if the link reach and role permit it.
|
||||||
|
"""
|
||||||
|
user = factories.UserFactory()
|
||||||
|
|
||||||
|
client = APIClient()
|
||||||
|
client.force_login(user)
|
||||||
|
|
||||||
|
document = factories.DocumentFactory(link_reach=reach, link_role=role)
|
||||||
|
|
||||||
|
answer = '{"answer": "Salut"}'
|
||||||
|
mock_create.return_value = MagicMock(
|
||||||
|
choices=[MagicMock(message=MagicMock(content=answer))]
|
||||||
|
)
|
||||||
|
|
||||||
|
url = f"/api/v1.0/documents/{document.id!s}/ai-transform/"
|
||||||
|
response = client.post(url, {"text": "Hello", "action": "prompt"})
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json() == {"answer": "Salut"}
|
||||||
|
mock_create.assert_called_once_with(
|
||||||
|
model="llama",
|
||||||
|
response_format={"type": "json_object"},
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": (
|
||||||
|
'Answer the prompt in markdown format. Return JSON: {"answer": '
|
||||||
|
'"Your markdown answer"}. Do not provide any other information.'
|
||||||
|
),
|
||||||
|
},
|
||||||
|
{"role": "user", "content": '{"markdown_input": "Hello"}'},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("via", VIA)
|
||||||
|
def test_api_documents_ai_transform_reader(via, mock_user_teams):
|
||||||
|
"""
|
||||||
|
Users who are simple readers on a document should not be allowed to request AI transform.
|
||||||
|
"""
|
||||||
|
user = factories.UserFactory()
|
||||||
|
|
||||||
|
client = APIClient()
|
||||||
|
client.force_login(user)
|
||||||
|
|
||||||
|
document = factories.DocumentFactory(link_role="reader")
|
||||||
|
if via == USER:
|
||||||
|
factories.UserDocumentAccessFactory(document=document, user=user, role="reader")
|
||||||
|
elif via == TEAM:
|
||||||
|
mock_user_teams.return_value = ["lasuite", "unknown"]
|
||||||
|
factories.TeamDocumentAccessFactory(
|
||||||
|
document=document, team="lasuite", role="reader"
|
||||||
|
)
|
||||||
|
|
||||||
|
url = f"/api/v1.0/documents/{document.id!s}/ai-transform/"
|
||||||
|
response = client.post(url, {"text": "Hello", "action": "prompt"})
|
||||||
|
|
||||||
|
assert response.status_code == 403
|
||||||
|
assert response.json() == {
|
||||||
|
"detail": "You do not have permission to perform this action."
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("role", ["editor", "administrator", "owner"])
|
||||||
|
@pytest.mark.parametrize("via", VIA)
|
||||||
|
@pytest.mark.usefixtures("ai_settings")
|
||||||
|
@patch("openai.resources.chat.completions.Completions.create")
|
||||||
|
def test_api_documents_ai_transform_success(mock_create, via, role, mock_user_teams):
|
||||||
|
"""
|
||||||
|
Editors, administrators and owners of a document should be able to request AI transform.
|
||||||
|
"""
|
||||||
|
user = factories.UserFactory()
|
||||||
|
|
||||||
|
client = APIClient()
|
||||||
|
client.force_login(user)
|
||||||
|
|
||||||
|
document = factories.DocumentFactory()
|
||||||
|
if via == USER:
|
||||||
|
factories.UserDocumentAccessFactory(document=document, user=user, role=role)
|
||||||
|
elif via == TEAM:
|
||||||
|
mock_user_teams.return_value = ["lasuite", "unknown"]
|
||||||
|
factories.TeamDocumentAccessFactory(
|
||||||
|
document=document, team="lasuite", role=role
|
||||||
|
)
|
||||||
|
|
||||||
|
answer = '{"answer": "Salut"}'
|
||||||
|
mock_create.return_value = MagicMock(
|
||||||
|
choices=[MagicMock(message=MagicMock(content=answer))]
|
||||||
|
)
|
||||||
|
|
||||||
|
url = f"/api/v1.0/documents/{document.id!s}/ai-transform/"
|
||||||
|
response = client.post(url, {"text": "Hello", "action": "prompt"})
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json() == {"answer": "Salut"}
|
||||||
|
mock_create.assert_called_once_with(
|
||||||
|
model="llama",
|
||||||
|
response_format={"type": "json_object"},
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": (
|
||||||
|
'Answer the prompt in markdown format. Return JSON: {"answer": '
|
||||||
|
'"Your markdown answer"}. Do not provide any other information.'
|
||||||
|
),
|
||||||
|
},
|
||||||
|
{"role": "user", "content": '{"markdown_input": "Hello"}'},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_api_documents_ai_transform_empty_text():
|
||||||
|
"""The text should not be empty when requesting AI transform."""
|
||||||
|
user = factories.UserFactory()
|
||||||
|
|
||||||
|
client = APIClient()
|
||||||
|
client.force_login(user)
|
||||||
|
|
||||||
|
document = factories.DocumentFactory(link_reach="public", link_role="editor")
|
||||||
|
|
||||||
|
url = f"/api/v1.0/documents/{document.id!s}/ai-transform/"
|
||||||
|
response = client.post(url, {"text": " ", "action": "prompt"})
|
||||||
|
|
||||||
|
assert response.status_code == 400
|
||||||
|
assert response.json() == {"text": ["This field may not be blank."]}
|
||||||
|
|
||||||
|
|
||||||
|
def test_api_documents_ai_transform_invalid_action():
|
||||||
|
"""The action should valid when requesting AI transform."""
|
||||||
|
user = factories.UserFactory()
|
||||||
|
|
||||||
|
client = APIClient()
|
||||||
|
client.force_login(user)
|
||||||
|
|
||||||
|
document = factories.DocumentFactory(link_reach="public", link_role="editor")
|
||||||
|
|
||||||
|
url = f"/api/v1.0/documents/{document.id!s}/ai-transform/"
|
||||||
|
response = client.post(url, {"text": "Hello", "action": "invalid"})
|
||||||
|
|
||||||
|
assert response.status_code == 400
|
||||||
|
assert response.json() == {"action": ['"invalid" is not a valid choice.']}
|
||||||
|
|
||||||
|
|
||||||
|
@override_settings(AI_DOCUMENT_RATE_THROTTLE_RATES={"minute": 3, "hour": 6, "day": 10})
|
||||||
|
@pytest.mark.usefixtures("ai_settings")
|
||||||
|
@patch("openai.resources.chat.completions.Completions.create")
|
||||||
|
def test_api_documents_ai_transform_throttling_document(mock_create):
|
||||||
|
"""
|
||||||
|
Throttling per document should be triggered on the AI transform endpoint.
|
||||||
|
For full throttle class test see: `test_api_utils_ai_document_rate_throttles`
|
||||||
|
"""
|
||||||
|
client = APIClient()
|
||||||
|
document = factories.DocumentFactory(link_reach="public", link_role="editor")
|
||||||
|
|
||||||
|
answer = '{"answer": "Salut"}'
|
||||||
|
mock_create.return_value = MagicMock(
|
||||||
|
choices=[MagicMock(message=MagicMock(content=answer))]
|
||||||
|
)
|
||||||
|
|
||||||
|
url = f"/api/v1.0/documents/{document.id!s}/ai-transform/"
|
||||||
|
for _ in range(3):
|
||||||
|
user = factories.UserFactory()
|
||||||
|
client.force_login(user)
|
||||||
|
response = client.post(url, {"text": "Hello", "action": "summarize"})
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json() == {"answer": "Salut"}
|
||||||
|
|
||||||
|
user = factories.UserFactory()
|
||||||
|
client.force_login(user)
|
||||||
|
response = client.post(url, {"text": "Hello", "action": "summarize"})
|
||||||
|
|
||||||
|
assert response.status_code == 429
|
||||||
|
assert response.json() == {
|
||||||
|
"detail": "Request was throttled. Expected available in 60 seconds."
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@override_settings(AI_USER_RATE_THROTTLE_RATES={"minute": 3, "hour": 6, "day": 10})
|
||||||
|
@pytest.mark.usefixtures("ai_settings")
|
||||||
|
@patch("openai.resources.chat.completions.Completions.create")
|
||||||
|
def test_api_documents_ai_transform_throttling_user(mock_create):
|
||||||
|
"""
|
||||||
|
Throttling per user should be triggered on the AI transform endpoint.
|
||||||
|
For full throttle class test see: `test_api_utils_ai_user_rate_throttles`
|
||||||
|
"""
|
||||||
|
user = factories.UserFactory()
|
||||||
|
client = APIClient()
|
||||||
|
client.force_login(user)
|
||||||
|
|
||||||
|
answer = '{"answer": "Salut"}'
|
||||||
|
mock_create.return_value = MagicMock(
|
||||||
|
choices=[MagicMock(message=MagicMock(content=answer))]
|
||||||
|
)
|
||||||
|
|
||||||
|
for _ in range(3):
|
||||||
|
document = factories.DocumentFactory(link_reach="public", link_role="editor")
|
||||||
|
url = f"/api/v1.0/documents/{document.id!s}/ai-transform/"
|
||||||
|
response = client.post(url, {"text": "Hello", "action": "summarize"})
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json() == {"answer": "Salut"}
|
||||||
|
|
||||||
|
document = factories.DocumentFactory(link_reach="public", link_role="editor")
|
||||||
|
url = f"/api/v1.0/documents/{document.id!s}/ai-transform/"
|
||||||
|
response = client.post(url, {"text": "Hello", "action": "summarize"})
|
||||||
|
|
||||||
|
assert response.status_code == 429
|
||||||
|
assert response.json() == {
|
||||||
|
"detail": "Request was throttled. Expected available in 60 seconds."
|
||||||
|
}
|
||||||
@@ -0,0 +1,370 @@
|
|||||||
|
"""
|
||||||
|
Test AI translate API endpoint for users in impress's core app.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
from django.core.cache import cache
|
||||||
|
from django.test import override_settings
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from rest_framework.test import APIClient
|
||||||
|
|
||||||
|
from core import factories
|
||||||
|
from core.tests.conftest import TEAM, USER, VIA
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.django_db
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def clear_cache():
|
||||||
|
"""Fixture to clear the cache before each test."""
|
||||||
|
cache.clear()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def ai_settings():
|
||||||
|
"""Fixture to set AI settings."""
|
||||||
|
with override_settings(
|
||||||
|
AI_BASE_URL="http://example.com", AI_API_KEY="test-key", AI_MODEL="llama"
|
||||||
|
):
|
||||||
|
yield
|
||||||
|
|
||||||
|
|
||||||
|
def test_api_documents_ai_translate_viewset_options_metadata():
|
||||||
|
"""The documents endpoint should give us the list of available languages."""
|
||||||
|
user = factories.UserFactory()
|
||||||
|
|
||||||
|
client = APIClient()
|
||||||
|
client.force_login(user)
|
||||||
|
|
||||||
|
factories.DocumentFactory(link_reach="public", link_role="editor")
|
||||||
|
|
||||||
|
response = APIClient().options("/api/v1.0/documents/")
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
metadata = response.json()
|
||||||
|
assert metadata["name"] == "Document List"
|
||||||
|
assert metadata["actions"]["POST"]["language"]["choices"][0] == {
|
||||||
|
"value": "af",
|
||||||
|
"display_name": "Afrikaans",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"reach, role",
|
||||||
|
[
|
||||||
|
("restricted", "reader"),
|
||||||
|
("restricted", "editor"),
|
||||||
|
("authenticated", "reader"),
|
||||||
|
("authenticated", "editor"),
|
||||||
|
("public", "reader"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_api_documents_ai_translate_anonymous_forbidden(reach, role):
|
||||||
|
"""
|
||||||
|
Anonymous users should not be able to request AI translate if the link reach
|
||||||
|
and role don't allow it.
|
||||||
|
"""
|
||||||
|
document = factories.DocumentFactory(link_reach=reach, link_role=role)
|
||||||
|
|
||||||
|
url = f"/api/v1.0/documents/{document.id!s}/ai-translate/"
|
||||||
|
response = APIClient().post(url, {"text": "hello", "language": "es"})
|
||||||
|
|
||||||
|
assert response.status_code == 401
|
||||||
|
assert response.json() == {
|
||||||
|
"detail": "Authentication credentials were not provided."
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.usefixtures("ai_settings")
|
||||||
|
@patch("openai.resources.chat.completions.Completions.create")
|
||||||
|
def test_api_documents_ai_translate_anonymous_success(mock_create):
|
||||||
|
"""
|
||||||
|
Anonymous users should be able to request AI translate to a document
|
||||||
|
if the link reach and role permit it.
|
||||||
|
"""
|
||||||
|
document = factories.DocumentFactory(link_reach="public", link_role="editor")
|
||||||
|
|
||||||
|
answer = '{"answer": "Salut"}'
|
||||||
|
mock_create.return_value = MagicMock(
|
||||||
|
choices=[MagicMock(message=MagicMock(content=answer))]
|
||||||
|
)
|
||||||
|
|
||||||
|
url = f"/api/v1.0/documents/{document.id!s}/ai-translate/"
|
||||||
|
response = APIClient().post(url, {"text": "Hello", "language": "es"})
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json() == {"answer": "Salut"}
|
||||||
|
mock_create.assert_called_once_with(
|
||||||
|
model="llama",
|
||||||
|
response_format={"type": "json_object"},
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": (
|
||||||
|
"Translate the markdown text to Spanish, preserving markdown formatting. "
|
||||||
|
'Return JSON: {"answer": "your translated markdown text in Spanish"}. '
|
||||||
|
"Do not provide any other information."
|
||||||
|
),
|
||||||
|
},
|
||||||
|
{"role": "user", "content": '{"markdown_input": "Hello"}'},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"reach, role",
|
||||||
|
[
|
||||||
|
("restricted", "reader"),
|
||||||
|
("restricted", "editor"),
|
||||||
|
("authenticated", "reader"),
|
||||||
|
("public", "reader"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_api_documents_ai_translate_authenticated_forbidden(reach, role):
|
||||||
|
"""
|
||||||
|
Users who are not related to a document can't request AI translate if the
|
||||||
|
link reach and role don't allow it.
|
||||||
|
"""
|
||||||
|
user = factories.UserFactory()
|
||||||
|
|
||||||
|
client = APIClient()
|
||||||
|
client.force_login(user)
|
||||||
|
|
||||||
|
document = factories.DocumentFactory(link_reach=reach, link_role=role)
|
||||||
|
|
||||||
|
url = f"/api/v1.0/documents/{document.id!s}/ai-translate/"
|
||||||
|
response = client.post(url, {"text": "Hello", "language": "es"})
|
||||||
|
|
||||||
|
assert response.status_code == 403
|
||||||
|
assert response.json() == {
|
||||||
|
"detail": "You do not have permission to perform this action."
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"reach, role",
|
||||||
|
[
|
||||||
|
("authenticated", "editor"),
|
||||||
|
("public", "editor"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@pytest.mark.usefixtures("ai_settings")
|
||||||
|
@patch("openai.resources.chat.completions.Completions.create")
|
||||||
|
def test_api_documents_ai_translate_authenticated_success(mock_create, reach, role):
|
||||||
|
"""
|
||||||
|
Autenticated who are not related to a document should be able to request AI translate
|
||||||
|
if the link reach and role permit it.
|
||||||
|
"""
|
||||||
|
user = factories.UserFactory()
|
||||||
|
|
||||||
|
client = APIClient()
|
||||||
|
client.force_login(user)
|
||||||
|
|
||||||
|
document = factories.DocumentFactory(link_reach=reach, link_role=role)
|
||||||
|
|
||||||
|
answer = '{"answer": "Salut"}'
|
||||||
|
mock_create.return_value = MagicMock(
|
||||||
|
choices=[MagicMock(message=MagicMock(content=answer))]
|
||||||
|
)
|
||||||
|
|
||||||
|
url = f"/api/v1.0/documents/{document.id!s}/ai-translate/"
|
||||||
|
response = client.post(url, {"text": "Hello", "language": "es-co"})
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json() == {"answer": "Salut"}
|
||||||
|
mock_create.assert_called_once_with(
|
||||||
|
model="llama",
|
||||||
|
response_format={"type": "json_object"},
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": (
|
||||||
|
"Translate the markdown text to Colombian Spanish, "
|
||||||
|
"preserving markdown formatting. Return JSON: "
|
||||||
|
'{"answer": "your translated markdown text in Colombian Spanish"}. '
|
||||||
|
"Do not provide any other information."
|
||||||
|
),
|
||||||
|
},
|
||||||
|
{"role": "user", "content": '{"markdown_input": "Hello"}'},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("via", VIA)
|
||||||
|
def test_api_documents_ai_translate_reader(via, mock_user_teams):
|
||||||
|
"""
|
||||||
|
Users who are simple readers on a document should not be allowed to request AI translate.
|
||||||
|
"""
|
||||||
|
user = factories.UserFactory()
|
||||||
|
|
||||||
|
client = APIClient()
|
||||||
|
client.force_login(user)
|
||||||
|
|
||||||
|
document = factories.DocumentFactory(link_role="reader")
|
||||||
|
if via == USER:
|
||||||
|
factories.UserDocumentAccessFactory(document=document, user=user, role="reader")
|
||||||
|
elif via == TEAM:
|
||||||
|
mock_user_teams.return_value = ["lasuite", "unknown"]
|
||||||
|
factories.TeamDocumentAccessFactory(
|
||||||
|
document=document, team="lasuite", role="reader"
|
||||||
|
)
|
||||||
|
|
||||||
|
url = f"/api/v1.0/documents/{document.id!s}/ai-translate/"
|
||||||
|
response = client.post(url, {"text": "Hello", "language": "es"})
|
||||||
|
|
||||||
|
assert response.status_code == 403
|
||||||
|
assert response.json() == {
|
||||||
|
"detail": "You do not have permission to perform this action."
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("role", ["editor", "administrator", "owner"])
|
||||||
|
@pytest.mark.parametrize("via", VIA)
|
||||||
|
@pytest.mark.usefixtures("ai_settings")
|
||||||
|
@patch("openai.resources.chat.completions.Completions.create")
|
||||||
|
def test_api_documents_ai_translate_success(mock_create, via, role, mock_user_teams):
|
||||||
|
"""
|
||||||
|
Editors, administrators and owners of a document should be able to request AI translate.
|
||||||
|
"""
|
||||||
|
user = factories.UserFactory()
|
||||||
|
|
||||||
|
client = APIClient()
|
||||||
|
client.force_login(user)
|
||||||
|
|
||||||
|
document = factories.DocumentFactory()
|
||||||
|
if via == USER:
|
||||||
|
factories.UserDocumentAccessFactory(document=document, user=user, role=role)
|
||||||
|
elif via == TEAM:
|
||||||
|
mock_user_teams.return_value = ["lasuite", "unknown"]
|
||||||
|
factories.TeamDocumentAccessFactory(
|
||||||
|
document=document, team="lasuite", role=role
|
||||||
|
)
|
||||||
|
|
||||||
|
answer = '{"answer": "Salut"}'
|
||||||
|
mock_create.return_value = MagicMock(
|
||||||
|
choices=[MagicMock(message=MagicMock(content=answer))]
|
||||||
|
)
|
||||||
|
|
||||||
|
url = f"/api/v1.0/documents/{document.id!s}/ai-translate/"
|
||||||
|
response = client.post(url, {"text": "Hello", "language": "es-co"})
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json() == {"answer": "Salut"}
|
||||||
|
mock_create.assert_called_once_with(
|
||||||
|
model="llama",
|
||||||
|
response_format={"type": "json_object"},
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": (
|
||||||
|
"Translate the markdown text to Colombian Spanish, "
|
||||||
|
"preserving markdown formatting. Return JSON: "
|
||||||
|
'{"answer": "your translated markdown text in Colombian Spanish"}. '
|
||||||
|
"Do not provide any other information."
|
||||||
|
),
|
||||||
|
},
|
||||||
|
{"role": "user", "content": '{"markdown_input": "Hello"}'},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_api_documents_ai_translate_empty_text():
|
||||||
|
"""The text should not be empty when requesting AI translate."""
|
||||||
|
user = factories.UserFactory()
|
||||||
|
|
||||||
|
client = APIClient()
|
||||||
|
client.force_login(user)
|
||||||
|
|
||||||
|
document = factories.DocumentFactory(link_reach="public", link_role="editor")
|
||||||
|
|
||||||
|
url = f"/api/v1.0/documents/{document.id!s}/ai-translate/"
|
||||||
|
response = client.post(url, {"text": " ", "language": "es"})
|
||||||
|
|
||||||
|
assert response.status_code == 400
|
||||||
|
assert response.json() == {"text": ["This field may not be blank."]}
|
||||||
|
|
||||||
|
|
||||||
|
def test_api_documents_ai_translate_invalid_action():
|
||||||
|
"""The action should valid when requesting AI translate."""
|
||||||
|
user = factories.UserFactory()
|
||||||
|
|
||||||
|
client = APIClient()
|
||||||
|
client.force_login(user)
|
||||||
|
|
||||||
|
document = factories.DocumentFactory(link_reach="public", link_role="editor")
|
||||||
|
|
||||||
|
url = f"/api/v1.0/documents/{document.id!s}/ai-translate/"
|
||||||
|
response = client.post(url, {"text": "Hello", "language": "invalid"})
|
||||||
|
|
||||||
|
assert response.status_code == 400
|
||||||
|
assert response.json() == {"language": ['"invalid" is not a valid choice.']}
|
||||||
|
|
||||||
|
|
||||||
|
@override_settings(AI_DOCUMENT_RATE_THROTTLE_RATES={"minute": 3, "hour": 6, "day": 10})
|
||||||
|
@pytest.mark.usefixtures("ai_settings")
|
||||||
|
@patch("openai.resources.chat.completions.Completions.create")
|
||||||
|
def test_api_documents_ai_translate_throttling_document(mock_create):
|
||||||
|
"""
|
||||||
|
Throttling per document should be triggered on the AI translate endpoint.
|
||||||
|
For full throttle class test see: `test_api_utils_ai_document_rate_throttles`
|
||||||
|
"""
|
||||||
|
client = APIClient()
|
||||||
|
document = factories.DocumentFactory(link_reach="public", link_role="editor")
|
||||||
|
|
||||||
|
answer = '{"answer": "Salut"}'
|
||||||
|
mock_create.return_value = MagicMock(
|
||||||
|
choices=[MagicMock(message=MagicMock(content=answer))]
|
||||||
|
)
|
||||||
|
|
||||||
|
url = f"/api/v1.0/documents/{document.id!s}/ai-translate/"
|
||||||
|
for _ in range(3):
|
||||||
|
user = factories.UserFactory()
|
||||||
|
client.force_login(user)
|
||||||
|
response = client.post(url, {"text": "Hello", "language": "es"})
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json() == {"answer": "Salut"}
|
||||||
|
|
||||||
|
user = factories.UserFactory()
|
||||||
|
client.force_login(user)
|
||||||
|
response = client.post(url, {"text": "Hello", "language": "es"})
|
||||||
|
|
||||||
|
assert response.status_code == 429
|
||||||
|
assert response.json() == {
|
||||||
|
"detail": "Request was throttled. Expected available in 60 seconds."
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@override_settings(AI_USER_RATE_THROTTLE_RATES={"minute": 3, "hour": 6, "day": 10})
|
||||||
|
@pytest.mark.usefixtures("ai_settings")
|
||||||
|
@patch("openai.resources.chat.completions.Completions.create")
|
||||||
|
def test_api_documents_ai_translate_throttling_user(mock_create):
|
||||||
|
"""
|
||||||
|
Throttling per user should be triggered on the AI translate endpoint.
|
||||||
|
For full throttle class test see: `test_api_utils_ai_user_rate_throttles`
|
||||||
|
"""
|
||||||
|
user = factories.UserFactory()
|
||||||
|
client = APIClient()
|
||||||
|
client.force_login(user)
|
||||||
|
|
||||||
|
answer = '{"answer": "Salut"}'
|
||||||
|
mock_create.return_value = MagicMock(
|
||||||
|
choices=[MagicMock(message=MagicMock(content=answer))]
|
||||||
|
)
|
||||||
|
|
||||||
|
for _ in range(3):
|
||||||
|
document = factories.DocumentFactory(link_reach="public", link_role="editor")
|
||||||
|
url = f"/api/v1.0/documents/{document.id!s}/ai-translate/"
|
||||||
|
response = client.post(url, {"text": "Hello", "language": "es"})
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json() == {"answer": "Salut"}
|
||||||
|
|
||||||
|
document = factories.DocumentFactory(link_reach="public", link_role="editor")
|
||||||
|
url = f"/api/v1.0/documents/{document.id!s}/ai-translate/"
|
||||||
|
response = client.post(url, {"text": "Hello", "language": "es"})
|
||||||
|
|
||||||
|
assert response.status_code == 429
|
||||||
|
assert response.json() == {
|
||||||
|
"detail": "Request was throttled. Expected available in 60 seconds."
|
||||||
|
}
|
||||||
@@ -21,6 +21,8 @@ def test_api_documents_retrieve_anonymous_public():
|
|||||||
assert response.json() == {
|
assert response.json() == {
|
||||||
"id": str(document.id),
|
"id": str(document.id),
|
||||||
"abilities": {
|
"abilities": {
|
||||||
|
"ai_transform": document.link_role == "editor",
|
||||||
|
"ai_translate": document.link_role == "editor",
|
||||||
"attachment_upload": document.link_role == "editor",
|
"attachment_upload": document.link_role == "editor",
|
||||||
"destroy": False,
|
"destroy": False,
|
||||||
"link_configuration": False,
|
"link_configuration": False,
|
||||||
@@ -75,6 +77,8 @@ def test_api_documents_retrieve_authenticated_unrelated_public_or_authenticated(
|
|||||||
assert response.json() == {
|
assert response.json() == {
|
||||||
"id": str(document.id),
|
"id": str(document.id),
|
||||||
"abilities": {
|
"abilities": {
|
||||||
|
"ai_transform": document.link_role == "editor",
|
||||||
|
"ai_translate": document.link_role == "editor",
|
||||||
"attachment_upload": document.link_role == "editor",
|
"attachment_upload": document.link_role == "editor",
|
||||||
"link_configuration": False,
|
"link_configuration": False,
|
||||||
"destroy": False,
|
"destroy": False,
|
||||||
|
|||||||
@@ -0,0 +1,127 @@
|
|||||||
|
"""
|
||||||
|
Test throttling on documents for the AI endpoint.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
from django.core.cache import cache
|
||||||
|
from django.test import override_settings
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from rest_framework.response import Response
|
||||||
|
from rest_framework.test import APIRequestFactory
|
||||||
|
from rest_framework.views import APIView
|
||||||
|
|
||||||
|
from core.api.utils import AIDocumentRateThrottle
|
||||||
|
|
||||||
|
|
||||||
|
class DocumentAPIView(APIView):
|
||||||
|
"""A simple view to test the throttle"""
|
||||||
|
|
||||||
|
throttle_classes = [AIDocumentRateThrottle]
|
||||||
|
|
||||||
|
def get(self, request, *args, **kwargs):
|
||||||
|
"""Minimal get method for testing purposes."""
|
||||||
|
return Response({"message": "Success"})
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def clear_cache():
|
||||||
|
"""Fixture to clear the cache before each test."""
|
||||||
|
cache.clear()
|
||||||
|
|
||||||
|
|
||||||
|
@override_settings(AI_DOCUMENT_RATE_THROTTLE_RATES={"minute": 3, "hour": 6, "day": 10})
|
||||||
|
@patch("time.time")
|
||||||
|
def test_api_utils_ai_document_rate_throttle_minute_limit(mock_time):
|
||||||
|
"""Test that minute limit is enforced."""
|
||||||
|
api_rf = APIRequestFactory()
|
||||||
|
mock_time.return_value = 1000000
|
||||||
|
|
||||||
|
# Simulate requests to the document API
|
||||||
|
for _i in range(3): # 3 first requests should be allowed
|
||||||
|
request = api_rf.get("/documents/1/")
|
||||||
|
response = DocumentAPIView.as_view()(request, pk=1)
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
# Simulate passage of time
|
||||||
|
mock_time.return_value += 59
|
||||||
|
|
||||||
|
# 4th request should be throttled
|
||||||
|
request = api_rf.get("/documents/1/")
|
||||||
|
response = DocumentAPIView.as_view()(request, pk=1)
|
||||||
|
assert response.status_code == 429
|
||||||
|
|
||||||
|
# After the 60s backoff wait time has passed, we can make a request again
|
||||||
|
mock_time.return_value += 1
|
||||||
|
|
||||||
|
request = api_rf.get("/documents/1/")
|
||||||
|
response = DocumentAPIView.as_view()(request, pk=1)
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
|
||||||
|
@override_settings(
|
||||||
|
AI_DOCUMENT_RATE_THROTTLE_RATES={"minute": 100000, "hour": 6, "day": 10}
|
||||||
|
)
|
||||||
|
@patch("time.time")
|
||||||
|
def test_ai_document_rate_throttle_hour_limit(mock_time):
|
||||||
|
"""Test that the hour limit is enforced without hitting the minute limit."""
|
||||||
|
api_rf = APIRequestFactory()
|
||||||
|
mock_time.return_value = 1000000
|
||||||
|
|
||||||
|
# Make requests to the document API, one per 21 seconds to avoid hitting the minute limit
|
||||||
|
for _i in range(6):
|
||||||
|
request = api_rf.get("/documents/1/")
|
||||||
|
response = DocumentAPIView.as_view()(request, pk=1)
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
# Simulate passage of time
|
||||||
|
mock_time.return_value += 21
|
||||||
|
|
||||||
|
# Simulate passage of time
|
||||||
|
mock_time.return_value += 3600 - 6 * 21 - 1
|
||||||
|
|
||||||
|
# 7th request should be throttled
|
||||||
|
request = api_rf.get("/documents/1/")
|
||||||
|
response = DocumentAPIView.as_view()(request, pk=1)
|
||||||
|
assert response.status_code == 429
|
||||||
|
|
||||||
|
# After the 1h backoff wait time has passed, we can make a request again
|
||||||
|
mock_time.return_value += 1
|
||||||
|
|
||||||
|
request = api_rf.get("/documents/1/")
|
||||||
|
response = DocumentAPIView.as_view()(request, pk=1)
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
|
||||||
|
@override_settings(AI_DOCUMENT_RATE_THROTTLE_RATES={"minute": 3, "hour": 6, "day": 10})
|
||||||
|
@patch("time.time")
|
||||||
|
def test_api_utils_ai_document_rate_throttle_day_limit(mock_time):
|
||||||
|
"""Test that day limit is enforced."""
|
||||||
|
api_rf = APIRequestFactory()
|
||||||
|
mock_time.return_value = 1000000
|
||||||
|
|
||||||
|
# Make requests to the document API, one per 10 minutes to avoid hitting
|
||||||
|
# the minute and hour limits
|
||||||
|
for _i in range(10): # 10 requests should be allowed
|
||||||
|
request = api_rf.get("/documents/1/")
|
||||||
|
response = DocumentAPIView.as_view()(request, pk=1)
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
# Simulate passage of time
|
||||||
|
mock_time.return_value += 60 * 10
|
||||||
|
|
||||||
|
# Simulate passage of time
|
||||||
|
mock_time.return_value += 24 * 3600 - 10 * 60 * 10 - 1
|
||||||
|
|
||||||
|
# 11th request should be throttled
|
||||||
|
request = api_rf.get("/documents/1/")
|
||||||
|
response = DocumentAPIView.as_view()(request, pk=1)
|
||||||
|
assert response.status_code == 429
|
||||||
|
|
||||||
|
# After the 24h backoff wait time has passed we can make a request again
|
||||||
|
mock_time.return_value += 1
|
||||||
|
|
||||||
|
request = api_rf.get("/documents/1/")
|
||||||
|
response = DocumentAPIView.as_view()(request, pk=1)
|
||||||
|
assert response.status_code == 200
|
||||||
146
src/backend/core/tests/test_api_utils_ai_user_rate_throttles.py
Normal file
146
src/backend/core/tests/test_api_utils_ai_user_rate_throttles.py
Normal file
@@ -0,0 +1,146 @@
|
|||||||
|
"""
|
||||||
|
Test throttling on users for the AI endpoint.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from unittest.mock import patch
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
from django.core.cache import cache
|
||||||
|
from django.test import override_settings
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from rest_framework.response import Response
|
||||||
|
from rest_framework.test import APIRequestFactory
|
||||||
|
from rest_framework.views import APIView
|
||||||
|
|
||||||
|
from core.api.utils import AIUserRateThrottle
|
||||||
|
from core.factories import UserFactory
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.django_db
|
||||||
|
|
||||||
|
|
||||||
|
class DocumentAPIView(APIView):
|
||||||
|
"""A simple view to test the throttle"""
|
||||||
|
|
||||||
|
throttle_classes = [AIUserRateThrottle]
|
||||||
|
|
||||||
|
def get(self, request, *args, **kwargs):
|
||||||
|
"""Minimal get method for testing purposes."""
|
||||||
|
return Response({"message": "Success"})
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def clear_cache():
|
||||||
|
"""Fixture to clear the cache before each test."""
|
||||||
|
cache.clear()
|
||||||
|
|
||||||
|
|
||||||
|
@override_settings(AI_USER_RATE_THROTTLE_RATES={"minute": 3, "hour": 6, "day": 10})
|
||||||
|
@patch("time.time")
|
||||||
|
def test_api_utils_ai_user_rate_throttle_minute_limit(mock_time):
|
||||||
|
"""Test that minute limit is enforced."""
|
||||||
|
user = UserFactory()
|
||||||
|
api_rf = APIRequestFactory()
|
||||||
|
mock_time.return_value = 1000000
|
||||||
|
|
||||||
|
# Simulate requests to the document API
|
||||||
|
for _i in range(3): # 3 first requests should be allowed
|
||||||
|
document_id = str(uuid4())
|
||||||
|
request = api_rf.get(f"/documents/{document_id:s}/")
|
||||||
|
request.user = user
|
||||||
|
response = DocumentAPIView.as_view()(request, pk=document_id)
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
# Simulate passage of time
|
||||||
|
mock_time.return_value += 59
|
||||||
|
|
||||||
|
# 4th request should be throttled
|
||||||
|
document_id = str(uuid4())
|
||||||
|
request = api_rf.get(f"/documents/{document_id:s}/")
|
||||||
|
request.user = user
|
||||||
|
response = DocumentAPIView.as_view()(request, pk=document_id)
|
||||||
|
assert response.status_code == 429
|
||||||
|
|
||||||
|
# After the 60s backoff wait time has passed, we can make a request again
|
||||||
|
mock_time.return_value += 1
|
||||||
|
|
||||||
|
document_id = str(uuid4())
|
||||||
|
request = api_rf.get(f"/documents/{document_id:s}/")
|
||||||
|
request.user = user
|
||||||
|
response = DocumentAPIView.as_view()(request, pk=document_id)
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
|
||||||
|
@override_settings(AI_USER_RATE_THROTTLE_RATES={"minute": 100000, "hour": 6, "day": 10})
|
||||||
|
@patch("time.time")
|
||||||
|
def test_ai_user_rate_throttle_hour_limit(mock_time):
|
||||||
|
"""Test that the hour limit is enforced without hitting the minute limit."""
|
||||||
|
user = UserFactory()
|
||||||
|
api_rf = APIRequestFactory()
|
||||||
|
mock_time.return_value = 1000000
|
||||||
|
|
||||||
|
# Make requests to the document API, one per 21 seconds to avoid hitting the minute limit
|
||||||
|
for _i in range(6):
|
||||||
|
document_id = str(uuid4())
|
||||||
|
request = api_rf.get(f"/documents/{document_id:s}/")
|
||||||
|
request.user = user
|
||||||
|
response = DocumentAPIView.as_view()(request, pk=document_id)
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
# Simulate passage of time
|
||||||
|
mock_time.return_value += 21
|
||||||
|
|
||||||
|
# Simulate passage of time
|
||||||
|
mock_time.return_value += 3600 - 6 * 21 - 1
|
||||||
|
|
||||||
|
# 7th request should be throttled
|
||||||
|
request = api_rf.get(f"/documents/{document_id:s}/")
|
||||||
|
request.user = user
|
||||||
|
response = DocumentAPIView.as_view()(request, pk=document_id)
|
||||||
|
assert response.status_code == 429
|
||||||
|
|
||||||
|
# After the 1h backoff wait time has passed, we can make a request again
|
||||||
|
mock_time.return_value += 1
|
||||||
|
|
||||||
|
request = api_rf.get(f"/documents/{document_id:s}/")
|
||||||
|
request.user = user
|
||||||
|
response = DocumentAPIView.as_view()(request, pk=document_id)
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
|
||||||
|
@override_settings(AI_USER_RATE_THROTTLE_RATES={"minute": 3, "hour": 6, "day": 10})
|
||||||
|
@patch("time.time")
|
||||||
|
def test_api_utils_ai_user_rate_throttle_day_limit(mock_time):
|
||||||
|
"""Test that day limit is enforced."""
|
||||||
|
user = UserFactory()
|
||||||
|
api_rf = APIRequestFactory()
|
||||||
|
mock_time.return_value = 1000000
|
||||||
|
|
||||||
|
# Make requests to the document API, one per 10 minutes to avoid hitting
|
||||||
|
# the minute and hour limits
|
||||||
|
for _i in range(10): # 10 requests should be allowed
|
||||||
|
document_id = str(uuid4())
|
||||||
|
request = api_rf.get(f"/documents/{document_id:s}/")
|
||||||
|
request.user = user
|
||||||
|
response = DocumentAPIView.as_view()(request, pk=document_id)
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
# Simulate passage of time
|
||||||
|
mock_time.return_value += 60 * 10
|
||||||
|
|
||||||
|
# Simulate passage of time
|
||||||
|
mock_time.return_value += 24 * 3600 - 10 * 60 * 10 - 1
|
||||||
|
|
||||||
|
# 11th request should be throttled
|
||||||
|
request = api_rf.get(f"/documents/{document_id:s}/")
|
||||||
|
request.user = user
|
||||||
|
response = DocumentAPIView.as_view()(request, pk=document_id)
|
||||||
|
assert response.status_code == 429
|
||||||
|
|
||||||
|
# After the 24h backoff wait time has passed we can make a request again
|
||||||
|
mock_time.return_value += 1
|
||||||
|
|
||||||
|
request = api_rf.get(f"/documents/{document_id:s}/")
|
||||||
|
request.user = user
|
||||||
|
response = DocumentAPIView.as_view()(request, pk=document_id)
|
||||||
|
assert response.status_code == 200
|
||||||
@@ -83,6 +83,8 @@ def test_models_documents_get_abilities_forbidden(is_authenticated, reach, role)
|
|||||||
user = factories.UserFactory() if is_authenticated else AnonymousUser()
|
user = factories.UserFactory() if is_authenticated else AnonymousUser()
|
||||||
abilities = document.get_abilities(user)
|
abilities = document.get_abilities(user)
|
||||||
assert abilities == {
|
assert abilities == {
|
||||||
|
"ai_transform": False,
|
||||||
|
"ai_translate": False,
|
||||||
"attachment_upload": False,
|
"attachment_upload": False,
|
||||||
"link_configuration": False,
|
"link_configuration": False,
|
||||||
"destroy": False,
|
"destroy": False,
|
||||||
@@ -113,6 +115,8 @@ def test_models_documents_get_abilities_reader(is_authenticated, reach):
|
|||||||
user = factories.UserFactory() if is_authenticated else AnonymousUser()
|
user = factories.UserFactory() if is_authenticated else AnonymousUser()
|
||||||
abilities = document.get_abilities(user)
|
abilities = document.get_abilities(user)
|
||||||
assert abilities == {
|
assert abilities == {
|
||||||
|
"ai_transform": False,
|
||||||
|
"ai_translate": False,
|
||||||
"attachment_upload": False,
|
"attachment_upload": False,
|
||||||
"destroy": False,
|
"destroy": False,
|
||||||
"link_configuration": False,
|
"link_configuration": False,
|
||||||
@@ -143,6 +147,8 @@ def test_models_documents_get_abilities_editor(is_authenticated, reach):
|
|||||||
user = factories.UserFactory() if is_authenticated else AnonymousUser()
|
user = factories.UserFactory() if is_authenticated else AnonymousUser()
|
||||||
abilities = document.get_abilities(user)
|
abilities = document.get_abilities(user)
|
||||||
assert abilities == {
|
assert abilities == {
|
||||||
|
"ai_transform": True,
|
||||||
|
"ai_translate": True,
|
||||||
"attachment_upload": True,
|
"attachment_upload": True,
|
||||||
"destroy": False,
|
"destroy": False,
|
||||||
"link_configuration": False,
|
"link_configuration": False,
|
||||||
@@ -162,6 +168,8 @@ def test_models_documents_get_abilities_owner():
|
|||||||
access = factories.UserDocumentAccessFactory(role="owner", user=user)
|
access = factories.UserDocumentAccessFactory(role="owner", user=user)
|
||||||
abilities = access.document.get_abilities(access.user)
|
abilities = access.document.get_abilities(access.user)
|
||||||
assert abilities == {
|
assert abilities == {
|
||||||
|
"ai_transform": True,
|
||||||
|
"ai_translate": True,
|
||||||
"attachment_upload": True,
|
"attachment_upload": True,
|
||||||
"destroy": True,
|
"destroy": True,
|
||||||
"link_configuration": True,
|
"link_configuration": True,
|
||||||
@@ -180,6 +188,8 @@ def test_models_documents_get_abilities_administrator():
|
|||||||
access = factories.UserDocumentAccessFactory(role="administrator")
|
access = factories.UserDocumentAccessFactory(role="administrator")
|
||||||
abilities = access.document.get_abilities(access.user)
|
abilities = access.document.get_abilities(access.user)
|
||||||
assert abilities == {
|
assert abilities == {
|
||||||
|
"ai_transform": True,
|
||||||
|
"ai_translate": True,
|
||||||
"attachment_upload": True,
|
"attachment_upload": True,
|
||||||
"destroy": False,
|
"destroy": False,
|
||||||
"link_configuration": True,
|
"link_configuration": True,
|
||||||
@@ -201,6 +211,8 @@ def test_models_documents_get_abilities_editor_user(django_assert_num_queries):
|
|||||||
abilities = access.document.get_abilities(access.user)
|
abilities = access.document.get_abilities(access.user)
|
||||||
|
|
||||||
assert abilities == {
|
assert abilities == {
|
||||||
|
"ai_transform": True,
|
||||||
|
"ai_translate": True,
|
||||||
"attachment_upload": True,
|
"attachment_upload": True,
|
||||||
"destroy": False,
|
"destroy": False,
|
||||||
"link_configuration": False,
|
"link_configuration": False,
|
||||||
@@ -224,6 +236,8 @@ def test_models_documents_get_abilities_reader_user(django_assert_num_queries):
|
|||||||
abilities = access.document.get_abilities(access.user)
|
abilities = access.document.get_abilities(access.user)
|
||||||
|
|
||||||
assert abilities == {
|
assert abilities == {
|
||||||
|
"ai_transform": False,
|
||||||
|
"ai_translate": False,
|
||||||
"attachment_upload": False,
|
"attachment_upload": False,
|
||||||
"destroy": False,
|
"destroy": False,
|
||||||
"link_configuration": False,
|
"link_configuration": False,
|
||||||
@@ -248,6 +262,8 @@ def test_models_documents_get_abilities_preset_role(django_assert_num_queries):
|
|||||||
abilities = access.document.get_abilities(access.user)
|
abilities = access.document.get_abilities(access.user)
|
||||||
|
|
||||||
assert abilities == {
|
assert abilities == {
|
||||||
|
"ai_transform": False,
|
||||||
|
"ai_translate": False,
|
||||||
"attachment_upload": False,
|
"attachment_upload": False,
|
||||||
"destroy": False,
|
"destroy": False,
|
||||||
"link_configuration": False,
|
"link_configuration": False,
|
||||||
|
|||||||
104
src/backend/core/tests/test_services_ai_services.py
Normal file
104
src/backend/core/tests/test_services_ai_services.py
Normal file
@@ -0,0 +1,104 @@
|
|||||||
|
"""
|
||||||
|
Test ai API endpoints in the impress core app.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
from django.core.exceptions import ImproperlyConfigured
|
||||||
|
from django.test.utils import override_settings
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from openai import OpenAIError
|
||||||
|
|
||||||
|
from core.services.ai_services import AIService
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.django_db
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"setting_name, setting_value",
|
||||||
|
[
|
||||||
|
("AI_BASE_URL", None),
|
||||||
|
("AI_API_KEY", None),
|
||||||
|
("AI_MODEL", None),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_api_ai_setting_missing(setting_name, setting_value):
|
||||||
|
"""Setting should be set"""
|
||||||
|
|
||||||
|
with override_settings(**{setting_name: setting_value}):
|
||||||
|
with pytest.raises(
|
||||||
|
ImproperlyConfigured,
|
||||||
|
match="AI configuration not set",
|
||||||
|
):
|
||||||
|
AIService()
|
||||||
|
|
||||||
|
|
||||||
|
@override_settings(
|
||||||
|
AI_BASE_URL="http://example.com", AI_API_KEY="test-key", AI_MODEL="test-model"
|
||||||
|
)
|
||||||
|
@patch("openai.resources.chat.completions.Completions.create")
|
||||||
|
def test_api_ai__client_error(mock_create):
|
||||||
|
"""Fail when the client raises an error"""
|
||||||
|
|
||||||
|
mock_create.side_effect = OpenAIError("Mocked client error")
|
||||||
|
|
||||||
|
with pytest.raises(
|
||||||
|
OpenAIError,
|
||||||
|
match="Mocked client error",
|
||||||
|
):
|
||||||
|
AIService().transform("hello", "prompt")
|
||||||
|
|
||||||
|
|
||||||
|
@override_settings(
|
||||||
|
AI_BASE_URL="http://example.com", AI_API_KEY="test-key", AI_MODEL="test-model"
|
||||||
|
)
|
||||||
|
@patch("openai.resources.chat.completions.Completions.create")
|
||||||
|
def test_api_ai__client_invalid_response(mock_create):
|
||||||
|
"""Fail when the client response is invalid"""
|
||||||
|
|
||||||
|
answer = {"no_answer": "This is an invalid response"}
|
||||||
|
mock_create.return_value = MagicMock(
|
||||||
|
choices=[MagicMock(message=MagicMock(content=json.dumps(answer)))]
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(
|
||||||
|
RuntimeError,
|
||||||
|
match="AI response does not contain an answer",
|
||||||
|
):
|
||||||
|
AIService().transform("hello", "prompt")
|
||||||
|
|
||||||
|
|
||||||
|
@override_settings(
|
||||||
|
AI_BASE_URL="http://example.com", AI_API_KEY="test-key", AI_MODEL="test-model"
|
||||||
|
)
|
||||||
|
@patch("openai.resources.chat.completions.Completions.create")
|
||||||
|
def test_api_ai__success(mock_create):
|
||||||
|
"""The AI request should work as expect when called with valid arguments."""
|
||||||
|
|
||||||
|
answer = '{"answer": "Salut"}'
|
||||||
|
mock_create.return_value = MagicMock(
|
||||||
|
choices=[MagicMock(message=MagicMock(content=answer))]
|
||||||
|
)
|
||||||
|
|
||||||
|
response = AIService().transform("hello", "prompt")
|
||||||
|
|
||||||
|
assert response == {"answer": "Salut"}
|
||||||
|
|
||||||
|
|
||||||
|
@override_settings(
|
||||||
|
AI_BASE_URL="http://example.com", AI_API_KEY="test-key", AI_MODEL="test-model"
|
||||||
|
)
|
||||||
|
@patch("openai.resources.chat.completions.Completions.create")
|
||||||
|
def test_api_ai__success_sanitize(mock_create):
|
||||||
|
"""The AI response should be sanitized"""
|
||||||
|
|
||||||
|
answer = '{"answer": "Salut\\n \tle \nmonde"}'
|
||||||
|
mock_create.return_value = MagicMock(
|
||||||
|
choices=[MagicMock(message=MagicMock(content=answer))]
|
||||||
|
)
|
||||||
|
|
||||||
|
response = AIService().transform("hello", "prompt")
|
||||||
|
|
||||||
|
assert response == {"answer": "Salut\n \tle \nmonde"}
|
||||||
@@ -454,6 +454,20 @@ class Base(Configuration):
|
|||||||
ALLOW_LOGOUT_GET_METHOD = values.BooleanValue(
|
ALLOW_LOGOUT_GET_METHOD = values.BooleanValue(
|
||||||
default=True, environ_name="ALLOW_LOGOUT_GET_METHOD", environ_prefix=None
|
default=True, environ_name="ALLOW_LOGOUT_GET_METHOD", environ_prefix=None
|
||||||
)
|
)
|
||||||
|
AI_API_KEY = values.Value(None, environ_name="AI_API_KEY", environ_prefix=None)
|
||||||
|
AI_BASE_URL = values.Value(None, environ_name="AI_BASE_URL", environ_prefix=None)
|
||||||
|
AI_MODEL = values.Value(None, environ_name="AI_MODEL", environ_prefix=None)
|
||||||
|
|
||||||
|
AI_DOCUMENT_RATE_THROTTLE_RATES = {
|
||||||
|
"minute": 5,
|
||||||
|
"hour": 100,
|
||||||
|
"day": 500,
|
||||||
|
}
|
||||||
|
AI_USER_RATE_THROTTLE_RATES = {
|
||||||
|
"minute": 3,
|
||||||
|
"hour": 50,
|
||||||
|
"day": 200,
|
||||||
|
}
|
||||||
|
|
||||||
USER_OIDC_FIELDS_TO_FULLNAME = values.ListValue(
|
USER_OIDC_FIELDS_TO_FULLNAME = values.ListValue(
|
||||||
default=["first_name", "last_name"],
|
default=["first_name", "last_name"],
|
||||||
|
|||||||
@@ -46,6 +46,7 @@ dependencies = [
|
|||||||
"jsonschema==4.23.0",
|
"jsonschema==4.23.0",
|
||||||
"markdown==3.7",
|
"markdown==3.7",
|
||||||
"nested-multipart-parser==1.5.0",
|
"nested-multipart-parser==1.5.0",
|
||||||
|
"openai==1.44.1",
|
||||||
"psycopg[binary]==3.2.3",
|
"psycopg[binary]==3.2.3",
|
||||||
"PyJWT==2.9.0",
|
"PyJWT==2.9.0",
|
||||||
"pypandoc==1.14",
|
"pypandoc==1.14",
|
||||||
|
|||||||
Reference in New Issue
Block a user