diff --git a/README.md b/README.md index 2c69b76..094c9b3 100644 --- a/README.md +++ b/README.md @@ -65,7 +65,7 @@ VesselBoost is also available as a web application. To access the webapp, please

### **OpenRecon** -VesselBoost is also available on OpenRecon. To run VesselBoost on OpenRecon, please refer to the [usage example](https://github.com/KMarshallX/VesselBoost/blob/master/documentation/openrecon_example.md). +VesselBoost is also available on Siemens OpenRecon. To run VesselBoost on OpenRecon enabled scanners (>XA60), please refer to the [open recon container](https://github.com/neurodesk/neurocontainers/tree/main/recipes/vesselboost). ## **Installation** This is a Python-based software package. To successfully run this project on your local machine, please follow the following steps to set up the necessary software environment. @@ -98,6 +98,28 @@ This is a Python-based software package. To successfully run this project on you conda activate vessel_boost_ci ``` +### **Brain extraction in offline environments** +Brain extraction uses SynthStrip and requires the `synthstrip.1.pt` weights file. If the file is not available locally, VesselBoost tries to download it from the FreeSurfer server at runtime. When there is no internet connection and no local weights file, brain extraction fails with an error. + +On a connected machine, download the weights into the standard VesselBoost location: + +``` +mkdir -p saved_models +curl -L \ + -o saved_models/synthstrip.1.pt \ + https://surfer.nmr.mgh.harvard.edu/docs/synthstrip/requirements/synthstrip.1.pt +``` + +If `curl` is unavailable, use `wget`: + +``` +wget \ + -O saved_models/synthstrip.1.pt \ + https://surfer.nmr.mgh.harvard.edu/docs/synthstrip/requirements/synthstrip.1.pt +``` + +For airgapped or offline deployments, copy `saved_models/synthstrip.1.pt` into the deployment image or runtime directory before running VesselBoost. Alternatively, set `VESSELBOOST_SYNTHSTRIP_WEIGHTS` to the weights file path or to the directory containing it. + ## **Citation** VesselBoost paper is now published! Please cite us if you use VesselBoost in your research: @@ -125,4 +147,3 @@ Marshall Xu <[marshall.xu@uq.edu.au](marshall.xu@uq.edu.au)> Saskia Bollmann <[saskia.bollmann@uq.edu.au](saskia.bollmann@uq.edu.au)> Fernanda Ribeiro <[fernanda.ribeiro@uq.edu.au](fernanda.ribeiro@uq.edu.au)> - diff --git a/documentation/openrecon_example.md b/documentation/openrecon_example.md deleted file mode 100644 index 73dfd76..0000000 --- a/documentation/openrecon_example.md +++ /dev/null @@ -1,69 +0,0 @@ -# **OpenRecon Usage Example** - -## **Current Version** - VesselBoost 2.0.1 -### **Notes on the latest version** -If you set '--prep_mode' to 1,2 or 3, which means (1) N4 bias field correction, (2) denosing, or (3) both N4 biasfield correction and denoising will happen, then you have to set a path to store the preprocessed images. In the mean time, we also added an option to enable brain extraction ('--enable_brain_extraction') using Synthstrip (from FreeSurfer) to improve the robustness of the preprocessing step. - -If you set '--prep_mode' to 4, which means **no preprocessing** will happen, then you don't have to set a path to store the preprocessed images. Also, there will be **no brain extraction** for this case. - -For patches-based prediction, we added a new feature to use Gaussian blending to reduce edge artifacts and improve the quality of the final segmentation. - -### **Prediction Module** -For OpenRecon, there are few configurable options to run the prediction module: - -1. Set **"Vessel Boost Modules"** _("id": "vbmodules")_ to "prediction" to run the prediction module. -2. Set **"Preprocessing Mode"** _("id": "vbprepmode")_ to 1, 2, 3 or 4 to select the preprocessing method. We recommend setting it to 1 for applying N4 bias field correction. -3. When you set "Preprocessing Mode" to 1, 2 or 3, you can also choose to enable brain extraction by setting **"Brain extraction flag"** _("id": "vbbrainextraction")_ to true. We recommend enabling this feature. - -In this case, the system will run the equivalent command below: -```python - python prediction.py \ - --image_path "./data/img/" \ - --preprocessed_path "./data/preprocessed/" \ - --output_path "./data/pred_seg" \ - --pretrained "./saved_models/manual_0429" \ - --prep_mode 1 \ - --enable_brain_extraction \ - --use_blending \ - --overlap_ratio 0.5 -``` - -### **TTA Module** -The configurable options for running the TTA module are basically the same, but you have to set TWO MORE parameters for TTA: -1. Set **"Vessel Boost Modules"** _("id": "vbmodules")_ to "tta" to run the TTA module. -2. Set **"Epoch number"** _("id": "vbepochs") to the number of epochs you want to run for TTA. The default value is 200 epochs. -3. Set **"Learning rate"** _("id": "vbrate") to the learning rate you want to use for TTA. The default value is 1e-3. - -The following equivalent command will be executed: -```python - python test_time_adaptation.py \ - --image_path "./data/img/" \ - --preprocessed_path "./data/preprocessed/" \ - --output_path "./data/pred_seg" \ - --pretrained "./saved_models/manual_0429" \ - --prep_mode 1 \ - --enable_brain_extraction \ - --epochs 100 \ - --learning_rate 1e-3 \ - --use_blending \ - --overlap_ratio 0.5 -``` - -### **AngiBoost Module** -*Note: This module was designed to adapt booster module on open recon, while the function is the same as TTA. Might be deprecated in future versions.* - -The configurable options for running the AngiBoost module are basically the same as TTA, but you have to set **"Vessel Boost Modules"** _("id": "vbmodules")_ to "booster" to run the AngiBooster module. The following equivalent command will be executed: - -```python - python angiboost.py \ - --image_path "./data/img/" \ - --preprocessed_path "./data/preprocessed/" \ - --pretrained "./saved_models/manual_0429" \ - --label_path "./data/seg/" \ # to store the initial segmentation - --output_path "./data/boost_seg/" \ - --output_model "./data/boost_seg/boost_model" \ - --prep_mode 1 \ - --enable_brain_extraction \ - --epochs 100 \ - --learning_rate 1e-2 -``` \ No newline at end of file diff --git a/library/synthstrip_utils.py b/library/synthstrip_utils.py index 7396b16..0baf082 100644 --- a/library/synthstrip_utils.py +++ b/library/synthstrip_utils.py @@ -12,7 +12,7 @@ from __future__ import annotations import os -import requests +from pathlib import Path from typing import Optional, Tuple from nitransforms.linear import Affine from torch import nn @@ -22,38 +22,140 @@ import torch import scipy -def download_weights(): - url = "https://surfer.nmr.mgh.harvard.edu/docs/synthstrip/requirements/synthstrip.1.pt" - if not os.path.exists("../saved_models/"): - os.mkdir("../saved_models/") - if os.path.exists("../saved_models/synthstrip.1.pt"): - print("\nSynthStrip weights already exist. Skipping download.") - return - print(f"Downloading SynthStrip weights from {url}...") - response = requests.get(url) - if response.status_code == 200: - with open("../saved_models/synthstrip.1.pt", "wb") as f: - f.write(response.content) - print("Download complete!") - else: - print(f"Failed to download weights. Status code: {response.status_code}") +SYNTHSTRIP_WEIGHTS_FILENAME = "synthstrip.1.pt" +SYNTHSTRIP_WEIGHTS_ENV = "VESSELBOOST_SYNTHSTRIP_WEIGHTS" +SYNTHSTRIP_WEIGHTS_URL = "https://surfer.nmr.mgh.harvard.edu/docs/synthstrip/requirements/synthstrip.1.pt" + + +def _repo_root() -> Path: + return Path(__file__).resolve().parents[1] + + +def _candidate_weight_paths(weights_path: Optional[os.PathLike[str] | str] = None) -> list[Path]: + """Return local SynthStrip weight paths in lookup order.""" + candidates: list[Path] = [] + + def add_candidate(path_like: os.PathLike[str] | str) -> None: + path = Path(path_like).expanduser() + candidates.append(path) + if path.name != SYNTHSTRIP_WEIGHTS_FILENAME: + candidates.append(path / SYNTHSTRIP_WEIGHTS_FILENAME) + + if weights_path: + add_candidate(weights_path) + + env_path = os.environ.get(SYNTHSTRIP_WEIGHTS_ENV) + if env_path: + add_candidate(env_path) + + candidates.extend( + [ + _repo_root() / "saved_models" / SYNTHSTRIP_WEIGHTS_FILENAME, + Path.cwd() / "saved_models" / SYNTHSTRIP_WEIGHTS_FILENAME, + Path.cwd().parent / "saved_models" / SYNTHSTRIP_WEIGHTS_FILENAME, + ] + ) + + deduped: list[Path] = [] + seen: set[str] = set() + for candidate in candidates: + key = str(candidate) + if key not in seen: + deduped.append(candidate) + seen.add(key) + return deduped + + +def resolve_weights_path(weights_path: Optional[os.PathLike[str] | str] = None) -> Path: + """ + Resolve the local SynthStrip weights path without attempting network access. + """ + candidates = _candidate_weight_paths(weights_path) + for candidate in candidates: + if candidate.is_file(): + return candidate.resolve() + + searched = "\n - ".join(str(candidate) for candidate in candidates) + raise FileNotFoundError( + "SynthStrip brain extraction requires local model weights, but " + f"{SYNTHSTRIP_WEIGHTS_FILENAME} was not found.\n" + f"Place {SYNTHSTRIP_WEIGHTS_FILENAME} in ./saved_models, or set " + f"{SYNTHSTRIP_WEIGHTS_ENV} to the weights file or containing directory.\n" + f"Searched:\n - {searched}" + ) + + +def _download_destination(weights_path: Optional[os.PathLike[str] | str] = None) -> Optional[Path]: + if weights_path: + return Path(weights_path).expanduser() + + env_path = os.environ.get(SYNTHSTRIP_WEIGHTS_ENV) + if env_path: + return Path(env_path).expanduser() + + return None -def load_strip_model(device: torch.device): +def download_weights(destination: Optional[os.PathLike[str] | str] = None, timeout: int = 60) -> Path: + """ + Download SynthStrip weights into the requested destination. + """ + import requests + + destination_path = Path(destination).expanduser() if destination else _repo_root() / "saved_models" + if destination_path.name != SYNTHSTRIP_WEIGHTS_FILENAME: + destination_path = destination_path / SYNTHSTRIP_WEIGHTS_FILENAME + destination_path.parent.mkdir(parents=True, exist_ok=True) + + if destination_path.exists(): + print(f"\nSynthStrip weights already exist at {destination_path}. Skipping download.") + return destination_path.resolve() + + print(f"Downloading SynthStrip weights from {SYNTHSTRIP_WEIGHTS_URL}...") + try: + response = requests.get(SYNTHSTRIP_WEIGHTS_URL, timeout=timeout) + response.raise_for_status() + except requests.RequestException as exc: + raise RuntimeError( + "Failed to download SynthStrip weights. Check the internet connection, " + f"or place {SYNTHSTRIP_WEIGHTS_FILENAME} in ./saved_models or set " + f"{SYNTHSTRIP_WEIGHTS_ENV} to a local weights path." + ) from exc + + destination_path.write_bytes(response.content) + print(f"Download complete: {destination_path}") + return destination_path.resolve() + + +def get_or_download_weights(weights_path: Optional[os.PathLike[str] | str] = None) -> Path: + """ + Resolve local SynthStrip weights, downloading them if they are missing. + """ + try: + return resolve_weights_path(weights_path) + except FileNotFoundError: + try: + return download_weights(_download_destination(weights_path)) + except RuntimeError as download_error: + raise RuntimeError( + "SynthStrip weights were not found locally and could not be downloaded. " + "An internet connection is required for the automatic download; " + f"offline deployments must provide {SYNTHSTRIP_WEIGHTS_FILENAME} locally." + ) from download_error + + +def load_strip_model(device: torch.device, weights_path: Optional[os.PathLike[str] | str] = None): """ Load the `StripModel` weights from a checkpoint file. """ - download_weights() - modelfile = "../saved_models/synthstrip.1.pt" - if not os.path.exists(modelfile): - raise FileNotFoundError(modelfile) + modelfile = get_or_download_weights(weights_path) model = StripModel() model.to(device) model.eval() - checkpoint = torch.load(modelfile, map_location=device) + checkpoint = torch.load(str(modelfile), map_location=device) if 'model_state_dict' in checkpoint: state = checkpoint['model_state_dict'] else: @@ -114,6 +216,7 @@ def skull_strip( image: nb.nifti1.Nifti1Image, device: torch.device, border: int = 1, + weights_path: Optional[os.PathLike[str] | str] = None, ) -> Tuple[np.ndarray, nb.nifti1.Nifti1Image]: """Run the synthstrip pipeline on an input image and return the mask. @@ -121,12 +224,13 @@ def skull_strip( - image: input image as a Nifti1Image object - device: torch device to use. If omitted, it's configured from `gpu`. - border: border threshold in mm used to generate final mask + - weights_path: optional local path to SynthStrip weights Returns - mask: boolean numpy array of the brain mask in native image grid """ - model = load_strip_model(device) + model = load_strip_model(device, weights_path) # load input volume conformed = conform(image) @@ -286,5 +390,3 @@ def forward(self, x): if self.activation is not None: out = self.activation(out) return out - -