Refactor external API authentication classes to inherit from a common base authentication backend. Prepare the introduction of a new authentication class responsible for verifying tokens provided to calendar integrations. Move token decoding responsibility to the new token service so it can both generate and validate tokens. Encapsulate external exceptions and expose a clear interface by defining custom Python exceptions raised during token validation. Taken from #897.
154 lines
4.3 KiB
Python
154 lines
4.3 KiB
Python
"""JWT token service."""
|
|
|
|
# pylint: disable=R0913,R0917
|
|
# ruff: noqa: PLR0913
|
|
|
|
from datetime import datetime, timedelta, timezone
|
|
from typing import Optional
|
|
|
|
from django.core.exceptions import ImproperlyConfigured
|
|
|
|
import jwt
|
|
|
|
|
|
class JWTError(Exception):
|
|
"""Base exception for all JWT token errors."""
|
|
|
|
|
|
class TokenExpiredError(JWTError):
|
|
"""Raised when the JWT token has expired."""
|
|
|
|
|
|
class TokenInvalidError(JWTError):
|
|
"""Raised when the JWT token has an invalid issuer or audience."""
|
|
|
|
|
|
class TokenDecodeError(JWTError):
|
|
"""Raised for any other unrecoverable JWT decode failure."""
|
|
|
|
|
|
class JwtTokenService:
|
|
"""Generic JWT token service with configurable settings."""
|
|
|
|
def __init__(
|
|
self,
|
|
secret_key: str,
|
|
algorithm: str,
|
|
issuer: str,
|
|
audience: str,
|
|
expiration_seconds: int,
|
|
token_type: str,
|
|
):
|
|
"""
|
|
Initialize the token service with custom settings.
|
|
|
|
Args:
|
|
secret_key: Secret key for JWT encoding/decoding
|
|
algorithm: JWT algorithm
|
|
issuer: Token issuer identifier
|
|
audience: Token audience identifier
|
|
expiration_seconds: Token expiration time in seconds
|
|
token_type: Token type
|
|
|
|
Raises:
|
|
ImproperlyConfigured: If secret_key is None or empty
|
|
"""
|
|
if not secret_key:
|
|
raise ImproperlyConfigured("Secret key is required.")
|
|
if not algorithm:
|
|
raise ImproperlyConfigured("Algorithm is required.")
|
|
if not token_type:
|
|
raise ImproperlyConfigured("Token's type is required.")
|
|
if expiration_seconds is None:
|
|
raise ImproperlyConfigured("Expiration's seconds is required.")
|
|
|
|
self._key = secret_key
|
|
self._algorithm = algorithm
|
|
self._issuer = issuer
|
|
self._audience = audience
|
|
self._expiration_seconds = expiration_seconds
|
|
self._token_type = token_type
|
|
|
|
def generate_jwt(
|
|
self, user, scope: str, extra_payload: Optional[dict] = None
|
|
) -> dict:
|
|
"""
|
|
Generate an access token for the given user.
|
|
|
|
Note: any extra_payload variables named iat, exp, or user_id will
|
|
be overwritten by this service
|
|
|
|
Args:
|
|
user: User instance for whom to generate the token
|
|
scope: Space-separated scope string
|
|
|
|
Returns:
|
|
Dictionary containing access_token, token_type, expires_in, and scope optionally
|
|
"""
|
|
now = datetime.now(timezone.utc)
|
|
|
|
payload = extra_payload.copy() if extra_payload else {}
|
|
|
|
payload.update(
|
|
{
|
|
"iat": now,
|
|
"exp": now + timedelta(seconds=self._expiration_seconds),
|
|
"user_id": str(user.id),
|
|
}
|
|
)
|
|
|
|
if self._issuer:
|
|
payload["iss"] = self._issuer
|
|
if self._audience:
|
|
payload["aud"] = self._audience
|
|
if scope:
|
|
payload["scope"] = scope
|
|
|
|
token = jwt.encode(
|
|
payload,
|
|
self._key,
|
|
algorithm=self._algorithm,
|
|
)
|
|
|
|
response = {
|
|
"access_token": token,
|
|
"token_type": self._token_type,
|
|
"expires_in": self._expiration_seconds,
|
|
}
|
|
|
|
if scope:
|
|
response["scope"] = scope
|
|
|
|
return response
|
|
|
|
def decode_jwt(self, token):
|
|
"""Decode and validate JWT token.
|
|
|
|
Args:
|
|
token: JWT token string
|
|
|
|
Returns:
|
|
Decoded payload dict.
|
|
|
|
Raises:
|
|
TokenExpiredError: If the token has expired.
|
|
TokenInvalidError: If the token has an invalid issuer or audience.
|
|
TokenDecodeError: If the token is malformed or cannot be decoded.
|
|
"""
|
|
|
|
try:
|
|
payload = jwt.decode(
|
|
token,
|
|
self._key,
|
|
algorithms=[self._algorithm],
|
|
issuer=self._issuer,
|
|
audience=self._audience,
|
|
)
|
|
return payload
|
|
except jwt.ExpiredSignatureError as e:
|
|
raise TokenExpiredError("Token expired.") from e
|
|
except (jwt.InvalidIssuerError, jwt.InvalidAudienceError) as e:
|
|
raise TokenInvalidError("Invalid token.") from e
|
|
except jwt.InvalidTokenError as e:
|
|
raise TokenDecodeError("Token decode error.") from e
|