diff --git a/src/summary/summary/core/celery_worker.py b/src/summary/summary/core/celery_worker.py index d99f6a44..66bf5c87 100644 --- a/src/summary/summary/core/celery_worker.py +++ b/src/summary/summary/core/celery_worker.py @@ -1,26 +1,22 @@ """Celery workers.""" -# ruff: noqa: PLR0913, PLR0915 +# ruff: noqa: PLR0913 import json -import os -import tempfile import time -from pathlib import Path from typing import Optional import openai import sentry_sdk from celery import Celery, signals from celery.utils.log import get_task_logger -from minio import Minio -from mutagen import File from requests import Session, exceptions from requests.adapters import HTTPAdapter from urllib3.util import Retry from summary.core.analytics import MetadataManager, get_analytics from summary.core.config import get_settings +from summary.core.file_service import FileService from summary.core.llm_service import LLMException, LLMObservability, LLMService from summary.core.prompt import ( FORMAT_NEXT_STEPS, @@ -59,17 +55,7 @@ if settings.sentry_dsn and settings.sentry_is_enabled: sentry_sdk.init(dsn=settings.sentry_dsn, enable_tracing=True) -class AudioValidationError(Exception): - """Custom exception for audio validation errors.""" - - pass - - -def save_audio_stream(audio_stream, chunk_size=32 * 1024): - """Save an audio stream to a temporary OGG file.""" - with tempfile.NamedTemporaryFile(suffix=".ogg", delete=False) as tmp: - tmp.writelines(audio_stream.stream(chunk_size)) - return Path(tmp.name) +file_service = FileService(logger=logger) def create_retry_session(): @@ -151,40 +137,6 @@ def process_audio_transcribe_summarize_v2( task_id = self.request.id - logger.info("Download recording | Filename: %s", filename) - - minio_client = Minio( - settings.aws_s3_endpoint_url, - access_key=settings.aws_s3_access_key_id, - secret_key=settings.aws_s3_secret_access_key.get_secret_value(), - secure=settings.aws_s3_secure_access, - ) - - logger.debug("Connection to the Minio bucket successful") - - audio_file_stream = minio_client.get_object( - settings.aws_storage_bucket_name, object_name=filename - ) - - temp_file_path = save_audio_stream(audio_file_stream) - - logger.info("Recording successfully downloaded") - logger.debug("Recording filepath: %s", temp_file_path) - - audio_file = File(temp_file_path) - metadata_manager.track(task_id, {"audio_length": audio_file.info.length}) - - if ( - settings.recording_max_duration is not None - and audio_file.info.length > settings.recording_max_duration - ): - error_msg = "Recording too long: %.2fs > %.2fs limit" % ( - audio_file.info.length, - settings.recording_max_duration, - ) - logger.error(error_msg) - raise AudioValidationError(error_msg) - logger.info("Initiating WhisperX client") whisperx_client = openai.OpenAI( api_key=settings.whisperx_api_key.get_secret_value(), @@ -192,31 +144,31 @@ def process_audio_transcribe_summarize_v2( max_retries=settings.whisperx_max_retries, ) - try: + with ( + file_service.prepare_audio_file(filename) as (audio_file, metadata), + ): + metadata_manager.track(task_id, {"audio_length": metadata["duration"]}) + logger.info( - "Querying transcription for %s seconds of audio in %s …", - audio_file.info.length, + "Querying transcription in '%s' language …", language, ) - transcription_start_time = time.time() - with open(temp_file_path, "rb") as audio_file: - transcription = whisperx_client.audio.transcriptions.create( - model=settings.whisperx_asr_model, - file=audio_file, - language=language or settings.whisperx_default_language, - ) - transcription_time = round(time.time() - transcription_start_time, 2) - metadata_manager.track( - task_id, - {"transcription_time": transcription_time}, - ) - logger.info("Transcription received in %s seconds.", transcription_time) - logger.debug("Transcription: \n %s", transcription) - finally: - if os.path.exists(temp_file_path): - os.remove(temp_file_path) - logger.debug("Temporary file removed: %s", temp_file_path) + transcription_start_time = time.time() + + transcription = whisperx_client.audio.transcriptions.create( + model=settings.whisperx_asr_model, + file=audio_file, + language=language or settings.whisperx_default_language, + ) + + transcription_time = round(time.time() - transcription_start_time, 2) + metadata_manager.track( + task_id, + {"transcription_time": transcription_time}, + ) + logger.info("Transcription received in %.2f seconds.", transcription_time) + logger.debug("Transcription: \n %s", transcription) metadata_manager.track_transcription_metadata(task_id, transcription) diff --git a/src/summary/summary/core/config.py b/src/summary/summary/core/config.py index ee7654a5..2624cced 100644 --- a/src/summary/summary/core/config.py +++ b/src/summary/summary/core/config.py @@ -1,7 +1,7 @@ """Application configuration and settings.""" from functools import lru_cache -from typing import Annotated, List, Optional +from typing import Annotated, List, Optional, Set from fastapi import Depends from pydantic import SecretStr @@ -19,6 +19,7 @@ class Settings(BaseSettings): # Audio recordings recording_max_duration: Optional[int] = None + recording_allowed_extensions: Set[str] = {".ogg"} # Celery settings celery_broker_url: str = "redis://redis/0" diff --git a/src/summary/summary/core/file_service.py b/src/summary/summary/core/file_service.py new file mode 100644 index 00000000..69e44629 --- /dev/null +++ b/src/summary/summary/core/file_service.py @@ -0,0 +1,133 @@ +"""File service to encapsulate files' manipulations.""" + +import os +import tempfile +from contextlib import contextmanager +from pathlib import Path + +import mutagen +from minio import Minio + +from summary.core.config import get_settings + +settings = get_settings() + + +class FileService: + """Service for downloading and preparing files from MinIO storage.""" + + def __init__(self, logger): + """Initialize FileService with MinIO client and configuration.""" + self._logger = logger + + self._minio_client = Minio( + settings.aws_s3_endpoint_url, + access_key=settings.aws_s3_access_key_id, + secret_key=settings.aws_s3_secret_access_key.get_secret_value(), + secure=settings.aws_s3_secure_access, + ) + + self._bucket_name = settings.aws_storage_bucket_name + self._stream_chunk_size = 32 * 1024 + + self._allowed_extensions = settings.recording_allowed_extensions + self._max_duration = settings.recording_max_duration + + def _download_from_minio(self, remote_object_key) -> Path: + """Download file from MinIO to local temporary file. + + The file is downloaded to a temporary location for local manipulation + such as validation, conversion, or processing before being used. + """ + self._logger.info("Download recording | object_key: %s", remote_object_key) + + if not remote_object_key: + self._logger.warning("Invalid object_key '%s'", remote_object_key) + raise ValueError("Invalid object_key") + + extension = Path(remote_object_key).suffix.lower() + + if extension not in self._allowed_extensions: + self._logger.warning("Invalid file extension '%s'", extension) + raise ValueError(f"Invalid file extension '{extension}'") + + response = None + + try: + response = self._minio_client.get_object( + self._bucket_name, remote_object_key + ) + + with tempfile.NamedTemporaryFile( + suffix=extension, delete=False, prefix="minio_download_" + ) as tmp: + for chunk in response.stream(self._stream_chunk_size): + tmp.write(chunk) + + tmp.flush() + local_path = Path(tmp.name) + + self._logger.info("Recording successfully downloaded") + self._logger.debug("Recording local file path: %s", local_path) + + return local_path + + finally: + if response: + response.close() + + def _validate_duration(self, local_path: Path) -> float: + """Validate audio file duration against configured maximum.""" + file_metadata = mutagen.File(local_path).info + duration = file_metadata.length + + self._logger.info( + "Recording file duration: %.2f seconds", + duration, + ) + + if self._max_duration is not None and duration > self._max_duration: + error_msg = "Recording too long. Limit is %.2fs seconds" % ( + self._max_duration, + ) + self._logger.error(error_msg) + raise ValueError(error_msg) + + return duration + + @contextmanager + def prepare_audio_file(self, remote_object_key: str): + """Download and prepare audio file for processing. + + Downloads file from MinIO, validates duration, and yields an open + file handle with metadata. Automatically cleans up temporary files + when the context exits. + """ + downloaded_path = None + file_handle = None + + try: + downloaded_path = self._download_from_minio(remote_object_key) + duration = self._validate_duration(downloaded_path) + + extension = downloaded_path.suffix.lower() + metadata = {"duration": duration, "extension": extension} + + file_handle = open(downloaded_path, "rb") + yield file_handle, metadata + + finally: + if file_handle: + file_handle.close() + + for path in [downloaded_path]: + if path is None or not os.path.exists(path): + continue + + try: + os.remove(path) + self._logger.debug("Temporary file removed: %s", path) + except OSError as e: + self._logger.warning( + "Failed to remove temporary file %s: %s", path, e + )