Stroke is a common cerebrovascular disease and one of the leading causes of death in the world. Accurate segmentation of stroke lesions is critical for diagnosis, treatment planning and disease monitoring. Over the last decade, various segmentation approaches have been proposed, with deep learning based methods becoming more and more essential for clinicians. Despite their rapid growth, current methods are often highly specialized for a particular imaging modality and stroke stage and result in heavy and complex pipelines that are difficult to deploy in clinical settings.
In this work, we propose StrokeFormer, a lightweight two-stage pipeline for stroke lesion segmentation, designed for clinical deployment. StrokeFormer combines a 2D region proposal mechanism with a 3D native segmentation model to reduce computational complexity, mitigate severe class imbalance, and better exploit anatomical context.
Despite promising theoretical results, the practical effectiveness of the current pipeline is majorly limited by the region proposal stage, which has to deal with extreme class imbalance. In response, this work investigates and discusses strategies to combat overfitting and undesired biases in stroke lesion segmentation, with the goal of providing insights and guidance for future research efforts.
Warning
Documentation Release Status: The documentation is currently under embargo. The full source documentation will be released publicly in August 2026.
This README serves as a general overview of the proposed method and documents its usage.
The two phases of StrokeFormer are organized as follows:
-
Region Proposal: Given a 3D volume, it is first decomposed into its 2D slices, then region proposal is performed independently on each. Since multiple regions can be proposed within a single slice, for example, when several lesions are present, we consolidate these into a single larger bounding box, that encompasses all proposed regions. Because the RPN treats each slice independently, predicted regions may vary across consecutive slices, especially when anatomical perspectives vary between them. Therefore, for each 3D scan, we identify and group slices corresponding to distinct "anatomical regions" based on brain shape and size. To enforce spatial consistency, within each group, we enforce within each group the most frequent proposed region. If there is no most frequent region, a bigger bounding box is created to encompass all proposals. Region proposals are therefore aggregated within each group, and segmentation is performed independently for each anatomical region;
-
Segmentation: The proposed regions are expanded to match a fixed RoI size (e.g.
$64 \times 64 \times 64$ ). When the 3D regions exceed this fixed resolution, we apply a sliding window approach over the larger region, processing overlapping sub-volume sequentially. The segmentation outputs from these overlapping windows are then averaged to produce the final prediction. Regions that were not proposed by the RPN model, and therefore not segmented, are masked to prevent the training signal from being distorted by non-relevant regions. The segmented outputs are finally concatenated back into a full 3D volume, reconstituting the whole scan.
To mitigate class imbalance, we propose a data augmentation strategy to generate anatomically plausible lesions. The strategy involves four main steps:
-
We identify similar healthy-sick pairs by comparing: brain size, shape and overall structure. Brain size and brain shape are compared using IoU, while overall structure similarity is computed through SSIM;
-
To ensure anatomical plausbility, we refine the previously selected pairs by retaining only those whose slice depth differs by no more than
$30$ slices; -
To induce and increase lesion variability, for each healthy slice, we randomly select lesions from its paired sick counterparts. Each lesion undergoes random roation and/or erosion;
-
Finally, the transformed lesions are transferred onto the healthy slice using Poisson image editing. To preserve the original slice intensities and avoid excessive blurring, we perform the transfer considering only the sick region on both slices.
To install the necessary requirements for the project, please follow the steps below.
Verify you have Python installed on your machine. The project is compatible with Python 3.11 or higher.
If you do not have Python installed, please refer to the official Python Guide.
It's strongly recommended to create a virtual environment for the project and activate it before proceeding.
Feel free to use any Python package manager to create the virtual environment. However, for a smooth installation of the requirements we recommend you use pip. Please refer to Creating a virtual environment.
You may skip this step, but please keep in mind that doing so could potentially lead to conflicts if you have other projects on your machine.
To clone this repository, download and extract the .zip project files using the <Code> button on the top-right or run the following command in your terminal:
git clone https://github.com/angelonazzaro/StrokeFormer.gitIf you have an NVIDIA GPU with cuda installed on your system, run the following command in your shell:
pip install -r requirements-cuda.txtOtherwise, if your machine does not support cuda, run the following command in your shell to make the code run on CPU:
pip install -r requirements.txtThis repository provides three main entry-point scripts for brain lesion segmentation:
cross_validation.py: End-to-end K-fold cross-validation pipeline (RPN + StrokeFormer) on 3D brain volumes.test_rpn.py: Evaluation and visualization of a trained RPN on 2D slices.test_strokeformer.py: Evaluation and visualization of a trained StrokeFormer model on 3D volumes, with optional Grad-CAM.train_rpn.py: Trains the RPN.train_strokeformer.py: Train the segmentation model.
All scripts are intended to be run from the project root with Python 3.
Cross-validation pipeline that:
- Splits 3D volumes into K folds with balanced lesion-size distribution;
- Trains an RPN on 2D slices derived from each fold;
- Uses the RPN to generate region proposals and saves them;
- Trains StrokeFormer using those proposals;
- Tests both RPN and StrokeFormer and logs metrics and predictions.
python cross_validation.py \
--scans_dir /path/to/Scans \
--masks_dir /path/to/Masks \
--model_prefix StrokeFormerATLAS \
--rpn_model_prefix RPNATLAS| Argument | Type | Description |
|---|---|---|
--scans_dir |
str |
Root directory containing 3D scan volumes (e.g., *T1w*.npy) |
--masks_dir |
str |
Root directory containing corresponding lesion masks (e.g., *T1lesion_mask*.npy) |
--model_prefix |
str |
Prefix for naming StrokeFormer runs and checkpoints |
--rpn_model_prefix |
str |
Prefix for naming RPN runs and checkpoints |
| Argument | Type | Default | Description |
|---|---|---|---|
--seed |
int |
42 |
Random seed for all random number generators |
--k |
int |
5 |
Number of folds for cross-validation |
--splits |
3 ints |
70 10 20 |
Percentage split (train, val, test); must sum to 100 |
--use_augmented |
flag |
False |
Use augmented scans for training when available |
--augmented_dir |
str |
data/ATLAS_2/Processed/Augmented/ |
Base directory containing augmented masks |
| Argument | Type | Default | Description |
|---|---|---|---|
--ext |
str |
.npy |
File extension for scan/mask volumes |
--transforms |
list[str] |
None |
Data augmentation/transformation functions |
--rpn_batch_size |
int |
32 |
Batch size for RPN training/testing |
--batch_size |
int |
8 |
Batch size for StrokeFormer training/testing |
--num_workers |
int |
0 |
Number of workers for PyTorch dataloaders |
--subvolume_depth |
int |
189 |
Depth of 3D subvolumes for StrokeFormer |
--overlap |
float |
None |
Overlap ratio for sliding-window inference |
--resize_to |
2 ints |
None |
In-plane resize (H, W) |
--scan_dim |
4 ints |
1 189 192 192 |
Expected scan shape (C, D, H, W) |
| Argument | Type | Default | Description |
|---|---|---|---|
--num_classes |
int |
2 |
Number of segmentation classes |
--in_channels |
int |
1 |
Input channels for StrokeFormer |
--opt_lr |
float |
3e-5 |
Base learning rate |
--warmup_lr |
float |
4e-6 |
Warmup learning rate |
--max_lr |
float |
4e-4 |
Maximum learning rate |
--warmup_steps |
int |
10 |
Warmup steps |
--weight_decay |
float |
1e-3 |
Weight decay |
--eps |
float |
1e-8 |
Optimizer epsilon |
--betas |
2 floats |
0.9 0.999 |
Adam beta coefficients |
--roi_size |
3 ints |
64 64 64 |
3D ROI size for region proposals |
--seg_loss |
str |
DiceCELoss |
Segmentation loss function |
--seg_loss_cfg |
str |
JSON config | Segmentation loss hyperparameters |
--cls_loss |
str |
None |
Classification loss (optional) |
--cls_loss_weight |
float |
0.5 |
Classification loss weight |
--seg_loss_weight |
float |
0.5 |
Segmentation loss weight |
| Argument | Type | Default | Description |
|---|---|---|---|
--default_root_dir |
str |
StrokeFormer |
StrokeFormer checkpoints/logs directory |
--project |
str |
StrokeFormer |
WandB project name |
--group |
str |
None |
WandB group name |
--max_epochs |
int |
250 |
Maximum training epochs |
--patience |
int |
50 |
Early stopping patience |
--entity |
str |
neurone-lab |
WandB entity |
--offline |
flag |
False |
Run WandB offline |
--devices |
list[int] |
[0] |
GPU device indices |
| Argument | Type | Default | Description |
|---|---|---|---|
--rpn_lr |
float |
1e-4 |
RPN learning rate |
--rpn_eps |
float |
1e-8 |
RPN optimizer epsilon |
--rpn_betas |
2 floats |
0.9 0.999 |
RPN Adam betas |
--rpn_weight_decay |
float |
1e-4 |
RPN weight decay |
--dataset_anchors |
str |
ATLAS |
Dataset type for anchor config |
--rpn_backbone_weights |
str |
None |
RPN backbone weights |
--rpn_default_root_dir |
str |
RPN |
RPN checkpoints/logs directory |
--rpn_project |
str |
RPN |
RPN WandB project |
--rpn_max_epochs |
int |
250 |
RPN maximum epochs |
--rpn_patience |
int |
50 |
RPN early stopping patience |
--rpn_group |
str |
None |
RPN WandB group |
--rpn_model_prefix |
str |
required | RPN model prefix |
| Argument | Type | Default | Description |
|---|---|---|---|
--min_delta |
float |
0.001 |
Minimum validation loss improvement |
--lr_logging_interval |
str |
epoch |
LR logging interval (epoch/step) |
--num_samples |
int |
5 |
Number of prediction samples to log |
--log_every_n_val_epochs |
int |
5 |
Prediction logging frequency |
| Argument | Type | Default | Description |
|---|---|---|---|
--target_layers |
str+ |
None |
Grad-CAM target layers for StrokeFormer |
--scores_dir |
str |
./scores |
StrokeFormer scores directory |
--scores_file |
str |
scores.csv |
StrokeFormer scores filename |
--rpn_scores_dir |
str |
./rpn_scores |
RPN scores directory |
--rpn_scores_file |
str |
scores.csv |
RPN scores filename |
Evaluate a trained RPN on 2D slices and save metrics/visualizations.
python test_rpn.py \
--scans /path/to/2D/Scans \
--masks /path/to/2D/Masks \
--ckpt_path /path/to/rpn.ckpt \
--model_name RPNATLAS-test| Argument | Type | Default | Description |
|---|---|---|---|
--seed |
int |
42 |
Random seed |
--scans |
str |
required | Directory of 2D scan slices |
--masks |
str |
required | Directory of 2D masks |
--batch_size |
int |
32 |
Test batch size |
--num_workers |
int |
0 |
Dataloader workers |
--resize_to |
2 ints |
None |
Resize (H, W) |
--ckpt_path |
str |
required | Path to RPN checkpoint |
--model_name |
str |
None |
Name for scores/predictions |
--num_classes |
int |
2 |
Number of classes |
--n_predictions |
int |
30 |
Number of prediction visualizations |
--scores_dir |
str |
./rpn_scores |
Output directory |
--scores_file |
str |
rpn_scores.csv |
Metrics CSV filename |
Evaluate StrokeFormer on 3D volumes with per-lesion-size metrics and optional Grad-CAM.
python test_strokeformer.py \
--scans /path/to/Scans \
--masks /path/to/Masks \
--ckpt_path /path/to/strokeformer.ckpt \
--model_name StrokeFormerATLAS-test| Argument | Type | Default | Description |
|---|---|---|---|
--seed |
int |
42 |
Random seed |
--scans |
str+ |
required | One or more scan paths/directories |
--masks |
str+ |
None |
One or more mask paths/directories |
--subvolume_depth |
int |
189 |
Sliding window depth |
--overlap |
float |
None |
Sliding window overlap |
--batch_size |
int |
8 |
Test batch size |
--num_workers |
int |
0 |
Dataloader workers |
--scan_dim |
4 ints |
1 189 192 192 |
Expected scan shape |
--regions |
str |
None |
Precomputed regions JSON |
--resize_to |
2 ints |
None |
In-plane resize (H, W) |
--ckpt_path |
str |
required | StrokeFormer checkpoint |
--model_name |
str |
None |
Name for scores/predictions |
--target_layers |
str+ |
None |
Grad-CAM target layers |
--n_predictions |
int |
30 |
Number of visualizations |
--num_classes |
int |
2 |
Number of classes |
--scores_dir |
str |
./scores |
Output directory |
--scores_file |
str |
scores.csv |
Global metrics CSV |
--per_size_scores_file |
str |
per_size_scores.csv |
Per-lesion-size metrics CSV |
data/
├── Scans/
│ ├── sub-001_T1w.npy
│ └── sub-002_T1w.npy
├── Masks/
│ ├── sub-001_label-L_desc-T1lesion_mask.npy
│ └── sub-002_label-L_desc-T1lesion_mask.npy
(Optional) Augmented/
├── Scans/
└── Masks/
scores/ # test_strokeformer.py & cross_validation.py
├── StrokeFormerATLAS/
│ └── predictions/
├── scores.csv # Global metrics
└── per_size_scores.csv # Per-lesion-size metrics
rpn_scores/ # test_rpn.py & cross_validation.py
├── RPNATLAS/
│ └── predictions/
└── rpn_scores.csv
For complete argument details and advanced usage, see the source code argument parsers.
If you found this work useful, please consider citing:
@article{nazzaro2026:strokeformer,
author = {Angelo Nazzaro},
title = {StrokeFormer: A lightweight approach for stroke lesion segmentation},
year = {2026},
institution = {University of Salerno}
url = {https://github.com/angelonazzaro/StrokeFormer}
}
