Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 52 additions & 43 deletions cellmap_flow/blockwise/blockwise_processor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
import logging
logging.getLogger().setLevel(logging.INFO)
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
logging.getLogger().setLevel(logging.INFO)
logging.basicConfig(level=logging.INFO)

import subprocess
from pathlib import Path

Expand All @@ -16,8 +22,9 @@
from cellmap_flow.utils.config_utils import build_models, load_config
from cellmap_flow.utils.serilization_utils import get_process_dataset
from cellmap_flow.utils.ds import generate_singlescale_metadata
from cellmap_flow.models.model_merger import get_model_merger


logger = logging.getLogger(__name__)


class CellMapFlowBlockwiseProcessor:
Expand Down Expand Up @@ -98,23 +105,41 @@ def __init__(self, yaml_config: str, create=False):

# Support multiple models with model_mode
self.models = models
self.model_mode = self.config.get("model_mode", "AND").upper()
if self.model_mode not in ["AND", "OR", "SUM"]:
raise Exception(
f"Invalid model_mode: {self.model_mode}. Must be one of: AND, OR, SUM"
)
self.model_mode_str = self.config.get("model_mode", "AND").upper()
try:
self.model_merger = get_model_merger(self.model_mode_str)
except ValueError as e:
raise Exception(str(e))

if len(models) > 1:
logger.info(
f"Using {len(models)} models with merge mode: {self.model_mode}"
f"Using {len(models)} models with merge mode: {self.model_mode_str}"
)

# Support cross-channel processing
self.process_only = self.config.get("process_only", None)
self.cross_channels_mode = self.config.get("cross_channels", None)

if self.cross_channels_mode:
self.cross_channels_mode = self.cross_channels_mode.upper()
try:
self.cross_channels_merger = get_model_merger(self.cross_channels_mode)
except ValueError as e:
raise Exception(f"Invalid cross_channels setting: {e}")
else:
self.cross_channels_merger = None

if self.process_only:
logger.info(
f"Processing only channels: {self.process_only} with merge mode: {self.cross_channels_mode}"
)

self.model_config = models[0]

# this is zyx

block_shape = [int(x) for x in self.model_config.config.block_shape][:3]
self.block_shape = self.config.get("block_size", block_shape)
self.block_shape = tuple(self.config.get("block_size", block_shape))

self.input_voxel_size = Coordinate(self.model_config.config.input_voxel_size)
self.output_voxel_size = Coordinate(self.model_config.config.output_voxel_size)
Expand All @@ -133,6 +158,9 @@ def __init__(self, yaml_config: str, create=False):
self.output_channel_names = self.channels
self.output_channel_indices = None

if not isinstance(self.output_channels, list):
self.output_channels = [self.output_channels]

if json_data:
g.input_norms, g.postprocess = get_process_dataset(json_data)

Expand Down Expand Up @@ -161,6 +189,12 @@ def __init__(self, yaml_config: str, create=False):

# Ensure we have output channels to iterate over
channels_to_create = self.output_channels if self.output_channels else []
if not isinstance(channels_to_create, list):
channels_to_create = [channels_to_create]

# check if there is two channels_to_create with same name
if len(channels_to_create) != len(set(channels_to_create)):
raise Exception(f"output_channels has duplicated channel names. channels: {channels_to_create}")

for channel in channels_to_create:
if create:
Expand Down Expand Up @@ -190,7 +224,7 @@ def __init__(self, yaml_config: str, create=False):
chunk_shape=(
self.block_shape
if len(final_output_shape) == 3
else (len(channel_indices),) + tuple(self.block_shape)
else (len(channel_indices),) + self.block_shape
),
voxel_size=(
self.output_voxel_size
Expand Down Expand Up @@ -254,6 +288,8 @@ def __init__(self, yaml_config: str, create=False):
if "multiscales" in list(zg.attrs):
old_multiscales = zg.attrs["multiscales"]
if old_multiscales != zattrs["multiscales"]:
logger.info(f"Old multiscales: {old_multiscales}")
logger.info(f"New multiscales: {zattrs['multiscales']}")
raise ValueError(
f"multiscales attribute already exists in {z_store.path} and is different from the new one"
)
Expand Down Expand Up @@ -302,10 +338,15 @@ def process_fn(self, block):
model_outputs = []
for inferencer in self.inferencers:
output = inferencer.process_chunk(self.idi_raw, block.write_roi)
if self.process_only and self.cross_channels_merger:
# Extract only the specified channels
channel_outputs = [output[ch_idx] for ch_idx in self.process_only]
# Merge the extracted channels based on cross_channels mode
output = self.cross_channels_merger.merge(channel_outputs)
model_outputs.append(output)

# Merge outputs based on model_mode
chunk_data = self._merge_model_outputs(model_outputs)
chunk_data = self.model_merger.merge(model_outputs)

chunk_data = chunk_data.astype(self.dtype)

Expand Down Expand Up @@ -372,38 +413,6 @@ def process_fn(self, block):
continue
array[array_write_roi] = predictions.to_ndarray(array_write_roi)

def _merge_model_outputs(self, model_outputs):
"""
Merge outputs from multiple models based on the configured model_mode.

Args:
model_outputs: List of numpy arrays from different models

Returns:
Merged numpy array
"""
if self.model_mode == "AND":
# Element-wise minimum (logical AND for binary, minimum for continuous)
merged = model_outputs[0]
for output in model_outputs[1:]:
merged = np.minimum(merged, output)
return merged

elif self.model_mode == "OR":
# Element-wise maximum (logical OR for binary, maximum for continuous)
merged = model_outputs[0]
for output in model_outputs[1:]:
merged = np.maximum(merged, output)
return merged

elif self.model_mode == "SUM":
# Sum all outputs and normalize by number of models
merged = np.sum(model_outputs, axis=0) / len(model_outputs)
return merged

else:
raise ValueError(f"Unknown model_mode: {self.model_mode}")

def client(self):
client = daisy.Client()
while True:
Expand Down Expand Up @@ -484,7 +493,7 @@ def run_worker():
f"prediction_logs/out.out",
"-e",
f"prediction_logs/out.err",
"cellmap_flow_blockwise_processor",
"cellmap_flow_blockwise",
f"{yaml_config}",
"--client",
]
Expand Down
8 changes: 7 additions & 1 deletion cellmap_flow/blockwise/cli.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import click
import logging

import logging
logging.getLogger().setLevel(logging.INFO)
logging.basicConfig(level=logging.INFO)
from cellmap_flow.blockwise import CellMapFlowBlockwiseProcessor


Expand Down Expand Up @@ -32,3 +34,7 @@ def cli(yaml_config, client, log_level):


logger = logging.getLogger(__name__)


if __name__ == "__main__":
cli()
10 changes: 5 additions & 5 deletions cellmap_flow/cli/yaml_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,23 +172,23 @@ def main(config_path: str, log_level: str, list_types: bool, validate_only: bool

# Build model configuration objects dynamically
logger.info("Building model configurations...")
models = build_models(config["models"])
g.models_config = build_models(config["models"])

logger.info(f"Configured {len(models)} model(s):")
for i, model in enumerate(models, 1):
logger.info(f"Configured {len(g.models_config)} model(s):")
for i, model in enumerate(g.models_config, 1):
model_name = getattr(model, "name", None) or type(model).__name__
logger.info(f" {i}. {model_name} ({type(model).__name__})")

# Validation mode - exit without running
if validate_only:
click.echo("\n✓ Configuration is valid!")
click.echo(f" - Models: {len(models)}")
click.echo(f" - Models: {len(g.models_config)}")
click.echo(f" - Data path: {data_path}")
click.echo(f" - Queue: {queue}")
return

# Run the models
run_multiple(models, data_path, charge_group, queue)
run_multiple(g.models_config, data_path, charge_group, queue)


if __name__ == "__main__":
Expand Down
Loading