Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 21 additions & 6 deletions CorpusCallosum/cc_visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,15 @@
from pathlib import Path
from typing import Literal

import nibabel as nib
import numpy as np

from CorpusCallosum.data.fsaverage_cc_template import load_fsaverage_cc_template
from CorpusCallosum.shape.contour import CCContour
from CorpusCallosum.shape.mesh import CCMesh
from FastSurferCNN.utils import AffineMatrix4x4
from FastSurferCNN.utils.logging import get_logger, setup_logging
from FastSurferCNN.utils.lta import read_lta

logger = get_logger(__name__)

Expand Down Expand Up @@ -39,7 +42,7 @@ def make_parser() -> argparse.ArgumentParser:
"cc_mesh.vtk - VTK mesh file format "
"cc_mesh.fssurf - FreeSurfer surface file "
"cc_mesh_overlay.curv - FreeSurfer curvature overlay file "
"cc_mesh_snap.png - Screenshot/snapshot of the 3D mesh (requires whippersnappy>=1.3.1)",
"cc_mesh_snap.png - Screenshot/snapshot of the 3D mesh (requires whippersnappy>=2.1)",
metavar="OUTPUT_DIR"
)
parser.add_argument(
Expand Down Expand Up @@ -213,6 +216,17 @@ def main(
# 3D visualization
cc_mesh = CCMesh.from_contours(contours, smooth=0)

if Path(output_dir / "mri" / "upright.mgz").exists():
header = nib.load(output_dir / "mri" / "upright.mgz").header
# we need to get the upright image header, which is the same as cc_up.lta applied to orig.
elif Path(template_dir / "mri/orig.mgz").exists() and Path(template_dir / "mri/transforms/cc_up.lta").exists():
image = nib.load(template_dir / "mri" / "orig.mgz")
lta_mat: AffineMatrix4x4 = read_lta(template_dir / "mri/transforms/cc_up.lta")["lta"]
image.affine = lta_mat @ image.affine
header = image.header
else:
header = None

plot_kwargs = dict(
colormap=colormap,
color_range=color_range,
Expand All @@ -225,15 +239,16 @@ def main(
logger.info(f"Writing vtk file to {output_dir / 'cc_mesh.vtk'}")
cc_mesh.write_vtk(str(output_dir / "cc_mesh.vtk"))
logger.info(f"Writing freesurfer surface file to {output_dir / 'cc_mesh.fssurf'}")
cc_mesh.write_fssurf(str(output_dir / "cc_mesh.fssurf"))
cc_mesh.write_fssurf(str(output_dir / "cc_mesh.fssurf"), image=header)
logger.info(f"Writing freesurfer overlay file to {output_dir / 'cc_mesh_overlay.curv'}")
cc_mesh.write_morph_data(str(output_dir / "cc_mesh_overlay.curv"))
try:
cc_mesh.snap_cc_picture(str(output_dir / "cc_mesh_snap.png"))
cc_mesh.snap_cc_picture(str(output_dir / "cc_mesh_snap.png"), ref_header=header)
logger.info(f"Writing 3D snapshot image to {output_dir / 'cc_mesh_snap.png'}")
except RuntimeError:
logger.warning("The cc_visualization script requires whippersnappy>=1.3.1 to makes screenshots, install with "
"`pip install whippersnappy>=1.3.1` !")
except Exception:
logger.warning("The cc_visualization script requires whippersnappy>=2.1 to makes screenshots, install with "
"`pip install whippersnappy>=2.1` !")
raise
return 0

if __name__ == "__main__":
Expand Down
104 changes: 50 additions & 54 deletions CorpusCallosum/paint_cc_into_pred.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,23 +19,20 @@
import sys
from functools import partial
from pathlib import Path
from typing import TypeVar, cast

import nibabel as nib
import numpy as np
from numpy import typing as npt
from scipy import ndimage

import FastSurferCNN.utils.logging as logging
from CorpusCallosum.data.constants import FORNIX_LABEL, SUBSEGMENT_LABELS
from FastSurferCNN.data_loader.conform import is_conform
from FastSurferCNN.data_loader.data_utils import load_image
from FastSurferCNN.reduce_to_aseg import reduce_to_aseg_and_save
from FastSurferCNN.utils import Mask2d, Mask3d, Shape3d, logging
from FastSurferCNN.utils.arg_types import path_or_none
from FastSurferCNN.utils.brainvolstats import mask_in_array
from FastSurferCNN.utils.parallel import thread_executor

_T = TypeVar("_T", bound=np.number)

logger = logging.get_logger(__name__)

HELPTEXT = """
Expand All @@ -55,7 +52,8 @@

Original Author: Leonie Henschel
Date: Jul-10-2020

Modified by: Clemens Pollak, David Kügler
Date: Dec-2025
"""


Expand Down Expand Up @@ -110,26 +108,23 @@ def make_parser() -> argparse.ArgumentParser:
return parser


def paint_in_cc(pred: npt.NDArray[np.int_],
aseg_cc: npt.NDArray[np.int_]) -> npt.NDArray[np.int_]:
def paint_in_cc(
pred: np.ndarray[Shape3d, np.dtype[int]],
aseg_cc: np.ndarray[Shape3d, np.dtype[int]],
) -> np.ndarray[Shape3d, np.dtype[int]]:
"""Paint corpus callosum segmentation into aseg+dkt segmentation map.

Parameters
----------
pred : npt.NDArray[np.int_]
pred : np.ndarray
Deep-learning segmentation map.
aseg_cc : npt.NDArray[np.int_]
aseg_cc : np.ndarray
Aseg segmentation with CC.

Returns
-------
npt.NDArray[np.int_]
np.ndarray
Segmentation map with added CC.

Notes
-----
This function modifies the original array and does not create a copy.
The CC labels (251-255) from aseg_cc are copied into pred.
"""
cc_mask = mask_in_array(aseg_cc, SUBSEGMENT_LABELS)

Expand All @@ -142,14 +137,14 @@ def paint_in_cc(pred: npt.NDArray[np.int_],
logger.info(f"Painting CC: {np.sum(cc_mask)} voxels (replacing {num_wm_replaced} WM, "
f"{num_background_replaced} background, {num_other_replaced} other)")

pred[cc_mask] = aseg_cc[cc_mask]
return pred

def _fill_gaps_in_direction(
corrected_pred: npt.NDArray[np.int_],
potential_fill: npt.NDArray[np.bool_],
source_binary: npt.NDArray[np.bool_],
target_binary: npt.NDArray[np.bool_],
out = np.where(cc_mask, aseg_cc, pred)
return out

def _fill_gaps_in_direction_(
corrected_pred: np.ndarray[Shape3d, np.dtype[int]],
potential_fill: Mask2d,
source_binary: Mask2d,
target_binary: Mask2d,
x_slice: int,
direction: str,
max_gap_voxels: int,
Expand All @@ -159,13 +154,13 @@ def _fill_gaps_in_direction(

Parameters
----------
corrected_pred : npt.NDArray[np.int_]
corrected_pred : np.ndarray
The segmentation array to modify in place.
potential_fill : npt.NDArray[np.bool_]
potential_fill : np.ndarray
2D mask of potential fill regions for this slice.
source_binary : npt.NDArray[np.bool_]
source_binary : np.ndarray
2D binary mask of source structure (e.g., CC).
target_binary : npt.NDArray[np.bool_]
target_binary : np.ndarray
2D binary mask of target structure (e.g., ventricle).
x_slice : int
The x-coordinate of the current slice.
Expand Down Expand Up @@ -254,10 +249,10 @@ def _fill_gaps_in_direction(
return voxels_filled


def _fill_gaps_between_structures(
corrected_pred: npt.NDArray[np.int_],
source_mask: npt.NDArray[np.bool_],
target_mask: npt.NDArray[np.bool_],
def _fill_gaps_between_structures_(
corrected_pred: np.ndarray[Shape3d, np.dtype[int]],
source_mask: Mask3d,
target_mask: Mask3d,
voxel_size: tuple[float, float, float],
close_gap_size_mm: float,
fillable_labels: set[int],
Expand All @@ -267,11 +262,11 @@ def _fill_gaps_between_structures(

Parameters
----------
corrected_pred : npt.NDArray[np.int_]
corrected_pred : np.ndarray
The segmentation array to modify in place.
source_mask : npt.NDArray[np.bool_]
source_mask : np.ndarray
3D binary mask of source structure (e.g., CC).
target_mask : npt.NDArray[np.bool_]
target_mask : np.ndarray
3D binary mask of target structure (e.g., ventricle or background).
voxel_size : tuple[float, float, float]
Voxel size in mm.
Expand Down Expand Up @@ -315,13 +310,13 @@ def _fill_gaps_between_structures(
potential_fill = (source_dilated & target_dilated) & ~(source_binary | target_binary)

# Fill gaps in inferior-superior direction
voxels_filled += _fill_gaps_in_direction(
voxels_filled += _fill_gaps_in_direction_(
corrected_pred, potential_fill, source_binary, target_binary,
x, 'inferior-superior', max_gap_vox_inferior_superior, fillable_labels
)

# Fill gaps in anterior-posterior direction
voxels_filled += _fill_gaps_in_direction(
voxels_filled += _fill_gaps_in_direction_(
corrected_pred, potential_fill, source_binary, target_binary,
x, 'anterior-posterior', max_gap_vox_anterior_posterior, fillable_labels
)
Expand All @@ -333,11 +328,11 @@ def _fill_gaps_between_structures(


def correct_wm_ventricles(
aseg_cc: npt.NDArray[np.int_],
fornix_mask: npt.NDArray[np.bool_],
aseg_cc: np.ndarray[Shape3d, np.dtype[int]],
fornix_mask: Mask3d,
voxel_size: tuple[float, float, float],
close_gap_size_mm: float = 3.0
) -> npt.NDArray[np.int_]:
) -> np.ndarray[Shape3d, np.dtype[int]]:
"""Fill small gaps between corpus callosum, ventricles, and background.

This function performs two gap-filling operations:
Expand All @@ -349,9 +344,9 @@ def correct_wm_ventricles(

Parameters
----------
aseg_cc : npt.NDArray[np.int_]
aseg_cc : np.ndarray
Aseg segmentation with CC already painted in.
fornix_mask : npt.NDArray[np.bool_]
fornix_mask : np.ndarray
Mask of the fornix. Not currently used (kept for interface compatibility).
voxel_size : tuple[float, float, float]
Voxel size of the aseg image in mm.
Expand All @@ -360,7 +355,7 @@ def correct_wm_ventricles(

Returns
-------
npt.NDArray[np.int_]
np.ndarray
Corrected segmentation map with filled gaps.
"""
# Create a copy to avoid modifying the original
Expand All @@ -374,37 +369,38 @@ def correct_wm_ventricles(

# Get background mask
background_mask = aseg_cc == 0

print(np.unique(corrected_pred))

# 1. Fill gaps between CC and ventricles (replace WM and background with ventricle labels)
_fill_gaps_between_structures(
_fill_gaps_between_structures_(
corrected_pred, cc_mask, ventricle_mask, voxel_size, close_gap_size_mm,
fillable_labels={0, 2, 41}, # background and WM
description="between CC and ventricles (WM/background → ventricle)"
)

print(np.unique(corrected_pred))

# 2. Fill WM gaps between CC and background (replace WM with background)
_fill_gaps_between_structures(
_fill_gaps_between_structures_(
corrected_pred, cc_mask, background_mask, voxel_size, close_gap_size_mm,
fillable_labels={2, 41}, # only WM
description="between CC and background (WM → background)"
)
print(np.unique(corrected_pred))

return corrected_pred


if __name__ == "__main__":
from FastSurferCNN.utils import nibabelImage

# Command Line options are error checking done here
options = argument_parse()

logging.setup_logging()

logger.info(f"Reading inputs: {options.input_cc} {options.input_pred}...")
cc_seg_image = cast(nibabelImage, nib.load(options.input_cc))
cc_seg_data = np.asanyarray(cc_seg_image.dataobj)
aseg_image = cast(nibabelImage, nib.load(options.input_pred))
aseg_data = np.asanyarray(aseg_image.dataobj)

tmap = thread_executor().map
(cc_seg_image, cc_seg_data), (aseg_image, aseg_data) = tmap(load_image, (options.input_cc, options.input_pred))

def _is_conform(img, dtype, verbose):
return is_conform(img, vox_size=None, img_size=None, verbose=verbose, dtype=dtype)
Expand Down Expand Up @@ -433,8 +429,8 @@ def _is_conform(img, dtype, verbose):
initial_wm = np.sum((aseg_data == 2) | (aseg_data == 41))
initial_ventricles = np.sum((aseg_data == 4) | (aseg_data == 43))

# Paint CC into prediction (modifies aseg_data in place)
paint_in_cc(aseg_data, cc_seg_data)
# Paint CC into prediction
aseg_data = paint_in_cc(aseg_data, cc_seg_data)

# Apply ventricle gap filling corrections
fornix_mask = cc_seg_data == FORNIX_LABEL
Expand Down
Loading