"""Util to generate S3 authorization headers for object storage access control""" import time from abc import ABC, abstractmethod from django.conf import settings from django.core.cache import cache from django.core.files.storage import default_storage import botocore from rest_framework.throttling import BaseThrottle def generate_s3_authorization_headers(key): """ Generate authorization headers for an s3 object. These headers can be used as an alternative to signed urls with many benefits: - the urls of our files never expire and can be stored in our documents' content - we don't leak authorized urls that could be shared (file access can only be done with cookies) - access control is truly realtime - the object storage service does not need to be exposed on internet """ url = default_storage.unsigned_connection.meta.client.generate_presigned_url( "get_object", ExpiresIn=0, Params={"Bucket": default_storage.bucket_name, "Key": key}, ) request = botocore.awsrequest.AWSRequest(method="get", url=url) s3_client = default_storage.connection.meta.client # pylint: disable=protected-access credentials = s3_client._request_signer._credentials # noqa: SLF001 frozen_credentials = credentials.get_frozen_credentials() region = s3_client.meta.region_name auth = botocore.auth.S3SigV4Auth(frozen_credentials, "s3", region) auth.add_auth(request) return request class AIBaseRateThrottle(BaseThrottle, ABC): """Base throttle class for AI-related rate limiting with backoff.""" def __init__(self, rates): """Initialize instance attributes with configurable rates.""" super().__init__() self.rates = rates self.cache_key = None self.recent_requests_minute = 0 self.recent_requests_hour = 0 self.recent_requests_day = 0 @abstractmethod def get_cache_key(self, request, view): """Abstract method to generate cache key for throttling.""" def allow_request(self, request, view): """Check if the request is allowed based on rate limits.""" self.cache_key = self.get_cache_key(request, view) if not self.cache_key: return True # Allow if no cache key is generated now = time.time() history = cache.get(self.cache_key, []) # Keep requests within the last 24 hours history = [req for req in history if req > now - 86400] # Calculate recent requests self.recent_requests_minute = len([req for req in history if req > now - 60]) self.recent_requests_hour = len([req for req in history if req > now - 3600]) self.recent_requests_day = len(history) # Check rate limits if self.recent_requests_minute >= self.rates["minute"]: return False if self.recent_requests_hour >= self.rates["hour"]: return False if self.recent_requests_day >= self.rates["day"]: return False # Log the request history.append(now) cache.set(self.cache_key, history, timeout=86400) return True def wait(self): """Implement a backoff strategy by increasing wait time based on limits hit.""" if self.recent_requests_day >= self.rates["day"]: return 86400 if self.recent_requests_hour >= self.rates["hour"]: return 3600 if self.recent_requests_minute >= self.rates["minute"]: return 60 return None class AIDocumentRateThrottle(AIBaseRateThrottle): """Throttle for limiting AI requests per document with backoff.""" def __init__(self, *args, **kwargs): super().__init__(settings.AI_DOCUMENT_RATE_THROTTLE_RATES) def get_cache_key(self, request, view): """Include document ID in the cache key.""" document_id = view.kwargs["pk"] return f"document_{document_id}_throttle_ai" class AIUserRateThrottle(AIBaseRateThrottle): """Throttle that limits requests per user or IP with backoff and rate limits.""" def __init__(self, *args, **kwargs): super().__init__(settings.AI_USER_RATE_THROTTLE_RATES) def get_cache_key(self, request, view=None): """Generate a cache key based on the user ID or IP for anonymous users.""" if request.user.is_authenticated: return f"user_{request.user.id!s}_throttle_ai" return f"anonymous_{self.get_ident(request)}_throttle_ai" def get_ident(self, request): """Return the request IP address.""" x_forwarded_for = request.META.get("HTTP_X_FORWARDED_FOR") return ( x_forwarded_for.split(",")[0] if x_forwarded_for else request.META.get("REMOTE_ADDR") )