Skip to content

zaris-ai/avwhisper

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

35 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

AVWhisper

Audio-Visual Speech Recognition for the UASpeech dysarthric-speech dataset. No fairseq — pure PyTorch + HuggingFace Transformers + PyTorch Lightning.

  • Unified entrypoint for audio-only (Whisper) and audio-visual (Whisper + AV-HuBERT / ResNet video) training.
  • Configurable fusion: encoder-only, decoder-only, both, or learned soft routing per sample.
  • Per-intelligibility-level metrics: WER / CER / WSR for very_low, low, mid, high, and total.
  • Character-level training for dysarthric speech (space-separated characters).

Documentation

Document What's in it
docs/ARCHITECTURE.md Approved repository layout, package boundaries, and entry points.
This README — CLI commands How to prepare data, train, validate, and run multi-seed workflows.
docs/API.md Full code contract: every module, class, function signature, tensor shapes, and invariants. Start here when reading or extending the code.
docs/UASpeech_Dataset_Info.md UASpeech layout, vocabulary filtering, B1/B2/B3 splits, intelligibility levels, and sample counts as loaded by the project.
docs/Gating_And_Routing.md Design rationale: why Whisper-Flamingo uses learned gates inside each fused block, why FusionRouter adds per-sample encoder/decoder scales, and what would change if the router emitted gates directly.

Project structure

avwhisper/              # Python package
  core/                 # TrainingConfig, UASpeech constants
  data/                 # Dataset loaders and collators
  models/               # Whisper AV fusion + video encoders
  training/             # Lightning module, metrics, pipeline
  preparation/          # Offline UASpeech dataset prep
  utils/                # Shared helpers (ffmpeg, runtime)
  cli/                  # Command-line entry points
configs/                # Training YAML configs
docs/                   # Documentation
tests/                  # Tests

Full layout and dependency rules: docs/ARCHITECTURE.md.

Quick setup

cd /mnt/data/projects/avwhisper
./setup.sh
source .venv/bin/activate

Or with Make:

make setup
source .venv/bin/activate

This creates .venv/, installs the package in editable mode, and installs requirements.txt. See pyproject.toml and setup.sh for details.

Dataset

Place the prepared dataset at uaspeech_root (default /home/park/UASpeech-prepared for AV and audio pretrain):

{uaspeech_root}/
  speaker_wordlist.xls
  video/roi/{speaker}/*.mp4              # ROI lip clips
  video/{speaker}/outputfiles30/*.mp4    # original 30 fps video/audio
  Audio/vad/{speaker}/*.wav              # optional VAD audio

Splits: B1 + B3 → train, B2 → valid (= test). 12 speakers across 4 intelligibility levels. Full details — including per-level sample counts, vocabulary filtering rules, and per-speaker level mapping — are in docs/UASpeech_Dataset_Info.md.

Prepare dataset (from raw UASpeech)

If you have the official home release (block videos + frame indices), build the AVWhisper layout with:

avwhisper-prepare-dataset \
  --source /home/park/UASpeech \
  --output /home/park/UASpeech-prepared

Then set uaspeech_root: /home/park/UASpeech-prepared in your training YAML.

Useful flags: --overwrite (rebuild existing clips), --skip-roi (clips only), --speakers M09 (single speaker), --validate-only (check existing output).

Requires pip install -e . so console commands are on your PATH.

CLI commands

Run all commands from the repo root with the venv active (source .venv/bin/activate) after pip install -e ..

Task Command
Prepare dataset avwhisper-prepare-dataset --source SRC --output DST
Download Whisper avwhisper-download-model small
Train avwhisper-train configs/audio_ua.yaml
Validate / test set validate_model: true in YAML, then avwhisper-train <config>
Multi-seed train avwhisper-train-multi-seed configs/av_soft_ua.yaml --num_seeds 10
Multi-seed validate avwhisper-validate-ensemble configs/av_soft_ua.yaml --num_seeds 10

Alternative (module style, no console scripts needed):

python -m avwhisper.cli.prepare_dataset --source SRC --output DST
python -m avwhisper.cli.train configs/audio_ua.yaml

Download Whisper (optional, for offline runs)

avwhisper-download-model small

Model names: tiny, base, small, medium, large, large-v2, large-v3.

Then set local_model_only: true in your config.

Train

Behavior is fully driven by the YAML config. Pre-built configs are in configs/:

avwhisper-train configs/audio_ua.yaml              # audio pretrain (12 spk, video-synced)
avwhisper-train configs/audio_ua_isolated.yaml     # optional: 28 spk isolated WAV
avwhisper-train configs/av_encoder_ua.yaml         # AV fusion in encoder
avwhisper-train configs/av_decoder_ua.yaml         # AV fusion in decoder
avwhisper-train configs/av_both_ua.yaml            # AV fusion in encoder + decoder
avwhisper-train configs/av_soft_ua.yaml            # learned soft routing

Or via Make: make train-audio, make train-av.

Checkpoints are saved to weights/checkpoint/{train_id}/best.ckpt.
TensorBoard logs go to results/{train_name}/{train_id}/.

Recommended training order

  1. Audio pretrain (12 speakers, video-synced, 5 seeds)
    avwhisper-train-multi-seed configs/audio_ua.yaml --num_seeds 5
    Checkpoints: weights/checkpoint/whisper_en_small_ua_synced12_{N}/best.ckpt
  2. AV both (5 seeds, frozen Whisper)
    avwhisper-train-multi-seed configs/av_both_ua.yaml --num_seeds 5
    Uses {seed} in pt_audio_ckpt to match each audio seed.
  3. Average metrics
    avwhisper-validate-ensemble configs/audio_ua.yaml --num_seeds 5
    avwhisper-validate-ensemble configs/av_both_ua.yaml --num_seeds 5
  4. (Optional) Soft routing — set pt_video_ckpt to the AV-both checkpoint

Single-seed shortcut: avwhisper-train configs/audio_ua.yaml then avwhisper-train configs/av_both_ua.yaml (set pt_audio_ckpt without {seed}).

Configs use SpecAugment, label smoothing, and early stopping for small-data regularization.

Validate / test (single checkpoint)

UASpeech uses block B2 for both validation and test. To run evaluation only (no training), set in your YAML:

validate_model: true

Then run:

avwhisper-train configs/audio_ua.yaml
  • If weights/checkpoint/{train_id}/best.ckpt exists → validates from that checkpoint
  • Otherwise → validates with current (initialized) weights

Useful related config flags:

Flag Effect
validation_decoding teacher_forced or autoregressive
per_speaker_metrics Print per-speaker WER/CER/WSR breakdown

Multi-seed training

Run the same config multiple times with different seeds. Each run gets a suffixed train_id ({train_id}_1, {train_id}_2, …):

avwhisper-train-multi-seed configs/av_soft_ua.yaml --num_seeds 10 --start_seed 1
Flag Default Meaning
--num_seeds 10 Number of training runs
--start_seed 1 First seed index (produces _1, _2, …)

Checkpoints: weights/checkpoint/{train_id}_N/best.ckpt for each seed N.

See docs/MULTI_SEED_TRAINING.md for more detail.

Multi-seed validation (ensemble metrics)

After multi-seed training, validate every seed checkpoint and average the reported metrics (WER/CER/WSR per level):

avwhisper-validate-ensemble configs/av_soft_ua.yaml --num_seeds 10 --start_seed 1
Flag Default Meaning
--num_seeds 10 Number of seed checkpoints to look for
--start_seed 1 First seed index
--output_file auto JSON output path (CSV written alongside)

Default output: results/{train_name}/{train_id}_seed_validation.json (+ .csv).

Missing checkpoints are skipped with a warning; at least one must exist.

Soft fusion routing

When soft_fusion_routing: true, both encoder and decoder layers are wrapped with gated cross-attention. A FusionRouter MLP consumes pooled statistics of the audio mel features and projected video features and produces three softmax probabilities — encoder-only, decoder-only, both — from which:

  • encoder_scale = P(encoder) + P(both)
  • decoder_scale = P(decoder) + P(both)

These scale the gated cross-attention outputs per sample. With router_train_only: true, only the router updates; everything else stays frozen. With routing_use_level: true, the router is also conditioned on the UASpeech intelligibility level.

Recommended: train AV-both first, then point pt_video_ckpt at that checkpoint before running soft routing.

FusionRouter architecture

Inputs to the router:

  • Audio: the raw Whisper log-mel spectrogram (input_features, shape [B, num_mel_bins, T_audio]). The router does not see Whisper encoder hidden states.
  • Video: projected video features (video_states, shape [B, T_vid, d_model]) — i.e. the output of the video encoder (AVHubertVideoEncoder or ResNetVideoEncoder) after the video_proj Linear that maps to Whisper's d_model.
flowchart TD
    AF["input_features (log-mel)\n[B, num_mel_bins, T_audio]"]
    VF["video_states (post video_encoder + video_proj)\n[B, T_vid, d_model]"]

    AF -->|"transpose + mean_std over time\n[B, 2 * num_mel_bins]"| AP["audio_proj\nLayerNorm -> Linear -> GELU -> Dropout\n-> [B, hidden_dim]"]
    VF -->|"masked mean_std over time\n[B, 2 * d_model]"| VP["video_proj_router\nLayerNorm -> Linear -> GELU -> Dropout\n-> [B, hidden_dim]"]

    AP --> CAT["concat [B, 2 * hidden_dim]\n(+ level_embed if routing_use_level)"]
    VP --> CAT
    LE["Level embedding\n[B, level_embed_dim]"] -.->|optional| CAT

    CAT --> CLS["classifier\nLinear -> GELU -> Dropout -> Linear\n-> logits [B, 3]"]
    CLS -->|softmax| PROBS["route_probs [B, 3]\nP(enc_only), P(dec_only), P(both)"]

    PROBS -->|"P(enc_only) + P(both)"| ES["encoder_scale [B]"]
    PROBS -->|"P(dec_only) + P(both)"| DS["decoder_scale [B]"]

    ES -->|"scale * gated xattn output"| ENC["Whisper Encoder layers"]
    DS -->|"scale * gated xattn output"| DEC["Whisper Decoder layers"]
Loading

Pooling is an unweighted mean + std over the time axis — every frame contributes equally (video uses video_padding_mask so padded frames are excluded).

Implementation: see FusionRouter, GatedCrossAttentionBlock, WhisperFlamingoEncoderLayer, WhisperFlamingoDecoderLayer, and WhisperVideoFusion in avwhisper/model.py. Reference: API §5. Design rationale (why gates and scales, and what changes if the router emits gates directly): docs/Gating_And_Routing.md.

Config options (cheat sheet)

Option Description
modality audio or audio_visual
fusion_location encoder, decoder, or both (audio-visual only)
char_level Character-level training (space between chars)
local_model_only true = load Whisper from HF cache only (no network)
pt_audio_ckpt Pretrained audio-only checkpoint (used to init AV backbone)
pt_video_ckpt Pretrained audio-visual checkpoint (resume / soft-routing init)
freeze_whisper Freeze Whisper backbone during AV fine-tuning
unfreeze_mlp_layers Also unfreeze encoder/decoder fc1/fc2 when freeze_whisper
soft_fusion_routing Enable learned soft routing between encoder/decoder fusion
routing_hidden_dim Hidden size of the FusionRouter MLP (default 256)
routing_dropout Dropout inside the FusionRouter (default 0.1)
routing_use_level Condition the router on UASpeech intelligibility level
router_train_only Freeze all parameters except the FusionRouter
validate_model Run trainer.validate() instead of trainer.fit()
validation_decoding teacher_forced or autoregressive
speaker_weights Per-speaker sampling weights for WeightedRandomSampler
per_speaker_metrics Also print per-speaker WER/CER/WSR at end of validation

Full field list with defaults and types: API §3. Pre-built config matrix: API §12.

Modules

Package Key modules
avwhisper.core config.py, constants.py
avwhisper.models whisper_av.py, avhubert.py, resnet.py
avwhisper.training module.py, metrics.py, optimizer.py, pipeline.py
avwhisper.data uaspeech.py, dataset.py, video.py
avwhisper.preparation offline dataset prep (prepare.py, outputfiles30.py, roi.py, …)
avwhisper.cli train.py, prepare_dataset.py, …

Legacy import paths (avwhisper.config, avwhisper.model, avwhisper.lightning, …) remain as thin re-export shims.

For per-function contracts and tensor shapes, see docs/API.md.

Dependencies

See requirements.txt and pyproject.toml. Core: torch, torchaudio, pytorch-lightning >= 2, transformers, editdistance, pandas, openpyxl, opencv-python(-headless), librosa, soundfile, pyyaml, tqdm.

Intelligibility levels (UASpeech)

Level Label Speakers
0 very_low M12
1 low F02, M07, M16
2 mid F04, M05, M11
3 high F05, M08, M09, M10, M14

Sample counts per split and per level are in docs/UASpeech_Dataset_Info.md.

About

Audio-visual speech recognition for the UASpeech dysarthric-speech dataset. Whisper + AV-HuBERT fusion in PyTorch Lightning, with audio-only and AV training, dataset prep, and per-intelligibility metrics.

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages