diff --git a/OceanOSSE/__init__.py b/OceanOSSE/__init__.py index f4a21a1..ebdae71 100644 --- a/OceanOSSE/__init__.py +++ b/OceanOSSE/__init__.py @@ -1,23 +1,33 @@ """ OceanOSSE -Python toolbox for performing Observing System Simulation Experiments (OSSEs) in ocean general circulation models. +Python toolbox for performing Observing System Simulation Experiments (OSSEs) +in ocean general circulation models. """ + __author__ = "Ollie Tooth (oliver.tooth@noc.ac.uk)" __credits__ = "National Oceanography Centre (NOC), Southampton, UK" from importlib.metadata import version as _version -from OceanOSSE import ( - cli, - pipeline, -) +from OceanOSSE import cli, pipeline +from OceanOSSE.gridding.regridder import Regridder +from OceanOSSE.io.dataloader import DataLoader +from OceanOSSE.io.datawriter import DataWriter +from OceanOSSE.sampling.sampler import ErrorKernel, ObsSampler +from OceanOSSE.sampling.sampler_nearest_neighbour import NNSampler try: __version__ = _version("OceanOSSE") except Exception: - # Local copy or not installed with setuptools. - # Disable minimum version checks on downstream libraries. __version__ = "9999.0.0" -__all__ = ("cli", "pipeline") \ No newline at end of file +__all__ = ( + "cli", + "pipeline", + "DataLoader", + "DataWriter", + "ErrorKernel", + "ObsSampler", + "Regridder", +) diff --git a/OceanOSSE/cli.py b/OceanOSSE/cli.py index 441fd3f..e5b9bd7 100644 --- a/OceanOSSE/cli.py +++ b/OceanOSSE/cli.py @@ -6,4 +6,119 @@ Created By: Ollie Tooth (oliver.tooth@noc.ac.uk) """ -# -- Import Dependencies -- # \ No newline at end of file +# -- Import dependencies -- # +import logging +import sys + +import typer +from typing_extensions import Annotated, Optional + +from OceanOSSE.pipeline import describe_pipeline, run_pipeline + +from .__init__ import __version__ + +app = typer.Typer() +logger = logging.getLogger(__name__) + + +# -- Define CLI Functions -- # +def create_header( + config_path: str, + log_path: str, +) -> None: + """ + Add OceanOSSE header to log. + + Parameters: + ----------- + config_path : str + Filepath to OceanOSSE config .toml file. + log_path : str + Filepath to OceanOSSE log file. + """ + logger.info( + f""" +╔══════════════════════════════════════════════════════════════╗ +║ OceanOSSE ║ +║ Ocean Observing System Simulation Experiment Tool ║ +╠══════════════════════════════════════════════════════════════╣ + OceanOSSE Version : {__version__} + Python Version : {sys.version.split()[0]} + Config File : {config_path} + Log File : {log_path} +╚══════════════════════════════════════════════════════════════╝ +""", + extra={"simple": True}, + ) + + +def init_logging(log_path: str) -> None: + """ + Initialise OceanOSSE logging. + + Parameters: + ----------- + log_path : str + Filepath to log file. If None, logs to 'ocean_osse.log'. + """ + # === Validate Inputs === # + if not isinstance(log_path, str): + raise TypeError("log_path must be a string.") + + logging.basicConfig( + format="⦿══⦿ OceanOSSE ⦿══⦿ ║ %(levelname)10s ║ %(asctime)s ║ %(message)s", + level=logging.INFO, + datefmt="%Y-%m-%d %H:%M:%S", + handlers=[logging.FileHandler(log_path), logging.StreamHandler()], + ) + + +# === Create Typer App === # +@app.callback() +def main() -> None: + """ + Main callback for Typer app to allow + single run command to be defined. + """ + pass + + +@app.command() +def run( + config: Annotated[str, typer.Argument(help="Path to OceanOSSE config .toml file")], + log: Annotated[ + Optional[str], + typer.Option( + help="Path to write OceanOSSE log file", rich_help_panel="Options" + ), + ] = "ocean_osse.log", + dry_run: Annotated[ + Optional[bool], + typer.Option( + help="Describe OceanOSSE workflow without execution.", + rich_help_panel="Options", + ), + ] = False, +) -> None: + """ + Run OceanOSSE workflow defined by configuration (.toml) file in current process. + """ + # === Initialise Logging === # + init_logging(log_path=log) + create_header(config_path=config, log_path=log) + + # === Run OceanOSSE === # + args = { + "config_file": config, + "log_filepath": log, + } + if dry_run: + describe_pipeline(args=args) + else: + run_pipeline(args=args) + + logging.info("✔ OceanOSSE Completed ✔") + + +if __name__ == "__main__": + app() diff --git a/OceanOSSE/gridding/regridder.py b/OceanOSSE/gridding/regridder.py new file mode 100644 index 0000000..f38f1d7 --- /dev/null +++ b/OceanOSSE/gridding/regridder.py @@ -0,0 +1,122 @@ +""" +regridder.py + +Description: Regridding module for OceanOSSE package. + +Created By: OceanOSSE Development Team (NOC, UK) +""" + +# -- Import Dependencies -- # +from __future__ import annotations + +import abc +import logging +from typing import Self + +import xarray as xr + +logger = logging.getLogger(__name__) + + +# -- Regridder Abstract Base Class -- # +class Regridder(abc.ABC): + """ + Abstract base class for regridding synthetic ocean observations onto + the original model grid, using methods such as objective analysis + or interpolation. + + Parameters + ---------- + target_grid : xarray.Dataset or None, optional + Dataset describing the target grid (coordinates, masks, etc.). + """ + + def __init__( + self, + target_grid: xr.Dataset | None = None, + ) -> None: + if target_grid is not None and not isinstance(target_grid, xr.Dataset): + raise TypeError("``target_grid`` must be an xarray.Dataset or None.") + self._target_grid = target_grid + + def __repr__(self) -> str: + has_grid = self._target_grid is not None + return f"{type(self).__name__}(target_grid={'' if has_grid else None})" + + @classmethod + @abc.abstractmethod + def from_config(cls, config: dict) -> Self: + """ + Construct a Regridder from the from the `[regridding]` table of + the .toml configuration file. + + Parameters + ---------- + config : dict + Configuration dictionary containing input parameters from .toml + configuration file. + + Returns + ------- + Self + Initialised Regridder instance. + """ + ... + + @abc.abstractmethod + def regrid(self, ds: xr.Dataset) -> xr.Dataset: + """ + Regrid the synthetic observation dataset onto the target grid. + + Parameters + ---------- + ds : xarray.Dataset + Synthetic observations dataset. + + Returns + ------- + xarray.Dataset + Dataset of synthetic observations regridded onto target grid. + """ + ... + + +# -- Regridder Implementations -- # + + +class TestRegridder(Regridder): + """ + Regridder used for testing and scaffold validation. + + Returns the the synthetic observations dataset unchanged. + """ + + @classmethod + def from_config(cls, config: dict) -> Self: + """ + Instantiate a TestRegridder from the `[regridding]` table of + the .toml configuration file. + """ + return cls() + + def regrid(self, ds: xr.Dataset) -> xr.Dataset: + """ + Regrid the synthetic observation dataset onto the target grid. + + Parameters + ---------- + ds : xarray.Dataset + Synthetic observations dataset. + + Returns + ------- + xarray.Dataset + Dataset of synthetic observations (unchanged from input). + """ + logger.debug( + "Regridding synthetic observations with TestRegridder -> returns input dataset unchanged." + ) + logging.info( + "--> Completed: Regridded synthetic observations with TestRegridder." + ) + return ds diff --git a/OceanOSSE/io/dataloader.py b/OceanOSSE/io/dataloader.py new file mode 100644 index 0000000..8e6b721 --- /dev/null +++ b/OceanOSSE/io/dataloader.py @@ -0,0 +1,354 @@ +""" +dataloader.py + +Description: DataLoader module for OceanOSSE package. + +Created By: OceanOSSE Development Team (NOC, UK) +""" + +# -- Import Dependencies -- # +from __future__ import annotations + +import abc +import glob +import logging +from typing import Self + +import xarray as xr + +logger = logging.getLogger(__name__) + + +# -- DataLoader Abstract Base Class -- # +class DataLoader(abc.ABC): + """ + DataLoader Base Class to load gridded ocean model data from + persistent local filesystem or cloud object storage. + + Parameters + ---------- + source : dict[str, dict] + Source dictionary of gridded ocean model variables to load. + dimensions : dict[str, str] + Mapping of standard dimension names to input dataset dimension names. + coordinates : dict[str, str] + Mapping of standard coordinate names to input dataset coordinate names. + + Attributes + ---------- + _source : dict[str, dict] + Source dictionary of gridded ocean model variables. + _dimensions : dict[str, str] + Mapping of standard dimension names to input dataset dimension names. + _coordinates : dict[str, str] + Mapping of standard coordinate names to input dataset coordinate names. + """ + + def __init__( + self, + source: dict[str, dict], + dimensions: dict[str, str], + coordinates: dict[str, str], + ): + # -- Verify Inputs -- # + if not isinstance(source, dict): + raise TypeError("``source`` must be a specfied as a dictionary.") + if not isinstance(dimensions, dict): + raise TypeError("``dimensions`` must be a specfied as a dictionary.") + if not isinstance(coordinates, dict): + raise TypeError("``coordinates`` must be a specfied as a dictionary.") + + # -- Class Attributes -- # + self._source = source + self._dimensions = dimensions + self._coordinates = coordinates + + def __repr__(self) -> str: + return ( + f"{type(self).__name__}(" + f"source={self._source!r}, " + f"dimensions={self._dimensions!r}, " + f"coordinates={self._coordinates!r})" + ) + + @classmethod + @abc.abstractmethod + def from_config(cls, config: dict) -> Self: + """ + Abstract class method to instantiate a DataLoader from the `[inputs]` table + of the .toml configuration file. + + This is the required constructor for all DataLoader subclasses - plugin + authors must implement this method for use in OceanOSSE. + + Parameters + ---------- + config : dict + Configuration dictionary containing input parameters from .toml + configuration file. + + Returns + ------- + Self + Initialised DataLoader instance. + """ + ... + + @abc.abstractmethod + def load_data(self) -> xr.Dataset: + """ + Abstract method to load gridded ocean model data into a standardised xarray.Dataset. + + The returned Dataset will be validated by the validate_dataset() method. + + Returns + ------- + xarray.Dataset + Dataset containing standardised gridded ocean model variables. + """ + ... + + def _standardise_dataset(self, ds: xr.Dataset) -> xr.Dataset: + """ + Standardise gridded ocean model xarray.Dataset by renaming dimensions and coordinates to OceanOSSE standard names. + + Parameters + ---------- + ds : xarray.Dataset + Dataset with original dimension and coordinate names. + + Returns + ------- + xarray.Dataset + Dataset with standardised dimension and coordinate names. + """ + # -- Rename dimensions to standard dimensions names -- # + rename_dims = {value: key for key, value in self._dimensions.items()} + ds = ds.rename_dims(rename_dims) + + # -- Assign standard coordinate names and drop any non-standard coordinates -- # + ds = ds.assign_coords( + {coord: ds[var] for coord, var in self._coordinates.items()} + ) + drop_coords = [coord for coord in ds.coords if coord not in self._coordinates] + ds = ds.drop_vars(drop_coords) + + return ds + + def _validate_dataset(self, ds: xr.Dataset) -> None: + """ + Validate the standardised xarray.Dataset. + + Parameters + ---------- + dataset : xarray.Dataset + Dataset containing standardised gridded ocean model variables. + + Raises + ------- + ValueError + If the dataset does not contain the required variables or dimensions. + """ + # -- Validate Required Dimensions -- # + required_dims = ["time", "lev", "j", "i"] + missing_dims = [dim for dim in required_dims if dim not in ds.dims] + if missing_dims: + raise ValueError( + f"{type(self).__name__}: loaded dataset is missing required " + f"dimension(s): {missing_dims}. Found dimensions: {list(ds.dims)}." + ) + + # -- Validate Required Coordinates -- # + required_coords = ["time", "depth", "lat", "lon"] + missing_coords = [coord for coord in required_coords if coord not in ds.coords] + if missing_coords: + raise ValueError( + f"{type(self).__name__}: loaded dataset is missing required " + f"coordinate(s): {missing_coords}. Found coordinates: {list(ds.coords)}." + ) + + +# -- DataLoader Implementations -- # +class NetCDFDataLoader(DataLoader): + """ + DataLoader implementation to load gridded ocean model data from NetCDF files. + + Parameters + ---------- + source : dict[str, dict] + Source dictionary of gridded ocean model variables to load. + dimensions : dict[str, str] + Mapping of standard dimension names to input dataset dimension names. + coordinates : dict[str, str] + Mapping of standard coordinate names to input dataset coordinate names. + """ + + def __init__( + self, + source: dict[str, dict], + dimensions: dict[str, str], + coordinates: dict[str, str], + ) -> None: + # -- Initialise parent DataLoader class -- # + super().__init__(source, dimensions, coordinates) + + @classmethod + def from_config(cls, config: dict) -> Self: + """ + Instantiate a NetCDFDataLoader from the `[inputs]` table of the .toml configuration file. + """ + # -- Verify Input -- # + if not isinstance(config, dict): + raise TypeError("config must be a dictionary.") + + # -- Instantiate DataLoader with source dict from config -- # + source = config["inputs"].get("variables", None) + if source is None: + raise ValueError( + "Missing 'variables' entry in [inputs] table of config .toml file." + ) + dimensions = config["inputs"].get("dimensions", None) + if dimensions is None: + raise ValueError( + "Missing 'dimensions' entry in [inputs] table of config .toml file." + ) + coordinates = config["inputs"].get("coordinates", None) + if coordinates is None: + raise ValueError( + "Missing 'coordinates' entry in [inputs] table of config .toml file." + ) + + return cls(source=source, dimensions=dimensions, coordinates=coordinates) + + def _open_dataset( + self, + filepath: str, + variables: list[str], + open_kwargs: dict, + ) -> xr.Dataset: + """ + Open input variable Dataset from a netCDF file(s). + + Parameters: + ----------- + filepath : str + Filepath pattern to input variable netCDF file(s). + variables : list[str] + Name of variable(s) to load from the dataset. + open_kwargs : dict + Additional keyword arguments to pass to xarray.open_dataset + or xarray.open_mfdataset. + + Returns: + -------- + xr.Dataset + Standardised variable Dataset. + """ + # -- Validate Inputs -- # + if not isinstance(filepath, str): + raise TypeError("filepath must be a string.") + if not isinstance(variables, list): + raise TypeError("variables must be a list of strings.") + if not isinstance(open_kwargs, dict): + raise TypeError("open_kwargs must be a dictionary.") + + filepaths = glob.glob(filepath) + if len(filepaths) == 0: + raise FileNotFoundError(f"No files found matching filepath: {filepath}") + + # Define CFDatetimeCoder to decode time coords: + coder = xr.coders.CFDatetimeCoder(time_unit="s") + + # -- Open input variable dataset -- # + if len(filepaths) == 1: + if open_kwargs is None: + open_kwargs = {"engine": "netcdf4"} + try: + logging.info( + f"In Progress: Opening variable(s) {variables} from netCDF file '{filepath}'" + ) + ds_var = xr.open_dataset( + filepaths[0], decode_times=coder, **open_kwargs + )[variables] + logging.info(f"--> Completed: Opened variable(s) {variables}.") + + except FileNotFoundError as e: + raise FileNotFoundError( + f"Failed to open netCDF file: {filepath}" + ) from e + else: + if open_kwargs is None: + open_kwargs = { + "data_vars": "minimal", + "compat": "no_conflicts", + "parallel": False, + "engine": "netcdf4", + } + if variables is not None: + open_kwargs["preprocess"] = lambda ds: ds[variables] + try: + logging.info( + f"In Progress: Opening variable(s) {variables} from multiple netCDF files '{filepath}'" + ) + ds_var = xr.open_mfdataset( + filepaths, decode_times=coder, **open_kwargs + )[variables] + logging.info(f"--> Completed: Opened variable(s) {variables}.") + + except FileNotFoundError as e: + raise FileNotFoundError( + f"Failed to open netCDF files: {filepaths}" + ) from e + + return ds_var + + def load_data(self) -> xr.Dataset: + """ + Load gridded ocean model data from netCDF file(s) into a standardised xarray.Dataset. + """ + # -- Define variable names, filepaths and open_kwargs from source dict -- # + variables = self._source.keys() + open_kwargs_list = [ + d_var.get("open_kwargs", {}) for d_var in self._source.values() + ] + filepaths = [d_var.get("path", "") for d_var in self._source.values()] + n_filepaths = len(set(filepaths)) + + if n_filepaths == 1: + # -- All Variables Share Common Filepath Pattern -- # + # Merge open_kwargs for all variables sharing the same filepath pattern: + open_kwargs = {} + for d_open_kwargs in open_kwargs_list: + open_kwargs.update(d_open_kwargs) + + # Load all variables into single xarray.Dataset: + ds = self._open_dataset( + filepath=filepaths[0], + variables=list(variables), + open_kwargs=open_kwargs, + ) + else: + # -- Variables Have Different Filepath Patterns -- # + ds_list = [] + for filepath, variable, open_kwargs in zip( + filepaths, variables, open_kwargs_list, strict=True + ): + ds_list.append( + self._open_dataset( + filepath=filepath, variables=[variable], open_kwargs=open_kwargs + ) + ) + # Merge individual variable datasets into single dataset: + ds = xr.merge(ds_list, compat="no_conflicts", join="exact") + + # -- Standardise dataset dimensions and coordinates -- # + ds = self._standardise_dataset(ds) + logging.info( + f"--> Completed: Standardised dataset dimensions {list(ds.sizes.keys())} and coordinates {list(ds.coords.keys())}." + ) + + # -- Validate standardised dataset -- # + self._validate_dataset(ds) + logging.info("--> Completed: Validated standardised dataset.") + + return ds diff --git a/OceanOSSE/io/datawriter.py b/OceanOSSE/io/datawriter.py new file mode 100644 index 0000000..1b0a1f4 --- /dev/null +++ b/OceanOSSE/io/datawriter.py @@ -0,0 +1,363 @@ +""" +datawriter.py + +Description: DataWriter module for OceanOSSE package. + +Created By: OceanOSSE Development Team (NOC, UK) +""" + +# -- Import Dependencies -- # +from __future__ import annotations + +import abc +import logging +from typing import Self + +import cftime +import numpy as np +import xarray as xr + +logger = logging.getLogger(__name__) + + +# -- DataWriter Abstract Base Class -- # +class DataWriter(abc.ABC): + """ + Abstract base class for writing a processed xarray.Dataset to persistent + storage (netCDF, Zarr, etc.). + + Parameters + ---------- + dimensions : dict[str, str] + Mapping of standard dimension names to input dataset dimension names. + coordinates : dict[str, str] + Mapping of standard coordinate names to input dataset variable names. + output_dir : str + Directory in which to write the output file. + output_name : str + Name of the output file (without extension). + date_format : str + Date format for time dimension in output filename. + Options are 'Y' (YYYY), 'M' (YYYY-MM) or 'D' (YYYY-MM-DD). + chunks : dict[str, int], optional + Dictionary defining chunk sizes for output dataset. + Default is None, meaning no chunking is applied. + writer_kwargs : dict[str, any], optional + Additional keyword arguments to pass to the underlying writing function + (e.g. ``xarray.Dataset.to_netcdf()`` or ``xarray.Dataset.to_zarr()``). + Default is None, meaning no additional keyword arguments are applied. + + Attributes + ---------- + _chunks : dict[str, int] or None + Dictionary defining chunk sizes for output dataset, or None if no chunking is applied. + _coordinates : dict[str, str] + Mapping of standard coordinate names to input dataset variable names. + _date_format : str + Date format for time dimension in output filename. + _dimensions : dict[str, str] + Mapping of standard dimension names to input dataset dimension names. + _output_name : str + Name of the output file (without extension). + _output_dir : str + Directory in which to write the output file. + _writer_kwargs : dict[str, any] or None + Additional keyword arguments to pass to the underlying writing function, or None if no additional keyword arguments are applied. + """ + + def __init__( + self, + dimensions: dict[str, str], + coordinates: dict[str, str], + output_dir: str, + output_name: str, + date_format: str, + chunks: dict[str, int] | None = None, + writer_kwargs: dict[str, any] | None = None, + ) -> None: + # -- Validate Input -- # + if not isinstance(dimensions, dict): + raise TypeError("``dimensions`` must be a specfied as a dictionary.") + if not isinstance(coordinates, dict): + raise TypeError("``coordinates`` must be a specfied as a dictionary.") + if not isinstance(output_dir, str): + raise TypeError("``output_dir`` must be a string.") + if not isinstance(output_name, str): + raise TypeError("``output_name`` must be a string.") + if not isinstance(date_format, str): + raise TypeError("``date_format`` must be a string.") + if (chunks is not None) and not isinstance(chunks, dict): + raise TypeError("``chunks`` must be a dict or None.") + if (writer_kwargs is not None) and not isinstance(writer_kwargs, dict): + raise TypeError("``writer_kwargs`` must be a dict or None.") + + # -- Class Attributes -- # + self._dimensions = dimensions + self._coordinates = coordinates + self._output_dir = output_dir + self._output_name = output_name + self._date_format = date_format + self._chunks = chunks + self._writer_kwargs = writer_kwargs + + def __repr__(self) -> str: + return ( + f"{type(self).__name__}(" + f"output_dir={self._output_dir!r}, " + f"output_name={self._output_name!r}, " + f"date_format={self._date_format!r}, " + f"chunks={self._chunks!r}, " + f"writer_kwargs={self._writer_kwargs!r})" + ) + + @classmethod + @abc.abstractmethod + def from_config(cls, config: dict) -> Self: + """ + Abstract class method to instantiate a DataWriter from the `[outputs]` table + of the .toml configuration file. + + This is the required constructor for all DataWriter subclasses - plugin + authors must implement this method for use in OceanOSSE. + + Parameters + ---------- + config : dict + Configuration dictionary containing input parameters from .toml + configuration file. + + Returns + ------- + Self + Initialised DataWriter instance. + """ + ... + + @abc.abstractmethod + def write_data(self, ds: xr.Dataset) -> str: + """ + Abstract method to write OceanOSSE output xarray.Dataset to persistent storage. + + Parameters + ---------- + ds : xarray.Dataset + Processed dataset to write. + + Returns + ------- + str + Resolved output path where data was written (for logging). + """ + ... + + def _reconstruct_dataset(self, ds: xr.Dataset) -> xr.Dataset: + """ + Reconstruct gridded ocean model xarray.Dataset from OceanOSSE output by renaming dimensions and coordinates to input names. + + Parameters + ---------- + ds : xarray.Dataset + Dataset with OceanOSSE standardised dimension and coordinate names. + + Returns + ------- + xarray.Dataset + Dataset with original gridded ocean model dimension and coordinate names. + """ + # -- Rename dimensions to original gridded ocean model dimension names -- # + ds = ds.rename_dims(self._dimensions) + + # -- Assign original coordinate names and drop any standard coordinates -- # + d_coords = {value: ds[key] for key, value in self._coordinates.items()} + ds = ds.assign_coords(d_coords) + drop_coords = [coord for coord in ds.coords if coord not in d_coords] + ds = ds.drop_vars(drop_coords) + + return ds + + def _get_output_filepath( + self, + ds: xr.Dataset, + output_dir: str, + output_name: str, + file_format: str, + date_format: str, + ) -> str: + """ + Define resolved filepath to OceanOSSE output file(s). + + Parameters: + ----------- + ds : xr.Dataset + Output xarray Dataset. + output_dir : str + Directory to save output file. + output_name : str + Prefix of output file name. + file_format : str + Output file format. Options are 'netcdf' or 'zarr'. + date_format : str + Date format for datetime limits in output filename. + Options are 'Y' (YYYY), 'M' (YYYY-MM) or 'D' (YYYY-MM-DD). + """ + # -- Validate Inputs -- # + if not isinstance(ds, xr.Dataset): + raise TypeError("ds must be an xr.Dataset.") + if not isinstance(output_dir, str): + raise TypeError("output_dir must be a string.") + if not isinstance(output_name, str): + raise TypeError("output_name must be a string.") + if file_format not in ["netcdf", "zarr"]: + raise ValueError("file_format must be either 'netcdf' or 'zarr'.") + + # -- Create Date String for Output Fileapath -- # + # Define time-limits of output dataset: + time_limits = ds["time"].values[[0, -1]] + + # Create date string from CFTime datetime objects: + if isinstance(time_limits[0], cftime.datetime): + if date_format == "Y": + fmt = "%Y" + elif date_format == "M": + fmt = "%Y-%m" + elif date_format == "D": + fmt = "%Y-%m-%d" + else: + raise ValueError( + f"Invalid date_format: '{date_format}'. Options are 'Y', 'M', 'D'." + ) + date_str = f"{time_limits[0].strftime(fmt)}-{time_limits[1].strftime(fmt)}" + + # Create date string from numpy datetime64: + elif isinstance(time_limits[0], np.datetime64): + date_str = f"{np.datetime_as_string(time_limits[0], unit=date_format)}-{np.datetime_as_string(time_limits[1], unit=date_format)}" + else: + raise TypeError( + f"Invalid type ({type(time_limits[0])}) for dates. Expected cftime.datetime or np.datetime64." + ) + + # -- Define Output Filepath -- # + if file_format == "netcdf": + output_filename = f"{output_dir}/{output_name}_{date_str}.nc" + elif file_format == "zarr": + output_filename = f"{output_dir}/{output_name}_{date_str}.zarr" + + return output_filename + + +# -- DataWriter Implementations -- # +class NetCDFDataWriter(DataWriter): + """ + DataWriter that serialises an xarray.Dataset to a netCDF file on disk. + + Parameters + ---------- + dimensions : dict[str, str] + Mapping of standard dimension names to input dataset dimension names. + coordinates : dict[str, str] + Mapping of standard coordinate names to input dataset variable names. + output_dir : str + Directory in which to write the output file. + output_name : str + Base filename (without ``.nc`` extension). + date_format : str + Date format for time dimension in output filename. + Options are 'Y' (YYYY), 'M' (YYYY-MM) or 'D' (YYYY-MM-DD). + chunks : dict[str, int], optional + Dictionary defining chunk sizes for output dataset. + Default is None, meaning no chunking is applied. + writer_kwargs : dict[str, any], optional + Additional keyword arguments to pass to xarray.Dataset.to_netcdf. + """ + + def __init__( + self, + dimensions: dict[str, str], + coordinates: dict[str, str], + output_dir: str, + output_name: str, + date_format: str, + chunks: dict[str, int] | None = None, + writer_kwargs: dict[str, any] | None = None, + ) -> None: + # -- Initialise parent DataWriter class -- # + super().__init__( + dimensions=dimensions, + coordinates=coordinates, + output_dir=output_dir, + output_name=output_name, + date_format=date_format, + chunks=chunks, + writer_kwargs=writer_kwargs, + ) + + @classmethod + def from_config(cls, config: dict) -> Self: + """ + Instantiate a NetCDFDataWriter from the `[outputs]` table of the .toml configuration file. + """ + # -- Verify Input -- # + if not isinstance(config, dict): + raise TypeError("config must be a dictionary.") + + # -- Instantiate NetCDFDataWriter from config -- # + inputs = config["inputs"] + outputs = config["outputs"] + return cls( + dimensions=inputs["dimensions"], + coordinates=inputs["coordinates"], + output_dir=outputs["output_dir"], + output_name=outputs["output_name"], + date_format=outputs["date_format"], + chunks=outputs.get("chunks", None), + writer_kwargs=outputs.get("writer_kwargs", None), + ) + + def write_data(self, ds: xr.Dataset) -> None: + """ + Write OceanOSSE output xarray.Dataset to a netCDF file. + + Parameters + ---------- + ds : xarray.Dataset + Processed dataset to write. + """ + # -- Validate Inputs -- # + if not isinstance(ds, xr.Dataset): + raise TypeError("ds must be an xr.Dataset.") + if self._chunks is not None and not isinstance(self._chunks, dict): + raise TypeError("chunks must be a dictionary.") + if not isinstance(self._output_dir, str): + raise TypeError("output_dir must be a string.") + if not isinstance(self._output_name, str): + raise TypeError("output_name must be a string.") + if self._date_format not in ["Y", "M", "D"]: + raise ValueError("date_format must be 'Y', 'M' or 'D'.") + + # -- Define Output Filepath -- # + output_filepath = self._get_output_filepath( + ds=ds, + output_dir=self._output_dir, + output_name=self._output_name, + file_format="netcdf", + date_format=self._date_format, + ) + + # -- Reconstruct Dataset with Original Dimension and Coordinate Names -- # + ds = self._reconstruct_dataset(ds) + + # -- Optionally Apply Chunking -- # + if self._chunks is not None: + ds = ds.chunk(self._chunks) + + # -- Write Dataset to NetCDF -- # + if self._writer_kwargs is None: + # Default writer_kwargs for netCDF output: + self._writer_kwargs = { + "unlimited_dims": self._dimensions.get("time"), + "mode": "w", + } + ds.to_netcdf(path=output_filepath, **self._writer_kwargs) + logging.info( + f"--> Completed: Written output dataset to netCDF file: {output_filepath}" + ) diff --git a/OceanOSSE/pipeline.py b/OceanOSSE/pipeline.py index f6962de..a33c3bb 100644 --- a/OceanOSSE/pipeline.py +++ b/OceanOSSE/pipeline.py @@ -6,4 +6,320 @@ Created By: Ollie Tooth (oliver.tooth@noc.ac.uk) """ -# -- Import Dependencies -- # \ No newline at end of file +# -- Import Dependencies -- # +import logging + +from OceanOSSE.gridding.regridder import Regridder, TestRegridder +from OceanOSSE.io.dataloader import DataLoader, NetCDFDataLoader +from OceanOSSE.io.datawriter import DataWriter, NetCDFDataWriter +from OceanOSSE.sampling.sampler import ObsSampler, TestObsSampler +from OceanOSSE.utils import import_class, load_config + +logger = logging.getLogger(__name__) + +# -- Registries -- # +_DATA_LOADER_REGISTRY: dict[str, type[DataLoader]] = {"netcdf": NetCDFDataLoader} + +_OBS_SAMPLER_REGISTRY: dict[str, type[ObsSampler]] = {"test": TestObsSampler} + +_REGRIDDER_REGISTRY: dict[str, type[Regridder]] = {"test": TestRegridder} + +_DATA_WRITER_REGISTRY: dict[str, type[DataWriter]] = {"netcdf": NetCDFDataWriter} + + +# -- Factory Functions -- # +def _create_DataLoader(config: dict) -> DataLoader: + """ + Instantiate a DataLoader from the `[inputs]` table of the .toml configuration file. + + Options: + - Built-in Registry: + - name: "netcdf" -> NetCDFDataLoader + + - Plugins: + - Custom DataLoader imported from `module` and `name` specified in config. + """ + # -- Validate Inputs -- # + if not isinstance(config, dict): + raise TypeError("``config`` must be a dictionary.") + + # -- Instantiate DataLoader -- # + inputs = config["inputs"] + + # 1. Plugin DataLoader: + if (inputs.get("module") is not None) and (inputs.get("name") is not None): + # -- Import custom DataLoader class -- # + data_loader = import_class( + module=inputs["module"], class_name=inputs["name"], class_type=DataLoader + ) + logger.info( + f"Completed: Created DataLoader from Plugin: {inputs['module']}.{inputs['name']}" + ) + + # 2. Registry DataLoader: + else: + # -- Use DataLoader class from registry -- # + format = inputs.get("format", "netcdf4") + try: + data_loader = _DATA_LOADER_REGISTRY[format] + except KeyError as e: + raise KeyError( + f"DataLoader '{format}' not found in registry. Available options: {list(_DATA_LOADER_REGISTRY.keys())}" + ) from e + logger.info( + f"Completed: Created DataLoader from Registry -> {format}: {data_loader.__name__}" + ) + + return data_loader.from_config(config=config) + + +def _create_ObsSampler(config: dict) -> ObsSampler: + """ + Instantiate an ObsSampler from the `[sampling]` table of the .toml configuration file. + + Options: + - Built-in Registry: + - name: "test" -> TestObsSampler + + - Plugins: + - Custom ObsSampler imported from `module` and `name` specified in config. + """ + # -- Validate Inputs -- # + if not isinstance(config, dict): + raise TypeError("``config`` must be a dictionary.") + + # -- Instantiate ObsSampler -- # + sampling = config["sampling"] + + # 1. Plugin ObsSampler: + if (sampling.get("module") is not None) and (sampling.get("name") is not None): + # -- Import custom ObsSampler class -- # + obs_sampler = import_class( + module=sampling["module"], + class_name=sampling["name"], + class_type=ObsSampler, + ) + logger.info( + f"Completed: Created ObsSampler from Plugin: {sampling['module']}.{sampling['name']}" + ) + + # 2. Registry ObsSampler: + else: + # -- Use ObsSampler class from registry -- # + name = sampling.get("name", "test") + try: + obs_sampler = _OBS_SAMPLER_REGISTRY[name] + except KeyError as e: + raise KeyError( + f"ObsSampler '{name}' not found in registry. Available options: {list(_OBS_SAMPLER_REGISTRY.keys())}" + ) from e + logger.info( + f"Completed: Created ObsSampler from Registry -> {name}: {obs_sampler.__name__}" + ) + + return obs_sampler.from_config(config=config) + + +def _create_Regridder(config: dict) -> Regridder: + """ + Instantiate a Regridder from the `[regridding]` table of the .toml configuration file. + + Options: + - Built-in Registry: + - name: "test" -> TestRegridder + + - Plugins: + - Custom Regridder imported from `module` and `name` specified in config. + """ + # -- Validate Inputs -- # + if not isinstance(config, dict): + raise TypeError("``config`` must be a dictionary.") + + # -- Instantiate Regridder -- # + regridding = config["regridding"] + + # 1. Plugin Regridder: + if (regridding.get("module") is not None) and (regridding.get("name") is not None): + # -- Import custom Regridder class -- # + regridder = import_class( + module=regridding["module"], + class_name=regridding["name"], + class_type=Regridder, + ) + logger.info( + f"Completed: Created Regridder from Plugin: {regridding['module']}.{regridding['name']}" + ) + + # 2. Registry Regridder: + else: + # -- Use Regridder class from registry -- # + name = regridding.get("name", "test") + try: + regridder = _REGRIDDER_REGISTRY[name] + except KeyError as e: + raise KeyError( + f"Regridder '{name}' not found in registry. Available options: {list(_REGRIDDER_REGISTRY.keys())}" + ) from e + logger.info( + f"Completed: Created Regridder from Registry -> {name}: {regridder.__name__}" + ) + + return regridder.from_config(config=config) + + +def _create_DataWriter(config: dict) -> DataWriter: + """ + Instantiate a DataWriter from the `[outputs]` table of the .toml configuration file. + + Options: + - Built-in Registry: + - name: "netcdf" -> NetCDFDataWriter + + - Plugins: + - Custom DataWriter imported from `module` and `name` specified in config. + """ + # -- Validate Inputs -- # + if not isinstance(config, dict): + raise TypeError("``config`` must be a dictionary.") + + # -- Instantiate DataWriter -- # + outputs = config["outputs"] + + # 1. Plugin DataWriter: + if (outputs.get("module") is not None) and (outputs.get("name") is not None): + # -- Import custom DataWriter class -- # + data_writer = import_class( + module=outputs["module"], class_name=outputs["name"], class_type=DataWriter + ) + logger.info( + f"Completed: Created DataWriter from Plugin: {outputs['module']}.{outputs['name']}" + ) + + # 2. Registry DataWriter: + else: + # -- Use DataWriter class from registry -- # + format = outputs.get("format", "netcdf4") + try: + data_writer = _DATA_WRITER_REGISTRY[format] + except KeyError as e: + raise KeyError( + f"DataWriter '{format}' not found in registry. Available options: {list(_DATA_WRITER_REGISTRY.keys())}" + ) from e + logger.info( + f"Completed: Created DataWriter from Registry -> {format}: {data_writer.__name__}" + ) + + return data_writer.from_config(config=config) + + +# -- Define Pipeline Functions -- # +def run_pipeline(args: dict) -> None: + """ + Run OceanOSSE pipeline using specified config .ini file. + + Pipeline Steps: + 1. Instantiate DataLoader -> Load standardised ocean model dataset. + 2. Instantiate ObsSampler -> Sample synthetic ocean observations from model dataset. + 3. Instantiate Regridder -> Regrid synthetic observations onto original model grid. + 4. Instantiate DataWriter -> Write output dataset to file. + + Parameters: + ----------- + args : dict + Command line arguments. + """ + # === Inputs === # + logger.info("==== Inputs ====") + # Load config .toml file: + config = load_config(config_path=args["config_file"]) + logger.info(f"Completed: Read & validated config file -> {args['config_file']}") + + # Load ocean model dataset using DataLoader: + data_loader = _create_DataLoader(config=config) + logger.info(f"In Progress: Loading ocean model dataset using {data_loader}...") + ds_mdl = data_loader.load_data() + + # === Sampling === # + logger.info("==== Sampling ====") + # Sample synthetic ocean observations using ObsSampler: + obs_sampler = _create_ObsSampler(config=config) + logger.info( + f"In Progress: Sampling synthetic ocean observations using {obs_sampler}..." + ) + ds_obs = obs_sampler.sample(ds=ds_mdl) + + # === Regridding === # + logger.info("==== Regridding ====") + # Regrid synthetic ocean observations using Regridder: + regridder = _create_Regridder(config=config) + logger.info( + f"In Progress: Regridding synthetic ocean observations using {regridder}..." + ) + ds_regridded = regridder.regrid(ds=ds_obs) + + # === Outputs === # + logger.info("==== Outputs ====") + # Write output dataset to file using DataWriter: + data_writer = _create_DataWriter(config=config) + logger.info(f"In Progress: Writing output dataset to file using {data_writer}...") + data_writer.write_data(ds=ds_regridded) + # Close all files: + for ds in [ds_mdl, ds_obs, ds_regridded]: + ds.close() + logger.info("--> Completed: Closed all dataset files.") + + +def describe_pipeline(args: dict) -> str: + """ + Describe & validate OceanOSSE pipeline using config. + + Parameters: + ----------- + args : dict + Command line arguments. + + Returns: + -------- + str + Description of OceanOSSE pipeline. + """ + # === Inputs === # + logger.info("==== Inputs ====") + # Load config .toml file: + config = load_config(config_path=args["config_file"]) + logger.info(f"Completed: Read & validated config file -> {args['config_file']}") + + # Load ocean model dataset using DataLoader: + data_loader = _create_DataLoader(config=config) + logger.info("Action: Load ocean model dataset using...") + logger.info("* DataLoader : %r", data_loader) + + # === Sampling === # + logger.info("==== Sampling ====") + # Sample synthetic ocean observations using ObsSampler: + obs_sampler = _create_ObsSampler(config=config) + logger.info("Action: Sample synthetic ocean observations using...") + logger.info("* ObsSampler : %r", obs_sampler) + + # === Regridding === # + logger.info("==== Regridding ====") + # Regrid synthetic ocean observations using Regridder: + regridder = _create_Regridder(config=config) + logger.info("Action: Regrid synthetic ocean observations using...") + logger.info("* Regridder : %r", regridder) + + # === Outputs === # + logger.info("==== Outputs ====") + # Write output dataset to file using DataWriter: + data_writer = _create_DataWriter(config=config) + logger.info("Action: Write output dataset to file using...") + logger.info("* DataWriter : %r", data_writer) + # Define output filepath: + outputs = config["outputs"] + if outputs["format"] == "netcdf": + extension = "nc" + else: + extension = "zarr" + logger.info( + f"* Output File = {outputs['output_dir']}/{outputs['output_name']}_YYYY-MM_YYYY-MM.{extension}" + ) diff --git a/OceanOSSE/sampling/sampler.py b/OceanOSSE/sampling/sampler.py new file mode 100644 index 0000000..8e23363 --- /dev/null +++ b/OceanOSSE/sampling/sampler.py @@ -0,0 +1,334 @@ +""" +sampler.py + +Description: Sampling module for OceanOSSE package. + +Created By: OceanOSSE Development Team (NOC, UK) +""" + +# -- Import Dependencies -- # +from __future__ import annotations + +import abc +import logging +from typing import Self + +import xarray as xr + +from OceanOSSE.utils import import_class + +logger = logging.getLogger(__name__) + + +# -- Utility Functions -- # +def get_error_kernels(config: dict) -> list[ErrorKernel] | None: + """ + Utility function to instantiate ErrorKernel instances from the `[sampling]` + table of the .toml configuration file. + + Parameters + ---------- + config : dict + Configuration dictionary containing input parameters from .toml + configuration file. + + Returns + ------- + list[ErrorKernel] | None + List of initialised ErrorKernel instances, or None if no kernels are + specified in the configuration. + """ + error_kernels_config = config["sampling"].get("error_kernels", None) + + if error_kernels_config is not None: + _ERROR_KERNEL_REGISTRY = {"test": TestErrorKernel} + kernels: list[ErrorKernel] = [] + for kernel_cfg in error_kernels_config: + if ("module" in kernel_cfg) and ("name" in kernel_cfg): + # -- Import custom ErrorKernel class -- # + Kernel = import_class( + module=kernel_cfg["module"], + class_name=kernel_cfg["name"], + class_type=ErrorKernel, + ) + + else: + # -- Use ErrorKernel class from registry -- # + try: + Kernel = _ERROR_KERNEL_REGISTRY[kernel_cfg["name"]] + except KeyError as e: + raise KeyError( + f"ErrorKernel name '{kernel_cfg['name']}' not found in registry." + ) from e + + # -- Instantiate ErrorKernel from configuration -- # + kernels.append(Kernel.from_config(config=config)) + + return kernels + + +# -- ErrorKernel Abstract Base Class -- # +class ErrorKernel(abc.ABC): + """ + Abstract base class for applying instrument or representation errors + to synthetic ocean observations. + + ErrorKernel transforms a sampled xarray.Dataset by adding noise, + applying a bias, convolving a point-spread function, etc. + + Multiple kernels can be chained by an :class:`ObsSampler` and are applied + sequentially in declaration order. + """ + + def __repr__(self) -> str: + return f"{type(self).__name__}()" + + @classmethod + @abc.abstractmethod + def from_config(cls, config: dict) -> Self: + """ + Abstract class method to instantiate an ErrorKernel from the `[sampling]` + table of the .toml configuration file. + + This is the required constructor for all ErrorKernel subclasses - plugin + authors must implement this method for use in OceanOSSE. + + Parameters + ---------- + config : dict + Configuration dictionary containing input parameters from .toml + configuration file. + + Returns + ------- + Self + Initialised ErrorKernel instance. + """ + ... + + @abc.abstractmethod + def apply(self, ds: xr.Dataset) -> xr.Dataset: + """ + Abstract method to apply the error kernel to an xarray.Dataset of + synthetic observations. + + Parameters + ---------- + ds : xarray.Dataset + Synthetic observations dataset produced by `ObsSampler.sample`. + + Returns + ------- + xarray.Dataset + Synthetic observations dataset with error applied. + """ + ... + + +class TestErrorKernel(ErrorKernel): + """ + ErrorKernel used for testing and scaffold validation. + + Returns the synthetic observations xarray.Dataset unchanged. + """ + + @classmethod + def from_config(cls, config: dict) -> Self: + """ + Instantiate a TestErrorKernel from the `[sampling]` table of + the .toml configuration file. + """ + return cls() + + def apply(self, ds: xr.Dataset) -> xr.Dataset: + """ + Apply the TestErrorKernel to an xarray.Dataset of synthetic observations. + + Parameters + ---------- + ds : xarray.Dataset + Synthetic observations dataset produced by `ObsSampler.sample`. + + Returns + ------- + xarray.Dataset + Synthetic observations dataset unchanged. + """ + logger.debug( + "Applying TestErrorKernel -> returns synthetic observations dataset unchanged." + ) + return ds + + +# -- ObsSampler Abstract Base Class -- # +class ObsSampler(abc.ABC): + """ + Abstract base class for sampling gridded ocean model output analogously + to an ocean observing platform (e.g., Argo floats). + + Parameters + ---------- + error_kernels : list[ErrorKernel], optional + List of ErrorKernel instances to apply sequentially to the sampled + synthetic observations dataset, by default None. + """ + + def __init__(self, error_kernels: list[ErrorKernel] | None = None): + # -- Validate Inputs -- # + if error_kernels is not None: + if not isinstance(error_kernels, list): + raise TypeError( + "`error_kernels` must be a list of ErrorKernel instances." + ) + for n, kernel in enumerate(error_kernels): + if not isinstance(kernel, ErrorKernel): + raise TypeError(f"`error_kernels[{n}]` must be an ErrorKernel.") + + # -- Class Attributes -- # + self._error_kernels = error_kernels + + def __repr__(self) -> str: + return f"{type(self).__name__}(error_kernels={self._error_kernels!r})" + + @classmethod + @abc.abstractmethod + def from_config(cls, config: dict) -> Self: + """ + Abstract class method to instantiate an ObsSampler from the `[sampling]` + table of the .toml configuration file. + + This is the required constructor for all ObsSampler subclasses - plugin + authors must implement this method for use in OceanOSSE. + + Parameters + ---------- + config : dict + Configuration dictionary containing input parameters from .toml + configuration file. + + Returns + ------- + Self + Initialised ObsSampler instance. + """ + ... + + @abc.abstractmethod + def collect_samples(self, ds: xr.Dataset) -> xr.Dataset: + """ + Abstract method to sample a gridded xarray.Dataset of ocean model output + to produce a synthetic observations dataset. + + This is the required sampling method for all ObsSampler subclasses - + plugin authors must implement this method for use in OceanOSSE. + + Parameters + ---------- + ds : xarray.Dataset + Gridded ocean model output dataset. + + Returns + ------- + xarray.Dataset + Sampled synthetic observations dataset. + """ + ... + + def apply_errors(self, ds: xr.Dataset) -> xr.Dataset: + """ + Apply all registered `ErrorKernel` instances to synthetic + observations sequentially. + + If no kernels are registered, the synthetic observations + dataset is returned unchanged. + + Parameters + ---------- + ds : xarray.Dataset + Synthetic observations dataset. + + Returns + ------- + xarray.Dataset + Synthetic observations dataset with all error kernels + applied in order. + """ + # -- Apply each Error Kernel sequentially -- # + if self._error_kernels is not None: + for kernel in self._error_kernels: + logger.debug(f"Applying ErrorKernel --> {repr(kernel)}") + ds = kernel.apply(ds) + logging.info( + "--> Completed: Applied ErrorKernels to synthetic observations." + ) + + return ds + + def sample(self, ds: xr.Dataset) -> xr.Dataset: + """ + Perform sampling pipeline for chosen ocean observing platform. + + Parameters + ---------- + ds : xarray.Dataset + Gridded ocean model dataset. + + Returns + ------- + xarray.Dataset + Synthetic observations dataset with errors applied. + """ + # -- Sample the gridded ocean model output -- # + ds_sampled = self.collect_samples(ds) + logging.info( + "--> Completed: Collected samples from ocean model dataset using ObsSampler." + ) + + # -- Apply error kernels sequentially to the synthetic observations -- # + ds_obs = self.apply_errors(ds_sampled) + + return ds_obs + + +# -- ObsSampler Implementations -- # + + +class TestObsSampler(ObsSampler): + """ + ObsSampler used for testing and scaffold validation. + + Returns the input gridded ocean model dataset unchanged as the synthetic + observations dataset. + """ + + @classmethod + def from_config(cls, config: dict) -> Self: + """ + Instantiate a TestObsSampler from the `[sampling]` table of + the .toml configuration file. + """ + # -- Collect ErrorKernel instances from configuration -- # + error_kernels = get_error_kernels(config=config) + + # -- Instantiate TestObsSampler with collected ErrorKernel instances -- # + return cls(error_kernels=error_kernels or None) + + def collect_samples(self, ds: xr.Dataset) -> xr.Dataset: + """ + Sample a gridded xarray.Dataset of ocean model output to produce a + synthetic observations dataset. + + Parameters + ---------- + ds : xarray.Dataset + Gridded ocean model output dataset. + + Returns + ------- + xarray.Dataset + Synthetic observations dataset (unchanged from input). + """ + logger.debug( + "Collecting samples with TestObsSampler -> returns input dataset unchanged." + ) + return ds diff --git a/OceanOSSE/sampling/sampler_nearest_neighbour.py b/OceanOSSE/sampling/sampler_nearest_neighbour.py new file mode 100644 index 0000000..eb00fc9 --- /dev/null +++ b/OceanOSSE/sampling/sampler_nearest_neighbour.py @@ -0,0 +1,280 @@ +# =================================================================== +# Copyright 2025 National Oceanography Centre +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0. +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. See the License for the specific language governing +# permissions and limitations under the License. +# =================================================================== + +""" +sampler_nearest_neighbour.py + +Description: Sampling module for OceanOSSE package. + +Created By: OceanOSSE Development Team (NOC, UK) +""" + +# -- Import Dependencies -- # +from __future__ import annotations + +import logging + +import xarray as xr +import numpy as np + +from OceanOSSE.utils import import_class +from OceanOSSE.sampling.sampler import ErrorKernel, ObsSampler + +logger = logging.getLogger(__name__) + + +class NNSampler(ObsSampler): + """ + Class for sampling gridded ocean model output analogously + to an ocean observing platform (e.g., Argo floats). + """ + def __init__(self, error_kernels: list[ErrorKernel] | None = None): + # -- Validate Inputs -- # + if error_kernels is not None: + if not isinstance(error_kernels, list): + raise TypeError( + "`error_kernels` must be a list of ErrorKernel instances." + ) + for n, kernel in enumerate(error_kernels): + if not isinstance(kernel, ErrorKernel): + raise TypeError(f"`error_kernels[{n}]` must be an ErrorKernel.") + + # -- Class Attributes -- # + self._error_kernels = error_kernels + + def from_config(self, config: dict) -> Self: + """ + Parameterss + ---------- + config : dict + Configuration dictionary containing input parameters from .toml + configuration file. + + Returns + ------- + self + Initialised ObsSampler instance. + """ + + return self + + + def collect_samples(self, ds, profile) -> xr.Dataset: + """ + Parameters + ---------- + ds : xarray.Dataset + Gridded ocean model output dataset. + profile : xarray.Dataset + loaded observation data locations. + + Returns + ------- + xarray.Dataset + Sampled synthetic observations dataset. + """ + profile = self.time_bounds(ds, profile) + + i_nn, j_nn = self.find_nearest_ij(ds, profile) + t_nn = self.find_nearest_time(ds, profile) + + ds_synth = self.extract_locations(ds, i_nn, j_nn, t_nn) + + return ds_synth + + + def apply_errors(self, ds: xr.Dataset) -> xr.Dataset: + """ + Parameters + ---------- + ds : xarray.Dataset + Synthetic observations dataset. + + Returns + ------- + xarray.Dataset + Synthetic observations dataset with all error kernels + applied in order. + """ + # -- Apply each Error Kernel sequentially -- # + if self._error_kernels is not None: + for kernel in self._error_kernels: + logger.debug(f"Applying ErrorKernel --> {repr(kernel)}") + ds = kernel.apply(ds) + logging.info( + "--> Completed: Applied ErrorKernels to synthetic observations." + ) + + return ds + + + def sample(self, ds: xr.Dataset, profile: xr.Dataset) -> xr.Dataset: + """ + Perform sampling pipeline for chosen ocean observing platform. + + Parameters + ---------- + ds : xarray.Dataset + Gridded ocean model dataset. + profile : xarray.Dataset observation profile dataset + + Returns + ------- + xarray.Dataset + Synthetic observations dataset with errors applied. + """ + # -- Sample the gridded ocean model output -- # + ds_sampled = self.collect_samples(ds, profile) + logging.info( + "--> Completed: Collected samples from ocean model dataset using ObsSampler." + ) + + # -- Apply error kernels sequentially to the synthetic observations -- # + ds_obs = self.apply_errors(ds_sampled) + + return ds_obs + + + def time_bounds(self, ds, profile): + """ + Remove profiles that are out of model bounds in time. + + Parameters + ---------- + ds : xarray.Dataset + Gridded ocean model dataset. + profile : xarray.Dataset + observation profile dataset + + Return + profile : xarray.Dataset + observation profile dataset + """ + st_date = ds.time.min(dim="t").to_numpy() + en_date = ds.time.max(dim="t").to_numpy() + p_time = profile.time.to_numpy() + + t_index = (p_time >= st_date) & (p_time <= en_date) + n_reject = np.sum(np.invert(t_index).astype(int)) + n_total = profile.time.size + logging.info('Profiles rejected for being outside time bounds: {:.2f}'.format((n_reject / n_total) * 100)) + print('Profiles rejected for being outside time bounds: {:.2f}%'.format((n_reject / n_total) * 100)) + if n_reject / n_total == 1: + raise ValueError("All profiles outside model time bounds.") + + t_xa = xr.DataArray(t_index, coords={"profile_id": profile.coords['profile_id']}) + profile = profile.where(t_xa, drop=True) + + return profile + + + def find_nearest_ij(self, ds, profile): + """ + Turn observation lat and lon into model index + + Parameters + ---------- + ds : xarray.Dataset + Gridded ocean model dataset. + profile : xarray.Dataset + observation profile dataset + + Return + index: indicies of model in i an j + """ + + lon_sub = np.abs(ds.lon - profile.lon) + lat_sub = np.abs(ds.lat - profile.lat) + dist = ((lon_sub + lat_sub) / 2) + # Stack along new gridpoint dimension + dist = dist.stack(gridpoint=("j", "i")) + + # Tiny tie-break penalties to sort dist, j , i + # Gives consitent results and 0.5 rounds up + if (dist.min("gridpoint") == 0.5).any(): + score = ( + dist + - 1e-6 * dist["j"] + - 1e-9 * dist["i"] + ) + else: + score = dist + + # Find nearest + nearest = score.argmin("gridpoint") + ji = score["gridpoint"].isel(gridpoint=nearest) + + i_nn = ji["i"] + j_nn = ji["j"] + i_nn = i_nn.drop_vars("gridpoint") + j_nn = j_nn.drop_vars("gridpoint") + + return i_nn, j_nn + + + def find_nearest_time(self, ds, profile, thresh=10): + """ + Turn observation time into model time index + + Parameters + ---------- + ds : xarray.Dataset + Gridded ocean model dataset. + profile : xarray.Dataset + observation profile dataset + thresh : int + threshold in model timesteps for a profile being out of time bounds + + Return + index: indicies of model in time + """ + # Time difference in microsec + time_delta = np.abs(ds.time - profile.time) + + # Find nearest and take first occurance (i.e. round down) + nearest = time_delta.argmin("t") + t_near = time_delta.isel(t=nearest) + + t_nn = t_near["t"] + + # Check for out of bounds + n_profile = len(profile.coords['profile_id']) + for p in range(n_profile): + ps = profile.coords['profile_id'][p].to_numpy() + if time_delta.sel(profile_id=ps).min() > (ds.time.isel(t=1) - ds.time.isel(t=0)): + raise ValueError("Profile time is outside model time bounds.") + + return t_nn + + + def extract_locations(self, ds, i_index, j_index, t_index): + """ + Extract a model profile at the specified model index. + + Parameters + ---------- + ds : xarray.Dataset + Gridded ocean model dataset. + i_index : observation index on model grid in i direction + j_index : observation index on model grid in j direction + + Return + xarray.Dataset + Synthetic observations dataset + """ + + ds_synth = ds.isel(i=i_index, j=j_index, t=t_index) + + return ds_synth + \ No newline at end of file diff --git a/OceanOSSE/utils.py b/OceanOSSE/utils.py new file mode 100644 index 0000000..82d7db7 --- /dev/null +++ b/OceanOSSE/utils.py @@ -0,0 +1,115 @@ +""" +utils.py + +Description: Utility functions for OceanOSSE package. + +Created By: OceanOSSE Development Team (NOC, UK) +""" + +# -- Import Dependencies -- # +from __future__ import annotations + +import importlib +import tomllib +from pathlib import Path + +from OceanOSSE.validation import AppConfig + + +# -- Utility Functions -- # +def load_config(config_path: str) -> dict: + """ + Load and parse an OceanOSSE configuration file. + + Parameters + ---------- + config_path : str + Path to the OceanOSSE ``.toml`` configuration file. + + Returns + ------- + dict + Parsed configuration as a nested dictionary. + + Raises + ------ + FileNotFoundError + tomllib.TOMLDecodeError + """ + # -- Validate Inputs -- # + if not isinstance(config_path, str): + raise TypeError("``config_path`` must be a string.") + path = Path(config_path) + if not path.exists(): + raise FileNotFoundError(f"Configuration file not found: {config_path!r}") + + # -- Load and Parse Configuration -- # + # Open config .toml file: + with open(path, "rb") as f: + cfg_data = tomllib.load(f) + + # Parse and validate config data using Pydantic models: + config = AppConfig(**cfg_data) + # Convert config params to dict: + d_config = config.model_dump(mode="json") + + return d_config + + +def import_class( + module: str, + class_name: str, + class_type: type, +) -> type: + """ + Dynamically import a class from specified module path. + + This is used by the OceanOSSE pipeline to load third-party + plugin classes declared in a `.toml` configuration file. + + Parameters + ---------- + module : str + Import path to the module, e.g. ``"my_package.samplers"``. + class_name : str + Name of the class to import from module, e.g. ``"ArgoSampler"``. + class_type : type + Expected base type of the class. + + Returns + ------- + type + Imported class object. + + Raises + ------ + ModuleNotFoundError + AttributeError + """ + # -- Validate Inputs -- # + if not isinstance(module, str): + raise TypeError("``module`` must be a string.") + if not isinstance(class_name, str): + raise TypeError("``class_name`` must be a string.") + + # -- Dynamically import class -- # + try: + module = importlib.import_module(module) + except ImportError as e: + raise ImportError(f"Failed to import module '{module}'") from e + try: + Kernel = getattr(module, class_name) + except AttributeError as e: + raise AttributeError( + f"Failed to import {class_type.__name__} '{class_name}' from module '{module}'" + ) from e + + # -- Verify that class is callable -- # + if not callable(Kernel): + raise TypeError(f"{class_type.__name__} '{class_name}' is not callable.") + + # -- Verify that class is a subclass of the expected type -- # + if not issubclass(Kernel, class_type): + raise TypeError(f"'{class_name}' is not a subclass of '{class_type.__name__}'.") + + return Kernel diff --git a/OceanOSSE/validation.py b/OceanOSSE/validation.py new file mode 100644 index 0000000..8a66dc1 --- /dev/null +++ b/OceanOSSE/validation.py @@ -0,0 +1,205 @@ +""" +validation.py + +Description: Validation functions for OceanOSSE package. + +Created By: OceanOSSE Development Team (NOC, UK) +""" + +# -- Import Dependencies -- # +from __future__ import annotations + +from typing import Any, Literal + +from pydantic import BaseModel, Field, field_validator, model_validator + + +# -- [inputs.variables.] -- # +class VariableConfig(BaseModel): + """ + Schema for `[inputs.variables.]`. + + Example + ------- + [inputs.variables.thetao] + path = "/data/NEMO_thetao_*.nc" + chunks = {time_counter = 1} + """ + + path: str + open_kwargs: dict[str, Any] = Field(default_factory=dict) + + @field_validator("path") + @classmethod + def path_must_be_nonempty(cls, v: str) -> str: + if not v.strip(): + raise ValueError("'path' must not be an empty string.") + return v + + +# -- [inputs.dimensions] -- # +class DimensionsConfig(BaseModel): + """ + Schema for `[inputs.dimensions]`. + + Example + ------- + .. code-block:: toml + + [inputs.dimensions] + time = "time_counter" + lev = "deptht" + j = "y" + i = "x" + """ + + time: str = "time" + lev: str = "lev" + j: str = "j" + i: str = "i" + + +# -- [inputs.coordinates] -- # +class CoordinatesConfig(BaseModel): + """ + Schema for `[inputs.coordinates]`. + + Example + ------- + .. code-block:: toml + + [inputs.coordinates] + time = "time_counter" + depth = "deptht" + lon = "nav_lon" + lat = "nav_lat" + """ + + time: str = "time" + depth: str = "depth" + lon: str = "lon" + lat: str = "lat" + + +# -- [inputs] -- # +class InputConfig(BaseModel): + """ + Schema for `[inputs]`. + + The `variables` field accepts any number of named sub-tables: + + .. code-block:: toml + + [inputs.variables.thetao] + path = "..." + open_kwargs = { engine = "netcdf4" } + + [inputs.variables.so] + path = "..." + open_kwargs = { engine = "netcdf4" } + """ + + dimensions: DimensionsConfig = Field(default_factory=DimensionsConfig) + coordinates: CoordinatesConfig = Field(default_factory=CoordinatesConfig) + data_dir: str = "" + format: Literal["netcdf", "zarr"] = "netcdf" + variables: dict[str, VariableConfig] + name: str | None = None + module: str | None = None + + @field_validator("variables") + @classmethod + def variables_must_not_be_empty(cls, v: dict) -> dict: + if not v: + raise ValueError( + "`[inputs.variables]` must contain at least one variable entry." + ) + return v + + +# -- [sampling] -- # +class SamplingConfig(BaseModel): + """ + Schema for `[sampling]`. + + .. code-block:: toml + + [sampling] + name = "..." + + [[sampling.error_kernels]] + name = "..." + + [[sampling.error_kernels]] + module = "..." + name = "..." + """ + + name: str = "test" + module: str | None = None + error_kernels: list[dict[str, Any]] = Field(default_factory=list) + + @model_validator(mode="after") + def plugin_requires_both_module_and_name(self) -> SamplingConfig: + if (self.module is not None) and (self.name is None): + raise ValueError( + "`[sampling]`: 'module' and 'name' must both be specified " + "together for plugin loading." + ) + return self + + +# -- [regridding] -- # +class RegriddingConfig(BaseModel): + """ + Schema for `[regridding]`. + + .. code-block:: toml + + [regridding] + module = "..." + name = "..." + """ + + name: str = "test" + module: str | None = None + + @model_validator(mode="after") + def plugin_requires_both_module_and_name(self) -> RegriddingConfig: + if (self.module is not None) and (self.name is None): + raise ValueError( + "``[regridding]``: 'module' and 'name' must both be specified together." + ) + return self + + +# -- [outputs] -- # +class OutputConfig(BaseModel): + format: Literal["netcdf"] = "netcdf" + output_dir: str + output_name: str + date_format: Literal["Y", "M", "D"] = "Y" + chunks: dict[str, int] = Field(default_factory=dict) + name: str | None = None + module: str | None = None + + @model_validator(mode="after") + def plugin_requires_both_module_and_name(self) -> OutputConfig: + if (self.module is not None) and (self.name is None): + raise ValueError( + "``[outputs]``: 'module' and 'name' must both be specified together." + ) + return self + + +# -- Top-level AppConfig -- # +class AppConfig(BaseModel): + """ + Top-level OceanOSSE configuration model. + Validates the entire parsed .toml dict. + """ + + inputs: InputConfig + sampling: SamplingConfig = Field(default_factory=SamplingConfig) + regridding: RegriddingConfig = Field(default_factory=RegriddingConfig) + outputs: OutputConfig diff --git a/examples/example_config.toml b/examples/example_config.toml new file mode 100644 index 0000000..0794038 --- /dev/null +++ b/examples/example_config.toml @@ -0,0 +1,54 @@ +# ============================================================ +# OceanOSSE Example Configuration: NEMO eORCA1 ERA5v1 +# ============================================================ +# To run: +# OceanOSSE run examples/example_config.toml +# OceanOSSE run examples/example_config.toml --dry-run +# ============================================================ + +# ========================== Inputs ========================== +[inputs] +data_dir = "/path/to/nemo/output" +format = "netcdf" + +[inputs.dimensions] +time = "time_counter" +lev = "deptht" +j = "y" +i = "x" + +[inputs.coordinates] +time = "time_counter" +depth = "deptht" +lon = "nav_lon" +lat = "nav_lat" + +[inputs.variables.thetao_con] +path = "/dssgfs01/scratch/npd/simulations/eORCA1_ERA5_v1/eORCA1_ERA5_1m_grid_T_202312-202312.nc" +open_kwargs = { engine = "netcdf4" } + +[inputs.variables.so_abs] +path = "/dssgfs01/scratch/npd/simulations/eORCA1_ERA5_v1/eORCA1_ERA5_1m_grid_T_202312-202312.nc" +open_kwargs = { engine = "netcdf4" } + +# ========================== Sampling ========================== +[sampling] +name = "test" + +[[sampling.error_kernels]] +name = "test" +kwargs = { sigma = 1.0 } + +# ========================== Regridding ========================== +[regridding] +name = "test" +kwargs = { method = "nearest_s2d" } + +# ========================== Outputs ========================== +[outputs] +format = "netcdf" +output_dir = "/dssgfs01/working/otooth/Software/OceanOSSE/OceanOSSE" +output_name = "OceanOSSE_TEST" +date_format = "M" +chunks = { time_counter = 12 } +writer_kwargs = { unlimited_dims = "time_counter", mode = "w" } diff --git a/pixi.toml b/pixi.toml index 2a854b1..699c8e2 100644 --- a/pixi.toml +++ b/pixi.toml @@ -24,6 +24,7 @@ python = "*" dask = "*" gsw = "*" netCDF4 = "*" +typer = ">=0.24.1,<0.25" xarray = "*" zarr = "*" git = "*" # needed for dynamic versioning @@ -32,6 +33,7 @@ git = "*" # needed for dynamic versioning # Define OceanOSSE package as local path dependency: OceanOSSE = { path = "." } python = "3.13.*" +pydantic = ">=2.13.3,<3" # === Features === # @@ -42,6 +44,7 @@ ipykernel = "*" matplotlib = "*" cartopy = "*" nc-time-axis = "*" +polars = ">=1.39.3,<2" [feature.dev.pypi-dependencies] ruff = "*" @@ -76,4 +79,4 @@ preview-docs = { cmd = "cd docs/; zensical serve", description = "Run local prev default = { solve-group = "default" } dev = { features = ["dev"], solve-group = "default" } docs = { features = ["docs"], solve-group = "default" } -release = { features = ["release"], solve-group = "default" } \ No newline at end of file +release = { features = ["release"], solve-group = "default" } diff --git a/pyproject.toml b/pyproject.toml index 7467a1f..986eef3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,15 +19,13 @@ readme = "README.md" requires-python = ">=3.12" dependencies = [ - "dask>=2025.7.0", - "flox>=0.10.6", - "gsw>=3.6.20", - "numbagg>=0.8", - "numba>=0.62", - "numpy>=2.3", # numba 0.62 added support for numpy 2.3 - "netCDF4<=1.7.3", - "xarray>=2025.07.1", - "zarr>=3.0.7", + "cftime", + "dask", + "netcdf4", + "pydantic", + "typer", + "xarray", + "zarr", ] [project.optional-dependencies] @@ -62,3 +60,6 @@ ignore = ["E501"] [tool.ruff.format] quote-style = "double" + +[project.scripts] +OceanOSSE = "OceanOSSE.cli:app" diff --git a/tests/unit/test_sampler.py b/tests/unit/test_sampler.py new file mode 100644 index 0000000..724846e --- /dev/null +++ b/tests/unit/test_sampler.py @@ -0,0 +1,241 @@ +# =================================================================== +# Copyright 2025 National Oceanography Centre +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0. +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. See the License for the specific language governing +# permissions and limitations under the License. +# =================================================================== +""" +test_sampler.py + +Description: +This module includes unit tests for extracting profiles. + +Author: +Benjamin Barton (benbar@noc.ac.uk) +""" +import pytest +import datetime as dt +import numpy as np +import xarray as xr +from OceanOSSE.sampling.sampler_nearest_neighbour import NNSampler + +def test_sampler(): + """ + Tests for extracting a profile that falls on a model grid point. + """ + # Build dataset + ds = construct_ds() + + # Synthetic profile + prof_id = np.array([0]) + profile_lon = np.array([3]) + profile_lat = np.array([5]) + profile_time = np.array([dt.datetime(2020, 5, 4)]) + profile = xr.Dataset( + { + "lon": (("profile_id"), profile_lon), + "lat": (("profile_id"), profile_lat), + "time": (("profile_id"), profile_time) + }, + coords={ + "profile_id": prof_id, + }, + ) + + sampler = NNSampler() + model_t = sampler.sample(ds, profile) + + assert (model_t.votemper.to_numpy().squeeze() == ds.votemper[3, :, 5, 3]).all() + + +def test_sampler_multi(): + """ + Tests for extracting multiple profiles that falls on a model grid point. + """ + # Build dataset + ds = construct_ds() + + # Synthetic profile + prof_id = np.array([0, 1]) + profile_lon = np.array([3, 8]) + profile_lat = np.array([5, 6]) + profile_time = np.array([dt.datetime(2020, 5, 4), dt.datetime(2020, 8, 23)]) + profile = xr.Dataset( + { + "lon": (("profile_id"), profile_lon), + "lat": (("profile_id"), profile_lat), + "time": (("profile_id"), profile_time) + }, + coords={ + "profile_id": prof_id, + }, + ) + + sampler = NNSampler() + model_t = sampler.sample(ds, profile) + + assert (model_t.votemper.sel(profile_id=1) == ds.votemper[114, :, 6, 8]).all() + + +def test_sampler_nn(): + """ + Test for extracting a profile that falls between model grid points that will use nearest + neighbour against analytic form. + """ + # Build dataset + ds = construct_ds() + + # Synthetic profile + prof_id = np.array([0, 1]) + profile_lon = np.array([3.5, 1.2]) + profile_lat = np.array([5.5, 2.2]) + profile_time = np.array([dt.datetime(2020, 5, 4), dt.datetime(2020, 8, 23)]) + profile = xr.Dataset( + { + "lon": (("profile_id"), profile_lon), + "lat": (("profile_id"), profile_lat), + "time": (("profile_id"), profile_time) + }, + coords={ + "profile_id": prof_id, + }, + ) + + sampler = NNSampler() + model_t = sampler.sample(ds, profile) + + assert ((model_t.votemper.sel(profile_id=0) == ds.votemper[3, :, 6, 4]).all() + & (model_t.votemper.sel(profile_id=1) == ds.votemper[114, :, 2, 1]).all()) + + +def test_sampler_time(): + """ + Tests for extracting a profile that falls on a model grid point but + inbetween two time steps. + """ + # Build dataset + ds = construct_ds() + + # Synthetic profile + prof_id = np.array([0]) + profile_lon = np.array([3]) + profile_lat = np.array([5]) + profile_time = np.array([dt.datetime(2020, 5, 6, 12)]) + profile = xr.Dataset( + { + "lon": (("profile_id"), profile_lon), + "lat": (("profile_id"), profile_lat), + "time": (("profile_id"), profile_time) + }, + coords={ + "profile_id": prof_id, + }, + ) + + sampler = NNSampler() + model_t = sampler.sample(ds, profile) + + assert (model_t.votemper.to_numpy().squeeze() == ds.votemper[5, :, 5, 3]).all() + + +def test_sampler_time_out_bounds(): + """ + Tests for extracting all profiles that are outside model time bounds. + """ + # Build dataset + ds = construct_ds() + + # Synthetic profile + prof_id = np.array([0]) + profile_lon = np.array([3]) + profile_lat = np.array([5]) + profile_time = np.array([dt.datetime(2021, 5, 1)]) + profile = xr.Dataset( + { + "lon": (("profile_id"), profile_lon), + "lat": (("profile_id"), profile_lat), + "time": (("profile_id"), profile_time) + }, + coords={ + "profile_id": prof_id, + }, + ) + + sampler = NNSampler() + with pytest.raises(ValueError, match=r".*time bounds.") as exc_info: + model_t = sampler.sample(ds, profile) + + assert exc_info.type is ValueError + + +def test_sampler_time_subset(): + """ + Tests for extracting profiles where some are outside model time bounds. + """ + # Build dataset + ds = construct_ds() + + # Synthetic profile + prof_id = np.array([0, 1]) + profile_lon = np.array([3, 8]) + profile_lat = np.array([5, 6]) + profile_time = np.array([dt.datetime(2021, 5, 1), dt.datetime(2020, 5, 6)]) + profile = xr.Dataset( + { + "lon": (("profile_id"), profile_lon), + "lat": (("profile_id"), profile_lat), + "time": (("profile_id"), profile_time) + }, + coords={ + "profile_id": prof_id, + }, + ) + + sampler = NNSampler() + model_t = sampler.sample(ds, profile) + + assert (model_t.votemper.sel(profile_id=1) == ds.votemper[5, :, 6, 8]).all() + + +def construct_ds(): + """ + Build a dataset for testing. + """ + lat = np.arange(0, 8) + lon = np.arange(0, 10) + depth = np.arange(0, 150, 10) + st_date = dt.datetime(2020, 5, 1) + num_days = 180 + model_dates = np.array([st_date + dt.timedelta(days=x) for x in range(num_days)]) + model_day = np.array([x for x in range(num_days)]) + + # Broadcast to 4D (time, depth, lat, lon) + t, d, y, x = np.meshgrid(model_day, depth, lat, lon, indexing='ij') + + votemper = 15 - (y * 0.4) + (x * 0.2) - (d * 0.05) + (t * 0.000005) + + # Build dataset + ds = xr.Dataset( + { + "votemper": (("t", "d", "j", "i"), votemper), + "lat": (("j", "i"), y[0, 0, :, :]), + "lon": (("j", "i"), x[0, 0, :, :]), + "depth": (("d", "j", "i"), d[0, :, :, :]), + "time": (("t"), model_dates) + }, + coords={ + "d": depth, + "j": lat, + "i": lon, + "t": model_day + }, + ) + + return ds \ No newline at end of file