171 lines
5.1 KiB
Python
171 lines
5.1 KiB
Python
|
|
"""Task queue utilities.
|
||
|
|
|
||
|
|
Provides decorators and helpers that abstract away the underlying task
|
||
|
|
queue library (currently Dramatiq). Application code should import from
|
||
|
|
here instead of importing dramatiq directly.
|
||
|
|
"""
|
||
|
|
|
||
|
|
import json
|
||
|
|
import logging
|
||
|
|
from typing import Any, Optional
|
||
|
|
|
||
|
|
from django.core.cache import cache
|
||
|
|
from django.utils import timezone
|
||
|
|
|
||
|
|
import dramatiq
|
||
|
|
from dramatiq.brokers.stub import StubBroker
|
||
|
|
from dramatiq.middleware import CurrentMessage
|
||
|
|
|
||
|
|
logger = logging.getLogger(__name__)
|
||
|
|
|
||
|
|
TASK_PROGRESS_CACHE_TIMEOUT = 86400 # 24 hours
|
||
|
|
TASK_TRACKING_CACHE_TTL = 86400 * 30 # 30 days
|
||
|
|
|
||
|
|
|
||
|
|
# ---------------------------------------------------------------------------
|
||
|
|
# Task wrapper (Celery-compatible API)
|
||
|
|
# ---------------------------------------------------------------------------
|
||
|
|
|
||
|
|
|
||
|
|
class Task:
|
||
|
|
"""Wrapper around a Dramatiq Message with a Celery-like API."""
|
||
|
|
|
||
|
|
def __init__(self, message):
|
||
|
|
self._message = message
|
||
|
|
|
||
|
|
@property
|
||
|
|
def id(self):
|
||
|
|
"""Celery-compatible task ID (maps to message_id)."""
|
||
|
|
return self._message.message_id
|
||
|
|
|
||
|
|
def track_owner(self, user_id):
|
||
|
|
"""Register tracking metadata for permission checks."""
|
||
|
|
cache.set(
|
||
|
|
f"task_tracking:{self.id}",
|
||
|
|
json.dumps(
|
||
|
|
{
|
||
|
|
"owner": str(user_id),
|
||
|
|
"actor_name": self._message.actor_name,
|
||
|
|
"queue_name": self._message.queue_name,
|
||
|
|
}
|
||
|
|
),
|
||
|
|
timeout=TASK_TRACKING_CACHE_TTL,
|
||
|
|
)
|
||
|
|
|
||
|
|
def __getattr__(self, name):
|
||
|
|
return getattr(self._message, name)
|
||
|
|
|
||
|
|
|
||
|
|
class CeleryCompatActor(dramatiq.Actor):
|
||
|
|
"""Actor subclass that adds a .delay() method returning a Task."""
|
||
|
|
|
||
|
|
def delay(self, *args, **kwargs):
|
||
|
|
"""Dispatch the task asynchronously, returning a Task wrapper."""
|
||
|
|
message = self.send(*args, **kwargs)
|
||
|
|
return Task(message)
|
||
|
|
|
||
|
|
|
||
|
|
# ---------------------------------------------------------------------------
|
||
|
|
# Decorators
|
||
|
|
# ---------------------------------------------------------------------------
|
||
|
|
|
||
|
|
|
||
|
|
def register_task(*args, **kwargs):
|
||
|
|
"""Decorator to register a task (wraps dramatiq.actor).
|
||
|
|
|
||
|
|
Usage::
|
||
|
|
|
||
|
|
@register_task(queue="import")
|
||
|
|
def my_task(arg):
|
||
|
|
...
|
||
|
|
"""
|
||
|
|
kwargs.setdefault("store_results", True)
|
||
|
|
if "queue" in kwargs:
|
||
|
|
kwargs.setdefault("queue_name", kwargs.pop("queue"))
|
||
|
|
kwargs.setdefault("actor_class", CeleryCompatActor)
|
||
|
|
|
||
|
|
def decorator(fn):
|
||
|
|
return dramatiq.actor(fn, **kwargs)
|
||
|
|
|
||
|
|
if args and callable(args[0]):
|
||
|
|
return decorator(args[0])
|
||
|
|
return decorator
|
||
|
|
|
||
|
|
|
||
|
|
# ---------------------------------------------------------------------------
|
||
|
|
# Task tracking & progress
|
||
|
|
# ---------------------------------------------------------------------------
|
||
|
|
|
||
|
|
|
||
|
|
def get_task_tracking(task_id: str) -> Optional[dict]:
|
||
|
|
"""Get tracking metadata for a task, or None if not found."""
|
||
|
|
raw = cache.get(f"task_tracking:{task_id}")
|
||
|
|
if raw is None:
|
||
|
|
return None
|
||
|
|
return json.loads(raw)
|
||
|
|
|
||
|
|
|
||
|
|
def set_task_progress(progress: int, metadata: Optional[dict[str, Any]] = None) -> None:
|
||
|
|
"""Set the progress of the currently executing task."""
|
||
|
|
current_message = CurrentMessage.get_current_message()
|
||
|
|
if not current_message:
|
||
|
|
logger.warning("set_task_progress called outside of a task")
|
||
|
|
return
|
||
|
|
|
||
|
|
task_id = current_message.message_id
|
||
|
|
try:
|
||
|
|
progress = max(0, min(100, int(progress)))
|
||
|
|
except (TypeError, ValueError):
|
||
|
|
progress = 0
|
||
|
|
|
||
|
|
cache.set(
|
||
|
|
f"task_progress:{task_id}",
|
||
|
|
{
|
||
|
|
"progress": progress,
|
||
|
|
"timestamp": timezone.now().timestamp(),
|
||
|
|
"metadata": metadata or {},
|
||
|
|
},
|
||
|
|
timeout=TASK_PROGRESS_CACHE_TIMEOUT,
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
def get_task_progress(task_id: str) -> Optional[dict[str, Any]]:
|
||
|
|
"""Get the progress of a task by ID."""
|
||
|
|
return cache.get(f"task_progress:{task_id}")
|
||
|
|
|
||
|
|
|
||
|
|
# ---------------------------------------------------------------------------
|
||
|
|
# EagerBroker for tests
|
||
|
|
# ---------------------------------------------------------------------------
|
||
|
|
|
||
|
|
|
||
|
|
class EagerBroker(StubBroker):
|
||
|
|
"""Broker that executes tasks synchronously (for tests).
|
||
|
|
|
||
|
|
Equivalent to Celery's CELERY_TASK_ALWAYS_EAGER mode.
|
||
|
|
Only runs CurrentMessage and Results middleware.
|
||
|
|
"""
|
||
|
|
|
||
|
|
def enqueue(self, message, *, delay=None):
|
||
|
|
from dramatiq.results import Results # noqa: PLC0415 # pylint: disable=C0415
|
||
|
|
|
||
|
|
actor = self.get_actor(message.actor_name)
|
||
|
|
cm = next(
|
||
|
|
(m for m in self.middleware if isinstance(m, CurrentMessage)),
|
||
|
|
None,
|
||
|
|
)
|
||
|
|
rm = next((m for m in self.middleware if isinstance(m, Results)), None)
|
||
|
|
prev = CurrentMessage.get_current_message() if cm else None
|
||
|
|
if cm:
|
||
|
|
cm.before_process_message(self, message)
|
||
|
|
try:
|
||
|
|
result = actor.fn(*message.args, **message.kwargs)
|
||
|
|
if rm:
|
||
|
|
rm.after_process_message(self, message, result=result)
|
||
|
|
finally:
|
||
|
|
if cm:
|
||
|
|
cm.after_process_message(self, message)
|
||
|
|
if prev is not None:
|
||
|
|
cm.before_process_message(self, prev)
|
||
|
|
return message
|