💥(summary) replace Whisper with WhisperX for transcription

Switch from insanely-fast-whisper to WhisperX for transcription services,
gaining word-level granularity not available with FlashAttention in previous
API. Whisper was packaged in FastAPI image with basic diarization and
speaker-segment association capabilities.

Implementation prioritized speed to market over testing. Future iterations
will enhance diarization quality and add proper test coverage. API consumers
should expect behavioral differences and new capabilities.
This commit is contained in:
lebaudantoine
2025-03-25 16:49:18 +01:00
parent 28538b63da
commit 53f1c9c1d4
2 changed files with 114 additions and 4 deletions

View File

@@ -1,10 +1,15 @@
"""API routes related to application tasks."""
from typing import Optional
from celery.result import AsyncResult
from fastapi import APIRouter
from pydantic import BaseModel
from summary.core.celery_worker import process_audio_transcribe_summarize
from summary.core.celery_worker import (
process_audio_transcribe_summarize,
process_audio_transcribe_summarize_v2,
)
class TaskCreation(BaseModel):
@@ -13,6 +18,7 @@ class TaskCreation(BaseModel):
filename: str
email: str
sub: str
version: Optional[int] = 2
router = APIRouter(prefix="/tasks")
@@ -21,9 +27,15 @@ router = APIRouter(prefix="/tasks")
@router.post("/")
async def create_task(request: TaskCreation):
"""Create a task."""
task = process_audio_transcribe_summarize.delay(
request.filename, request.email, request.sub
)
if request.version == 1:
task = process_audio_transcribe_summarize.delay(
request.filename, request.email, request.sub
)
else:
task = process_audio_transcribe_summarize_v2.delay(
request.filename, request.email, request.sub
)
return {"id": task.id, "message": "Task created"}

View File

@@ -56,6 +56,35 @@ def create_retry_session():
return session
def format_segments(transcription_data):
"""Format transcription segments from WhisperX into a readable conversation format.
Processes transcription data with segments containing speaker information and text,
combining consecutive segments from the same speaker and formatting them as a
conversation with speaker labels.
"""
formatted_output = ""
if not transcription_data or not hasattr(transcription_data, 'segments'):
if isinstance(transcription_data, dict) and 'segments' in transcription_data:
segments = transcription_data['segments']
else:
return "Error: Invalid transcription data format"
else:
segments = transcription_data.segments
previous_speaker = None
for segment in segments:
speaker = segment.get('speaker', 'UNKNOWN_SPEAKER')
text = segment.get('text', '')
if text:
if speaker != previous_speaker:
formatted_output += f"\n\n *{speaker}*: {text}"
else:
formatted_output += f" {text}"
previous_speaker = speaker
return formatted_output
def post_with_retries(url, data):
"""Send POST request with automatic retries."""
session = create_retry_session()
@@ -140,3 +169,72 @@ def process_audio_transcribe_summarize(filename: str, email: str, sub: str):
logger.info("Webhook submitted successfully. Status: %s", response.status_code)
logger.debug("Response body: %s", response.text)
@celery.task(max_retries=settings.celery_max_retries)
def process_audio_transcribe_summarize_v2(filename: str, email: str, sub: str):
"""Process an audio file by transcribing it and generating a summary.
This Celery task performs the following operations:
1. Retrieves the audio file from MinIO storage
2. Transcribes the audio using WhisperX model
3. Sends the results via webhook
"""
logger.info("Notification received")
logger.debug("filename: %s", filename)
minio_client = Minio(
settings.aws_s3_endpoint_url,
access_key=settings.aws_s3_access_key_id,
secret_key=settings.aws_s3_secret_access_key,
secure=settings.aws_s3_secure_access,
)
logger.debug("Connection to the Minio bucket successful")
audio_file_stream = minio_client.get_object(
settings.aws_storage_bucket_name, object_name=filename
)
temp_file_path = save_audio_stream(audio_file_stream)
logger.debug("Recording successfully downloaded, filepath: %s", temp_file_path)
logger.debug("Initiating OpenAI client")
openai_client = openai.OpenAI(
api_key=settings.openai_api_key, base_url=settings.openai_base_url
)
try:
logger.debug("Querying transcription …")
with open(temp_file_path, "rb") as audio_file:
transcription = openai_client.audio.transcriptions.create(
model=settings.openai_asr_model, file=audio_file
)
logger.debug("Transcription: \n %s", transcription)
finally:
if os.path.exists(temp_file_path):
os.remove(temp_file_path)
logger.debug("Temporary file removed: %s", temp_file_path)
formatted_transcription = format_segments(transcription)
data = {
"title": "Transcription",
"content": formatted_transcription,
"email": email,
"sub": sub,
}
logger.debug("Submitting webhook to %s", settings.webhook_url)
logger.debug("Request payload: %s", json.dumps(data, indent=2))
response = post_with_retries(settings.webhook_url, data)
logger.info("Webhook submitted successfully. Status: %s", response.status_code)
logger.debug("Response body: %s", response.text)
# TODO - integrate summarize the transcript and create a new document.