diff --git a/env.d/development/common.dist b/env.d/development/common.dist index 712c4dbd..97b8e08a 100644 --- a/env.d/development/common.dist +++ b/env.d/development/common.dist @@ -39,3 +39,8 @@ LOGOUT_REDIRECT_URL=http://localhost:3000 OIDC_REDIRECT_ALLOWED_HOSTS=["http://localhost:8083", "http://localhost:3000"] OIDC_AUTH_REQUEST_EXTRA_PARAMS={"acr_values": "eidas1"} + +# AI +AI_BASE_URL=https://openaiendpoint.com +AI_API_KEY=password +AI_MODEL=llama diff --git a/src/backend/core/api/serializers.py b/src/backend/core/api/serializers.py index 269ce728..2f498b26 100644 --- a/src/backend/core/api/serializers.py +++ b/src/backend/core/api/serializers.py @@ -9,7 +9,8 @@ from django.utils.translation import gettext_lazy as _ import magic from rest_framework import exceptions, serializers -from core import models +from core import enums, models +from core.services.ai_services import AI_ACTIONS class UserSerializer(serializers.ModelSerializer): @@ -378,3 +379,33 @@ class VersionFilterSerializer(serializers.Serializer): page_size = serializers.IntegerField( required=False, min_value=1, max_value=50, default=20 ) + + +class AITransformSerializer(serializers.Serializer): + """Serializer for AI transform requests.""" + + action = serializers.ChoiceField(choices=AI_ACTIONS, required=True) + text = serializers.CharField(required=True) + + def validate_text(self, value): + """Ensure the text field is not empty.""" + + if len(value.strip()) == 0: + raise serializers.ValidationError("Text field cannot be empty.") + return value + + +class AITranslateSerializer(serializers.Serializer): + """Serializer for AI translate requests.""" + + language = serializers.ChoiceField( + choices=tuple(enums.ALL_LANGUAGES.items()), required=True + ) + text = serializers.CharField(required=True) + + def validate_text(self, value): + """Ensure the text field is not empty.""" + + if len(value.strip()) == 0: + raise serializers.ValidationError("Text field cannot be empty.") + return value diff --git a/src/backend/core/api/utils.py b/src/backend/core/api/utils.py index 53d84f68..53ca09eb 100644 --- a/src/backend/core/api/utils.py +++ b/src/backend/core/api/utils.py @@ -1,8 +1,14 @@ """Util to generate S3 authorization headers for object storage access control""" +import time +from abc import ABC, abstractmethod + +from django.conf import settings +from django.core.cache import cache from django.core.files.storage import default_storage import botocore +from rest_framework.throttling import BaseThrottle def generate_s3_authorization_headers(key): @@ -31,3 +37,93 @@ def generate_s3_authorization_headers(key): auth.add_auth(request) return request + + +class AIBaseRateThrottle(BaseThrottle, ABC): + """Base throttle class for AI-related rate limiting with backoff.""" + + def __init__(self, rates): + """Initialize instance attributes with configurable rates.""" + super().__init__() + self.rates = rates + self.cache_key = None + self.recent_requests_minute = 0 + self.recent_requests_hour = 0 + self.recent_requests_day = 0 + + @abstractmethod + def get_cache_key(self, request, view): + """Abstract method to generate cache key for throttling.""" + + def allow_request(self, request, view): + """Check if the request is allowed based on rate limits.""" + self.cache_key = self.get_cache_key(request, view) + if not self.cache_key: + return True # Allow if no cache key is generated + + now = time.time() + history = cache.get(self.cache_key, []) + # Keep requests within the last 24 hours + history = [req for req in history if req > now - 86400] + + # Calculate recent requests + self.recent_requests_minute = len([req for req in history if req > now - 60]) + self.recent_requests_hour = len([req for req in history if req > now - 3600]) + self.recent_requests_day = len(history) + + # Check rate limits + if self.recent_requests_minute >= self.rates["minute"]: + return False + if self.recent_requests_hour >= self.rates["hour"]: + return False + if self.recent_requests_day >= self.rates["day"]: + return False + + # Log the request + history.append(now) + cache.set(self.cache_key, history, timeout=86400) + return True + + def wait(self): + """Implement a backoff strategy by increasing wait time based on limits hit.""" + if self.recent_requests_day >= self.rates["day"]: + return 86400 + if self.recent_requests_hour >= self.rates["hour"]: + return 3600 + if self.recent_requests_minute >= self.rates["minute"]: + return 60 + return None + + +class AIDocumentRateThrottle(AIBaseRateThrottle): + """Throttle for limiting AI requests per document with backoff.""" + + def __init__(self, *args, **kwargs): + super().__init__(settings.AI_DOCUMENT_RATE_THROTTLE_RATES) + + def get_cache_key(self, request, view): + """Include document ID in the cache key.""" + document_id = view.kwargs["pk"] + return f"document_{document_id}_throttle_ai" + + +class AIUserRateThrottle(AIBaseRateThrottle): + """Throttle that limits requests per user or IP with backoff and rate limits.""" + + def __init__(self, *args, **kwargs): + super().__init__(settings.AI_USER_RATE_THROTTLE_RATES) + + def get_cache_key(self, request, view=None): + """Generate a cache key based on the user ID or IP for anonymous users.""" + if request.user.is_authenticated: + return f"user_{request.user.id!s}_throttle_ai" + return f"anonymous_{self.get_ident(request)}_throttle_ai" + + def get_ident(self, request): + """Return the request IP address.""" + x_forwarded_for = request.META.get("HTTP_X_FORWARDED_FOR") + return ( + x_forwarded_for.split(",")[0] + if x_forwarded_for + else request.META.get("REMOTE_ADDR") + ) diff --git a/src/backend/core/api/viewsets.py b/src/backend/core/api/viewsets.py index 918ad426..66939686 100644 --- a/src/backend/core/api/viewsets.py +++ b/src/backend/core/api/viewsets.py @@ -21,6 +21,7 @@ from rest_framework import ( decorators, exceptions, filters, + metadata, mixins, pagination, status, @@ -30,7 +31,8 @@ from rest_framework import ( response as drf_response, ) -from core import models +from core import enums, models +from core.services.ai_services import AIService from . import permissions, serializers, utils @@ -302,6 +304,23 @@ class ResourceAccessViewsetMixin: serializer.save() +class DocumentMetadata(metadata.SimpleMetadata): + """Custom metadata class to add information""" + + def determine_metadata(self, request, view): + """Add language choices only for the list endpoint.""" + simple_metadata = super().determine_metadata(request, view) + + if request.path.endswith("/documents/"): + simple_metadata["actions"]["POST"]["language"] = { + "choices": [ + {"value": code, "display_name": name} + for code, name in enums.ALL_LANGUAGES.items() + ] + } + return simple_metadata + + class DocumentViewSet( ResourceViewsetMixin, mixins.CreateModelMixin, @@ -319,6 +338,7 @@ class DocumentViewSet( resource_field_name = "document" queryset = models.Document.objects.all() ordering = ["-updated_at"] + metadata_class = DocumentMetadata def list(self, request, *args, **kwargs): """Restrict resources returned by the list endpoint""" @@ -455,10 +475,7 @@ class DocumentViewSet( serializer = serializers.LinkDocumentSerializer( document, data=request.data, partial=True ) - if not serializer.is_valid(): - return drf_response.Response( - serializer.errors, status=status.HTTP_400_BAD_REQUEST - ) + serializer.is_valid(raise_exception=True) serializer.save() return drf_response.Response(serializer.data, status=status.HTTP_200_OK) @@ -471,10 +488,7 @@ class DocumentViewSet( # Validate metadata in payload serializer = serializers.FileUploadSerializer(data=request.data) - if not serializer.is_valid(): - return drf_response.Response( - serializer.errors, status=status.HTTP_400_BAD_REQUEST - ) + serializer.is_valid(raise_exception=True) # Generate a generic yet unique filename to store the image in object storage file_id = uuid.uuid4() @@ -482,13 +496,13 @@ class DocumentViewSet( key = f"{document.key_base}/{ATTACHMENTS_FOLDER:s}/{file_id!s}.{extension:s}" # Prepare metadata for storage - metadata = {"Metadata": {"owner": str(request.user.id)}} + extra_args = {"Metadata": {"owner": str(request.user.id)}} if serializer.validated_data["is_unsafe"]: - metadata["Metadata"]["is_unsafe"] = "true" + extra_args["Metadata"]["is_unsafe"] = "true" file = serializer.validated_data["file"] default_storage.connection.meta.client.upload_fileobj( - file, default_storage.bucket_name, key, ExtraArgs=metadata + file, default_storage.bucket_name, key, ExtraArgs=extra_args ) return drf_response.Response( @@ -537,6 +551,63 @@ class DocumentViewSet( request = utils.generate_s3_authorization_headers(f"{pk:s}/{attachment_key:s}") return drf_response.Response("authorized", headers=request.headers, status=200) + @decorators.action( + detail=True, + methods=["post"], + name="Apply a transformation action on a piece of text with AI", + url_path="ai-transform", + throttle_classes=[utils.AIDocumentRateThrottle, utils.AIUserRateThrottle], + ) + def ai_transform(self, request, *args, **kwargs): + """ + POST /api/v1.0/documents//ai-transform + with expected data: + - text: str + - action: str [prompt, correct, rephrase, summarize] + Return JSON response with the processed text. + """ + # Check permissions first + self.get_object() + + serializer = serializers.AITransformSerializer(data=request.data) + serializer.is_valid(raise_exception=True) + + text = serializer.validated_data["text"] + action = serializer.validated_data["action"] + + response = AIService().transform(text, action) + + return drf_response.Response(response, status=status.HTTP_200_OK) + + @decorators.action( + detail=True, + methods=["post"], + name="Translate a piece of text with AI", + serializer_class=serializers.AITranslateSerializer, + url_path="ai-translate", + throttle_classes=[utils.AIDocumentRateThrottle, utils.AIUserRateThrottle], + ) + def ai_translate(self, request, *args, **kwargs): + """ + POST /api/v1.0/documents//ai-translate + with expected data: + - text: str + - language: str [settings.LANGUAGES] + Return JSON response with the translated text. + """ + # Check permissions first + self.get_object() + + serializer = self.get_serializer(data=request.data) + serializer.is_valid(raise_exception=True) + + text = serializer.validated_data["text"] + language = serializer.validated_data["language"] + + response = AIService().translate(text, language) + + return drf_response.Response(response, status=status.HTTP_200_OK) + class DocumentAccessViewSet( ResourceAccessViewsetMixin, diff --git a/src/backend/core/enums.py b/src/backend/core/enums.py index e67d7b5b..8f7e70cf 100644 --- a/src/backend/core/enums.py +++ b/src/backend/core/enums.py @@ -2,15 +2,11 @@ Core application enums declaration """ -from django.conf import global_settings, settings +from django.conf import global_settings from django.utils.translation import gettext_lazy as _ -# Django sets `LANGUAGES` by default with all supported languages. We can use it for -# the choice of languages which should not be limited to the few languages active in -# the app. +# In Django's code base, `LANGUAGES` is set by default with all supported languages. +# We can use it for the choice of languages which should not be limited to the few languages +# active in the app. # pylint: disable=no-member -ALL_LANGUAGES = getattr( - settings, - "ALL_LANGUAGES", - [(language, _(name)) for language, name in global_settings.LANGUAGES], -) +ALL_LANGUAGES = {language: _(name) for language, name in global_settings.LANGUAGES} diff --git a/src/backend/core/models.py b/src/backend/core/models.py index 52c401c1..3697568b 100644 --- a/src/backend/core/models.py +++ b/src/backend/core/models.py @@ -508,6 +508,8 @@ class Document(BaseModel): can_get = bool(roles) return { + "ai_transform": is_owner_or_admin or is_editor, + "ai_translate": is_owner_or_admin or is_editor, "attachment_upload": is_owner_or_admin or is_editor, "destroy": RoleChoices.OWNER in roles, "link_configuration": is_owner_or_admin, diff --git a/src/backend/core/services/__init__.py b/src/backend/core/services/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/backend/core/services/ai_services.py b/src/backend/core/services/ai_services.py new file mode 100644 index 00000000..b7c70405 --- /dev/null +++ b/src/backend/core/services/ai_services.py @@ -0,0 +1,89 @@ +"""AI services.""" + +import json +import re + +from django.conf import settings +from django.core.exceptions import ImproperlyConfigured + +from openai import OpenAI + +from core import enums + +AI_ACTIONS = { + "prompt": ( + "Answer the prompt in markdown format. Return JSON: " + '{"answer": "Your markdown answer"}. ' + "Do not provide any other information." + ), + "correct": ( + "Correct grammar and spelling of the markdown text, " + "preserving language and markdown formatting. " + 'Return JSON: {"answer": "your corrected markdown text"}. ' + "Do not provide any other information." + ), + "rephrase": ( + "Rephrase the given markdown text, " + "preserving language and markdown formatting. " + 'Return JSON: {"answer": "your rephrased markdown text"}. ' + "Do not provide any other information." + ), + "summarize": ( + "Summarize the markdown text, preserving language and markdown formatting. " + 'Return JSON: {"answer": "your markdown summary"}. ' + "Do not provide any other information." + ), +} + +AI_TRANSLATE = ( + "Translate the markdown text to {language:s}, preserving markdown formatting. " + 'Return JSON: {{"answer": "your translated markdown text in {language:s}"}}. ' + "Do not provide any other information." +) + + +class AIService: + """Service class for AI-related operations.""" + + def __init__(self): + """Ensure that the AI configuration is set properly.""" + if ( + settings.AI_BASE_URL is None + or settings.AI_API_KEY is None + or settings.AI_MODEL is None + ): + raise ImproperlyConfigured("AI configuration not set") + self.client = OpenAI(base_url=settings.AI_BASE_URL, api_key=settings.AI_API_KEY) + + def call_ai_api(self, system_content, text): + """Helper method to call the OpenAI API and process the response.""" + response = self.client.chat.completions.create( + model=settings.AI_MODEL, + response_format={"type": "json_object"}, + messages=[ + {"role": "system", "content": system_content}, + {"role": "user", "content": json.dumps({"markdown_input": text})}, + ], + ) + + content = response.choices[0].message.content + sanitized_content = re.sub(r"(?