diff --git a/cellmap_flow/post/postprocessors.py b/cellmap_flow/post/postprocessors.py index ef814c9..47bcafc 100644 --- a/cellmap_flow/post/postprocessors.py +++ b/cellmap_flow/post/postprocessors.py @@ -7,7 +7,7 @@ import threading from scipy.ndimage import label import mwatershed as mws -from scipy.ndimage import measurements +from scipy.ndimage import measurements, gaussian_filter import fastremap from funlib.math import cantor_number import fastmorph @@ -131,7 +131,9 @@ def is_segmentation(self): class AffinityPostprocessor(PostProcessor): def __init__( self, - bias: float = 0.0, + adjacent_edge_bias: float = -0.4, + lr_bias_ratio: float = -0.175, + filter_val: float = 0.5, neighborhood: str = """[ [1, 0, 0], [0, 1, 0], @@ -145,36 +147,106 @@ def __init__( ]""", ): use_exact = "True" - self.bias = float(bias) + self.adjacent_edge_bias = float(adjacent_edge_bias) + self.lr_bias_ratio = float(lr_bias_ratio) + self.filter_val = float(filter_val) self.neighborhood = ast.literal_eval(neighborhood) - self.use_exact = use_exact == "True" + self.use_exact = use_exact == "False" self.num_previous_segments = 0 - def _process(self, data, chunk_num_voxels, chunk_corner): - data = data / 255.0 - n_channels = data.shape[0] - self.neighborhood = self.neighborhood[:n_channels] - # raise Exception(data.max(), data.min(), self.neighborhood) + import numpy as np + from scipy.ndimage import measurements - segmentation = mws.agglom( - data.astype(np.float64) - self.bias, - self.neighborhood, - ) + def filter_fragments( + self, affs_data: np.ndarray, fragments_data: np.ndarray, filter_val: float + ) -> None: + """Allows filtering of MWS fragments based on mean value of affinities & fragments. Will filter and update the fragment array in-place. - # filter fragments - average_affs = np.mean(data, axis=0) + Args: + aff_data (``np.ndarray``): + An array containing affinity data. + + fragments_data (``np.ndarray``): + An array containing fragment data. + + filter_val (``float``): + Threshold to filter if the average value falls below. + """ - filtered_fragments = [] + average_affs: float = np.mean(affs_data.data, axis=0) - fragment_ids = fastremap.unique(segmentation[segmentation > 0]) + filtered_fragments: list = [] + + fragment_ids: np.ndarray = np.unique(fragments_data) for fragment, mean in zip( - fragment_ids, measurements.mean(average_affs, segmentation, fragment_ids) + fragment_ids, measurements.mean(average_affs, fragments_data, fragment_ids) ): - if mean >= self.bias: + if mean < filter_val: filtered_fragments.append(fragment) - fastremap.mask_except(segmentation, filtered_fragments, in_place=True) + filtered_fragments: np.ndarray = np.array( + filtered_fragments, dtype=fragments_data.dtype + ) + # replace: np.ndarray = np.zeros_like(filtered_fragments) + fastremap.mask(fragments_data, filtered_fragments, in_place=True) + + def _process(self, data, chunk_num_voxels, chunk_corner): + data[data < self.filter_val] = 0 + if data.dtype == np.uint8: + logger.info("Assuming affinities are in [0,255]") + max_affinity_value: float = 255.0 + data = data.astype(np.float64) + else: + data = data.astype(np.float64) + max_affinity_value: float = 1.0 + + data /= max_affinity_value + + if data.max() < 1e-4: + segmentation = np.zeros( + data.shape, dtype=np.uint64 if self.use_exact else np.uint16 + ) + return np.expand_dims(segmentation, axis=0) + + channels = [ + channel for channel, ntp in enumerate(self.neighborhood) if ntp is not None + ] + neighborhood = [self.neighborhood[channel] for channel in channels] + + data = data[channels] + random_noise: float = np.random.randn(*data.shape) * 0.0001 + smoothed_affs: np.ndarray = ( + gaussian_filter(data, sigma=(0, *(np.amax(neighborhood, axis=0) / 3))) - 0.5 + ) * 0.001 + shift: np.ndarray = np.array( + [ + ( + self.adjacent_edge_bias + if max(offset) <= 1 + else np.linalg.norm(offset) * self.lr_bias_ratio + ) + for offset in neighborhood + ] + ).reshape((-1, *((1,) * (len(data.shape) - 1)))) + + # raise Exception(data.max(), data.min(), self.neighborhood) + + # segmentation = mws.agglom( + # data.astype(np.float64) - self.bias, + # self.neighborhood, + # ) + + # filter fragments + segmentation = mws.agglom( + data + shift + random_noise + smoothed_affs, + offsets=neighborhood, + ) + if self.filter_val > 0.0: + self.filter_fragments(data, segmentation, self.filter_val) + + # fragment_ids = fastremap.unique(segmentation[segmentation > 0]) + # fastremap.mask_except(segmentation, filtered_fragments, in_place=True) fastremap.renumber(segmentation, in_place=True) unique_increment = chunk_num_voxels * pymorton.interleave(*chunk_corner) if not self.use_exact: diff --git a/cellmap_flow/utils/data.py b/cellmap_flow/utils/data.py index c6e1c42..59c2b05 100644 --- a/cellmap_flow/utils/data.py +++ b/cellmap_flow/utils/data.py @@ -122,7 +122,7 @@ def _get_config(self): config.output_channels = len( config.channels ) # 0:all_mem,1:organelle,2:mito,3:er,4:nucleus,5:pm,6:vs,7:ld - config.block_shape = np.array(tuple(out_shape) + (len(channels),)) + config.block_shape = np.array(tuple(out_shape) + (config.output_channels,)) return config @@ -384,7 +384,8 @@ def get_dacapo_channels(task): if hasattr(task, "channels"): return task.channels elif type(task).__name__ == "AffinitiesTask": - return ["x", "y", "z"] + # to be backwards compatible in case .channels or .neighborhood doesn't exist + return [f"aff_{'.'.join(map(str, n))}" for n in task.predictor.neighborhood] else: return ["membrane"]