Files
calendars/src/backend/core/task_utils.py

171 lines
5.1 KiB
Python
Raw Normal View History

"""Task queue utilities.
Provides decorators and helpers that abstract away the underlying task
queue library (currently Dramatiq). Application code should import from
here instead of importing dramatiq directly.
"""
import json
import logging
from typing import Any, Optional
from django.core.cache import cache
from django.utils import timezone
import dramatiq
from dramatiq.brokers.stub import StubBroker
from dramatiq.middleware import CurrentMessage
logger = logging.getLogger(__name__)
TASK_PROGRESS_CACHE_TIMEOUT = 86400 # 24 hours
TASK_TRACKING_CACHE_TTL = 86400 * 30 # 30 days
# ---------------------------------------------------------------------------
# Task wrapper (Celery-compatible API)
# ---------------------------------------------------------------------------
class Task:
"""Wrapper around a Dramatiq Message with a Celery-like API."""
def __init__(self, message):
self._message = message
@property
def id(self):
"""Celery-compatible task ID (maps to message_id)."""
return self._message.message_id
def track_owner(self, user_id):
"""Register tracking metadata for permission checks."""
cache.set(
f"task_tracking:{self.id}",
json.dumps(
{
"owner": str(user_id),
"actor_name": self._message.actor_name,
"queue_name": self._message.queue_name,
}
),
timeout=TASK_TRACKING_CACHE_TTL,
)
def __getattr__(self, name):
return getattr(self._message, name)
class CeleryCompatActor(dramatiq.Actor):
"""Actor subclass that adds a .delay() method returning a Task."""
def delay(self, *args, **kwargs):
"""Dispatch the task asynchronously, returning a Task wrapper."""
message = self.send(*args, **kwargs)
return Task(message)
# ---------------------------------------------------------------------------
# Decorators
# ---------------------------------------------------------------------------
def register_task(*args, **kwargs):
"""Decorator to register a task (wraps dramatiq.actor).
Usage::
@register_task(queue="import")
def my_task(arg):
...
"""
kwargs.setdefault("store_results", True)
if "queue" in kwargs:
kwargs.setdefault("queue_name", kwargs.pop("queue"))
kwargs.setdefault("actor_class", CeleryCompatActor)
def decorator(fn):
return dramatiq.actor(fn, **kwargs)
if args and callable(args[0]):
return decorator(args[0])
return decorator
# ---------------------------------------------------------------------------
# Task tracking & progress
# ---------------------------------------------------------------------------
def get_task_tracking(task_id: str) -> Optional[dict]:
"""Get tracking metadata for a task, or None if not found."""
raw = cache.get(f"task_tracking:{task_id}")
if raw is None:
return None
return json.loads(raw)
def set_task_progress(progress: int, metadata: Optional[dict[str, Any]] = None) -> None:
"""Set the progress of the currently executing task."""
current_message = CurrentMessage.get_current_message()
if not current_message:
logger.warning("set_task_progress called outside of a task")
return
task_id = current_message.message_id
try:
progress = max(0, min(100, int(progress)))
except (TypeError, ValueError):
progress = 0
cache.set(
f"task_progress:{task_id}",
{
"progress": progress,
"timestamp": timezone.now().timestamp(),
"metadata": metadata or {},
},
timeout=TASK_PROGRESS_CACHE_TIMEOUT,
)
def get_task_progress(task_id: str) -> Optional[dict[str, Any]]:
"""Get the progress of a task by ID."""
return cache.get(f"task_progress:{task_id}")
# ---------------------------------------------------------------------------
# EagerBroker for tests
# ---------------------------------------------------------------------------
class EagerBroker(StubBroker):
"""Broker that executes tasks synchronously (for tests).
Equivalent to Celery's CELERY_TASK_ALWAYS_EAGER mode.
Only runs CurrentMessage and Results middleware.
"""
def enqueue(self, message, *, delay=None):
from dramatiq.results import Results # noqa: PLC0415 # pylint: disable=C0415
actor = self.get_actor(message.actor_name)
cm = next(
(m for m in self.middleware if isinstance(m, CurrentMessage)),
None,
)
rm = next((m for m in self.middleware if isinstance(m, Results)), None)
prev = CurrentMessage.get_current_message() if cm else None
if cm:
cm.before_process_message(self, message)
try:
result = actor.fn(*message.args, **message.kwargs)
if rm:
rm.after_process_message(self, message, result=result)
finally:
if cm:
cm.after_process_message(self, message)
if prev is not None:
cm.before_process_message(self, prev)
return message