♻️(summary) enhance file handling in the Celery worker

The previous code lacked proper encapsulation, resulting in an overly complex
worker. While the initial naive approach was great for bootstrapping the
feature, the refactor introduces more maturity with dedicated service classes
that have clear, single responsibilities.

During the extraction to services, several minor issues were fixed:

1) Properly closing the MinIO response.

2) Enhanced validation of object filenames and extensions to ensure
correct file handling.

3) Introduced a context manager to automatically clean up temporary
local files, removing reliance on developers.

4) Slightly improved logging and naming for clarity.

5) Dynamic temporary file extension handling when it was previously
always an hardcoded .ogg file, even when it was not the case.
This commit is contained in:
lebaudantoine
2025-12-30 18:59:39 +01:00
committed by aleb_the_flash
parent 4cb6320b83
commit 4e5032a7a4
3 changed files with 159 additions and 73 deletions

View File

@@ -1,26 +1,22 @@
"""Celery workers."""
# ruff: noqa: PLR0913, PLR0915
# ruff: noqa: PLR0913
import json
import os
import tempfile
import time
from pathlib import Path
from typing import Optional
import openai
import sentry_sdk
from celery import Celery, signals
from celery.utils.log import get_task_logger
from minio import Minio
from mutagen import File
from requests import Session, exceptions
from requests.adapters import HTTPAdapter
from urllib3.util import Retry
from summary.core.analytics import MetadataManager, get_analytics
from summary.core.config import get_settings
from summary.core.file_service import FileService
from summary.core.llm_service import LLMException, LLMObservability, LLMService
from summary.core.prompt import (
FORMAT_NEXT_STEPS,
@@ -59,17 +55,7 @@ if settings.sentry_dsn and settings.sentry_is_enabled:
sentry_sdk.init(dsn=settings.sentry_dsn, enable_tracing=True)
class AudioValidationError(Exception):
"""Custom exception for audio validation errors."""
pass
def save_audio_stream(audio_stream, chunk_size=32 * 1024):
"""Save an audio stream to a temporary OGG file."""
with tempfile.NamedTemporaryFile(suffix=".ogg", delete=False) as tmp:
tmp.writelines(audio_stream.stream(chunk_size))
return Path(tmp.name)
file_service = FileService(logger=logger)
def create_retry_session():
@@ -151,40 +137,6 @@ def process_audio_transcribe_summarize_v2(
task_id = self.request.id
logger.info("Download recording | Filename: %s", filename)
minio_client = Minio(
settings.aws_s3_endpoint_url,
access_key=settings.aws_s3_access_key_id,
secret_key=settings.aws_s3_secret_access_key.get_secret_value(),
secure=settings.aws_s3_secure_access,
)
logger.debug("Connection to the Minio bucket successful")
audio_file_stream = minio_client.get_object(
settings.aws_storage_bucket_name, object_name=filename
)
temp_file_path = save_audio_stream(audio_file_stream)
logger.info("Recording successfully downloaded")
logger.debug("Recording filepath: %s", temp_file_path)
audio_file = File(temp_file_path)
metadata_manager.track(task_id, {"audio_length": audio_file.info.length})
if (
settings.recording_max_duration is not None
and audio_file.info.length > settings.recording_max_duration
):
error_msg = "Recording too long: %.2fs > %.2fs limit" % (
audio_file.info.length,
settings.recording_max_duration,
)
logger.error(error_msg)
raise AudioValidationError(error_msg)
logger.info("Initiating WhisperX client")
whisperx_client = openai.OpenAI(
api_key=settings.whisperx_api_key.get_secret_value(),
@@ -192,31 +144,31 @@ def process_audio_transcribe_summarize_v2(
max_retries=settings.whisperx_max_retries,
)
try:
with (
file_service.prepare_audio_file(filename) as (audio_file, metadata),
):
metadata_manager.track(task_id, {"audio_length": metadata["duration"]})
logger.info(
"Querying transcription for %s seconds of audio in %s",
audio_file.info.length,
"Querying transcription in '%s' language",
language,
)
transcription_start_time = time.time()
with open(temp_file_path, "rb") as audio_file:
transcription = whisperx_client.audio.transcriptions.create(
model=settings.whisperx_asr_model,
file=audio_file,
language=language or settings.whisperx_default_language,
)
transcription_time = round(time.time() - transcription_start_time, 2)
metadata_manager.track(
task_id,
{"transcription_time": transcription_time},
)
logger.info("Transcription received in %s seconds.", transcription_time)
logger.debug("Transcription: \n %s", transcription)
finally:
if os.path.exists(temp_file_path):
os.remove(temp_file_path)
logger.debug("Temporary file removed: %s", temp_file_path)
transcription_start_time = time.time()
transcription = whisperx_client.audio.transcriptions.create(
model=settings.whisperx_asr_model,
file=audio_file,
language=language or settings.whisperx_default_language,
)
transcription_time = round(time.time() - transcription_start_time, 2)
metadata_manager.track(
task_id,
{"transcription_time": transcription_time},
)
logger.info("Transcription received in %.2f seconds.", transcription_time)
logger.debug("Transcription: \n %s", transcription)
metadata_manager.track_transcription_metadata(task_id, transcription)

View File

@@ -1,7 +1,7 @@
"""Application configuration and settings."""
from functools import lru_cache
from typing import Annotated, List, Optional
from typing import Annotated, List, Optional, Set
from fastapi import Depends
from pydantic import SecretStr
@@ -19,6 +19,7 @@ class Settings(BaseSettings):
# Audio recordings
recording_max_duration: Optional[int] = None
recording_allowed_extensions: Set[str] = {".ogg"}
# Celery settings
celery_broker_url: str = "redis://redis/0"

View File

@@ -0,0 +1,133 @@
"""File service to encapsulate files' manipulations."""
import os
import tempfile
from contextlib import contextmanager
from pathlib import Path
import mutagen
from minio import Minio
from summary.core.config import get_settings
settings = get_settings()
class FileService:
"""Service for downloading and preparing files from MinIO storage."""
def __init__(self, logger):
"""Initialize FileService with MinIO client and configuration."""
self._logger = logger
self._minio_client = Minio(
settings.aws_s3_endpoint_url,
access_key=settings.aws_s3_access_key_id,
secret_key=settings.aws_s3_secret_access_key.get_secret_value(),
secure=settings.aws_s3_secure_access,
)
self._bucket_name = settings.aws_storage_bucket_name
self._stream_chunk_size = 32 * 1024
self._allowed_extensions = settings.recording_allowed_extensions
self._max_duration = settings.recording_max_duration
def _download_from_minio(self, remote_object_key) -> Path:
"""Download file from MinIO to local temporary file.
The file is downloaded to a temporary location for local manipulation
such as validation, conversion, or processing before being used.
"""
self._logger.info("Download recording | object_key: %s", remote_object_key)
if not remote_object_key:
self._logger.warning("Invalid object_key '%s'", remote_object_key)
raise ValueError("Invalid object_key")
extension = Path(remote_object_key).suffix.lower()
if extension not in self._allowed_extensions:
self._logger.warning("Invalid file extension '%s'", extension)
raise ValueError(f"Invalid file extension '{extension}'")
response = None
try:
response = self._minio_client.get_object(
self._bucket_name, remote_object_key
)
with tempfile.NamedTemporaryFile(
suffix=extension, delete=False, prefix="minio_download_"
) as tmp:
for chunk in response.stream(self._stream_chunk_size):
tmp.write(chunk)
tmp.flush()
local_path = Path(tmp.name)
self._logger.info("Recording successfully downloaded")
self._logger.debug("Recording local file path: %s", local_path)
return local_path
finally:
if response:
response.close()
def _validate_duration(self, local_path: Path) -> float:
"""Validate audio file duration against configured maximum."""
file_metadata = mutagen.File(local_path).info
duration = file_metadata.length
self._logger.info(
"Recording file duration: %.2f seconds",
duration,
)
if self._max_duration is not None and duration > self._max_duration:
error_msg = "Recording too long. Limit is %.2fs seconds" % (
self._max_duration,
)
self._logger.error(error_msg)
raise ValueError(error_msg)
return duration
@contextmanager
def prepare_audio_file(self, remote_object_key: str):
"""Download and prepare audio file for processing.
Downloads file from MinIO, validates duration, and yields an open
file handle with metadata. Automatically cleans up temporary files
when the context exits.
"""
downloaded_path = None
file_handle = None
try:
downloaded_path = self._download_from_minio(remote_object_key)
duration = self._validate_duration(downloaded_path)
extension = downloaded_path.suffix.lower()
metadata = {"duration": duration, "extension": extension}
file_handle = open(downloaded_path, "rb")
yield file_handle, metadata
finally:
if file_handle:
file_handle.close()
for path in [downloaded_path]:
if path is None or not os.path.exists(path):
continue
try:
os.remove(path)
self._logger.debug("Temporary file removed: %s", path)
except OSError as e:
self._logger.warning(
"Failed to remove temporary file %s: %s", path, e
)