💥(summary) replace Whisper with WhisperX for transcription
Switch from insanely-fast-whisper to WhisperX for transcription services, gaining word-level granularity not available with FlashAttention in previous API. Whisper was packaged in FastAPI image with basic diarization and speaker-segment association capabilities. Implementation prioritized speed to market over testing. Future iterations will enhance diarization quality and add proper test coverage. API consumers should expect behavioral differences and new capabilities.
This commit is contained in:
@@ -1,10 +1,15 @@
|
||||
"""API routes related to application tasks."""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from celery.result import AsyncResult
|
||||
from fastapi import APIRouter
|
||||
from pydantic import BaseModel
|
||||
|
||||
from summary.core.celery_worker import process_audio_transcribe_summarize
|
||||
from summary.core.celery_worker import (
|
||||
process_audio_transcribe_summarize,
|
||||
process_audio_transcribe_summarize_v2,
|
||||
)
|
||||
|
||||
|
||||
class TaskCreation(BaseModel):
|
||||
@@ -13,6 +18,7 @@ class TaskCreation(BaseModel):
|
||||
filename: str
|
||||
email: str
|
||||
sub: str
|
||||
version: Optional[int] = 2
|
||||
|
||||
|
||||
router = APIRouter(prefix="/tasks")
|
||||
@@ -21,9 +27,15 @@ router = APIRouter(prefix="/tasks")
|
||||
@router.post("/")
|
||||
async def create_task(request: TaskCreation):
|
||||
"""Create a task."""
|
||||
task = process_audio_transcribe_summarize.delay(
|
||||
request.filename, request.email, request.sub
|
||||
)
|
||||
if request.version == 1:
|
||||
task = process_audio_transcribe_summarize.delay(
|
||||
request.filename, request.email, request.sub
|
||||
)
|
||||
else:
|
||||
task = process_audio_transcribe_summarize_v2.delay(
|
||||
request.filename, request.email, request.sub
|
||||
)
|
||||
|
||||
return {"id": task.id, "message": "Task created"}
|
||||
|
||||
|
||||
|
||||
@@ -56,6 +56,35 @@ def create_retry_session():
|
||||
return session
|
||||
|
||||
|
||||
def format_segments(transcription_data):
|
||||
"""Format transcription segments from WhisperX into a readable conversation format.
|
||||
|
||||
Processes transcription data with segments containing speaker information and text,
|
||||
combining consecutive segments from the same speaker and formatting them as a
|
||||
conversation with speaker labels.
|
||||
"""
|
||||
formatted_output = ""
|
||||
if not transcription_data or not hasattr(transcription_data, 'segments'):
|
||||
if isinstance(transcription_data, dict) and 'segments' in transcription_data:
|
||||
segments = transcription_data['segments']
|
||||
else:
|
||||
return "Error: Invalid transcription data format"
|
||||
else:
|
||||
segments = transcription_data.segments
|
||||
|
||||
previous_speaker = None
|
||||
|
||||
for segment in segments:
|
||||
speaker = segment.get('speaker', 'UNKNOWN_SPEAKER')
|
||||
text = segment.get('text', '')
|
||||
if text:
|
||||
if speaker != previous_speaker:
|
||||
formatted_output += f"\n\n *{speaker}*: {text}"
|
||||
else:
|
||||
formatted_output += f" {text}"
|
||||
previous_speaker = speaker
|
||||
return formatted_output
|
||||
|
||||
def post_with_retries(url, data):
|
||||
"""Send POST request with automatic retries."""
|
||||
session = create_retry_session()
|
||||
@@ -140,3 +169,72 @@ def process_audio_transcribe_summarize(filename: str, email: str, sub: str):
|
||||
|
||||
logger.info("Webhook submitted successfully. Status: %s", response.status_code)
|
||||
logger.debug("Response body: %s", response.text)
|
||||
|
||||
|
||||
@celery.task(max_retries=settings.celery_max_retries)
|
||||
def process_audio_transcribe_summarize_v2(filename: str, email: str, sub: str):
|
||||
"""Process an audio file by transcribing it and generating a summary.
|
||||
|
||||
This Celery task performs the following operations:
|
||||
1. Retrieves the audio file from MinIO storage
|
||||
2. Transcribes the audio using WhisperX model
|
||||
3. Sends the results via webhook
|
||||
|
||||
"""
|
||||
logger.info("Notification received")
|
||||
logger.debug("filename: %s", filename)
|
||||
|
||||
minio_client = Minio(
|
||||
settings.aws_s3_endpoint_url,
|
||||
access_key=settings.aws_s3_access_key_id,
|
||||
secret_key=settings.aws_s3_secret_access_key,
|
||||
secure=settings.aws_s3_secure_access,
|
||||
)
|
||||
|
||||
logger.debug("Connection to the Minio bucket successful")
|
||||
|
||||
audio_file_stream = minio_client.get_object(
|
||||
settings.aws_storage_bucket_name, object_name=filename
|
||||
)
|
||||
|
||||
temp_file_path = save_audio_stream(audio_file_stream)
|
||||
logger.debug("Recording successfully downloaded, filepath: %s", temp_file_path)
|
||||
|
||||
logger.debug("Initiating OpenAI client")
|
||||
openai_client = openai.OpenAI(
|
||||
api_key=settings.openai_api_key, base_url=settings.openai_base_url
|
||||
)
|
||||
|
||||
try:
|
||||
logger.debug("Querying transcription …")
|
||||
with open(temp_file_path, "rb") as audio_file:
|
||||
transcription = openai_client.audio.transcriptions.create(
|
||||
model=settings.openai_asr_model, file=audio_file
|
||||
)
|
||||
|
||||
logger.debug("Transcription: \n %s", transcription)
|
||||
finally:
|
||||
if os.path.exists(temp_file_path):
|
||||
os.remove(temp_file_path)
|
||||
logger.debug("Temporary file removed: %s", temp_file_path)
|
||||
|
||||
formatted_transcription = format_segments(transcription)
|
||||
|
||||
data = {
|
||||
"title": "Transcription",
|
||||
"content": formatted_transcription,
|
||||
"email": email,
|
||||
"sub": sub,
|
||||
}
|
||||
|
||||
logger.debug("Submitting webhook to %s", settings.webhook_url)
|
||||
logger.debug("Request payload: %s", json.dumps(data, indent=2))
|
||||
|
||||
response = post_with_retries(settings.webhook_url, data)
|
||||
|
||||
logger.info("Webhook submitted successfully. Status: %s", response.status_code)
|
||||
logger.debug("Response body: %s", response.text)
|
||||
|
||||
# TODO - integrate summarize the transcript and create a new document.
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user