π§ IHFNet: Incomplete Multimodal Hierarchical Feature Fusion Network for Mild Cognitive Impairment Conversion Prediction
This repository implements IHFNet and multiple baselines for pMCI vs sMCI classification on ADNI using T1 sMRI + FDG-PET + clinical features.
- Training with Stratified 5-fold cross-validation (default) via
main_rebuild.py - Multiple model options configured in
model_object.py(IHFNet and baselines) - Inference example on provided
.niisample files viaInference.py - Cross-dataset evaluation ADNI1 β ADNI2 via
cross_dataset_test.py(loads 5 fold checkpoints and reports per-fold + mean metrics)
.
βββ assets/
βββ example_data/
β βββ test_mri.nii
β βββ test_pet.nii
βββ Net/
β βββ IHFNet.py
β βββ TriLightNet.py
β βββ ResnetEncoder.py
β βββ poolformer.py
β βββ metaformer3D.py
β βββ defineViT.py
β βββ ComparisonNet.py
β βββ basic.py
β βββ kan.py
βββ utils/
β βββ api.py
β βββ basic.py
β βββ observer.py
βββ Config.py
βββ Dataset.py
βββ loss_function.py
βββ model_object.py
βββ model_params.py
βββ main_rebuild.py
βββ Inference.py
βββ cross_dataset_test.py
βββ requirements.txt
βββ README.md
The dataset is obtained from the Alzheimer's Disease Neuroimaging Initiative (ADNI), specifically the ADNI-1 and ADNI-2 cohorts.
ADNI link: https://adni.loni.usc.edu/
This codebase expects NIfTI volumes named by a shared identifier (e.g., subject/image id). The dataset loader will look for:
- MRI file:
<mri_dir>/<ID>.nii - PET file:
<pet_dir>/<ID>.nii
Example layout:
<DATA_ROOT>/
βββ ADNI1/
β βββ MRI/
β β βββ <ID>.nii
β β βββ ...
β βββ PET/
β β βββ <ID>.nii
β β βββ ...
β βββ labels.csv
β βββ clinical.csv
βββ ADNI2/
βββ MRI/
βββ PET/
βββ labels.csv
βββ clinical.csv
Dataset.py reads a CSV where:
- Column 0: image id (without extension)
- Column 1: label string (must contain the task labels)
For the default task (pMCI vs sMCI), labels should be exactly pMCI or sMCI.
Minimal example:
PTID,label
002_S_0413,sMCI
002_S_0619,pMCINote: Internally the dataset maps
sMCI -> 0,pMCI -> 1.
The clinical file is not included in this repo due to privacy. The default clinical feature extraction (get_clinical in Dataset.py) expects the following columns:
PTID(must match the image id)PTGENDER,AGE,PTEDUCAT,FDG_bl,TAU_bl,PTAU_bl,APOE4
Missing values will be handled by the code (filled with zeros in the current implementation).
Create a clean Python environment (recommended) and install dependencies:
pip install -r requirements.txtrequirements.txt pins torch==2.6.0+cu118 and torchvision==0.21.0+cu118. If pip install -r requirements.txt fails on your machine, install PyTorch first following the official instructions, then install the remaining packages.
The training entry in main_rebuild.py asserts CUDA availability. Run training on a CUDA-capable GPU environment.
The main training entry is:
python main_rebuild.py \
--model IHFNet_With_MLP \
--mri_dir <ADNI1/MRI> \
--pet_dir <ADNI1/PET> \
--cli_dir <ADNI1/clinical.csv> \
--csv_file <ADNI1/labels.csv> \
--device cuda:0 \
--n_splits 5 \
--batch_size 8 \
--seed 42Windows/PowerShell users: you can run the same command in one line, or use backticks
`for line continuation.
Model indices are defined in model_object.py (dictionary models). Common choices:
IHFNet/IHFNet_With_MLP/IHFNetWithoutCMIMIMF,HFBSurv,TriLightNet, etc.
Use --model <MODEL_INDEX> to switch.
During training, a timestamped output directory is created:
./<checkpoints_dir>_<YYYY-mm-dd_HH-MM>/
For each fold, the best checkpoint is saved as:
<ExperimentName>_best_model_fold{1..n_splits}.pth
For IHFNet variants, <ExperimentName> is IHFNet (see model_object.py).
TensorBoard logs are written under:
<output_dir>/summery/
We provide sample MRI and PET images in example_data/.
Run:
python Inference.pyBefore running, edit Inference.py to set:
model_path: path to your.pthcheckpointcli_features: a list of 9 clinical features used by the script
Note: This script is a minimal example and does not load clinical CSV automatically.
The cross-dataset test script loads the 5 fold checkpoints trained on one dataset and evaluates on the other dataset, printing per-fold metrics and the mean.
Run:
python cross_dataset_test.py \
--direction both \
--models all \
--n_splits 5 \
--batch_size 8 \
--device cuda:0 \
--adni1_mri_dir <ADNI1/MRI> --adni1_pet_dir <ADNI1/PET> --adni1_cli_dir <ADNI1/clinical.csv> --adni1_csv_file <ADNI1/labels.csv> --adni1_weights_dir <ADNI1_checkpoints_dir> \
--adni2_mri_dir <ADNI2/MRI> --adni2_pet_dir <ADNI2/PET> --adni2_cli_dir <ADNI2/clinical.csv> --adni2_csv_file <ADNI2/labels.csv> --adni2_weights_dir <ADNI2_checkpoints_dir>Outputs:
- A JSON summary file will be written to
./cross_dataset_test_outputs/.
In the Modality column,
M,P, andCdenote MRI, PET, and clinical data.
The following results are reported as mean Β± std over 5 folds (as shown in the paper figure).
| Method | Modality | ACC β | PRE β | BACC β | AUC β | F1 β |
|---|---|---|---|---|---|---|
| HOPE | M | 0.611 Β± 0.038 | 0.599 Β± 0.042 | 0.699 Β± 0.051 | 0.648 Β± 0.062 | 0.593 Β± 0.098 |
| ResNet | M,P | 0.725 Β± 0.041 | 0.671 Β± 0.057 | 0.693 Β± 0.048 | 0.653 Β± 0.069 | 0.606 Β± 0.087 |
| JSRL | M,P | 0.582 Β± 0.047 | 0.580 Β± 0.054 | 0.566 Β± 0.059 | 0.571 Β± 0.073 | 0.580 Β± 0.105 |
| VAPL | M,C | 0.630 Β± 0.039 | 0.648 Β± 0.048 | 0.628 Β± 0.054 | 0.635 Β± 0.067 | 0.651 Β± 0.092 |
| Diamond | M,P | 0.736 Β± 0.044 | 0.719 Β± 0.056 | 0.693 Β± 0.047 | 0.638 Β± 0.076 | 0.591 Β± 0.094 |
| HyperFusionNet | M,P | 0.688 Β± 0.020 | 0.704 Β± 0.180 | 0.585 Β± 0.034 | 0.613 Β± 0.075 | 0.336 Β± 0.113 |
| nnMamba | M,P | 0.694 Β± 0.070 | 0.629 Β± 0.150 | 0.686 Β± 0.062 | 0.684 Β± 0.085 | 0.609 Β± 0.074 |
| MDLNet | M,P | 0.738 Β± 0.058 | 0.820 Β± 0.165 | 0.650 Β± 0.051 | 0.676 Β± 0.053 | 0.477 Β± 0.104 |
| HFBSurv | M,P,C | 0.740 Β± 0.036 | 0.701 Β± 0.052 | 0.714 Β± 0.049 | 0.711 Β± 0.074 | 0.630 Β± 0.091 |
| IMF | M,P,C | 0.756 Β± 0.040 | 0.740 Β± 0.058 | 0.710 Β± 0.050 | 0.720 Β± 0.077 | 0.605 Β± 0.089 |
| MultiModalADNet | M,P,C | 0.700 Β± 0.083 | 0.641 Β± 0.183 | 0.693 Β± 0.147 | 0.650 Β± 0.076 | 0.637 Β± 0.061 |
| IHFNet (Ours) | M,P,C | 0.737 Β± 0.043 | 0.708 Β± 0.086 | 0.723 Β± 0.061 | 0.738 Β± 0.078 | 0.666 Β± 0.130 |
| Method | Modality | ACC β | PRE β | BACC β | AUC β | F1 β |
|---|---|---|---|---|---|---|
| HOPE | M | 0.701 Β± 0.045 | 0.706 Β± 0.053 | 0.645 Β± 0.049 | 0.624 Β± 0.078 | 0.505 Β± 0.102 |
| ResNet | M,P | 0.809 Β± 0.039 | 0.721 Β± 0.061 | 0.683 Β± 0.052 | 0.709 Β± 0.083 | 0.510 Β± 0.095 |
| JSRL | M,P | 0.650 Β± 0.052 | 0.600 Β± 0.067 | 0.655 Β± 0.054 | 0.694 Β± 0.090 | 0.519 Β± 0.112 |
| VAPL | M,C | 0.712 Β± 0.043 | 0.750 Β± 0.071 | 0.672 Β± 0.058 | 0.623 Β± 0.086 | 0.561 Β± 0.089 |
| Diamond | M,P | 0.818 Β± 0.041 | 0.739 Β± 0.065 | 0.682 Β± 0.053 | 0.645 Β± 0.079 | 0.509 Β± 0.108 |
| HyperFusionNet | M,P | 0.791 Β± 0.008 | 0.537 Β± 0.011 | 0.750 Β± 0.041 | 0.790 Β± 0.044 | 0.595 Β± 0.043 |
| nnMamba | M,P | 0.750 Β± 0.049 | 0.481 Β± 0.077 | 0.719 Β± 0.050 | 0.782 Β± 0.069 | 0.554 Β± 0.075 |
| MDLNet | M,P | 0.769 Β± 0.080 | 0.714 Β± 0.214 | 0.649 Β± 0.033 | 0.679 Β± 0.098 | 0.446 Β± 0.063 |
| HFBSurv | M,P,C | 0.813 Β± 0.042 | 0.701 Β± 0.072 | 0.688 Β± 0.051 | 0.672 Β± 0.084 | 0.539 Β± 0.098 |
| IMF | M,P,C | 0.838 Β± 0.044 | 0.737 Β± 0.068 | 0.713 Β± 0.055 | 0.757 Β± 0.092 | 0.564 Β± 0.101 |
| MultiModalADNet | M,P,C | 0.750 Β± 0.066 | 0.531 Β± 0.105 | 0.754 Β± 0.027 | 0.801 Β± 0.011 | 0.621 Β± 0.033 |
| IHFNet (Ours) | M,P,C | 0.856 Β± 0.056 | 0.777 Β± 0.073 | 0.735 Β± 0.047 | 0.812 Β± 0.103 | 0.584 Β± 0.119 |
