From 4e5032a7a4f6cb4e251d742caf808da0198d6fb9 Mon Sep 17 00:00:00 2001 From: lebaudantoine Date: Tue, 30 Dec 2025 18:59:39 +0100 Subject: [PATCH] =?UTF-8?q?=E2=99=BB=EF=B8=8F(summary)=20enhance=20file=20?= =?UTF-8?q?handling=20in=20the=20Celery=20worker?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The previous code lacked proper encapsulation, resulting in an overly complex worker. While the initial naive approach was great for bootstrapping the feature, the refactor introduces more maturity with dedicated service classes that have clear, single responsibilities. During the extraction to services, several minor issues were fixed: 1) Properly closing the MinIO response. 2) Enhanced validation of object filenames and extensions to ensure correct file handling. 3) Introduced a context manager to automatically clean up temporary local files, removing reliance on developers. 4) Slightly improved logging and naming for clarity. 5) Dynamic temporary file extension handling when it was previously always an hardcoded .ogg file, even when it was not the case. --- src/summary/summary/core/celery_worker.py | 96 ++++------------ src/summary/summary/core/config.py | 3 +- src/summary/summary/core/file_service.py | 133 ++++++++++++++++++++++ 3 files changed, 159 insertions(+), 73 deletions(-) create mode 100644 src/summary/summary/core/file_service.py diff --git a/src/summary/summary/core/celery_worker.py b/src/summary/summary/core/celery_worker.py index d99f6a44..66bf5c87 100644 --- a/src/summary/summary/core/celery_worker.py +++ b/src/summary/summary/core/celery_worker.py @@ -1,26 +1,22 @@ """Celery workers.""" -# ruff: noqa: PLR0913, PLR0915 +# ruff: noqa: PLR0913 import json -import os -import tempfile import time -from pathlib import Path from typing import Optional import openai import sentry_sdk from celery import Celery, signals from celery.utils.log import get_task_logger -from minio import Minio -from mutagen import File from requests import Session, exceptions from requests.adapters import HTTPAdapter from urllib3.util import Retry from summary.core.analytics import MetadataManager, get_analytics from summary.core.config import get_settings +from summary.core.file_service import FileService from summary.core.llm_service import LLMException, LLMObservability, LLMService from summary.core.prompt import ( FORMAT_NEXT_STEPS, @@ -59,17 +55,7 @@ if settings.sentry_dsn and settings.sentry_is_enabled: sentry_sdk.init(dsn=settings.sentry_dsn, enable_tracing=True) -class AudioValidationError(Exception): - """Custom exception for audio validation errors.""" - - pass - - -def save_audio_stream(audio_stream, chunk_size=32 * 1024): - """Save an audio stream to a temporary OGG file.""" - with tempfile.NamedTemporaryFile(suffix=".ogg", delete=False) as tmp: - tmp.writelines(audio_stream.stream(chunk_size)) - return Path(tmp.name) +file_service = FileService(logger=logger) def create_retry_session(): @@ -151,40 +137,6 @@ def process_audio_transcribe_summarize_v2( task_id = self.request.id - logger.info("Download recording | Filename: %s", filename) - - minio_client = Minio( - settings.aws_s3_endpoint_url, - access_key=settings.aws_s3_access_key_id, - secret_key=settings.aws_s3_secret_access_key.get_secret_value(), - secure=settings.aws_s3_secure_access, - ) - - logger.debug("Connection to the Minio bucket successful") - - audio_file_stream = minio_client.get_object( - settings.aws_storage_bucket_name, object_name=filename - ) - - temp_file_path = save_audio_stream(audio_file_stream) - - logger.info("Recording successfully downloaded") - logger.debug("Recording filepath: %s", temp_file_path) - - audio_file = File(temp_file_path) - metadata_manager.track(task_id, {"audio_length": audio_file.info.length}) - - if ( - settings.recording_max_duration is not None - and audio_file.info.length > settings.recording_max_duration - ): - error_msg = "Recording too long: %.2fs > %.2fs limit" % ( - audio_file.info.length, - settings.recording_max_duration, - ) - logger.error(error_msg) - raise AudioValidationError(error_msg) - logger.info("Initiating WhisperX client") whisperx_client = openai.OpenAI( api_key=settings.whisperx_api_key.get_secret_value(), @@ -192,31 +144,31 @@ def process_audio_transcribe_summarize_v2( max_retries=settings.whisperx_max_retries, ) - try: + with ( + file_service.prepare_audio_file(filename) as (audio_file, metadata), + ): + metadata_manager.track(task_id, {"audio_length": metadata["duration"]}) + logger.info( - "Querying transcription for %s seconds of audio in %s …", - audio_file.info.length, + "Querying transcription in '%s' language …", language, ) - transcription_start_time = time.time() - with open(temp_file_path, "rb") as audio_file: - transcription = whisperx_client.audio.transcriptions.create( - model=settings.whisperx_asr_model, - file=audio_file, - language=language or settings.whisperx_default_language, - ) - transcription_time = round(time.time() - transcription_start_time, 2) - metadata_manager.track( - task_id, - {"transcription_time": transcription_time}, - ) - logger.info("Transcription received in %s seconds.", transcription_time) - logger.debug("Transcription: \n %s", transcription) - finally: - if os.path.exists(temp_file_path): - os.remove(temp_file_path) - logger.debug("Temporary file removed: %s", temp_file_path) + transcription_start_time = time.time() + + transcription = whisperx_client.audio.transcriptions.create( + model=settings.whisperx_asr_model, + file=audio_file, + language=language or settings.whisperx_default_language, + ) + + transcription_time = round(time.time() - transcription_start_time, 2) + metadata_manager.track( + task_id, + {"transcription_time": transcription_time}, + ) + logger.info("Transcription received in %.2f seconds.", transcription_time) + logger.debug("Transcription: \n %s", transcription) metadata_manager.track_transcription_metadata(task_id, transcription) diff --git a/src/summary/summary/core/config.py b/src/summary/summary/core/config.py index ee7654a5..2624cced 100644 --- a/src/summary/summary/core/config.py +++ b/src/summary/summary/core/config.py @@ -1,7 +1,7 @@ """Application configuration and settings.""" from functools import lru_cache -from typing import Annotated, List, Optional +from typing import Annotated, List, Optional, Set from fastapi import Depends from pydantic import SecretStr @@ -19,6 +19,7 @@ class Settings(BaseSettings): # Audio recordings recording_max_duration: Optional[int] = None + recording_allowed_extensions: Set[str] = {".ogg"} # Celery settings celery_broker_url: str = "redis://redis/0" diff --git a/src/summary/summary/core/file_service.py b/src/summary/summary/core/file_service.py new file mode 100644 index 00000000..69e44629 --- /dev/null +++ b/src/summary/summary/core/file_service.py @@ -0,0 +1,133 @@ +"""File service to encapsulate files' manipulations.""" + +import os +import tempfile +from contextlib import contextmanager +from pathlib import Path + +import mutagen +from minio import Minio + +from summary.core.config import get_settings + +settings = get_settings() + + +class FileService: + """Service for downloading and preparing files from MinIO storage.""" + + def __init__(self, logger): + """Initialize FileService with MinIO client and configuration.""" + self._logger = logger + + self._minio_client = Minio( + settings.aws_s3_endpoint_url, + access_key=settings.aws_s3_access_key_id, + secret_key=settings.aws_s3_secret_access_key.get_secret_value(), + secure=settings.aws_s3_secure_access, + ) + + self._bucket_name = settings.aws_storage_bucket_name + self._stream_chunk_size = 32 * 1024 + + self._allowed_extensions = settings.recording_allowed_extensions + self._max_duration = settings.recording_max_duration + + def _download_from_minio(self, remote_object_key) -> Path: + """Download file from MinIO to local temporary file. + + The file is downloaded to a temporary location for local manipulation + such as validation, conversion, or processing before being used. + """ + self._logger.info("Download recording | object_key: %s", remote_object_key) + + if not remote_object_key: + self._logger.warning("Invalid object_key '%s'", remote_object_key) + raise ValueError("Invalid object_key") + + extension = Path(remote_object_key).suffix.lower() + + if extension not in self._allowed_extensions: + self._logger.warning("Invalid file extension '%s'", extension) + raise ValueError(f"Invalid file extension '{extension}'") + + response = None + + try: + response = self._minio_client.get_object( + self._bucket_name, remote_object_key + ) + + with tempfile.NamedTemporaryFile( + suffix=extension, delete=False, prefix="minio_download_" + ) as tmp: + for chunk in response.stream(self._stream_chunk_size): + tmp.write(chunk) + + tmp.flush() + local_path = Path(tmp.name) + + self._logger.info("Recording successfully downloaded") + self._logger.debug("Recording local file path: %s", local_path) + + return local_path + + finally: + if response: + response.close() + + def _validate_duration(self, local_path: Path) -> float: + """Validate audio file duration against configured maximum.""" + file_metadata = mutagen.File(local_path).info + duration = file_metadata.length + + self._logger.info( + "Recording file duration: %.2f seconds", + duration, + ) + + if self._max_duration is not None and duration > self._max_duration: + error_msg = "Recording too long. Limit is %.2fs seconds" % ( + self._max_duration, + ) + self._logger.error(error_msg) + raise ValueError(error_msg) + + return duration + + @contextmanager + def prepare_audio_file(self, remote_object_key: str): + """Download and prepare audio file for processing. + + Downloads file from MinIO, validates duration, and yields an open + file handle with metadata. Automatically cleans up temporary files + when the context exits. + """ + downloaded_path = None + file_handle = None + + try: + downloaded_path = self._download_from_minio(remote_object_key) + duration = self._validate_duration(downloaded_path) + + extension = downloaded_path.suffix.lower() + metadata = {"duration": duration, "extension": extension} + + file_handle = open(downloaded_path, "rb") + yield file_handle, metadata + + finally: + if file_handle: + file_handle.close() + + for path in [downloaded_path]: + if path is None or not os.path.exists(path): + continue + + try: + os.remove(path) + self._logger.debug("Temporary file removed: %s", path) + except OSError as e: + self._logger.warning( + "Failed to remove temporary file %s: %s", path, e + )