♻️(summary) enhance file handling in the Celery worker
The previous code lacked proper encapsulation, resulting in an overly complex worker. While the initial naive approach was great for bootstrapping the feature, the refactor introduces more maturity with dedicated service classes that have clear, single responsibilities. During the extraction to services, several minor issues were fixed: 1) Properly closing the MinIO response. 2) Enhanced validation of object filenames and extensions to ensure correct file handling. 3) Introduced a context manager to automatically clean up temporary local files, removing reliance on developers. 4) Slightly improved logging and naming for clarity. 5) Dynamic temporary file extension handling when it was previously always an hardcoded .ogg file, even when it was not the case.
This commit is contained in:
committed by
aleb_the_flash
parent
4cb6320b83
commit
4e5032a7a4
@@ -1,26 +1,22 @@
|
||||
"""Celery workers."""
|
||||
|
||||
# ruff: noqa: PLR0913, PLR0915
|
||||
# ruff: noqa: PLR0913
|
||||
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import openai
|
||||
import sentry_sdk
|
||||
from celery import Celery, signals
|
||||
from celery.utils.log import get_task_logger
|
||||
from minio import Minio
|
||||
from mutagen import File
|
||||
from requests import Session, exceptions
|
||||
from requests.adapters import HTTPAdapter
|
||||
from urllib3.util import Retry
|
||||
|
||||
from summary.core.analytics import MetadataManager, get_analytics
|
||||
from summary.core.config import get_settings
|
||||
from summary.core.file_service import FileService
|
||||
from summary.core.llm_service import LLMException, LLMObservability, LLMService
|
||||
from summary.core.prompt import (
|
||||
FORMAT_NEXT_STEPS,
|
||||
@@ -59,17 +55,7 @@ if settings.sentry_dsn and settings.sentry_is_enabled:
|
||||
sentry_sdk.init(dsn=settings.sentry_dsn, enable_tracing=True)
|
||||
|
||||
|
||||
class AudioValidationError(Exception):
|
||||
"""Custom exception for audio validation errors."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def save_audio_stream(audio_stream, chunk_size=32 * 1024):
|
||||
"""Save an audio stream to a temporary OGG file."""
|
||||
with tempfile.NamedTemporaryFile(suffix=".ogg", delete=False) as tmp:
|
||||
tmp.writelines(audio_stream.stream(chunk_size))
|
||||
return Path(tmp.name)
|
||||
file_service = FileService(logger=logger)
|
||||
|
||||
|
||||
def create_retry_session():
|
||||
@@ -151,40 +137,6 @@ def process_audio_transcribe_summarize_v2(
|
||||
|
||||
task_id = self.request.id
|
||||
|
||||
logger.info("Download recording | Filename: %s", filename)
|
||||
|
||||
minio_client = Minio(
|
||||
settings.aws_s3_endpoint_url,
|
||||
access_key=settings.aws_s3_access_key_id,
|
||||
secret_key=settings.aws_s3_secret_access_key.get_secret_value(),
|
||||
secure=settings.aws_s3_secure_access,
|
||||
)
|
||||
|
||||
logger.debug("Connection to the Minio bucket successful")
|
||||
|
||||
audio_file_stream = minio_client.get_object(
|
||||
settings.aws_storage_bucket_name, object_name=filename
|
||||
)
|
||||
|
||||
temp_file_path = save_audio_stream(audio_file_stream)
|
||||
|
||||
logger.info("Recording successfully downloaded")
|
||||
logger.debug("Recording filepath: %s", temp_file_path)
|
||||
|
||||
audio_file = File(temp_file_path)
|
||||
metadata_manager.track(task_id, {"audio_length": audio_file.info.length})
|
||||
|
||||
if (
|
||||
settings.recording_max_duration is not None
|
||||
and audio_file.info.length > settings.recording_max_duration
|
||||
):
|
||||
error_msg = "Recording too long: %.2fs > %.2fs limit" % (
|
||||
audio_file.info.length,
|
||||
settings.recording_max_duration,
|
||||
)
|
||||
logger.error(error_msg)
|
||||
raise AudioValidationError(error_msg)
|
||||
|
||||
logger.info("Initiating WhisperX client")
|
||||
whisperx_client = openai.OpenAI(
|
||||
api_key=settings.whisperx_api_key.get_secret_value(),
|
||||
@@ -192,31 +144,31 @@ def process_audio_transcribe_summarize_v2(
|
||||
max_retries=settings.whisperx_max_retries,
|
||||
)
|
||||
|
||||
try:
|
||||
with (
|
||||
file_service.prepare_audio_file(filename) as (audio_file, metadata),
|
||||
):
|
||||
metadata_manager.track(task_id, {"audio_length": metadata["duration"]})
|
||||
|
||||
logger.info(
|
||||
"Querying transcription for %s seconds of audio in %s …",
|
||||
audio_file.info.length,
|
||||
"Querying transcription in '%s' language …",
|
||||
language,
|
||||
)
|
||||
transcription_start_time = time.time()
|
||||
with open(temp_file_path, "rb") as audio_file:
|
||||
transcription = whisperx_client.audio.transcriptions.create(
|
||||
model=settings.whisperx_asr_model,
|
||||
file=audio_file,
|
||||
language=language or settings.whisperx_default_language,
|
||||
)
|
||||
|
||||
transcription_time = round(time.time() - transcription_start_time, 2)
|
||||
metadata_manager.track(
|
||||
task_id,
|
||||
{"transcription_time": transcription_time},
|
||||
)
|
||||
logger.info("Transcription received in %s seconds.", transcription_time)
|
||||
logger.debug("Transcription: \n %s", transcription)
|
||||
finally:
|
||||
if os.path.exists(temp_file_path):
|
||||
os.remove(temp_file_path)
|
||||
logger.debug("Temporary file removed: %s", temp_file_path)
|
||||
transcription_start_time = time.time()
|
||||
|
||||
transcription = whisperx_client.audio.transcriptions.create(
|
||||
model=settings.whisperx_asr_model,
|
||||
file=audio_file,
|
||||
language=language or settings.whisperx_default_language,
|
||||
)
|
||||
|
||||
transcription_time = round(time.time() - transcription_start_time, 2)
|
||||
metadata_manager.track(
|
||||
task_id,
|
||||
{"transcription_time": transcription_time},
|
||||
)
|
||||
logger.info("Transcription received in %.2f seconds.", transcription_time)
|
||||
logger.debug("Transcription: \n %s", transcription)
|
||||
|
||||
metadata_manager.track_transcription_metadata(task_id, transcription)
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""Application configuration and settings."""
|
||||
|
||||
from functools import lru_cache
|
||||
from typing import Annotated, List, Optional
|
||||
from typing import Annotated, List, Optional, Set
|
||||
|
||||
from fastapi import Depends
|
||||
from pydantic import SecretStr
|
||||
@@ -19,6 +19,7 @@ class Settings(BaseSettings):
|
||||
|
||||
# Audio recordings
|
||||
recording_max_duration: Optional[int] = None
|
||||
recording_allowed_extensions: Set[str] = {".ogg"}
|
||||
|
||||
# Celery settings
|
||||
celery_broker_url: str = "redis://redis/0"
|
||||
|
||||
133
src/summary/summary/core/file_service.py
Normal file
133
src/summary/summary/core/file_service.py
Normal file
@@ -0,0 +1,133 @@
|
||||
"""File service to encapsulate files' manipulations."""
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
|
||||
import mutagen
|
||||
from minio import Minio
|
||||
|
||||
from summary.core.config import get_settings
|
||||
|
||||
settings = get_settings()
|
||||
|
||||
|
||||
class FileService:
|
||||
"""Service for downloading and preparing files from MinIO storage."""
|
||||
|
||||
def __init__(self, logger):
|
||||
"""Initialize FileService with MinIO client and configuration."""
|
||||
self._logger = logger
|
||||
|
||||
self._minio_client = Minio(
|
||||
settings.aws_s3_endpoint_url,
|
||||
access_key=settings.aws_s3_access_key_id,
|
||||
secret_key=settings.aws_s3_secret_access_key.get_secret_value(),
|
||||
secure=settings.aws_s3_secure_access,
|
||||
)
|
||||
|
||||
self._bucket_name = settings.aws_storage_bucket_name
|
||||
self._stream_chunk_size = 32 * 1024
|
||||
|
||||
self._allowed_extensions = settings.recording_allowed_extensions
|
||||
self._max_duration = settings.recording_max_duration
|
||||
|
||||
def _download_from_minio(self, remote_object_key) -> Path:
|
||||
"""Download file from MinIO to local temporary file.
|
||||
|
||||
The file is downloaded to a temporary location for local manipulation
|
||||
such as validation, conversion, or processing before being used.
|
||||
"""
|
||||
self._logger.info("Download recording | object_key: %s", remote_object_key)
|
||||
|
||||
if not remote_object_key:
|
||||
self._logger.warning("Invalid object_key '%s'", remote_object_key)
|
||||
raise ValueError("Invalid object_key")
|
||||
|
||||
extension = Path(remote_object_key).suffix.lower()
|
||||
|
||||
if extension not in self._allowed_extensions:
|
||||
self._logger.warning("Invalid file extension '%s'", extension)
|
||||
raise ValueError(f"Invalid file extension '{extension}'")
|
||||
|
||||
response = None
|
||||
|
||||
try:
|
||||
response = self._minio_client.get_object(
|
||||
self._bucket_name, remote_object_key
|
||||
)
|
||||
|
||||
with tempfile.NamedTemporaryFile(
|
||||
suffix=extension, delete=False, prefix="minio_download_"
|
||||
) as tmp:
|
||||
for chunk in response.stream(self._stream_chunk_size):
|
||||
tmp.write(chunk)
|
||||
|
||||
tmp.flush()
|
||||
local_path = Path(tmp.name)
|
||||
|
||||
self._logger.info("Recording successfully downloaded")
|
||||
self._logger.debug("Recording local file path: %s", local_path)
|
||||
|
||||
return local_path
|
||||
|
||||
finally:
|
||||
if response:
|
||||
response.close()
|
||||
|
||||
def _validate_duration(self, local_path: Path) -> float:
|
||||
"""Validate audio file duration against configured maximum."""
|
||||
file_metadata = mutagen.File(local_path).info
|
||||
duration = file_metadata.length
|
||||
|
||||
self._logger.info(
|
||||
"Recording file duration: %.2f seconds",
|
||||
duration,
|
||||
)
|
||||
|
||||
if self._max_duration is not None and duration > self._max_duration:
|
||||
error_msg = "Recording too long. Limit is %.2fs seconds" % (
|
||||
self._max_duration,
|
||||
)
|
||||
self._logger.error(error_msg)
|
||||
raise ValueError(error_msg)
|
||||
|
||||
return duration
|
||||
|
||||
@contextmanager
|
||||
def prepare_audio_file(self, remote_object_key: str):
|
||||
"""Download and prepare audio file for processing.
|
||||
|
||||
Downloads file from MinIO, validates duration, and yields an open
|
||||
file handle with metadata. Automatically cleans up temporary files
|
||||
when the context exits.
|
||||
"""
|
||||
downloaded_path = None
|
||||
file_handle = None
|
||||
|
||||
try:
|
||||
downloaded_path = self._download_from_minio(remote_object_key)
|
||||
duration = self._validate_duration(downloaded_path)
|
||||
|
||||
extension = downloaded_path.suffix.lower()
|
||||
metadata = {"duration": duration, "extension": extension}
|
||||
|
||||
file_handle = open(downloaded_path, "rb")
|
||||
yield file_handle, metadata
|
||||
|
||||
finally:
|
||||
if file_handle:
|
||||
file_handle.close()
|
||||
|
||||
for path in [downloaded_path]:
|
||||
if path is None or not os.path.exists(path):
|
||||
continue
|
||||
|
||||
try:
|
||||
os.remove(path)
|
||||
self._logger.debug("Temporary file removed: %s", path)
|
||||
except OSError as e:
|
||||
self._logger.warning(
|
||||
"Failed to remove temporary file %s: %s", path, e
|
||||
)
|
||||
Reference in New Issue
Block a user