UNITE: Distilling UNI-2 into a Lightweight Vision Transformer for Histopathology
UNITE (UNified Image-to-Text Embedding Distillation) is a lightweight student Vision Transformer (ViT-Small) distilled from the large UNI-2 foundation model.
It uses a symmetric CLIP-style contrastive alignment loss to replicate the geometry of the teacher embedding spaceβwithout needing access to the teacherβs weights.
π Key Highlights Efficiency
Compresses UNI-2 (~681M parameters) β UNITE (~19M parameters) 35Γ reduction
Specialization
Outperforms the teacher on linear probing (BACC) for CPTAC-OV β Distillation acts as domain specialization
Black-Box Distillation
Requires only teacher embeddings, not weights
Works with closed-source or API-only foundation models
π Repository Structure UNITE-Distillation/ βββ src/ β βββ models.py # Student ViT + CLIP-style alignment loss β βββ dataset.py # PyTorch Dataset class β βββ utils.py # H5 dimension fixing & helpers βββ scripts/ β βββ 1_preprocess_wsi.py # Patch extraction from .svs β βββ 2_train.py # Distillation training loop β βββ 3_extract_student.py # Student inference β βββ 4_extract_uni2.py # Teacher (UNI-2) baseline extraction β βββ 5_pool_features.py # Patch β slide/case pooling β βββ 6_benchmark.py # PathoBench: retrieval + linear probe βββ requirements.txt βββ README.md
π οΈ Installation & Setup
-
Clone the repository git clone https://github.com/AyushChaurasia18/UNITE_Distillation.git cd UNITE_Distillation
-
Install dependencies pip install -r requirements.txt
-
Install external frameworks
This project uses TRIDENT for WSI processing and PathoBench for evaluation.
pip install git+https://github.com/mahmoodlab/TRIDENT.git pip install git+https://github.com/mahmoodlab/Patho-Bench.git
πββοΈ Usage Pipeline Step 1: Data Pre-processing
Extract patches from .svs Whole Slide Images.
python scripts/1_preprocess_wsi.py
Step 2: Distillation Training
Train UNITE using pre-computed teacher embeddings.
python scripts/2_train.py
Step 3: Feature Extraction
Generate student embeddings for the full dataset.
python scripts/3_extract_student.py
Step 4: Benchmarking
Evaluate retrieval & linear probing performance.
python scripts/5_pool_features.py python scripts/6_benchmark.py
π Results (CPTAC-OV) Model Params Embedding Dim Retrieval (mAP@1) Linear Probe (BACC) UNI-2 (Teacher) ~681M 1536 0.400 0.346 UNITE (Student) ~19M 384 0.352 0.412
UNITE achieves 35Γ compression and outperforms the teacher in linear separability.
π» Computational Requirements Training
NVIDIA H100 80GB recommended
Batch size 64+
bf16 mixed precision
Total time: ~23 hours for 150 epochs
Inference
Works on consumer GPUs (RTX 3090/4090)
π Acknowledgements
TRIDENT & PathoBench β Mahmood Lab, Harvard Medical School
Mentor: Prof. Maitrik Shah (Ahmedabad University)