Batch Processing Transcripts with Celery

Implementing reliable batch transcription at scale requires decoupling heavy inference workloads from synchronous API gateways. When processing podcast episodes, multi-track video interviews, or archival audio, Celery provides deterministic task routing, explicit retry semantics, and horizontal worker scaling. The architecture must handle Whisper Large V3 inference, Pyannote speaker diarization, timestamp alignment, and audio normalization without exhausting GPU memory or triggering broker connection storms. Proper configuration of Celery’s concurrency model, acknowledgment policies, and chunking strategy prevents cascading failures during high-throughput transcription bursts.

Celery Worker Configuration and Broker Setup

Celery’s default concurrency model relies on prefork, but for GPU-bound transcription workloads, you must isolate inference tasks to dedicated workers with explicit resource limits. The broker should run with persistence enabled, and the result backend must be configured for task state tracking across restarts. For production deployments, consult the official Celery configuration documentation to validate broker-specific tuning parameters.

# celery_config.py
import os

# Broker and backend configuration
broker_url = os.getenv("CELERY_BROKER_URL", "redis://redis-broker:6379/0")
result_backend = os.getenv("CELERY_RESULT_BACKEND", "redis://redis-broker:6379/1")

# Prevent task loss during worker restarts
task_acks_late = True
task_reject_on_worker_lost = True
task_ignore_result = False

# GPU worker isolation: one task at a time to prevent CUDA OOM
worker_concurrency = 1
worker_prefetch_multiplier = 1
worker_max_tasks_per_child = 50  # Force memory cleanup after N tasks

# Retry and timeout defaults
task_default_retry_delay = 15
task_default_max_retries = 3
task_soft_time_limit = 1800
task_time_limit = 2100

# Serialization
task_serializer = "json"
result_serializer = "json"
accept_content = ["json"]

# Logging
worker_hijack_root_logger = False
worker_log_format = "%(asctime)s [%(levelname)s] %(name)s: %(message)s"
worker_task_log_format = "%(asctime)s [%(levelname)s] %(name)s.%(task_name)s: %(message)s"

The worker_max_tasks_per_child parameter is critical for preventing CUDA context fragmentation and memory leaks in long-running Whisper inference loops. Setting task_acks_late = True ensures tasks are only acknowledged after successful completion, which is mandatory when integrating with Async Transcription Queue Management architectures that rely on idempotent task execution.

Task Implementation and Chunking Strategy

Transcription tasks must split long audio files into manageable segments before passing them to Whisper Large V3. Chunking prevents out-of-memory (OOM) crashes and enables parallel diarization alignment. The following task definition implements VAD-based segmentation, batch inference, and diarization handoff with explicit error boundaries and diagnostic telemetry.

# tasks/transcription.py
import logging
import os
import traceback
import numpy as np
import torch
import librosa
from celery import Celery, shared_task
from pyannote.audio import Pipeline
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor

app = Celery()
app.config_from_object("celery_config")

logger = logging.getLogger(__name__)

WHISPER_MODEL_ID = "openai/whisper-large-v3"
PYANNOTE_MODEL_ID = "pyannote/speaker-diarization-3.1"

try:
    whisper_processor = AutoProcessor.from_pretrained(WHISPER_MODEL_ID)
    whisper_model = AutoModelForSpeechSeq2Seq.from_pretrained(
        WHISPER_MODEL_ID,
        torch_dtype=torch.float16,
        low_cpu_mem_usage=True,
        use_safetensors=True
    ).to("cuda")
    whisper_model.eval()

    diarization_pipeline = Pipeline.from_pretrained(
        PYANNOTE_MODEL_ID,
        use_auth_token=os.getenv("HF_TOKEN")
    )
    logger.info("Successfully loaded Whisper Large V3 and Pyannote 3.1 models.")
except Exception as e:
    logger.critical("Model initialization failed: %s", e)
    raise RuntimeError("GPU model preload failed. Worker cannot start.") from e


def chunk_audio_by_vad(
    audio_path: str,
    sample_rate: int = 16000,
    max_chunk_sec: int = 30,
    top_db: int = 30
) -> list[tuple[np.ndarray, float, float]]:
    """Split audio on energy-based VAD intervals, capped at max_chunk_sec to prevent OOM."""
    y, sr = librosa.load(audio_path, sr=sample_rate)
    intervals = librosa.effects.split(y, top_db=top_db)
    max_samples = int(max_chunk_sec * sr)
    chunks: list[tuple[np.ndarray, float, float]] = []
    for start_idx, end_idx in intervals:
        cursor = start_idx
        while cursor < end_idx:
            window_end = min(cursor + max_samples, end_idx)
            chunks.append((y[cursor:window_end], cursor / sr, window_end / sr))
            cursor = window_end
    return chunks


@shared_task(bind=True, max_retries=3, default_retry_delay=15)
def process_transcription_batch(self, audio_url: str, task_id: str) -> dict:
    """Execute batch transcription with chunking, diarization, and explicit diagnostics."""
    temp_path = None
    try:
        logger.info("Task %s: Preparing audio: %s", self.request.id, audio_url)
        temp_path = f"/tmp/{task_id}.wav"
        # In production, replace with a robust downloader (e.g., requests + streaming write)
        # download_audio(audio_url, temp_path)

        chunks = chunk_audio_by_vad(temp_path)
        logger.info("Task %s: Audio split into %d chunks.", self.request.id, len(chunks))

        transcript_segments = []
        for i, (chunk_audio, start_sec, end_sec) in enumerate(chunks):
            # Match the model's fp16 dtype to avoid a CUDA dtype mismatch at forward.
            inputs = whisper_processor(
                chunk_audio, sampling_rate=16000, return_tensors="pt"
            ).to("cuda", dtype=torch.float16)
            with torch.no_grad():
                generated_ids = whisper_model.generate(inputs["input_features"])
            transcription = whisper_processor.batch_decode(
                generated_ids, skip_special_tokens=True
            )[0].strip()
            transcript_segments.append({
                "chunk_index": i,
                "start": start_sec,
                "end": end_sec,
                "text": transcription
            })

            if torch.cuda.is_available():
                mem_alloc = torch.cuda.memory_allocated() / (1024 ** 2)
                logger.debug("Chunk %d processed. GPU VRAM allocated: %.2f MB", i, mem_alloc)

        logger.info("Task %s: Running Pyannote diarization.", self.request.id)
        diarization_result = diarization_pipeline(temp_path)

        aligned_transcript = []
        for seg in transcript_segments:
            speaker = "SPEAKER_00"
            for turn, _, speaker_label in diarization_result.itertracks(yield_label=True):
                if turn.start <= seg["start"] <= turn.end:
                    speaker = speaker_label
                    break
            aligned_transcript.append({
                "start": seg["start"],
                "end": seg["end"],
                "speaker": speaker,
                "text": seg["text"]
            })

        logger.info("Task %s: Successfully aligned %d segments.", self.request.id, len(aligned_transcript))
        return {"status": "success", "segments": aligned_transcript}

    except Exception as exc:
        logger.error("Task %s failed: %s", self.request.id, traceback.format_exc())
        raise self.retry(exc=exc, countdown=15 * (2 ** self.request.retries))
    finally:
        if temp_path and os.path.exists(temp_path):
            os.remove(temp_path)
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
            logger.debug("Task %s: CUDA cache cleared.", self.request.id)

Diagnostics, Retry Semantics, and Pipeline Execution

Production transcription pipelines require explicit failure boundaries. The @shared_task decorator above implements exponential backoff via self.retry(), ensuring transient network drops or temporary GPU scheduling delays do not permanently fail a job. When integrating with Transcription & Speaker Diarization workflows, validate that chunk boundaries align with natural speech pauses to prevent mid-word truncation.

For low-quality audio sources, apply pre-processing normalization before chunking. Use librosa.effects.trim() to remove leading and trailing silence and reduce background noise impact on Whisper’s phoneme recognition accuracy. When routing to cost-optimized endpoints, monitor task_time_limit violations and route failed chunks to fallback models or human review queues.

To deploy workers, execute:

celery -A tasks.transcription worker --loglevel=info --pool=solo --hostname=gpu_worker@%h

The --pool=solo flag enforces single-threaded execution per worker, preventing race conditions during CUDA context initialization. Monitor queue depth and worker health via Flower (celery flower) or Prometheus exporters. Explicit diagnostics, such as VRAM allocation logging and task_acks_late enforcement, guarantee that partial failures do not corrupt downstream timestamp alignment or archival metadata.