diff --git a/configs/config_camus.yaml b/configs/config_camus.yaml index 4f0465e71..e623a854b 100644 --- a/configs/config_camus.yaml +++ b/configs/config_camus.yaml @@ -1,29 +1,12 @@ # config_camus.yaml - comments were autogenerated from PARAMETER_DESCRIPTIONS in zea/config/parameters.py -# The data section contains the parameters for the data. +# Data path and loading settings. data: - # The path of the folder to load data files from (relative to the user data - # root as set in users.yaml) - dataset_folder: hf://zeahub/camus-sample - # The path of the file to load when running the UI (either an absolute path or - # one relative to the dataset folder) - file_path: val/patient0401/patient0401_4CH_half_sequence.hdf5 + # Full path to the data file. Supports absolute paths, paths relative to the + # user data root (set in users.yaml), and Hugging Face Hub paths + # (hf://org/repo/path/to/file.hdf5). + path: hf://zeahub/camus-sample/val/patient0401/patient0401_4CH_half_sequence.hdf5 # true: use local data on this device, false: use data from NAS local: false - # The form of data to load (raw_data, rf_data, iq_data, beamformed_data, - # envelope_data, image, image_sc) - dtype: image_sc - # The dynamic range for showing data in db [min, max] - dynamic_range: [-60, 0] - # The frame number to load when running the UI (null, int, 'all') - frame_no: all - # The type of data to convert to (raw_data, aligned_data, beamformed_data, - # envelope_data, image, image_sc) - to_dtype: image_sc - -# Settings pertaining to plotting when running the UI (`zea --config -# `) -plot: - # Set to true to save the plots to disk, false to only display them in the UI - save: true - # The plotting library to use (opencv, matplotlib) - plot_lib: opencv \ No newline at end of file + # Indices into the data to load. null loads the default, 'all' loads every + # frame, int loads a single frame, list loads specific frames. + indices: all diff --git a/configs/config_carotid.yaml b/configs/config_carotid.yaml index 7e6d911b8..6faad51e2 100644 --- a/configs/config_carotid.yaml +++ b/configs/config_carotid.yaml @@ -1,36 +1,24 @@ # config_carotid.yaml - comments were autogenerated from PARAMETER_DESCRIPTIONS in zea/config/parameters.py -# The data section contains the parameters for the data. +# Data path and loading settings. data: - # The path of the folder to load data files from (relative to the user data - # root as set in users.yaml) - dataset_folder: hf://zeahub/zea-carotid-2023 - # The path of the file to load when running the UI (either an absolute path or - # one relative to the dataset folder) - file_path: 2_cross_bifur_right_0000.hdf5 + # Full path to the data file. Supports absolute paths, paths relative to the + # user data root (set in users.yaml), and Hugging Face Hub paths + # (hf://org/repo/path/to/file.hdf5). + path: hf://zeahub/zea-carotid-2023/2_cross_bifur_right_0000.hdf5 # true: use local data on this device, false: use data from NAS local: false - # The form of data to load (raw_data, rf_data, iq_data, beamformed_data, - # envelope_data, image, image_sc) - dtype: raw_data - # The dynamic range for showing data in db [min, max] - dynamic_range: [-40, 0] - # The frame number to load when running the UI (null, int, 'all') - frame_no: all + # Indices into the data to load. null loads the default, 'all' loads every + # frame, int loads a single frame, list loads specific frames. + indices: all -# The parameters section is a flat mapping of scan/probe/custom parameters that -# overwrite values loaded from the data file. Documented reconstruction -# parameters are listed below; arbitrary custom parameters are also allowed. +# Open mapping of scan/probe/custom parameters that overwrite values loaded from +# the data file. ProbeSpec and ScanSpec are the authoritative sources for valid +# parameter names — see the spec reference in data-acquisition. Arbitrary custom +# parameters are forwarded to the pipeline unchanged. parameters: - # The number of transmits in a frame. Can be 'all' for all transmits, an - # integer for a specific number of transmits selected evenly from the - # transmits in the frame, or a list of integers for specific transmits to - # select from the frame. - selected_transmits: all # 149 is all or 11 for instance for reduce number of Tx - # The number of channels in the raw data (1 for rf data, 2 for iq data) + selected_transmits: all n_ch: 1 - # The number of pixels in the beamforming grid in the x-direction grid_size_x: 400 - # The number of pixels in the beamforming grid in the z-direction grid_size_z: 600 # This section contains the necessary parameters for building the pipeline. @@ -48,17 +36,5 @@ pipeline: - name: normalize - name: log_compress -# The device to run on ('cpu', 'gpu:0', 'gpu:1', ...) +# The device to run on ('cpu', 'gpu:0', 'gpu:1', 'auto:1', ...) device: auto:1 - -# Settings pertaining to plotting when running the UI (`zea --config -# `) -plot: - # Set to true to save the plots to disk, false to only display them in the UI - save: true - # The plotting library to use (opencv, matplotlib) - plot_lib: opencv - # Set to true to run the UI in headless mode - headless: false - # The name for the plot - tag: carotid diff --git a/configs/config_echonet.yaml b/configs/config_echonet.yaml index a4291295d..843ff7057 100644 --- a/configs/config_echonet.yaml +++ b/configs/config_echonet.yaml @@ -1,35 +1,23 @@ # config_echonet.yaml - comments were autogenerated from PARAMETER_DESCRIPTIONS in zea/config/parameters.py -# The data section contains the parameters for the data. +# Data path and loading settings. data: - # The path of the folder to load data files from (relative to the user data - # root as set in users.yaml) - dataset_folder: echonet_v2025/train - # The path of the file to load when running the UI (either an absolute path or - # one relative to the dataset folder) - file_path: 0X1A8F20B8BF0B4B45.hdf5 + # Full path to the data file. Supports absolute paths, paths relative to the + # user data root (set in users.yaml), and Hugging Face Hub paths + # (hf://org/repo/path/to/file.hdf5). + path: echonet_v2025/train/0X1A8F20B8BF0B4B45.hdf5 # true: use local data on this device, false: use data from NAS local: false - # The form of data to load (raw_data, rf_data, iq_data, beamformed_data, - # envelope_data, image, image_sc) - dtype: image - # The dynamic range for showing data in db [min, max] - dynamic_range: [-60, 0] - # The frame number to load when running the UI (null, int, 'all') - frame_no: all - # The type of data to convert to (raw_data, aligned_data, beamformed_data, - # envelope_data, image, image_sc) - to_dtype: image_sc + # Indices into the data to load. null loads the default, 'all' loads every + # frame, int loads a single frame, list loads specific frames. + indices: all -# The parameters section is a flat mapping of scan/probe/custom parameters that -# overwrite values loaded from the data file. Documented reconstruction -# parameters are listed below; arbitrary custom parameters are also allowed. +# Open mapping of scan/probe/custom parameters that overwrite values loaded from +# the data file. ProbeSpec and ScanSpec are the authoritative sources for valid +# parameter names — see the spec reference in data-acquisition. Arbitrary custom +# parameters are forwarded to the pipeline unchanged. parameters: - # The range of theta values in radians for scan conversion (null, [min, max]). - theta_range: [-0.78, 0.78] # [-45, 45] in rads - # The range of rho values in meters for scan conversion (null, [min, max]). + theta_range: [-0.78, 0.78] rho_range: [0, 1] - # Value to fill the image with outside the defined region (float, default - # 0.0). fill_value: -60 # This section contains the necessary parameters for building the pipeline. @@ -40,11 +28,3 @@ pipeline: - name: scan_convert params: jit_compile: false - -# Settings pertaining to plotting when running the UI (`zea --config -# `) -plot: - # Set to true to save the plots to disk, false to only display them in the UI - save: true - # The plotting library to use (opencv, matplotlib) - plot_lib: opencv diff --git a/configs/config_echonetlvh.yaml b/configs/config_echonetlvh.yaml index e37f88e24..f8702dc24 100644 --- a/configs/config_echonetlvh.yaml +++ b/configs/config_echonetlvh.yaml @@ -1,35 +1,23 @@ # config_echonetlvh.yaml - comments were autogenerated from PARAMETER_DESCRIPTIONS in zea/config/parameters.py -# The data section contains the parameters for the data. +# Data path and loading settings. data: - # The path of the folder to load data files from (relative to the user data - # root as set in users.yaml) - dataset_folder: echonetlvh_v2025/train - # The path of the file to load when running the UI (either an absolute path or - # one relative to the dataset folder) - file_path: 0X1017398D3C3F5FF9.hdf5 + # Full path to the data file. Supports absolute paths, paths relative to the + # user data root (set in users.yaml), and Hugging Face Hub paths + # (hf://org/repo/path/to/file.hdf5). + path: echonetlvh_v2025/train/0X1017398D3C3F5FF9.hdf5 # true: use local data on this device, false: use data from NAS local: false - # The form of data to load (raw_data, rf_data, iq_data, beamformed_data, - # envelope_data, image, image_sc) - dtype: image - # The dynamic range for showing data in db [min, max] - dynamic_range: [-60, 0] - # The frame number to load when running the UI (null, int, 'all') - frame_no: all - # The type of data to convert to (raw_data, aligned_data, beamformed_data, - # envelope_data, image, image_sc) - to_dtype: image_sc + # Indices into the data to load. null loads the default, 'all' loads every + # frame, int loads a single frame, list loads specific frames. + indices: all -# The parameters section is a flat mapping of scan/probe/custom parameters that -# overwrite values loaded from the data file. Documented reconstruction -# parameters are listed below; arbitrary custom parameters are also allowed. +# Open mapping of scan/probe/custom parameters that overwrite values loaded from +# the data file. ProbeSpec and ScanSpec are the authoritative sources for valid +# parameter names — see the spec reference in data-acquisition. Arbitrary custom +# parameters are forwarded to the pipeline unchanged. parameters: - # The range of theta values in radians for scan conversion (null, [min, max]). - theta_range: [-0.78, 0.78] # [-45, 45] in rads - # The range of rho values in meters for scan conversion (null, [min, max]). + theta_range: [-0.78, 0.78] rho_range: [0, 256] - # Value to fill the image with outside the defined region (float, default - # 0.0). fill_value: -60 # This section contains the necessary parameters for building the pipeline. @@ -41,11 +29,3 @@ pipeline: params: jit_compile: false order: 2 - -# Settings pertaining to plotting when running the UI (`zea --config -# `) -plot: - # Set to true to save the plots to disk, false to only display them in the UI - save: true - # The plotting library to use (opencv, matplotlib) - plot_lib: opencv diff --git a/configs/config_picmus_iq.yaml b/configs/config_picmus_iq.yaml index 5394b739c..1cffe69fa 100644 --- a/configs/config_picmus_iq.yaml +++ b/configs/config_picmus_iq.yaml @@ -1,44 +1,26 @@ # config_picmus_iq.yaml - comments were autogenerated from PARAMETER_DESCRIPTIONS in zea/config/parameters.py -# The data section contains the parameters for the data. +# Data path and loading settings. data: - # The path of the folder to load data files from (relative to the user data - # root as set in users.yaml) - dataset_folder: hf://zeahub/picmus/database/simulation/contrast_speckle/contrast_speckle_simu_dataset_iq - # The path of the file to load when running the UI (either an absolute path or - # one relative to the dataset folder) - file_path: contrast_speckle_simu_dataset_iq.hdf5 + # Full path to the data file. Supports absolute paths, paths relative to the + # user data root (set in users.yaml), and Hugging Face Hub paths + # (hf://org/repo/path/to/file.hdf5). + path: hf://zeahub/picmus/database/simulation/contrast_speckle/contrast_speckle_simu_dataset_iq/contrast_speckle_simu_dataset_iq.hdf5 # true: use local data on this device, false: use data from NAS local: false - # The form of data to load (raw_data, rf_data, iq_data, beamformed_data, - # envelope_data, image, image_sc) - dtype: raw_data - # The dynamic range for showing data in db [min, max] - dynamic_range: [-60, 0] -# The parameters section is a flat mapping of scan/probe/custom parameters that -# overwrite values loaded from the data file. Documented reconstruction -# parameters are listed below; arbitrary custom parameters are also allowed. +# Open mapping of scan/probe/custom parameters that overwrite values loaded from +# the data file. ProbeSpec and ScanSpec are the authoritative sources for valid +# parameter names — see the spec reference in data-acquisition. Arbitrary custom +# parameters are forwarded to the pipeline unchanged. parameters: - # The number of transmits in a frame. Can be 'all' for all transmits, an - # integer for a specific number of transmits selected evenly from the - # transmits in the frame, or a list of integers for specific transmits to - # select from the frame. selected_transmits: all - # The number of pixels in the beamforming grid in the x-direction grid_size_x: 300 - # The number of pixels in the beamforming grid in the z-direction grid_size_z: 500 - # Set to true to apply lens correction in the time-of-flight calculation apply_lens_correction: false - # The speed of sound in the lens in m/s. Usually around 1000 m/s lens_sound_speed: 1000 - # The thickness of the lens in meters lens_thickness: 0.001 - # The limits of the z-axis in the scan in meters (null, [min, max]) zlims: [0.006, 0.055] - # The limits of the x-axis in the scan in meters (null, [min, max]) xlims: [-0.02, 0.02] - # The number of channels in the raw data (1 for rf data, 2 for iq data) n_ch: 2 # This section contains the necessary parameters for building the pipeline. @@ -55,15 +37,5 @@ pipeline: - name: normalize - name: log_compress -# The device to run on ('cpu', 'gpu:0', 'gpu:1', ...) +# The device to run on ('cpu', 'gpu:0', 'gpu:1', 'auto:1', ...) device: auto:1 - -# Settings pertaining to plotting when running the UI (`zea --config -# `) -plot: - # Set to true to save the plots to disk, false to only display them in the UI - save: true - # The plotting library to use (opencv, matplotlib) - plot_lib: matplotlib - # The name for the plot - tag: test diff --git a/configs/config_picmus_rf.yaml b/configs/config_picmus_rf.yaml index e793e6f7e..ffc2a6ae4 100644 --- a/configs/config_picmus_rf.yaml +++ b/configs/config_picmus_rf.yaml @@ -1,45 +1,26 @@ # config_picmus_rf.yaml - comments were autogenerated from PARAMETER_DESCRIPTIONS in zea/config/parameters.py -# The data section contains the parameters for the data. +# Data path and loading settings. data: - # The path of the folder to load data files from (relative to the user data - # root as set in users.yaml) - dataset_folder: hf://zeahub/picmus/database/simulation/contrast_speckle/contrast_speckle_simu_dataset_rf - # The path of the file to load when running the UI (either an absolute path or - # one relative to the dataset folder) - file_path: contrast_speckle_simu_dataset_rf.hdf5 + # Full path to the data file. Supports absolute paths, paths relative to the + # user data root (set in users.yaml), and Hugging Face Hub paths + # (hf://org/repo/path/to/file.hdf5). + path: hf://zeahub/picmus/database/simulation/contrast_speckle/contrast_speckle_simu_dataset_rf/contrast_speckle_simu_dataset_rf.hdf5 # true: use local data on this device, false: use data from NAS local: false - # The form of data to load (raw_data, rf_data, iq_data, beamformed_data, - # envelope_data, image, image_sc) - dtype: raw_data - # The dynamic range for showing data in db [min, max] - dynamic_range: [-50, 0] - -# The parameters section is a flat mapping of scan/probe/custom parameters that -# overwrite values loaded from the data file. Documented reconstruction -# parameters are listed below; arbitrary custom parameters are also allowed. +# Open mapping of scan/probe/custom parameters that overwrite values loaded from +# the data file. ProbeSpec and ScanSpec are the authoritative sources for valid +# parameter names — see the spec reference in data-acquisition. Arbitrary custom +# parameters are forwarded to the pipeline unchanged. parameters: - # The number of transmits in a frame. Can be 'all' for all transmits, an - # integer for a specific number of transmits selected evenly from the - # transmits in the frame, or a list of integers for specific transmits to - # select from the frame. selected_transmits: all - # The number of pixels in the beamforming grid in the x-direction grid_size_x: 400 - # The number of pixels in the beamforming grid in the z-direction grid_size_z: 600 - # The number of channels in the raw data (1 for rf data, 2 for iq data) n_ch: 1 - # Set to true to apply lens correction in the time-of-flight calculation apply_lens_correction: false - # The speed of sound in the lens in m/s. Usually around 1000 m/s lens_sound_speed: 1000 - # The thickness of the lens in meters lens_thickness: 0.001 - # The limits of the z-axis in the scan in meters (null, [min, max]) zlims: [0.006, 0.055] - # The limits of the x-axis in the scan in meters (null, [min, max]) xlims: [-0.02, 0.02] # This section contains the necessary parameters for building the pipeline. @@ -60,15 +41,5 @@ pipeline: - name: normalize - name: log_compress -# The device to run on ('cpu', 'gpu:0', 'gpu:1', ...) +# The device to run on ('cpu', 'gpu:0', 'gpu:1', 'auto:1', ...) device: auto:1 - -# Settings pertaining to plotting when running the UI (`zea --config -# `) -plot: - # Set to true to save the plots to disk, false to only display them in the UI - save: true - # The plotting library to use (opencv, matplotlib) - plot_lib: matplotlib - # The name for the plot - tag: picmus_rf \ No newline at end of file diff --git a/docs/_static/diagrams_dataflow.png b/docs/_static/diagrams_dataflow.png deleted file mode 100644 index 08d185a49..000000000 Binary files a/docs/_static/diagrams_dataflow.png and /dev/null differ diff --git a/docs/source/_autosummary/zea.rst b/docs/source/_autosummary/zea.rst index 63fde1ec2..9efe03b1d 100644 --- a/docs/source/_autosummary/zea.rst +++ b/docs/source/_autosummary/zea.rst @@ -34,7 +34,6 @@ zea display doppler func - interface io_lib log metrics diff --git a/docs/source/_spec_ref.rst b/docs/source/_spec_ref.rst index c799ec9be..883711c2b 100644 --- a/docs/source/_spec_ref.rst +++ b/docs/source/_spec_ref.rst @@ -71,6 +71,8 @@ Stored as HDF5 root-level attributes (not groups). - |badge-opt| +.. _group-reference: + Group reference ~~~~~~~~~~~~~~~ diff --git a/docs/source/cli.rst b/docs/source/cli.rst index a5fbb82f7..ad737a383 100644 --- a/docs/source/cli.rst +++ b/docs/source/cli.rst @@ -3,12 +3,10 @@ Command line interface Besides the main :doc:`zea API documentation <_autosummary/zea>`, ``zea`` also provides a command line interface (CLI). -------------------------------- -File reading and visualization -------------------------------- - -.. autoprogram:: zea.__main__:get_parser() - :prog: zea +.. note:: + The ``zea`` CLI is currently a placeholder. Extended visualization and data + inspection commands will be added in a future release. In the meantime, + use ``python -m zea.data.convert`` and ``python -m zea.data`` below. ------------------------------- Convert datasets diff --git a/docs/source/config.rst b/docs/source/config.rst index 9eda7a529..9a2101e43 100644 --- a/docs/source/config.rst +++ b/docs/source/config.rst @@ -1,189 +1,155 @@ -.. THIS FILE WAS AUTOGENERATED USING docs/source/parameters_doc.py. DO NOT EDIT MANUALLY. - .. _config: -Config -====== - -This page documents the ``zea`` configuration system. Configs are YAML files that -control data loading, preprocessing, model settings, and scan parameters. +Parameters +========== .. note:: - Configs are used to initialize :doc:`zea.Models ` and the :doc:`pipeline`. - For the data format and file I/O, see :doc:`data-acquisition`. + For the HDF5 data format and file I/O see :doc:`data-acquisition`. + For pipeline operations see :doc:`pipeline`. -.. note:: - The ``parameters`` section is a flat mapping that overrides values loaded from the - data file. It mirrors :class:`zea.Parameters` (the merged :class:`zea.Probe` and scan - parameters — see :doc:`data-acquisition`), and may additionally contain arbitrary - custom parameters that are passed straight through to the pipeline. The documented - reconstruction keys below are the most common ones, but they are not exhaustive. +---------------------------- +Parameters in the file +---------------------------- + +Every ``zea`` HDF5 file stores all parameters needed to process the acquisition +alongside the raw data. They are split into two groups: + +**Probe** (``probe/``) + Fixed for the whole acquisition — element geometry, center frequency, + bandwidth, lens properties. Shared across all tracks. + Defined by :class:`~zea.data.spec.ProbeSpec`. + +**Scan** (``scan/``) + Per-track transmit sequence — delays, apodizations, angles, waveforms, + sound speed. Each track has its own :class:`~zea.data.spec.ScanSpec`. -Configs are written in YAML format and can be loaded, edited, and saved using the ``zea`` API. +See the :ref:`group-reference` table for the complete field listing. -------------------------------- -How to Load and Save a Config -------------------------------- +---------------------------- +zea.Parameters +---------------------------- -Here is a minimal example of how to load and save a config file using zea: +:meth:`~zea.File.load_parameters` merges the probe and scan groups into a +single :class:`~zea.Parameters` object and adds derived quantities +(``wavelength``, ``n_tx``, ``grid``, ``xlims``/``zlims``, ``selected_transmits``): + +.. code-block:: python + + with zea.File("data.hdf5") as f: + parameters = f.load_parameters() # single-track + parameters = f.load_parameters(track=0) # multi-track + +---------------------------- +Config +---------------------------- + +A config is a YAML file (loaded as :class:`~zea.Config`) that specifies where +the data lives, the pipeline to run, the device to use, and any parameter +overrides. .. doctest:: >>> from zea import Config >>> from zea.config import check_config - >>> # Load a config from file >>> config = Config.from_path("../configs/config_picmus_rf.yaml") - >>> # or some predefined from Hugging Face Hub - >>> config = Config.from_path("hf://zeahub/configs/config_picmus_rf.yaml") - - >>> # We can check if the config has valid parameters (zea compliance) - >>> config = check_config(config) - - >>> # Access or change parameters - >>> config.parameters.grid_size_x = 512 - >>> print(config.parameters.grid_size_x) - 512 - - >>> # Save the config back to file - >>> config.to_yaml("my_new_config.yaml") + >>> config = check_config(config) # fills defaults, validates + >>> config.pipeline.operations + ['demodulate', 'downsample', 'beamform', 'envelope_detect', 'normalize', 'log_compress'] + >>> config.to_yaml("my_config.yaml") .. testcleanup:: import os + os.remove("my_config.yaml") - os.remove("my_new_config.yaml") - -------------------------------- -Parameter List -------------------------------- +Supported keys +~~~~~~~~~~~~~~ -Below is a hierarchical list of all configuration parameters, grouped by section. -Descriptions are shown for each parameter. +**data** — where to find the file -.. contents:: - :local: - :depth: 2 +.. list-table:: + :header-rows: 1 + :widths: 20 15 65 + + * - Key + - Default + - Description + * - ``path`` + - ``null`` + - Full path to the HDF5 file. Supports local absolute paths, paths relative + to the user data root (set in ``users.yaml``), and Hugging Face Hub paths + (``hf://org/repo/path/to/file.hdf5``). + * - ``local`` + - ``true`` + - Whether to use local data (``true``) or a network/NAS location (``false``). + * - ``indices`` + - ``null`` + - Which frames to load: ``null`` (default), ``'all'``, a single ``int``, or + a list of ints. + +**parameters** — override any field from the :ref:`group-reference` or pass +custom keys straight through to the pipeline: + +.. code-block:: yaml + + parameters: + center_frequency: 5.0e6 + xlims: [-0.02, 0.02] + grid_size_x: 512 + +**pipeline** — list of operations (see :doc:`pipeline`): -------------------------------- -Parameters Reference -------------------------------- +.. list-table:: + :header-rows: 1 + :widths: 20 15 65 + + * - Key + - Default + - Description + * - ``operations`` + - ``[identity]`` + - Ordered list of operations. Each entry is either an operation name (string) + or a mapping with ``name`` and optional ``params``. + * - ``jit_options`` + - ``'ops'`` + - JIT scope: ``'ops'`` (compile each op), ``'pipeline'`` (compile the whole + pipeline), or ``null`` (disable JIT). + * - ``with_batch_dim`` + - ``true`` + - Whether operations expect a leading batch dimension. + +**Top-level keys** .. list-table:: :header-rows: 1 - :widths: 20 80 - - * - **Parameter** - - **Description** - * - ``data`` - - The data section contains the parameters for the data. - * - ``data.apodization`` - - The receive apodization to use. - * - ``data.dataset_folder`` - - The path of the folder to load data files from (relative to the user data root as set in users.yaml) - * - ``data.dtype`` - - The form of data to load (raw_data, rf_data, iq_data, beamformed_data, envelope_data, image, image_sc) - * - ``data.dynamic_range`` - - The dynamic range for showing data in db [min, max] - * - ``data.file_path`` - - The path of the file to load when running the UI (either an absolute path or one relative to the dataset folder) - * - ``data.frame_no`` - - The frame number to load when running the UI (null, int, 'all') - * - ``data.input_range`` - - The range of the input data in db (null, [min, max]) - * - ``data.local`` - - true: use local data on this device, false: use data from NAS - * - ``data.output_range`` - - The output range to which the data should be mapped (e.g. [0, 1]). - * - ``data.resolution`` - - The spatial resolution of the data in meters per pixel (float, optional). - * - ``data.to_dtype`` - - The type of data to convert to (raw_data, aligned_data, beamformed_data, envelope_data, image, image_sc) - * - ``data.user`` - - The user to use when loading data (null, dict) + :widths: 20 15 65 + + * - Key + - Default + - Description * - ``device`` - - The device to run on ('cpu', 'gpu:0', 'gpu:1', ...) - * - ``git`` - - The git commit hash or branch for reproducibility (string, optional). + - ``'auto:1'`` + - Target hardware: ``cpu``, ``gpu``, ``cuda``, ``gpu:0``, ``auto:1`` + (auto-select; ``-1`` for last device). * - ``hide_devices`` - - List of device indices to hide from selection (list of int, optional). - * - ``parameters`` - - The parameters section is a flat mapping of scan/probe/custom parameters that overwrite values loaded from the data file. Documented reconstruction parameters are listed below; arbitrary custom parameters are also allowed. - * - ``parameters.apply_lens_correction`` - - Set to true to apply lens correction in the time-of-flight calculation - * - ``parameters.center_frequency`` - - The center frequency of the transmit pulse in Hz - * - ``parameters.demodulation_frequency`` - - The demodulation frequency of the data in Hz. This is the assumed center frequency of the transmit waveform used to demodulate the rf data to iq data. - * - ``parameters.f_number`` - - The receive f-number for apodization. Set to zero to disable masking. The f-number is the ratio between the distance from the transducer and the size of the aperture. - * - ``parameters.fill_value`` - - Value to fill the image with outside the defined region (float, default 0.0). - * - ``parameters.grid_size_x`` - - The number of pixels in the beamforming grid in the x-direction - * - ``parameters.grid_size_z`` - - The number of pixels in the beamforming grid in the z-direction - * - ``parameters.lens_sound_speed`` - - The speed of sound in the lens in m/s. Usually around 1000 m/s - * - ``parameters.lens_thickness`` - - The thickness of the lens in meters - * - ``parameters.n_ax`` - - The number of samples in a receive recording per channel. - * - ``parameters.n_ch`` - - The number of channels in the raw data (1 for rf data, 2 for iq data) - * - ``parameters.phi_range`` - - The range of phi values in radians for 3D scan conversion (null, [min, max]). - * - ``parameters.resolution`` - - The resolution for scan conversion in meters per pixel (float, optional). - * - ``parameters.rho_range`` - - The range of rho values in meters for scan conversion (null, [min, max]). - * - ``parameters.sampling_frequency`` - - The sampling frequency of the data in Hz - * - ``parameters.selected_transmits`` - - The number of transmits in a frame. Can be 'all' for all transmits, an integer for a specific number of transmits selected evenly from the transmits in the frame, or a list of integers for specific transmits to select from the frame. - * - ``parameters.theta_range`` - - The range of theta values in radians for scan conversion (null, [min, max]). - * - ``parameters.xlims`` - - The limits of the x-axis in the scan in meters (null, [min, max]) - * - ``parameters.ylims`` - - The limits of the y-axis in the scan in meters (null, [min, max]) - * - ``parameters.zlims`` - - The limits of the z-axis in the scan in meters (null, [min, max]) - * - ``pipeline`` - - This section contains the necessary parameters for building the pipeline. - * - ``pipeline.jit_kwargs`` - - Additional keyword arguments for the JIT compiler. Defaults to None. - * - ``pipeline.jit_options`` - - The JIT options to use. Must be 'pipeline', 'ops', or None. 'pipeline' compiles the entire pipeline as a single function. 'ops' compiles each operation separately. None disables JIT compilation. Defaults to 'ops'. - * - ``pipeline.name`` - - The name of the pipeline. Defaults to 'pipeline'. - * - ``pipeline.operations`` - - The operations to perform on the data. This is a list of dictionaries, where each dictionary contains the parameters for a single operation. - * - ``pipeline.validate`` - - Whether to validate the pipeline. Defaults to True. - * - ``pipeline.with_batch_dim`` - - Whether operations should expect a batch dimension in the input. Defaults to True. - * - ``plot`` - - Settings pertaining to plotting when running the UI (`zea --config `) - * - ``plot.fliplr`` - - Set to true to flip the image left to right - * - ``plot.fps`` - - Frames per second for video output. - * - ``plot.headless`` - - Set to true to run the UI in headless mode - * - ``plot.image_extension`` - - The file extension to use when saving the image (png, jpg) - * - ``plot.plot_lib`` - - The plotting library to use (opencv, matplotlib) - * - ``plot.save`` - - Set to true to save the plots to disk, false to only display them in the UI - * - ``plot.selector`` - - Type of selector to use for ROI selection in the UI ('rectangle', 'lasso', or None). - * - ``plot.selector_metric`` - - Metric to use for evaluating selected regions (e.g., 'gcnr'). - * - ``plot.tag`` - - The name for the plot - * - ``plot.video_extension`` - - The file extension to use when saving the video (mp4, gif) - * - ``scan`` - - Deprecated alias for 'parameters'. Supported for backward compatibility; prefer using 'parameters'. + - ``null`` + - Device indices to exclude from auto-selection (int or list of ints). + * - ``git`` + - ``null`` + - Git commit or branch recorded for reproducibility. + +The top-level config is **open**: arbitrary extra sections (``model:``, etc.) +are accepted and passed through unchanged. + +---------------------------- +API reference +---------------------------- + +.. autoclass:: zea.Config + :members: from_path, from_yaml, to_yaml + :undoc-members: + :show-inheritance: + +.. autofunction:: zea.config.check_config + :no-index: diff --git a/docs/source/parameters_doc.py b/docs/source/parameters_doc.py index 887d1d410..eafdd861a 100644 --- a/docs/source/parameters_doc.py +++ b/docs/source/parameters_doc.py @@ -1,6 +1,7 @@ """Automatically generate comments in YAML config files based on parameter descriptions. -Also generates a reStructuredText (RST) file for sphinx documentation. +Run this script to add inline comments to all YAML config files under ``configs/``. +Also checks that PARAMETER_DESCRIPTIONS stays in sync with the ConfigSchema spec. """ import os @@ -8,7 +9,8 @@ import sys from pathlib import Path -from schema import And, Optional, Or, Schema +# Ensure the workspace version of zea is used when this script is run directly. +sys.path.insert(0, str(Path(__file__).resolve().parent.parent.parent)) from zea import log from zea.internal.config.parameters import PARAMETER_DESCRIPTIONS @@ -19,9 +21,8 @@ # Determine if we are running from the docs directory or project root def get_project_paths(): - """Get the project paths for configs and parameters.rst based on the current working directory.""" + """Get the project paths for configs based on the current working directory.""" cwd = Path.cwd().resolve() - # If we're in docs or docs/source, adjust accordingly if cwd.name == "source" and cwd.parent.name == "docs": project_root = cwd.parent.parent elif cwd.name == "docs": @@ -30,19 +31,14 @@ def get_project_paths(): project_root = cwd configs_dir = project_root / "configs" - rst_path = project_root / "docs" / "source" / "config.rst" - # Ensure the configs directory exists if not configs_dir.exists(): raise FileNotFoundError(f"Configs directory does not exist: {configs_dir}") - # Ensure the rst path is valid - if not rst_path.parent.exists(): - raise FileNotFoundError(f"RST path does not exist: {rst_path.parent}") - return configs_dir, rst_path + return configs_dir -CONFIGS_DIRNAME, PARAMETERS_RST_PATH = get_project_paths() +CONFIGS_DIRNAME = get_project_paths() def flatten_dict_keys(d, prefix=""): @@ -59,37 +55,6 @@ def flatten_dict_keys(d, prefix=""): return keys -def flatten_schema_keys(schema, prefix=""): - """Recursively flatten schema keys into dot notation, handling Optional/And/Or.""" - keys = set() - if isinstance(schema, Schema): - schema = schema.schema - if isinstance(schema, dict): - for k, v in schema.items(): - if isinstance(k, Optional): - key = getattr(k, "schema", getattr(k, "key", k)) - else: - key = k - full = f"{prefix}.{key}" if prefix else str(key) - if isinstance(v, And): - if hasattr(v, "_validators") and v._validators: - last = v._validators[-1] - keys |= flatten_schema_keys(last, full) - else: - keys.add(full) - elif isinstance(v, Or): - if hasattr(v, "_options"): - for opt in v._options: - keys |= flatten_schema_keys(opt, full) - else: - keys.add(full) - elif isinstance(v, (Schema, dict)): - keys |= flatten_schema_keys(v, full) - else: - keys.add(full) - return keys - - def wrap_string_as_comment(input_string, indent_level=0, max_line_length=100, indent_size=2): """Limit the length of lines in a string and adds a comment prefix.""" if isinstance(input_string, dict): @@ -121,8 +86,6 @@ def process_yaml_content(lines, descriptions, indent_size=2): """Process YAML content line by line and add comments.""" modified_lines = [] current_keys = [] - data_found = False - plot_found = False for line in lines: if line.strip().startswith("#"): continue @@ -131,17 +94,12 @@ def process_yaml_content(lines, descriptions, indent_size=2): current_keys = current_keys[:indent_level] key = line.split(":")[0].strip() current_keys.append(key) - if key == "data" and indent_level == 0: - data_found = True - if key == "plot" and indent_level == 0: - plot_found = True # Special handling: skip comments for operation entries inside pipeline.operations if ( len(current_keys) >= 3 and current_keys[0] == "pipeline" and current_keys[1] == "operations" ): - # Do not add comments for keys inside operations list modified_lines.append(line) continue description = descriptions @@ -153,17 +111,12 @@ def process_yaml_content(lines, descriptions, indent_size=2): comment_lines = wrap_string_as_comment( description, indent_level, max_line_length=80, indent_size=indent_size ) - # Only add comment if it's not just "# -" if comment_lines.strip() != "# -": modified_lines.append(comment_lines) modified_lines.append(line) else: modified_lines.append(line) - if data_found and plot_found: - return modified_lines - else: - print("data and/or plot key not found. Not adding comments.") - return lines + return modified_lines def add_comments_to_yaml(file_path, descriptions): @@ -179,125 +132,19 @@ def add_comments_to_yaml(file_path, descriptions): file.writelines(modified_content) -def check_parameter_descriptions(descriptions, schema): +def check_parameter_descriptions(descriptions, spec): """Check for missing or extra parameter descriptions (ignoring 'description' keys). - Returns (missing, extra) as sets. + + ``spec`` is the top-level :class:`~zea.internal.config.validation.ConfigSpec` + subclass. Returns (missing, extra) as sorted lists. """ rst_keys = flatten_dict_keys(descriptions) - schema_keys = flatten_schema_keys(schema) + schema_keys = spec.all_field_paths() missing = sorted(schema_keys - rst_keys) extra = sorted(rst_keys - schema_keys) return missing, extra -def dict_to_rst_table(param_dict): - """Convert a nested dictionary to a reStructuredText (RST) table.""" - lines = [] - lines.append(".. list-table::") - lines.append(" :header-rows: 1") - lines.append(" :widths: 20 80\n") - lines.append(" * - **Parameter**") - lines.append(" - **Description**") - - def recurse(d, prefix=""): - for k in sorted(d): - if k == "description": - continue - v = d[k] - param_name = f"{prefix}.{k}" if prefix else k - if isinstance(v, dict): - desc = v.get("description", "") - lines.append(f" * - ``{param_name}``\n - {desc}") - recurse(v, param_name) - else: - lines.append(f" * - ``{param_name}``\n - {v}") - - recurse(param_dict) - return "\n".join(lines) - - -def create_parameters_rst(param_dict, rst_path=PARAMETERS_RST_PATH): - """Generate a reStructuredText (RST) file from the parameter dictionary.""" - intro = """.. THIS FILE WAS AUTOGENERATED USING docs/source/parameters_doc.py. DO NOT EDIT MANUALLY. - -.. _config: - -Config -====== - -This page documents the ``zea`` configuration system. Configs are YAML files that -control data loading, preprocessing, model settings, and scan parameters. - -.. note:: - Configs are used to initialize :doc:`zea.Models ` and the :doc:`pipeline`. - For the data format and file I/O, see :doc:`data-acquisition`. - -.. note:: - The ``parameters`` section is a flat mapping that overrides values loaded from the - data file. It mirrors :class:`zea.Parameters` (the merged :class:`zea.Probe` and scan - parameters — see :doc:`data-acquisition`), and may additionally contain arbitrary - custom parameters that are passed straight through to the pipeline. The documented - reconstruction keys below are the most common ones, but they are not exhaustive. - -Configs are written in YAML format and can be loaded, edited, and saved using the ``zea`` API. - -------------------------------- -How to Load and Save a Config -------------------------------- - -Here is a minimal example of how to load and save a config file using zea: - -.. doctest:: - - >>> from zea import Config - >>> from zea.config import check_config - - >>> # Load a config from file - >>> config = Config.from_path("../configs/config_picmus_rf.yaml") - >>> # or some predefined from Hugging Face Hub - >>> config = Config.from_path("hf://zeahub/configs/config_picmus_rf.yaml") - - >>> # We can check if the config has valid parameters (zea compliance) - >>> config = check_config(config) - - >>> # Access or change parameters - >>> config.parameters.grid_size_x = 512 - >>> print(config.parameters.grid_size_x) - 512 - - >>> # Save the config back to file - >>> config.to_yaml("my_new_config.yaml") - -.. testcleanup:: - - import os - - os.remove("my_new_config.yaml") - -------------------------------- -Parameter List -------------------------------- - -Below is a hierarchical list of all configuration parameters, grouped by section. -Descriptions are shown for each parameter. - -.. contents:: - :local: - :depth: 2 - -------------------------------- -Parameters Reference -------------------------------- -""" - table = dict_to_rst_table(param_dict) - with open(rst_path, "w", encoding="utf-8") as f: - f.write(intro) - f.write("\n") - f.write(table) - f.write("\n") - log.info(f"Generated {rst_path} from PARAMETER_DESCRIPTIONS.") - - def update_configs(descriptions, configs_dir=CONFIGS_DIRNAME): """Update YAML config files with comments based on parameter descriptions.""" config_dir = Path(configs_dir) @@ -308,10 +155,10 @@ def update_configs(descriptions, configs_dir=CONFIGS_DIRNAME): if __name__ == "__main__": - from zea.internal.config.validation import config_schema + from zea.internal.config.validation import ConfigSchema # 1. Check parameter descriptions - missing, extra = check_parameter_descriptions(PARAMETER_DESCRIPTIONS, config_schema) + missing, extra = check_parameter_descriptions(PARAMETER_DESCRIPTIONS, ConfigSchema) if missing: log.warning( "The following config parameters are missing descriptions in PARAMETER_DESCRIPTIONS:" @@ -331,8 +178,5 @@ def update_configs(descriptions, configs_dir=CONFIGS_DIRNAME): else: log.info("All config parameters are documented in PARAMETER_DESCRIPTIONS.") - # 2. Generate config.rst - create_parameters_rst(PARAMETER_DESCRIPTIONS) - - # 3. Update YAML configs with comments + # 2. Update YAML configs with comments update_configs(PARAMETER_DESCRIPTIONS) diff --git a/docs/source/spec_doc.py b/docs/source/spec_doc.py index 48b2ccd84..b7c942f53 100644 --- a/docs/source/spec_doc.py +++ b/docs/source/spec_doc.py @@ -327,6 +327,8 @@ def generate() -> str: # --- Group reference tabs ------------------------------------------------- lines += [ + ".. _group-reference:", + "", "Group reference", "~~~~~~~~~~~~~~~", "", diff --git a/poetry.lock b/poetry.lock index 0961ea95c..ece148dd3 100644 --- a/poetry.lock +++ b/poetry.lock @@ -4844,18 +4844,6 @@ files = [ {file = "ruff-0.15.14.tar.gz", hash = "sha256:48e866b165be4a9bdbf310f7d3c9a07edef2fe8cd63ffeb4e00bb590506ebf9f"}, ] -[[package]] -name = "schema" -version = "0.7.7" -description = "Simple data validation library" -optional = false -python-versions = "*" -groups = ["main"] -files = [ - {file = "schema-0.7.7-py2.py3-none-any.whl", hash = "sha256:5d976a5b50f36e74e2157b47097b60002bd4d42e65425fcc9c9befadb4255dde"}, - {file = "schema-0.7.7.tar.gz", hash = "sha256:7da553abd2958a19dc2547c388cde53398b39196175a9be59ea1caf5ab0a1807"}, -] - [[package]] name = "scikit-image" version = "0.25.2" @@ -6603,4 +6591,4 @@ tests = ["cloudpickle", "ipykernel", "ipywidgets", "papermill", "pre-commit", "p [metadata] lock-version = "2.1" python-versions = ">=3.10" -content-hash = "79841908effb8e2652ca2ef0e08762a38bf0dc8403835b9a89d1e82d5a123590" +content-hash = "b2322c414b93c57f2836316c56eabe2fc668f2045b63824a41c3df66e415d6fe" diff --git a/pyproject.toml b/pyproject.toml index 20d83d67b..5b88211c0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,7 +40,6 @@ dependencies = [ "matplotlib >=3.8", "scipy >=1.13", "pillow >=12.2.0", - "schema >=0.7", "tqdm >=4", "pyyaml >=6", "decorator >=5", diff --git a/tests/data/test_dataset.py b/tests/data/test_dataset.py index cec6120a1..5cc476fd8 100644 --- a/tests/data/test_dataset.py +++ b/tests/data/test_dataset.py @@ -62,12 +62,12 @@ ) def test_dataset_indexing(file_idx, idx, expected_shape, dummy_dataset_path): """Test ui initialization function""" - config = {"data": {"dataset_folder": dummy_dataset_path, "dtype": "image"}} + config = {"data": {"path": str(dummy_dataset_path)}} config = check_config(Config(config)) dataset = Dataset.from_config(**config.data) file = dataset[file_idx] - data = file[file.format_key(config.data.dtype)]["values"][idx] + data = file[file.format_key("image")]["values"][idx] assert data.shape == expected_shape, ( f"Data shape {data.shape} does not match expected shape {expected_shape}" diff --git a/tests/test_config_validation.py b/tests/test_config_validation.py new file mode 100644 index 000000000..b3c6645c1 --- /dev/null +++ b/tests/test_config_validation.py @@ -0,0 +1,129 @@ +"""Tests for the dataclass-based config validation (zea.internal.config.validation).""" + +import pytest + +from zea.config import Config, _migrate_legacy_config, check_config +from zea.internal.config.validation import ( + ConfigSchema, + ParametersConfig, + validate_config, +) + + +def test_defaults_are_filled(): + """Validation fills in defaults for all optional sections.""" + result = validate_config({}) + + assert result["device"] == "auto:1" + assert result["git"] is None + assert result["pipeline"]["operations"] == ["identity"] + # data defaults + assert result["data"]["local"] is True + assert result["data"]["path"] is None + assert result["data"]["indices"] is None + + +def test_validation_is_idempotent(): + """Validating an already-validated config yields the same dict.""" + once = validate_config({"data": {"path": "hf://zeahub/picmus/file.hdf5", "local": False}}) + twice = validate_config(once) + assert once == twice + + +def test_empty_config_is_valid(): + """An empty config is valid — no required fields in ConfigSchema.""" + result = validate_config({}) + assert result["device"] == "auto:1" + assert result["data"]["local"] is True + + +def test_missing_required_data_field_does_not_raise(): + """All data fields are optional — an empty data: section is valid.""" + result = validate_config({"data": {}}) + assert result["data"]["path"] is None + assert result["data"]["local"] is True + + +@pytest.mark.parametrize( + "config", + [ + {"device": "tpu:0"}, # invalid device + {"pipeline": {"jit_options": "bad_option"}}, # enum + {"data": {"local": "yes"}}, # must be bool + {"data": {"indices": {"bad": "type"}}}, # invalid indices type + ], +) +def test_invalid_values_raise(config): + with pytest.raises(ValueError): + validate_config(config) + + +@pytest.mark.parametrize("device", ["cpu", "gpu", "cuda", "cuda:0", "gpu:1", "auto:1", "auto:-1"]) +def test_valid_devices(device): + result = validate_config({"device": device}) + assert result["device"] == device + + +def test_arbitrary_parameters_keys_pass_through(): + """The parameters section accepts and round-trips arbitrary custom keys.""" + config = {"parameters": {"grid_size_x": 128, "my_custom_param": 42}} + result = validate_config(config) + assert result["parameters"]["grid_size_x"] == 128 + assert result["parameters"]["my_custom_param"] == 42 + + +def test_arbitrary_top_level_keys_preserved(): + """Unknown top-level sections (e.g. model:) are preserved unchanged.""" + config = {"model": {"name": "diffusion", "steps": 100}} + result = validate_config(config) + assert result["model"] == {"name": "diffusion", "steps": 100} + + +def test_parameters_config_is_open(): + assert ParametersConfig.ALLOW_EXTRA is True + assert ConfigSchema.ALLOW_EXTRA is True + + +def test_all_field_paths_includes_nested(): + paths = ConfigSchema.all_field_paths() + assert "data.path" in paths + assert "data.local" in paths + assert "data.indices" in paths + assert "pipeline.operations" in paths + assert "pipeline.jit_options" in paths + assert "device" in paths + assert "git" in paths + assert "plot.plot_lib" not in paths + assert "data.dtype" not in paths + assert "data.dynamic_range" not in paths + + +def test_scan_alias_migrated_to_parameters(): + """The deprecated scan: section is aliased to parameters: on load.""" + migrated = _migrate_legacy_config({"scan": {"grid_size_x": 64}}) + assert "scan" not in migrated + assert migrated["parameters"] == {"grid_size_x": 64} + + +def test_check_config_freezes_config_object(): + config = Config({}) + checked = check_config(config) + assert isinstance(checked, Config) + assert checked.__frozen__ is True + assert checked.pipeline.operations == ["identity"] + assert checked.data.local is True + + +def test_data_config_local_default(): + """DataConfig local defaults to True even without data: in the config.""" + result = validate_config({}) + assert result["data"]["local"] is True + + +def test_data_config_passthrough_with_full_section(): + """A full data: section validates correctly.""" + config = {"data": {"path": "hf://zeahub/picmus/file.hdf5", "local": False, "indices": "all"}} + result = validate_config(config) + assert result["data"]["path"] == "hf://zeahub/picmus/file.hdf5" + assert result["data"]["local"] is False + assert result["data"]["indices"] == "all" diff --git a/tests/test_configs.py b/tests/test_configs.py index 676e4f396..1d0c7c1ff 100644 --- a/tests/test_configs.py +++ b/tests/test_configs.py @@ -5,7 +5,6 @@ import pytest import yaml -from schema import SchemaError from zea.config import Config, check_config from zea.internal.setup_zea import setup_config @@ -70,11 +69,17 @@ def test_all_configs_valid(file): configuration = check_config(configuration) # check another time, since defaults are now set, which are not # checked by the first check_config. Basically this checks if the - # config_validation.py entries are correct. + # validation.py entries are correct. check_config(configuration) - except SchemaError as se: - raise ValueError(f"Error in config {file}") from se + except ValueError as ve: + raise ValueError(f"Error in config {file}") from ve + + +def test_config_rejects_string_path(): + """Config(path) must raise TypeError — use Config.from_path() instead.""" + with pytest.raises(TypeError, match="Config.from_path"): + Config("configs/config_picmus_rf.yaml") def test_dot_indexing(): diff --git a/tests/test_interface.py b/tests/test_interface.py deleted file mode 100644 index e5d4cf499..000000000 --- a/tests/test_interface.py +++ /dev/null @@ -1,60 +0,0 @@ -"""Basic testing for interface / generate""" - -import sys -from pathlib import Path -from unittest.mock import MagicMock - -import numpy as np - -from tests.data import generate_example_dataset -from zea.data.file import File -from zea.interface import Interface -from zea.internal.setup_zea import setup_config - -wd = Path(__file__).parent.parent -sys.path.append(str(wd)) - - -def test_interface_initialization(): - """Test interface initialization""" - config = setup_config("hf://zeahub/configs/config_camus.yaml") - - interface = Interface(config) - interface.run(plot=True) - - data = interface.get_data() - assert data is not None - assert isinstance(data, np.ndarray), "Data is not a numpy array" - assert len(data.shape) == 2, "Data must be 2d (grid_size_z, grid_size_x)" - - -def test_interface_reads_map_backed_dataset(tmp_path): - """For map-backed types (e.g. image) the read must descend into the - 'values' sub-dataset rather than indexing the group directly.""" - path = tmp_path / "with_image.hdf5" - generate_example_dataset( - path, - add_optional_dtypes=True, - n_frames=3, - grid_size_z=8, - grid_size_x=8, - image_dtype=np.uint8, - ) - - with File(path) as f: - iface = object.__new__(Interface) - iface.file = f - iface.verbose = False - config = MagicMock() - config.data.dtype = "image" - config.data.frame_no = 0 - iface.config = config - - grp = f[f.format_key("image")] - if hasattr(grp, "keys") and "values" in grp: - data = grp["values"][0] - else: - data = grp[0] - - assert isinstance(data, np.ndarray), "get_data must return ndarray" - assert data.ndim >= 2, "returned frame must be at least 2-D" diff --git a/zea/__init__.py b/zea/__init__.py index b4a824b35..f67e9a7f8 100644 --- a/zea/__init__.py +++ b/zea/__init__.py @@ -31,7 +31,6 @@ from .data.datasets import Dataset, Folder from .data.file import File, load_file from .datapaths import set_data_paths - from .interface import Interface from .internal.device import init_device from .internal.setup_zea import setup, setup_config from .ops import Pipeline @@ -151,7 +150,6 @@ def _check_backend_installed(): "File": ("zea.data.file", "File"), "load_file": ("zea.data.file", "load_file"), "set_data_paths": ("zea.datapaths", "set_data_paths"), - "Interface": ("zea.interface", "Interface"), "init_device": ("zea.internal.device", "init_device"), "setup": ("zea.internal.setup_zea", "setup"), "setup_config": ("zea.internal.setup_zea", "setup_config"), diff --git a/zea/__main__.py b/zea/__main__.py index 4851fdf61..c025f927e 100644 --- a/zea/__main__.py +++ b/zea/__main__.py @@ -1,63 +1,12 @@ -"""Main entry point for zea - -Run as `zea --config path/to/config.yaml` to start the zea interface. -Or do not pass a config file to open a file dialog to choose a config file. +"""Main entry point for zea. +CLI functionality will be added in a future PR. """ -import argparse -import sys -from pathlib import Path - -from zea.visualize import set_mpl_style - - -def get_parser(): - """Command line argument parser""" - parser = argparse.ArgumentParser( - description="Load and process ultrasound data based on a configuration file." - ) - parser.add_argument("-c", "--config", type=str, default=None, help="path to the config file.") - parser.add_argument( - "-t", - "--task", - default="view", - choices=["view"], - type=str, - help="Which task to run. Currently only 'view' is supported.", - ) - parser.add_argument( - "--skip_validate_file", - default=False, - action="store_true", - help="Skip zea file integrity checks. Use with caution.", - ) - return parser - def main(): """main entrypoint for zea""" - args = get_parser().parse_args() - - set_mpl_style() - - wd = Path(__file__).parent.resolve() - sys.path.append(str(wd)) - - from zea.interface import Interface - from zea.internal.setup_zea import setup - - config = setup(args.config) - - if args.task == "view": - cli = Interface( - config, - validate_file=not args.skip_validate_file, - ) - - cli.run(plot=True) - else: - raise ValueError(f"Unknown task {args.task}, see `zea --help` for available tasks.") + print("zea CLI is not yet available. Please use the Python API directly.") if __name__ == "__main__": diff --git a/zea/config.py b/zea/config.py index 1cba326b5..56a92049e 100644 --- a/zea/config.py +++ b/zea/config.py @@ -26,11 +26,11 @@ >>> config = Config.from_path("hf://zeahub/configs/config_picmus_rf.yaml") >>> # Access attributes with dot notation - >>> print(config.data.dtype) - raw_data + >>> print(config.data.local) + False >>> # Update recursively - >>> config.update_recursive({"data": {"dtype": "raw_data"}}) + >>> config.update_recursive({"data": {"local": False}}) >>> # Save to YAML >>> config.to_yaml("new_config.yaml") @@ -53,7 +53,7 @@ import yaml from zea import log -from zea.internal.config.validation import config_schema +from zea.internal.config.validation import validate_config from zea.internal.core import dict_to_tensor from zea.internal.preset_utils import HF_PREFIX, _hf_resolve_path from zea.internal.utils import deprecated @@ -103,6 +103,11 @@ def __init__(self, dictionary=None, __parent__=None, **kwargs): super().__setattr__("__accessed__", {}) super().__setattr__("__parent__", __parent__) + if isinstance(dictionary, (str, Path)): + raise TypeError( + f"Config() expects a dict, not {type(dictionary).__name__!r}. " + "To load from a file use Config.from_path()." + ) if dictionary is None: dictionary = {} if kwargs: @@ -508,7 +513,7 @@ def check_config(config: Union[dict, Config], verbose: bool = False): def _try_validate_config(config): try: - config = config_schema.validate(config) + config = validate_config(config) return config except Exception as e: log.error(f"Config is not valid: {e}") diff --git a/zea/data/convert/verasonics.py b/zea/data/convert/verasonics.py index 823c908dc..881126e42 100644 --- a/zea/data/convert/verasonics.py +++ b/zea/data/convert/verasonics.py @@ -40,6 +40,7 @@ """ # noqa: E501 import os +import re import sys import traceback from pathlib import Path @@ -48,7 +49,6 @@ import numpy as np import yaml from keras import ops -from schema import And, Optional, Or, Regex, Schema from zea import log from zea.data.convert.utils import ( @@ -69,22 +69,60 @@ } -_CONVERT_YAML_SCHEMA = Schema( - { - "files": [ - { - "name": str, - Optional("first_frame"): And(int, lambda x: x >= 0), - Optional("frames"): Or( - "all", - And(str, Regex(r"^\d+(-\d+)?$")), # Matches "30-99" or single number like "5" - [And(int, lambda x: x >= 0)], # List of non-negative integers - ), - Optional("transmits"): Or("all", [And(int, lambda x: x >= 0)]), - } - ] - } -) +_FRAMES_RANGE_RE = re.compile(r"^\d+(-\d+)?$") + + +def _validate_convert_config(data): + """Validate the structure of a convert.yaml config dict. + + Expected shape:: + + files: + - name: + first_frame: = 0> # optional + frames: all | "N" | "N-M" | [N, ...] # optional + transmits: all | [N, ...] # optional + """ + if not isinstance(data, dict) or "files" not in data: + raise ValueError("convert.yaml must have a top-level 'files' key") + if not isinstance(data["files"], list): + raise ValueError("'files' must be a list") + for entry in data["files"]: + if not isinstance(entry, dict): + raise ValueError(f"each entry in 'files' must be a dict, got {type(entry).__name__}") + if not isinstance(entry.get("name"), str): + raise ValueError(f"each file entry must have a string 'name', got {entry!r}") + if "first_frame" in entry: + ff = entry["first_frame"] + if not isinstance(ff, int) or isinstance(ff, bool) or ff < 0: + raise ValueError(f"'first_frame' must be a non-negative int, got {ff!r}") + if "frames" in entry: + fr = entry["frames"] + if not ( + fr == "all" + or (isinstance(fr, str) and _FRAMES_RANGE_RE.fullmatch(fr)) + or ( + isinstance(fr, list) + and all(isinstance(x, int) and not isinstance(x, bool) and x >= 0 for x in fr) + ) + ): + raise ValueError( + f"'frames' must be 'all', a range string like '30-99', or a list of " + f"non-negative ints, got {fr!r}" + ) + if "transmits" in entry: + tr = entry["transmits"] + if not ( + tr == "all" + or ( + isinstance(tr, list) + and all(isinstance(x, int) and not isinstance(x, bool) and x >= 0 for x in tr) + ) + ): + raise ValueError( + f"'transmits' must be 'all' or a list of non-negative ints, got {tr!r}" + ) + return data class VerasonicsFile(h5py.File): @@ -452,7 +490,7 @@ def load_convert_config(self): data = yaml.load(file, Loader=yaml.FullLoader) # Validate the YAML structure - validated_data = _CONVERT_YAML_SCHEMA.validate(data) + validated_data = _validate_convert_config(data) files = validated_data["files"] filenames = [file["name"] for file in files] diff --git a/zea/data/datasets.py b/zea/data/datasets.py index f99de0b50..636e8a149 100644 --- a/zea/data/datasets.py +++ b/zea/data/datasets.py @@ -511,18 +511,11 @@ def find_files(self, paths) -> List[str]: return file_paths @classmethod - def from_config(cls, dataset_folder, user=None, **kwargs): + def from_config(cls, path, user=None, **kwargs): """Creates a Dataset from a config file.""" - path = format_data_path(dataset_folder, user) - - if "file_path" in kwargs: - log.warning( - "Found 'file_path' in config, this will be ignored since a Dataset is " - + "always multiple files." - ) - + resolved = format_data_path(path, user) reduced_params = reduce_to_signature(cls.__init__, kwargs) - return cls(path, **reduced_params) + return cls(resolved, **reduced_params) def __len__(self): """Returns the number of files in the dataset.""" diff --git a/zea/interface.py b/zea/interface.py deleted file mode 100644 index 724e1a418..000000000 --- a/zea/interface.py +++ /dev/null @@ -1,551 +0,0 @@ -"""Convenience interface for loading and displaying ultrasound data. - -Example usage -^^^^^^^^^^^^^^ - -.. doctest:: - - >>> import zea - >>> from zea.internal.setup_zea import setup_config - - >>> config = setup_config("hf://zeahub/configs/config_camus.yaml") - - >>> interface = zea.Interface(config) - >>> interface.run(plot=True) # doctest: +SKIP - -""" - -import asyncio -import time -from pathlib import Path -from typing import List - -import matplotlib -import matplotlib.pyplot as plt -import numpy as np -from PIL import Image - -from zea import log -from zea.config import Config -from zea.data.file import File -from zea.datapaths import format_data_path -from zea.display import to_8bit -from zea.internal.core import DataTypes -from zea.internal.utils import keep_trying -from zea.internal.viewer import ( - ImageViewerMatplotlib, - ImageViewerOpenCV, - filename_from_window_dialog, - running_in_notebook, -) -from zea.io_lib import matplotlib_figure_to_numpy, save_video -from zea.ops import Pipeline - - -class Interface: - """Interface for selecting / loading / processing single ultrasound images. - - Useful for inspecting datasets and single ultrasound images. - - # TODO: maybe we can refactor such that it is clear what needs to be in config. - """ - - def __init__(self, config: Config = None, verbose: bool = True, validate_file: bool = True): - """Initialize Interface. - - Args: - config (Config): Configuration object. - verbose (bool): Whether to print verbose output. - validate_file (bool): Whether to validate the file. - """ - self.config = Config(config) - self.verbose = verbose - - self.file = File(self.file_path) - - if validate_file: - self.file.validate() - - # get probe and parameters from file - self.probe = self.file.probe - self._param_overrides = dict(self.config.get("parameters", {}) or {}) - self.parameters = self.file.load_parameters(**self._param_overrides) - - # initialize Pipeline - assert "pipeline" in self.config, ( - "Pipeline not found in config, please specify pipeline in config." - ) - - self.process = Pipeline.from_config( - self.config, - with_batch_dim=False, - jit_options=None, - ) - self.input_tensors = self.process.prepare_parameters( - self.parameters, **self._param_overrides - ) - - # initialize attributes for UI class - self.data = None - self.image = None - self.mpl_img = None - self.fig = None - self.ax = None - self.gui = None - self.image_viewer = None - - self.plot_lib = self.config.plot.plot_lib - - if self.config.plot.headless is None: - self.headless = False - else: - self.headless = self.config.plot.headless - - self.check_for_display() - - if self.plot_lib == "opencv": - self.image_viewer = ImageViewerOpenCV( - self.data_to_display, - window_name=self.file.name, - num_threads=1, - headless=self.headless, - ) - elif self.plot_lib == "matplotlib": - self.image_viewer = ImageViewerMatplotlib( - self.data_to_display, - window_name=self.file.name, - num_threads=1, - ) - - @property - def dtype(self): - """Data type of data when loaded from file.""" - return self.config.data.dtype - - @property - def dataset_folder(self): - """Path to dataset folder.""" - return format_data_path(self.config.data.dataset_folder, self.config.data.user) - - @property - def file_path(self): - """Path to data file.""" - if self.config.data.file_path: - return self.dataset_folder / self.config.data.file_path - else: - return self.choose_file_path() - - @file_path.setter - def file_path(self, value): - """Set file path to data file.""" - self.config.data.file_path = value - - def choose_file_path(self): - """Choose file path from window dialog.""" - if self.headless: - raise ValueError( - "No file path specified for data file, which is required " - "in headless mode as window dialog cannot be opened." - ) - filetype = "hdf5" - log.info("Please select file from window dialog...") - self.file_path = filename_from_window_dialog( - f"Choose .{filetype} file", - filetypes=((filetype, "*." + filetype),), - initialdir=self.dataset_folder, - ) - return self.file_path - - @property - def data_root(self): - """Root path to data file.""" - return Path(self.config.user.data_root) - - @dtype.setter - def dtype(self, value): - self.config.data.dtype = value - - @property - def to_dtype(self): - """Data type to convert to for display.""" - return self.config.data.to_dtype - - @to_dtype.setter - def to_dtype(self, value): - self.config.data.to_dtype = value - - @property - def frame_no(self): - """Frame number to display.""" - return self.config.data.get("frame_no") - - @frame_no.setter - def frame_no(self, value): - self.config.data.frame_no = value - - def check_for_display(self): - """check if in headless mode (no monitor available)""" - if self.headless is False: - if matplotlib.get_backend().lower() == "agg": - self.headless = True - log.warning("Could not connect to display, running headless.") - else: - # self.plot_lib = "matplotlib" # force matplotlib in headless mode - matplotlib.use("agg") - log.info("Running in headless mode as set by config.") - - def set_backend_for_notebooks(self): - """Set backend to QtAgg if running in notebook""" - if running_in_notebook() and not self.headless: - matplotlib.use("QtAgg") - - def get_data(self): - """Get data. Chosen datafile should be listed in the dataset. - - Using either file specified in config or if None, the ui window. - - Returns: - data (np.ndarray): data array of shape (n_tx, n_el, n_ax, N_ch) - """ - if self.verbose: - log.info(f"Selected {log.yellow(self.file_path)}") - - # grab frame number from config or user input if not set in config - if self.frame_no == "all": - log.info("Will run all frames as `all` was chosen in config...") - elif self.frame_no is None: - if self.file.n_frames == 1: - self.frame_no = 0 - else: - self.frame_no = keep_trying( - lambda: int(input(f">> Frame number (0 / {self.file.n_frames - 1}): ")) - ) - - # get data from dataset - grp = self.file[self.file.format_key(self.dtype)] - if hasattr(grp, "keys") and "values" in grp: - data = grp["values"][self.frame_no] - else: - data = grp[self.frame_no] - - return data - - def data_to_display(self, data=None): - """Get data and convert to display to_dtype.""" - if data is None: - self.data = self.get_data() - else: - self.data = data - - if self.to_dtype not in ["image", "image_sc"]: - log.warning( - f"Image to_dtype: {self.to_dtype} not supported for displaying data." - "falling back to to_dtype: `image_sc`" - ) - self.to_dtype = "image_sc" - - # select transmits if raw or aligned data - data_type = self.process.operations[0].input_data_type - if data_type in [DataTypes.RAW_DATA, DataTypes.ALIGNED_DATA]: - n_tx = self.data.shape[0] - assert len(self.parameters.selected_transmits) <= n_tx, ( - f"Number of selected transmits {len(self.parameters.selected_transmits)} " - f"exceeds number of transmits in raw data {n_tx}" - ) - self.data = np.take(self.data, self.parameters.selected_transmits, axis=0) - - inputs = {self.process.key: self.data} - - outputs = self.process(**inputs, **self.input_tensors) - - self.image = outputs[self.process.output_key] - - # match orientation if necessary - if self.config.plot.fliplr: - self.image = np.fliplr(self.image) - # opencv requires 8 bit images - if self.plot_lib == "opencv": - self.image = to_8bit(self.image, self.config.data.dynamic_range) - return self.image - - def run(self, plot=False, block=True): - """Run ui. Will retrieve, process and plot data if set to True.""" - save = self.config.plot.save - - if self.frame_no == "all": - try: - loop = asyncio.get_running_loop() - loop.create_task(self.run_movie(save)) # already running loop - except RuntimeError: - asyncio.run(self.run_movie(save)) # no loop yet - - else: - if plot: - self.image = self.plot( - save=save, - block=block, - ) - else: - self.image = self.data_to_display() - - return self.image - - def plot( - self, - data: np.ndarray = None, - save: bool = False, - block: bool = True, - ): - """Plot image using matplotlib or opencv. - - Args: - save (bool): whether to save the image to disk. - block (bool): whether to block the UI while plotting. - Returns: - image (np.ndarray): plotted image (grabbed from figure). - """ - assert self.image_viewer is not None, "Image viewer not initialized." - - self.image_viewer.threading = False - - if self.plot_lib == "matplotlib": - if self.image_viewer.fig is None: - self._init_plt_figure() - self.image_viewer.show(data) - if save: - self.save_image(self.fig) - if not self.headless and block: - plt.show(block=True) - self.image = matplotlib_figure_to_numpy(self.fig) - return self.image - - elif self.plot_lib == "opencv": - self.image_viewer.show(data) - if not self.headless and block: - self.image_viewer._cv2.waitKey(0) - self.save_image(self.image) - return self.image - - def _init_plt_figure(self): - figsize = (10, 10) - if self.parameters: - extent = [ - self.parameters.xlims[0] * 1e3, - self.parameters.xlims[1] * 1e3, - self.parameters.zlims[1] * 1e3, - self.parameters.zlims[0] * 1e3, - ] - # set figure aspect ratio to match scan - aspect_ratio = abs(extent[1] - extent[0]) / abs(extent[3] - extent[2]) - figsize = tuple(np.array(figsize) * aspect_ratio) - else: - extent = None - - self.fig, self.ax = plt.subplots(figsize=figsize) - - image_range = self.config.data.dynamic_range - imshow_kwargs = { - "cmap": "gray", - "vmin": image_range[0], - "vmax": image_range[1], - "origin": "upper", - "extent": extent, - "interpolation": "none", - } - cax_kwargs = { - "pad": 0.05, - "position": "right", - "size": "5%", - } - - self.ax.set_xlabel("Lateral Width (mm)", size=15) - self.ax.set_ylabel("Axial length (mm)", size=15) - self.ax.tick_params(axis="x") - self.ax.tick_params(axis="y") - - # assign properties of fig, ax to image viewer - self.image_viewer.imshow_kwargs = imshow_kwargs - self.image_viewer.cax_kwargs = cax_kwargs - self.image_viewer.fig = self.fig - self.image_viewer.ax = self.ax - - async def run_movie(self, save: bool = False): - """Run all frames in file in sequence""" - - log.info('Playing video, press/hold "q" while the window is active to exit...') - self.image_viewer.threading = True - images = await self._movie_loop(save) - - if save: - self.save_video(images) - - async def _movie_loop(self, save: bool = False) -> List[np.ndarray]: - """Process data and plot it in real time. - - NOTE: when plot loop is terminated by user, it will only save the shown frames. - This is to prevent long waiting times when saving a movie (for large datasets). - - Args: - save (bool): Whether to save the plotted images. - - Returns: - list: A list of the plotted images. - """ - # Initialize list of images - images = [] - - # Load correct number of frames (needs to get_data first) - self.frame_no = 0 - self.get_data() - n_frames = self.file.n_frames - - self.verbose = False - try: - while True: - # first frame is already plotted during initialization of plotting - start_time = time.time() - frame_counter = 0 - self.image_viewer.frame_no = 0 - while frame_counter < n_frames: - if self.gui: - await self.gui.check_freeze() - - await asyncio.sleep(0.01) - - self.frame_no = frame_counter - - if frame_counter == 0: - if self.plot_lib == "matplotlib": - if self.image_viewer.fig is None: - self._init_plt_figure() - - self.image_viewer.show() - - # set counter to frame number of image viewer (possibly not updated) - frame_counter = self.image_viewer.frame_no - - # check if frame counter updated - if frame_counter != self.frame_no: - fps = frame_counter / (time.time() - start_time) - print( - f"frame {frame_counter} / {n_frames} ({fps:.2f} fps)", - end="\r", - ) - if save and (len(images) < n_frames): - if self.plot_lib == "matplotlib": - # grab image from plt figure - image = matplotlib_figure_to_numpy(self.fig) - else: - image = np.array(self.image) - images.append(image) - - # For opencv, show frame for 25 ms and check if "q" is pressed - if not self.headless: - if self.plot_lib == "opencv": - if self.image_viewer._cv2.waitKey(25) & 0xFF == ord("q"): - self.image_viewer.close() - return images - if self.image_viewer.has_been_closed(): - return images - # For matplotlib, check if window has been closed - elif self.plot_lib == "matplotlib": - if time.sleep(0.025) and self.image_viewer.has_been_closed(): - return images - # For headless mode, check if all frames have been plotted - if self.headless: - if len(images) == n_frames: - return images - - # clear line, frame number - print("\x1b[2K", end="\r") - - # only loop once if in headless mode - if self.headless: - return images - - except KeyboardInterrupt: - if save: - if len(images) > 0: - self.save_video(images) - raise - - def save_image(self, fig, path=None): - """Save image to disk. - - Args: - fig (fig object): figure. - path (str, optional): path to save image to. Defaults to None. - - """ - if path is None: - if self.config.plot.tag: - tag = "_" + self.config.plot.tag - else: - tag = "" - - if self.frame_no is not None: - filename = self.file_path.stem + "-" + str(self.frame_no) + tag - else: - filename = self.file_path.stem + tag - - ext = f".{self.config.plot.image_extension.lstrip('.')}" - - path = Path("./figures", filename).with_suffix(ext) - Path("./figures").mkdir(parents=True, exist_ok=True) - - if isinstance(fig, plt.Figure): - fig.savefig(path, transparent=True) - elif isinstance(fig, Image.Image): - fig.save(path) - else: - raise ValueError( - f"Figure is not PIL image or matplotlib figure object, got {type(fig)}" - ) - - if self.verbose: - log.info(f"Image saved to {log.yellow(path)}") - - def save_video(self, images, path=None): - """Save video to disk. - - Args: - images (list): list of images. - path (str, optional): path to save image to. Defaults to None. - - """ - if path is None: - if self.config.plot.tag: - tag = "_" + self.config.plot.tag - else: - tag = "" - filename = self.file_path.stem + tag + "." + self.config.plot.video_extension - - path = Path("./figures", filename) - Path("./figures").mkdir(parents=True, exist_ok=True) - - if not isinstance(images[0], np.ndarray): - raise ValueError("Images are not numpy arrays.") - - fps = self.config.plot.fps - - save_video(images, path, fps=fps) - - if self.verbose: - log.info(f"Video saved to {log.yellow(path)}") - - def __del__(self): - try: - if self.image_viewer is not None: - self.image_viewer.close() - except Exception: - pass - try: - if self.fig is not None: - plt.close(self.fig) - except Exception: - pass - try: - if self.file is not None: - self.file.close() - except Exception: - pass diff --git a/zea/internal/config/create.py b/zea/internal/config/create.py index 36ad84201..eb4354f71 100644 --- a/zea/internal/config/create.py +++ b/zea/internal/config/create.py @@ -3,69 +3,73 @@ import sys from pathlib import Path -import schema +import yaml -from zea.config import Config +from zea.config import Config, check_config from zea.internal.config.parameters import PARAMETER_DESCRIPTIONS -from zea.internal.config.validation import check_config, config_schema +from zea.internal.config.validation import ConfigSchema from zea.log import green, red from zea.utils import get_date_string, strtobool -def _get_input_value(config, schema_key, schema_value, descriptions): +def _get_input_value(config, key, validator, descriptions): + """Prompt for a value for ``key``, parse it as YAML, and validate it.""" while True: - input_val = input(f"Enter a value for {schema_key}: ") - if not isinstance(schema_key, str): - _key = schema_key.key - else: - _key = schema_key + input_val = input(f"Enter a value for {key}: ") if input_val == "help": - if _key not in descriptions: - print(red(f"No description available for {_key}")) - continue - print("\t" + green(descriptions[_key])) + desc = descriptions.get(key) if isinstance(descriptions, dict) else None + if not desc: + print(red(f"No description available for {key}")) + else: + print("\t" + green(desc)) continue try: - config[_key] = input_val - if isinstance(schema_value, schema.And): - for _type in schema_value.args: - try: - config[_key] = _type(config[_key]) - break - except Exception: - pass - schema_value.validate(config[_key]) - else: - schema_value(config[_key]) + # YAML parsing mirrors how config files are actually loaded, so e.g. + # "5" -> int, "true" -> bool, "[1, 2]" -> list, "all" -> str. + parsed = yaml.safe_load(input_val) + if validator is not None: + validator(parsed) + config[key] = parsed break - except Exception as e: + except Exception as e: # noqa: BLE001 - report any parse/validation error and retry print(f"Invalid input: {red(e)}") return config +def _resolve_spec_field(keys): + """Resolve a (slash separated) key path to its validator and descriptions. + + Unknown sections/keys resolve to ``None`` validator (extra keys are allowed). + """ + spec_cls = ConfigSchema + descriptions = PARAMETER_DESCRIPTIONS + for k in keys[:-1]: + spec_cls = spec_cls.NESTED.get(k) if spec_cls is not None else None + if isinstance(descriptions, dict): + descriptions = descriptions.get(k, {}) + validator = spec_cls.VALIDATORS.get(keys[-1]) if spec_cls is not None else None + return validator, descriptions + + def create_config(): """Create a new config file by asking the user for input.""" - def _ask_user_input(config, schema_obj, descriptions): - for key, value in schema_obj.schema.items(): - if isinstance(value, schema.Schema): - if not isinstance(key, str): - _key = key.key - else: - _key = key - if isinstance(key, schema.Optional): - # skip optional keys - continue - config[_key] = _ask_user_input( - config.setdefault(_key, {}), value, descriptions[_key] - ) - elif not isinstance(key, schema.Optional): - config = _get_input_value(config, key, value, descriptions) - + def _ask_user_input(config, spec_cls, descriptions): + for name in spec_cls.required_fields(): + nested = spec_cls.NESTED.get(name) + if nested is not None: + sub_desc = descriptions.get(name, {}) if isinstance(descriptions, dict) else {} + config[name] = _ask_user_input(config.setdefault(name, {}), nested, sub_desc) + else: + validator = spec_cls.VALIDATORS.get(name) + config = _get_input_value(config, name, validator, descriptions) return config config = {} - _ask_user_input(config, config_schema, PARAMETER_DESCRIPTIONS) + _ask_user_input(config, ConfigSchema, PARAMETER_DESCRIPTIONS) + + # Sections that are validated nested specs (cannot be set as a single value). + base_schemas = list(ConfigSchema.NESTED) # Ask user if they want to change any optional keys while True: @@ -75,52 +79,27 @@ def _ask_user_input(config, schema_obj, descriptions): change_optional = strtobool(input_val) if change_optional: - key = input("Enter the key name (e.g., 'model/beamformer/param'): ") + key = input("Enter the key name (e.g., 'parameters/grid_size_x'): ") keys = key.split("/") - base_schemas = [ - "data", - "plot", - "model", - "preprocess", - "postprocess", - "scan", - ] - - if len(keys) > 1: - if keys[0] not in base_schemas: - print(red(f"Invalid key {key}, please try again.")) - continue - - if len(keys) == 1: - if keys[0] in base_schemas: - print( - red( - f"Invalid key, cannot be part of base keys {base_schemas} " - "please try again." - ) + + if len(keys) > 1 and keys[0] not in base_schemas: + print(red(f"Invalid key {key}, please try again.")) + continue + if len(keys) == 1 and keys[0] in base_schemas: + print( + red( + f"Invalid key, cannot be part of base keys {base_schemas} " + "please try again." ) - continue + ) + continue nested_dict = config for k in keys[:-1]: nested_dict = nested_dict.setdefault(k, {}) - # retrieve schema value from the nested key - schema_obj = config_schema - for k in keys: - sub_keys = [ - s.key if not isinstance(s, str) else s for s in schema_obj.schema.keys() - ] - - schema_key = list(schema_obj.schema.keys())[sub_keys.index(k)] - - schema_obj = schema_obj.schema[schema_key] - - descriptions = PARAMETER_DESCRIPTIONS - for k in keys[:-1]: - descriptions = descriptions[k] - - nested_dict = _get_input_value(nested_dict, keys[-1], schema_obj, descriptions) + validator, descriptions = _resolve_spec_field(keys) + nested_dict = _get_input_value(nested_dict, keys[-1], validator, descriptions) else: print("No optional keys will be changed.") break diff --git a/zea/internal/config/parameters.py b/zea/internal/config/parameters.py index ba19f80db..c875efc46 100644 --- a/zea/internal/config/parameters.py +++ b/zea/internal/config/parameters.py @@ -1,91 +1,28 @@ """Parameter descriptions for the config file.""" -from zea.internal.config.validation import _ALLOWED_PLOT_LIBS, _DATA_TYPES - - -def allows_type_to_str(allowed_types): - """Transforms a list of allowed types into a string for use in a comment.""" - ouput_str = ", ".join([str(a) if a is not None else "null" for a in allowed_types]) - return ouput_str - - PARAMETER_DESCRIPTIONS = { "data": { - "description": "The data section contains the parameters for the data.", - "dataset_folder": ( - "The path of the folder to load data files from (relative to the user data " - "root as set in users.yaml)" + "description": "Data path and loading settings.", + "path": ( + "Full path to the data file. Supports absolute paths, paths relative to " + "the user data root (set in users.yaml), and Hugging Face Hub paths " + "(hf://org/repo/path/to/file.hdf5)." ), - "to_dtype": (f"The type of data to convert to ({allows_type_to_str(_DATA_TYPES)})"), - "file_path": ( - "The path of the file to load when running the UI (either an absolute path " - "or one relative to the dataset folder)" - ), - "frame_no": "The frame number to load when running the UI (null, int, 'all')", - "input_range": "The range of the input data in db (null, [min, max])", - "apodization": "The receive apodization to use.", - "output_range": ("The output range to which the data should be mapped (e.g. [0, 1])."), - "resolution": ("The spatial resolution of the data in meters per pixel (float, optional)."), "local": "true: use local data on this device, false: use data from NAS", - "dtype": ( - "The form of data to load (raw_data, rf_data, iq_data, beamformed_data, " - "envelope_data, image, image_sc)" + "indices": ( + "Indices into the data to load. null loads the default, 'all' loads every frame, " + "int loads a single frame, list loads specific frames." ), - "dynamic_range": "The dynamic range for showing data in db [min, max]", - "user": "The user to use when loading data (null, dict)", + "user": "User path overrides set automatically by setup_zea (null, dict).", }, "parameters": { "description": ( - "The parameters section is a flat mapping of scan/probe/custom parameters " - "that overwrite values loaded from the data file. Documented reconstruction " - "parameters are listed below; arbitrary custom parameters are also allowed." - ), - "selected_transmits": ( - "The number of transmits in a frame. Can be 'all' for all transmits, an " - "integer for a specific number of transmits selected evenly from the " - "transmits in the frame, or a list of integers for specific transmits to " - "select from the frame." - ), - "grid_size_x": "The number of pixels in the beamforming grid in the x-direction", - "grid_size_z": "The number of pixels in the beamforming grid in the z-direction", - "n_ch": "The number of channels in the raw data (1 for rf data, 2 for iq data)", - "n_ax": "The number of samples in a receive recording per channel.", - "xlims": "The limits of the x-axis in the scan in meters (null, [min, max])", - "ylims": "The limits of the y-axis in the scan in meters (null, [min, max])", - "zlims": "The limits of the z-axis in the scan in meters (null, [min, max])", - "center_frequency": "The center frequency of the transmit pulse in Hz", - "sampling_frequency": "The sampling frequency of the data in Hz", - "demodulation_frequency": ( - "The demodulation frequency of the data in Hz. This is the assumed center " - "frequency of the transmit waveform used to demodulate the rf data to iq " - "data." - ), - "apply_lens_correction": ( - "Set to true to apply lens correction in the time-of-flight calculation" + "Open mapping of scan/probe/custom parameters that overwrite values loaded " + "from the data file. ProbeSpec and ScanSpec are the authoritative sources " + "for valid parameter names — see the spec reference in data-acquisition. " + "Arbitrary custom parameters are forwarded to the pipeline unchanged." ), - "lens_thickness": "The thickness of the lens in meters", - "lens_sound_speed": ("The speed of sound in the lens in m/s. Usually around 1000 m/s"), - "f_number": ( - "The receive f-number for apodization. Set to zero to disable masking. " - "The f-number is the ratio between the distance from the transducer and the " - "size of the aperture." - ), - "fill_value": ( - "Value to fill the image with outside the defined region (float, default 0.0)." - ), - "phi_range": ( - "The range of phi values in radians for 3D scan conversion (null, [min, max])." - ), - "theta_range": ( - "The range of theta values in radians for scan conversion (null, [min, max])." - ), - "rho_range": ("The range of rho values in meters for scan conversion (null, [min, max])."), - "resolution": ("The resolution for scan conversion in meters per pixel (float, optional)."), }, - "scan": ( - "Deprecated alias for 'parameters'. Supported for backward compatibility; " - "prefer using 'parameters'." - ), "pipeline": { "description": "This section contains the necessary parameters for building the pipeline.", "operations": ( @@ -101,29 +38,11 @@ def allows_type_to_str(allowed_types): "'ops' compiles each operation separately. None disables JIT compilation. " "Defaults to 'ops'." ), - "jit_kwargs": ("Additional keyword arguments for the JIT compiler. Defaults to None."), - "name": ("The name of the pipeline. Defaults to 'pipeline'."), - "validate": ("Whether to validate the pipeline. Defaults to True."), - }, - "device": "The device to run on ('cpu', 'gpu:0', 'gpu:1', ...)", - "plot": { - "description": ( - "Settings pertaining to plotting when running the UI " - "(`zea --config `)" - ), - "save": ("Set to true to save the plots to disk, false to only display them in the UI"), - "plot_lib": (f"The plotting library to use ({allows_type_to_str(_ALLOWED_PLOT_LIBS)})"), - "fps": "Frames per second for video output.", - "tag": "The name for the plot", - "fliplr": "Set to true to flip the image left to right", - "image_extension": "The file extension to use when saving the image (png, jpg)", - "video_extension": "The file extension to use when saving the video (mp4, gif)", - "headless": "Set to true to run the UI in headless mode", - "selector": ( - "Type of selector to use for ROI selection in the UI ('rectangle', 'lasso', or None)." - ), - "selector_metric": ("Metric to use for evaluating selected regions (e.g., 'gcnr')."), + "jit_kwargs": "Additional keyword arguments for the JIT compiler. Defaults to None.", + "name": "The name of the pipeline. Defaults to 'pipeline'.", + "validate": "Whether to validate the pipeline. Defaults to True.", }, + "device": "The device to run on ('cpu', 'gpu:0', 'gpu:1', 'auto:1', ...)", "git": "The git commit hash or branch for reproducibility (string, optional).", - "hide_devices": ("List of device indices to hide from selection (list of int, optional)."), + "hide_devices": "List of device indices to hide from selection (list of int, optional).", } diff --git a/zea/internal/config/validation.py b/zea/internal/config/validation.py index 43e32ed1a..aefc508df 100644 --- a/zea/internal/config/validation.py +++ b/zea/internal/config/validation.py @@ -1,178 +1,415 @@ -"""Validate configuration yaml files. +"""Validate configuration dictionaries. -https://github.com/keleshev/schema -https://www.andrewvillazon.com/validate-yaml-python-schema/ +Config validation follows the same dataclass-Spec pattern used elsewhere in zea +for :class:`~zea.data.spec.ProbeSpec` / :class:`~zea.data.spec.ScanSpec`: each +config section is a :func:`dataclasses.dataclass` with typed fields, default +values, and validation in ``__post_init__``. -This file specifies bare bone structure of the config files. -Furthermore it check the config file you create for validity and sets -missing (if optional) parameters to default values. When adding functionality -that needs parameters from the config file, make sure to add those paremeters here. -Also if that parameter is optional, add a default value. +Unlike the array Specs in :mod:`zea.data.spec` (which validate numpy +``dtype``/``shape`` and named-dimension consistency), config values are plain +Python scalars / lists / dicts, so validation here uses small validator +callables (enums, numeric ranges, regexes). +The ``parameters`` section and the top-level config are *open*: they accept +arbitrary extra keys, which are stored and re-emitted unchanged. This mirrors +:class:`zea.Parameters`, which keeps unknown keys as pass-through +``_custom_params``. """ +import re +from dataclasses import MISSING, dataclass, field, fields from pathlib import Path +from typing import Any, Callable, ClassVar, Optional, Type -from schema import And, Optional, Or, Regex, Schema - -from zea.internal.checks import _DATA_TYPES -from zea.metrics import metrics_registry - -# predefined checks, later used in schema to check validity of parameter -any_number = Or( - int, - float, - error="Must be a number, scientific notation should be of form x.xe+xx, " - "otherwise interpreted as string", -) -list_of_size_two = And(list, lambda _list: len(_list) == 2) -positive_integer = And(int, lambda i: i > 0) -positive_integer_and_zero = And(int, lambda i: i >= 0) -positive_float = And(float, lambda f: f > 0) -list_of_floats = And(list, lambda _list: all(isinstance(_l, float) for _l in _list)) -list_of_positive_integers = And(list, lambda _list: all(_l >= 0 for _l in _list)) -percentage = And(any_number, lambda f: 0 <= f <= 100) - -_ALLOWED_PLOT_LIBS = ("opencv", "matplotlib") - -# pipeline / operations -pipeline_schema = Schema( - { - Optional("operations", default=["identity"]): Or( - None, [Or(str, {"name": str, "params": dict}, {"name": str})] - ), - Optional("with_batch_dim", default=True): bool, - Optional("jit_options", default="ops"): Or(None, "ops", "pipeline"), - Optional("jit_kwargs", default=None): Or(None, dict), - Optional("name", default="pipeline"): str, - Optional("validate", default=True): bool, - } -) - -# postprocess DEPRECATED -postprocess_schema = Schema( - { - Optional("contrast_boost", default=None): Or( - None, - { - "k_p": float, - "k_n": float, - "threshold": float, - }, - ), - Optional("thresholding", default=None): Or( - None, - { - Optional("percentile", default=None): Or(None, percentage), - Optional("threshold", default=None): Or(None, any_number), - Optional("fill_value", default="min"): Or("min", "max", "threshold", any_number), - Optional("below_threshold", default=True): bool, - Optional("threshold_type", default="hard"): Or("hard", "soft"), - }, - ), - Optional("lista", default=None): Or(bool, None), +# --------------------------------------------------------------------------- +# Validator helpers +# +# Each validator is a ``Callable[[Any], Any]`` that returns the (possibly +# coerced) value or raises ``ValueError`` with a human-readable message. +# --------------------------------------------------------------------------- + + +def boolean(value: Any) -> bool: + """Validate a boolean.""" + if not isinstance(value, bool): + raise ValueError(f"must be a boolean, got {type(value).__name__}") + return value + + +def string(value: Any) -> str: + """Validate a string.""" + if not isinstance(value, str): + raise ValueError(f"must be a string, got {type(value).__name__}") + return value + + +def integer(value: Any) -> int: + """Validate an integer (``bool`` is rejected).""" + if isinstance(value, bool) or not isinstance(value, int): + raise ValueError(f"must be an integer, got {type(value).__name__}") + return value + + +def any_number(value: Any) -> Any: + """Validate a number (``int`` or ``float``, ``bool`` is rejected).""" + if isinstance(value, bool) or not isinstance(value, (int, float)): + raise ValueError( + "must be a number, scientific notation should be of form x.xe+xx, " + "otherwise interpreted as string" + ) + return value + + +def positive_integer(value: Any) -> int: + """Validate a strictly positive integer.""" + integer(value) + if value <= 0: + raise ValueError(f"must be a positive integer, got {value}") + return value + + +def positive_integer_and_zero(value: Any) -> int: + """Validate a non-negative integer.""" + integer(value) + if value < 0: + raise ValueError(f"must be a non-negative integer, got {value}") + return value + + +def positive_float(value: Any) -> float: + """Validate a strictly positive float.""" + if isinstance(value, bool) or not isinstance(value, float): + raise ValueError(f"must be a float, got {type(value).__name__}") + if value <= 0: + raise ValueError(f"must be a positive float, got {value}") + return value + + +def mapping(value: Any) -> dict: + """Validate a dict (mapping).""" + if not isinstance(value, dict): + raise ValueError(f"must be a dict, got {type(value).__name__}") + return value + + +def list_of_size_two(value: Any) -> list: + """Validate a list of exactly two elements.""" + if not isinstance(value, list) or len(value) != 2: + raise ValueError(f"must be a list of length two, got {value!r}") + return value + + +def list_of_positive_integers(value: Any) -> list: + """Validate a list of non-negative integers.""" + if not isinstance(value, list) or not all( + isinstance(x, int) and not isinstance(x, bool) and x >= 0 for x in value + ): + raise ValueError(f"must be a list of non-negative integers, got {value!r}") + return value + + +def string_or_path(value: Any) -> Any: + """Validate a string or :class:`pathlib.Path`.""" + if not isinstance(value, (str, Path)): + raise ValueError(f"must be a string or path, got {type(value).__name__}") + return value + + +def enum(*allowed: Any) -> Callable[[Any], Any]: + """Build a validator that accepts only one of ``allowed`` values.""" + + def validate(value: Any) -> Any: + if value not in allowed: + raise ValueError(f"must be one of {list(allowed)}, got {value!r}") + return value + + return validate + + +def regex(pattern: str) -> Callable[[Any], str]: + """Build a validator that fully matches ``pattern``.""" + compiled = re.compile(pattern) + + def validate(value: Any) -> str: + if not isinstance(value, str) or compiled.fullmatch(value) is None: + raise ValueError(f"must match pattern {pattern!r}, got {value!r}") + return value + + return validate + + +def any_of(*validators: Callable[[Any], Any]) -> Callable[[Any], Any]: + """Build a validator that passes if any of ``validators`` passes.""" + + def validate(value: Any) -> Any: + errors = [] + for validator in validators: + try: + return validator(value) + except ValueError as exc: + errors.append(str(exc)) + raise ValueError(" or ".join(errors)) + + return validate + + +def optional(validator: Callable[[Any], Any]) -> Callable[[Any], Any]: + """Build a validator that also accepts ``None``.""" + + def validate(value: Any) -> Any: + if value is None: + return None + return validator(value) + + return validate + + +def operations_list(value: Any) -> list: + """Validate the pipeline ``operations`` list. + + Each element is either an operation name (str) or a mapping with a ``name`` + (str) and optional ``params`` (dict). + """ + if not isinstance(value, list): + raise ValueError(f"must be a list of operations, got {type(value).__name__}") + for op in value: + if isinstance(op, str): + continue + if isinstance(op, dict): + if not isinstance(op.get("name"), str): + raise ValueError(f"operation {op!r} must have a string 'name'") + unexpected = set(op) - {"name", "params"} + if unexpected: + raise ValueError(f"operation {op!r} has unexpected keys {sorted(unexpected)}") + if "params" in op and not isinstance(op["params"], dict): + raise ValueError(f"operation {op!r} 'params' must be a dict") + continue + raise ValueError(f"invalid operation {op!r}") + return value + + +# --------------------------------------------------------------------------- +# Config Spec base class +# --------------------------------------------------------------------------- + + +@dataclass +class ConfigSpec: + """Base class for config sections. + + Subclasses are dataclasses that declare their fields (with defaults for + optional fields) and the following class variables: + + - ``VALIDATORS``: maps a field name to a validator callable. + - ``NESTED``: maps a field name to a nested :class:`ConfigSpec` subclass. + - ``ALLOW_EXTRA``: when ``True`` arbitrary extra keys are accepted and + passed through unchanged (used for the open ``parameters`` section and the + top-level config). + """ + + VALIDATORS: ClassVar[dict[str, Callable[[Any], Any]]] = {} + NESTED: ClassVar[dict[str, Type["ConfigSpec"]]] = {} + ALLOW_EXTRA: ClassVar[bool] = False + + def __post_init__(self) -> None: + if not hasattr(self, "_extra"): + self._extra: dict[str, Any] = {} + for name in self.field_names(): + value = getattr(self, name) + nested = self.NESTED.get(name) + if nested is not None: + setattr(self, name, self._coerce_nested(name, nested, value)) + continue + validator = self.VALIDATORS.get(name) + if validator is not None: + try: + value = validator(value) + except ValueError as exc: + raise ValueError(f"{type(self).__name__}.{name}: {exc}") from exc + setattr(self, name, value) + + def _coerce_nested(self, name: str, nested: Type["ConfigSpec"], value: Any) -> "ConfigSpec": + if value is None: + # Optional nested section: fall back to its defaults. + return nested.from_dict({}) + if isinstance(value, nested): + return value + if isinstance(value, dict): + try: + return nested.from_dict(value) + except ValueError as exc: + raise ValueError(f"{type(self).__name__}.{name}: {exc}") from exc + raise ValueError( + f"{type(self).__name__}.{name}: expected a mapping for " + f"{nested.__name__}, got {type(value).__name__}" + ) + + # -- construction / serialization -------------------------------------- + + @classmethod + def from_dict(cls, dictionary: Optional[dict]) -> "ConfigSpec": + """Validate ``dictionary`` and return a populated spec instance.""" + if dictionary is None: + dictionary = {} + if not isinstance(dictionary, dict): + raise ValueError(f"{cls.__name__}: expected a mapping, got {type(dictionary).__name__}") + + field_names = set(cls.field_names()) + known = {k: v for k, v in dictionary.items() if k in field_names} + extra = {k: v for k, v in dictionary.items() if k not in field_names} + + if extra and not cls.ALLOW_EXTRA: + raise ValueError(f"{cls.__name__}: unexpected keys {sorted(extra)}") + + missing = [name for name in cls.required_fields() if name not in known] + if missing: + raise ValueError(f"{cls.__name__}: missing required keys {missing}") + + obj = cls(**known) + if extra: + obj._extra.update(extra) + return obj + + def to_dict(self) -> dict[str, Any]: + """Return a plain dict with defaults filled and nested specs expanded.""" + result: dict[str, Any] = {} + for name in self.field_names(): + value = getattr(self, name) + if isinstance(value, ConfigSpec): + value = value.to_dict() + result[name] = value + result.update(self._extra) + return result + + # -- introspection (used by tooling / docs) ---------------------------- + + @classmethod + def field_names(cls) -> tuple[str, ...]: + """Return the names of all declared fields.""" + return tuple(f.name for f in fields(cls)) + + @classmethod + def required_fields(cls) -> tuple[str, ...]: + """Return the names of fields without a default value.""" + return tuple( + f.name for f in fields(cls) if f.default is MISSING and f.default_factory is MISSING + ) + + @classmethod + def optional_fields(cls) -> tuple[str, ...]: + """Return the names of fields with a default value.""" + required = set(cls.required_fields()) + return tuple(name for name in cls.field_names() if name not in required) + + @classmethod + def all_field_paths(cls, prefix: str = "") -> set[str]: + """Return all documented field paths in dot notation, recursing nested specs.""" + paths: set[str] = set() + for name in cls.field_names(): + full = f"{prefix}.{name}" if prefix else name + nested = cls.NESTED.get(name) + if nested is not None: + paths |= nested.all_field_paths(full) + else: + paths.add(full) + return paths + + +# --------------------------------------------------------------------------- +# Config sections +# --------------------------------------------------------------------------- + + +@dataclass +class DataConfig(ConfigSpec): + """The ``data:`` section: data path and loading settings.""" + + path: Any = None + local: Any = True + indices: Any = None + user: Any = None + + VALIDATORS: ClassVar[dict] = { + "path": optional(string_or_path), + "local": boolean, + "indices": optional(any_of(enum("all"), integer, list_of_positive_integers)), + "user": optional(mapping), } -) - -# scan -# Schema for the flat ``parameters:`` config section (formerly ``scan:``). -# ``ignore_extra_keys`` allows arbitrary custom/manual parameters in addition -# to the documented recon parameters below. -parameters_schema = Schema( - { - Optional("xlims", default=None): Or(None, list_of_size_two), - Optional("zlims", default=None): Or(None, list_of_size_two), - Optional("ylims", default=None): Or(None, list_of_size_two), - Optional("selected_transmits", default=None): Or( - None, - positive_integer, - list_of_positive_integers, - "all", - "center", - ), - Optional("grid_size_x", default=None): Or(None, positive_integer), - Optional("grid_size_z", default=None): Or(None, positive_integer), - Optional("n_ch", default=None): Or(None, int), - Optional("n_ax", default=None): Or(None, int), - Optional("center_frequency", default=None): Or(None, any_number), - Optional("sampling_frequency", default=None): Or(None, any_number), - Optional("demodulation_frequency", default=None): Or(None, any_number), - Optional("f_number", default=None): Or(None, positive_float), - Optional("apply_lens_correction", default=False): bool, - Optional("lens_thickness", default=1e-3): positive_float, - Optional("lens_sound_speed", default=1000): Or(positive_float, positive_integer), - Optional("theta_range", default=None): Or(None, list_of_size_two), - Optional("phi_range", default=None): Or(None, list_of_size_two), - Optional("rho_range", default=None): Or(None, list_of_size_two), - Optional("fill_value", default=0.0): any_number, - Optional("resolution", default=None): Or(None, positive_float), - }, - ignore_extra_keys=True, -) - -# plot -plot_schema = Schema( - { - Optional("save", default=False): bool, - Optional("plot_lib", default="opencv"): Or(*_ALLOWED_PLOT_LIBS), - Optional("fps", default=20): int, - Optional("tag", default=None): Or(None, str), - Optional("headless", default=False): bool, - Optional("selector", default=None): Or(None, "rectangle", "lasso"), - Optional("selector_metric", default="gcnr"): Or(*metrics_registry.registered_names()), - Optional("fliplr", default=False): bool, - Optional("image_extension", default="png"): Or("png", "jpg"), - Optional("video_extension", default="gif"): Or("mp4", "gif"), + + +@dataclass +class ParametersConfig(ConfigSpec): + """The ``parameters:`` section — open pass-through for scan/probe/custom parameters. + + ProbeSpec and ScanSpec are the single source of truth for which parameter + names are valid. Any key listed here overrides the value loaded from the + data file; arbitrary custom keys are forwarded to the pipeline unchanged. + """ + + ALLOW_EXTRA: ClassVar[bool] = True + + +@dataclass +class PipelineConfig(ConfigSpec): + """The ``pipeline:`` section: operations and JIT settings.""" + + operations: Any = field(default_factory=lambda: ["identity"]) + with_batch_dim: Any = True + jit_options: Any = "ops" + jit_kwargs: Any = None + name: Any = "pipeline" + validate: Any = True + + VALIDATORS: ClassVar[dict] = { + "operations": optional(operations_list), + "with_batch_dim": boolean, + "jit_options": optional(enum("ops", "pipeline")), + "jit_kwargs": optional(mapping), + "name": string, + "validate": boolean, } -) - -data_schema = Schema( - { - "dtype": Or(*_DATA_TYPES), - "dataset_folder": str, - Optional("resolution", default=None): Or(None, positive_float), - Optional("to_dtype", default="image"): Or(*_DATA_TYPES), - Optional("file_path", default=None): Or(None, str, Path), - Optional("local", default=True): bool, - Optional("frame_no", default=None): Or(None, "all", int), - Optional("dynamic_range", default=[-60, 0]): list_of_size_two, - Optional("input_range", default=None): Or(None, list_of_size_two), - Optional("output_range", default=None): Or(None, list_of_size_two), - Optional("apodization", default=None): Or(None, str), - Optional("user", default=None): Or(None, dict), + + +@dataclass +class ConfigSchema(ConfigSpec): + """The top-level config. + + This is *open*: arbitrary extra top-level sections (e.g. ``data:``, + ``model:``) are accepted and passed through unchanged. The deprecated + ``scan:`` section is aliased to ``parameters:`` before validation (see + :func:`zea.config._migrate_legacy_config`). + """ + + data: Any = None + pipeline: Any = None + parameters: Any = None + device: Any = "auto:1" + hide_devices: Any = None + git: Any = None + + ALLOW_EXTRA: ClassVar[bool] = True + NESTED: ClassVar[dict] = { + "data": DataConfig, + "pipeline": PipelineConfig, + "parameters": ParametersConfig, } -) - -# top level schema -config_schema = Schema( - { - "data": data_schema, - Optional("plot", default=plot_schema.validate({})): plot_schema, - Optional("pipeline", default=pipeline_schema.validate({})): pipeline_schema, - # Flat mapping of scan/probe/custom parameters that overwrite values - # loaded from the file (see ``File.load_parameters`` and - # ``Pipeline.prepare_parameters``). Documented recon parameters are - # validated; arbitrary custom keys are also allowed (ignore_extra_keys). - Optional("parameters", default=parameters_schema.validate({})): parameters_schema, - # Deprecated alias for ``parameters``; still accepted for backward - # compatibility (migrated to ``parameters`` on load). - Optional("scan"): Or(None, dict), - Optional("device", default="auto:1"): Or( - "cpu", - "gpu", - "cuda", - Regex(r"cuda:\d+"), - Regex(r"gpu:\d+"), - Regex(r"auto:\d+"), - Regex(r"auto:-\d+"), - None, + VALIDATORS: ClassVar[dict] = { + "device": optional( + any_of( + enum("cpu", "gpu", "cuda"), + regex(r"cuda:\d+"), + regex(r"gpu:\d+"), + regex(r"auto:-?\d+"), + ) ), - Optional("hide_devices", default=None): Or( - None, list_of_positive_integers, positive_integer_and_zero - ), - Optional("git", default=None): Or(None, str), - }, - # Allow arbitrary extra top-level keys; they are ignored by the workflow - # unless accessed manually from code (see redesign of Config). - ignore_extra_keys=True, -) + "hide_devices": optional(any_of(list_of_positive_integers, positive_integer_and_zero)), + "git": optional(string), + } + + +def validate_config(config: Optional[dict]) -> dict: + """Validate a config dict and return a plain dict with defaults filled in. + + This is the replacement for the previous ``config_schema.validate(...)``. + """ + return ConfigSchema.from_dict(config).to_dict()