diff --git a/src/summary/summary/api/route/tasks.py b/src/summary/summary/api/route/tasks.py index 42b212bb..0c5bc3b7 100644 --- a/src/summary/summary/api/route/tasks.py +++ b/src/summary/summary/api/route/tasks.py @@ -1,10 +1,15 @@ """API routes related to application tasks.""" +from typing import Optional + from celery.result import AsyncResult from fastapi import APIRouter from pydantic import BaseModel -from summary.core.celery_worker import process_audio_transcribe_summarize +from summary.core.celery_worker import ( + process_audio_transcribe_summarize, + process_audio_transcribe_summarize_v2, +) class TaskCreation(BaseModel): @@ -13,6 +18,7 @@ class TaskCreation(BaseModel): filename: str email: str sub: str + version: Optional[int] = 2 router = APIRouter(prefix="/tasks") @@ -21,9 +27,15 @@ router = APIRouter(prefix="/tasks") @router.post("/") async def create_task(request: TaskCreation): """Create a task.""" - task = process_audio_transcribe_summarize.delay( - request.filename, request.email, request.sub - ) + if request.version == 1: + task = process_audio_transcribe_summarize.delay( + request.filename, request.email, request.sub + ) + else: + task = process_audio_transcribe_summarize_v2.delay( + request.filename, request.email, request.sub + ) + return {"id": task.id, "message": "Task created"} diff --git a/src/summary/summary/core/celery_worker.py b/src/summary/summary/core/celery_worker.py index 9ff55578..ad637f4c 100644 --- a/src/summary/summary/core/celery_worker.py +++ b/src/summary/summary/core/celery_worker.py @@ -56,6 +56,35 @@ def create_retry_session(): return session +def format_segments(transcription_data): + """Format transcription segments from WhisperX into a readable conversation format. + + Processes transcription data with segments containing speaker information and text, + combining consecutive segments from the same speaker and formatting them as a + conversation with speaker labels. + """ + formatted_output = "" + if not transcription_data or not hasattr(transcription_data, 'segments'): + if isinstance(transcription_data, dict) and 'segments' in transcription_data: + segments = transcription_data['segments'] + else: + return "Error: Invalid transcription data format" + else: + segments = transcription_data.segments + + previous_speaker = None + + for segment in segments: + speaker = segment.get('speaker', 'UNKNOWN_SPEAKER') + text = segment.get('text', '') + if text: + if speaker != previous_speaker: + formatted_output += f"\n\n *{speaker}*: {text}" + else: + formatted_output += f" {text}" + previous_speaker = speaker + return formatted_output + def post_with_retries(url, data): """Send POST request with automatic retries.""" session = create_retry_session() @@ -140,3 +169,72 @@ def process_audio_transcribe_summarize(filename: str, email: str, sub: str): logger.info("Webhook submitted successfully. Status: %s", response.status_code) logger.debug("Response body: %s", response.text) + + +@celery.task(max_retries=settings.celery_max_retries) +def process_audio_transcribe_summarize_v2(filename: str, email: str, sub: str): + """Process an audio file by transcribing it and generating a summary. + + This Celery task performs the following operations: + 1. Retrieves the audio file from MinIO storage + 2. Transcribes the audio using WhisperX model + 3. Sends the results via webhook + + """ + logger.info("Notification received") + logger.debug("filename: %s", filename) + + minio_client = Minio( + settings.aws_s3_endpoint_url, + access_key=settings.aws_s3_access_key_id, + secret_key=settings.aws_s3_secret_access_key, + secure=settings.aws_s3_secure_access, + ) + + logger.debug("Connection to the Minio bucket successful") + + audio_file_stream = minio_client.get_object( + settings.aws_storage_bucket_name, object_name=filename + ) + + temp_file_path = save_audio_stream(audio_file_stream) + logger.debug("Recording successfully downloaded, filepath: %s", temp_file_path) + + logger.debug("Initiating OpenAI client") + openai_client = openai.OpenAI( + api_key=settings.openai_api_key, base_url=settings.openai_base_url + ) + + try: + logger.debug("Querying transcription …") + with open(temp_file_path, "rb") as audio_file: + transcription = openai_client.audio.transcriptions.create( + model=settings.openai_asr_model, file=audio_file + ) + + logger.debug("Transcription: \n %s", transcription) + finally: + if os.path.exists(temp_file_path): + os.remove(temp_file_path) + logger.debug("Temporary file removed: %s", temp_file_path) + + formatted_transcription = format_segments(transcription) + + data = { + "title": "Transcription", + "content": formatted_transcription, + "email": email, + "sub": sub, + } + + logger.debug("Submitting webhook to %s", settings.webhook_url) + logger.debug("Request payload: %s", json.dumps(data, indent=2)) + + response = post_with_retries(settings.webhook_url, data) + + logger.info("Webhook submitted successfully. Status: %s", response.status_code) + logger.debug("Response body: %s", response.text) + + # TODO - integrate summarize the transcript and create a new document. + +