2026-01-22 20:15:07 +01:00
|
|
|
"""JWT token service."""
|
|
|
|
|
|
|
|
|
|
# pylint: disable=R0913,R0917
|
|
|
|
|
# ruff: noqa: PLR0913
|
|
|
|
|
|
|
|
|
|
from datetime import datetime, timedelta, timezone
|
|
|
|
|
from typing import Optional
|
|
|
|
|
|
|
|
|
|
from django.core.exceptions import ImproperlyConfigured
|
|
|
|
|
|
|
|
|
|
import jwt
|
|
|
|
|
|
|
|
|
|
|
2026-01-26 16:23:46 +01:00
|
|
|
class JWTError(Exception):
|
|
|
|
|
"""Base exception for all JWT token errors."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TokenExpiredError(JWTError):
|
|
|
|
|
"""Raised when the JWT token has expired."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TokenInvalidError(JWTError):
|
|
|
|
|
"""Raised when the JWT token has an invalid issuer or audience."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TokenDecodeError(JWTError):
|
|
|
|
|
"""Raised for any other unrecoverable JWT decode failure."""
|
|
|
|
|
|
|
|
|
|
|
2026-01-22 20:15:07 +01:00
|
|
|
class JwtTokenService:
|
|
|
|
|
"""Generic JWT token service with configurable settings."""
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
secret_key: str,
|
|
|
|
|
algorithm: str,
|
|
|
|
|
issuer: str,
|
|
|
|
|
audience: str,
|
|
|
|
|
expiration_seconds: int,
|
|
|
|
|
token_type: str,
|
|
|
|
|
):
|
|
|
|
|
"""
|
|
|
|
|
Initialize the token service with custom settings.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
secret_key: Secret key for JWT encoding/decoding
|
|
|
|
|
algorithm: JWT algorithm
|
|
|
|
|
issuer: Token issuer identifier
|
|
|
|
|
audience: Token audience identifier
|
|
|
|
|
expiration_seconds: Token expiration time in seconds
|
|
|
|
|
token_type: Token type
|
|
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
|
ImproperlyConfigured: If secret_key is None or empty
|
|
|
|
|
"""
|
|
|
|
|
if not secret_key:
|
|
|
|
|
raise ImproperlyConfigured("Secret key is required.")
|
|
|
|
|
if not algorithm:
|
|
|
|
|
raise ImproperlyConfigured("Algorithm is required.")
|
|
|
|
|
if not token_type:
|
|
|
|
|
raise ImproperlyConfigured("Token's type is required.")
|
|
|
|
|
if expiration_seconds is None:
|
|
|
|
|
raise ImproperlyConfigured("Expiration's seconds is required.")
|
|
|
|
|
|
|
|
|
|
self._key = secret_key
|
|
|
|
|
self._algorithm = algorithm
|
|
|
|
|
self._issuer = issuer
|
|
|
|
|
self._audience = audience
|
|
|
|
|
self._expiration_seconds = expiration_seconds
|
|
|
|
|
self._token_type = token_type
|
|
|
|
|
|
|
|
|
|
def generate_jwt(
|
|
|
|
|
self, user, scope: str, extra_payload: Optional[dict] = None
|
|
|
|
|
) -> dict:
|
|
|
|
|
"""
|
|
|
|
|
Generate an access token for the given user.
|
|
|
|
|
|
|
|
|
|
Note: any extra_payload variables named iat, exp, or user_id will
|
|
|
|
|
be overwritten by this service
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
user: User instance for whom to generate the token
|
|
|
|
|
scope: Space-separated scope string
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Dictionary containing access_token, token_type, expires_in, and scope optionally
|
|
|
|
|
"""
|
|
|
|
|
now = datetime.now(timezone.utc)
|
|
|
|
|
|
|
|
|
|
payload = extra_payload.copy() if extra_payload else {}
|
|
|
|
|
|
|
|
|
|
payload.update(
|
|
|
|
|
{
|
|
|
|
|
"iat": now,
|
|
|
|
|
"exp": now + timedelta(seconds=self._expiration_seconds),
|
|
|
|
|
"user_id": str(user.id),
|
|
|
|
|
}
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if self._issuer:
|
|
|
|
|
payload["iss"] = self._issuer
|
|
|
|
|
if self._audience:
|
|
|
|
|
payload["aud"] = self._audience
|
|
|
|
|
if scope:
|
|
|
|
|
payload["scope"] = scope
|
|
|
|
|
|
|
|
|
|
token = jwt.encode(
|
|
|
|
|
payload,
|
|
|
|
|
self._key,
|
|
|
|
|
algorithm=self._algorithm,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
response = {
|
|
|
|
|
"access_token": token,
|
|
|
|
|
"token_type": self._token_type,
|
|
|
|
|
"expires_in": self._expiration_seconds,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if scope:
|
|
|
|
|
response["scope"] = scope
|
|
|
|
|
|
|
|
|
|
return response
|
2026-01-26 16:23:46 +01:00
|
|
|
|
|
|
|
|
def decode_jwt(self, token):
|
|
|
|
|
"""Decode and validate JWT token.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
token: JWT token string
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Decoded payload dict.
|
|
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
|
TokenExpiredError: If the token has expired.
|
|
|
|
|
TokenInvalidError: If the token has an invalid issuer or audience.
|
|
|
|
|
TokenDecodeError: If the token is malformed or cannot be decoded.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
payload = jwt.decode(
|
|
|
|
|
token,
|
|
|
|
|
self._key,
|
|
|
|
|
algorithms=[self._algorithm],
|
|
|
|
|
issuer=self._issuer,
|
|
|
|
|
audience=self._audience,
|
|
|
|
|
)
|
|
|
|
|
return payload
|
|
|
|
|
except jwt.ExpiredSignatureError as e:
|
|
|
|
|
raise TokenExpiredError("Token expired.") from e
|
|
|
|
|
except (jwt.InvalidIssuerError, jwt.InvalidAudienceError) as e:
|
|
|
|
|
raise TokenInvalidError("Invalid token.") from e
|
|
|
|
|
except jwt.InvalidTokenError as e:
|
|
|
|
|
raise TokenDecodeError("Token decode error.") from e
|