2024-08-19 22:38:41 +02:00
|
|
|
"""Util to generate S3 authorization headers for object storage access control"""
|
|
|
|
|
|
2024-09-20 22:42:46 +02:00
|
|
|
import time
|
|
|
|
|
from abc import ABC, abstractmethod
|
|
|
|
|
|
|
|
|
|
from django.conf import settings
|
|
|
|
|
from django.core.cache import cache
|
2024-08-19 22:38:41 +02:00
|
|
|
from django.core.files.storage import default_storage
|
|
|
|
|
|
|
|
|
|
import botocore
|
2024-09-20 22:42:46 +02:00
|
|
|
from rest_framework.throttling import BaseThrottle
|
2024-08-19 22:38:41 +02:00
|
|
|
|
|
|
|
|
|
|
|
|
|
def generate_s3_authorization_headers(key):
|
|
|
|
|
"""
|
|
|
|
|
Generate authorization headers for an s3 object.
|
|
|
|
|
These headers can be used as an alternative to signed urls with many benefits:
|
|
|
|
|
- the urls of our files never expire and can be stored in our documents' content
|
|
|
|
|
- we don't leak authorized urls that could be shared (file access can only be done
|
|
|
|
|
with cookies)
|
|
|
|
|
- access control is truly realtime
|
|
|
|
|
- the object storage service does not need to be exposed on internet
|
|
|
|
|
"""
|
|
|
|
|
url = default_storage.unsigned_connection.meta.client.generate_presigned_url(
|
|
|
|
|
"get_object",
|
|
|
|
|
ExpiresIn=0,
|
|
|
|
|
Params={"Bucket": default_storage.bucket_name, "Key": key},
|
|
|
|
|
)
|
|
|
|
|
request = botocore.awsrequest.AWSRequest(method="get", url=url)
|
|
|
|
|
|
|
|
|
|
s3_client = default_storage.connection.meta.client
|
|
|
|
|
# pylint: disable=protected-access
|
|
|
|
|
credentials = s3_client._request_signer._credentials # noqa: SLF001
|
|
|
|
|
frozen_credentials = credentials.get_frozen_credentials()
|
|
|
|
|
region = s3_client.meta.region_name
|
|
|
|
|
auth = botocore.auth.S3SigV4Auth(frozen_credentials, "s3", region)
|
|
|
|
|
auth.add_auth(request)
|
|
|
|
|
|
|
|
|
|
return request
|
2024-09-20 22:42:46 +02:00
|
|
|
|
|
|
|
|
|
|
|
|
|
class AIBaseRateThrottle(BaseThrottle, ABC):
|
|
|
|
|
"""Base throttle class for AI-related rate limiting with backoff."""
|
|
|
|
|
|
|
|
|
|
def __init__(self, rates):
|
|
|
|
|
"""Initialize instance attributes with configurable rates."""
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.rates = rates
|
|
|
|
|
self.cache_key = None
|
|
|
|
|
self.recent_requests_minute = 0
|
|
|
|
|
self.recent_requests_hour = 0
|
|
|
|
|
self.recent_requests_day = 0
|
|
|
|
|
|
|
|
|
|
@abstractmethod
|
|
|
|
|
def get_cache_key(self, request, view):
|
|
|
|
|
"""Abstract method to generate cache key for throttling."""
|
|
|
|
|
|
|
|
|
|
def allow_request(self, request, view):
|
|
|
|
|
"""Check if the request is allowed based on rate limits."""
|
|
|
|
|
self.cache_key = self.get_cache_key(request, view)
|
|
|
|
|
if not self.cache_key:
|
|
|
|
|
return True # Allow if no cache key is generated
|
|
|
|
|
|
|
|
|
|
now = time.time()
|
|
|
|
|
history = cache.get(self.cache_key, [])
|
|
|
|
|
# Keep requests within the last 24 hours
|
|
|
|
|
history = [req for req in history if req > now - 86400]
|
|
|
|
|
|
|
|
|
|
# Calculate recent requests
|
|
|
|
|
self.recent_requests_minute = len([req for req in history if req > now - 60])
|
|
|
|
|
self.recent_requests_hour = len([req for req in history if req > now - 3600])
|
|
|
|
|
self.recent_requests_day = len(history)
|
|
|
|
|
|
|
|
|
|
# Check rate limits
|
|
|
|
|
if self.recent_requests_minute >= self.rates["minute"]:
|
|
|
|
|
return False
|
|
|
|
|
if self.recent_requests_hour >= self.rates["hour"]:
|
|
|
|
|
return False
|
|
|
|
|
if self.recent_requests_day >= self.rates["day"]:
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
# Log the request
|
|
|
|
|
history.append(now)
|
|
|
|
|
cache.set(self.cache_key, history, timeout=86400)
|
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
def wait(self):
|
|
|
|
|
"""Implement a backoff strategy by increasing wait time based on limits hit."""
|
|
|
|
|
if self.recent_requests_day >= self.rates["day"]:
|
|
|
|
|
return 86400
|
|
|
|
|
if self.recent_requests_hour >= self.rates["hour"]:
|
|
|
|
|
return 3600
|
|
|
|
|
if self.recent_requests_minute >= self.rates["minute"]:
|
|
|
|
|
return 60
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class AIDocumentRateThrottle(AIBaseRateThrottle):
|
|
|
|
|
"""Throttle for limiting AI requests per document with backoff."""
|
|
|
|
|
|
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
|
|
|
super().__init__(settings.AI_DOCUMENT_RATE_THROTTLE_RATES)
|
|
|
|
|
|
|
|
|
|
def get_cache_key(self, request, view):
|
|
|
|
|
"""Include document ID in the cache key."""
|
|
|
|
|
document_id = view.kwargs["pk"]
|
|
|
|
|
return f"document_{document_id}_throttle_ai"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class AIUserRateThrottle(AIBaseRateThrottle):
|
|
|
|
|
"""Throttle that limits requests per user or IP with backoff and rate limits."""
|
|
|
|
|
|
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
|
|
|
super().__init__(settings.AI_USER_RATE_THROTTLE_RATES)
|
|
|
|
|
|
|
|
|
|
def get_cache_key(self, request, view=None):
|
|
|
|
|
"""Generate a cache key based on the user ID or IP for anonymous users."""
|
|
|
|
|
if request.user.is_authenticated:
|
|
|
|
|
return f"user_{request.user.id!s}_throttle_ai"
|
|
|
|
|
return f"anonymous_{self.get_ident(request)}_throttle_ai"
|
|
|
|
|
|
|
|
|
|
def get_ident(self, request):
|
|
|
|
|
"""Return the request IP address."""
|
|
|
|
|
x_forwarded_for = request.META.get("HTTP_X_FORWARDED_FOR")
|
|
|
|
|
return (
|
|
|
|
|
x_forwarded_for.split(",")[0]
|
|
|
|
|
if x_forwarded_for
|
|
|
|
|
else request.META.get("REMOTE_ADDR")
|
|
|
|
|
)
|