Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 23 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ VesselBoost is also available as a web application. To access the webapp, please
</p>

### **OpenRecon**
VesselBoost is also available on OpenRecon. To run VesselBoost on OpenRecon, please refer to the [usage example](https://github.com/KMarshallX/VesselBoost/blob/master/documentation/openrecon_example.md).
VesselBoost is also available on Siemens OpenRecon. To run VesselBoost on OpenRecon enabled scanners (>XA60), please refer to the [open recon container](https://github.com/neurodesk/neurocontainers/tree/main/recipes/vesselboost).

## **Installation**
This is a Python-based software package. To successfully run this project on your local machine, please follow the following steps to set up the necessary software environment.
Expand Down Expand Up @@ -98,6 +98,28 @@ This is a Python-based software package. To successfully run this project on you
conda activate vessel_boost_ci
```

### **Brain extraction in offline environments**
Brain extraction uses SynthStrip and requires the `synthstrip.1.pt` weights file. If the file is not available locally, VesselBoost tries to download it from the FreeSurfer server at runtime. When there is no internet connection and no local weights file, brain extraction fails with an error.

On a connected machine, download the weights into the standard VesselBoost location:

```
mkdir -p saved_models
curl -L \
-o saved_models/synthstrip.1.pt \
https://surfer.nmr.mgh.harvard.edu/docs/synthstrip/requirements/synthstrip.1.pt
```

If `curl` is unavailable, use `wget`:

```
wget \
-O saved_models/synthstrip.1.pt \
https://surfer.nmr.mgh.harvard.edu/docs/synthstrip/requirements/synthstrip.1.pt
```

For airgapped or offline deployments, copy `saved_models/synthstrip.1.pt` into the deployment image or runtime directory before running VesselBoost. Alternatively, set `VESSELBOOST_SYNTHSTRIP_WEIGHTS` to the weights file path or to the directory containing it.

## **Citation**
VesselBoost paper is now published! Please cite us if you use VesselBoost in your research:

Expand Down Expand Up @@ -125,4 +147,3 @@ Marshall Xu <[marshall.xu@uq.edu.au](marshall.xu@uq.edu.au)>
Saskia Bollmann <[saskia.bollmann@uq.edu.au](saskia.bollmann@uq.edu.au)>

Fernanda Ribeiro <[fernanda.ribeiro@uq.edu.au](fernanda.ribeiro@uq.edu.au)>

69 changes: 0 additions & 69 deletions documentation/openrecon_example.md

This file was deleted.

152 changes: 127 additions & 25 deletions library/synthstrip_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from __future__ import annotations

import os
import requests
from pathlib import Path
from typing import Optional, Tuple
from nitransforms.linear import Affine
from torch import nn
Expand All @@ -22,38 +22,140 @@
import torch
import scipy

def download_weights():
url = "https://surfer.nmr.mgh.harvard.edu/docs/synthstrip/requirements/synthstrip.1.pt"
if not os.path.exists("../saved_models/"):
os.mkdir("../saved_models/")
if os.path.exists("../saved_models/synthstrip.1.pt"):
print("\nSynthStrip weights already exist. Skipping download.")
return
print(f"Downloading SynthStrip weights from {url}...")
response = requests.get(url)
if response.status_code == 200:
with open("../saved_models/synthstrip.1.pt", "wb") as f:
f.write(response.content)
print("Download complete!")
else:
print(f"Failed to download weights. Status code: {response.status_code}")
SYNTHSTRIP_WEIGHTS_FILENAME = "synthstrip.1.pt"
SYNTHSTRIP_WEIGHTS_ENV = "VESSELBOOST_SYNTHSTRIP_WEIGHTS"
SYNTHSTRIP_WEIGHTS_URL = "https://surfer.nmr.mgh.harvard.edu/docs/synthstrip/requirements/synthstrip.1.pt"


def _repo_root() -> Path:
return Path(__file__).resolve().parents[1]


def _candidate_weight_paths(weights_path: Optional[os.PathLike[str] | str] = None) -> list[Path]:
"""Return local SynthStrip weight paths in lookup order."""
candidates: list[Path] = []

def add_candidate(path_like: os.PathLike[str] | str) -> None:
path = Path(path_like).expanduser()
candidates.append(path)
if path.name != SYNTHSTRIP_WEIGHTS_FILENAME:
candidates.append(path / SYNTHSTRIP_WEIGHTS_FILENAME)

if weights_path:
add_candidate(weights_path)

env_path = os.environ.get(SYNTHSTRIP_WEIGHTS_ENV)
if env_path:
add_candidate(env_path)

candidates.extend(
[
_repo_root() / "saved_models" / SYNTHSTRIP_WEIGHTS_FILENAME,
Path.cwd() / "saved_models" / SYNTHSTRIP_WEIGHTS_FILENAME,
Path.cwd().parent / "saved_models" / SYNTHSTRIP_WEIGHTS_FILENAME,
]
)

deduped: list[Path] = []
seen: set[str] = set()
for candidate in candidates:
key = str(candidate)
if key not in seen:
deduped.append(candidate)
seen.add(key)
return deduped


def resolve_weights_path(weights_path: Optional[os.PathLike[str] | str] = None) -> Path:
"""
Resolve the local SynthStrip weights path without attempting network access.
"""
candidates = _candidate_weight_paths(weights_path)
for candidate in candidates:
if candidate.is_file():
return candidate.resolve()

searched = "\n - ".join(str(candidate) for candidate in candidates)
raise FileNotFoundError(
"SynthStrip brain extraction requires local model weights, but "
f"{SYNTHSTRIP_WEIGHTS_FILENAME} was not found.\n"
f"Place {SYNTHSTRIP_WEIGHTS_FILENAME} in ./saved_models, or set "
f"{SYNTHSTRIP_WEIGHTS_ENV} to the weights file or containing directory.\n"
f"Searched:\n - {searched}"
)


def _download_destination(weights_path: Optional[os.PathLike[str] | str] = None) -> Optional[Path]:
if weights_path:
return Path(weights_path).expanduser()

env_path = os.environ.get(SYNTHSTRIP_WEIGHTS_ENV)
if env_path:
return Path(env_path).expanduser()

return None


def load_strip_model(device: torch.device):
def download_weights(destination: Optional[os.PathLike[str] | str] = None, timeout: int = 60) -> Path:
"""
Download SynthStrip weights into the requested destination.
"""
import requests

destination_path = Path(destination).expanduser() if destination else _repo_root() / "saved_models"
if destination_path.name != SYNTHSTRIP_WEIGHTS_FILENAME:
destination_path = destination_path / SYNTHSTRIP_WEIGHTS_FILENAME
destination_path.parent.mkdir(parents=True, exist_ok=True)

if destination_path.exists():
print(f"\nSynthStrip weights already exist at {destination_path}. Skipping download.")
return destination_path.resolve()

print(f"Downloading SynthStrip weights from {SYNTHSTRIP_WEIGHTS_URL}...")
try:
response = requests.get(SYNTHSTRIP_WEIGHTS_URL, timeout=timeout)
response.raise_for_status()
except requests.RequestException as exc:
raise RuntimeError(
"Failed to download SynthStrip weights. Check the internet connection, "
f"or place {SYNTHSTRIP_WEIGHTS_FILENAME} in ./saved_models or set "
f"{SYNTHSTRIP_WEIGHTS_ENV} to a local weights path."
) from exc

destination_path.write_bytes(response.content)
print(f"Download complete: {destination_path}")
return destination_path.resolve()


def get_or_download_weights(weights_path: Optional[os.PathLike[str] | str] = None) -> Path:
"""
Resolve local SynthStrip weights, downloading them if they are missing.
"""
try:
return resolve_weights_path(weights_path)
except FileNotFoundError:
try:
return download_weights(_download_destination(weights_path))
except RuntimeError as download_error:
raise RuntimeError(
"SynthStrip weights were not found locally and could not be downloaded. "
"An internet connection is required for the automatic download; "
f"offline deployments must provide {SYNTHSTRIP_WEIGHTS_FILENAME} locally."
) from download_error


def load_strip_model(device: torch.device, weights_path: Optional[os.PathLike[str] | str] = None):
"""
Load the `StripModel` weights from a checkpoint file.
"""

download_weights()
modelfile = "../saved_models/synthstrip.1.pt"
if not os.path.exists(modelfile):
raise FileNotFoundError(modelfile)
modelfile = get_or_download_weights(weights_path)

model = StripModel()
model.to(device)
model.eval()

checkpoint = torch.load(modelfile, map_location=device)
checkpoint = torch.load(str(modelfile), map_location=device)
if 'model_state_dict' in checkpoint:
state = checkpoint['model_state_dict']
else:
Expand Down Expand Up @@ -114,19 +216,21 @@ def skull_strip(
image: nb.nifti1.Nifti1Image,
device: torch.device,
border: int = 1,
weights_path: Optional[os.PathLike[str] | str] = None,
) -> Tuple[np.ndarray, nb.nifti1.Nifti1Image]:
"""Run the synthstrip pipeline on an input image and return the mask.

Parameters
- image: input image as a Nifti1Image object
- device: torch device to use. If omitted, it's configured from `gpu`.
- border: border threshold in mm used to generate final mask
- weights_path: optional local path to SynthStrip weights

Returns
- mask: boolean numpy array of the brain mask in native image grid
"""

model = load_strip_model(device)
model = load_strip_model(device, weights_path)

# load input volume
conformed = conform(image)
Expand Down Expand Up @@ -286,5 +390,3 @@ def forward(self, x):
if self.activation is not None:
out = self.activation(out)
return out


Loading