From ec0cf74a78b08ff7d8491291bab75ac8096da093 Mon Sep 17 00:00:00 2001 From: Andrew Van Date: Sat, 25 Apr 2026 16:46:35 -0500 Subject: [PATCH 1/4] :sparkles: Expose typed Python API for the seven wk-* operations MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Each warpkit.scripts. module now exports a typed function (``medic``, ``unwrap_phase``, ``compute_fieldmap``, ``apply_warp``, ``convert_warp``, ``convert_fieldmap``, ``compute_jacobian``) and a frozen Result dataclass alongside the existing ``main()`` argparse entry point. ``main()`` is now a thin shim: parse argv → call the typed function → forward ``ValueError`` to ``parser.error`` so existing CLI behaviour (``SystemExit(2)``, ``error: ...`` on stderr) is preserved. A new top-level ``warpkit.api`` module re-exports all seven functions + result types for convenient ``from warpkit.api import medic`` usage by library integrators (e.g. nipype interfaces, fmriprep). - New ``warpkit/scripts/_metadata.py`` centralises BIDS sidecar loading, the either-or/mutex acquisition resolution, image coercion, and noise-frame trimming previously duplicated across three scripts. - ``warpkit/scripts/_warp_io.py`` no longer takes an ``argparse.ArgumentParser``; ``read_input_frames`` / ``write_output`` raise ``ValueError`` and the CLI shims forward to ``parser.error``. ``read_input_frames`` also accepts already-loaded ``Nifti1Image`` objects in addition to paths. - All seven typed functions are keyword-only and accept either paths or ``nib.Nifti1Image`` for image inputs; outputs are written and returned as absolute ``pathlib.Path`` in the result dataclass. No CLI behaviour changes; all 207 existing tests pass. --- tests/test_scripts.py | 4 +- warpkit/api.py | 48 ++++++ warpkit/scripts/_metadata.py | 127 +++++++++++++++ warpkit/scripts/_warp_io.py | 44 ++--- warpkit/scripts/apply_warp.py | 242 ++++++++++++++++++---------- warpkit/scripts/compute_fieldmap.py | 216 +++++++++++++++---------- warpkit/scripts/compute_jacobian.py | 114 +++++++++---- warpkit/scripts/convert_fieldmap.py | 235 ++++++++++++++++++--------- warpkit/scripts/convert_warp.py | 212 +++++++++++++++--------- warpkit/scripts/medic.py | 236 ++++++++++++++++----------- warpkit/scripts/unwrap_phase.py | 200 +++++++++++++---------- 11 files changed, 1127 insertions(+), 551 deletions(-) create mode 100644 warpkit/api.py create mode 100644 warpkit/scripts/_metadata.py diff --git a/tests/test_scripts.py b/tests/test_scripts.py index 0160f1d..ca36580 100644 --- a/tests/test_scripts.py +++ b/tests/test_scripts.py @@ -1693,14 +1693,12 @@ def test_bundle_frames_to_3d_series_clears_vector_intent(): def test_write_output_per_frame_map_clears_vector_intent(tmp_path): """Per-frame map outputs must round-trip without a stale vector intent.""" - import argparse - from warpkit.scripts._warp_io import write_output frames = [_vector_intent_frame() for _ in range(2)] out_paths = [str(tmp_path / "f1.nii"), str(tmp_path / "f2.nii")] - write_output(frames, out_paths, "map", argparse.ArgumentParser()) + write_output(frames, out_paths, "map") for p in out_paths: loaded = _load(p) diff --git a/warpkit/api.py b/warpkit/api.py new file mode 100644 index 0000000..a333a06 --- /dev/null +++ b/warpkit/api.py @@ -0,0 +1,48 @@ +"""Typed Python entry points for the seven warpkit operations. + +Every ``wk-*`` CLI tool is mirrored here as a typed Python function so +library callers (e.g. nipype interfaces, fmriprep) can drive warpkit without +shelling out or fabricating ``argv``. Each function takes keyword-only +arguments, raises :class:`ValueError` on validation problems, writes its +outputs to disk, and returns a frozen dataclass with the absolute paths of +the written NIfTIs. + +Mapping CLI -> Python: + +* ``wk-medic`` -> :func:`medic` -> :class:`MedicResult` +* ``wk-unwrap-phase`` -> :func:`unwrap_phase` -> :class:`UnwrapPhaseResult` +* ``wk-compute-fieldmap`` -> :func:`compute_fieldmap` -> :class:`ComputeFieldmapResult` +* ``wk-apply-warp`` -> :func:`apply_warp` -> :class:`ApplyWarpResult` +* ``wk-convert-warp`` -> :func:`convert_warp` -> :class:`ConvertWarpResult` +* ``wk-convert-fieldmap`` -> :func:`convert_fieldmap` -> :class:`ConvertFieldmapResult` +* ``wk-compute-jacobian`` -> :func:`compute_jacobian` -> :class:`ComputeJacobianResult` + +The CLI flag-name → Python kwarg mapping is the obvious dash-to-underscore +transform; the only difference is ``--TEs`` (kept for MR convention) → +``tes`` (lowercase, per repo style). +""" + +from .scripts.apply_warp import ApplyWarpResult, apply_warp +from .scripts.compute_fieldmap import ComputeFieldmapResult, compute_fieldmap +from .scripts.compute_jacobian import ComputeJacobianResult, compute_jacobian +from .scripts.convert_fieldmap import ConvertFieldmapResult, convert_fieldmap +from .scripts.convert_warp import ConvertWarpResult, convert_warp +from .scripts.medic import MedicResult, medic +from .scripts.unwrap_phase import UnwrapPhaseResult, unwrap_phase + +__all__ = [ + "ApplyWarpResult", + "ComputeFieldmapResult", + "ComputeJacobianResult", + "ConvertFieldmapResult", + "ConvertWarpResult", + "MedicResult", + "UnwrapPhaseResult", + "apply_warp", + "compute_fieldmap", + "compute_jacobian", + "convert_fieldmap", + "convert_warp", + "medic", + "unwrap_phase", +] diff --git a/warpkit/scripts/_metadata.py b/warpkit/scripts/_metadata.py new file mode 100644 index 0000000..938237c --- /dev/null +++ b/warpkit/scripts/_metadata.py @@ -0,0 +1,127 @@ +"""Shared acquisition-metadata helpers for the warpkit script entry points. + +Centralises: + +* coercing user-supplied images (``Path`` / ``str`` / ``Nifti1Image``) into + ``Nifti1Image`` objects, +* loading echo time / total readout time / phase encoding direction from + BIDS-style JSON sidecars, +* the "either ``--metadata`` or direct args" mutex/either-or check shared by + ``wk-medic``, ``wk-unwrap-phase``, and ``wk-compute-fieldmap``. + +Validation errors raise :class:`ValueError`; the CLI shims forward those to +``parser.error`` so the user-visible behaviour (``SystemExit(2)``) is +preserved. +""" + +from __future__ import annotations + +import json +from collections.abc import Sequence +from os import PathLike +from typing import cast + +import nibabel as nib + + +def ensure_image(x: PathLike[str] | str | nib.Nifti1Image) -> nib.Nifti1Image: + """Coerce a path or in-memory image into a ``Nifti1Image``.""" + if isinstance(x, nib.Nifti1Image): + return x + return cast(nib.Nifti1Image, nib.load(str(x))) + + +def ensure_images( + xs: Sequence[PathLike[str] | str | nib.Nifti1Image], +) -> list[nib.Nifti1Image]: + return [ensure_image(x) for x in xs] + + +def load_acquisition_from_metadata( + metadata_paths: Sequence[PathLike[str] | str], +) -> tuple[list[float], float, str]: + """Read EchoTime (s → ms), TotalReadoutTime (s) and PhaseEncodingDirection + from BIDS-style JSON sidecars. + + Per-echo ``EchoTime`` is read from each file; ``TotalReadoutTime`` and + ``PhaseEncodingDirection`` are taken from the first. + """ + metadatas = [] + for j in metadata_paths: + with open(j) as f: + metadatas.append(json.load(f)) + tes_ms = [float(m["EchoTime"]) * 1000 for m in metadatas] + trt = float(metadatas[0]["TotalReadoutTime"]) + ped = str(metadatas[0]["PhaseEncodingDirection"]) + return tes_ms, trt, ped + + +def resolve_acquisition( + *, + metadata: Sequence[PathLike[str] | str] | None, + tes: Sequence[float] | None, + total_readout_time: float | None = None, + phase_encoding_direction: str | None = None, + require_trt_pe: bool = True, +) -> tuple[list[float], float | None, str | None]: + """Resolve echo times / TRT / PED from either BIDS metadata or direct args. + + Mirrors the either-or / mutex logic in the CLI scripts. Set + ``require_trt_pe=False`` for callers that only need echo times (e.g. + ``unwrap_phase``); the error message then matches that script's wording. + + Error messages reference the dash-form CLI flags (``--metadata``, + ``--TEs``, ...) so CLI tests continue to match; nipype/library callers + will see the same text via ``ValueError``. + """ + flag_map = { + "tes": "--TEs", + "total_readout_time": "--total-readout-time", + "phase_encoding_direction": "--phase-encoding-direction", + } + if require_trt_pe: + direct_vals = { + "tes": tes, + "total_readout_time": total_readout_time, + "phase_encoding_direction": phase_encoding_direction, + } + else: + direct_vals = {"tes": tes} + direct_supplied = [k for k, v in direct_vals.items() if v is not None] + + if metadata is not None and direct_supplied: + names = ", ".join(flag_map[k] for k in direct_supplied) + raise ValueError( + f"--metadata is mutually exclusive with {names}; pass one or the " + "other, not both." + ) + if metadata is None and len(direct_supplied) != len(direct_vals): + if require_trt_pe: + missing = [flag_map[k] for k in direct_vals if k not in direct_supplied] + raise ValueError( + "either --metadata or all of --TEs, --total-readout-time, and " + f"--phase-encoding-direction must be provided (missing: {', '.join(missing)})." + ) + raise ValueError("either --metadata or --TEs must be provided.") + + if metadata is not None: + tes_resolved, trt_resolved, ped_resolved = load_acquisition_from_metadata( + metadata + ) + if require_trt_pe: + return tes_resolved, trt_resolved, ped_resolved + return tes_resolved, None, None + + return list(tes or []), total_readout_time, phase_encoding_direction + + +def trim_noise_frames(images: list[nib.Nifti1Image], n: int) -> list[nib.Nifti1Image]: + """Trim the last ``n`` frames from each 4D image. Returns the input list + unchanged when ``n == 0``.""" + if n == 0: + return images + if n < 0: + raise ValueError(f"noise_frames must be non-negative; got {n}.") + return [ + nib.Nifti1Image(img.dataobj[..., :-n], img.affine, img.header) for img in images + ] diff --git a/warpkit/scripts/_warp_io.py b/warpkit/scripts/_warp_io.py index 17ccca0..67c44b3 100644 --- a/warpkit/scripts/_warp_io.py +++ b/warpkit/scripts/_warp_io.py @@ -4,33 +4,42 @@ "1+ files of maps or fields" input model and the same "1 bundled file or N per-frame files" output model. This module hosts the frame splitting and bundling helpers that the scripts share. + +Validation errors are raised as :class:`ValueError`. The CLI shims catch +``ValueError`` and forward to ``parser.error`` so the user-visible behaviour +(``SystemExit(2)`` with ``error: ...`` on stderr) is preserved. """ from __future__ import annotations -import argparse +from collections.abc import Sequence +from os import PathLike from typing import cast import nibabel as nib import numpy as np +from ._metadata import ensure_image + def read_input_frames( - input_paths: list[str], + inputs: Sequence[PathLike[str] | str | nib.Nifti1Image], from_type: str, - parser: argparse.ArgumentParser, ) -> list[nib.Nifti1Image]: """Load input file(s) and split into a flat list of single-frame images. ``from_type`` is ``"map"`` (1-channel) or ``"field"`` (3-channel) — the - user-supplied input type from ``--from``. Each input may be a single 3D - map, a 4D map series, a 4D field, or a 5D field (singleton or - multi-frame). The returned frames are 3D for maps and 4D - ``(X, Y, Z, 3)`` for fields. + user-supplied input type. Each input may be a single 3D map, a 4D map + series, a 4D field, or a 5D field (singleton or multi-frame). The + returned frames are 3D for maps and 4D ``(X, Y, Z, 3)`` for fields. + + Inputs may be paths, ``str`` paths, or already-loaded ``Nifti1Image`` + objects. """ frames: list[nib.Nifti1Image] = [] - for p in input_paths: - img = cast(nib.Nifti1Image, nib.load(p)) + for idx, p in enumerate(inputs): + img = ensure_image(p) + label = str(p) if not isinstance(p, nib.Nifti1Image) else f"input #{idx}" if from_type == "map": if img.ndim == 3: frames.append(img) @@ -39,8 +48,8 @@ def read_input_frames( fd = np.asarray(img.dataobj[..., i]) frames.append(nib.Nifti1Image(fd, img.affine, img.header)) else: - parser.error( - f"map input must be 3D or 4D; got shape {img.shape} for {p}" + raise ValueError( + f"map input must be 3D or 4D; got shape {img.shape} for {label}" ) else: # field if img.ndim == 4 and img.shape[-1] == 3: @@ -51,9 +60,9 @@ def read_input_frames( fd = np.asarray(img.dataobj[..., i, :]) frames.append(nib.Nifti1Image(fd, img.affine, img.header)) else: - parser.error( + raise ValueError( "field input must be 4D (X,Y,Z,3) or 5D (X,Y,Z,T,3); " - f"got shape {img.shape} for {p}" + f"got shape {img.shape} for {label}" ) return frames @@ -89,9 +98,8 @@ def bundle_frames_to_field_series(frames: list[nib.Nifti1Image]) -> nib.Nifti1Im def write_output( frames: list[nib.Nifti1Image], - out_paths: list[str], + out_paths: list[str] | list[PathLike[str]], out_type: str, - parser: argparse.ArgumentParser, ) -> None: """Write per-frame images either bundled into one file (when exactly one output path is given for >1 frames) or one file per frame. @@ -108,7 +116,7 @@ def write_output( if out_type == "map" else bundle_frames_to_field_series(frames) ) - bundled.to_filename(out_paths[0]) + bundled.to_filename(str(out_paths[0])) elif n_out == n: for path, img in zip(out_paths, frames, strict=True): # Per-frame map outputs may carry a vector intent inherited from an @@ -118,9 +126,9 @@ def write_output( header = cast(nib.Nifti1Header, img.header.copy()) header.set_intent("none", (), "") img = nib.Nifti1Image(np.asarray(img.dataobj), img.affine, header) - img.to_filename(path) + img.to_filename(str(path)) else: - parser.error( + raise ValueError( f"got {n_out} --output path(s) for {n} frame(s); must be 1 " f"(bundle into a single file) or {n} (one per frame)" ) diff --git a/warpkit/scripts/apply_warp.py b/warpkit/scripts/apply_warp.py index bc3501b..e65be99 100644 --- a/warpkit/scripts/apply_warp.py +++ b/warpkit/scripts/apply_warp.py @@ -1,5 +1,19 @@ +"""``wk-apply-warp`` — resample an image through a displacement transform. + +Public surface: + +* :func:`apply_warp` — typed Python entry point. Returns an + :class:`ApplyWarpResult` with the absolute path of the written NIfTI. +* :func:`main` — argparse CLI shim. +""" + +from __future__ import annotations + import argparse -from collections.abc import Callable +from collections.abc import Callable, Sequence +from dataclasses import dataclass +from os import PathLike +from pathlib import Path from typing import cast import nibabel as nib @@ -16,6 +30,12 @@ ) from . import epilog +from ._metadata import ensure_image, ensure_images + + +@dataclass(frozen=True, slots=True) +class ApplyWarpResult: + output: Path def _build_transform_getter( @@ -23,26 +43,25 @@ def _build_transform_getter( transform_type: str, phase_encoding_axis: str | None, in_format: str, - parser: argparse.ArgumentParser, ) -> tuple[int, str, Callable[[int], nib.Nifti1Image]]: """Validate and wrap the user-supplied transform inputs. - Returns (frame_count, transform_type, getter). The getter is a callable - that takes a 0-indexed frame number and returns an itk-format - ``Nifti1Image`` ready for ``resample_image``. Single-frame transforms are - cached on first access. + Returns (frame_count, transform_type, getter). The getter takes a + 0-indexed frame number and returns an itk-format ``Nifti1Image`` ready + for ``resample_image``. Single-frame transforms are cached on first + access. Validation errors raise :class:`ValueError`. """ if len(transforms) > 1: if transform_type != "field": - parser.error( - "--transform-type=map is incompatible with a multi-file " - "--transform series (each file must be a 3-channel field). " - "Pass a single 4D map series or use --transform-type field." + raise ValueError( + "transform_type='map' is incompatible with a multi-file " + "transform series (each file must be a 3-channel field). " + "Pass a single 4D map series or use transform_type='field'." ) for t in transforms: if t.ndim != 4 or t.shape[-1] != 3: - parser.error( - "when --transform is a series, each file must be a 4D " + raise ValueError( + "when transform is a series, each file must be a 4D " f"3-channel field (X,Y,Z,3); got shape {t.shape}" ) cache: list[nib.Nifti1Image | None] = [None] * len(transforms) @@ -60,8 +79,8 @@ def get_series(i: int) -> nib.Nifti1Image: if transform_type == "map": if not phase_encoding_axis: - parser.error( - "--phase-encoding-axis is required when the transform is a " + raise ValueError( + "phase_encoding_axis is required when the transform is a " "1-channel displacement map." ) axis = cast(str, phase_encoding_axis) @@ -80,7 +99,7 @@ def get_series(i: int) -> nib.Nifti1Image: t, axis=axis, format="itk", frame=i ), ) - parser.error( + raise ValueError( f"displacement map must be 3D or 4D; got {t.ndim}D shape {t.shape}" ) @@ -101,12 +120,116 @@ def get_5d_frame(i: int) -> nib.Nifti1Image: return convert_warp(frame_img, in_type=in_format, out_type="itk") return n, "field", get_5d_frame - parser.error( + raise ValueError( "displacement field must be 4D (X,Y,Z,3) or 5D (X,Y,Z,T,3); " f"got shape {t.shape}" ) +def apply_warp( + *, + input: PathLike[str] | str | nib.Nifti1Image, + transform: Sequence[PathLike[str] | str | nib.Nifti1Image], + output: PathLike[str] | str, + transform_type: str, + reference: PathLike[str] | str | nib.Nifti1Image | None = None, + phase_encoding_axis: str | None = None, + format: str = "itk", +) -> ApplyWarpResult: + """Resample ``input`` through one or more displacement transforms and + write the result to ``output``. + + ``transform_type`` is ``"map"`` (1-channel along ``phase_encoding_axis``) + or ``"field"`` (3-channel). A multi-file transform must be ``"field"``; + a single-frame transform broadcasts across all frames of a 4D input. + + ``format`` is the input format of 3-channel fields and must be one of + ``itk`` / ``fsl`` / ``ants`` / ``afni`` (default ``itk``). + """ + if transform_type not in ("map", "field"): + raise ValueError( + f"transform_type must be 'map' or 'field'; got {transform_type!r}" + ) + if format not in WARP_ITK_FLIPS: + raise ValueError( + f"format must be one of {tuple(WARP_ITK_FLIPS)}; got {format!r}" + ) + if phase_encoding_axis is not None and phase_encoding_axis not in AXIS_MAP: + raise ValueError( + f"phase_encoding_axis must be one of {tuple(AXIS_MAP)}; " + f"got {phase_encoding_axis!r}" + ) + + input_img = ensure_image(input) + transforms = ensure_images(transform) + + reference_img: nib.Nifti1Image + if reference is not None: + reference_img = ensure_image(reference) + elif input_img.ndim == 3: + reference_img = input_img + else: + reference_img = nib.Nifti1Image( + np.asarray(input_img.dataobj[..., 0]), + input_img.affine, + input_img.header, + ) + + n_transform, post_transform_type, get_transform = _build_transform_getter( + transforms, + transform_type=transform_type, + phase_encoding_axis=phase_encoding_axis, + in_format=format, + ) + + if input_img.ndim == 3: + n_input = 1 + elif input_img.ndim == 4: + n_input = input_img.shape[-1] + else: + raise ValueError( + f"input image must be 3D or 4D; got {input_img.ndim}D shape " + f"{input_img.shape}" + ) + + if n_transform > 1 and n_input == 1: + raise ValueError( + f"got a {n_transform}-frame transform but input is 3D; input " + "must be 4D when applying a series of transforms." + ) + if n_transform > 1 and n_transform != n_input: + raise ValueError( + f"transform has {n_transform} frame(s) but input has {n_input}; " + "they must match (or pass a single-frame transform to broadcast)." + ) + + print( + f" input frames: {n_input}; transform frames: {n_transform} " + f"(type={post_transform_type})" + ) + + if n_input == 1: + out_img = resample_image(reference_img, input_img, get_transform(0)) + else: + out_frames = [] + for i in range(n_input): + frame_data = np.asarray(input_img.dataobj[..., i]) + frame_img = nib.Nifti1Image(frame_data, input_img.affine, input_img.header) + t_idx = i if n_transform > 1 else 0 + resampled = resample_image(reference_img, frame_img, get_transform(t_idx)) + out_frames.append(resampled.get_fdata()) + if (i + 1) % 10 == 0 or (i + 1) == n_input: + print(f" resampled frame {i + 1}/{n_input}") + out_data = np.stack(out_frames, axis=-1).astype(np.float32) + out_img = nib.Nifti1Image(out_data, reference_img.affine, reference_img.header) + + output_path = Path(str(output)).resolve() + print(f"Saving resampled image to {output_path}...") + out_img.to_filename(str(output_path)) + print("Done.") + return ApplyWarpResult(output=output_path) + + def main(): parser = argparse.ArgumentParser( description=( @@ -177,78 +300,23 @@ def main(): ) args = parser.parse_args() - setup_logging() print(f"wk-apply-warp: {args}") - input_img = cast(nib.Nifti1Image, nib.load(args.input)) - transforms = [cast(nib.Nifti1Image, nib.load(p)) for p in args.transform] - - # determine the reference grid - reference_img: nib.Nifti1Image - if args.reference: - reference_img = cast(nib.Nifti1Image, nib.load(args.reference)) - elif input_img.ndim == 3: - reference_img = input_img - else: - reference_img = nib.Nifti1Image( - np.asarray(input_img.dataobj[..., 0]), - input_img.affine, - input_img.header, - ) - - # build a per-frame transform getter - n_transform, transform_type, get_transform = _build_transform_getter( - transforms, - transform_type=args.transform_type, - phase_encoding_axis=args.phase_encoding_axis, - in_format=args.format, - parser=parser, - ) - - # input frame count - if input_img.ndim == 3: - n_input = 1 - elif input_img.ndim == 4: - n_input = input_img.shape[-1] - else: - parser.error( - f"input image must be 3D or 4D; got {input_img.ndim}D shape {input_img.shape}" - ) - - # compatibility checks - if n_transform > 1 and n_input == 1: - parser.error( - f"got a {n_transform}-frame transform but input is 3D; input " - "must be 4D when applying a series of transforms." - ) - if n_transform > 1 and n_transform != n_input: - parser.error( - f"transform has {n_transform} frame(s) but input has {n_input}; " - "they must match (or pass a single-frame transform to broadcast)." + # Map ValueError messages onto the dash-form CLI flags so existing + # parser.error-style error text continues to make sense to CLI users. + try: + apply_warp( + input=args.input, + transform=args.transform, + output=args.output, + transform_type=args.transform_type, + reference=args.reference, + phase_encoding_axis=args.phase_encoding_axis, + format=args.format, ) - - print( - f" input frames: {n_input}; transform frames: {n_transform} " - f"(type={transform_type})" - ) - - # resample - if n_input == 1: - out_img = resample_image(reference_img, input_img, get_transform(0)) - else: - out_frames = [] - for i in range(n_input): - frame_data = np.asarray(input_img.dataobj[..., i]) - frame_img = nib.Nifti1Image(frame_data, input_img.affine, input_img.header) - t_idx = i if n_transform > 1 else 0 - resampled = resample_image(reference_img, frame_img, get_transform(t_idx)) - out_frames.append(resampled.get_fdata()) - if (i + 1) % 10 == 0 or (i + 1) == n_input: - print(f" resampled frame {i + 1}/{n_input}") - out_data = np.stack(out_frames, axis=-1).astype(np.float32) - out_img = nib.Nifti1Image(out_data, reference_img.affine, reference_img.header) - - print(f"Saving resampled image to {args.output}...") - out_img.to_filename(args.output) - print("Done.") + except ValueError as e: + msg = str(e) + msg = msg.replace("transform_type=", "--transform-type=") + msg = msg.replace("phase_encoding_axis", "--phase-encoding-axis") + parser.error(msg) diff --git a/warpkit/scripts/compute_fieldmap.py b/warpkit/scripts/compute_fieldmap.py index cf0cb19..441de59 100644 --- a/warpkit/scripts/compute_fieldmap.py +++ b/warpkit/scripts/compute_fieldmap.py @@ -1,6 +1,20 @@ +"""``wk-compute-fieldmap`` — post-unwrap stage of MEDIC: compute B0 field maps +and EPI distortion-correction displacement maps from previously unwrapped +multi-echo phase. + +Public surface: + +* :func:`compute_fieldmap` — typed Python entry point. +* :func:`main` — argparse CLI shim. +""" + +from __future__ import annotations + import argparse -import json -from typing import cast +from collections.abc import Sequence +from dataclasses import dataclass +from os import PathLike +from pathlib import Path import nibabel as nib import numpy as np @@ -15,10 +29,109 @@ ) from . import epilog +from ._metadata import ensure_image, ensure_images, resolve_acquisition PE_DIRECTIONS = ("i", "j", "k", "i-", "j-", "k-", "x", "y", "z", "x-", "y-", "z-") +@dataclass(frozen=True, slots=True) +class ComputeFieldmapResult: + """Same three-tuple as :class:`MedicResult`.""" + + fieldmap_native: Path + displacement_map: Path + fieldmap: Path + + +def compute_fieldmap( + *, + unwrapped: Sequence[PathLike[str] | str | nib.Nifti1Image], + magnitude: Sequence[PathLike[str] | str | nib.Nifti1Image], + masks: PathLike[str] | str | nib.Nifti1Image, + out_prefix: PathLike[str] | str, + tes: Sequence[float] | None = None, + total_readout_time: float | None = None, + phase_encoding_direction: str | None = None, + metadata: Sequence[PathLike[str] | str] | None = None, + border_filt: Sequence[int] = (1, 5), + svd_filt: int = 10, + n_cpus: int = 4, +) -> ComputeFieldmapResult: + """Compute native-space field map, displacement map, and undistorted-space + field map from already-unwrapped multi-echo phase + masks. + + ``unwrapped`` and ``masks`` are typically the outputs of + :func:`warpkit.scripts.unwrap_phase.unwrap_phase`. + """ + tes_ms, trt, ped = resolve_acquisition( + metadata=metadata, + tes=tes, + total_readout_time=total_readout_time, + phase_encoding_direction=phase_encoding_direction, + require_trt_pe=True, + ) + assert trt is not None and ped is not None + + if len(tes_ms) != len(unwrapped) or len(tes_ms) != len(magnitude): + raise ValueError( + f"got {len(tes_ms)} echo time(s), {len(unwrapped)} unwrapped " + f"file(s), and {len(magnitude)} magnitude file(s); all three " + "must match." + ) + + mag_imgs = ensure_images(magnitude) + unwrapped_imgs = ensure_images(unwrapped) + masks_img = ensure_image(masks) + + border_filt_list = list(border_filt) + if len(border_filt_list) != 2: + raise ValueError( + f"border_filt must have exactly 2 elements; got {len(border_filt_list)}." + ) + border_filt_tuple: tuple[int, int] = (border_filt_list[0], border_filt_list[1]) + + fmaps_native = compute_field_maps( + unwrapped_imgs, + masks_img, + mag_imgs, + list(tes_ms), + border_filt=border_filt_tuple, + svd_filt=svd_filt, + n_cpus=n_cpus, + ) + + # Convert native-space field maps to displacement maps in distorted space, + # invert to get distorted -> undistorted, then re-derive an undistorted- + # space field map. Mirrors warpkit.distortion.medic. + inv_displacement_maps = field_maps_to_displacement_maps(fmaps_native, trt, ped) + dmaps = invert_displacement_maps(inv_displacement_maps, ped) + fmaps = displacement_maps_to_field_maps(dmaps, trt, ped, flip_sign=True) + + if ( + np.corrcoef( + fmaps.dataobj[..., 0].ravel(), + fmaps_native.dataobj[..., 0].ravel(), + )[0, 1] + < 0 + ): + fmaps = nib.Nifti1Image(fmaps.get_fdata() * -1, fmaps.affine, fmaps.header) + + out_prefix_str = str(out_prefix) + print("Saving field maps and displacement maps to file...") + fmap_native_path = Path(f"{out_prefix_str}_fieldmaps_native.nii").resolve() + dmap_path = Path(f"{out_prefix_str}_displacementmaps.nii").resolve() + fmap_path = Path(f"{out_prefix_str}_fieldmaps.nii").resolve() + fmaps_native.to_filename(str(fmap_native_path)) + dmaps.to_filename(str(dmap_path)) + fmaps.to_filename(str(fmap_path)) + print("Done.") + return ComputeFieldmapResult( + fieldmap_native=fmap_native_path, + displacement_map=dmap_path, + fieldmap=fmap_path, + ) + + def main(): parser = argparse.ArgumentParser( description=( @@ -109,89 +222,22 @@ def main(): ) args = parser.parse_args() - - direct_args = { - "--TEs": args.tes, - "--total-readout-time": args.total_readout_time, - "--phase-encoding-direction": args.phase_encoding_direction, - } - direct_supplied = [name for name, val in direct_args.items() if val is not None] - - if args.metadata and direct_supplied: - parser.error( - "--metadata is mutually exclusive with " - f"{', '.join(direct_supplied)}; pass one or the other, not both." - ) - if not args.metadata and len(direct_supplied) != len(direct_args): - missing = [name for name in direct_args if name not in direct_supplied] - parser.error( - "either --metadata or all of --TEs, --total-readout-time, and " - f"--phase-encoding-direction must be provided (missing: {', '.join(missing)})." - ) - - echo_times: list[float] - total_readout_time: float - phase_encoding_direction: str - if args.metadata: - metadatas = [] - for j in args.metadata: - with open(j) as f: - metadatas.append(json.load(f)) - echo_times = [float(m["EchoTime"]) * 1000 for m in metadatas] - total_readout_time = float(metadatas[0]["TotalReadoutTime"]) - phase_encoding_direction = str(metadatas[0]["PhaseEncodingDirection"]) - else: - echo_times = cast(list[float], args.tes) - total_readout_time = cast(float, args.total_readout_time) - phase_encoding_direction = cast(str, args.phase_encoding_direction) - - if len(echo_times) != len(args.unwrapped) or len(echo_times) != len(args.magnitude): - parser.error( - f"got {len(echo_times)} echo time(s), {len(args.unwrapped)} " - f"--unwrapped file(s), and {len(args.magnitude)} --magnitude " - "file(s); all three must match." - ) - setup_logging() print(f"wk-compute-fieldmap: {args}") - mag_imgs = [cast(nib.Nifti1Image, nib.load(m)) for m in args.magnitude] - unwrapped_imgs = [cast(nib.Nifti1Image, nib.load(u)) for u in args.unwrapped] - masks_img = cast(nib.Nifti1Image, nib.load(args.masks)) - - fmaps_native = compute_field_maps( - unwrapped_imgs, - masks_img, - mag_imgs, - echo_times, - border_filt=tuple(args.border_filt), - svd_filt=args.svd_filt, - n_cpus=args.n_cpus, - ) - - # convert native-space field maps to displacement maps in distorted space, - # invert to get distorted -> undistorted, then re-derive an undistorted- - # space field map. Mirrors warpkit.distortion.medic. - inv_displacement_maps = field_maps_to_displacement_maps( - fmaps_native, total_readout_time, phase_encoding_direction - ) - dmaps = invert_displacement_maps(inv_displacement_maps, phase_encoding_direction) - fmaps = displacement_maps_to_field_maps( - dmaps, total_readout_time, phase_encoding_direction, flip_sign=True - ) - - # sign flip if undistorted-space fmap correlates negatively with native - if ( - np.corrcoef( - fmaps.dataobj[..., 0].ravel(), - fmaps_native.dataobj[..., 0].ravel(), - )[0, 1] - < 0 - ): - fmaps = nib.Nifti1Image(fmaps.get_fdata() * -1, fmaps.affine, fmaps.header) - - print("Saving field maps and displacement maps to file...") - fmaps_native.to_filename(f"{args.out_prefix}_fieldmaps_native.nii") - dmaps.to_filename(f"{args.out_prefix}_displacementmaps.nii") - fmaps.to_filename(f"{args.out_prefix}_fieldmaps.nii") - print("Done.") + try: + compute_fieldmap( + unwrapped=args.unwrapped, + magnitude=args.magnitude, + masks=args.masks, + out_prefix=args.out_prefix, + tes=args.tes, + total_readout_time=args.total_readout_time, + phase_encoding_direction=args.phase_encoding_direction, + metadata=args.metadata, + border_filt=tuple(args.border_filt), + svd_filt=args.svd_filt, + n_cpus=args.n_cpus, + ) + except ValueError as e: + parser.error(str(e)) diff --git a/warpkit/scripts/compute_jacobian.py b/warpkit/scripts/compute_jacobian.py index 77736d2..bf64ad4 100644 --- a/warpkit/scripts/compute_jacobian.py +++ b/warpkit/scripts/compute_jacobian.py @@ -1,4 +1,19 @@ +"""``wk-compute-jacobian`` — compute the Jacobian determinant of a +displacement warp. + +Public surface: + +* :func:`compute_jacobian` — typed Python entry point. +* :func:`main` — argparse CLI shim. +""" + +from __future__ import annotations + import argparse +from collections.abc import Sequence +from dataclasses import dataclass +from os import PathLike +from pathlib import Path from typing import cast import nibabel as nib @@ -17,6 +32,11 @@ from ._warp_io import read_input_frames, write_output +@dataclass(frozen=True, slots=True) +class ComputeJacobianResult: + output: list[Path] + + def _frame_to_itk_field( img: nib.Nifti1Image, in_type: str, @@ -32,6 +52,58 @@ def _frame_to_itk_field( return cast(nib.Nifti1Image, convert_warp(img, in_type=in_format, out_type="itk")) +def compute_jacobian( + *, + input: Sequence[PathLike[str] | str | nib.Nifti1Image], + output: Sequence[PathLike[str] | str], + from_type: str, + from_format: str = "itk", + axis: str | None = None, + frame: int | None = None, +) -> ComputeJacobianResult: + """Compute the per-frame Jacobian determinant of a displacement warp. + + Output is one 3D scalar volume per input frame (bundled into a 4D series + if a single ``output`` path is given, or one per frame if N paths). + """ + if from_type not in ("map", "field"): + raise ValueError(f"from_type must be 'map' or 'field'; got {from_type!r}") + if from_format not in WARP_ITK_FLIPS: + raise ValueError( + f"from_format must be one of {tuple(WARP_ITK_FLIPS)}; got {from_format!r}" + ) + if axis is not None and axis not in AXIS_MAP: + raise ValueError(f"axis must be one of {tuple(AXIS_MAP)}; got {axis!r}") + + frames = read_input_frames(list(input), from_type) + + if frame is not None: + if frame < 0 or frame >= len(frames): + raise ValueError( + f"frame {frame} is out of range; input has {len(frames)} frame(s)" + ) + frames = [frames[frame]] + + if from_type == "map" and not axis: + raise ValueError("--axis is required when the input is a displacement map.") + + in_fmt_label = from_format if from_type == "field" else "n/a" + print( + f"wk-compute-jacobian: {len(frames)} frame(s); {from_type}({in_fmt_label}) " + "-> jacobian" + ) + + jacobians: list[nib.Nifti1Image] = [] + for img in frames: + field_itk = _frame_to_itk_field(img, from_type, from_format, axis) + jacobians.append(compute_jacobian_determinant(field_itk)) + + out_paths_resolved = [Path(str(p)).resolve() for p in output] + write_output(jacobians, [str(p) for p in out_paths_resolved], "map") + print("Done.") + return ComputeJacobianResult(output=out_paths_resolved) + + def main(): parser = argparse.ArgumentParser( description=( @@ -96,32 +168,16 @@ def main(): args = parser.parse_args() setup_logging() - frames = read_input_frames(args.input, args.from_type, parser) - in_type = args.from_type - - if args.frame is not None: - if args.frame < 0 or args.frame >= len(frames): - parser.error( - f"--frame {args.frame} is out of range; input has " - f"{len(frames)} frame(s)" - ) - frames = [frames[args.frame]] - - if in_type == "map" and not args.axis: - parser.error("--axis is required when the input is a displacement map.") - - in_fmt_label = args.from_format if in_type == "field" else "n/a" - print( - f"wk-compute-jacobian: {len(frames)} frame(s); {in_type}({in_fmt_label}) " - "-> jacobian" - ) - - jacobians: list[nib.Nifti1Image] = [] - for img in frames: - field_itk = _frame_to_itk_field(img, in_type, args.from_format, args.axis) - jacobians.append(compute_jacobian_determinant(field_itk)) - - # Output is a 3D scalar per frame, so use the same packing as 1-channel - # maps (4D series when bundled into a single file). - write_output(jacobians, args.output, "map", parser) - print("Done.") + try: + compute_jacobian( + input=args.input, + output=args.output, + from_type=args.from_type, + from_format=args.from_format, + axis=args.axis, + frame=args.frame, + ) + except ValueError as e: + msg = str(e) + msg = msg.replace("frame ", "--frame ", 1) if msg.startswith("frame ") else msg + parser.error(msg) diff --git a/warpkit/scripts/convert_fieldmap.py b/warpkit/scripts/convert_fieldmap.py index a0cf441..c0eae8e 100644 --- a/warpkit/scripts/convert_fieldmap.py +++ b/warpkit/scripts/convert_fieldmap.py @@ -1,4 +1,19 @@ +"""``wk-convert-fieldmap`` — convert between mm displacement maps/fields and +Hz B0 field maps. + +Public surface: + +* :func:`convert_fieldmap` — typed Python entry point. +* :func:`main` — argparse CLI shim. +""" + +from __future__ import annotations + import argparse +from collections.abc import Sequence +from dataclasses import dataclass +from os import PathLike +from pathlib import Path import nibabel as nib @@ -19,6 +34,127 @@ PE_DIRECTIONS = tuple(AXIS_MAP) +@dataclass(frozen=True, slots=True) +class ConvertFieldmapResult: + output: list[Path] + + +def convert_fieldmap( + *, + input: Sequence[PathLike[str] | str | nib.Nifti1Image], + output: Sequence[PathLike[str] | str], + from_type: str, + to_type: str, + total_readout_time: float, + phase_encoding_direction: str, + from_format: str = "itk", + to_format: str = "itk", + flip_sign: bool = False, + frame: int | None = None, +) -> ConvertFieldmapResult: + """Convert between mm displacement maps/fields and Hz B0 field maps. + + Exactly one of ``from_type`` / ``to_type`` must be ``"fieldmap"`` (the Hz + side); the other side is ``"map"`` (1-channel mm) or ``"field"`` + (3-channel mm). Use :func:`convert_warp` for representation/format + conversions on the mm side. + """ + if from_type not in ("map", "field", "fieldmap"): + raise ValueError( + f"from_type must be 'map', 'field' or 'fieldmap'; got {from_type!r}" + ) + if to_type not in ("map", "field", "fieldmap"): + raise ValueError( + f"to_type must be 'map', 'field' or 'fieldmap'; got {to_type!r}" + ) + if phase_encoding_direction not in PE_DIRECTIONS: + raise ValueError( + f"phase_encoding_direction must be one of {PE_DIRECTIONS}; " + f"got {phase_encoding_direction!r}" + ) + if from_format not in WARP_ITK_FLIPS: + raise ValueError( + f"from_format must be one of {tuple(WARP_ITK_FLIPS)}; got {from_format!r}" + ) + if to_format not in WARP_ITK_FLIPS: + raise ValueError( + f"to_format must be one of {tuple(WARP_ITK_FLIPS)}; got {to_format!r}" + ) + + if from_type == to_type: + raise ValueError( + f"from_type={from_type} and to_type={to_type} are the same; use " + "convert_warp for representation/format conversions on the mm side." + ) + + crosses_units = (from_type == "fieldmap") != (to_type == "fieldmap") + if not crosses_units: + raise ValueError( + "convert_fieldmap converts between mm (map/field) and Hz " + "(fieldmap); both --from and --to are on the mm side. Use " + "convert_warp instead." + ) + + # _warp_io.read_input_frames knows about map/field only; remap "fieldmap" + # to "map" since on disk a Hz field map is shaped like a 1-channel map. + from_io = "map" if from_type == "fieldmap" else from_type + frames = read_input_frames(list(input), from_io) + + if frame is not None: + if frame < 0 or frame >= len(frames): + raise ValueError( + f"frame {frame} is out of range; input has {len(frames)} frame(s)" + ) + frames = [frames[frame]] + + print( + f"wk-convert-fieldmap: {len(frames)} frame(s); " + f"{from_type} -> {to_type} " + f"(trt={total_readout_time}s, pe={phase_encoding_direction})" + ) + + converted: list[nib.Nifti1Image] = [] + for img in frames: + if to_type == "fieldmap": + # mm side -> Hz fieldmap + if from_type == "field": + map_img = displacement_field_to_map( + img, axis=phase_encoding_direction, format=from_format + ) + else: + map_img = img + converted.append( + displacement_maps_to_field_maps( + map_img, + total_readout_time, + phase_encoding_direction, + flip_sign=flip_sign, + ) + ) + else: + # Hz fieldmap -> mm side + map_img = field_maps_to_displacement_maps( + img, total_readout_time, phase_encoding_direction + ) + if to_type == "field": + converted.append( + displacement_map_to_field( + map_img, + axis=phase_encoding_direction, + format=to_format, + frame=0, + ) + ) + else: + converted.append(map_img) + + out_writer_type = "field" if to_type == "field" else "map" + out_paths_resolved = [Path(str(p)).resolve() for p in output] + write_output(converted, [str(p) for p in out_paths_resolved], out_writer_type) + print("Done.") + return ConvertFieldmapResult(output=out_paths_resolved) + + def main(): parser = argparse.ArgumentParser( description=( @@ -112,87 +248,30 @@ def main(): args = parser.parse_args() setup_logging() - # _warp_io.read_input_frames knows about map/field only; remap "fieldmap" - # to "map" since on disk a Hz field map is shaped like a 1-channel map. - from_arg_for_io = "map" if args.from_type == "fieldmap" else args.from_type - frames = read_input_frames(args.input, from_arg_for_io, parser) - - if args.frame is not None: - if args.frame < 0 or args.frame >= len(frames): - parser.error( - f"--frame {args.frame} is out of range; input has " - f"{len(frames)} frame(s)" - ) - frames = [frames[args.frame]] - - in_type = args.from_type - - if in_type == args.to_type: - parser.error( - f"--from={in_type} and --to={args.to_type} are the same; use " - "wk-convert-warp for representation/format conversions on the " - "mm side." - ) - - crosses_units = (in_type == "fieldmap") != (args.to_type == "fieldmap") - if not crosses_units: - parser.error( - "wk-convert-fieldmap converts between mm (map/field) and Hz " - "(fieldmap); both --from and --to are on the mm side. Use " - "wk-convert-warp instead." - ) - if args.total_readout_time is None or not args.phase_encoding_direction: parser.error( "--total-readout-time and --phase-encoding-direction are " "required for mm <-> Hz conversion." ) - print( - f"wk-convert-fieldmap: {len(frames)} frame(s); " - f"{in_type} -> {args.to_type} " - f"(trt={args.total_readout_time}s, pe={args.phase_encoding_direction})" - ) - - converted: list[nib.Nifti1Image] = [] - for img in frames: - if args.to_type == "fieldmap": - # mm side -> Hz fieldmap - if in_type == "field": - map_img = displacement_field_to_map( - img, - axis=args.phase_encoding_direction, - format=args.from_format, - ) - else: - map_img = img - converted.append( - displacement_maps_to_field_maps( - map_img, - args.total_readout_time, - args.phase_encoding_direction, - flip_sign=args.flip_sign, - ) - ) - else: - # Hz fieldmap -> mm side - map_img = field_maps_to_displacement_maps( - img, - args.total_readout_time, - args.phase_encoding_direction, - ) - if args.to_type == "field": - converted.append( - displacement_map_to_field( - map_img, - axis=args.phase_encoding_direction, - format=args.to_format, - frame=0, - ) - ) - else: - converted.append(map_img) - - out_writer_type = "field" if args.to_type == "field" else "map" - write_output(converted, args.output, out_writer_type, parser) - print("Done.") + try: + convert_fieldmap( + input=args.input, + output=args.output, + from_type=args.from_type, + to_type=args.to_type, + total_readout_time=args.total_readout_time, + phase_encoding_direction=args.phase_encoding_direction, + from_format=args.from_format, + to_format=args.to_format, + flip_sign=args.flip_sign, + frame=args.frame, + ) + except ValueError as e: + msg = str(e) + msg = msg.replace("from_type=", "--from=") + msg = msg.replace("to_type=", "--to=") + msg = msg.replace("convert_fieldmap converts", "wk-convert-fieldmap converts") + msg = msg.replace("Use convert_warp", "Use wk-convert-warp") + msg = msg.replace("frame ", "--frame ", 1) if msg.startswith("frame ") else msg + parser.error(msg) diff --git a/warpkit/scripts/convert_warp.py b/warpkit/scripts/convert_warp.py index 58a9dc8..fa2255e 100644 --- a/warpkit/scripts/convert_warp.py +++ b/warpkit/scripts/convert_warp.py @@ -1,4 +1,21 @@ +"""``wk-convert-warp`` — interconvert displacement maps and displacement +fields, convert between ITK / FSL / ANTs / AFNI format conventions, and +optionally invert the warp along the way. + +Public surface: + +* :func:`convert_warp` — typed Python entry point. Returns a + :class:`ConvertWarpResult` with the absolute paths of the written NIfTIs. +* :func:`main` — argparse CLI shim. +""" + +from __future__ import annotations + import argparse +from collections.abc import Sequence +from dataclasses import dataclass +from os import PathLike +from pathlib import Path import nibabel as nib import numpy as np @@ -7,18 +24,25 @@ from warpkit.utilities import ( AXIS_MAP, WARP_ITK_FLIPS, - convert_warp, displacement_field_to_map, displacement_map_to_field, invert_displacement_field, invert_displacement_maps, setup_logging, ) +from warpkit.utilities import ( + convert_warp as _convert_warp_image, +) from . import epilog from ._warp_io import read_input_frames, write_output +@dataclass(frozen=True, slots=True) +class ConvertWarpResult: + output: list[Path] + + def _invert_frames( frames: list[nib.Nifti1Image], in_type: str, @@ -28,22 +52,8 @@ def _invert_frames( ) -> tuple[list[nib.Nifti1Image], str, str]: """Invert each frame and return ``(frames, post_type, post_format)``. - Routing is by frame count, not by input type, because the 1D map inverter - is markedly faster per frame than the full 3D field inverter: - - * **Single frame** uses :func:`invert_displacement_field`. A map input is - first promoted to a 3-channel itk field via - :func:`displacement_map_to_field`. The result is always in itk field - form. - * **Multi-frame** stacks all frames into a single 4D ``(X, Y, Z, T)`` map - and runs :func:`invert_displacement_maps` once. A field input first has - its ``axis`` channel extracted (off-axis channels are dropped — fine - for the EPI-distortion case, where displacement is along the - phase-encoding axis). The result is always in 1-channel map form. - - The returned ``post_type`` / ``post_format`` are the actual representation - of the inverted frames (which may differ from the input's), so the - downstream conversion stage can route correctly. + See module-level commentary in the original script for the routing + rationale (single-frame vs. multi-frame, map promotion to field, etc.). """ n = len(frames) if n == 1: @@ -55,7 +65,7 @@ def _invert_frames( field_itk = ( img if in_format == "itk" - else convert_warp(img, in_type=in_format, out_type="itk") + else _convert_warp_image(img, in_type=in_format, out_type="itk") ) return [invert_displacement_field(field_itk, verbose=verbose)], "field", "itk" @@ -96,7 +106,7 @@ def _convert_frames( converted.append(img) else: converted.append( - convert_warp(img, in_type=in_format, out_type=out_format) + _convert_warp_image(img, in_type=in_format, out_type=out_format) ) elif in_type == "map" and out_type == "field": assert axis is not None @@ -111,6 +121,97 @@ def _convert_frames( return converted +def convert_warp( + *, + input: Sequence[PathLike[str] | str | nib.Nifti1Image], + output: Sequence[PathLike[str] | str], + from_type: str, + to_type: str | None = None, + from_format: str = "itk", + to_format: str = "itk", + axis: str | None = None, + frame: int | None = None, + invert: bool = False, + verbose: bool = False, +) -> ConvertWarpResult: + """Convert displacement transforms between map/field representations, + between field-format conventions, and optionally invert. + + See ``wk-convert-warp --help`` for the full parameter description; the + Python kwargs are the snake_case equivalents of the CLI flags. + """ + if from_type not in ("map", "field"): + raise ValueError(f"from_type must be 'map' or 'field'; got {from_type!r}") + if to_type is not None and to_type not in ("map", "field"): + raise ValueError(f"to_type must be 'map' or 'field'; got {to_type!r}") + if from_format not in WARP_ITK_FLIPS: + raise ValueError( + f"from_format must be one of {tuple(WARP_ITK_FLIPS)}; got {from_format!r}" + ) + if to_format not in WARP_ITK_FLIPS: + raise ValueError( + f"to_format must be one of {tuple(WARP_ITK_FLIPS)}; got {to_format!r}" + ) + if axis is not None and axis not in AXIS_MAP: + raise ValueError(f"axis must be one of {tuple(AXIS_MAP)}; got {axis!r}") + + out_type = to_type or from_type + frames = read_input_frames(list(input), from_type) + + if frame is not None: + if frame < 0 or frame >= len(frames): + raise ValueError( + f"frame {frame} is out of range; input has {len(frames)} frame(s)" + ) + frames = [frames[frame]] + + multi_frame = len(frames) > 1 + needs_axis = ( + (from_type == "map" and out_type == "field") + or (from_type == "field" and out_type == "map") + or (invert and (from_type == "map" or multi_frame)) + ) + if needs_axis and not axis: + raise ValueError( + "--axis is required when converting between maps and fields, " + "when inverting a single-frame map, or when inverting a " + "multi-frame series (the multi-frame inverter operates along a " + "single axis)." + ) + + in_fmt_label = from_format if from_type == "field" else "n/a" + out_fmt_label = to_format if out_type == "field" else "n/a" + invert_label = " (inverted)" if invert else "" + print( + f"wk-convert-warp: {len(frames)} frame(s); " + f"{from_type}({in_fmt_label}) -> {out_type}({out_fmt_label}){invert_label}" + ) + + post_type, post_format = from_type, from_format + if invert: + frames, post_type, post_format = _invert_frames( + frames, + in_type=from_type, + in_format=from_format, + axis=axis, + verbose=verbose, + ) + + converted = _convert_frames( + frames, + in_type=post_type, + out_type=out_type, + in_format=post_format, + out_format=to_format, + axis=axis, + ) + + out_paths_resolved = [Path(str(p)).resolve() for p in output] + write_output(converted, [str(p) for p in out_paths_resolved], out_type) + print("Done.") + return ConvertWarpResult(output=out_paths_resolved) + + def main(): parser = argparse.ArgumentParser( description=( @@ -210,65 +311,20 @@ def main(): args = parser.parse_args() setup_logging() - frames = read_input_frames(args.input, args.from_type, parser) - in_type = args.from_type - out_type = args.to_type or in_type - - if args.frame is not None: - if args.frame < 0 or args.frame >= len(frames): - parser.error( - f"--frame {args.frame} is out of range; input has " - f"{len(frames)} frame(s)" - ) - frames = [frames[args.frame]] - - # --axis is required for map<->field conversion AND for inversion when - # the chosen inversion routing needs it: the single-frame route promotes - # a map to a field (needs axis), and the multi-frame route always runs - # the 1D map inverter (needs axis for map<->axis-channel). - multi_frame = len(frames) > 1 - needs_axis = ( - (in_type == "map" and out_type == "field") - or (in_type == "field" and out_type == "map") - or (args.invert and (in_type == "map" or multi_frame)) - ) - if needs_axis and not args.axis: - parser.error( - "--axis is required when converting between maps and fields, " - "when inverting a single-frame map, or when inverting a " - "multi-frame series (the multi-frame inverter operates along a " - "single axis)." - ) - - in_fmt_label = args.from_format if in_type == "field" else "n/a" - out_fmt_label = args.to_format if out_type == "field" else "n/a" - invert_label = " (inverted)" if args.invert else "" - print( - f"wk-convert-warp: {len(frames)} frame(s); " - f"{in_type}({in_fmt_label}) -> {out_type}({out_fmt_label}){invert_label}" - ) - - # post_type / post_format track the actual representation of the frames - # after the inversion stage, which may differ from the user-declared - # input (e.g. multi-frame field input emerges as 1-channel maps). - post_type, post_format = in_type, args.from_format - if args.invert: - frames, post_type, post_format = _invert_frames( - frames, - in_type=in_type, - in_format=args.from_format, + try: + convert_warp( + input=args.input, + output=args.output, + from_type=args.from_type, + to_type=args.to_type, + from_format=args.from_format, + to_format=args.to_format, axis=args.axis, + frame=args.frame, + invert=args.invert, verbose=args.verbose, ) - - converted = _convert_frames( - frames, - in_type=post_type, - out_type=out_type, - in_format=post_format, - out_format=args.to_format, - axis=args.axis, - ) - - write_output(converted, args.output, out_type, parser) - print("Done.") + except ValueError as e: + msg = str(e) + msg = msg.replace("frame ", "--frame ", 1) if msg.startswith("frame ") else msg + parser.error(msg) diff --git a/warpkit/scripts/medic.py b/warpkit/scripts/medic.py index 9433726..fdd2b97 100644 --- a/warpkit/scripts/medic.py +++ b/warpkit/scripts/medic.py @@ -1,18 +1,147 @@ +"""``wk-medic`` — full Multi-Echo DIstortion Correction pipeline. + +This module exposes two surfaces: + +* :func:`medic` — the typed Python entry point. Takes paths or in-memory + ``Nifti1Image`` objects, runs MEDIC, writes outputs, and returns a + :class:`MedicResult` with the absolute paths of the three NIfTIs written + (``_fieldmaps_native.nii``, ``_displacementmaps.nii`` and + ``_fieldmaps.nii``). Library/integration code (e.g. nipype interfaces) + should call this. +* :func:`main` — the argparse CLI entry point. Parses ``sys.argv``, then + defers to :func:`medic`. ``ValueError`` from the latter is forwarded to + ``parser.error`` so CLI behaviour (``SystemExit(2)``, ``error: ...`` on + stderr) is preserved. +""" + +from __future__ import annotations + import argparse -import json -from typing import cast +from collections.abc import Sequence +from dataclasses import dataclass +from os import PathLike +from pathlib import Path import nibabel as nib from warpkit import __version__ -from warpkit.distortion import medic +from warpkit.distortion import medic as _medic_distortion from warpkit.utilities import setup_logging from . import epilog +from ._metadata import ensure_images, resolve_acquisition, trim_noise_frames PE_DIRECTIONS = ("i", "j", "k", "i-", "j-", "k-", "x", "y", "z", "x-", "y-", "z-") +@dataclass(frozen=True, slots=True) +class MedicResult: + """Absolute paths of the three NIfTIs written by :func:`medic`.""" + + fieldmap_native: Path + displacement_map: Path + fieldmap: Path + + +def medic( + *, + phase: Sequence[PathLike[str] | str | nib.Nifti1Image], + magnitude: Sequence[PathLike[str] | str | nib.Nifti1Image], + out_prefix: PathLike[str] | str, + tes: Sequence[float] | None = None, + total_readout_time: float | None = None, + phase_encoding_direction: str | None = None, + metadata: Sequence[PathLike[str] | str] | None = None, + noise_frames: int = 0, + n_cpus: int = 4, + wrap_limit: bool = False, + debug: bool = False, +) -> MedicResult: + """Run the full MEDIC pipeline and write the three output NIfTIs. + + Either pass ``metadata`` (one BIDS sidecar per echo) or all three of + ``tes`` (ms), ``total_readout_time`` (s) and ``phase_encoding_direction`` + (``i``/``j``/``k``/``x``/``y``/``z`` with optional trailing ``-``). The + two are mutually exclusive. + + ``noise_frames`` trims that many frames from the end of every + phase/magnitude file before unwrapping (matches the CLI's ``-f``). + + Returns a :class:`MedicResult` with absolute paths of the three written + NIfTIs. + """ + if len(phase) != len(magnitude): + raise ValueError( + f"got {len(phase)} phase file(s) but {len(magnitude)} magnitude " + "file(s); they must match (one mag/phase pair per echo)." + ) + + tes_ms, trt, ped = resolve_acquisition( + metadata=metadata, + tes=tes, + total_readout_time=total_readout_time, + phase_encoding_direction=phase_encoding_direction, + require_trt_pe=True, + ) + # require_trt_pe=True guarantees both are populated. + assert trt is not None and ped is not None + + if len(tes_ms) != len(phase): + raise ValueError( + f"got {len(tes_ms)} echo time(s) but --phase has {len(phase)} " + "file(s); they must match." + ) + + mag_data = ensure_images(magnitude) + phase_data = ensure_images(phase) + + if noise_frames > 0: + print(f"Removing {noise_frames} noise frames from the end of each file...") + mag_data = trim_noise_frames(mag_data, noise_frames) + phase_data = trim_noise_frames(phase_data, noise_frames) + + if debug: + fmaps_native, dmaps, fmaps = _medic_distortion( + phase_data, + mag_data, + tes_ms, + trt, + ped, + n_cpus=n_cpus, + border_filt=(1000, 1000), + svd_filt=1000, + debug=True, + wrap_limit=wrap_limit, + ) + else: + fmaps_native, dmaps, fmaps = _medic_distortion( + phase_data, + mag_data, + tes_ms, + trt, + ped, + n_cpus=n_cpus, + svd_filt=10, + border_size=5, + wrap_limit=wrap_limit, + ) + + out_prefix_str = str(out_prefix) + print("Saving field maps and displacement maps to file...") + fmap_native_path = Path(f"{out_prefix_str}_fieldmaps_native.nii").resolve() + dmap_path = Path(f"{out_prefix_str}_displacementmaps.nii").resolve() + fmap_path = Path(f"{out_prefix_str}_fieldmaps.nii").resolve() + fmaps_native.to_filename(str(fmap_native_path)) + dmaps.to_filename(str(dmap_path)) + fmaps.to_filename(str(fmap_path)) + print("Done.") + return MedicResult( + fieldmap_native=fmap_native_path, + displacement_map=dmap_path, + fieldmap=fmap_path, + ) + + def main(): parser = argparse.ArgumentParser( description="Multi-Echo DIstortion Correction", epilog=f"{epilog}" @@ -74,98 +203,23 @@ def main(): help="Turns off some heuristics for phase unwrapping", ) - # parse arguments args = parser.parse_args() - - direct_args = { - "--TEs": args.tes, - "--total-readout-time": args.total_readout_time, - "--phase-encoding-direction": args.phase_encoding_direction, - } - direct_supplied = [name for name, val in direct_args.items() if val is not None] - - if args.metadata and direct_supplied: - parser.error( - "--metadata is mutually exclusive with " - f"{', '.join(direct_supplied)}; pass one or the other, not both." - ) - if not args.metadata and len(direct_supplied) != len(direct_args): - missing = [name for name in direct_args if name not in direct_supplied] - parser.error( - "either --metadata or all of --TEs, --total-readout-time, and " - f"--phase-encoding-direction must be provided (missing: {', '.join(missing)})." - ) - - if args.metadata: - metadatas = [] - for j in args.metadata: - with open(j) as f: - metadatas.append(json.load(f)) - echo_times = [float(m["EchoTime"]) * 1000 for m in metadatas] - total_readout_time = float(metadatas[0]["TotalReadoutTime"]) - phase_encoding_direction = str(metadatas[0]["PhaseEncodingDirection"]) - else: - echo_times = args.tes - total_readout_time = args.total_readout_time - phase_encoding_direction = args.phase_encoding_direction - - if len(echo_times) != len(args.phase): - parser.error( - f"got {len(echo_times)} echo time(s) but --phase has {len(args.phase)} file(s); they must match." - ) - - # setup logging setup_logging() - - # log arguments print(f"medic: {args}") - # load magnitude and phase data - mag_data = [cast(nib.Nifti1Image, nib.load(m)) for m in args.magnitude] - phase_data = [cast(nib.Nifti1Image, nib.load(p)) for p in args.phase] - - # if noiseframes specified, remove them - if args.noiseframes > 0: - print(f"Removing {args.noiseframes} noise frames from the end of each file...") - mag_data = [ - nib.Nifti1Image(m.dataobj[..., : -args.noiseframes], m.affine, m.header) - for m in mag_data - ] - phase_data = [ - nib.Nifti1Image(p.dataobj[..., : -args.noiseframes], p.affine, p.header) - for p in phase_data - ] - - # now run medic - if args.debug: - fmaps_native, dmaps, fmaps = medic( - phase_data, - mag_data, - echo_times, - total_readout_time, - phase_encoding_direction, + try: + medic( + phase=args.phase, + magnitude=args.magnitude, + out_prefix=args.out_prefix, + tes=args.tes, + total_readout_time=args.total_readout_time, + phase_encoding_direction=args.phase_encoding_direction, + metadata=args.metadata, + noise_frames=args.noiseframes, n_cpus=args.n_cpus, - border_filt=(1000, 1000), - svd_filt=1000, - debug=True, wrap_limit=args.wrap_limit, + debug=args.debug, ) - else: - fmaps_native, dmaps, fmaps = medic( - phase_data, - mag_data, - echo_times, - total_readout_time, - phase_encoding_direction, - n_cpus=args.n_cpus, - svd_filt=10, - border_size=5, - wrap_limit=args.wrap_limit, - ) - - # save the fmaps and dmaps to file - print("Saving field maps and displacement maps to file...") - fmaps_native.to_filename(f"{args.out_prefix}_fieldmaps_native.nii") - dmaps.to_filename(f"{args.out_prefix}_displacementmaps.nii") - fmaps.to_filename(f"{args.out_prefix}_fieldmaps.nii") - print("Done.") + except ValueError as e: + parser.error(str(e)) diff --git a/warpkit/scripts/unwrap_phase.py b/warpkit/scripts/unwrap_phase.py index 5246eaa..bffaa22 100644 --- a/warpkit/scripts/unwrap_phase.py +++ b/warpkit/scripts/unwrap_phase.py @@ -1,6 +1,21 @@ +"""``wk-unwrap-phase`` — ROMEO multi-echo phase unwrapping (the unwrap stage of +MEDIC). + +Public surface: + +* :func:`unwrap_phase` — typed Python entry point. Returns an + :class:`UnwrapPhaseResult` with the per-echo unwrapped-phase paths and the + per-frame masks path. +* :func:`main` — argparse CLI shim. +""" + +from __future__ import annotations + import argparse -import json -from typing import cast +from collections.abc import Sequence +from dataclasses import dataclass +from os import PathLike +from pathlib import Path import nibabel as nib @@ -9,6 +24,93 @@ from warpkit.utilities import setup_logging from . import epilog +from ._metadata import ensure_images, resolve_acquisition, trim_noise_frames + + +@dataclass(frozen=True, slots=True) +class UnwrapPhaseResult: + unwrapped: list[Path] + masks: Path + + +def unwrap_phase( + *, + phase: Sequence[PathLike[str] | str | nib.Nifti1Image], + magnitude: Sequence[PathLike[str] | str | nib.Nifti1Image], + out_prefix: PathLike[str] | str, + tes: Sequence[float] | None = None, + metadata: Sequence[PathLike[str] | str] | None = None, + noise_frames: int = 0, + n_cpus: int = 4, + wrap_limit: bool = False, + debug: bool = False, +) -> UnwrapPhaseResult: + """Run ROMEO multi-echo phase unwrapping. + + Either pass ``metadata`` (one BIDS sidecar per echo) or ``tes`` directly + (echo times in ms). The two are mutually exclusive. + + Returns absolute paths of one unwrapped-phase NIfTI per echo + (``_unwrapped_echo-NN.nii``) plus the per-frame masks NIfTI + (``_masks.nii``). + """ + if len(magnitude) != len(phase): + raise ValueError( + f"got {len(magnitude)} magnitude file(s) but {len(phase)} phase " + "file(s); they must match (one mag/phase pair per echo)." + ) + if metadata is not None and len(metadata) != len(phase): + raise ValueError( + f"got {len(metadata)} metadata file(s) but {len(phase)} phase " + "file(s); they must match (one sidecar per echo)." + ) + + tes_ms, _, _ = resolve_acquisition(metadata=metadata, tes=tes, require_trt_pe=False) + + if len(tes_ms) != len(phase): + raise ValueError( + f"got {len(tes_ms)} echo time(s) but --phase has {len(phase)} " + "file(s); they must match." + ) + + mag_data = ensure_images(magnitude) + phase_data = ensure_images(phase) + + if noise_frames < 0: + raise ValueError(f"noise_frames must be non-negative; got {noise_frames}.") + if noise_frames > 0: + for label, imgs in (("phase", phase_data), ("magnitude", mag_data)): + for idx, img in enumerate(imgs): + n_frames = img.shape[-1] if img.ndim == 4 else 1 + if noise_frames >= n_frames: + raise ValueError( + f"noise_frames={noise_frames} would leave 0 frames " + f"in {label} image #{idx} (has {n_frames} frame(s))." + ) + print(f"Removing {noise_frames} noise frames from the end of each file...") + mag_data = trim_noise_frames(mag_data, noise_frames) + phase_data = trim_noise_frames(phase_data, noise_frames) + + unwrapped_imgs, masks_img = unwrap_phases( + phase_data, + mag_data, + list(tes_ms), + n_cpus=n_cpus, + debug=debug, + wrap_limit=wrap_limit, + ) + + out_prefix_str = str(out_prefix) + print("Saving unwrapped phase images and masks to file...") + unwrapped_paths: list[Path] = [] + for i_echo, img in enumerate(unwrapped_imgs, start=1): + out_path = Path(f"{out_prefix_str}_unwrapped_echo-{i_echo:02d}.nii").resolve() + img.to_filename(str(out_path)) + unwrapped_paths.append(out_path) + masks_path = Path(f"{out_prefix_str}_masks.nii").resolve() + masks_img.to_filename(str(masks_path)) + print("Done.") + return UnwrapPhaseResult(unwrapped=unwrapped_paths, masks=masks_path) def main(): @@ -67,86 +169,20 @@ def main(): ) args = parser.parse_args() - - if args.metadata and args.tes is not None: - parser.error( - "--metadata is mutually exclusive with --TEs; pass one or the " - "other, not both." - ) - if not args.metadata and args.tes is None: - parser.error("either --metadata or --TEs must be provided.") - - if len(args.magnitude) != len(args.phase): - parser.error( - f"got {len(args.magnitude)} --magnitude file(s) but " - f"{len(args.phase)} --phase file(s); they must match (one " - "mag/phase pair per echo)." - ) - if args.metadata is not None and len(args.metadata) != len(args.phase): - parser.error( - f"got {len(args.metadata)} --metadata file(s) but " - f"{len(args.phase)} --phase file(s); they must match (one " - "sidecar per echo)." - ) - - echo_times: list[float] - if args.metadata: - metadatas = [] - for j in args.metadata: - with open(j) as f: - metadatas.append(json.load(f)) - echo_times = [float(m["EchoTime"]) * 1000 for m in metadatas] - else: - echo_times = cast(list[float], args.tes) - - if len(echo_times) != len(args.phase): - parser.error( - f"got {len(echo_times)} echo time(s) but --phase has " - f"{len(args.phase)} file(s); they must match." - ) - setup_logging() print(f"wk-unwrap-phase: {args}") - mag_data = [cast(nib.Nifti1Image, nib.load(m)) for m in args.magnitude] - phase_data = [cast(nib.Nifti1Image, nib.load(p)) for p in args.phase] - - if args.noiseframes < 0: - parser.error(f"--noiseframes must be non-negative; got {args.noiseframes}.") - if args.noiseframes > 0: - for label, imgs, paths in ( - ("phase", phase_data, args.phase), - ("magnitude", mag_data, args.magnitude), - ): - for img, path in zip(imgs, paths, strict=True): - n_frames = img.shape[-1] if img.ndim == 4 else 1 - if args.noiseframes >= n_frames: - parser.error( - f"--noiseframes={args.noiseframes} would leave 0 " - f"frames in {label} file '{path}' (has {n_frames} " - "frame(s))." - ) - print(f"Removing {args.noiseframes} noise frames from the end of each file...") - mag_data = [ - nib.Nifti1Image(m.dataobj[..., : -args.noiseframes], m.affine, m.header) - for m in mag_data - ] - phase_data = [ - nib.Nifti1Image(p.dataobj[..., : -args.noiseframes], p.affine, p.header) - for p in phase_data - ] - - unwrapped_imgs, masks_img = unwrap_phases( - phase_data, - mag_data, - echo_times, - n_cpus=args.n_cpus, - debug=args.debug, - wrap_limit=args.wrap_limit, - ) - - print("Saving unwrapped phase images and masks to file...") - for i_echo, img in enumerate(unwrapped_imgs, start=1): - img.to_filename(f"{args.out_prefix}_unwrapped_echo-{i_echo:02d}.nii") - masks_img.to_filename(f"{args.out_prefix}_masks.nii") - print("Done.") + try: + unwrap_phase( + phase=args.phase, + magnitude=args.magnitude, + out_prefix=args.out_prefix, + tes=args.tes, + metadata=args.metadata, + noise_frames=args.noiseframes, + n_cpus=args.n_cpus, + wrap_limit=args.wrap_limit, + debug=args.debug, + ) + except ValueError as e: + parser.error(str(e)) From 781762066ee75c5e01b964792b5d5d6488c30fce Mon Sep 17 00:00:00 2001 From: Andrew Van Date: Sat, 25 Apr 2026 17:18:46 -0500 Subject: [PATCH 2/4] :wrench: Loosen Codecov thresholds (1% project / 90% patch) Default Codecov policy treats any project-coverage decrease as a failure and requires patch coverage to match the base. New CLI shims add unavoidable argparse / __main__ glue, so a small drop is expected. Allow 1% slack on project and pin patch to a 90% floor (with 1% slack) instead. Co-Authored-By: Claude Opus 4.7 (1M context) --- codecov.yml | 10 ++++++++++ 1 file changed, 10 insertions(+) create mode 100644 codecov.yml diff --git a/codecov.yml b/codecov.yml new file mode 100644 index 0000000..e238925 --- /dev/null +++ b/codecov.yml @@ -0,0 +1,10 @@ +coverage: + status: + project: + default: + target: auto + threshold: 1% + patch: + default: + target: 90% + threshold: 1% From 51e6c7a2030729f9d5a20b310db1a144557781ef Mon Sep 17 00:00:00 2001 From: Andrew Van Date: Sat, 25 Apr 2026 17:20:11 -0500 Subject: [PATCH 3/4] :fire: remove compose.yml --- compose.yml | 5 ----- 1 file changed, 5 deletions(-) delete mode 100644 compose.yml diff --git a/compose.yml b/compose.yml deleted file mode 100644 index cc9cef2..0000000 --- a/compose.yml +++ /dev/null @@ -1,5 +0,0 @@ -services: - me_pipeline: - image: warpkit:latest - build: - context: . From c8cf6203abba37c001bcde72e13b2db5333cf55c Mon Sep 17 00:00:00 2001 From: Andrew Van Date: Sat, 25 Apr 2026 17:25:35 -0500 Subject: [PATCH 4/4] :bug: Centralise noise-frame and metadata validation in _metadata helpers Address review feedback on the typed Python API: - ``trim_noise_frames`` now rejects 3D inputs (``[..., :-n]`` would chop the Z dimension instead of frames) and ``n >= n_frames`` (which silently yielded an empty 4D series). Both now raise ``ValueError`` from the helper itself, so ``medic()`` and any future caller inherit the same safety; the duplicated guard inside ``unwrap_phase()`` is dropped. - ``load_acquisition_from_metadata`` honours ``require_trt_pe``: callers that only need ``EchoTime`` (e.g. ``unwrap_phase``) accept sidecars without ``TotalReadoutTime`` / ``PhaseEncodingDirection``. Missing keys surface as a clean ``ValueError`` instead of a ``KeyError``. Adds CLI tests for the EchoTime-only sidecar path, the missing-EchoTime error, and ``-f >= n_frames`` rejection in ``wk-medic``. Co-Authored-By: Claude Opus 4.7 (1M context) --- tests/test_scripts.py | 85 +++++++++++++++++++++++++++++++++ warpkit/scripts/_metadata.py | 59 ++++++++++++++++++----- warpkit/scripts/unwrap_phase.py | 10 ---- 3 files changed, 132 insertions(+), 22 deletions(-) diff --git a/tests/test_scripts.py b/tests/test_scripts.py index ca36580..a2a0d27 100644 --- a/tests/test_scripts.py +++ b/tests/test_scripts.py @@ -600,6 +600,91 @@ def test_unwrap_phase_noiseframes_consumes_all_frames(argv, capsys, tmp_path): assert "0 frames" in err +def test_unwrap_phase_metadata_accepts_echotime_only(argv, capsys, tmp_path): + """``wk-unwrap-phase`` only needs EchoTime; sidecars without + TotalReadoutTime / PhaseEncodingDirection must work. A mismatched phase + count forces a clean parser.error *after* the metadata loader resolves — + if the loader still required TRT/PED we'd see a KeyError instead.""" + sidecar = tmp_path / "m1.json" + sidecar.write_text(json.dumps({"EchoTime": 0.014})) + argv( + [ + "wk-unwrap-phase", + "--magnitude", + "m.nii", + "--phase", + "p1.nii", + "p2.nii", + "--metadata", + str(sidecar), + "--out-prefix", + str(tmp_path / "out"), + ] + ) + with pytest.raises(SystemExit) as exc: + unwrap_phase_main() + assert exc.value.code == 2 + err = capsys.readouterr().err + assert "must match" in err + + +def test_unwrap_phase_metadata_missing_echotime(argv, capsys, tmp_path): + """A sidecar without EchoTime must surface as a clean parser error, not a + KeyError from the metadata loader.""" + sidecar = tmp_path / "m1.json" + sidecar.write_text(json.dumps({})) + argv( + [ + "wk-unwrap-phase", + "--magnitude", + "m.nii", + "--phase", + "p.nii", + "--metadata", + str(sidecar), + "--out-prefix", + str(tmp_path / "out"), + ] + ) + with pytest.raises(SystemExit) as exc: + unwrap_phase_main() + assert exc.value.code == 2 + err = capsys.readouterr().err + assert "EchoTime" in err + + +def test_medic_noiseframes_consumes_all_frames(argv, capsys, tmp_path): + """``-f`` >= n_frames now raises a clean parser error in medic too: the + check moved into ``trim_noise_frames`` so every caller is protected from + silently producing an empty 4D series.""" + mag = _write_nifti(tmp_path / "m.nii", (4, 4, 4, 5)) + phase = _write_nifti(tmp_path / "p.nii", (4, 4, 4, 5)) + argv( + [ + "wk-medic", + "--magnitude", + mag, + "--phase", + phase, + "--TEs", + "14.0", + "--total-readout-time", + "0.05", + "--phase-encoding-direction", + "j", + "--out-prefix", + str(tmp_path / "out"), + "-f", + "5", + ] + ) + with pytest.raises(SystemExit) as exc: + medic_main() + assert exc.value.code == 2 + err = capsys.readouterr().err + assert "0 frames" in err + + # --------------------------------------------------------------------------- # compute_fieldmap --help / argument validation # --------------------------------------------------------------------------- diff --git a/warpkit/scripts/_metadata.py b/warpkit/scripts/_metadata.py index 938237c..37f258b 100644 --- a/warpkit/scripts/_metadata.py +++ b/warpkit/scripts/_metadata.py @@ -39,18 +39,40 @@ def ensure_images( def load_acquisition_from_metadata( metadata_paths: Sequence[PathLike[str] | str], -) -> tuple[list[float], float, str]: - """Read EchoTime (s → ms), TotalReadoutTime (s) and PhaseEncodingDirection - from BIDS-style JSON sidecars. + *, + require_trt_pe: bool = True, +) -> tuple[list[float], float | None, str | None]: + """Read EchoTime (s → ms) — and optionally TotalReadoutTime (s) and + PhaseEncodingDirection — from BIDS-style JSON sidecars. Per-echo ``EchoTime`` is read from each file; ``TotalReadoutTime`` and - ``PhaseEncodingDirection`` are taken from the first. + ``PhaseEncodingDirection`` are taken from the first. When + ``require_trt_pe=False`` the latter two are skipped so callers that only + need echo times (e.g. ``unwrap_phase``) accept sidecars that omit them. + Missing required keys raise :class:`ValueError`. """ metadatas = [] for j in metadata_paths: with open(j) as f: metadatas.append(json.load(f)) - tes_ms = [float(m["EchoTime"]) * 1000 for m in metadatas] + try: + tes_ms = [float(m["EchoTime"]) * 1000 for m in metadatas] + except KeyError: + raise ValueError( + "metadata sidecar is missing required key: 'EchoTime'." + ) from None + if not require_trt_pe: + return tes_ms, None, None + missing = [ + k + for k in ("TotalReadoutTime", "PhaseEncodingDirection") + if k not in metadatas[0] + ] + if missing: + raise ValueError( + "metadata sidecar is missing required key(s): " + f"{', '.join(repr(k) for k in missing)}." + ) trt = float(metadatas[0]["TotalReadoutTime"]) ped = str(metadatas[0]["PhaseEncodingDirection"]) return tes_ms, trt, ped @@ -105,23 +127,36 @@ def resolve_acquisition( raise ValueError("either --metadata or --TEs must be provided.") if metadata is not None: - tes_resolved, trt_resolved, ped_resolved = load_acquisition_from_metadata( - metadata - ) - if require_trt_pe: - return tes_resolved, trt_resolved, ped_resolved - return tes_resolved, None, None + return load_acquisition_from_metadata(metadata, require_trt_pe=require_trt_pe) return list(tes or []), total_readout_time, phase_encoding_direction def trim_noise_frames(images: list[nib.Nifti1Image], n: int) -> list[nib.Nifti1Image]: """Trim the last ``n`` frames from each 4D image. Returns the input list - unchanged when ``n == 0``.""" + unchanged when ``n == 0``. + + When ``n > 0`` each image must be 4D with strictly more than ``n`` frames; + otherwise ``[..., :-n]`` would either chop the Z dimension of a 3D volume + or yield an empty 4D series that crashes downstream consumers. Both raise + :class:`ValueError`. + """ if n == 0: return images if n < 0: raise ValueError(f"noise_frames must be non-negative; got {n}.") + for idx, img in enumerate(images): + if img.ndim != 4: + raise ValueError( + f"noise_frames={n} requires 4D images; image #{idx} has " + f"ndim={img.ndim}." + ) + n_frames = img.shape[-1] + if n >= n_frames: + raise ValueError( + f"noise_frames={n} would leave 0 frames in image #{idx} " + f"(has {n_frames} frame(s))." + ) return [ nib.Nifti1Image(img.dataobj[..., :-n], img.affine, img.header) for img in images ] diff --git a/warpkit/scripts/unwrap_phase.py b/warpkit/scripts/unwrap_phase.py index bffaa22..828f62b 100644 --- a/warpkit/scripts/unwrap_phase.py +++ b/warpkit/scripts/unwrap_phase.py @@ -76,17 +76,7 @@ def unwrap_phase( mag_data = ensure_images(magnitude) phase_data = ensure_images(phase) - if noise_frames < 0: - raise ValueError(f"noise_frames must be non-negative; got {noise_frames}.") if noise_frames > 0: - for label, imgs in (("phase", phase_data), ("magnitude", mag_data)): - for idx, img in enumerate(imgs): - n_frames = img.shape[-1] if img.ndim == 4 else 1 - if noise_frames >= n_frames: - raise ValueError( - f"noise_frames={noise_frames} would leave 0 frames " - f"in {label} image #{idx} (has {n_frames} frame(s))." - ) print(f"Removing {noise_frames} noise frames from the end of each file...") mag_data = trim_noise_frames(mag_data, noise_frames) phase_data = trim_noise_frames(phase_data, noise_frames)