diff --git a/OceanOSSE/__init__.py b/OceanOSSE/__init__.py index f4a21a1..79632f1 100644 --- a/OceanOSSE/__init__.py +++ b/OceanOSSE/__init__.py @@ -1,23 +1,33 @@ """ OceanOSSE -Python toolbox for performing Observing System Simulation Experiments (OSSEs) in ocean general circulation models. +Python toolbox for performing Observing System Simulation Experiments (OSSEs) +in ocean general circulation models. """ + __author__ = "Ollie Tooth (oliver.tooth@noc.ac.uk)" __credits__ = "National Oceanography Centre (NOC), Southampton, UK" from importlib.metadata import version as _version -from OceanOSSE import ( - cli, - pipeline, -) +from OceanOSSE import cli, pipeline +from OceanOSSE.gridding.regridder import Regridder +from OceanOSSE.io.dataloader import DataLoader +from OceanOSSE.io.datawriter import DataWriter +from OceanOSSE.sampling.sampler import ErrorKernel, ObsSampler +from OceanOSSE.gridding.regridder_simple import SwapRegridder try: __version__ = _version("OceanOSSE") except Exception: - # Local copy or not installed with setuptools. - # Disable minimum version checks on downstream libraries. __version__ = "9999.0.0" -__all__ = ("cli", "pipeline") \ No newline at end of file +__all__ = ( + "cli", + "pipeline", + "DataLoader", + "DataWriter", + "ErrorKernel", + "ObsSampler", + "Regridder", +) diff --git a/OceanOSSE/cli.py b/OceanOSSE/cli.py index 441fd3f..e5b9bd7 100644 --- a/OceanOSSE/cli.py +++ b/OceanOSSE/cli.py @@ -6,4 +6,119 @@ Created By: Ollie Tooth (oliver.tooth@noc.ac.uk) """ -# -- Import Dependencies -- # \ No newline at end of file +# -- Import dependencies -- # +import logging +import sys + +import typer +from typing_extensions import Annotated, Optional + +from OceanOSSE.pipeline import describe_pipeline, run_pipeline + +from .__init__ import __version__ + +app = typer.Typer() +logger = logging.getLogger(__name__) + + +# -- Define CLI Functions -- # +def create_header( + config_path: str, + log_path: str, +) -> None: + """ + Add OceanOSSE header to log. + + Parameters: + ----------- + config_path : str + Filepath to OceanOSSE config .toml file. + log_path : str + Filepath to OceanOSSE log file. + """ + logger.info( + f""" +╔══════════════════════════════════════════════════════════════╗ +║ OceanOSSE ║ +║ Ocean Observing System Simulation Experiment Tool ║ +╠══════════════════════════════════════════════════════════════╣ + OceanOSSE Version : {__version__} + Python Version : {sys.version.split()[0]} + Config File : {config_path} + Log File : {log_path} +╚══════════════════════════════════════════════════════════════╝ +""", + extra={"simple": True}, + ) + + +def init_logging(log_path: str) -> None: + """ + Initialise OceanOSSE logging. + + Parameters: + ----------- + log_path : str + Filepath to log file. If None, logs to 'ocean_osse.log'. + """ + # === Validate Inputs === # + if not isinstance(log_path, str): + raise TypeError("log_path must be a string.") + + logging.basicConfig( + format="⦿══⦿ OceanOSSE ⦿══⦿ ║ %(levelname)10s ║ %(asctime)s ║ %(message)s", + level=logging.INFO, + datefmt="%Y-%m-%d %H:%M:%S", + handlers=[logging.FileHandler(log_path), logging.StreamHandler()], + ) + + +# === Create Typer App === # +@app.callback() +def main() -> None: + """ + Main callback for Typer app to allow + single run command to be defined. + """ + pass + + +@app.command() +def run( + config: Annotated[str, typer.Argument(help="Path to OceanOSSE config .toml file")], + log: Annotated[ + Optional[str], + typer.Option( + help="Path to write OceanOSSE log file", rich_help_panel="Options" + ), + ] = "ocean_osse.log", + dry_run: Annotated[ + Optional[bool], + typer.Option( + help="Describe OceanOSSE workflow without execution.", + rich_help_panel="Options", + ), + ] = False, +) -> None: + """ + Run OceanOSSE workflow defined by configuration (.toml) file in current process. + """ + # === Initialise Logging === # + init_logging(log_path=log) + create_header(config_path=config, log_path=log) + + # === Run OceanOSSE === # + args = { + "config_file": config, + "log_filepath": log, + } + if dry_run: + describe_pipeline(args=args) + else: + run_pipeline(args=args) + + logging.info("✔ OceanOSSE Completed ✔") + + +if __name__ == "__main__": + app() diff --git a/OceanOSSE/gridding/regridder.py b/OceanOSSE/gridding/regridder.py new file mode 100644 index 0000000..f38f1d7 --- /dev/null +++ b/OceanOSSE/gridding/regridder.py @@ -0,0 +1,122 @@ +""" +regridder.py + +Description: Regridding module for OceanOSSE package. + +Created By: OceanOSSE Development Team (NOC, UK) +""" + +# -- Import Dependencies -- # +from __future__ import annotations + +import abc +import logging +from typing import Self + +import xarray as xr + +logger = logging.getLogger(__name__) + + +# -- Regridder Abstract Base Class -- # +class Regridder(abc.ABC): + """ + Abstract base class for regridding synthetic ocean observations onto + the original model grid, using methods such as objective analysis + or interpolation. + + Parameters + ---------- + target_grid : xarray.Dataset or None, optional + Dataset describing the target grid (coordinates, masks, etc.). + """ + + def __init__( + self, + target_grid: xr.Dataset | None = None, + ) -> None: + if target_grid is not None and not isinstance(target_grid, xr.Dataset): + raise TypeError("``target_grid`` must be an xarray.Dataset or None.") + self._target_grid = target_grid + + def __repr__(self) -> str: + has_grid = self._target_grid is not None + return f"{type(self).__name__}(target_grid={'' if has_grid else None})" + + @classmethod + @abc.abstractmethod + def from_config(cls, config: dict) -> Self: + """ + Construct a Regridder from the from the `[regridding]` table of + the .toml configuration file. + + Parameters + ---------- + config : dict + Configuration dictionary containing input parameters from .toml + configuration file. + + Returns + ------- + Self + Initialised Regridder instance. + """ + ... + + @abc.abstractmethod + def regrid(self, ds: xr.Dataset) -> xr.Dataset: + """ + Regrid the synthetic observation dataset onto the target grid. + + Parameters + ---------- + ds : xarray.Dataset + Synthetic observations dataset. + + Returns + ------- + xarray.Dataset + Dataset of synthetic observations regridded onto target grid. + """ + ... + + +# -- Regridder Implementations -- # + + +class TestRegridder(Regridder): + """ + Regridder used for testing and scaffold validation. + + Returns the the synthetic observations dataset unchanged. + """ + + @classmethod + def from_config(cls, config: dict) -> Self: + """ + Instantiate a TestRegridder from the `[regridding]` table of + the .toml configuration file. + """ + return cls() + + def regrid(self, ds: xr.Dataset) -> xr.Dataset: + """ + Regrid the synthetic observation dataset onto the target grid. + + Parameters + ---------- + ds : xarray.Dataset + Synthetic observations dataset. + + Returns + ------- + xarray.Dataset + Dataset of synthetic observations (unchanged from input). + """ + logger.debug( + "Regridding synthetic observations with TestRegridder -> returns input dataset unchanged." + ) + logging.info( + "--> Completed: Regridded synthetic observations with TestRegridder." + ) + return ds diff --git a/OceanOSSE/gridding/regridder_simple.py b/OceanOSSE/gridding/regridder_simple.py new file mode 100644 index 0000000..c36dc35 --- /dev/null +++ b/OceanOSSE/gridding/regridder_simple.py @@ -0,0 +1,143 @@ +# =================================================================== +# Copyright 2025 National Oceanography Centre +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0. +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. See the License for the specific language governing +# permissions and limitations under the License. +# =================================================================== + +""" +sampler_nearest_neighbour.py + +Description: Sampling module for OceanOSSE package. + +Created By: OceanOSSE Development Team (NOC, UK) +""" + +# -- Import Dependencies -- # +from __future__ import annotations + +import logging + +import xarray as xr +import numpy as np + +from OceanOSSE.utils import import_class +from OceanOSSE.gridding.regridder import Regridder + +logger = logging.getLogger(__name__) + + +class SwapRegridder(Regridder): + """ + Regridding class for synthetic ocean observations onto + the original model grid using climatology and exchanging profiles. + + Parameters + ---------- + target_grid : xarray.Dataset or None, optional + Dataset describing the target grid (coordinates, masks, etc.). + """ + + def __init__( + self, + target_grid: xr.Dataset | None = None, + ) -> None: + if target_grid is not None and not isinstance(target_grid, xr.Dataset): + raise TypeError("``target_grid`` must be an xarray.Dataset or None.") + self._target_grid = target_grid + self.ds_clim = self.climatology() + + def __repr__(self) -> str: + has_grid = self._target_grid is not None + return f"{type(self).__name__}(target_grid={'' if has_grid else None})" + + + def from_config(cls, config: dict) -> Self: + """ + Construct a Regridder from the from the `[regridding]` table of + the .toml configuration file. + + Parameters + ---------- + config : dict + Configuration dictionary containing input parameters from .toml + configuration file. + + Returns + ------- + Self + Initialised Regridder instance. + """ + + return self + + + def regrid(self, ds: xr.Dataset) -> xr.Dataset: + """ + Regrid the synthetic observation dataset into the target grid. + + Parameters + ---------- + ds : xarray.Dataset + Synthetic observations dataset. + + Returns + ------- + xarray.Dataset + Dataset of synthetic observations placed into target grid. + """ + # Use indices in synthetic profile set to replace data in the climatology with model data + ds_model = self.ds_clim + n_profile = len(ds.coords['profile_id']) + + # loop over profiles + for p in range(n_profile): + i_ind = ds.coords['i'][p].to_numpy() + j_ind = ds.coords['j'][p].to_numpy() + t_ind = ds.coords['t'][p].to_numpy() + ps = ds.coords['profile_id'][p].to_numpy() + + profile = ds['votemper'].isel(profile_id=ps) + + ds_model['votemper'].loc[ + dict( + t=ds.t.sel(profile_id=ps), + j=ds.j.sel(profile_id=ps), + i=ds.i.sel(profile_id=ps)) + ] = profile.values + + return ds_model + + + def climatology(self): + """ + Calculate the climatology of the target grid. + + Returns + ------- + xarray.Dataset + Dataset of monthly means. + """ + ds = self._target_grid.assign_coords( + monthday=("t", self._target_grid.t.dt.strftime("%m-%d").data) + ) + # calculate climatology + ds_clim = ds.groupby('monthday').mean() + + # tile the climatology data back over full time series + ds_clim_full = ds_clim.sel(monthday=ds.monthday) + + # Remove not needed time dim from variables + for v in ["lat", "lon", "depth"]: + ds_clim_full[v] = ds_clim_full[v].isel(t=0, drop=True) + ds_clim_full = ds_clim_full.drop_vars('monthday') + + return ds_clim_full + \ No newline at end of file diff --git a/OceanOSSE/io/dataloader.py b/OceanOSSE/io/dataloader.py new file mode 100644 index 0000000..8e6b721 --- /dev/null +++ b/OceanOSSE/io/dataloader.py @@ -0,0 +1,354 @@ +""" +dataloader.py + +Description: DataLoader module for OceanOSSE package. + +Created By: OceanOSSE Development Team (NOC, UK) +""" + +# -- Import Dependencies -- # +from __future__ import annotations + +import abc +import glob +import logging +from typing import Self + +import xarray as xr + +logger = logging.getLogger(__name__) + + +# -- DataLoader Abstract Base Class -- # +class DataLoader(abc.ABC): + """ + DataLoader Base Class to load gridded ocean model data from + persistent local filesystem or cloud object storage. + + Parameters + ---------- + source : dict[str, dict] + Source dictionary of gridded ocean model variables to load. + dimensions : dict[str, str] + Mapping of standard dimension names to input dataset dimension names. + coordinates : dict[str, str] + Mapping of standard coordinate names to input dataset coordinate names. + + Attributes + ---------- + _source : dict[str, dict] + Source dictionary of gridded ocean model variables. + _dimensions : dict[str, str] + Mapping of standard dimension names to input dataset dimension names. + _coordinates : dict[str, str] + Mapping of standard coordinate names to input dataset coordinate names. + """ + + def __init__( + self, + source: dict[str, dict], + dimensions: dict[str, str], + coordinates: dict[str, str], + ): + # -- Verify Inputs -- # + if not isinstance(source, dict): + raise TypeError("``source`` must be a specfied as a dictionary.") + if not isinstance(dimensions, dict): + raise TypeError("``dimensions`` must be a specfied as a dictionary.") + if not isinstance(coordinates, dict): + raise TypeError("``coordinates`` must be a specfied as a dictionary.") + + # -- Class Attributes -- # + self._source = source + self._dimensions = dimensions + self._coordinates = coordinates + + def __repr__(self) -> str: + return ( + f"{type(self).__name__}(" + f"source={self._source!r}, " + f"dimensions={self._dimensions!r}, " + f"coordinates={self._coordinates!r})" + ) + + @classmethod + @abc.abstractmethod + def from_config(cls, config: dict) -> Self: + """ + Abstract class method to instantiate a DataLoader from the `[inputs]` table + of the .toml configuration file. + + This is the required constructor for all DataLoader subclasses - plugin + authors must implement this method for use in OceanOSSE. + + Parameters + ---------- + config : dict + Configuration dictionary containing input parameters from .toml + configuration file. + + Returns + ------- + Self + Initialised DataLoader instance. + """ + ... + + @abc.abstractmethod + def load_data(self) -> xr.Dataset: + """ + Abstract method to load gridded ocean model data into a standardised xarray.Dataset. + + The returned Dataset will be validated by the validate_dataset() method. + + Returns + ------- + xarray.Dataset + Dataset containing standardised gridded ocean model variables. + """ + ... + + def _standardise_dataset(self, ds: xr.Dataset) -> xr.Dataset: + """ + Standardise gridded ocean model xarray.Dataset by renaming dimensions and coordinates to OceanOSSE standard names. + + Parameters + ---------- + ds : xarray.Dataset + Dataset with original dimension and coordinate names. + + Returns + ------- + xarray.Dataset + Dataset with standardised dimension and coordinate names. + """ + # -- Rename dimensions to standard dimensions names -- # + rename_dims = {value: key for key, value in self._dimensions.items()} + ds = ds.rename_dims(rename_dims) + + # -- Assign standard coordinate names and drop any non-standard coordinates -- # + ds = ds.assign_coords( + {coord: ds[var] for coord, var in self._coordinates.items()} + ) + drop_coords = [coord for coord in ds.coords if coord not in self._coordinates] + ds = ds.drop_vars(drop_coords) + + return ds + + def _validate_dataset(self, ds: xr.Dataset) -> None: + """ + Validate the standardised xarray.Dataset. + + Parameters + ---------- + dataset : xarray.Dataset + Dataset containing standardised gridded ocean model variables. + + Raises + ------- + ValueError + If the dataset does not contain the required variables or dimensions. + """ + # -- Validate Required Dimensions -- # + required_dims = ["time", "lev", "j", "i"] + missing_dims = [dim for dim in required_dims if dim not in ds.dims] + if missing_dims: + raise ValueError( + f"{type(self).__name__}: loaded dataset is missing required " + f"dimension(s): {missing_dims}. Found dimensions: {list(ds.dims)}." + ) + + # -- Validate Required Coordinates -- # + required_coords = ["time", "depth", "lat", "lon"] + missing_coords = [coord for coord in required_coords if coord not in ds.coords] + if missing_coords: + raise ValueError( + f"{type(self).__name__}: loaded dataset is missing required " + f"coordinate(s): {missing_coords}. Found coordinates: {list(ds.coords)}." + ) + + +# -- DataLoader Implementations -- # +class NetCDFDataLoader(DataLoader): + """ + DataLoader implementation to load gridded ocean model data from NetCDF files. + + Parameters + ---------- + source : dict[str, dict] + Source dictionary of gridded ocean model variables to load. + dimensions : dict[str, str] + Mapping of standard dimension names to input dataset dimension names. + coordinates : dict[str, str] + Mapping of standard coordinate names to input dataset coordinate names. + """ + + def __init__( + self, + source: dict[str, dict], + dimensions: dict[str, str], + coordinates: dict[str, str], + ) -> None: + # -- Initialise parent DataLoader class -- # + super().__init__(source, dimensions, coordinates) + + @classmethod + def from_config(cls, config: dict) -> Self: + """ + Instantiate a NetCDFDataLoader from the `[inputs]` table of the .toml configuration file. + """ + # -- Verify Input -- # + if not isinstance(config, dict): + raise TypeError("config must be a dictionary.") + + # -- Instantiate DataLoader with source dict from config -- # + source = config["inputs"].get("variables", None) + if source is None: + raise ValueError( + "Missing 'variables' entry in [inputs] table of config .toml file." + ) + dimensions = config["inputs"].get("dimensions", None) + if dimensions is None: + raise ValueError( + "Missing 'dimensions' entry in [inputs] table of config .toml file." + ) + coordinates = config["inputs"].get("coordinates", None) + if coordinates is None: + raise ValueError( + "Missing 'coordinates' entry in [inputs] table of config .toml file." + ) + + return cls(source=source, dimensions=dimensions, coordinates=coordinates) + + def _open_dataset( + self, + filepath: str, + variables: list[str], + open_kwargs: dict, + ) -> xr.Dataset: + """ + Open input variable Dataset from a netCDF file(s). + + Parameters: + ----------- + filepath : str + Filepath pattern to input variable netCDF file(s). + variables : list[str] + Name of variable(s) to load from the dataset. + open_kwargs : dict + Additional keyword arguments to pass to xarray.open_dataset + or xarray.open_mfdataset. + + Returns: + -------- + xr.Dataset + Standardised variable Dataset. + """ + # -- Validate Inputs -- # + if not isinstance(filepath, str): + raise TypeError("filepath must be a string.") + if not isinstance(variables, list): + raise TypeError("variables must be a list of strings.") + if not isinstance(open_kwargs, dict): + raise TypeError("open_kwargs must be a dictionary.") + + filepaths = glob.glob(filepath) + if len(filepaths) == 0: + raise FileNotFoundError(f"No files found matching filepath: {filepath}") + + # Define CFDatetimeCoder to decode time coords: + coder = xr.coders.CFDatetimeCoder(time_unit="s") + + # -- Open input variable dataset -- # + if len(filepaths) == 1: + if open_kwargs is None: + open_kwargs = {"engine": "netcdf4"} + try: + logging.info( + f"In Progress: Opening variable(s) {variables} from netCDF file '{filepath}'" + ) + ds_var = xr.open_dataset( + filepaths[0], decode_times=coder, **open_kwargs + )[variables] + logging.info(f"--> Completed: Opened variable(s) {variables}.") + + except FileNotFoundError as e: + raise FileNotFoundError( + f"Failed to open netCDF file: {filepath}" + ) from e + else: + if open_kwargs is None: + open_kwargs = { + "data_vars": "minimal", + "compat": "no_conflicts", + "parallel": False, + "engine": "netcdf4", + } + if variables is not None: + open_kwargs["preprocess"] = lambda ds: ds[variables] + try: + logging.info( + f"In Progress: Opening variable(s) {variables} from multiple netCDF files '{filepath}'" + ) + ds_var = xr.open_mfdataset( + filepaths, decode_times=coder, **open_kwargs + )[variables] + logging.info(f"--> Completed: Opened variable(s) {variables}.") + + except FileNotFoundError as e: + raise FileNotFoundError( + f"Failed to open netCDF files: {filepaths}" + ) from e + + return ds_var + + def load_data(self) -> xr.Dataset: + """ + Load gridded ocean model data from netCDF file(s) into a standardised xarray.Dataset. + """ + # -- Define variable names, filepaths and open_kwargs from source dict -- # + variables = self._source.keys() + open_kwargs_list = [ + d_var.get("open_kwargs", {}) for d_var in self._source.values() + ] + filepaths = [d_var.get("path", "") for d_var in self._source.values()] + n_filepaths = len(set(filepaths)) + + if n_filepaths == 1: + # -- All Variables Share Common Filepath Pattern -- # + # Merge open_kwargs for all variables sharing the same filepath pattern: + open_kwargs = {} + for d_open_kwargs in open_kwargs_list: + open_kwargs.update(d_open_kwargs) + + # Load all variables into single xarray.Dataset: + ds = self._open_dataset( + filepath=filepaths[0], + variables=list(variables), + open_kwargs=open_kwargs, + ) + else: + # -- Variables Have Different Filepath Patterns -- # + ds_list = [] + for filepath, variable, open_kwargs in zip( + filepaths, variables, open_kwargs_list, strict=True + ): + ds_list.append( + self._open_dataset( + filepath=filepath, variables=[variable], open_kwargs=open_kwargs + ) + ) + # Merge individual variable datasets into single dataset: + ds = xr.merge(ds_list, compat="no_conflicts", join="exact") + + # -- Standardise dataset dimensions and coordinates -- # + ds = self._standardise_dataset(ds) + logging.info( + f"--> Completed: Standardised dataset dimensions {list(ds.sizes.keys())} and coordinates {list(ds.coords.keys())}." + ) + + # -- Validate standardised dataset -- # + self._validate_dataset(ds) + logging.info("--> Completed: Validated standardised dataset.") + + return ds diff --git a/OceanOSSE/io/datawriter.py b/OceanOSSE/io/datawriter.py new file mode 100644 index 0000000..1b0a1f4 --- /dev/null +++ b/OceanOSSE/io/datawriter.py @@ -0,0 +1,363 @@ +""" +datawriter.py + +Description: DataWriter module for OceanOSSE package. + +Created By: OceanOSSE Development Team (NOC, UK) +""" + +# -- Import Dependencies -- # +from __future__ import annotations + +import abc +import logging +from typing import Self + +import cftime +import numpy as np +import xarray as xr + +logger = logging.getLogger(__name__) + + +# -- DataWriter Abstract Base Class -- # +class DataWriter(abc.ABC): + """ + Abstract base class for writing a processed xarray.Dataset to persistent + storage (netCDF, Zarr, etc.). + + Parameters + ---------- + dimensions : dict[str, str] + Mapping of standard dimension names to input dataset dimension names. + coordinates : dict[str, str] + Mapping of standard coordinate names to input dataset variable names. + output_dir : str + Directory in which to write the output file. + output_name : str + Name of the output file (without extension). + date_format : str + Date format for time dimension in output filename. + Options are 'Y' (YYYY), 'M' (YYYY-MM) or 'D' (YYYY-MM-DD). + chunks : dict[str, int], optional + Dictionary defining chunk sizes for output dataset. + Default is None, meaning no chunking is applied. + writer_kwargs : dict[str, any], optional + Additional keyword arguments to pass to the underlying writing function + (e.g. ``xarray.Dataset.to_netcdf()`` or ``xarray.Dataset.to_zarr()``). + Default is None, meaning no additional keyword arguments are applied. + + Attributes + ---------- + _chunks : dict[str, int] or None + Dictionary defining chunk sizes for output dataset, or None if no chunking is applied. + _coordinates : dict[str, str] + Mapping of standard coordinate names to input dataset variable names. + _date_format : str + Date format for time dimension in output filename. + _dimensions : dict[str, str] + Mapping of standard dimension names to input dataset dimension names. + _output_name : str + Name of the output file (without extension). + _output_dir : str + Directory in which to write the output file. + _writer_kwargs : dict[str, any] or None + Additional keyword arguments to pass to the underlying writing function, or None if no additional keyword arguments are applied. + """ + + def __init__( + self, + dimensions: dict[str, str], + coordinates: dict[str, str], + output_dir: str, + output_name: str, + date_format: str, + chunks: dict[str, int] | None = None, + writer_kwargs: dict[str, any] | None = None, + ) -> None: + # -- Validate Input -- # + if not isinstance(dimensions, dict): + raise TypeError("``dimensions`` must be a specfied as a dictionary.") + if not isinstance(coordinates, dict): + raise TypeError("``coordinates`` must be a specfied as a dictionary.") + if not isinstance(output_dir, str): + raise TypeError("``output_dir`` must be a string.") + if not isinstance(output_name, str): + raise TypeError("``output_name`` must be a string.") + if not isinstance(date_format, str): + raise TypeError("``date_format`` must be a string.") + if (chunks is not None) and not isinstance(chunks, dict): + raise TypeError("``chunks`` must be a dict or None.") + if (writer_kwargs is not None) and not isinstance(writer_kwargs, dict): + raise TypeError("``writer_kwargs`` must be a dict or None.") + + # -- Class Attributes -- # + self._dimensions = dimensions + self._coordinates = coordinates + self._output_dir = output_dir + self._output_name = output_name + self._date_format = date_format + self._chunks = chunks + self._writer_kwargs = writer_kwargs + + def __repr__(self) -> str: + return ( + f"{type(self).__name__}(" + f"output_dir={self._output_dir!r}, " + f"output_name={self._output_name!r}, " + f"date_format={self._date_format!r}, " + f"chunks={self._chunks!r}, " + f"writer_kwargs={self._writer_kwargs!r})" + ) + + @classmethod + @abc.abstractmethod + def from_config(cls, config: dict) -> Self: + """ + Abstract class method to instantiate a DataWriter from the `[outputs]` table + of the .toml configuration file. + + This is the required constructor for all DataWriter subclasses - plugin + authors must implement this method for use in OceanOSSE. + + Parameters + ---------- + config : dict + Configuration dictionary containing input parameters from .toml + configuration file. + + Returns + ------- + Self + Initialised DataWriter instance. + """ + ... + + @abc.abstractmethod + def write_data(self, ds: xr.Dataset) -> str: + """ + Abstract method to write OceanOSSE output xarray.Dataset to persistent storage. + + Parameters + ---------- + ds : xarray.Dataset + Processed dataset to write. + + Returns + ------- + str + Resolved output path where data was written (for logging). + """ + ... + + def _reconstruct_dataset(self, ds: xr.Dataset) -> xr.Dataset: + """ + Reconstruct gridded ocean model xarray.Dataset from OceanOSSE output by renaming dimensions and coordinates to input names. + + Parameters + ---------- + ds : xarray.Dataset + Dataset with OceanOSSE standardised dimension and coordinate names. + + Returns + ------- + xarray.Dataset + Dataset with original gridded ocean model dimension and coordinate names. + """ + # -- Rename dimensions to original gridded ocean model dimension names -- # + ds = ds.rename_dims(self._dimensions) + + # -- Assign original coordinate names and drop any standard coordinates -- # + d_coords = {value: ds[key] for key, value in self._coordinates.items()} + ds = ds.assign_coords(d_coords) + drop_coords = [coord for coord in ds.coords if coord not in d_coords] + ds = ds.drop_vars(drop_coords) + + return ds + + def _get_output_filepath( + self, + ds: xr.Dataset, + output_dir: str, + output_name: str, + file_format: str, + date_format: str, + ) -> str: + """ + Define resolved filepath to OceanOSSE output file(s). + + Parameters: + ----------- + ds : xr.Dataset + Output xarray Dataset. + output_dir : str + Directory to save output file. + output_name : str + Prefix of output file name. + file_format : str + Output file format. Options are 'netcdf' or 'zarr'. + date_format : str + Date format for datetime limits in output filename. + Options are 'Y' (YYYY), 'M' (YYYY-MM) or 'D' (YYYY-MM-DD). + """ + # -- Validate Inputs -- # + if not isinstance(ds, xr.Dataset): + raise TypeError("ds must be an xr.Dataset.") + if not isinstance(output_dir, str): + raise TypeError("output_dir must be a string.") + if not isinstance(output_name, str): + raise TypeError("output_name must be a string.") + if file_format not in ["netcdf", "zarr"]: + raise ValueError("file_format must be either 'netcdf' or 'zarr'.") + + # -- Create Date String for Output Fileapath -- # + # Define time-limits of output dataset: + time_limits = ds["time"].values[[0, -1]] + + # Create date string from CFTime datetime objects: + if isinstance(time_limits[0], cftime.datetime): + if date_format == "Y": + fmt = "%Y" + elif date_format == "M": + fmt = "%Y-%m" + elif date_format == "D": + fmt = "%Y-%m-%d" + else: + raise ValueError( + f"Invalid date_format: '{date_format}'. Options are 'Y', 'M', 'D'." + ) + date_str = f"{time_limits[0].strftime(fmt)}-{time_limits[1].strftime(fmt)}" + + # Create date string from numpy datetime64: + elif isinstance(time_limits[0], np.datetime64): + date_str = f"{np.datetime_as_string(time_limits[0], unit=date_format)}-{np.datetime_as_string(time_limits[1], unit=date_format)}" + else: + raise TypeError( + f"Invalid type ({type(time_limits[0])}) for dates. Expected cftime.datetime or np.datetime64." + ) + + # -- Define Output Filepath -- # + if file_format == "netcdf": + output_filename = f"{output_dir}/{output_name}_{date_str}.nc" + elif file_format == "zarr": + output_filename = f"{output_dir}/{output_name}_{date_str}.zarr" + + return output_filename + + +# -- DataWriter Implementations -- # +class NetCDFDataWriter(DataWriter): + """ + DataWriter that serialises an xarray.Dataset to a netCDF file on disk. + + Parameters + ---------- + dimensions : dict[str, str] + Mapping of standard dimension names to input dataset dimension names. + coordinates : dict[str, str] + Mapping of standard coordinate names to input dataset variable names. + output_dir : str + Directory in which to write the output file. + output_name : str + Base filename (without ``.nc`` extension). + date_format : str + Date format for time dimension in output filename. + Options are 'Y' (YYYY), 'M' (YYYY-MM) or 'D' (YYYY-MM-DD). + chunks : dict[str, int], optional + Dictionary defining chunk sizes for output dataset. + Default is None, meaning no chunking is applied. + writer_kwargs : dict[str, any], optional + Additional keyword arguments to pass to xarray.Dataset.to_netcdf. + """ + + def __init__( + self, + dimensions: dict[str, str], + coordinates: dict[str, str], + output_dir: str, + output_name: str, + date_format: str, + chunks: dict[str, int] | None = None, + writer_kwargs: dict[str, any] | None = None, + ) -> None: + # -- Initialise parent DataWriter class -- # + super().__init__( + dimensions=dimensions, + coordinates=coordinates, + output_dir=output_dir, + output_name=output_name, + date_format=date_format, + chunks=chunks, + writer_kwargs=writer_kwargs, + ) + + @classmethod + def from_config(cls, config: dict) -> Self: + """ + Instantiate a NetCDFDataWriter from the `[outputs]` table of the .toml configuration file. + """ + # -- Verify Input -- # + if not isinstance(config, dict): + raise TypeError("config must be a dictionary.") + + # -- Instantiate NetCDFDataWriter from config -- # + inputs = config["inputs"] + outputs = config["outputs"] + return cls( + dimensions=inputs["dimensions"], + coordinates=inputs["coordinates"], + output_dir=outputs["output_dir"], + output_name=outputs["output_name"], + date_format=outputs["date_format"], + chunks=outputs.get("chunks", None), + writer_kwargs=outputs.get("writer_kwargs", None), + ) + + def write_data(self, ds: xr.Dataset) -> None: + """ + Write OceanOSSE output xarray.Dataset to a netCDF file. + + Parameters + ---------- + ds : xarray.Dataset + Processed dataset to write. + """ + # -- Validate Inputs -- # + if not isinstance(ds, xr.Dataset): + raise TypeError("ds must be an xr.Dataset.") + if self._chunks is not None and not isinstance(self._chunks, dict): + raise TypeError("chunks must be a dictionary.") + if not isinstance(self._output_dir, str): + raise TypeError("output_dir must be a string.") + if not isinstance(self._output_name, str): + raise TypeError("output_name must be a string.") + if self._date_format not in ["Y", "M", "D"]: + raise ValueError("date_format must be 'Y', 'M' or 'D'.") + + # -- Define Output Filepath -- # + output_filepath = self._get_output_filepath( + ds=ds, + output_dir=self._output_dir, + output_name=self._output_name, + file_format="netcdf", + date_format=self._date_format, + ) + + # -- Reconstruct Dataset with Original Dimension and Coordinate Names -- # + ds = self._reconstruct_dataset(ds) + + # -- Optionally Apply Chunking -- # + if self._chunks is not None: + ds = ds.chunk(self._chunks) + + # -- Write Dataset to NetCDF -- # + if self._writer_kwargs is None: + # Default writer_kwargs for netCDF output: + self._writer_kwargs = { + "unlimited_dims": self._dimensions.get("time"), + "mode": "w", + } + ds.to_netcdf(path=output_filepath, **self._writer_kwargs) + logging.info( + f"--> Completed: Written output dataset to netCDF file: {output_filepath}" + ) diff --git a/OceanOSSE/pipeline.py b/OceanOSSE/pipeline.py index f6962de..a33c3bb 100644 --- a/OceanOSSE/pipeline.py +++ b/OceanOSSE/pipeline.py @@ -6,4 +6,320 @@ Created By: Ollie Tooth (oliver.tooth@noc.ac.uk) """ -# -- Import Dependencies -- # \ No newline at end of file +# -- Import Dependencies -- # +import logging + +from OceanOSSE.gridding.regridder import Regridder, TestRegridder +from OceanOSSE.io.dataloader import DataLoader, NetCDFDataLoader +from OceanOSSE.io.datawriter import DataWriter, NetCDFDataWriter +from OceanOSSE.sampling.sampler import ObsSampler, TestObsSampler +from OceanOSSE.utils import import_class, load_config + +logger = logging.getLogger(__name__) + +# -- Registries -- # +_DATA_LOADER_REGISTRY: dict[str, type[DataLoader]] = {"netcdf": NetCDFDataLoader} + +_OBS_SAMPLER_REGISTRY: dict[str, type[ObsSampler]] = {"test": TestObsSampler} + +_REGRIDDER_REGISTRY: dict[str, type[Regridder]] = {"test": TestRegridder} + +_DATA_WRITER_REGISTRY: dict[str, type[DataWriter]] = {"netcdf": NetCDFDataWriter} + + +# -- Factory Functions -- # +def _create_DataLoader(config: dict) -> DataLoader: + """ + Instantiate a DataLoader from the `[inputs]` table of the .toml configuration file. + + Options: + - Built-in Registry: + - name: "netcdf" -> NetCDFDataLoader + + - Plugins: + - Custom DataLoader imported from `module` and `name` specified in config. + """ + # -- Validate Inputs -- # + if not isinstance(config, dict): + raise TypeError("``config`` must be a dictionary.") + + # -- Instantiate DataLoader -- # + inputs = config["inputs"] + + # 1. Plugin DataLoader: + if (inputs.get("module") is not None) and (inputs.get("name") is not None): + # -- Import custom DataLoader class -- # + data_loader = import_class( + module=inputs["module"], class_name=inputs["name"], class_type=DataLoader + ) + logger.info( + f"Completed: Created DataLoader from Plugin: {inputs['module']}.{inputs['name']}" + ) + + # 2. Registry DataLoader: + else: + # -- Use DataLoader class from registry -- # + format = inputs.get("format", "netcdf4") + try: + data_loader = _DATA_LOADER_REGISTRY[format] + except KeyError as e: + raise KeyError( + f"DataLoader '{format}' not found in registry. Available options: {list(_DATA_LOADER_REGISTRY.keys())}" + ) from e + logger.info( + f"Completed: Created DataLoader from Registry -> {format}: {data_loader.__name__}" + ) + + return data_loader.from_config(config=config) + + +def _create_ObsSampler(config: dict) -> ObsSampler: + """ + Instantiate an ObsSampler from the `[sampling]` table of the .toml configuration file. + + Options: + - Built-in Registry: + - name: "test" -> TestObsSampler + + - Plugins: + - Custom ObsSampler imported from `module` and `name` specified in config. + """ + # -- Validate Inputs -- # + if not isinstance(config, dict): + raise TypeError("``config`` must be a dictionary.") + + # -- Instantiate ObsSampler -- # + sampling = config["sampling"] + + # 1. Plugin ObsSampler: + if (sampling.get("module") is not None) and (sampling.get("name") is not None): + # -- Import custom ObsSampler class -- # + obs_sampler = import_class( + module=sampling["module"], + class_name=sampling["name"], + class_type=ObsSampler, + ) + logger.info( + f"Completed: Created ObsSampler from Plugin: {sampling['module']}.{sampling['name']}" + ) + + # 2. Registry ObsSampler: + else: + # -- Use ObsSampler class from registry -- # + name = sampling.get("name", "test") + try: + obs_sampler = _OBS_SAMPLER_REGISTRY[name] + except KeyError as e: + raise KeyError( + f"ObsSampler '{name}' not found in registry. Available options: {list(_OBS_SAMPLER_REGISTRY.keys())}" + ) from e + logger.info( + f"Completed: Created ObsSampler from Registry -> {name}: {obs_sampler.__name__}" + ) + + return obs_sampler.from_config(config=config) + + +def _create_Regridder(config: dict) -> Regridder: + """ + Instantiate a Regridder from the `[regridding]` table of the .toml configuration file. + + Options: + - Built-in Registry: + - name: "test" -> TestRegridder + + - Plugins: + - Custom Regridder imported from `module` and `name` specified in config. + """ + # -- Validate Inputs -- # + if not isinstance(config, dict): + raise TypeError("``config`` must be a dictionary.") + + # -- Instantiate Regridder -- # + regridding = config["regridding"] + + # 1. Plugin Regridder: + if (regridding.get("module") is not None) and (regridding.get("name") is not None): + # -- Import custom Regridder class -- # + regridder = import_class( + module=regridding["module"], + class_name=regridding["name"], + class_type=Regridder, + ) + logger.info( + f"Completed: Created Regridder from Plugin: {regridding['module']}.{regridding['name']}" + ) + + # 2. Registry Regridder: + else: + # -- Use Regridder class from registry -- # + name = regridding.get("name", "test") + try: + regridder = _REGRIDDER_REGISTRY[name] + except KeyError as e: + raise KeyError( + f"Regridder '{name}' not found in registry. Available options: {list(_REGRIDDER_REGISTRY.keys())}" + ) from e + logger.info( + f"Completed: Created Regridder from Registry -> {name}: {regridder.__name__}" + ) + + return regridder.from_config(config=config) + + +def _create_DataWriter(config: dict) -> DataWriter: + """ + Instantiate a DataWriter from the `[outputs]` table of the .toml configuration file. + + Options: + - Built-in Registry: + - name: "netcdf" -> NetCDFDataWriter + + - Plugins: + - Custom DataWriter imported from `module` and `name` specified in config. + """ + # -- Validate Inputs -- # + if not isinstance(config, dict): + raise TypeError("``config`` must be a dictionary.") + + # -- Instantiate DataWriter -- # + outputs = config["outputs"] + + # 1. Plugin DataWriter: + if (outputs.get("module") is not None) and (outputs.get("name") is not None): + # -- Import custom DataWriter class -- # + data_writer = import_class( + module=outputs["module"], class_name=outputs["name"], class_type=DataWriter + ) + logger.info( + f"Completed: Created DataWriter from Plugin: {outputs['module']}.{outputs['name']}" + ) + + # 2. Registry DataWriter: + else: + # -- Use DataWriter class from registry -- # + format = outputs.get("format", "netcdf4") + try: + data_writer = _DATA_WRITER_REGISTRY[format] + except KeyError as e: + raise KeyError( + f"DataWriter '{format}' not found in registry. Available options: {list(_DATA_WRITER_REGISTRY.keys())}" + ) from e + logger.info( + f"Completed: Created DataWriter from Registry -> {format}: {data_writer.__name__}" + ) + + return data_writer.from_config(config=config) + + +# -- Define Pipeline Functions -- # +def run_pipeline(args: dict) -> None: + """ + Run OceanOSSE pipeline using specified config .ini file. + + Pipeline Steps: + 1. Instantiate DataLoader -> Load standardised ocean model dataset. + 2. Instantiate ObsSampler -> Sample synthetic ocean observations from model dataset. + 3. Instantiate Regridder -> Regrid synthetic observations onto original model grid. + 4. Instantiate DataWriter -> Write output dataset to file. + + Parameters: + ----------- + args : dict + Command line arguments. + """ + # === Inputs === # + logger.info("==== Inputs ====") + # Load config .toml file: + config = load_config(config_path=args["config_file"]) + logger.info(f"Completed: Read & validated config file -> {args['config_file']}") + + # Load ocean model dataset using DataLoader: + data_loader = _create_DataLoader(config=config) + logger.info(f"In Progress: Loading ocean model dataset using {data_loader}...") + ds_mdl = data_loader.load_data() + + # === Sampling === # + logger.info("==== Sampling ====") + # Sample synthetic ocean observations using ObsSampler: + obs_sampler = _create_ObsSampler(config=config) + logger.info( + f"In Progress: Sampling synthetic ocean observations using {obs_sampler}..." + ) + ds_obs = obs_sampler.sample(ds=ds_mdl) + + # === Regridding === # + logger.info("==== Regridding ====") + # Regrid synthetic ocean observations using Regridder: + regridder = _create_Regridder(config=config) + logger.info( + f"In Progress: Regridding synthetic ocean observations using {regridder}..." + ) + ds_regridded = regridder.regrid(ds=ds_obs) + + # === Outputs === # + logger.info("==== Outputs ====") + # Write output dataset to file using DataWriter: + data_writer = _create_DataWriter(config=config) + logger.info(f"In Progress: Writing output dataset to file using {data_writer}...") + data_writer.write_data(ds=ds_regridded) + # Close all files: + for ds in [ds_mdl, ds_obs, ds_regridded]: + ds.close() + logger.info("--> Completed: Closed all dataset files.") + + +def describe_pipeline(args: dict) -> str: + """ + Describe & validate OceanOSSE pipeline using config. + + Parameters: + ----------- + args : dict + Command line arguments. + + Returns: + -------- + str + Description of OceanOSSE pipeline. + """ + # === Inputs === # + logger.info("==== Inputs ====") + # Load config .toml file: + config = load_config(config_path=args["config_file"]) + logger.info(f"Completed: Read & validated config file -> {args['config_file']}") + + # Load ocean model dataset using DataLoader: + data_loader = _create_DataLoader(config=config) + logger.info("Action: Load ocean model dataset using...") + logger.info("* DataLoader : %r", data_loader) + + # === Sampling === # + logger.info("==== Sampling ====") + # Sample synthetic ocean observations using ObsSampler: + obs_sampler = _create_ObsSampler(config=config) + logger.info("Action: Sample synthetic ocean observations using...") + logger.info("* ObsSampler : %r", obs_sampler) + + # === Regridding === # + logger.info("==== Regridding ====") + # Regrid synthetic ocean observations using Regridder: + regridder = _create_Regridder(config=config) + logger.info("Action: Regrid synthetic ocean observations using...") + logger.info("* Regridder : %r", regridder) + + # === Outputs === # + logger.info("==== Outputs ====") + # Write output dataset to file using DataWriter: + data_writer = _create_DataWriter(config=config) + logger.info("Action: Write output dataset to file using...") + logger.info("* DataWriter : %r", data_writer) + # Define output filepath: + outputs = config["outputs"] + if outputs["format"] == "netcdf": + extension = "nc" + else: + extension = "zarr" + logger.info( + f"* Output File = {outputs['output_dir']}/{outputs['output_name']}_YYYY-MM_YYYY-MM.{extension}" + ) diff --git a/OceanOSSE/sampling/sampler.py b/OceanOSSE/sampling/sampler.py new file mode 100644 index 0000000..8e23363 --- /dev/null +++ b/OceanOSSE/sampling/sampler.py @@ -0,0 +1,334 @@ +""" +sampler.py + +Description: Sampling module for OceanOSSE package. + +Created By: OceanOSSE Development Team (NOC, UK) +""" + +# -- Import Dependencies -- # +from __future__ import annotations + +import abc +import logging +from typing import Self + +import xarray as xr + +from OceanOSSE.utils import import_class + +logger = logging.getLogger(__name__) + + +# -- Utility Functions -- # +def get_error_kernels(config: dict) -> list[ErrorKernel] | None: + """ + Utility function to instantiate ErrorKernel instances from the `[sampling]` + table of the .toml configuration file. + + Parameters + ---------- + config : dict + Configuration dictionary containing input parameters from .toml + configuration file. + + Returns + ------- + list[ErrorKernel] | None + List of initialised ErrorKernel instances, or None if no kernels are + specified in the configuration. + """ + error_kernels_config = config["sampling"].get("error_kernels", None) + + if error_kernels_config is not None: + _ERROR_KERNEL_REGISTRY = {"test": TestErrorKernel} + kernels: list[ErrorKernel] = [] + for kernel_cfg in error_kernels_config: + if ("module" in kernel_cfg) and ("name" in kernel_cfg): + # -- Import custom ErrorKernel class -- # + Kernel = import_class( + module=kernel_cfg["module"], + class_name=kernel_cfg["name"], + class_type=ErrorKernel, + ) + + else: + # -- Use ErrorKernel class from registry -- # + try: + Kernel = _ERROR_KERNEL_REGISTRY[kernel_cfg["name"]] + except KeyError as e: + raise KeyError( + f"ErrorKernel name '{kernel_cfg['name']}' not found in registry." + ) from e + + # -- Instantiate ErrorKernel from configuration -- # + kernels.append(Kernel.from_config(config=config)) + + return kernels + + +# -- ErrorKernel Abstract Base Class -- # +class ErrorKernel(abc.ABC): + """ + Abstract base class for applying instrument or representation errors + to synthetic ocean observations. + + ErrorKernel transforms a sampled xarray.Dataset by adding noise, + applying a bias, convolving a point-spread function, etc. + + Multiple kernels can be chained by an :class:`ObsSampler` and are applied + sequentially in declaration order. + """ + + def __repr__(self) -> str: + return f"{type(self).__name__}()" + + @classmethod + @abc.abstractmethod + def from_config(cls, config: dict) -> Self: + """ + Abstract class method to instantiate an ErrorKernel from the `[sampling]` + table of the .toml configuration file. + + This is the required constructor for all ErrorKernel subclasses - plugin + authors must implement this method for use in OceanOSSE. + + Parameters + ---------- + config : dict + Configuration dictionary containing input parameters from .toml + configuration file. + + Returns + ------- + Self + Initialised ErrorKernel instance. + """ + ... + + @abc.abstractmethod + def apply(self, ds: xr.Dataset) -> xr.Dataset: + """ + Abstract method to apply the error kernel to an xarray.Dataset of + synthetic observations. + + Parameters + ---------- + ds : xarray.Dataset + Synthetic observations dataset produced by `ObsSampler.sample`. + + Returns + ------- + xarray.Dataset + Synthetic observations dataset with error applied. + """ + ... + + +class TestErrorKernel(ErrorKernel): + """ + ErrorKernel used for testing and scaffold validation. + + Returns the synthetic observations xarray.Dataset unchanged. + """ + + @classmethod + def from_config(cls, config: dict) -> Self: + """ + Instantiate a TestErrorKernel from the `[sampling]` table of + the .toml configuration file. + """ + return cls() + + def apply(self, ds: xr.Dataset) -> xr.Dataset: + """ + Apply the TestErrorKernel to an xarray.Dataset of synthetic observations. + + Parameters + ---------- + ds : xarray.Dataset + Synthetic observations dataset produced by `ObsSampler.sample`. + + Returns + ------- + xarray.Dataset + Synthetic observations dataset unchanged. + """ + logger.debug( + "Applying TestErrorKernel -> returns synthetic observations dataset unchanged." + ) + return ds + + +# -- ObsSampler Abstract Base Class -- # +class ObsSampler(abc.ABC): + """ + Abstract base class for sampling gridded ocean model output analogously + to an ocean observing platform (e.g., Argo floats). + + Parameters + ---------- + error_kernels : list[ErrorKernel], optional + List of ErrorKernel instances to apply sequentially to the sampled + synthetic observations dataset, by default None. + """ + + def __init__(self, error_kernels: list[ErrorKernel] | None = None): + # -- Validate Inputs -- # + if error_kernels is not None: + if not isinstance(error_kernels, list): + raise TypeError( + "`error_kernels` must be a list of ErrorKernel instances." + ) + for n, kernel in enumerate(error_kernels): + if not isinstance(kernel, ErrorKernel): + raise TypeError(f"`error_kernels[{n}]` must be an ErrorKernel.") + + # -- Class Attributes -- # + self._error_kernels = error_kernels + + def __repr__(self) -> str: + return f"{type(self).__name__}(error_kernels={self._error_kernels!r})" + + @classmethod + @abc.abstractmethod + def from_config(cls, config: dict) -> Self: + """ + Abstract class method to instantiate an ObsSampler from the `[sampling]` + table of the .toml configuration file. + + This is the required constructor for all ObsSampler subclasses - plugin + authors must implement this method for use in OceanOSSE. + + Parameters + ---------- + config : dict + Configuration dictionary containing input parameters from .toml + configuration file. + + Returns + ------- + Self + Initialised ObsSampler instance. + """ + ... + + @abc.abstractmethod + def collect_samples(self, ds: xr.Dataset) -> xr.Dataset: + """ + Abstract method to sample a gridded xarray.Dataset of ocean model output + to produce a synthetic observations dataset. + + This is the required sampling method for all ObsSampler subclasses - + plugin authors must implement this method for use in OceanOSSE. + + Parameters + ---------- + ds : xarray.Dataset + Gridded ocean model output dataset. + + Returns + ------- + xarray.Dataset + Sampled synthetic observations dataset. + """ + ... + + def apply_errors(self, ds: xr.Dataset) -> xr.Dataset: + """ + Apply all registered `ErrorKernel` instances to synthetic + observations sequentially. + + If no kernels are registered, the synthetic observations + dataset is returned unchanged. + + Parameters + ---------- + ds : xarray.Dataset + Synthetic observations dataset. + + Returns + ------- + xarray.Dataset + Synthetic observations dataset with all error kernels + applied in order. + """ + # -- Apply each Error Kernel sequentially -- # + if self._error_kernels is not None: + for kernel in self._error_kernels: + logger.debug(f"Applying ErrorKernel --> {repr(kernel)}") + ds = kernel.apply(ds) + logging.info( + "--> Completed: Applied ErrorKernels to synthetic observations." + ) + + return ds + + def sample(self, ds: xr.Dataset) -> xr.Dataset: + """ + Perform sampling pipeline for chosen ocean observing platform. + + Parameters + ---------- + ds : xarray.Dataset + Gridded ocean model dataset. + + Returns + ------- + xarray.Dataset + Synthetic observations dataset with errors applied. + """ + # -- Sample the gridded ocean model output -- # + ds_sampled = self.collect_samples(ds) + logging.info( + "--> Completed: Collected samples from ocean model dataset using ObsSampler." + ) + + # -- Apply error kernels sequentially to the synthetic observations -- # + ds_obs = self.apply_errors(ds_sampled) + + return ds_obs + + +# -- ObsSampler Implementations -- # + + +class TestObsSampler(ObsSampler): + """ + ObsSampler used for testing and scaffold validation. + + Returns the input gridded ocean model dataset unchanged as the synthetic + observations dataset. + """ + + @classmethod + def from_config(cls, config: dict) -> Self: + """ + Instantiate a TestObsSampler from the `[sampling]` table of + the .toml configuration file. + """ + # -- Collect ErrorKernel instances from configuration -- # + error_kernels = get_error_kernels(config=config) + + # -- Instantiate TestObsSampler with collected ErrorKernel instances -- # + return cls(error_kernels=error_kernels or None) + + def collect_samples(self, ds: xr.Dataset) -> xr.Dataset: + """ + Sample a gridded xarray.Dataset of ocean model output to produce a + synthetic observations dataset. + + Parameters + ---------- + ds : xarray.Dataset + Gridded ocean model output dataset. + + Returns + ------- + xarray.Dataset + Synthetic observations dataset (unchanged from input). + """ + logger.debug( + "Collecting samples with TestObsSampler -> returns input dataset unchanged." + ) + return ds diff --git a/OceanOSSE/utils.py b/OceanOSSE/utils.py new file mode 100644 index 0000000..82d7db7 --- /dev/null +++ b/OceanOSSE/utils.py @@ -0,0 +1,115 @@ +""" +utils.py + +Description: Utility functions for OceanOSSE package. + +Created By: OceanOSSE Development Team (NOC, UK) +""" + +# -- Import Dependencies -- # +from __future__ import annotations + +import importlib +import tomllib +from pathlib import Path + +from OceanOSSE.validation import AppConfig + + +# -- Utility Functions -- # +def load_config(config_path: str) -> dict: + """ + Load and parse an OceanOSSE configuration file. + + Parameters + ---------- + config_path : str + Path to the OceanOSSE ``.toml`` configuration file. + + Returns + ------- + dict + Parsed configuration as a nested dictionary. + + Raises + ------ + FileNotFoundError + tomllib.TOMLDecodeError + """ + # -- Validate Inputs -- # + if not isinstance(config_path, str): + raise TypeError("``config_path`` must be a string.") + path = Path(config_path) + if not path.exists(): + raise FileNotFoundError(f"Configuration file not found: {config_path!r}") + + # -- Load and Parse Configuration -- # + # Open config .toml file: + with open(path, "rb") as f: + cfg_data = tomllib.load(f) + + # Parse and validate config data using Pydantic models: + config = AppConfig(**cfg_data) + # Convert config params to dict: + d_config = config.model_dump(mode="json") + + return d_config + + +def import_class( + module: str, + class_name: str, + class_type: type, +) -> type: + """ + Dynamically import a class from specified module path. + + This is used by the OceanOSSE pipeline to load third-party + plugin classes declared in a `.toml` configuration file. + + Parameters + ---------- + module : str + Import path to the module, e.g. ``"my_package.samplers"``. + class_name : str + Name of the class to import from module, e.g. ``"ArgoSampler"``. + class_type : type + Expected base type of the class. + + Returns + ------- + type + Imported class object. + + Raises + ------ + ModuleNotFoundError + AttributeError + """ + # -- Validate Inputs -- # + if not isinstance(module, str): + raise TypeError("``module`` must be a string.") + if not isinstance(class_name, str): + raise TypeError("``class_name`` must be a string.") + + # -- Dynamically import class -- # + try: + module = importlib.import_module(module) + except ImportError as e: + raise ImportError(f"Failed to import module '{module}'") from e + try: + Kernel = getattr(module, class_name) + except AttributeError as e: + raise AttributeError( + f"Failed to import {class_type.__name__} '{class_name}' from module '{module}'" + ) from e + + # -- Verify that class is callable -- # + if not callable(Kernel): + raise TypeError(f"{class_type.__name__} '{class_name}' is not callable.") + + # -- Verify that class is a subclass of the expected type -- # + if not issubclass(Kernel, class_type): + raise TypeError(f"'{class_name}' is not a subclass of '{class_type.__name__}'.") + + return Kernel diff --git a/OceanOSSE/validation.py b/OceanOSSE/validation.py new file mode 100644 index 0000000..8a66dc1 --- /dev/null +++ b/OceanOSSE/validation.py @@ -0,0 +1,205 @@ +""" +validation.py + +Description: Validation functions for OceanOSSE package. + +Created By: OceanOSSE Development Team (NOC, UK) +""" + +# -- Import Dependencies -- # +from __future__ import annotations + +from typing import Any, Literal + +from pydantic import BaseModel, Field, field_validator, model_validator + + +# -- [inputs.variables.] -- # +class VariableConfig(BaseModel): + """ + Schema for `[inputs.variables.]`. + + Example + ------- + [inputs.variables.thetao] + path = "/data/NEMO_thetao_*.nc" + chunks = {time_counter = 1} + """ + + path: str + open_kwargs: dict[str, Any] = Field(default_factory=dict) + + @field_validator("path") + @classmethod + def path_must_be_nonempty(cls, v: str) -> str: + if not v.strip(): + raise ValueError("'path' must not be an empty string.") + return v + + +# -- [inputs.dimensions] -- # +class DimensionsConfig(BaseModel): + """ + Schema for `[inputs.dimensions]`. + + Example + ------- + .. code-block:: toml + + [inputs.dimensions] + time = "time_counter" + lev = "deptht" + j = "y" + i = "x" + """ + + time: str = "time" + lev: str = "lev" + j: str = "j" + i: str = "i" + + +# -- [inputs.coordinates] -- # +class CoordinatesConfig(BaseModel): + """ + Schema for `[inputs.coordinates]`. + + Example + ------- + .. code-block:: toml + + [inputs.coordinates] + time = "time_counter" + depth = "deptht" + lon = "nav_lon" + lat = "nav_lat" + """ + + time: str = "time" + depth: str = "depth" + lon: str = "lon" + lat: str = "lat" + + +# -- [inputs] -- # +class InputConfig(BaseModel): + """ + Schema for `[inputs]`. + + The `variables` field accepts any number of named sub-tables: + + .. code-block:: toml + + [inputs.variables.thetao] + path = "..." + open_kwargs = { engine = "netcdf4" } + + [inputs.variables.so] + path = "..." + open_kwargs = { engine = "netcdf4" } + """ + + dimensions: DimensionsConfig = Field(default_factory=DimensionsConfig) + coordinates: CoordinatesConfig = Field(default_factory=CoordinatesConfig) + data_dir: str = "" + format: Literal["netcdf", "zarr"] = "netcdf" + variables: dict[str, VariableConfig] + name: str | None = None + module: str | None = None + + @field_validator("variables") + @classmethod + def variables_must_not_be_empty(cls, v: dict) -> dict: + if not v: + raise ValueError( + "`[inputs.variables]` must contain at least one variable entry." + ) + return v + + +# -- [sampling] -- # +class SamplingConfig(BaseModel): + """ + Schema for `[sampling]`. + + .. code-block:: toml + + [sampling] + name = "..." + + [[sampling.error_kernels]] + name = "..." + + [[sampling.error_kernels]] + module = "..." + name = "..." + """ + + name: str = "test" + module: str | None = None + error_kernels: list[dict[str, Any]] = Field(default_factory=list) + + @model_validator(mode="after") + def plugin_requires_both_module_and_name(self) -> SamplingConfig: + if (self.module is not None) and (self.name is None): + raise ValueError( + "`[sampling]`: 'module' and 'name' must both be specified " + "together for plugin loading." + ) + return self + + +# -- [regridding] -- # +class RegriddingConfig(BaseModel): + """ + Schema for `[regridding]`. + + .. code-block:: toml + + [regridding] + module = "..." + name = "..." + """ + + name: str = "test" + module: str | None = None + + @model_validator(mode="after") + def plugin_requires_both_module_and_name(self) -> RegriddingConfig: + if (self.module is not None) and (self.name is None): + raise ValueError( + "``[regridding]``: 'module' and 'name' must both be specified together." + ) + return self + + +# -- [outputs] -- # +class OutputConfig(BaseModel): + format: Literal["netcdf"] = "netcdf" + output_dir: str + output_name: str + date_format: Literal["Y", "M", "D"] = "Y" + chunks: dict[str, int] = Field(default_factory=dict) + name: str | None = None + module: str | None = None + + @model_validator(mode="after") + def plugin_requires_both_module_and_name(self) -> OutputConfig: + if (self.module is not None) and (self.name is None): + raise ValueError( + "``[outputs]``: 'module' and 'name' must both be specified together." + ) + return self + + +# -- Top-level AppConfig -- # +class AppConfig(BaseModel): + """ + Top-level OceanOSSE configuration model. + Validates the entire parsed .toml dict. + """ + + inputs: InputConfig + sampling: SamplingConfig = Field(default_factory=SamplingConfig) + regridding: RegriddingConfig = Field(default_factory=RegriddingConfig) + outputs: OutputConfig diff --git a/examples/example_config.toml b/examples/example_config.toml new file mode 100644 index 0000000..0794038 --- /dev/null +++ b/examples/example_config.toml @@ -0,0 +1,54 @@ +# ============================================================ +# OceanOSSE Example Configuration: NEMO eORCA1 ERA5v1 +# ============================================================ +# To run: +# OceanOSSE run examples/example_config.toml +# OceanOSSE run examples/example_config.toml --dry-run +# ============================================================ + +# ========================== Inputs ========================== +[inputs] +data_dir = "/path/to/nemo/output" +format = "netcdf" + +[inputs.dimensions] +time = "time_counter" +lev = "deptht" +j = "y" +i = "x" + +[inputs.coordinates] +time = "time_counter" +depth = "deptht" +lon = "nav_lon" +lat = "nav_lat" + +[inputs.variables.thetao_con] +path = "/dssgfs01/scratch/npd/simulations/eORCA1_ERA5_v1/eORCA1_ERA5_1m_grid_T_202312-202312.nc" +open_kwargs = { engine = "netcdf4" } + +[inputs.variables.so_abs] +path = "/dssgfs01/scratch/npd/simulations/eORCA1_ERA5_v1/eORCA1_ERA5_1m_grid_T_202312-202312.nc" +open_kwargs = { engine = "netcdf4" } + +# ========================== Sampling ========================== +[sampling] +name = "test" + +[[sampling.error_kernels]] +name = "test" +kwargs = { sigma = 1.0 } + +# ========================== Regridding ========================== +[regridding] +name = "test" +kwargs = { method = "nearest_s2d" } + +# ========================== Outputs ========================== +[outputs] +format = "netcdf" +output_dir = "/dssgfs01/working/otooth/Software/OceanOSSE/OceanOSSE" +output_name = "OceanOSSE_TEST" +date_format = "M" +chunks = { time_counter = 12 } +writer_kwargs = { unlimited_dims = "time_counter", mode = "w" } diff --git a/pixi.toml b/pixi.toml index 2a854b1..699c8e2 100644 --- a/pixi.toml +++ b/pixi.toml @@ -24,6 +24,7 @@ python = "*" dask = "*" gsw = "*" netCDF4 = "*" +typer = ">=0.24.1,<0.25" xarray = "*" zarr = "*" git = "*" # needed for dynamic versioning @@ -32,6 +33,7 @@ git = "*" # needed for dynamic versioning # Define OceanOSSE package as local path dependency: OceanOSSE = { path = "." } python = "3.13.*" +pydantic = ">=2.13.3,<3" # === Features === # @@ -42,6 +44,7 @@ ipykernel = "*" matplotlib = "*" cartopy = "*" nc-time-axis = "*" +polars = ">=1.39.3,<2" [feature.dev.pypi-dependencies] ruff = "*" @@ -76,4 +79,4 @@ preview-docs = { cmd = "cd docs/; zensical serve", description = "Run local prev default = { solve-group = "default" } dev = { features = ["dev"], solve-group = "default" } docs = { features = ["docs"], solve-group = "default" } -release = { features = ["release"], solve-group = "default" } \ No newline at end of file +release = { features = ["release"], solve-group = "default" } diff --git a/pyproject.toml b/pyproject.toml index 7467a1f..986eef3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,15 +19,13 @@ readme = "README.md" requires-python = ">=3.12" dependencies = [ - "dask>=2025.7.0", - "flox>=0.10.6", - "gsw>=3.6.20", - "numbagg>=0.8", - "numba>=0.62", - "numpy>=2.3", # numba 0.62 added support for numpy 2.3 - "netCDF4<=1.7.3", - "xarray>=2025.07.1", - "zarr>=3.0.7", + "cftime", + "dask", + "netcdf4", + "pydantic", + "typer", + "xarray", + "zarr", ] [project.optional-dependencies] @@ -62,3 +60,6 @@ ignore = ["E501"] [tool.ruff.format] quote-style = "double" + +[project.scripts] +OceanOSSE = "OceanOSSE.cli:app" diff --git a/tests/unit/test_regridder.py b/tests/unit/test_regridder.py new file mode 100644 index 0000000..b52e75d --- /dev/null +++ b/tests/unit/test_regridder.py @@ -0,0 +1,142 @@ +# =================================================================== +# Copyright 2025 National Oceanography Centre +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0. +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. See the License for the specific language governing +# permissions and limitations under the License. +# =================================================================== +""" +test_sampler.py + +Description: +This module includes unit tests for extracting profiles. + +Author: +Benjamin Barton (benbar@noc.ac.uk) +""" + +import pytest +import datetime as dt +import numpy as np +import xarray as xr + +from OceanOSSE.gridding.regridder_simple import SwapRegridder + + +def test_climatology(): + """ + Test producing a daily climatology. + """ + ds = construct_ds() + + regrid = SwapRegridder(ds) + clim = regrid.ds_clim + + ts = ds.votemper.mean(dim=["d", "j", "i"]) + clim_mean = clim.votemper.mean(dim=["d", "j", "i"]) + + st_date = dt.datetime(2020, 5, 1) + test_sec1 = (dt.datetime(2020, 5, 30) - st_date).days + test_sec2 = (dt.datetime(2021, 5, 30) - st_date).days + test_temp1 = 15 - (0 * 0.4) + (0 * 0.2) - (0 * 0.2) + (test_sec1 * 0.000005) + test_temp2 = 15 - (0 * 0.4) + (0 * 0.2) - (0 * 0.2) + (test_sec2 * 0.000005) + test_temp = (test_temp1 + test_temp2) / 2 + clim_day = clim.votemper.sel(t='2020-05-30').isel(d=0, j=0, i=0) + + assert (np.isclose(ts.mean().to_numpy(), clim_mean.mean().to_numpy(), atol=1e8) + & (clim_day.to_numpy() == test_temp)) + + +def test_regrid(): + """ + Test replacing profiles in climatology with model data. + """ + ds = construct_ds() + synth_profiles = construct_profile_ds() + + regrid_data = SwapRegridder(ds) + ds_model = regrid_data.regrid(synth_profiles) + + assert ((ds_model != regrid_data.ds_clim) + & (ds_model.isel(i=3, j=5, t=31) == synth_profiles.isel(profile_id=0))) + + +def construct_ds(): + """ + Build a dataset for testing. + """ + lat = np.arange(0, 8) + lon = np.arange(0, 10) + depth = np.arange(0, 150, 10) + st_date = dt.datetime(2020, 5, 1) + num_days = 730 + model_dates = np.array([st_date + dt.timedelta(days=x) for x in range(num_days)]) + model_day = np.array([x for x in range(num_days)]) + + # Broadcast to 4D (time, depth, lat, lon) + t, d, y, x = np.meshgrid(model_day, depth, lat, lon, indexing='ij') + + votemper = 15 - (y * 0.4) + (x * 0.2) - (d * 0.2) + (t * 0.000005) + + # Build dataset + ds = xr.Dataset( + { + "votemper": (("t", "d", "j", "i"), votemper), + "lat": (("j", "i"), y[0, 0, :, :]), + "lon": (("j", "i"), x[0, 0, :, :]), + "depth": (("d", "j", "i"), d[0, :, :, :]), + "time": (("t"), t[:, 0, 0, 0]) + }, + coords={ + "d": depth, + "j": lat, + "i": lon, + "t": model_dates + }, + ) + + return ds + + +def construct_profile_ds(): + d = np.arange(0, 150, 10) + profile_id = np.arange(2) + + j = np.array([5, 6]) + i = np.array([3, 8]) + + depth = np.tile(d[:, None], (1, profile_id.size)) + + # Time coordinate + st_date = dt.datetime(2020, 5, 1) + time = np.array([ + dt.datetime(2020, 6, 1), + dt.datetime(2020, 7, 2), + ]) + time_day = np.array([(x - st_date).days for x in time]) + + votemper = 15 - depth * 0.02 - j[None, :] * 0.2 + i[None, :] * 0.1 + (time_day * 0.000005) + + ds = xr.Dataset( + data_vars={ + "votemper": (("d", "profile_id"), votemper), + "lat": (("profile_id",), j), + "lon": (("profile_id",), i), + "depth": (("d", "profile_id"), depth), + }, + coords={ + "d": d, + "profile_id": profile_id, + "t": (("profile_id",), time), + "j": (("profile_id",), j), + "i": (("profile_id",), i), + }, + ) + + return ds \ No newline at end of file