diff --git a/pyproject.toml b/pyproject.toml index 3bf9154..6a15145 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -2,7 +2,7 @@ build-backend = "setuptools.build_meta" requires = [ "cython", - "numpy<2", + "numpy", "setuptools>=42", "setuptools_scm[toml]>=3.4", "wheel" @@ -31,15 +31,17 @@ classifiers = [# https://pypi.python.org/pypi?%3Aaction=list_classifiers ] dependencies = [ - "adjustText", "anndata>=0.12.4", # https://github.com/scverse/anndata/issues/2166 - "bioio<2", + "adjustText", "bioio-nd2", "bioio-tifffile", + "bioio-ome-tiff", + "bioio-ome-zarr>=3.2.2", # https://github.com/bioio-devs/bioio-ome-zarr/pull/130, https://github.com/bioio-devs/bioio-ome-zarr/issues/128 + "bioio", "centrosome", "cp-measure", + "dask!=2026.1.2", # https://github.com/dask/dask/issues/12265 "dask-image", - "dask", "decorator", "filelock", "flox", @@ -53,14 +55,13 @@ dependencies = [ "matplotlib", "natsort", "numcodecs", - "numpy<2", + "numpy", "ome-zarr", "pandas", "pint", "psutil", "pyarrow", "pydantic", - "pysam", "scikit-image", "scikit-learn", "scipy", @@ -69,9 +70,9 @@ dependencies = [ "stardist", "statsmodels", "tensorflow", - "tifffile==2024.8.30", + "tifffile", "xarray", - "zarr<3" + "zarr>=3" ] [project.optional-dependencies] @@ -88,10 +89,12 @@ cellpose = [ ufish = [ "ufish" ] +pysam = [ + "pysam" +] test = [ "miniwdl", - "pytest", - "pytest-xdist" + "pytest" ] dev = [ @@ -105,22 +108,7 @@ doc = [ "sphinx_argparse", "sphinx_rtd_theme", ] -all = [ - "napari", - "napari_ome_zarr", - "cellpose", - "dask-ml", - "ufish", - "pytest", - "pre-commit", - "ipython", - "myst_parser", - "nbsphinx", - "sphinx-copybutton", - "sphinx", - "sphinx_argparse", - "sphinx_rtd_theme" -] + [project.entry-points."miniwdl.plugin.container_backend"] miniwdl_test_local = "scallops.tests.miniwdl_local.local_runner:LocalRunner" diff --git a/requirements.doc.txt b/requirements.doc.txt index 6c56034..b75a0dd 100644 --- a/requirements.doc.txt +++ b/requirements.doc.txt @@ -1,4 +1,4 @@ -ipython==9.9.0 +ipython==9.10.0 nbsphinx==0.9.8 sphinx-copybutton==0.5.2 sphinx==9.1.0 diff --git a/requirements.txt b/requirements.txt index 038ff74..fdc03a5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,29 +1,30 @@ anndata==0.12.10 adjustText==1.3.0 -bioio-nd2==1.1.0 -bioio-tifffile==1.1.0 -bioio-base==1.0.7 -bioio==1.6.1 +bioio==3.2.0 +bioio-nd2==1.6.2 +bioio-ome-tiff==1.4.0 +bioio-ome-zarr==3.2.2 +bioio-tifffile==1.3.0 centrosome==1.3.3 cp-measure==0.1.13 cython==3.2.4 dask-image==2025.11.0 -dask==2025.11.0 +dask==2026.1.1 decorator==5.2.1 filelock==3.20.3 -flox==0.11.0 +flox==0.11.1 fsspec==2026.2.0 igraph==1.0.0 -itk-elastix==0.23.0 -itk==5.4.3 +itk-elastix==0.24.0 +itk==5.4.5 joblib==1.5.3 kneed==0.8.5 mahotas==1.4.18 matplotlib==3.10.8 natsort==8.4.0 -numcodecs==0.15.1 -numpy==1.26.4 -ome-zarr==0.10.3 +numcodecs==0.16.5 +numpy==2.3.5 +ome-zarr==0.13.0 pandas==2.3.3 pint==0.25.2 psutil==7.2.2 @@ -37,7 +38,7 @@ seaborn==0.13.2 shapely==2.1.2 stardist==0.9.2 statsmodels==0.14.6 -tensorflow==2.19.0 -tifffile==2024.8.30 +tensorflow==2.20.0 +tifffile==2026.1.28 xarray==2026.1.0 -zarr==2.18.7 +zarr==3.1.5 diff --git a/scallops/_bioio_zarr_reader.py b/scallops/_bioio_zarr_reader.py deleted file mode 100644 index 6118c46..0000000 --- a/scallops/_bioio_zarr_reader.py +++ /dev/null @@ -1,263 +0,0 @@ -import logging -from pathlib import Path -from typing import Any, Dict, List, Optional, Tuple - -import xarray as xr -from bioio_base import constants, dimensions, exceptions, io, reader, types -from fsspec.spec import AbstractFileSystem -from ome_zarr.io import parse_url -from ome_zarr.reader import Reader as ZarrReader - -logger = logging.getLogger("scallops") - - -# Same as https://github.com/bioio-devs/bioio-ome-zarr/blob/main/bioio_ome_zarr/reader.py but fixes bug in channel names -# Also checks to see if zarr path is {zarr_path}/images/image1 with only 1 image -# See https://github.com/bioio-devs/bioio-ome-zarr/pull/22 -class ScallopsZarrReader(reader.Reader): - """The main class of each reader plugin. This class is subclass of the abstract class reader - (BaseReader) in bioio-base. - - Parameters - ---------- - image: types.PathLike - String or Path to the ZARR root - fs_kwargs: Dict[str, Any] - Ignored - """ - - _xarray_dask_data: Optional["xr.DataArray"] = None - _xarray_data: Optional["xr.DataArray"] = None - _mosaic_xarray_dask_data: Optional["xr.DataArray"] = None - _mosaic_xarray_data: Optional["xr.DataArray"] = None - _dims: Optional[dimensions.Dimensions] = None - _metadata: Optional[Any] = None - _scenes: Optional[Tuple[str, ...]] = None - _current_scene_index: int = 0 - # Do not provide default value because - # they may not need to be used by your reader (i.e. input param is an array) - _fs: "AbstractFileSystem" - _path: str - - # Required Methods - - def __init__( - self, - image: types.PathLike, - fs_kwargs: Dict[str, Any] = {}, - ): - # Expand details of provided image - self._fs, self._path = io.pathlike_to_fs( - image, - enforce_exists=False, - fs_kwargs=fs_kwargs, - ) - - # Enforce valid image - if not self._is_supported_image(self._fs, self._path): - raise exceptions.UnsupportedFileFormatError( - self.__class__.__name__, - self._path, - "Could not find a .zgroup or .zarray file at the provided path.", - ) - - self._zarr = get_zarr_reader(self._fs, self._path).zarr - self._physical_pixel_sizes: Optional[types.PhysicalPixelSizes] = None - self._channel_names: Optional[List[str]] = None - - @staticmethod - def _is_supported_image(fs: AbstractFileSystem, path: str, **kwargs: Any) -> bool: - try: - get_zarr_reader(fs, path) - return True - except AttributeError: - return False - - @classmethod - def is_supported_image( - cls, - image: types.ImageLike, - fs_kwargs: Dict[str, Any] = {}, - **kwargs: Any, - ) -> bool: - if isinstance(image, (str, Path)): - return cls._is_supported_image(None, str(image), **kwargs) - else: - return reader.Reader.is_supported_image( - cls, image, fs_kwargs=fs_kwargs, **kwargs - ) - - @property - def scenes(self) -> Tuple[str, ...]: - if self._scenes is None: - scenes = self._zarr.root_attrs["multiscales"] - - # if (each scene has a name) and (that name is unique) use name. - # otherwise generate scene names. - if all("name" in scene for scene in scenes) and ( - len({scene["name"] for scene in scenes}) == len(scenes) - ): - self._scenes = tuple(str(scene["name"]) for scene in scenes) - else: - self._scenes = tuple( - f"scene_{i}" - for i in range(len(self._zarr.root_attrs["multiscales"])) - ) - return self._scenes - - @property - def resolution_levels(self) -> Tuple[int, ...]: - """ - Returns - ------- - resolution_levels: Tuple[str, ...] - Return the available resolution levels for the current scene. - By default these are ordered from highest resolution to lowest - resolution. - """ - return tuple( - rl - for rl in range( - len( - self._zarr.root_attrs["multiscales"][self.current_scene_index][ - "datasets" - ] - ) - ) - ) - - def _read_delayed(self) -> xr.DataArray: - return self._xarr_format(delayed=True) - - def _read_immediate(self) -> xr.DataArray: - return self._xarr_format(delayed=False) - - def _xarr_format(self, delayed: bool) -> xr.DataArray: - data_path = self._zarr.root_attrs["multiscales"][self.current_scene_index][ - "datasets" - ][self.current_resolution_level]["path"] - image_data = self._zarr.load(data_path) - - axes = self._zarr.root_attrs["multiscales"][self.current_scene_index].get( - "axes" - ) - if axes: - dims = [sub["name"].upper() for sub in axes] - else: - dims = list(reader.Reader._guess_dim_order(image_data.shape)) - - if not delayed: - image_data = image_data.compute() - - coords = self._get_coords( - dims, - image_data.shape, - scene=self.current_scene, - channel_names=self.channel_names, - ) - - return xr.DataArray( - image_data, - dims=dims, - coords=coords, - attrs={constants.METADATA_UNPROCESSED: self._zarr.root_attrs}, - ) - - # Optional Methods - @property - def physical_pixel_sizes(self) -> types.PhysicalPixelSizes: - """Return the physical pixel sizes of the image.""" - if self._physical_pixel_sizes is None: - try: - z_size, y_size, x_size = self._get_pixel_size( - list(self.dims.order), - ) - except Exception as e: - logger.warning(f"Could not parse zarr pixel size: {e}") - z_size, y_size, x_size = None, None, None - - self._physical_pixel_sizes = types.PhysicalPixelSizes( - z_size, y_size, x_size - ) - return self._physical_pixel_sizes - - def _get_pixel_size( - self, - dims: List[str], - ) -> Tuple[Optional[float], Optional[float], Optional[float]]: - # OmeZarr file may contain an additional set of "coordinateTransformations" - # these coefficents are applied to all resolution levels. - if ( - "coordinateTransformations" - in self._zarr.root_attrs["multiscales"][self.current_scene_index] - ): - universal_res_consts = self._zarr.root_attrs["multiscales"][ - self.current_scene_index - ]["coordinateTransformations"][0]["scale"] - else: - universal_res_consts = [1.0 for _ in range(len(dims))] - - coord_transform = self._zarr.root_attrs["multiscales"][ - self.current_scene_index - ]["datasets"][self.current_resolution_level]["coordinateTransformations"] - - spatial_coeffs = {} - - for dim in [ - dimensions.DimensionNames.SpatialX, - dimensions.DimensionNames.SpatialY, - dimensions.DimensionNames.SpatialZ, - ]: - if dim in dims: - dim_index = dims.index(dim) - spatial_coeffs[dim] = ( - coord_transform[0]["scale"][dim_index] - * universal_res_consts[dim_index] - ) - else: - spatial_coeffs[dim] = None - - return ( - spatial_coeffs[dimensions.DimensionNames.SpatialZ], - spatial_coeffs[dimensions.DimensionNames.SpatialY], - spatial_coeffs[dimensions.DimensionNames.SpatialX], - ) - - @property - def channel_names(self) -> Optional[List[str]]: - if self._channel_names is None: - if "omero" in self._zarr.root_attrs: - self._channel_names = [ - str(channel["label"]) - for channel in self._zarr.root_attrs["omero"]["channels"] - ] - return self._channel_names - - @staticmethod - def _get_coords( - dims: List[str], - shape: Tuple[int, ...], - scene: str, - channel_names: Optional[List[str]], - ) -> Dict[str, Any]: - coords: Dict[str, Any] = {} - - # Use dims for coord determination - if dimensions.DimensionNames.Channel in dims: - # Generate channel names if no existing channel names - if channel_names is None: - coords[dimensions.DimensionNames.Channel] = [ - f"channel_{i}" - for i in range(shape[dims.index(dimensions.DimensionNames.Channel)]) - ] - else: - coords[dimensions.DimensionNames.Channel] = channel_names - - return coords - - -def get_zarr_reader(fs: AbstractFileSystem, path: str) -> ZarrReader: - if fs is not None: - path = fs.unstrip_protocol(path) - - return ZarrReader(parse_url(path, mode="r")) diff --git a/scallops/cli/pooled_if_sbs.py b/scallops/cli/pooled_if_sbs.py index 3522679..0ea00df 100644 --- a/scallops/cli/pooled_if_sbs.py +++ b/scallops/cli/pooled_if_sbs.py @@ -25,7 +25,7 @@ import pyarrow.parquet as pq import xarray as xr import zarr -from dask.delayed import Delayed, delayed +from dask.delayed import delayed from matplotlib import pyplot as plt from skimage.segmentation import expand_labels @@ -75,6 +75,7 @@ from scallops.zarr_io import ( _get_fs, _get_sep, + _get_store_path, _write_zarr_image, is_anndata_zarr, open_ome_zarr, @@ -178,8 +179,7 @@ def _peaks_to_bases( def spot_detection_pipeline( image_tuple: tuple[tuple[str, ...], list[str], dict], iss_channels: list[int], - file_separator: str, - root: zarr.Group | str, + output: str, max_filter_width: int, sigma_log: float | list[float], z_index: int | str, @@ -196,7 +196,7 @@ def spot_detection_pipeline( spot_detection_method: Literal["log", "spotiflow", "u-fish", "piscis"] = "log", spot_detection_n_cycles: int | None = None, expected_cycles: int | None = None, -) -> list[Delayed]: +): """Run the spot detection pipeline. This function processes a set of images, performs spot detection, and saves the @@ -204,8 +204,7 @@ def spot_detection_pipeline( :param image_tuple: A tuple containing information about the images. :param iss_channels: List of channel indices used for ISS sequencing. - :param file_separator: Separator used in file paths. - :param root: Root path or zarr group where the results will be stored. + :param output: Root path to where the results will be stored. :param max_filter_width: Maximum filter width used in spot detection. :param z_index: Either 'max' or z-index :param sigma_log: Sigma parameter for log transformation in spot detection. @@ -225,15 +224,19 @@ def spot_detection_pipeline( """ _, file_list, metadata = image_tuple image_key = metadata["id"] + output_fs = fsspec.url_to_fs(output)[0] + output_sep = output_fs.sep + output = output.rstrip(output_sep) + points_path = f"{output}{output_sep}points" + + points_protocol = _get_fs_protocol(output_fs) + if points_protocol != "file": + points_path = f"{points_protocol}://{points_path}" + peaks_path = f"{points_path}{output_sep}{image_key}-peaks.parquet" if not force: - points_path = f"{root.store.path.rstrip(_get_sep(root))}{_get_sep(root)}points" - points_protocol = _get_fs_protocol(_get_fs(root)) - if points_protocol != "file": - points_path = f"{points_protocol}://{points_path}" - peaks_path = f"{points_path}{_get_sep(root)}{image_key}-peaks.parquet" if is_parquet_file(peaks_path): logger.info(f"Skipping spot detection for {image_key}") - return [] + return image = _images2fov(file_list, metadata, dask=True) image = _z_projection(image, z_index) if expected_cycles is not None: @@ -293,10 +296,9 @@ def spot_detection_pipeline( dask_delayed.append( _write_image( name=f"{image_key}-log", - root=root, + root=open_ome_zarr(output, mode="a"), image=loged, output_format=output_image_format, - file_separator=file_separator, zarr_format="zarr", compute=compute, ) @@ -308,10 +310,9 @@ def spot_detection_pipeline( dask_delayed.append( _write_image( name=f"{image_key}-std", - root=root, + root=open_ome_zarr(output, mode="a"), image=std_arr, output_format=output_image_format, - file_separator=file_separator, metadata=dict(parent=image_key), compute=compute, ) @@ -323,10 +324,9 @@ def spot_detection_pipeline( dask_delayed.append( _write_image( name=f"{image_key}-max", - root=root, + root=open_ome_zarr(output, mode="a"), image=maxed, output_format=output_image_format, - file_separator=file_separator, zarr_format="zarr", compute=compute, ) @@ -334,14 +334,10 @@ def spot_detection_pipeline( else: del maxed if "peaks" in save_keys: - points_path = f"{root.store.path.rstrip(_get_sep(root))}{_get_sep(root)}points" - protocol = _get_fs_protocol(_get_fs(root)) - if protocol != "file": - points_path = f"{protocol}://{points_path}" - _get_fs(root).makedirs(points_path, exist_ok=True) - peaks_path = f"{points_path}{_get_sep(root)}{image_key}-peaks.parquet" - if _get_fs(root).exists(peaks_path): - _get_fs(root).rm(peaks_path, recursive=True) + output_fs.makedirs(points_path, exist_ok=True) + + if output_fs.exists(peaks_path): + output_fs.rm(peaks_path, recursive=True) dask_delayed.append( _to_parquet( @@ -353,7 +349,6 @@ def spot_detection_pipeline( ) if not compute and len(dask_delayed) > 0: dask.compute(*dask_delayed) - return [] def _fix_cycles(sbs_cycles): @@ -803,19 +798,17 @@ def spot_detect_main(arguments: argparse.Namespace): chunks = (chunks, chunks) output = _add_suffix(output, ".zarr") - root = open_ome_zarr(output, mode="a") + exp_gen = _set_up_experiment(images, image_pattern, group_by, subset=subset) with ( _create_default_dask_config(), _create_dask_client(dask_scheduler_url, **dask_cluster_parameters), ): - delayed_results = [] for img in exp_gen: - delayed_results += spot_detection_pipeline( + spot_detection_pipeline( img, iss_channels=channels, - file_separator=None, - root=root, + output=output, z_index=z_index, output_image_format="zarr", max_filter_width=max_filter_width, @@ -833,8 +826,6 @@ def spot_detect_main(arguments: argparse.Namespace): spot_detection_n_cycles=spot_detection_n_cycles, expected_cycles=expected_cycles, ) - if len(delayed_results) > 0: - dask.compute(*delayed_results) def reads_pipeline( @@ -911,7 +902,7 @@ def reads_pipeline( logger.info(f"Running reads for {image_key}") spots_sep = _get_sep(spots_root) - points_path = f"{spots_root.store.path.rstrip(spots_sep)}{spots_sep}points" + points_path = f"{_get_store_path(spots_root).rstrip(spots_sep)}{spots_sep}points" spots_protocol = _get_fs_protocol(_get_fs(spots_root)) if spots_protocol != "file": points_path = f"{spots_protocol}://{points_path}" @@ -1229,8 +1220,8 @@ def reads_main(arguments: argparse.Namespace): for key in image_keys: reads_pipeline( key, - spots_root=zarr.open(spots, "r"), - labels_root=zarr.open(labels + labels_fs.sep + "labels", "r"), + spots_root=zarr.open(spots, mode="r"), + labels_root=zarr.open(labels + labels_fs.sep + "labels", mode="r"), barcodes_file=barcodes_file, file_separator=output_fs.sep, threshold_peaks=threshold_peaks, diff --git a/scallops/cli/register.py b/scallops/cli/register.py index ed081d9..bf9e89f 100644 --- a/scallops/cli/register.py +++ b/scallops/cli/register.py @@ -464,7 +464,7 @@ def get_matching_names( results = [] for path in paths: name = os.path.basename(path) - if not name.startswith(".") and is_ome_zarr_array(zarr.open(path, "r")): + if not name.startswith(".") and is_ome_zarr_array(zarr.open(path, mode="r")): results.append(path) return results diff --git a/scallops/cli/util.py b/scallops/cli/util.py index 0d2da08..6782e78 100644 --- a/scallops/cli/util.py +++ b/scallops/cli/util.py @@ -199,7 +199,7 @@ def _write_image( root: zarr.Group | str, image: np.ndarray | xr.DataArray | da.Array, output_format: str, - file_separator: str, + file_separator: str = "/", metadata: dict | None = None, compute: bool = True, **kwargs, diff --git a/scallops/features/generate.py b/scallops/features/generate.py index 1dd2e4e..1c389b1 100644 --- a/scallops/features/generate.py +++ b/scallops/features/generate.py @@ -158,7 +158,7 @@ def label_features( if isinstance(intensity_image, da.Array): # y,x,c assert intensity_image.shape[:-1] == label_shape, ( - f"{intensity_image.shape} != {label_shape}" + f"{intensity_image.shape[:-1]} != {label_shape}" ) label_image = label_image.rechunk(intensity_image.chunksize[:-1]) diff --git a/scallops/features/image_quality.py b/scallops/features/image_quality.py index 162cd4f..5548d9f 100644 --- a/scallops/features/image_quality.py +++ b/scallops/features/image_quality.py @@ -1,6 +1,32 @@ -import centrosome.radial_power_spectrum import numpy as np import scipy +from scipy.fftpack import fft2 +from scipy.ndimage import sum as nd_sum + + +# copied from centrosome.radial_power_spectrum but use np.ptp instead of img.ptp for numpy 2 +def rps(img): + assert img.ndim == 2 + radii2 = (np.arange(img.shape[0]).reshape((img.shape[0], 1)) ** 2) + ( + np.arange(img.shape[1]) ** 2 + ) + radii2 = np.minimum(radii2, np.flipud(radii2)) + radii2 = np.minimum(radii2, np.fliplr(radii2)) + maxwidth = ( + min(img.shape[0], img.shape[1]) / 8.0 + ) # truncate early to avoid edge effects + if np.ptp(img) > 0: + img = img / np.median(abs(img - img.mean())) # intensity invariant + mag = abs(fft2(img - np.mean(img))) + power = mag**2 + radii = np.floor(np.sqrt(radii2)).astype(int) + 1 + labels = np.arange(2, np.floor(maxwidth)).astype(int).tolist() # skip DC component + if len(labels) > 0: + magsum = nd_sum(mag, radii, labels) + powersum = nd_sum(power, radii, labels) + return np.array(labels), np.array(magsum), np.array(powersum) + + return [2], [0], [0] def power_spectrum(image: np.ndarray) -> float: @@ -9,7 +35,7 @@ def power_spectrum(image: np.ndarray) -> float: gives a measure of image blur. A higher slope indicates more lower frequency components, and hence more blur. See https://cellprofiler-manual.s3.amazonaws.com/CellProfiler-4.0.5/modules/measurement.html#measureimagequality """ - radii, magnitude, power = centrosome.radial_power_spectrum.rps(image) + radii, magnitude, power = rps(image) if sum(magnitude) > 0 and len(np.unique(image)) > 1: valid = magnitude > 0 radii = radii[valid].reshape((-1, 1)) diff --git a/scallops/io.py b/scallops/io.py index 72c69d3..8767edc 100644 --- a/scallops/io.py +++ b/scallops/io.py @@ -29,6 +29,7 @@ import anndata import bioio +import bioio_ome_zarr import bioio_tifffile import dask import dask.array as da @@ -54,12 +55,11 @@ from xarray.core.utils import equivalent from zarr.storage import StoreLike -from scallops._bioio_zarr_reader import ScallopsZarrReader from scallops.experiment.elements import Experiment, _LazyLoadData from scallops.externals.tifffile2014 import imsave from scallops.utils import forceTCZYX, mlcs from scallops.xr import _crop -from scallops.zarr_io import _read_zarr_experiment, read_ome_zarr_array +from scallops.zarr_io import _get_store_path, _read_zarr_experiment, read_ome_zarr_array logger = logging.getLogger("scallops") @@ -234,7 +234,7 @@ def _create_image(path: str, **kwargs) -> bioio.BioImage: base_path_lc, ext = os.path.splitext(path_lc) if "reader" not in img_args: if ext in ["", ".zarr", "/", ".zarr/"]: - img_args["reader"] = ScallopsZarrReader + img_args["reader"] = bioio_ome_zarr.Reader elif ext in [".tiff", ".tif"] and os.path.splitext(base_path_lc)[1] != ".ome": img_args["reader"] = bioio_tifffile.Reader return bioio.BioImage(path, **img_args) @@ -1358,7 +1358,7 @@ def _images2fov( name = ( os.path.basename(file_list[i]) if not isinstance(file_list[i], zarr.Group) - else file_list[i].store.path + else _get_store_path(file_list[i]) ) src_metadata.append(dict(attrs=image_attrs[i], name=name)) @@ -1599,6 +1599,7 @@ def _get_image_key_func(group_by): lambda: [] ) # key is tuple -> value is tuple of group, dict maxdepth = None + for image_path in image_paths: if isinstance(image_path, Path): # IF URI DO NOT PROVIDE AS PATH @@ -1611,6 +1612,7 @@ def _get_image_key_func(group_by): pass else: root = image_path + if root is not None: if "0" not in root: # format: "path.zarr/images/" if "images" in root: @@ -1664,6 +1666,7 @@ def _get_image_key_func(group_by): if image_path in [".", "./"] and _get_fs_protocol(fs) == "file": image_path = fs.info(image_path)["name"].rstrip(".") image_prefix = None + if fs.isdir(image_path): image_path = image_path.rstrip(fs.sep) if maxdepth is None: @@ -1681,6 +1684,7 @@ def _get_image_key_func(group_by): withdirs=True, ) ) + paths = [p for p in all_paths if p.lower().endswith(extension)] if len(paths) == 0: # try with no maxdepth @@ -1718,7 +1722,7 @@ def _get_image_key_func(group_by): group_to_matches[group].append((x, d)) if len(group_to_matches) == 0: - message = [f"No files found matching pattern: {file_regex.pattern}"] + message = [f"No files found matching pattern: {files_pattern}"] if subset_ is not None: message.append(f", subset: {', '.join([str(s) for s in subset_])}") if len(group_by) > 0: @@ -1784,7 +1788,9 @@ def file_sort_key(x): src=file_list, common_src=mlcs( [ - Path(x).stem if not isinstance(x, zarr.Group) else x.store.path + Path(x).stem + if not isinstance(x, zarr.Group) + else _get_store_path(x) for x in file_list ] ), diff --git a/scallops/registration/itk.py b/scallops/registration/itk.py index 205c15b..5dbbef1 100644 --- a/scallops/registration/itk.py +++ b/scallops/registration/itk.py @@ -33,7 +33,12 @@ from scallops.registration.landmarks import _get_translation, find_landmarks from scallops.utils import _dask_from_array_no_copy from scallops.xr import _get_dims -from scallops.zarr_io import open_ome_zarr, write_zarr +from scallops.zarr_io import ( + default_zarr_format, + get_zarr_array_kwargs, + open_ome_zarr, + write_zarr, +) logger = logging.getLogger("scallops") @@ -328,15 +333,18 @@ def _init_callback(init_params: dict[str, Any]) -> dict[str, Any]: group = None if image_root is not None: images_group = image_root.require_group("images", overwrite=False) + fmt = default_zarr_format() group = images_group.create_group( image_name.replace("/", "-"), overwrite=True ) - zarr_dataset = group.create_dataset( + + zarr_dataset = group.create_array( "0", shape=shape, chunks=(1,) * (len(shape) - 2) + chunk_size, dtype=dtype, overwrite=True, + **get_zarr_array_kwargs(fmt), ) return { @@ -1164,12 +1172,15 @@ def _itk_transform_image_zarr( image_name.replace("/", "-"), overwrite=True ) chunks = (1,) * len(transform_dims) + (chunksize or (1024, 1024)) - data = group.create_dataset( + fmt = default_zarr_format() + + data = group.create_array( "0", shape=dim_sizes + output_size, chunks=chunks, dtype=image.dtype, overwrite=True, + **get_zarr_array_kwargs(fmt), ) _itk_transform_image( diff --git a/scallops/stitch/_stitch.py b/scallops/stitch/_stitch.py index acd7401..c294997 100644 --- a/scallops/stitch/_stitch.py +++ b/scallops/stitch/_stitch.py @@ -6,7 +6,6 @@ from collections.abc import Sequence from typing import Literal -import dask.array as da import fsspec import numpy as np import pandas as pd @@ -14,7 +13,6 @@ import pyarrow.parquet as pq import zarr from sklearn.cluster import AgglomerativeClustering -from zarr.errors import PathNotFoundError from scallops.cli.util import _get_cli_logger, cli_metadata from scallops.io import is_parquet_file, read_image @@ -32,7 +30,7 @@ tile_source_labels, ) from scallops.utils import _dask_from_array_no_copy -from scallops.zarr_io import is_ome_zarr_array +from scallops.zarr_io import is_ome_zarr_array, write_zarr logger = _get_cli_logger() @@ -82,14 +80,14 @@ def _single_stitch( if is_ome_zarr_array(image_output_root.get(f"images/{image_key}")): logger.info(f"Skipping stitching for {image_key}.") return - except PathNotFoundError: + except: # noqa: E722 pass elif not no_save_labels: try: if is_ome_zarr_array(image_output_root.get(f"labels/{image_key}-mask")): logger.info(f"Skipping stitching for {image_key}.") return - except PathNotFoundError: + except: # noqa: E722 pass elif is_parquet_file(f"{other_output_path}{image_key}-positions.parquet"): logger.info(f"Skipping stitching for {image_key}.") @@ -341,10 +339,10 @@ def _single_stitch( tile_shape_no_crop[0] - fuse_crop_width * 2, tile_shape_no_crop[1] - fuse_crop_width * 2, ) - fused_y_size = ( + fused_y_size = int( np.round(stitch_positions_df["y"].max()).astype(int) + fused_tile_shape[0] ) - fused_x_size = ( + fused_x_size = int( np.round(stitch_positions_df["x"].max()).astype(int) + fused_tile_shape[1] ) @@ -358,8 +356,6 @@ def _single_stitch( blend, image_output_root, image_key, - fused_y_size, - fused_x_size, fused_tile_shape, chunk_size, image_spacing, @@ -384,8 +380,6 @@ def _write_arrays( blend, image_output_root, image_key, - fused_y_size, - fused_x_size, fused_tile_shape, chunk_size, image_spacing, @@ -405,17 +399,8 @@ def _write_arrays( labels_group = image_output_root.require_group("labels") group = labels_group.create_group(image_key + "-mask", overwrite=True) - array = group.create_dataset( - name="0", - shape=(fused_y_size, fused_x_size), - chunks=chunk_size, - dtype=np.uint8, - dimension_separator="/", - overwrite=True, - ) - - da.to_zarr( - arr=_dask_from_array_no_copy( + write_zarr( + data=_dask_from_array_no_copy( tile_overlap_mask( stitch_positions_df, fill=blend != "none", @@ -423,39 +408,39 @@ def _write_arrays( ), chunks=chunk_size, ), - url=array, + grp=group, + image_attrs=None, + coords=None, + dims=None, + scaler=None, compute=True, - dimension_separator="/", ) group.attrs.update( _create_label_ome_metadata(image_spacing, image_key + "-mask") ) if blend == "none": group = labels_group.create_group(image_key + "-tile", overwrite=True) - array = group.create_dataset( - name="0", - shape=(fused_y_size, fused_x_size), - chunks=chunk_size, - dtype=np.uint16, - dimension_separator="/", - overwrite=True, - ) - - da.to_zarr( - arr=_dask_from_array_no_copy( + write_zarr( + data=_dask_from_array_no_copy( tile_source_labels(stitch_positions_df, fused_tile_shape), chunks=chunk_size, ), - url=array, + grp=group, + image_attrs=None, + coords=None, + dims=None, + scaler=None, compute=True, - dimension_separator="/", ) label_metadata = _create_label_ome_metadata( image_spacing, image_key + "-tile" ) - label_metadata["multiscales"][0]["metadata"] = { - "source": f"../../images/{image_key}" - } + label_multiscales = ( + label_metadata["ome"]["multiscales"] + if "ome" in label_metadata + else label_metadata["multiscales"] + ) + label_multiscales[0]["metadata"] = {"source": f"../../images/{image_key}"} group.attrs.update(label_metadata) cleanup_paths = [] if not no_save_image: diff --git a/scallops/stitch/fuse.py b/scallops/stitch/fuse.py index a7343ae..ec16174 100644 --- a/scallops/stitch/fuse.py +++ b/scallops/stitch/fuse.py @@ -22,12 +22,14 @@ from scallops.stitch._radial import radial_correct from scallops.stitch.utils import dtype_convert from scallops.utils import _cpu_count, _dask_from_array_no_copy +from scallops.zarr_io import default_zarr_format, get_zarr_array_kwargs logger = logging.getLogger("scallops") def _create_label_ome_metadata(image_spacing: tuple[float, float], label_name: str): - return { + fmt = default_zarr_format() + d = { "multiscales": [ { "axes": [ @@ -38,10 +40,10 @@ def _create_label_ome_metadata(image_spacing: tuple[float, float], label_name: s { "coordinateTransformations": [ { - "scale": [ + "scale": ( float(image_spacing[0]), float(image_spacing[1]), - ], + ), "type": "scale", } ], @@ -49,10 +51,14 @@ def _create_label_ome_metadata(image_spacing: tuple[float, float], label_name: s } ], "name": f"/labels/{label_name}", - "version": "0.4", + "version": fmt.version, } ] } + if fmt.version in ("0.1", "0.2", "0.3", "0.4"): + return d + + return {"ome": d} def _create_ome_metadata( @@ -64,9 +70,10 @@ def _create_ome_metadata( metadata = {} metadata.update(**kwargs) metadata["stitch_coords"] = dict() + fmt = default_zarr_format() for c in stitch_coords: # convert to dict metadata["stitch_coords"][c] = stitch_coords[c].to_list() - return { + d = { "multiscales": [ { "metadata": metadata, @@ -79,11 +86,11 @@ def _create_ome_metadata( { "coordinateTransformations": [ { - "scale": [ + "scale": ( 1.0, float(image_spacing[0]), float(image_spacing[1]), - ], + ), "type": "scale", } ], @@ -91,10 +98,13 @@ def _create_ome_metadata( } ], "name": f"/images/{image_key}", - "version": "0.4", + "version": fmt.version, } ] } + if fmt.version in ("0.1", "0.2", "0.3", "0.4"): + return d + return {"ome": d} def _fuse( @@ -174,8 +184,8 @@ def _fuse( df["x"] = df["x"].round().values.astype(int) df["y"] = df["y"].round().values.astype(int) - fused_y_size = (df["y"] + ysize).max() - fused_x_size = (df["x"] + xsize).max() + fused_y_size = int((df["y"] + ysize).max()) + fused_x_size = int((df["x"] + xsize).max()) if channels_per_batch is None: if blend == "none": @@ -222,18 +232,16 @@ def _fuse( locks.append(threading.Lock()) locks = np.array(locks) partition_tree = shapely.STRtree(partition_boxes) + output_shape = (len(output_channels), fused_y_size, fused_x_size) + fmt = default_zarr_format() - result = group.create_dataset( - shape=( - len(output_channels), # c - fused_y_size, - fused_x_size, - ), + result = group.create_array( + shape=output_shape, dtype=target_dtype, chunks=(1,) + chunk_size, name="0", - dimension_separator="/", overwrite=True, + **get_zarr_array_kwargs(fmt), ) _fuse_image_delayed = delayed(_fuse_image) @@ -373,7 +381,6 @@ def _fuse( url=result, region=(slice(channel_batch, channel_batch + channels_per_batch),), compute=True, - dimension_separator="/", ) diff --git a/scallops/stitch/utils.py b/scallops/stitch/utils.py index c7bb063..12e5200 100644 --- a/scallops/stitch/utils.py +++ b/scallops/stitch/utils.py @@ -4,7 +4,7 @@ import logging import os import tempfile -from collections.abc import Sequence +from collections.abc import Mapping, Sequence from typing import Literal import bioio @@ -20,6 +20,7 @@ from ome_types import from_xml from pint import UndefinedUnitError, UnitRegistry from skimage.util import img_as_float, img_as_ubyte, img_as_uint +from zarr.core.group import GroupMetadata from scallops.cli.util import _group_src_attrs from scallops.features.image_quality import power_spectrum @@ -311,10 +312,10 @@ def _get_ome(image: bioio.BioImage): return metadata except NotImplementedError: pass - - if isinstance(image.metadata, str): + image_metadata = _get_image_metadata(image) + if isinstance(image_metadata, str): try: - return from_xml(image.metadata) + return from_xml(image_metadata) except: # noqa: E722 pass return None @@ -323,7 +324,10 @@ def _get_ome(image: bioio.BioImage): def get_tile_position(image: bioio.BioImage, image_index: int = 0): ome_metadata = _get_ome(image) - if ome_metadata is not None: + if ( + ome_metadata is not None + and len(ome_metadata.images[image_index].pixels.planes) > 0 + ): values = [ ome_metadata.images[image_index].pixels.planes[0].position_y, ome_metadata.images[image_index].pixels.planes[0].position_x, @@ -334,36 +338,44 @@ def get_tile_position(image: bioio.BioImage, image_index: int = 0): physical_size_x_unit = ( ome_metadata.images[image_index].pixels.planes[0].position_x_unit.value ) - elif "multiscales" in image.metadata: - metadata = image.metadata["multiscales"][0]["metadata"] - values = [metadata["position_y"], metadata["position_x"]] - physical_size_y_unit = metadata["position_y_unit"] - physical_size_x_unit = metadata["position_x_unit"] else: - attrs = image.xarray_dask_data.attrs - if "unprocessed" in attrs: - if 51123 in attrs["unprocessed"]: - attrs = attrs["unprocessed"][51123] - return np.array([attrs["YPositionUm"], attrs["XPositionUm"]]) - elif 50839 in attrs["unprocessed"]: - attrs = attrs["unprocessed"][50839] - if "Info" in attrs: - attrs = json.loads(attrs["Info"]) + image_metadata = _get_image_metadata(image) + + if "ome" in image_metadata or "multiscales" in image_metadata: + metadata = ( + image_metadata["ome"]["multiscales"][0]["metadata"] + if "ome" in image_metadata + else image_metadata["multiscales"][0]["metadata"] + ) + + values = [metadata["position_y"], metadata["position_x"]] + physical_size_y_unit = metadata["position_y_unit"] + physical_size_x_unit = metadata["position_x_unit"] + else: + attrs = image.xarray_dask_data.attrs + if "unprocessed" in attrs: + if 51123 in attrs["unprocessed"]: + attrs = attrs["unprocessed"][51123] return np.array([attrs["YPositionUm"], attrs["XPositionUm"]]) - elif 270 in attrs["unprocessed"]: # IXM - attrs = attrs["unprocessed"][270] - import xml.etree.ElementTree as ET - - try: - tree = ET.fromstring(attrs) - stage_y = tree.findall(".//prop[@id='stage-position-y']") - stage_x = tree.findall(".//prop[@id='stage-position-x']") - if len(stage_y) == 1 and len(stage_x) == 1: - stage_y = stage_y[0].attrib["value"] - stage_x = stage_x[0].attrib["value"] - return np.array([stage_y, stage_x]) - except: # noqa: E722 - pass + elif 50839 in attrs["unprocessed"]: + attrs = attrs["unprocessed"][50839] + if "Info" in attrs: + attrs = json.loads(attrs["Info"]) + return np.array([attrs["YPositionUm"], attrs["XPositionUm"]]) + elif 270 in attrs["unprocessed"]: # IXM + attrs = attrs["unprocessed"][270] + import xml.etree.ElementTree as ET + + try: + tree = ET.fromstring(attrs) + stage_y = tree.findall(".//prop[@id='stage-position-y']") + stage_x = tree.findall(".//prop[@id='stage-position-x']") + if len(stage_y) == 1 and len(stage_x) == 1: + stage_y = stage_y[0].attrib["value"] + stage_x = stage_x[0].attrib["value"] + return np.array([stage_y, stage_x]) + except: # noqa: E722 + pass if physical_size_y_unit is not None and physical_size_x_unit is not None: try: values[0] = ( @@ -403,6 +415,15 @@ def get_pixel_size( return _pixel_size_from_image(_create_image(filepaths[0])) +def _get_image_metadata(image: bioio.BioImage) -> dict: + metadata = image.metadata # can be zarr GroupMetadata or dict + if isinstance(metadata, GroupMetadata): + metadata = metadata.attributes + if not isinstance(metadata, Mapping): + return dict() + return metadata + + def _pixel_size_from_image(image: bioio.BioImage) -> np.array: ome_metadata = _get_ome(image) values = None @@ -415,43 +436,51 @@ def _pixel_size_from_image(image: bioio.BioImage) -> np.array: ] physical_size_y_unit = ome_metadata.images[0].pixels.physical_size_y_unit.value physical_size_x_unit = ome_metadata.images[0].pixels.physical_size_x_unit.value - elif "multiscales" in image.metadata: - metadata = image.metadata["multiscales"][0]["metadata"] - values = [metadata["physical_size_y"], metadata["physical_size_x"]] - physical_size_y_unit = metadata["physical_size_y_unit"] - physical_size_x_unit = metadata["physical_size_x_unit"] else: - attrs = image.xarray_dask_data.attrs - if "unprocessed" in attrs: - attrs = attrs["unprocessed"] - if 51123 in attrs: - attrs = attrs[51123] - if "PixelSizeUm" in attrs: - pixel_size = attrs["PixelSizeUm"] - values = np.array([pixel_size, pixel_size]) - elif 270 in attrs: - import xml.etree.ElementTree as ET - - try: - tree = ET.fromstring(attrs[270]) - y = tree.findall(".//prop[@id='spatial-calibration-y']") - x = tree.findall(".//prop[@id='spatial-calibration-x']") - if len(y) == 1 and len(x) == 1: - y = y[0].attrib["value"] - x = x[0].attrib["value"] - values = np.array([y, x]).astype(float) - units = tree.findall(".//prop[@id='spatial-calibration-units']") - if len(units) == 1: - units = units[0].attrib["value"] - physical_size_y_unit = units - physical_size_x_unit = units - - except: # noqa: E722 - pass - if values is None and hasattr(image, "physical_pixel_sizes"): - values = np.array( - [image.physical_pixel_sizes.Y, image.physical_pixel_sizes.X] + image_metadata = _get_image_metadata(image) + if "ome" in image_metadata or "multiscales" in image_metadata: + metadata = ( + image_metadata["ome"]["multiscales"][0]["metadata"] + if "ome" in image_metadata + else image_metadata["multiscales"][0]["metadata"] ) + values = [metadata["physical_size_y"], metadata["physical_size_x"]] + physical_size_y_unit = metadata["physical_size_y_unit"] + physical_size_x_unit = metadata["physical_size_x_unit"] + else: + attrs = image.xarray_dask_data.attrs + if "unprocessed" in attrs: + attrs = attrs["unprocessed"] + if 51123 in attrs: + attrs = attrs[51123] + if "PixelSizeUm" in attrs: + pixel_size = attrs["PixelSizeUm"] + values = np.array([pixel_size, pixel_size]) + elif 270 in attrs: + import xml.etree.ElementTree as ET + + try: + tree = ET.fromstring(attrs[270]) + y = tree.findall(".//prop[@id='spatial-calibration-y']") + x = tree.findall(".//prop[@id='spatial-calibration-x']") + if len(y) == 1 and len(x) == 1: + y = y[0].attrib["value"] + x = x[0].attrib["value"] + values = np.array([y, x]).astype(float) + units = tree.findall( + ".//prop[@id='spatial-calibration-units']" + ) + if len(units) == 1: + units = units[0].attrib["value"] + physical_size_y_unit = units + physical_size_x_unit = units + + except: # noqa: E722 + pass + if values is None and hasattr(image, "physical_pixel_sizes"): + values = np.array( + [image.physical_pixel_sizes.Y, image.physical_pixel_sizes.X] + ) if physical_size_y_unit is not None and physical_size_x_unit is not None: try: values[0] = ( diff --git a/scallops/tests/test_features.py b/scallops/tests/test_features.py index 4f4c294..b509be3 100644 --- a/scallops/tests/test_features.py +++ b/scallops/tests/test_features.py @@ -60,12 +60,14 @@ def test_to_label_crops(tmp_path, array_A1_102_cells, array_A1_102_alnpheno): assert len(result_df) == 1 and result_df.index.values[0] == 2603 group = zarr.group() - intensity_image_zarr = group.create_dataset( - name="image", shape=intensity_image.shape + intensity_image_zarr = group.create_array( + name="image", shape=intensity_image.shape, dtype=intensity_image.dtype ) intensity_image_zarr[:] = intensity_image.compute() - label_image_zarr = group.create_dataset(name="label", shape=label_image.shape) + label_image_zarr = group.create_array( + name="label", shape=label_image.shape, dtype=label_image.dtype + ) label_image_zarr[:] = label_image.compute() to_label_crops( diff --git a/scallops/tests/test_illumination_correction.py b/scallops/tests/test_illumination_correction.py index d997bc4..3c6c12f 100644 --- a/scallops/tests/test_illumination_correction.py +++ b/scallops/tests/test_illumination_correction.py @@ -28,8 +28,8 @@ def test_illumination_correction_cli(tmp_path): ] subprocess.check_call(args) - store = zarr.ZipStore("scallops/tests/data/ops-illum-corr.zip", mode="r") - root = zarr.group(store=store) + store = zarr.storage.ZipStore("scallops/tests/data/ops-illum-corr.zip", mode="r") + root = zarr.open(store=store, mode="r") np.testing.assert_equal( root["data"][...], read_image(os.path.join(tmp_path, "images", "A1")).values.squeeze(), diff --git a/scallops/tests/test_io.py b/scallops/tests/test_io.py index 01e6a94..fb09fa4 100644 --- a/scallops/tests/test_io.py +++ b/scallops/tests/test_io.py @@ -217,12 +217,14 @@ def test_write_non_ome_zarr_image(tmp_path, dask): image.attrs["physical_pixel_sizes"] = (1, 1, 1) image.attrs["physical_pixel_units"] = ("mm", "mm", "mm") zarr_path = str(tmp_path / "test.zarr") - _write_zarr_image("foo", open_ome_zarr(zarr_path), image, zarr_format="zarr") - _write_zarr_image("foo2", open_ome_zarr(zarr_path), image) + _write_zarr_image("img_zarr", open_ome_zarr(zarr_path), image, zarr_format="zarr") + _write_zarr_image("img_ome_zarr", open_ome_zarr(zarr_path), image) + + data_zarr = read_image(f"{zarr_path}/images/img_zarr", dask=False) + data_ome_zarr = read_image(f"{zarr_path}/images/img_ome_zarr", dask=False) - data_zarr = read_image(f"{zarr_path}/images/foo", dask=False) - data_ome_zarr = read_image(f"{zarr_path}/images/foo2", dask=False) xr.testing.assert_equal(data_zarr, data_ome_zarr) + xr.testing.assert_equal(image, data_ome_zarr) @pytest.mark.io @@ -344,7 +346,7 @@ def test_read_write_labels(tmp_path, array_A1_102_nuclei): _write_zarr_labels( name="test", root=open_ome_zarr(str(tmp_path), "w"), labels=nuclei ) - test = read_ome_zarr_array(zarr.open(str(tmp_path / "labels" / "test"), "r")) + test = read_ome_zarr_array(zarr.open(str(tmp_path / "labels" / "test"), mode="r")) np.testing.assert_equal(nuclei, test.data) diff --git a/scallops/zarr_io.py b/scallops/zarr_io.py index 4210d36..780eb6c 100644 --- a/scallops/zarr_io.py +++ b/scallops/zarr_io.py @@ -24,7 +24,7 @@ from dask.delayed import Delayed from dask.graph_manipulation import bind from ome_zarr.axes import KNOWN_AXES -from ome_zarr.format import CurrentFormat +from ome_zarr.format import FormatV04 from ome_zarr.io import parse_url from ome_zarr.scale import Scaler from ome_zarr.types import JSONDict @@ -38,6 +38,18 @@ logger = logging.getLogger("scallops") +def default_zarr_format(): + return FormatV04() + + +def get_zarr_array_kwargs(fmt): + return ( + {"dimension_separator": "/"} + if fmt.version == 2 + else {"chunk_key_encoding": fmt.chunk_key_encoding} + ) + + def is_anndata_zarr(store: StoreLike) -> bool: """Determines whether store is an AnnData Zarr . @@ -76,13 +88,21 @@ def is_ome_zarr_array(node: zarr.Group) -> bool: result = is_ome_zarr_array(root) print(result) # Output: True """ - return node is not None and "multiscales" in node.attrs + return node is not None and ("ome" in node.attrs or "multiscales" in node.attrs) def _get_fs(group: zarr.Group): if hasattr(group.store, "fs"): return group.store.fs - return fsspec.url_to_fs(group.store.path)[0] + return fsspec.url_to_fs(_get_store_path(group))[0] + + +def _get_store_path(group: zarr.Group): + if hasattr(group.store, "root"): + return str(group.store.root) + if hasattr(group.store, "path"): + return group.store.path + return "" def _get_sep(group: zarr.Group) -> str: @@ -134,7 +154,7 @@ def _create_omero_metadata( # Napari requires that colors are specified if channel names are specified channels = ( [ - dict(label=channel_names[i], color=colors[i % len(colors)]) + dict(label=str(channel_names[i]), color=colors[i % len(colors)]) for i in range(len(channel_names)) ] if not np.isscalar(channel_names) @@ -181,7 +201,7 @@ def _fix_attrs(d: dict) -> None: elif isinstance(value, ome_types.OME): # Hack to prevent OverflowError: # Overlong 4 byte UTF-8 sequence detected when encoding string - d[key] = d[key].dict() + d[key] = d[key].model_dump(mode="json") elif isinstance(value, zarr.Group): d[key] = str(value) elif isinstance(value, list): @@ -189,7 +209,7 @@ def _fix_attrs(d: dict) -> None: if isinstance(value[i], dict): _fix_attrs(value[i]) elif isinstance(value[i], ome_types.OME): - value[i] = value[i].dict() + value[i] = value[i].model_dump(mode="json") elif isinstance(value[i], zarr.Group): value[i] = str(value) @@ -210,33 +230,8 @@ def _attrs_axes_coordinates( - Updated image attributes dictionary. - List of axes dictionaries. - List of coordinate transformations dictionaries or None. - - :example: - - .. code-block:: python - - import xarray as xr - import numpy as np - from scallops.zarr_io import _attrs_axes_coordinates - - data = np.random.rand(5, 10, 512, 512) - dims = ("c", "z", "y", "x") - coords = {"c": ["DAPI", "FITC", "TRITC", "Cy5", "Cy7"]} - array = xr.DataArray(data, dims=dims, coords=coords) - image_attrs = { - "physical_pixel_sizes": [0.1, 0.1, 0.5], - "physical_pixel_units": ["um", "um", "um"], - } - - # Prepare attributes, axes, and coordinate transformations - updated_attrs, axes, coord_transformations = _attrs_axes_coordinates( - image_attrs, array.coords, array.dims - ) - print(updated_attrs) - print(axes) - print(coord_transformations) """ - image_attrs = _fix_json(image_attrs) + omero = _create_omero_metadata(coords, dims) if omero is not None: image_attrs["omero"] = omero @@ -269,7 +264,9 @@ def _attrs_axes_coordinates( axis["unit"] = physical_pixel_units[space_index] space_index = space_index + 1 axes.append(axis) - + image_attrs = image_attrs.copy() + _fix_attrs(image_attrs) + image_attrs = _fix_json(image_attrs) return image_attrs, axes, coordinate_transformations @@ -404,49 +401,73 @@ def write_zarr( if image_attrs is not None: # Metadata can't be numpy arrays or python classes so do a round trip # conversion to convert to JSON serializable - _fix_attrs(image_attrs) + if metadata is not None: image_attrs.update(metadata) + image_attrs, axes, coordinate_transformations = _attrs_axes_coordinates( image_attrs, coords, dims ) + dask_delayed = [] + fmt = default_zarr_format() if zarr_format == "zarr": # No axis validation + zarr_array_kwargs = get_zarr_array_kwargs(fmt) if isinstance(data, da.Array): d = da.to_zarr( arr=data, url=grp.store, component=str(Path(grp.path, "0")), compute=compute, - dimension_separator=grp._store._dimension_separator, + zarr_array_kwargs=zarr_array_kwargs, ) if not compute: dask_delayed.append(d) elif not isinstance(data, zarr.Array): - grp.create_dataset("0", data=data, overwrite=True) - + grp.create_array("0", data=data, overwrite=True, **zarr_array_kwargs) + # v3 + # ome/omero for channel metadata + # ome/multiscales[0]/metadata for other metadata + + # v2: + # omero for channel metadata + # multiscales[0]/metadata for other metadata datasets = [{"path": "0"}] if coordinate_transformations is not None: datasets[0]["coordinateTransformations"] = coordinate_transformations - multiscales = [ - dict(version=CurrentFormat().version, datasets=datasets, name=grp.name) - ] - d = {"multiscales": multiscales} + multiscales = [dict(version=fmt.version, datasets=datasets, name=grp.name)] + zarr_attrs = ( + {"multiscales": multiscales} + if fmt.zarr_format == 2 + else {"ome": {"multiscales": multiscales}} + ) + if axes is not None: multiscales[0]["axes"] = axes if image_attrs is not None: - multiscales[0]["metadata"] = image_attrs if "omero" in image_attrs: - d["omero"] = image_attrs["omero"] + if fmt.zarr_format == 2: + omero = zarr_attrs.get("omero", {}) + omero.update(image_attrs.pop("omero")) + zarr_attrs["omero"] = omero + else: + omero = zarr_attrs["ome"].get("omero", {}) + omero.update(image_attrs.pop("omero")) + zarr_attrs["ome"]["omero"] = omero + + multiscales[0]["metadata"] = image_attrs + if len(dask_delayed) > 0: @dask.delayed def _write_metadata_delayed(grp, d): grp.attrs.update(d) - return dask_delayed + [bind(_write_metadata_delayed, dask_delayed)(grp, d)] + return dask_delayed + [ + bind(_write_metadata_delayed, dask_delayed)(grp, zarr_attrs) + ] else: - grp.attrs.update(d) + grp.attrs.update(zarr_attrs) return dask_delayed else: return write_image( @@ -454,8 +475,9 @@ def _write_metadata_delayed(grp, d): group=grp, scaler=scaler, axes=axes, + fmt=fmt, compute=compute, - metadata=image_attrs, + metadata=image_attrs if image_attrs is not None else {}, coordinate_transformations=( [coordinate_transformations] if coordinate_transformations is not None @@ -554,60 +576,56 @@ def _write_zarr_labels( isinstance(labels, xr.DataArray) and isinstance(labels.data, da.Array) ): labels = rechunk(labels) + fmt = default_zarr_format() return write_image( labels, grp, scaler=scaler, axes=label_axes, + fmt=fmt, metadata=metadata, compute=compute, + coordinate_transformations=None, storage_options=storage_options, ) -def _read_zarr_attrs(multiscale0: zarr.Group) -> tuple[dict, dict, list[str]]: - """Read attributes from a Zarr multiscale group. +def _read_zarr_attrs(attrs) -> tuple[dict, dict, list[str]]: + """Read attributes from Zarr. This function reads and processes the attributes, coordinates, and dimensions from the first multiscale dataset in a Zarr group. It also handles physical pixel sizes and units if available. - :param multiscale0: The Zarr group containing the multiscale dataset. + :param attrs: Zarr attributes. :return: A tuple containing: - coords: Dictionary of coordinates. - attrs: Dictionary of attributes. - dims: List of dimension names. - - :example: - - .. code-block:: python - - import zarr - from scallops.zarr_io import _read_zarr_attrs - - # Create a Zarr group with multiscale attributes - store = zarr.DirectoryStore("example.zarr") - root = zarr.group(store=store) - multiscale0 = root.create_group("multiscales") - multiscale0.attrs["axes"] = [{"name": "x"}, {"name": "y"}, {"name": "z"}] - multiscale0.attrs["datasets"] = [ - {"coordinateTransformations": [{"scale": [1.0, 0.5, 0.5]}]} - ] - - # Read attributes from the multiscale group - coords, attrs, dims = _read_zarr_attrs(multiscale0) - print(coords) - print(attrs) - print(dims) """ - attrs = multiscale0.get("metadata") - if attrs is None: - attrs = {} + # v3 + # ome/omero for channel metadata + # ome/multiscales[0]/metadata for other metadata + + # v2: + # omero for channel metadata + # multiscales[0]/metadata for other metadata + + if "ome" in attrs: + attrs = attrs["ome"] + multiscales = attrs["multiscales"] + if len(multiscales) > 0: + multiscale0 = multiscales[0] + else: + return None, None, None + axes = multiscale0["axes"] dims = [axis["name"] for axis in axes] - - coords = {d: attrs[d] for d in dims if d in attrs and d != "c"} - if "omero" in attrs and "c" in dims: + metadata = multiscale0.get("metadata") + if metadata is None: + metadata = {} + coords = {d: metadata[d] for d in dims if d in metadata and d != "c"} + if "c" in dims and "omero" in attrs: channel_names = attrs["omero"].get("channels") if channel_names is not None: coords["c"] = [c["label"] for c in channel_names] @@ -624,9 +642,9 @@ def _read_zarr_attrs(multiscale0: zarr.Group) -> tuple[dict, dict, list[str]]: if len(space_indices_with_units) > 0: scale = multiscale0["datasets"][0]["coordinateTransformations"][0]["scale"] physical_pixel_sizes = tuple([scale[d] for d in space_indices_with_units]) - attrs["physical_pixel_sizes"] = physical_pixel_sizes - attrs["physical_pixel_units"] = tuple(units) - return coords, attrs, dims + metadata["physical_pixel_sizes"] = physical_pixel_sizes + metadata["physical_pixel_units"] = tuple(units) + return coords, metadata, dims def _read_ome_zarr_array( @@ -643,14 +661,9 @@ def _read_ome_zarr_array( node = zarr.open(node, mode="r") if node is None: raise ValueError(f"{_node} not found") - if "multiscales" in node.attrs: - dims = None - coords = {} - attrs = {} - multiscales = node.attrs["multiscales"] - if len(multiscales) > 0: - multiscale0 = multiscales[0] - coords, attrs, dims = _read_zarr_attrs(multiscale0) + # For zarr v3, everything is under the "ome" namespace + if "ome" in node.attrs or "multiscales" in node.attrs: + coords, attrs, dims = _read_zarr_attrs(node.attrs) array = node["0"] return array, dims, coords, attrs else: # see if user passed test.zarr and zarr file only has one image @@ -659,7 +672,7 @@ def _read_ome_zarr_array( image_keys = list(images.keys()) if len(image_keys) == 1: return _read_ome_zarr_array(images[image_keys[0]]) - logger.warning("multiscales not found in attrs") + logger.warning(f"multiscales not found in attrs for {node} ") def read_ome_zarr_array( @@ -705,10 +718,11 @@ def open_ome_zarr(url: Path | str, mode: str = "a") -> zarr.Group | None: """ try: - loc = parse_url(url, mode=mode) + fmt = default_zarr_format() + loc = parse_url(url, mode=mode, fmt=fmt) if loc is None: return None - return zarr.open(loc.store, mode=mode) + return zarr.open(loc.store, mode=mode, zarr_format=fmt.zarr_format) except Exception as e: logger.error(f"Failed to open OME-Zarr store: {url}") raise e