diff --git a/.gitignore b/.gitignore index acd9206..f6917a5 100644 --- a/.gitignore +++ b/.gitignore @@ -162,4 +162,13 @@ cython_debug/ #.idea/ # Misc -.vscode/ \ No newline at end of file +.vscode/ + +# Project-specific +ignore/ +*.zarr/ +.claude/ +test_corrections.zarr/ +correction_slices/ +corrections/ +output/ \ No newline at end of file diff --git a/cellmap_flow/cli/server_cli.py b/cellmap_flow/cli/server_cli.py index a86e98c..1c52b05 100644 --- a/cellmap_flow/cli/server_cli.py +++ b/cellmap_flow/cli/server_cli.py @@ -23,7 +23,6 @@ from cellmap_flow.utils.plugin_manager import load_plugins -logging.basicConfig() logger = logging.getLogger(__name__) @@ -47,7 +46,7 @@ def cli(log_level): cellmap_flow_server script -s /path/to/script.py -d /path/to/data cellmap_flow_server cellmap-model -f /path/to/model -n mymodel -d /path/to/data """ - logging.basicConfig(level=getattr(logging, log_level.upper())) + logging.basicConfig(level=getattr(logging, log_level.upper()), force=True) @cli.command(name="list-models") @@ -82,6 +81,9 @@ def create_dynamic_server_command(cli_name: str, config_class: Type[ModelConfig] except: type_hints = {} + # Track used short names to avoid collisions with common options. + used_short_names = {"-d", "-p"} + # Create the command function def command_func(**kwargs): # Separate model config kwargs from server kwargs @@ -141,7 +143,9 @@ def command_func(**kwargs): # Add model-specific options based on constructor parameters for param_name, param_info in reversed(list(sig.parameters.items())): - option_config = create_click_option_from_param(param_name, param_info) + option_config = create_click_option_from_param( + param_name, param_info, used_short_names + ) if option_config: command_func = click.option( *option_config.pop("param_decls"), **option_config diff --git a/cellmap_flow/cli/viewer_cli.py b/cellmap_flow/cli/viewer_cli.py new file mode 100644 index 0000000..7bd5d3d --- /dev/null +++ b/cellmap_flow/cli/viewer_cli.py @@ -0,0 +1,77 @@ +""" +Simple CLI for viewing datasets with CellMap Flow without requiring model configs. +""" + +import click +import logging +import neuroglancer +from cellmap_flow.dashboard.app import create_and_run_app +from cellmap_flow.globals import g +from cellmap_flow.utils.scale_pyramid import get_raw_layer + +logging.basicConfig() +logger = logging.getLogger(__name__) + + +@click.command() +@click.option( + "-d", + "--dataset", + required=True, + type=str, + help="Path to the dataset (zarr or n5)", +) +@click.option( + "--log-level", + type=click.Choice( + ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], case_sensitive=False + ), + default="INFO", + help="Set the logging level", +) +def main(dataset, log_level): + """ + Start CellMap Flow viewer with a dataset. + + Example: + cellmap_flow_viewer -d /path/to/dataset.zarr + """ + logging.basicConfig(level=getattr(logging, log_level.upper())) + + logger.info(f"Starting CellMap Flow viewer with dataset: {dataset}") + + # Set up neuroglancer server + neuroglancer.set_server_bind_address("0.0.0.0") + + # Create viewer + viewer = neuroglancer.Viewer() + + # Set dataset path in globals + g.dataset_path = dataset + g.viewer = viewer + + # Add dataset layer to viewer + with viewer.txn() as s: + # Set coordinate space + s.dimensions = neuroglancer.CoordinateSpace( + names=["z", "y", "x"], + units="nm", + scales=[8, 8, 8], + ) + + # Add data layer + s.layers["data"] = get_raw_layer(dataset) + + # Print viewer URL + logger.info(f"Neuroglancer viewer URL: {viewer}") + print(f"\n{'='*80}") + print(f"Neuroglancer viewer: {viewer}") + print(f"Dataset: {dataset}") + print(f"{'='*80}\n") + + # Start the dashboard app + create_and_run_app(neuroglancer_url=str(viewer), inference_servers=None) + + +if __name__ == "__main__": + main() diff --git a/cellmap_flow/dashboard/app.py b/cellmap_flow/dashboard/app.py index 6210495..b261870 100644 --- a/cellmap_flow/dashboard/app.py +++ b/cellmap_flow/dashboard/app.py @@ -5,8 +5,7 @@ from flask import Flask from flask_cors import CORS -from cellmap_flow.dashboard import state -from cellmap_flow.dashboard.state import LogHandler +from cellmap_flow.globals import g, LogHandler from cellmap_flow.dashboard.routes.logging_routes import logging_bp from cellmap_flow.dashboard.routes.index_page import index_bp from cellmap_flow.dashboard.routes.pipeline_builder_page import pipeline_builder_bp @@ -14,6 +13,7 @@ from cellmap_flow.dashboard.routes.pipeline import pipeline_bp from cellmap_flow.dashboard.routes.blockwise import blockwise_bp from cellmap_flow.dashboard.routes.bbx_generator import bbx_bp +from cellmap_flow.dashboard.routes.finetune_routes import finetune_bp logger = logging.getLogger(__name__) @@ -37,11 +37,12 @@ app.register_blueprint(pipeline_bp) app.register_blueprint(blockwise_bp) app.register_blueprint(bbx_bp) +app.register_blueprint(finetune_bp) def create_and_run_app(neuroglancer_url=None, inference_servers=None): - state.NEUROGLANCER_URL = neuroglancer_url - state.INFERENCE_SERVER = inference_servers + g.NEUROGLANCER_URL = neuroglancer_url + g.INFERENCE_SERVER = inference_servers hostname = socket.gethostname() port = 0 logger.warning(f"Host name: {hostname}") diff --git a/cellmap_flow/dashboard/finetune_utils.py b/cellmap_flow/dashboard/finetune_utils.py new file mode 100644 index 0000000..487a070 --- /dev/null +++ b/cellmap_flow/dashboard/finetune_utils.py @@ -0,0 +1,1001 @@ +""" +Helper functions for finetuning annotation workflows. + +Handles MinIO server management, annotation zarr creation, and +periodic synchronization of annotations between MinIO and local disk. +""" + +import json +import os +import re +import socket +import subprocess +import time +import logging +import threading +from concurrent.futures import ThreadPoolExecutor, as_completed +from datetime import datetime +from pathlib import Path + +import numpy as np +import s3fs +import zarr + +from cellmap_flow.globals import g + +minio_state = g.minio_state +annotation_volumes = g.annotation_volumes +output_sessions = g.output_sessions + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# Session management +# --------------------------------------------------------------------------- + +def get_or_create_session_path(base_output_path: str) -> str: + """ + Get or create a timestamped session directory for the given base output path. + + If a session already exists for this base path, reuse it. + Otherwise, create a new timestamped subdirectory. + + Args: + base_output_path: Base output directory (e.g., "output/to/here") + + Returns: + Timestamped session path (e.g., "output/to/here/20260213_123456") + """ + base_output_path = os.path.expanduser(base_output_path) + + if base_output_path not in output_sessions: + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + session_path = os.path.join(base_output_path, timestamp) + output_sessions[base_output_path] = session_path + logger.info(f"Created new session path: {session_path}") + + return output_sessions[base_output_path] + + +# --------------------------------------------------------------------------- +# Network helpers +# --------------------------------------------------------------------------- + +def get_local_ip(): + """Get the local IP address for MinIO server.""" + try: + s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + s.connect(("8.8.8.8", 80)) + local_ip = s.getsockname()[0] + s.close() + return local_ip + except Exception: + return "127.0.0.1" + + +def find_available_port(start_port=9000): + """Find an available port pair for MinIO server (API on port, console on port+1).""" + for port in range(start_port, start_port + 100): + try: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s1: + s1.bind(("", port)) + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s2: + s2.bind(("", port + 1)) + return port + except OSError: + continue + raise RuntimeError("Could not find available port for MinIO") + + +# --------------------------------------------------------------------------- +# Zarr creation +# --------------------------------------------------------------------------- + +def create_correction_zarr( + zarr_path, + raw_crop_shape, + raw_voxel_size, + raw_offset, + annotation_crop_shape, + annotation_voxel_size, + annotation_offset, + dataset_path, + model_name, + output_channels, + raw_dtype="uint8", + create_mask=False, +): + """ + Create a correction zarr with OME-NGFF v0.4 metadata. + + Structure: + crop_id.zarr/ + raw/s0/ (uint8, shape=raw_crop_shape) + annotation/s0/ (uint8, shape=annotation_crop_shape) + mask/s0/ (optional, uint8, shape=annotation_crop_shape) + .zattrs (metadata) + + Returns: + (success: bool, info: str) + """ + try: + def add_ome_ngff_metadata(group, name, voxel_size, translation_offset=None): + """Add OME-NGFF v0.4 metadata.""" + if translation_offset is not None: + physical_translation = [ + float(o * v) for o, v in zip(translation_offset, voxel_size) + ] + else: + physical_translation = [0.0, 0.0, 0.0] + + transforms = [{"type": "scale", "scale": [float(v) for v in voxel_size]}] + + if translation_offset is not None: + transforms.append( + {"type": "translation", "translation": physical_translation} + ) + + group.attrs["multiscales"] = [ + { + "version": "0.4", + "name": name, + "axes": [ + {"name": "z", "type": "space", "unit": "nanometer"}, + {"name": "y", "type": "space", "unit": "nanometer"}, + {"name": "x", "type": "space", "unit": "nanometer"}, + ], + "datasets": [ + {"path": "s0", "coordinateTransformations": transforms} + ], + } + ] + + root = zarr.open(zarr_path, mode="w") + + # Raw group + raw_group = root.create_group("raw") + raw_group.create_dataset( + "s0", + shape=tuple(raw_crop_shape), + chunks=(64, 64, 64), + dtype=raw_dtype, + compressor=zarr.Blosc(cname="zstd", clevel=3, shuffle=zarr.Blosc.SHUFFLE), + fill_value=0, + ) + add_ome_ngff_metadata(raw_group, "raw", raw_voxel_size, raw_offset) + + # Annotation group + annotation_group = root.create_group("annotation") + annotation_group.create_dataset( + "s0", + shape=tuple(annotation_crop_shape), + chunks=(64, 64, 64), + dtype="uint8", + compressor=zarr.Blosc(cname="zstd", clevel=3, shuffle=zarr.Blosc.SHUFFLE), + fill_value=0, + ) + add_ome_ngff_metadata( + annotation_group, "annotation", annotation_voxel_size, annotation_offset + ) + + # Optional mask group + if create_mask: + mask_group = root.create_group("mask") + mask_group.create_dataset( + "s0", + shape=tuple(annotation_crop_shape), + chunks=(64, 64, 64), + dtype="uint8", + compressor=zarr.Blosc( + cname="zstd", clevel=3, shuffle=zarr.Blosc.SHUFFLE + ), + fill_value=0, + ) + add_ome_ngff_metadata( + mask_group, "mask", annotation_voxel_size, annotation_offset + ) + + # Root metadata + root.attrs["roi"] = { + "raw_offset": ( + raw_offset.tolist() + if hasattr(raw_offset, "tolist") + else list(raw_offset) + ), + "raw_shape": ( + raw_crop_shape.tolist() + if hasattr(raw_crop_shape, "tolist") + else list(raw_crop_shape) + ), + "annotation_offset": ( + annotation_offset.tolist() + if hasattr(annotation_offset, "tolist") + else list(annotation_offset) + ), + "annotation_shape": ( + annotation_crop_shape.tolist() + if hasattr(annotation_crop_shape, "tolist") + else list(annotation_crop_shape) + ), + } + root.attrs["raw_voxel_size"] = ( + raw_voxel_size.tolist() + if hasattr(raw_voxel_size, "tolist") + else list(raw_voxel_size) + ) + root.attrs["annotation_voxel_size"] = ( + annotation_voxel_size.tolist() + if hasattr(annotation_voxel_size, "tolist") + else list(annotation_voxel_size) + ) + root.attrs["model_name"] = model_name + root.attrs["dataset_path"] = dataset_path + root.attrs["created_at"] = datetime.now().isoformat() + + logger.info(f"Created correction zarr at {zarr_path}") + + return True, zarr_path + + except Exception as e: + logger.error(f"Error creating zarr: {e}") + return False, str(e) + + +def create_annotation_volume_zarr( + zarr_path, + dataset_shape_voxels, + output_voxel_size, + dataset_offset_nm, + chunk_size, + dataset_path, + model_name, + input_size, + input_voxel_size, +): + """ + Create a sparse annotation volume zarr covering the full dataset extent. + + The volume has chunk_size = model output_size so each chunk maps to one + training sample. Only metadata files are created (no chunk data), so the + zarr is tiny regardless of dataset size. + + Label scheme: 0=unannotated (ignored), 1=background, 2=foreground. + + Returns: + (success: bool, info: str) + """ + try: + root = zarr.open(zarr_path, mode="w") + + annotation_group = root.create_group("annotation") + annotation_group.create_dataset( + "s0", + shape=tuple(dataset_shape_voxels), + chunks=tuple(chunk_size), + dtype="uint8", + compressor=zarr.Blosc(cname="zstd", clevel=3, shuffle=zarr.Blosc.SHUFFLE), + fill_value=0, + ) + + # OME-NGFF v0.4 metadata with translation for dataset offset + physical_translation = [float(o) for o in dataset_offset_nm] + transforms = [ + {"type": "scale", "scale": [float(v) for v in output_voxel_size]}, + {"type": "translation", "translation": physical_translation}, + ] + annotation_group.attrs["multiscales"] = [ + { + "version": "0.4", + "name": "annotation", + "axes": [ + {"name": "z", "type": "space", "unit": "nanometer"}, + {"name": "y", "type": "space", "unit": "nanometer"}, + {"name": "x", "type": "space", "unit": "nanometer"}, + ], + "datasets": [ + {"path": "s0", "coordinateTransformations": transforms} + ], + } + ] + + # Root metadata + root.attrs["type"] = "annotation_volume" + root.attrs["model_name"] = model_name + root.attrs["dataset_path"] = dataset_path + root.attrs["chunk_size"] = ( + chunk_size.tolist() if hasattr(chunk_size, "tolist") else list(chunk_size) + ) + root.attrs["output_voxel_size"] = ( + output_voxel_size.tolist() + if hasattr(output_voxel_size, "tolist") + else list(output_voxel_size) + ) + root.attrs["input_size"] = ( + input_size.tolist() if hasattr(input_size, "tolist") else list(input_size) + ) + root.attrs["input_voxel_size"] = ( + input_voxel_size.tolist() + if hasattr(input_voxel_size, "tolist") + else list(input_voxel_size) + ) + root.attrs["dataset_offset_nm"] = ( + dataset_offset_nm.tolist() + if hasattr(dataset_offset_nm, "tolist") + else list(dataset_offset_nm) + ) + root.attrs["dataset_shape_voxels"] = ( + dataset_shape_voxels.tolist() + if hasattr(dataset_shape_voxels, "tolist") + else list(dataset_shape_voxels) + ) + root.attrs["created_at"] = datetime.now().isoformat() + + logger.info( + f"Created annotation volume zarr at {zarr_path} " + f"(shape={dataset_shape_voxels}, chunks={chunk_size})" + ) + + return True, zarr_path + + except Exception as e: + logger.error(f"Error creating annotation volume zarr: {e}") + return False, str(e) + + +# --------------------------------------------------------------------------- +# MinIO management +# --------------------------------------------------------------------------- + +def ensure_minio_serving(zarr_path, crop_id, output_base_dir=None): + """ + Ensure MinIO is running and upload zarr file. + + Args: + zarr_path: Path to zarr file to upload + crop_id: Unique identifier for the crop + output_base_dir: Base output directory (MinIO will use output_base_dir/.minio) + + Returns: + MinIO URL for the zarr file + """ + if minio_state["process"] is None or minio_state["process"].poll() is not None: + # Determine MinIO storage location + if output_base_dir: + minio_root = Path(output_base_dir) / ".minio" + minio_state["output_base"] = output_base_dir + else: + minio_root = Path("~/.minio-server").expanduser() + minio_state["output_base"] = None + + minio_root.mkdir(parents=True, exist_ok=True) + minio_state["minio_root"] = str(minio_root) + + ip = get_local_ip() + port = find_available_port() + + env = os.environ.copy() + env["MINIO_ROOT_USER"] = "minio" + env["MINIO_ROOT_PASSWORD"] = "minio123" + env["MINIO_API_CORS_ALLOW_ORIGIN"] = "*" + + minio_cmd = [ + "minio", + "server", + str(minio_root), + "--address", + f"{ip}:{port}", + "--console-address", + f"{ip}:{port+1}", + ] + + logger.info(f"Starting MinIO server at {ip}:{port}") + minio_proc = subprocess.Popen( + minio_cmd, env=env, stdout=subprocess.PIPE, stderr=subprocess.PIPE + ) + time.sleep(3) + + if minio_proc.poll() is not None: + stderr = minio_proc.stderr.read().decode() if minio_proc.stderr else "" + raise RuntimeError(f"MinIO failed to start: {stderr}") + + minio_state["process"] = minio_proc + minio_state["port"] = port + minio_state["ip"] = ip + + logger.info(f"MinIO started (PID: {minio_proc.pid})") + + # Configure mc client + subprocess.run( + [ + "mc", + "alias", + "set", + "myserver", + f"http://{ip}:{port}", + "minio", + "minio123", + ], + check=True, + capture_output=True, + ) + + # Create bucket if needed + result = subprocess.run( + ["mc", "mb", f"myserver/{minio_state['bucket']}"], + capture_output=True, + text=True, + ) + if result.returncode != 0 and "already" not in result.stderr.lower(): + logger.warning(f"Bucket creation returned: {result.stderr}") + + # Make bucket public + subprocess.run( + ["mc", "anonymous", "set", "public", f"myserver/{minio_state['bucket']}"], + check=True, + capture_output=True, + ) + + # Start periodic sync thread + start_periodic_sync() + + # Upload zarr file + zarr_name = Path(zarr_path).name + target = f"myserver/{minio_state['bucket']}/{zarr_name}" + + logger.info(f"Uploading {zarr_name} to MinIO") + result = subprocess.run( + ["mc", "mirror", "--overwrite", zarr_path, target], + capture_output=True, + text=True, + ) + + if result.returncode != 0: + raise RuntimeError(f"Failed to upload to MinIO: {result.stderr}") + + logger.info(f"Uploaded {zarr_name} to MinIO") + + minio_url = ( + f"http://{minio_state['ip']}:{minio_state['port']}" + f"/{minio_state['bucket']}/{zarr_name}" + ) + return minio_url + + +# --------------------------------------------------------------------------- +# S3 / MinIO sync helpers +# --------------------------------------------------------------------------- + +def _safe_epoch_timestamp(value) -> float: + """Convert LastModified-like values to epoch seconds, best-effort.""" + if value is None: + return 0.0 + if isinstance(value, datetime): + return float(value.timestamp()) + if isinstance(value, (int, float)): + return float(value) + try: + parsed = datetime.fromisoformat(str(value)) + return float(parsed.timestamp()) + except Exception: + return 0.0 + + +def _get_sync_worker_count() -> int: + """ + Determine thread count for chunk sync. + + Prefer scheduler-provided CPU counts (e.g., LSF bsub -n), then fall back + to process CPU affinity / system CPU count. + """ + env_candidates = [ + "LSB_DJOB_NUMPROC", + "LSB_MAX_NUM_PROCESSORS", + "NSLOTS", + "SLURM_CPUS_PER_TASK", + "OMP_NUM_THREADS", + ] + for key in env_candidates: + raw = os.environ.get(key) + if not raw: + continue + try: + value = int(raw) + if value > 0: + return value + except ValueError: + continue + + try: + return max(1, len(os.sched_getaffinity(0))) + except Exception: + return max(1, os.cpu_count() or 1) + + +def _copy_chunks_parallel(s3, copy_pairs): + """ + Copy chunk files from MinIO in parallel. + + Args: + s3: s3fs filesystem instance + copy_pairs: list of (src_chunk_path, dst_chunk_path_str) + """ + if not copy_pairs: + return + + available_workers = _get_sync_worker_count() + workers = max(1, min(len(copy_pairs), available_workers)) + + def _copy_one(src_dst): + src_chunk_path, dst_chunk_path = src_dst + s3.get(src_chunk_path, dst_chunk_path) + return src_chunk_path + + with ThreadPoolExecutor(max_workers=workers) as executor: + futures = [executor.submit(_copy_one, pair) for pair in copy_pairs] + for fut in as_completed(futures): + try: + fut.result() + except Exception as e: + logger.debug(f"Error syncing chunk in parallel copy: {e}") + + +def _make_s3_filesystem(): + """Create an s3fs filesystem pointed at the local MinIO instance.""" + return s3fs.S3FileSystem( + anon=False, + key="minio", + secret="minio123", + client_kwargs={ + "endpoint_url": f"http://{minio_state['ip']}:{minio_state['port']}", + "region_name": "us-east-1", + }, + ) + + +def _sync_zarr_group_metadata(s3, src_path, dst_path): + """Sync zarr group structure and metadata from S3 to local disk. + + Ensures destination arrays exist with correct shape/dtype and copies attrs. + """ + src_store = s3fs.S3Map(root=src_path, s3=s3) + src_group = zarr.open_group(store=src_store, mode="r") + + dst_store = zarr.DirectoryStore(str(dst_path)) + dst_group = zarr.open_group(store=dst_store, mode="a") + + for key in src_group.array_keys(): + src_array = src_group[key] + if key not in dst_group: + dst_group.create_dataset( + key, + shape=src_array.shape, + chunks=src_array.chunks, + dtype=src_array.dtype, + fill_value=0, + overwrite=True, + ) + dst_group[key].attrs.update(src_array.attrs) + + dst_group.attrs.update(src_group.attrs) + + +def _diff_and_sync_chunks(s3, s0_path, dst_s0_path, known_chunk_state, force=False): + """Diff remote vs known chunk state and sync changed chunks to local disk. + + Returns: + (changed_keys, removed_keys, remote_chunk_state) + """ + try: + chunk_files = s3.ls(s0_path) + except FileNotFoundError: + chunk_files = [] + + remote_chunk_state = {} + for chunk_file in chunk_files: + chunk_key = Path(chunk_file).name + if not re.match(r"^\d+\.\d+\.\d+$", chunk_key): + continue + try: + info = s3.info(chunk_file) + remote_chunk_state[chunk_key] = _safe_epoch_timestamp(info.get("LastModified")) + except Exception: + remote_chunk_state[chunk_key] = 0.0 + + if force: + changed_keys = list(remote_chunk_state.keys()) + else: + changed_keys = [k for k, v in remote_chunk_state.items() if known_chunk_state.get(k) != v] + removed_keys = [k for k in known_chunk_state if k not in remote_chunk_state] + + if not changed_keys and not removed_keys: + return [], [], remote_chunk_state + + # Copy changed chunks + dst_s0_path = Path(dst_s0_path) + dst_s0_path.mkdir(parents=True, exist_ok=True) + copy_pairs = [(f"{s0_path}/{k}", str(dst_s0_path / k)) for k in changed_keys] + _copy_chunks_parallel(s3, copy_pairs) + + # Remove stale local chunks + for k in removed_keys: + local_chunk = dst_s0_path / k + try: + if local_chunk.exists(): + local_chunk.unlink() + except Exception as e: + logger.debug(f"Error removing stale chunk {k}: {e}") + + return changed_keys, removed_keys, remote_chunk_state + + +# --------------------------------------------------------------------------- +# Annotation sync (crop-based) +# --------------------------------------------------------------------------- + +def sync_annotation_from_minio(crop_id, force=False): + """ + Sync a single annotation crop from MinIO to local filesystem. + + Args: + crop_id: Crop ID to sync + force: Force sync even if not modified + + Returns: + bool: True if synced successfully + """ + if not minio_state["ip"] or not minio_state["port"] or not minio_state["output_base"]: + return False + + try: + s3 = _make_s3_filesystem() + + zarr_name = f"{crop_id}.zarr" + src_path = f"{minio_state['bucket']}/{zarr_name}/annotation" + dst_path = Path(minio_state["output_base"]) / zarr_name / "annotation" + + if not s3.exists(src_path): + return False + + known_chunk_state = minio_state["chunk_sync_state"].get(crop_id, {}) + s0_path = f"{src_path}/s0" + changed, removed, remote_chunk_state = _diff_and_sync_chunks( + s3, s0_path, dst_path / "s0", known_chunk_state, force=force + ) + + if not changed and not removed: + return False + + logger.info( + f"Syncing annotation for {crop_id} " + f"(changed={len(changed)}, removed={len(removed)})" + ) + + _sync_zarr_group_metadata(s3, src_path, dst_path) + + minio_state["last_sync"][crop_id] = datetime.now() + minio_state["chunk_sync_state"][crop_id] = remote_chunk_state + + logger.info(f"Successfully synced annotation for {crop_id}") + return True + + except Exception as e: + logger.error(f"Error syncing annotation for {crop_id}: {e}") + import traceback + logger.error(traceback.format_exc()) + return False + + +# --------------------------------------------------------------------------- +# Annotation sync (full-dataset sync) +# --------------------------------------------------------------------------- + +def sync_all_annotations_from_minio(force: bool = True): + """Sync all annotations from MinIO to local disk. + + Returns: + Number of annotations synced, or -1 if MinIO is not initialized. + """ + if not minio_state.get("ip") or not minio_state.get("port"): + logger.info("MinIO not initialized, skipping annotation sync") + return -1 + + logger.info(f"Syncing all annotations from MinIO (force={force})...") + s3 = _make_s3_filesystem() + zarrs = s3.ls(minio_state["bucket"]) + zarr_ids = [Path(c).name.replace(".zarr", "") for c in zarrs if c.endswith(".zarr")] + synced = 0 + for zid in zarr_ids: + try: + zarr_name = f"{zid}.zarr" + attrs_path = f"{minio_state['bucket']}/{zarr_name}/.zattrs" + if s3.exists(attrs_path): + root_attrs = json.loads(s3.cat(attrs_path)) + if root_attrs.get("type") == "annotation_volume": + if sync_annotation_volume_from_minio(zid, force=force): + synced += 1 + continue + except Exception: + pass + if sync_annotation_from_minio(zid, force=force): + synced += 1 + logger.info(f"Synced {synced}/{len(zarr_ids)} annotations") + return synced + + +# --------------------------------------------------------------------------- +# Volume metadata helpers +# --------------------------------------------------------------------------- + +def _get_volume_metadata(volume_id, zarr_path=None): + """ + Get volume metadata from in-memory cache or reconstruct from zarr attrs. + + Used for server restart recovery -- if annotation_volumes dict was lost, + reconstruct metadata from the zarr's stored attributes. + """ + if volume_id in annotation_volumes: + return annotation_volumes[volume_id] + + if zarr_path is None: + return None + + try: + root = zarr.open(zarr_path, mode="r") + attrs = dict(root.attrs) + if attrs.get("type") != "annotation_volume": + return None + + metadata = { + "zarr_path": zarr_path, + "model_name": attrs.get("model_name", ""), + "output_size": attrs.get("chunk_size", [56, 56, 56]), + "input_size": attrs.get("input_size", [178, 178, 178]), + "input_voxel_size": attrs.get("input_voxel_size", [16, 16, 16]), + "output_voxel_size": attrs.get("output_voxel_size", [16, 16, 16]), + "dataset_path": attrs.get("dataset_path", ""), + "dataset_offset_nm": attrs.get("dataset_offset_nm", [0, 0, 0]), + "corrections_dir": str(Path(zarr_path).parent), + "extracted_chunks": set(), + "chunk_sync_state": {}, + } + annotation_volumes[volume_id] = metadata + return metadata + except Exception as e: + logger.error(f"Error reconstructing volume metadata for {volume_id}: {e}") + return None + + +def extract_correction_from_chunk(volume_id, chunk_indices, volume_metadata): + """ + Extract a correction entry from a single annotated chunk in a sparse volume. + + Reads the annotation chunk, extracts raw data with context padding, and + creates a standard correction zarr entry compatible with CorrectionDataset. + + Args: + volume_id: Volume identifier + chunk_indices: Tuple (cz, cy, cx) of chunk indices + volume_metadata: Volume metadata dict + + Returns: + bool: True if correction was created (chunk had annotations) + """ + from cellmap_flow.image_data_interface import ImageDataInterface + from funlib.geometry import Roi, Coordinate + + cz, cy, cx = chunk_indices + chunk_size = np.array(volume_metadata["output_size"]) + output_voxel_size = np.array(volume_metadata["output_voxel_size"]) + input_size = np.array(volume_metadata["input_size"]) + input_voxel_size = np.array(volume_metadata["input_voxel_size"]) + dataset_offset_nm = np.array(volume_metadata["dataset_offset_nm"]) + corrections_dir = volume_metadata["corrections_dir"] + + vol_zarr_path = volume_metadata["zarr_path"] + vol = zarr.open(vol_zarr_path, mode="r") + + z_start = cz * chunk_size[0] + y_start = cy * chunk_size[1] + x_start = cx * chunk_size[2] + + annotation_data = vol["annotation/s0"][ + z_start : z_start + chunk_size[0], + y_start : y_start + chunk_size[1], + x_start : x_start + chunk_size[2], + ] + + # Skip if all zeros (unannotated or erased) + if not np.any(annotation_data): + return False + + # Compute physical position of this chunk's center + chunk_offset_nm = dataset_offset_nm + np.array( + [z_start, y_start, x_start] + ) * output_voxel_size + chunk_center_nm = chunk_offset_nm + (chunk_size * output_voxel_size) / 2 + + # Extract raw data with full context padding + read_shape_nm = input_size * input_voxel_size + raw_roi = Roi( + offset=Coordinate(chunk_center_nm - read_shape_nm / 2), + shape=Coordinate(read_shape_nm), + ) + + logger.info( + f"Extracting raw for chunk ({cz},{cy},{cx}): " + f"ROI offset={raw_roi.offset}, shape={raw_roi.shape}" + ) + + idi = ImageDataInterface( + volume_metadata["dataset_path"], voxel_size=input_voxel_size + ) + raw_data = idi.to_ndarray_ts(raw_roi) + + # Create correction entry + correction_id = f"{volume_id}_chunk_{cz}_{cy}_{cx}" + correction_zarr_path = os.path.join(corrections_dir, f"{correction_id}.zarr") + + raw_offset_voxels = ( + (chunk_center_nm - read_shape_nm / 2) / input_voxel_size + ).astype(int) + annotation_offset_voxels = (chunk_offset_nm / output_voxel_size).astype(int) + + success, zarr_info = create_correction_zarr( + zarr_path=correction_zarr_path, + raw_crop_shape=input_size, + raw_voxel_size=input_voxel_size, + raw_offset=raw_offset_voxels, + annotation_crop_shape=chunk_size, + annotation_voxel_size=output_voxel_size, + annotation_offset=annotation_offset_voxels, + dataset_path=volume_metadata["dataset_path"], + model_name=volume_metadata["model_name"], + output_channels=1, + raw_dtype=str(raw_data.dtype), + create_mask=False, + ) + + if not success: + logger.error(f"Failed to create correction zarr for chunk ({cz},{cy},{cx})") + return False + + # Write data + corr_zarr = zarr.open(correction_zarr_path, mode="r+") + corr_zarr["raw/s0"][:] = raw_data + corr_zarr["annotation/s0"][:] = annotation_data + + corr_zarr.attrs["source"] = "sparse_volume" + corr_zarr.attrs["volume_id"] = volume_id + corr_zarr.attrs["chunk_indices"] = [cz, cy, cx] + + logger.info(f"Created correction {correction_id} from chunk ({cz},{cy},{cx})") + return True + + +# --------------------------------------------------------------------------- +# Annotation volume sync +# --------------------------------------------------------------------------- + +def sync_annotation_volume_from_minio(volume_id, force=False): + """ + Sync an annotation volume from MinIO, detect annotated chunks, extract corrections. + + Steps: + 1. Sync the full annotation zarr from MinIO to local disk + 2. List chunk files in MinIO to find annotated chunks + 3. For each new annotated chunk, extract raw data and create correction entry + + Returns: + bool: True if any corrections were created + """ + if not minio_state["ip"] or not minio_state["port"] or not minio_state["output_base"]: + logger.warning("MinIO not initialized, skipping volume sync") + return False + + try: + zarr_name = f"{volume_id}.zarr" + local_zarr_path = os.path.join(minio_state["output_base"], zarr_name) + volume_meta = _get_volume_metadata(volume_id, local_zarr_path) + + if volume_meta is None: + logger.warning(f"No metadata for volume {volume_id}, skipping") + return False + + s3 = _make_s3_filesystem() + + bucket = minio_state["bucket"] + src_annotation_path = f"{bucket}/{zarr_name}/annotation" + + if not s3.exists(src_annotation_path): + return False + + # Sync zarr group metadata + dst_annotation_path = Path(local_zarr_path) / "annotation" + dst_annotation_path.mkdir(parents=True, exist_ok=True) + _sync_zarr_group_metadata(s3, src_annotation_path, dst_annotation_path) + + # Diff and sync chunks + s0_path = f"{bucket}/{zarr_name}/annotation/s0" + known_chunk_state = volume_meta.get("chunk_sync_state", {}) + changed_chunk_keys, removed_chunk_keys, remote_chunk_state = _diff_and_sync_chunks( + s3, s0_path, dst_annotation_path / "s0", known_chunk_state, force=force + ) + + if not changed_chunk_keys and not removed_chunk_keys: + minio_state["last_sync"][volume_id] = datetime.now() + return False + + logger.info( + f"Synced {len(changed_chunk_keys)} changed chunks for volume {volume_id}" + ) + + # Extract corrections for changed chunks + extracted_chunks = volume_meta.get("extracted_chunks", set()) + changed_chunk_indices = [ + tuple(map(int, k.split("."))) + for k in changed_chunk_keys + ] + created_any = False + + for chunk_idx in changed_chunk_indices: + try: + created = extract_correction_from_chunk( + volume_id, chunk_idx, volume_meta + ) + if created: + extracted_chunks.add(chunk_idx) + created_any = True + else: + extracted_chunks.discard(chunk_idx) + except Exception as e: + logger.error(f"Error extracting correction for chunk {chunk_idx}: {e}") + import traceback + logger.error(traceback.format_exc()) + + # Update tracked state + volume_meta["extracted_chunks"] = extracted_chunks + volume_meta["chunk_sync_state"] = remote_chunk_state + minio_state["last_sync"][volume_id] = datetime.now() + + if created_any or changed_chunk_keys or removed_chunk_keys: + logger.info( + f"Volume {volume_id}: {len(extracted_chunks)} total chunks extracted" + ) + + return bool(created_any or changed_chunk_keys or removed_chunk_keys) + + except Exception as e: + logger.error(f"Error syncing annotation volume {volume_id}: {e}") + import traceback + logger.error(traceback.format_exc()) + return False + + +# --------------------------------------------------------------------------- +# Periodic sync +# --------------------------------------------------------------------------- + +def periodic_sync_annotations(): + """Background thread function to periodically sync annotations from MinIO.""" + while True: + try: + time.sleep(30) + if not minio_state["output_base"]: + continue + if not minio_state["ip"] or not minio_state["port"]: + continue + sync_all_annotations_from_minio(force=False) + except Exception as e: + logger.debug(f"Error in periodic sync: {e}") + + +def start_periodic_sync(): + """Start the periodic annotation sync thread if not already running.""" + if minio_state["sync_thread"] is None or not minio_state["sync_thread"].is_alive(): + thread = threading.Thread(target=periodic_sync_annotations, daemon=True) + thread.start() + minio_state["sync_thread"] = thread + logger.info("Started periodic annotation sync thread") + + diff --git a/cellmap_flow/dashboard/routes/bbx_generator.py b/cellmap_flow/dashboard/routes/bbx_generator.py index 17ce1af..0599b95 100644 --- a/cellmap_flow/dashboard/routes/bbx_generator.py +++ b/cellmap_flow/dashboard/routes/bbx_generator.py @@ -4,7 +4,9 @@ from flask import Blueprint, request, jsonify from cellmap_flow.utils.scale_pyramid import get_raw_layer -from cellmap_flow.dashboard.state import bbx_generator_state +from cellmap_flow.globals import g + +bbx_generator_state = g.bbx_generator_state logger = logging.getLogger(__name__) diff --git a/cellmap_flow/dashboard/routes/blockwise.py b/cellmap_flow/dashboard/routes/blockwise.py index 881461e..f3ebded 100644 --- a/cellmap_flow/dashboard/routes/blockwise.py +++ b/cellmap_flow/dashboard/routes/blockwise.py @@ -11,7 +11,7 @@ from cellmap_flow.globals import g from cellmap_flow.utils.web_utils import INPUT_NORM_DICT_KEY, POSTPROCESS_DICT_KEY -from cellmap_flow.dashboard.state import get_blockwise_tasks_dir +from cellmap_flow.globals import get_blockwise_tasks_dir logger = logging.getLogger(__name__) diff --git a/cellmap_flow/dashboard/routes/finetune_routes.py b/cellmap_flow/dashboard/routes/finetune_routes.py new file mode 100644 index 0000000..3caeb1a --- /dev/null +++ b/cellmap_flow/dashboard/routes/finetune_routes.py @@ -0,0 +1,1235 @@ +import json +import os +import subprocess +import time +import uuid +import logging +from datetime import datetime +from pathlib import Path + +import numpy as np +import zarr +import neuroglancer +from flask import Blueprint, request, jsonify, Response + +from cellmap_flow.globals import g +from cellmap_flow.utils.load_py import load_safe_config +from cellmap_flow.globals import g +from cellmap_flow.dashboard.finetune_utils import ( + get_or_create_session_path, + create_correction_zarr, + create_annotation_volume_zarr, + ensure_minio_serving, + sync_annotation_from_minio, + sync_all_annotations_from_minio, +) + +logger = logging.getLogger(__name__) + +finetune_bp = Blueprint("finetune", __name__) + + +@finetune_bp.route("/api/finetune/models", methods=["GET"]) +def get_finetune_models(): + """Get available models for finetuning with their configurations.""" + try: + models = [] + + if hasattr(g, "models_config") and g.models_config: + for model_config in g.models_config: + try: + config = model_config.config + models.append( + { + "name": model_config.name, + "write_shape": list(config.write_shape), + "output_voxel_size": list(config.output_voxel_size), + "output_channels": config.output_channels, + } + ) + except Exception as e: + logger.warning( + f"Could not extract config for {model_config.name}: {e}" + ) + + if len(models) == 0 and hasattr(g, "jobs") and g.jobs: + logger.warning("No models in g.models_config, checking running jobs") + for job in g.jobs: + if hasattr(job, "model_name"): + job_model_name = job.model_name + if ( + hasattr(g, "pipeline_model_configs") + and job_model_name in g.pipeline_model_configs + ): + config_dict = g.pipeline_model_configs[job_model_name] + try: + models.append( + { + "name": job_model_name, + "write_shape": config_dict.get("write_shape", []), + "output_voxel_size": config_dict.get( + "output_voxel_size", [] + ), + "output_channels": config_dict.get( + "output_channels", 1 + ), + } + ) + logger.info( + f"Found config for {job_model_name} in pipeline_model_configs" + ) + except Exception as e: + logger.warning( + f"Could not extract config for {job_model_name}: {e}" + ) + else: + logger.warning( + f"No configuration found for running job: {job_model_name}" + ) + + selected = models[0]["name"] if len(models) == 1 else None + + return jsonify({"models": models, "selected_model": selected}) + + except Exception as e: + logger.error(f"Error getting finetune models: {e}") + return jsonify({"error": str(e)}), 500 + + +@finetune_bp.route("/api/finetune/view-center", methods=["GET"]) +def get_view_center(): + """Get current view center position from Neuroglancer viewer.""" + try: + if not hasattr(g, "viewer") or g.viewer is None: + return jsonify({"success": False, "error": "Viewer not initialized"}), 400 + + with g.viewer.txn() as s: + position = s.position + + dimensions = s.dimensions + scales_nm = None + + if dimensions and hasattr(dimensions, "scales"): + scales_nm = list(dimensions.scales) + logger.info(f"Viewer scales (raw): {scales_nm}") + + if hasattr(dimensions, "units"): + units = dimensions.units + if isinstance(units, str): + units = [units] * len(scales_nm) + + converted_scales = [] + for scale, unit in zip(scales_nm, units): + if unit == "m": + converted_scales.append(scale * 1e9) + elif unit == "nm": + converted_scales.append(scale) + else: + logger.warning(f"Unknown unit: {unit}, assuming nm") + converted_scales.append(scale) + scales_nm = converted_scales + + logger.info(f"Viewer scales (nm): {scales_nm}") + else: + logger.warning("Could not extract scales from viewer dimensions") + + if hasattr(position, "tolist"): + position = position.tolist() + elif hasattr(position, "__iter__"): + position = list(position) + + logger.info(f"Got view center position: {position}") + + return jsonify( + {"success": True, "position": position, "scales_nm": scales_nm} + ) + + except Exception as e: + logger.error(f"Error getting view center position: {e}") + import traceback + logger.error(traceback.format_exc()) + return jsonify({"success": False, "error": str(e)}), 500 + + +@finetune_bp.route("/api/finetune/create-crop", methods=["POST"]) +def create_annotation_crop(): + """Create an annotation crop centered at view center position.""" + try: + from cellmap_flow.image_data_interface import ImageDataInterface + from funlib.geometry import Roi, Coordinate + + data = request.get_json() + model_name = data.get("model_name") + output_path = data.get("output_path") + + if not hasattr(g, "models_config") or not g.models_config: + return jsonify({"success": False, "error": "No models loaded"}), 400 + + if not hasattr(g, "viewer") or g.viewer is None: + return jsonify({"success": False, "error": "Viewer not initialized"}), 400 + + with g.viewer.txn() as s: + position = s.position + + dimensions = s.dimensions + viewer_scales_nm = None + + if dimensions and hasattr(dimensions, "scales"): + scales_nm = list(dimensions.scales) + + if hasattr(dimensions, "units"): + units = dimensions.units + if isinstance(units, str): + units = [units] * len(scales_nm) + + converted_scales = [] + for scale, unit in zip(scales_nm, units): + if unit == "m": + converted_scales.append(scale * 1e9) + elif unit == "nm": + converted_scales.append(scale) + else: + logger.warning(f"Unknown unit: {unit}, assuming nm") + converted_scales.append(scale) + viewer_scales_nm = converted_scales + else: + viewer_scales_nm = scales_nm + + if hasattr(position, "tolist"): + view_center = position.tolist() + elif hasattr(position, "__iter__"): + view_center = list(position) + else: + view_center = position + + view_center = np.array(view_center) + + logger.info(f"Auto-detected view center: {view_center}") + logger.info(f"Auto-detected viewer scales: {viewer_scales_nm} nm") + + # Find model config + model_config = None + for mc in g.models_config: + if mc.name == model_name: + model_config = mc + break + + if not model_config: + return ( + jsonify({"success": False, "error": f"Model {model_name} not found"}), + 404, + ) + + config = model_config.config + read_shape = np.array(config.read_shape) + write_shape = np.array(config.write_shape) + input_voxel_size = np.array(config.input_voxel_size) + output_voxel_size = np.array(config.output_voxel_size) + output_channels = config.output_channels + + if viewer_scales_nm is not None: + viewer_scales_nm = np.array(viewer_scales_nm) + view_center_nm = view_center * viewer_scales_nm + logger.info( + f"Converted view center from {view_center} (viewer coords) to {view_center_nm} nm" + ) + else: + view_center_nm = view_center + logger.warning( + "No viewer scales provided, assuming view center is already in nm" + ) + + raw_crop_shape_voxels = (read_shape / input_voxel_size).astype(int) + annotation_crop_shape_voxels = (write_shape / output_voxel_size).astype(int) + + half_read_shape = read_shape / 2 + raw_crop_offset_nm = view_center_nm - half_read_shape + raw_crop_offset_voxels = (raw_crop_offset_nm / input_voxel_size).astype(int) + + half_write_shape = write_shape / 2 + annotation_crop_offset_nm = view_center_nm - half_write_shape + annotation_crop_offset_voxels = ( + annotation_crop_offset_nm / output_voxel_size + ).astype(int) + + crop_id = f"{uuid.uuid4().hex[:8]}-{datetime.now().strftime('%Y%m%d-%H%M%S')}" + + if output_path: + session_path = get_or_create_session_path(output_path) + corrections_dir = os.path.join(session_path, "corrections") + os.makedirs(corrections_dir, exist_ok=True) + zarr.open_group(corrections_dir, mode="a") + zarr_path = os.path.join(corrections_dir, f"{crop_id}.zarr") + logger.info(f"Using session path: {session_path}") + logger.info(f"Corrections directory: {corrections_dir}") + else: + corrections_dir = os.path.expanduser("~/.cellmap_flow/corrections") + os.makedirs(corrections_dir, exist_ok=True) + zarr.open_group(corrections_dir, mode="a") + zarr_path = os.path.join(corrections_dir, f"{crop_id}.zarr") + + dataset_path = getattr(g, "dataset_path", "unknown") + + logger.info(f"Creating ImageDataInterface for {dataset_path}") + logger.info(f"Using input voxel size: {input_voxel_size} nm") + try: + idi = ImageDataInterface(dataset_path, voxel_size=input_voxel_size) + raw_dtype = str(idi.ts.dtype) + logger.info(f"Dataset dtype: {raw_dtype}") + except Exception as e: + logger.error(f"Error creating ImageDataInterface: {e}") + return ( + jsonify( + {"success": False, "error": f"Failed to access dataset: {str(e)}"} + ), + 500, + ) + + success, zarr_info = create_correction_zarr( + zarr_path=zarr_path, + raw_crop_shape=raw_crop_shape_voxels, + raw_voxel_size=input_voxel_size, + raw_offset=raw_crop_offset_voxels, + annotation_crop_shape=annotation_crop_shape_voxels, + annotation_voxel_size=output_voxel_size, + annotation_offset=annotation_crop_offset_voxels, + dataset_path=dataset_path, + model_name=model_name, + output_channels=output_channels, + raw_dtype=raw_dtype, + create_mask=False, + ) + + if not success: + return jsonify({"success": False, "error": zarr_info}), 500 + + logger.info(f"Reading raw data from {dataset_path}") + try: + roi = Roi( + offset=Coordinate(view_center_nm - read_shape / 2), + shape=Coordinate(read_shape), + ) + logger.info(f"Reading ROI: offset={roi.offset}, shape={roi.shape}") + + raw_data = idi.to_ndarray_ts(roi) + logger.info( + f"Read raw data with shape: {raw_data.shape}, dtype: {raw_data.dtype}" + ) + + raw_zarr = zarr.open(zarr_path, mode="r+") + raw_zarr["raw/s0"][:] = raw_data + logger.info(f"Wrote raw data to {zarr_path}/raw/s0") + + except Exception as e: + logger.error(f"Error reading/writing raw data: {e}") + import traceback + logger.error(traceback.format_exc()) + return ( + jsonify( + {"success": False, "error": f"Failed to read raw data: {str(e)}"} + ), + 500, + ) + + minio_url = ensure_minio_serving( + zarr_path, crop_id, output_base_dir=corrections_dir + ) + + return jsonify( + { + "success": True, + "crop_id": crop_id, + "zarr_path": zarr_path, + "minio_url": minio_url, + "neuroglancer_url": f"{minio_url}/annotation", + "metadata": { + "center_position_nm": view_center_nm.tolist(), + "raw_crop_offset": raw_crop_offset_voxels.tolist(), + "raw_crop_shape": raw_crop_shape_voxels.tolist(), + "raw_voxel_size": input_voxel_size.tolist(), + "annotation_crop_offset": annotation_crop_offset_voxels.tolist(), + "annotation_crop_shape": annotation_crop_shape_voxels.tolist(), + "annotation_voxel_size": output_voxel_size.tolist(), + }, + } + ) + + except Exception as e: + logger.error(f"Error creating annotation crop: {e}") + import traceback + logger.error(traceback.format_exc()) + return jsonify({"success": False, "error": str(e)}), 500 + + +@finetune_bp.route("/api/finetune/create-volume", methods=["POST"]) +def create_annotation_volume(): + """Create a sparse annotation volume covering the full dataset extent.""" + try: + from cellmap_flow.image_data_interface import ImageDataInterface + from funlib.geometry import Coordinate + + data = request.get_json() + model_name = data.get("model_name") + output_path = data.get("output_path") + + if not hasattr(g, "models_config") or not g.models_config: + return jsonify({"success": False, "error": "No models loaded"}), 400 + + model_config = None + for mc in g.models_config: + if mc.name == model_name: + model_config = mc + break + + if not model_config: + return ( + jsonify({"success": False, "error": f"Model {model_name} not found"}), + 404, + ) + + config = model_config.config + read_shape = np.array(config.read_shape) + write_shape = np.array(config.write_shape) + input_voxel_size = np.array(config.input_voxel_size) + output_voxel_size = np.array(config.output_voxel_size) + + output_size = (write_shape / output_voxel_size).astype(int) + input_size = (read_shape / input_voxel_size).astype(int) + + dataset_path = getattr(g, "dataset_path", None) + if not dataset_path: + return ( + jsonify({"success": False, "error": "No dataset path configured"}), + 400, + ) + + logger.info(f"Getting dataset extent from {dataset_path}") + try: + idi = ImageDataInterface(dataset_path, voxel_size=output_voxel_size) + dataset_roi = idi.roi + dataset_offset_nm = np.array(dataset_roi.offset) + dataset_shape_nm = np.array(dataset_roi.shape) + + dataset_shape_voxels = (dataset_shape_nm / output_voxel_size).astype(int) + dataset_shape_voxels = ( + np.ceil(dataset_shape_voxels / output_size).astype(int) * output_size + ) + + logger.info( + f"Dataset extent: offset={dataset_offset_nm} nm, " + f"shape={dataset_shape_voxels} voxels (at {output_voxel_size} nm/voxel)" + ) + except Exception as e: + logger.error(f"Error getting dataset extent: {e}") + return ( + jsonify( + { + "success": False, + "error": f"Failed to access dataset: {str(e)}", + } + ), + 500, + ) + + volume_id = ( + f"vol-{uuid.uuid4().hex[:8]}-{datetime.now().strftime('%Y%m%d-%H%M%S')}" + ) + + if output_path: + session_path = get_or_create_session_path(output_path) + corrections_dir = os.path.join(session_path, "corrections") + os.makedirs(corrections_dir, exist_ok=True) + zarr.open_group(corrections_dir, mode="a") + zarr_path = os.path.join(corrections_dir, f"{volume_id}.zarr") + logger.info(f"Using session path: {session_path}") + else: + corrections_dir = os.path.expanduser("~/.cellmap_flow/corrections") + os.makedirs(corrections_dir, exist_ok=True) + zarr.open_group(corrections_dir, mode="a") + zarr_path = os.path.join(corrections_dir, f"{volume_id}.zarr") + + success, zarr_info = create_annotation_volume_zarr( + zarr_path=zarr_path, + dataset_shape_voxels=dataset_shape_voxels, + output_voxel_size=output_voxel_size, + dataset_offset_nm=dataset_offset_nm, + chunk_size=output_size, + dataset_path=dataset_path, + model_name=model_name, + input_size=input_size, + input_voxel_size=input_voxel_size, + ) + + if not success: + return jsonify({"success": False, "error": zarr_info}), 500 + + minio_url = ensure_minio_serving( + zarr_path, volume_id, output_base_dir=corrections_dir + ) + + g.annotation_volumes[volume_id] = { + "zarr_path": zarr_path, + "model_name": model_name, + "output_size": output_size.tolist(), + "input_size": input_size.tolist(), + "input_voxel_size": input_voxel_size.tolist(), + "output_voxel_size": output_voxel_size.tolist(), + "dataset_path": dataset_path, + "dataset_offset_nm": dataset_offset_nm.tolist(), + "corrections_dir": corrections_dir, + "extracted_chunks": set(), + "chunk_sync_state": {}, + } + + return jsonify( + { + "success": True, + "volume_id": volume_id, + "zarr_path": zarr_path, + "minio_url": minio_url, + "neuroglancer_url": f"{minio_url}/annotation", + "metadata": { + "dataset_shape_voxels": dataset_shape_voxels.tolist(), + "chunk_size": output_size.tolist(), + "output_voxel_size": output_voxel_size.tolist(), + "dataset_offset_nm": dataset_offset_nm.tolist(), + }, + } + ) + + except Exception as e: + logger.error(f"Error creating annotation volume: {e}") + import traceback + logger.error(traceback.format_exc()) + return jsonify({"success": False, "error": str(e)}), 500 + + +@finetune_bp.route("/api/finetune/add-to-viewer", methods=["POST"]) +def add_crop_to_viewer(): + """Add annotation crop or volume layer to Neuroglancer viewer.""" + try: + data = request.get_json() + crop_id = data.get("crop_id") + minio_url = data.get("minio_url") + + if not hasattr(g, "viewer") or g.viewer is None: + return jsonify({"success": False, "error": "Viewer not initialized"}), 400 + + with g.viewer.txn() as s: + layer_name = data.get("layer_name", f"annotation_{crop_id}") + source_config = { + "url": f"s3+{minio_url}", + "subsources": {"default": {"writingEnabled": True}, "bounds": {}}, + } + layer = neuroglancer.SegmentationLayer(source=source_config) + s.layers[layer_name] = layer + + logger.info(f"Added layer {layer_name} to viewer") + + return jsonify( + { + "success": True, + "message": "Layer added to viewer", + "layer_name": layer_name, + } + ) + + except Exception as e: + logger.error(f"Error adding layer to viewer: {e}") + import traceback + logger.error(traceback.format_exc()) + return jsonify({"success": False, "error": str(e)}), 500 + + +@finetune_bp.route("/api/finetune/sync-annotations", methods=["POST"]) +def sync_annotations_manually(): + """Manually trigger sync of annotations from MinIO to local disk.""" + try: + data = request.get_json() + crop_id = data.get("crop_id", None) + force = data.get("force", True) + + if crop_id: + success = sync_annotation_from_minio(crop_id, force=force) + if success: + return jsonify( + {"success": True, "message": f"Synced annotation for {crop_id}"} + ) + else: + return jsonify( + {"success": False, "message": f"No updates to sync for {crop_id}"} + ) + else: + synced = sync_all_annotations_from_minio(force=force) + if synced == -1: + return ( + jsonify({"success": False, "error": "MinIO not initialized"}), + 400, + ) + return jsonify( + { + "success": True, + "message": f"Synced {synced} annotations", + "synced_count": synced, + } + ) + + except Exception as e: + logger.error(f"Error in sync endpoint: {e}") + import traceback + logger.error(traceback.format_exc()) + return jsonify({"success": False, "error": str(e)}), 500 + + +@finetune_bp.route("/api/finetune/submit", methods=["POST"]) +def submit_finetuning(): + """Submit a finetuning job to the LSF cluster.""" + try: + data = request.get_json() + model_name = data.get("model_name") + corrections_path_str = data.get("corrections_path") + lora_r = data.get("lora_r", 8) + num_epochs = data.get("num_epochs", 10) + batch_size = data.get("batch_size", 2) + learning_rate = data.get("learning_rate", 1e-4) + checkpoint_path_override = data.get("checkpoint_path") + auto_serve = data.get("auto_serve", True) + loss_type = data.get("loss_type", "mse") + label_smoothing = data.get("label_smoothing", 0.1) + distillation_lambda = data.get("distillation_lambda", 0.0) + distillation_scope = data.get("distillation_scope", "unlabeled") + margin = data.get("margin", 0.3) + balance_classes = data.get("balance_classes", False) + queue = data.get("queue", "gpu_h100") + output_type = data.get("output_type", None) # None = auto-detect + select_channel = data.get("select_channel", None) + offsets = data.get("offsets", None) + + if not model_name: + return jsonify({"success": False, "error": "model_name is required"}), 400 + + if not corrections_path_str: + return ( + jsonify( + { + "success": False, + "error": "corrections_path is required. Please specify the output path where annotation crops are saved.", + } + ), + 400, + ) + + model_config = None + for config in g.models_config: + if config.name == model_name: + model_config = config + break + + if not model_config: + return ( + jsonify({"success": False, "error": f"Model {model_name} not found"}), + 404, + ) + + base_corrections_path = Path(corrections_path_str) + + actual_corrections_path = None + if ( + base_corrections_path.name == "corrections" + and base_corrections_path.exists() + ): + actual_corrections_path = base_corrections_path + session_path = base_corrections_path.parent + else: + session_path = get_or_create_session_path(str(base_corrections_path)) + actual_corrections_path = Path(session_path) / "corrections" + + if not actual_corrections_path.exists(): + return ( + jsonify( + { + "success": False, + "error": f"Corrections path does not exist: {actual_corrections_path}. Please create annotation crops first.", + } + ), + 400, + ) + + output_base = Path(session_path) + logger.info(f"Using session path for finetuning: {session_path}") + logger.info(f"Corrections path: {actual_corrections_path}") + + try: + sync_all_annotations_from_minio(force=False) + except Exception as e: + logger.warning(f"Error syncing annotations before training: {e}") + + # Detect sparse annotations + has_sparse = False + try: + for p in actual_corrections_path.iterdir(): + if p.suffix == ".zarr" and (p / ".zattrs").exists(): + attrs = json.loads((p / ".zattrs").read_text()) + if attrs.get("source") == "sparse_volume": + has_sparse = True + break + except Exception as e: + logger.warning(f"Error checking for sparse annotations: {e}") + + sparse_auto_switched = False + if has_sparse: + logger.info("Detected sparse annotations, will use mask_unannotated=True") + if loss_type == "mse": + loss_type = "margin" + distillation_lambda = 0.5 + sparse_auto_switched = True + logger.info( + "Auto-switched to margin loss + distillation (lambda=0.5) for sparse annotations" + ) + + # Auto-detect output_type and offsets from model script + from cellmap_flow.finetune.finetune_cli import _read_offsets_from_script + if output_type is None: + # Try to auto-detect from model script + if hasattr(model_config, 'script_path'): + script_offsets = _read_offsets_from_script(model_config.script_path) + if script_offsets is not None: + output_type = "affinities" + offsets = json.dumps(script_offsets) + logger.info( + f"Auto-detected output_type='affinities' with " + f"{len(script_offsets)} offsets from model script" + ) + if output_type is None: + output_type = "binary" + + if output_type == "affinities" and offsets is None: + if hasattr(model_config, 'script_path'): + offsets = _read_offsets_from_script(model_config.script_path) + if offsets is not None: + logger.info(f"Auto-detected {len(offsets)} offsets from model script") + offsets = json.dumps(offsets) + if offsets is None: + return jsonify({ + "success": False, + "error": "output_type='affinities' requires offsets. " + "Define 'offsets' in the model script or pass them in the request." + }), 400 + elif isinstance(offsets, list): + offsets = json.dumps(offsets) + + finetune_job = g.finetune_job_manager.submit_finetuning_job( + model_config=model_config, + corrections_path=actual_corrections_path, + lora_r=lora_r, + num_epochs=num_epochs, + batch_size=batch_size, + learning_rate=learning_rate, + output_base=output_base, + checkpoint_path_override=( + Path(checkpoint_path_override) if checkpoint_path_override else None + ), + auto_serve=auto_serve, + mask_unannotated=has_sparse, + loss_type=loss_type, + label_smoothing=label_smoothing, + distillation_lambda=distillation_lambda, + distillation_scope=distillation_scope, + margin=margin, + balance_classes=balance_classes, + queue=queue, + output_type=output_type, + select_channel=select_channel, + offsets=offsets, + ) + + logger.info(f"Submitted finetuning job: {finetune_job.job_id}") + + lsf_job_id = None + if finetune_job.lsf_job: + if hasattr(finetune_job.lsf_job, "job_id"): + lsf_job_id = finetune_job.lsf_job.job_id + elif hasattr(finetune_job.lsf_job, "process"): + lsf_job_id = f"PID:{finetune_job.lsf_job.process.pid}" + + response = { + "success": True, + "job_id": finetune_job.job_id, + "lsf_job_id": lsf_job_id, + "output_dir": str(finetune_job.output_dir), + "message": "Finetuning job submitted successfully", + } + if sparse_auto_switched: + response["note"] = ( + "Auto-switched to margin loss + distillation (lambda=0.5) for sparse annotations" + ) + + return jsonify(response) + + except ValueError as e: + logger.error(f"Validation error: {e}") + return jsonify({"success": False, "error": str(e)}), 400 + except Exception as e: + logger.error(f"Error submitting finetuning job: {e}") + import traceback + logger.error(traceback.format_exc()) + return jsonify({"success": False, "error": str(e)}), 500 + + +@finetune_bp.route("/api/finetune/jobs", methods=["GET"]) +def get_finetuning_jobs(): + """Get list of all finetuning jobs.""" + try: + jobs = g.finetune_job_manager.list_jobs() + return jsonify({"success": True, "jobs": jobs}) + except Exception as e: + logger.error(f"Error getting jobs: {e}") + return jsonify({"success": False, "error": str(e)}), 500 + + +@finetune_bp.route("/api/finetune/job//status", methods=["GET"]) +def get_job_status(job_id): + """Get detailed status of a specific job.""" + try: + status = g.finetune_job_manager.get_job_status(job_id) + if status is None: + return jsonify({"success": False, "error": "Job not found"}), 404 + + return jsonify({"success": True, **status}) + except Exception as e: + logger.error(f"Error getting job status: {e}") + return jsonify({"success": False, "error": str(e)}), 500 + + +@finetune_bp.route("/api/finetune/job//logs", methods=["GET"]) +def get_job_logs(job_id): + """Get training logs for a specific job.""" + try: + logs = g.finetune_job_manager.get_job_logs(job_id) + if logs is None: + return jsonify({"success": False, "error": "Job not found"}), 404 + + return jsonify({"success": True, "logs": logs}) + except Exception as e: + logger.error(f"Error getting job logs: {e}") + return jsonify({"success": False, "error": str(e)}), 500 + + +@finetune_bp.route("/api/finetune/job//logs/stream", methods=["GET"]) +def stream_job_logs(job_id): + """Server-Sent Events stream for live training logs.""" + + import re as _re + + _log_filters = [ + _re.compile(r"^\s+base_model\.\S+\.lora_"), + _re.compile(r"^INFO:werkzeug:"), + _re.compile(r"^Array metadata \(scale="), + _re.compile(r"^Host name:"), + _re.compile(r"^DEBUG trainer:"), + ] + + def _should_show(line): + for pat in _log_filters: + if pat.search(line): + return False + return True + + def _iter_visible_lines(text): + for line in text.splitlines(): + if line and _should_show(line): + yield line + + def _sse_data_block(lines): + if not lines: + return None + payload = "\n".join(lines) + return "data: " + payload.replace("\n", "\ndata: ") + "\n\n" + + def _read_bpeek_content(lsf_job_id): + try: + result = subprocess.run( + ["bpeek", str(lsf_job_id)], + capture_output=True, + text=True, + timeout=5, + ) + except Exception as e: + logger.debug(f"bpeek call failed for job {lsf_job_id}: {e}") + return None + + output = result.stdout or "" + stderr = (result.stderr or "").strip() + if stderr and "Not yet started" not in stderr: + logger.debug(f"bpeek stderr for job {lsf_job_id}: {stderr}") + return output + + def generate(): + heartbeat_interval_s = 1.0 + last_heartbeat = time.perf_counter() + + fjm = g.finetune_job_manager + if job_id not in fjm.jobs: + yield f"data: Job {job_id} not found\n\n" + return + + finetune_job = fjm.jobs[job_id] + lsf_job_id = None + if finetune_job.lsf_job and hasattr(finetune_job.lsf_job, "job_id"): + lsf_job_id = finetune_job.lsf_job.job_id + + use_bpeek = lsf_job_id is not None + last_bpeek_line_count = 0 + last_bpeek_poll = 0.0 + bpeek_poll_interval_s = 0.25 + + # Send existing content first + if use_bpeek: + initial = _read_bpeek_content(lsf_job_id) + if initial is None: + use_bpeek = False + else: + initial_lines = initial.splitlines() + last_bpeek_line_count = len(initial_lines) + block = _sse_data_block(list(_iter_visible_lines(initial))) + if block: + yield block + + if not use_bpeek and finetune_job.log_file.exists(): + try: + with open(finetune_job.log_file, "r") as f: + existing_content = f.read() + block = _sse_data_block(list(_iter_visible_lines(existing_content))) + if block: + yield block + except Exception as e: + logger.error(f"Error reading log file: {e}") + + last_position = ( + finetune_job.log_file.stat().st_size + if finetune_job.log_file.exists() + else 0 + ) + + while finetune_job.status.value in ["PENDING", "RUNNING"]: + try: + now = time.perf_counter() + + if ( + use_bpeek + and lsf_job_id + and now - last_bpeek_poll >= bpeek_poll_interval_s + ): + last_bpeek_poll = now + content = _read_bpeek_content(lsf_job_id) + if content is None: + use_bpeek = False + else: + current_lines = content.splitlines() + if len(current_lines) < last_bpeek_line_count: + delta_lines = current_lines + else: + delta_lines = current_lines[last_bpeek_line_count:] + last_bpeek_line_count = len(current_lines) + if delta_lines: + delta_text = "\n".join(delta_lines) + block = _sse_data_block( + list(_iter_visible_lines(delta_text)) + ) + if block: + yield block + + if not use_bpeek and finetune_job.log_file.exists(): + with open(finetune_job.log_file, "r") as f: + f.seek(last_position) + new_content = f.read() + last_position = f.tell() + if new_content: + block = _sse_data_block( + list(_iter_visible_lines(new_content)) + ) + if block: + yield block + + if now - last_heartbeat >= heartbeat_interval_s: + yield ": ping\n\n" + last_heartbeat = now + + time.sleep(0.1) + + except Exception as e: + logger.error(f"Error streaming logs: {e}") + break + + yield f"data: === Training {finetune_job.status.value} ===\n\n" + + return Response( + generate(), + mimetype="text/event-stream", + headers={ + "Cache-Control": "no-cache", + "X-Accel-Buffering": "no", + "Connection": "keep-alive", + }, + ) + + +@finetune_bp.route("/api/finetune/job//cancel", methods=["POST"]) +def cancel_job(job_id): + """Cancel a running finetuning job.""" + try: + success = g.finetune_job_manager.cancel_job(job_id) + + if success: + return jsonify({"success": True, "message": f"Job {job_id} cancelled"}) + else: + return jsonify({"success": False, "error": "Failed to cancel job"}), 400 + + except Exception as e: + logger.error(f"Error cancelling job: {e}") + return jsonify({"success": False, "error": str(e)}), 500 + + +@finetune_bp.route("/api/finetune/job//inference-server", methods=["GET"]) +def get_inference_server_status(job_id): + """Get inference server status for a finetuning job.""" + try: + job = g.finetune_job_manager.get_job(job_id) + if not job: + return jsonify({"success": False, "error": "Job not found"}), 404 + + return jsonify( + { + "success": True, + "ready": job.inference_server_ready, + "url": job.inference_server_url, + "model_name": job.finetuned_model_name, + "model_script_path": ( + str(job.model_script_path) if job.model_script_path else None + ), + } + ) + + except Exception as e: + logger.error(f"Error getting inference server status: {e}") + return jsonify({"success": False, "error": str(e)}), 500 + + +@finetune_bp.route("/api/viewer/add-finetuned-layer", methods=["POST"]) +def add_finetuned_layer_to_viewer(): + """Add finetuned model layer to Neuroglancer viewer and register model in system.""" + try: + data = request.get_json() + server_url = data.get("server_url") + model_name = data.get("model_name") + model_script_path = data.get("model_script_path") + + if not server_url or not model_name: + return ( + jsonify( + {"success": False, "error": "Missing server_url or model_name"} + ), + 400, + ) + + logger.info(f"Registering finetuned model: {model_name} at {server_url}") + + # 1. Load model config from script if provided + if model_script_path and Path(model_script_path).exists(): + try: + model_config = load_safe_config(model_script_path) + + if not hasattr(g, "models_config"): + g.models_config = [] + + base_model_name = ( + model_name.rsplit("_finetuned_", 1)[0] + if "_finetuned_" in model_name + else model_name + ) + g.models_config = [ + mc + for mc in g.models_config + if not ( + hasattr(mc, "name") + and mc.name.startswith(f"{base_model_name}_finetuned") + ) + ] + + g.models_config.append(model_config) + logger.info(f"Loaded model config from {model_script_path}") + except Exception as e: + logger.warning(f"Could not load model config: {e}") + + # 2. Add to model_catalog under "Finetuned" group + if not hasattr(g, "model_catalog"): + g.model_catalog = {} + + if "Finetuned" not in g.model_catalog: + g.model_catalog["Finetuned"] = {} + + base_model_name = ( + model_name.rsplit("_finetuned_", 1)[0] + if "_finetuned_" in model_name + else model_name + ) + g.model_catalog["Finetuned"] = { + name: path + for name, path in g.model_catalog["Finetuned"].items() + if not name.startswith(f"{base_model_name}_finetuned") + } + + g.model_catalog["Finetuned"][model_name] = ( + model_script_path if model_script_path else "" + ) + logger.info(f"Added to model catalog: Finetuned/{model_name}") + + # 3. Create a Job object for the running inference server + from cellmap_flow.utils.bsub_utils import LSFJob + + finetune_job = None + for job_id, ft_job in g.finetune_job_manager.jobs.items(): + if ft_job.finetuned_model_name == model_name: + finetune_job = ft_job + break + + if finetune_job and finetune_job.job_id: + inference_job = LSFJob( + job_id=finetune_job.job_id, model_name=model_name + ) + inference_job.host = server_url + inference_job.status = finetune_job.status + + g.jobs = [ + j + for j in g.jobs + if not ( + hasattr(j, "model_name") + and j.model_name + and j.model_name.startswith(f"{base_model_name}_finetuned") + ) + ] + + g.jobs.append(inference_job) + logger.info( + f"Created Job object for {model_name} with job_id {finetune_job.job_id}" + ) + else: + logger.warning( + f"Could not find finetune job for {model_name}, Job object not created" + ) + + # 4. Add neuroglancer layer + layer_name = model_name + + with g.viewer.txn() as s: + if layer_name in s.layers: + logger.info(f"Removing old finetuned layer: {layer_name}") + del s.layers[layer_name] + + from cellmap_flow.utils.neuroglancer_utils import get_norms_post_args + from cellmap_flow.utils.web_utils import ARGS_KEY + + st_data = get_norms_post_args(g.input_norms, g.postprocess) + + layer_source = f"zarr://{server_url}/{model_name}{ARGS_KEY}{st_data}{ARGS_KEY}" + s.layers[layer_name] = neuroglancer.ImageLayer( + source=layer_source, + shader="""#uicontrol invlerp normalized(range=[0, 255], window=[0, 255]); + #uicontrol vec3 color color(default="red"); + void main(){emitRGB(color * normalized());}""", + ) + + logger.info(f"Added neuroglancer layer: {layer_name} -> {server_url}") + + return jsonify( + { + "success": True, + "layer_name": layer_name, + "model_name": model_name, + "reload_page": True, + } + ) + + except Exception as e: + logger.error(f"Error adding finetuned layer: {e}", exc_info=True) + return jsonify({"success": False, "error": str(e)}), 500 + + +@finetune_bp.route("/api/finetune/job//restart", methods=["POST"]) +def restart_finetuning_job(job_id): + """Restart training on the same GPU via in-process control channel.""" + try: + restart_t0 = time.perf_counter() + data = request.get_json() or {} + + # Sync annotations from MinIO before restarting training + try: + sync_t0 = time.perf_counter() + synced = sync_all_annotations_from_minio(force=False) + sync_elapsed = time.perf_counter() - sync_t0 + logger.info( + f"Restart pre-sync complete for job {job_id}: synced={synced}, " + f"elapsed={sync_elapsed:.2f}s" + ) + except Exception as e: + logger.warning(f"Error syncing annotations before restart: {e}") + + updated_params = {} + passthrough_keys = [ + "lora_r", + "lora_alpha", + "num_epochs", + "batch_size", + "learning_rate", + "loss_type", + "label_smoothing", + "distillation_lambda", + "margin", + "balance_classes", + "mask_unannotated", + "gradient_accumulation_steps", + "num_workers", + "no_augment", + "no_mixed_precision", + "patch_shape", + "output_type", + "select_channel", + "offsets", + ] + for key in passthrough_keys: + if key in data and data[key] is not None: + updated_params[key] = data[key] + + # UI uses distillation_scope; CLI expects distillation_all_voxels. + if "distillation_scope" in data and data["distillation_scope"] is not None: + scope = str(data["distillation_scope"]).lower() + if scope in {"all", "unlabeled"}: + updated_params["distillation_all_voxels"] = scope == "all" + else: + logger.warning( + f"Ignoring invalid distillation_scope on restart for job {job_id}: {data['distillation_scope']}" + ) + + signal_t0 = time.perf_counter() + job = g.finetune_job_manager.restart_finetuning_job( + job_id=job_id, updated_params=updated_params + ) + signal_elapsed = time.perf_counter() - signal_t0 + total_elapsed = time.perf_counter() - restart_t0 + logger.info( + f"Restart request processed for job {job_id}: " + f"signal_write={signal_elapsed:.2f}s total={total_elapsed:.2f}s" + ) + + return jsonify( + { + "success": True, + "job_id": job.job_id, + "message": "Restart request sent. Training will restart on the same GPU.", + } + ) + + except Exception as e: + logger.error(f"Error restarting job: {e}") + return jsonify({"success": False, "error": str(e)}), 500 diff --git a/cellmap_flow/dashboard/routes/index_page.py b/cellmap_flow/dashboard/routes/index_page.py index 6752553..dedd947 100644 --- a/cellmap_flow/dashboard/routes/index_page.py +++ b/cellmap_flow/dashboard/routes/index_page.py @@ -6,7 +6,6 @@ from cellmap_flow.post.postprocessors import get_postprocessors_list from cellmap_flow.models.model_merger import get_model_mergers_list from cellmap_flow.globals import g -import cellmap_flow.dashboard.state as state logger = logging.getLogger(__name__) @@ -29,8 +28,8 @@ def index(): return render_template( "index.html", - neuroglancer_url=state.NEUROGLANCER_URL, - inference_servers=state.INFERENCE_SERVER, + neuroglancer_url=g.NEUROGLANCER_URL, + inference_servers=g.INFERENCE_SERVER, input_normalizers=input_norms, output_postprocessors=output_postprocessors, model_mergers=model_mergers, diff --git a/cellmap_flow/dashboard/routes/logging_routes.py b/cellmap_flow/dashboard/routes/logging_routes.py index 73ff32f..711f115 100644 --- a/cellmap_flow/dashboard/routes/logging_routes.py +++ b/cellmap_flow/dashboard/routes/logging_routes.py @@ -4,7 +4,7 @@ from flask import Blueprint, Response -from cellmap_flow.dashboard.state import log_buffer, log_clients +from cellmap_flow.globals import g logger = logging.getLogger(__name__) @@ -16,12 +16,12 @@ def stream_logs(): """Stream logs via Server-Sent Events (SSE)""" def generate(): # Send existing log buffer first - for log_line in log_buffer: + for log_line in g.log_buffer: yield f"data: {log_line}\n\n" # Create a queue for this client client_queue = queue.Queue(maxsize=100) - log_clients.append(client_queue) + g.log_clients.append(client_queue) try: while True: @@ -33,8 +33,8 @@ def generate(): yield ": keepalive\n\n" finally: # Clean up when client disconnects - if client_queue in log_clients: - log_clients.remove(client_queue) + if client_queue in g.log_clients: + g.log_clients.remove(client_queue) return Response(generate(), mimetype="text/event-stream", headers={ "Cache-Control": "no-cache", diff --git a/cellmap_flow/dashboard/routes/pipeline.py b/cellmap_flow/dashboard/routes/pipeline.py index 1bb2f9e..c6af045 100644 --- a/cellmap_flow/dashboard/routes/pipeline.py +++ b/cellmap_flow/dashboard/routes/pipeline.py @@ -16,7 +16,6 @@ from cellmap_flow.utils.load_py import load_safe_config from cellmap_flow.utils.scale_pyramid import get_raw_layer from cellmap_flow.utils.web_utils import encode_to_str, ARGS_KEY -from cellmap_flow.dashboard.state import CUSTOM_CODE_FOLDER logger = logging.getLogger(__name__) @@ -141,7 +140,7 @@ def process(): # Save custom code to a file with date and time timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") filename = f"custom_code_{timestamp}.py" - filepath = os.path.join(CUSTOM_CODE_FOLDER, filename) + filepath = os.path.join(g.CUSTOM_CODE_FOLDER, filename) with open(filepath, "w") as file: file.write(custom_code) diff --git a/cellmap_flow/dashboard/state.py b/cellmap_flow/dashboard/state.py index 07dd777..389c95b 100644 --- a/cellmap_flow/dashboard/state.py +++ b/cellmap_flow/dashboard/state.py @@ -1,53 +1,15 @@ -import os -import logging -import queue -from collections import deque - -logger = logging.getLogger(__name__) - -# Global log buffer for streaming to frontend -log_buffer = deque(maxlen=1000) # Keep last 1000 lines -log_clients = [] # List of queues for connected clients - - -# Custom handler to capture logs -class LogHandler(logging.Handler): - def emit(self, record): - log_entry = self.format(record) - log_buffer.append(log_entry) - # Send to all connected clients - for client_queue in log_clients: - try: - client_queue.put_nowait(log_entry) - except queue.Full: - pass - - -NEUROGLANCER_URL = None -INFERENCE_SERVER = None - -CUSTOM_CODE_FOLDER = os.path.expanduser( - os.environ.get( - "CUSTOM_CODE_FOLDER", - "~/Desktop/cellmap/cellmap-flow/example/example_norm", - ) -) - -# Blockwise task directory will be set from globals or use default -def get_blockwise_tasks_dir(): - from cellmap_flow.globals import g - tasks_dir = getattr(g, 'blockwise_tasks_dir', None) or os.path.expanduser("~/.cellmap_flow/blockwise_tasks") - os.makedirs(tasks_dir, exist_ok=True) - return tasks_dir - - -# Global state for BBX generator -bbx_generator_state = { - "dataset_path": None, - "num_boxes": 0, - "bounding_boxes": [], - "viewer": None, - "viewer_process": None, - "viewer_url": None, - "viewer_state": None -} +# Re-export all dashboard state from the globals singleton for backward compatibility. +# New code should import directly from cellmap_flow.globals. + +from cellmap_flow.globals import g, LogHandler, get_blockwise_tasks_dir # noqa: F401 + +log_buffer = g.log_buffer +log_clients = g.log_clients +NEUROGLANCER_URL = g.NEUROGLANCER_URL +INFERENCE_SERVER = g.INFERENCE_SERVER +CUSTOM_CODE_FOLDER = g.CUSTOM_CODE_FOLDER +bbx_generator_state = g.bbx_generator_state +finetune_job_manager = g.finetune_job_manager +minio_state = g.minio_state +annotation_volumes = g.annotation_volumes +output_sessions = g.output_sessions diff --git a/cellmap_flow/dashboard/static/css/dark.css b/cellmap_flow/dashboard/static/css/dark.css index 13d2028..b9c34ab 100644 --- a/cellmap_flow/dashboard/static/css/dark.css +++ b/cellmap_flow/dashboard/static/css/dark.css @@ -146,4 +146,74 @@ background-color: #1a4971; border-color: #2980b9; color: #cce5ff; + } + + /* Modal styling for dark mode */ + .modal-content { + background-color: #1e1e1e; + color: #ffffff; + border: 1px solid #555; + } + + .modal-header { + border-bottom-color: #555; + } + + .modal-footer { + border-top-color: #555; + } + + /* Form controls in dark mode */ + .form-control, .form-select { + background-color: #2a2a2a; + color: #ffffff; + border-color: #555; + } + + .form-control:focus, .form-select:focus { + background-color: #2a2a2a; + color: #ffffff; + border-color: #0d6efd; + } + + /* Muted text readable on dark backgrounds */ + .text-muted, .form-text { + color: #adb5bd !important; + } + + /* Card styling for dark mode */ + .card { + background-color: #1e1e1e; + border-color: #555; + color: #ffffff; + } + + .card-header { + background-color: #2a2a2a; + border-bottom-color: #555; + color: #ffffff; + } + + .card-body { + color: #ffffff; + } + + /* Labels and headings in dark mode */ + .form-label, label { + color: #e0e0e0 !important; + } + + h1, h2, h3, h4, h5, h6 { + color: #ffffff; + } + + /* Badge secondary needs contrast */ + .badge.bg-secondary { + color: #ffffff; + } + + /* Placeholder text */ + .form-control::placeholder { + color: #888 !important; + opacity: 1; } \ No newline at end of file diff --git a/cellmap_flow/dashboard/templates/_dashboard.html b/cellmap_flow/dashboard/templates/_dashboard.html index 955c18a..6974a0c 100644 --- a/cellmap_flow/dashboard/templates/_dashboard.html +++ b/cellmap_flow/dashboard/templates/_dashboard.html @@ -51,6 +51,22 @@

Dashboard

Postprocess + + + @@ -81,6 +97,16 @@

Dashboard

> {% include "_output_tab.html" %} + + +
+ {% include "_finetune_tab.html" %} +
\ No newline at end of file diff --git a/cellmap_flow/dashboard/templates/_finetune_tab.html b/cellmap_flow/dashboard/templates/_finetune_tab.html new file mode 100644 index 0000000..1a97acc --- /dev/null +++ b/cellmap_flow/dashboard/templates/_finetune_tab.html @@ -0,0 +1,1364 @@ + + + + + +
+ +
+ +
+ +
+ +
+ + +
+ + Crop will be sized to model's output inference size + +
+ + + + + +
+ + + Directory where annotation crops will be saved (must be accessible to MinIO). Crop will be created at current view center position. +
+ + +
+ + + +
+ + Crop: Small region at current view center (dense, paint 1=foreground). + Volume: Full dataset extent (sparse, paint 1=background, 2=foreground). + + + + + + +
+ + +
+
+
+ + +
+ + +
+
+
Training Configuration
+
+
+ +
+ + + + Override the base model checkpoint to finetune from. If left empty, the system will attempt to extract it from the model configuration or script. + +
+ +
+
+
+ + + Higher rank = more trainable parameters +
+
+
+
+ + + Typical range: 10-20 epochs +
+
+
+ +
+
+
+ + + Higher = faster but uses more GPU memory +
+
+
+
+ + + LoRA typically uses higher learning rates +
+
+
+ +
+
+
+ + + Margin is recommended for sparse annotations +
+
+
+
+ + + Keeps model close to original predictions +
+
+
+
+
+
+ + + Where to apply distillation loss +
+
+
+
+
+ + +
+ Weight fg and bg equally in loss regardless of scribble ratio. Prevents foreground overprediction. +
+
+
+ +
+
+
+ + +
+
+
+ + +
+ + +
+ +
+ +
+
+
+ + + + + +
+
+
Training Logs
+
+ + +
+
+
+
+
+ Loss vs Epoch + No loss data yet +
+ +
+ +
+
+
+
+ + + + + diff --git a/cellmap_flow/finetune/__init__.py b/cellmap_flow/finetune/__init__.py new file mode 100644 index 0000000..2077344 --- /dev/null +++ b/cellmap_flow/finetune/__init__.py @@ -0,0 +1,38 @@ +""" +Human-in-the-loop finetuning for CellMap-Flow models. + +This package provides lightweight LoRA-based finetuning for pre-trained models +using user corrections as training data. +""" + +from cellmap_flow.finetune.lora_wrapper import ( + detect_adaptable_layers, + wrap_model_with_lora, + print_lora_parameters, + load_lora_adapter, + save_lora_adapter, +) + +from cellmap_flow.finetune.correction_dataset import ( + CorrectionDataset, + create_dataloader, +) + +from cellmap_flow.finetune.lora_trainer import ( + LoRAFinetuner, + DiceLoss, + CombinedLoss, +) + +__all__ = [ + "detect_adaptable_layers", + "wrap_model_with_lora", + "print_lora_parameters", + "load_lora_adapter", + "save_lora_adapter", + "CorrectionDataset", + "create_dataloader", + "LoRAFinetuner", + "DiceLoss", + "CombinedLoss", +] diff --git a/cellmap_flow/finetune/correction_dataset.py b/cellmap_flow/finetune/correction_dataset.py new file mode 100644 index 0000000..96ed451 --- /dev/null +++ b/cellmap_flow/finetune/correction_dataset.py @@ -0,0 +1,353 @@ +""" +PyTorch Dataset for loading user corrections. + +This module provides a Dataset class that loads 3D EM data and correction +masks from Zarr files for training LoRA adapters. +""" + +import logging +from pathlib import Path +from typing import List, Tuple, Optional +import numpy as np +import zarr +import torch +from torch.utils.data import Dataset + +logger = logging.getLogger(__name__) + + +class CorrectionDataset(Dataset): + """ + PyTorch Dataset for user corrections stored in Zarr format. + + Loads raw EM data and corrected masks from corrections.zarr/, with + optional 3D augmentation. + + Args: + corrections_zarr_path: Path to corrections.zarr directory + patch_shape: Shape of patches to extract (Z, Y, X) + If None, uses full correction size + augment: Whether to apply 3D augmentation + model_name: If specified, only load corrections for this model + + Examples: + >>> dataset = CorrectionDataset( + ... "test_corrections.zarr", + ... patch_shape=(64, 64, 64), + ... augment=True + ... ) + >>> print(f"Dataset size: {len(dataset)}") + >>> raw, target = dataset[0] + >>> print(f"Raw shape: {raw.shape}, Target shape: {target.shape}") + """ + + def __init__( + self, + corrections_zarr_path: str, + patch_shape: Optional[Tuple[int, int, int]] = None, + augment: bool = True, + model_name: Optional[str] = None, + ): + self.corrections_path = Path(corrections_zarr_path) + self.patch_shape = patch_shape + self.augment = augment + self.model_name = model_name + + # Load corrections + self.corrections = self._load_corrections() + + if len(self.corrections) == 0: + raise ValueError( + f"No corrections found in {corrections_zarr_path}. " + f"Generate corrections first." + ) + + logger.info( + f"Loaded {len(self.corrections)} corrections from {corrections_zarr_path}" + ) + + def _load_corrections(self) -> List[dict]: + """Load correction metadata from Zarr.""" + corrections = [] + + logger.info(f"Loading corrections from: {self.corrections_path}") + + if not self.corrections_path.exists(): + logger.error(f"Corrections path does not exist: {self.corrections_path}") + return corrections + + path_str = str(self.corrections_path) + z = zarr.open_group(path_str, mode="r") + + for correction_id in z.keys(): + corr_group = z[correction_id] + + # Check if correction has required data + # Support both 'mask' (from test scripts) and 'annotation' (from dashboard) + has_raw = "raw" in corr_group + has_mask = "mask" in corr_group + has_annotation = "annotation" in corr_group + + has_raw_s0 = has_raw and "s0" in corr_group["raw"] + has_mask_s0 = has_mask and "s0" in corr_group["mask"] + has_annotation_s0 = has_annotation and "s0" in corr_group["annotation"] + + if not has_raw_s0 or not (has_mask_s0 or has_annotation_s0): + logger.warning( + f"Skipping {correction_id}: missing raw/s0 or mask|annotation/s0" + ) + continue + + # Use 'mask' if available, otherwise use 'annotation' + mask_key = "mask" if has_mask_s0 else "annotation" + + # Get metadata + attrs = dict(corr_group.attrs) + + # Filter by model name if specified + if self.model_name and attrs.get("model_name") != self.model_name: + continue + + raw_path = self.corrections_path / correction_id / "raw" / "s0" + mask_path = self.corrections_path / correction_id / mask_key / "s0" + + if not raw_path.exists() or not mask_path.exists(): + logger.warning( + f"Skipping {correction_id}: missing paths " + f"raw_path={raw_path} (exists={raw_path.exists()}), " + f"mask_path={mask_path} (exists={mask_path.exists()})" + ) + continue + + corrections.append( + { + "id": correction_id, + "raw_path": str(raw_path), + "mask_path": str(mask_path), + "metadata": attrs, + } + ) + + return corrections + + def __len__(self) -> int: + """Return number of corrections.""" + return len(self.corrections) + + def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Load a correction pair (raw, target). + + Args: + idx: Index of correction + + Returns: + Tuple of (raw, target) tensors: + - raw: (1, Z, Y, X) float32 tensor + - target: (1, Z, Y, X) float32 tensor, values in [0, 1] + """ + correction = self.corrections[idx] + + # Load data using ImageDataInterface for consistent data loading + from cellmap_flow.image_data_interface import ImageDataInterface + + try: + raw = ImageDataInterface( + correction["raw_path"], normalize=False + ).to_ndarray_ts() + mask = ImageDataInterface( + correction["mask_path"], normalize=False + ).to_ndarray_ts() + except Exception as e: + raise FileNotFoundError( + f"Failed loading correction '{correction.get('id', idx)}' " + f"raw_path='{correction.get('raw_path')}' " + f"mask_path='{correction.get('mask_path')}': {e}" + ) from e + + # Convert to float + raw = raw.astype(np.float32) + mask = mask.astype(np.float32) + + # Normalize mask to [0, 1] + # Only normalize pixel-intensity masks (0-255 range), not class labels (0, 1, 2) + # Class labels are small integers used by mask_unannotated logic in trainer + if mask.max() > 2.0: + mask = mask / 255.0 + + # For models with different input/output sizes, we keep raw at full size + # Patching is disabled for this case - use full corrections + # Apply augmentation (only if raw and mask have same shape) + if self.augment and raw.shape == mask.shape: + raw, mask = self._augment_3d(raw, mask) + elif self.augment and raw.shape != mask.shape: + logger.debug( + f"Skipping augmentation: raw {raw.shape} != mask {mask.shape}. " + "Augmentation requires matching sizes." + ) + + # Add channel dimension and convert to torch + raw = torch.from_numpy(raw[np.newaxis, ...]) # (1, Z, Y, X) + mask = torch.from_numpy(mask[np.newaxis, ...]) # (1, Z, Y, X) + + return raw, mask + + def _random_crop( + self, raw: np.ndarray, mask: np.ndarray, patch_shape: Tuple[int, int, int] + ) -> Tuple[np.ndarray, np.ndarray]: + """ + Extract a random patch from the volumes. + + Args: + raw: Raw data (Z, Y, X) + mask: Mask data (Z, Y, X) + patch_shape: Desired patch shape (Z, Y, X) + + Returns: + Cropped (raw, mask) pair + """ + z, y, x = raw.shape + pz, py, px = patch_shape + + # If volume is smaller than patch, pad it + if z < pz or y < py or x < px: + pad_z = max(0, pz - z) + pad_y = max(0, py - y) + pad_x = max(0, px - x) + + raw = np.pad(raw, ((0, pad_z), (0, pad_y), (0, pad_x)), mode="reflect") + mask = np.pad(mask, ((0, pad_z), (0, pad_y), (0, pad_x)), mode="reflect") + z, y, x = raw.shape + + # Random offset + z_offset = np.random.randint(0, max(1, z - pz + 1)) + y_offset = np.random.randint(0, max(1, y - py + 1)) + x_offset = np.random.randint(0, max(1, x - px + 1)) + + # Crop + raw_crop = raw[ + z_offset : z_offset + pz, y_offset : y_offset + py, x_offset : x_offset + px + ] + mask_crop = mask[ + z_offset : z_offset + pz, y_offset : y_offset + py, x_offset : x_offset + px + ] + + return raw_crop, mask_crop + + def _augment_3d( + self, raw: np.ndarray, mask: np.ndarray + ) -> Tuple[np.ndarray, np.ndarray]: + """ + Apply 3D augmentation to raw and mask. + + Augmentations: + - Random flips on Z/Y/X axes (50% each) + - Random 90° rotations in XY plane (0°, 90°, 180°, 270°) + - Random intensity scaling for raw (×0.8 to ×1.2) + - Random Gaussian noise for raw (σ=0.01) + + Args: + raw: Raw data (Z, Y, X) + mask: Mask data (Z, Y, X) + + Returns: + Augmented (raw, mask) pair + """ + # Random flips + if np.random.rand() > 0.5: + raw = np.flip(raw, axis=0).copy() # Flip Z + mask = np.flip(mask, axis=0).copy() + + if np.random.rand() > 0.5: + raw = np.flip(raw, axis=1).copy() # Flip Y + mask = np.flip(mask, axis=1).copy() + + if np.random.rand() > 0.5: + raw = np.flip(raw, axis=2).copy() # Flip X + mask = np.flip(mask, axis=2).copy() + + # Random 90° rotation in XY plane + k = np.random.randint(0, 4) # 0, 1, 2, or 3 (0°, 90°, 180°, 270°) + if k > 0: + raw = np.rot90(raw, k=k, axes=(1, 2)).copy() + mask = np.rot90(mask, k=k, axes=(1, 2)).copy() + + # Intensity augmentation for raw only + # Random scaling (×0.8 to ×1.2) + scale = np.random.uniform(0.8, 1.2) + raw = np.clip(raw * scale, 0, 1) + + # Random Gaussian noise (σ=0.01) + noise = np.random.normal(0, 0.01, raw.shape).astype(np.float32) + raw = np.clip(raw + noise, 0, 1) + + return raw, mask + + +def create_dataloader( + corrections_zarr_path: str, + batch_size: int = 2, + patch_shape: Optional[Tuple[int, int, int]] = None, + augment: bool = True, + num_workers: int = 4, + shuffle: bool = True, + model_name: Optional[str] = None, +) -> torch.utils.data.DataLoader: + """ + Create a DataLoader for corrections. + + Args: + corrections_zarr_path: Path to corrections.zarr directory + batch_size: Batch size (2-4 recommended for 3D data) + patch_shape: Shape of patches to extract (Z, Y, X) + augment: Whether to apply augmentation + num_workers: Number of data loading workers + shuffle: Whether to shuffle data + model_name: If specified, only load corrections for this model + + Returns: + DataLoader instance + + Examples: + >>> dataloader = create_dataloader( + ... "test_corrections.zarr", + ... batch_size=2, + ... patch_shape=(64, 64, 64) + ... ) + >>> for raw, target in dataloader: + ... print(f"Batch: raw={raw.shape}, target={target.shape}") + ... break + Batch: raw=torch.Size([2, 1, 64, 64, 64]), target=torch.Size([2, 1, 64, 64, 64]) + """ + dataset = CorrectionDataset( + corrections_zarr_path, + patch_shape=patch_shape, + augment=augment, + model_name=model_name, + ) + + # Clamp batch size to number of samples so DataLoader doesn't error + actual_batch_size = ( + min(batch_size, len(dataset)) if len(dataset) > 0 else batch_size + ) + if actual_batch_size != batch_size: + logger.info( + f"Clamped batch_size from {batch_size} to {actual_batch_size} " + f"(only {len(dataset)} samples available)" + ) + + dataloader = torch.utils.data.DataLoader( + dataset, + batch_size=actual_batch_size, + shuffle=shuffle, + num_workers=num_workers, + pin_memory=True, # Faster GPU transfer + persistent_workers=num_workers > 0, # Keep workers alive between epochs + ) + + logger.info( + f"Created DataLoader with {len(dataset)} samples, " + f"batch_size={actual_batch_size}, num_workers={num_workers}" + ) + + return dataloader diff --git a/cellmap_flow/finetune/finetune_cli.py b/cellmap_flow/finetune/finetune_cli.py new file mode 100644 index 0000000..4f4baba --- /dev/null +++ b/cellmap_flow/finetune/finetune_cli.py @@ -0,0 +1,846 @@ +#!/usr/bin/env python +""" +Command-line interface for LoRA finetuning. + +Usage: + python -m cellmap_flow.finetune.finetune_cli \ + --model-checkpoint /path/to/checkpoint \ + --corrections corrections.zarr \ + --output-dir output/fly_organelles_v1.1 + + # With custom settings + python -m cellmap_flow.finetune.finetune_cli \ + --model-checkpoint /path/to/checkpoint \ + --corrections corrections.zarr \ + --output-dir output/fly_organelles_v1.1 \ + --lora-r 16 \ + --batch-size 4 \ + --num-epochs 20 \ + --learning-rate 2e-4 +""" + +import argparse +import gc +import json +import logging +import os +import socket +import sys +import threading +import time +from contextlib import closing +from datetime import datetime +from pathlib import Path +from typing import Optional + +import torch +import torch.nn as nn + +from cellmap_flow.models.models_config import FlyModelConfig, DaCapoModelConfig, ModelConfig +from cellmap_flow.finetune.lora_wrapper import wrap_model_with_lora +from cellmap_flow.finetune.correction_dataset import create_dataloader +from cellmap_flow.finetune.lora_trainer import LoRAFinetuner + +# Set up logging +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', + force=True, +) +logger = logging.getLogger(__name__) + + +class RestartController: + """In-memory restart control shared between training loop and server endpoint.""" + + def __init__(self): + self._event = threading.Event() + self._lock = threading.Lock() + self._pending = None + + def request_restart(self, payload: Optional[dict]) -> bool: + signal_data = { + "restart": True, + "timestamp": datetime.now().isoformat(), + "params": {}, + } + if isinstance(payload, dict): + if "timestamp" in payload and payload["timestamp"]: + signal_data["timestamp"] = payload["timestamp"] + if isinstance(payload.get("params"), dict): + signal_data["params"] = payload["params"] + + with self._lock: + self._pending = signal_data + self._event.set() + return True + + def get_if_triggered(self) -> Optional[dict]: + if not self._event.is_set(): + return None + with self._lock: + signal_data = self._pending + self._pending = None + self._event.clear() + return signal_data + + +def _wait_for_port_ready(host: str, port: int, timeout_s: float = 30.0, interval_s: float = 0.1) -> bool: + """Wait until a TCP port is accepting connections.""" + deadline = time.perf_counter() + timeout_s + while time.perf_counter() < deadline: + try: + with closing(socket.create_connection((host, port), timeout=0.5)): + return True + except OSError: + time.sleep(interval_s) + return False + + +def _start_inference_server_background( + args, model_config: ModelConfig, trained_model, restart_controller: Optional[RestartController] = None +): + """ + Start inference server in a background daemon thread. + + The server shares the same model object, so retraining updates weights + automatically without needing to restart the server. + + Args: + args: Command-line arguments + model_config: Base model configuration + trained_model: The trained LoRA model + + Returns: + (thread, port) tuple + """ + logger.info("=" * 60) + logger.info("Starting inference server with finetuned model...") + logger.info("=" * 60) + + startup_t0 = time.perf_counter() + + # Clear GPU cache from training + cleanup_t0 = time.perf_counter() + logger.info("Clearing GPU cache...") + torch.cuda.empty_cache() + gc.collect() + cleanup_elapsed = time.perf_counter() - cleanup_t0 + + # Validate serve data path + if not args.serve_data_path: + raise ValueError("--serve-data-path is required when --auto-serve is enabled") + + if not Path(args.serve_data_path).exists(): + raise ValueError(f"Data path not found: {args.serve_data_path}") + + # Use the already-trained model + logger.info("Using trained LoRA model for inference...") + + from cellmap_flow.models.models_config import _get_device + device = _get_device() + trained_model.eval() + logger.info(f"Model set to eval mode on {device}") + + # Replace the model in the config with our finetuned version + model_config.config.model = trained_model + + # Start server + from cellmap_flow.server import CellMapFlowServer + from cellmap_flow.utils.web_utils import get_free_port + + setup_t0 = time.perf_counter() + logger.info(f"Creating server for dataset: {model_config.name}_finetuned") + restart_callback = restart_controller.request_restart if restart_controller is not None else None + server = CellMapFlowServer(args.serve_data_path, model_config, restart_callback=restart_callback) + + # Get port + port = args.serve_port if args.serve_port != 0 else get_free_port() + + # Start in daemon thread (server.run() prints CELLMAP_FLOW_SERVER_IP marker automatically) + server_thread = threading.Thread( + target=server.run, + kwargs={'port': port, 'debug': False}, + daemon=True + ) + server_thread.start() + setup_elapsed = time.perf_counter() - setup_t0 + + wait_t0 = time.perf_counter() + server_ready = _wait_for_port_ready("127.0.0.1", port) + wait_elapsed = time.perf_counter() - wait_t0 + + host_url = f"http://{socket.gethostname()}:{port}" + total_elapsed = time.perf_counter() - startup_t0 + logger.info("=" * 60) + if server_ready: + logger.info(f"Inference server port is ready on 127.0.0.1:{port}") + else: + logger.warning(f"Inference server did not become ready within timeout on 127.0.0.1:{port}") + logger.info(f"Inference server running at {host_url}") + logger.info( + f"Startup timings (s): cleanup={cleanup_elapsed:.2f}, setup={setup_elapsed:.2f}, " + f"wait_for_bind={wait_elapsed:.2f}, total={total_elapsed:.2f}" + ) + logger.info("Server is running in background. Watching for restart signals...") + logger.info("=" * 60) + + return server_thread, port + + +def _wait_for_restart_signal( + signal_file: Optional[Path], + check_interval: float = 1.0, + restart_controller: Optional[RestartController] = None, +): + """ + Watch for a restart signal file. Blocks until signal appears. + + Prefers in-memory restart events from the control endpoint, and + falls back to a signal file for backward compatibility. + + Args: + signal_file: Optional path to watch for legacy signal file + check_interval: Seconds between checks + + Returns: + Dict with restart parameters, or None if signal file is malformed + """ + logger.info(f"Watching for restart signal (controller + file fallback: {signal_file})") + + while True: + if restart_controller is not None: + in_memory_signal = restart_controller.get_if_triggered() + if in_memory_signal is not None: + logger.info(f"Restart signal received via HTTP control endpoint: {in_memory_signal}") + return in_memory_signal + + if signal_file and signal_file.exists(): + try: + with open(signal_file) as f: + signal_data = json.load(f) + signal_file.unlink() # Remove signal file + logger.info(f"Restart signal received: {signal_data}") + return signal_data + except Exception as e: + logger.error(f"Error reading restart signal: {e}") + # Remove malformed signal file + try: + signal_file.unlink() + except OSError: + pass + return None + time.sleep(check_interval) + + +def _apply_restart_params(args, signal_data: dict): + """ + Update args with parameters from restart signal. + + Args: + args: argparse Namespace to update + signal_data: Dict from restart signal file + """ + params = signal_data.get("params", {}) + for key, value in params.items(): + if hasattr(args, key) and value is not None: + old_value = getattr(args, key) + setattr(args, key, value) + if old_value != value: + logger.info(f"Updated {key}: {old_value} -> {value}") + + +def _generate_model_files(args, model_config, timestamp): + """ + Generate model script and YAML files after training. + + Args: + args: Command-line arguments + model_config: Model configuration + timestamp: Timestamp string for naming + + Returns: + (finetuned_model_name, script_path, yaml_path) tuple + """ + from cellmap_flow.finetune.finetuned_model_templates import ( + generate_finetuned_model_script, + generate_finetuned_model_yaml + ) + + model_basename = model_config.name + finetuned_model_name = f"{model_basename}_finetuned_{timestamp}" + + # Create models directory in output + output_dir_path = Path(args.output_dir) + session_path = output_dir_path.parent.parent.parent + models_dir = session_path / "models" + models_dir.mkdir(exist_ok=True, parents=True) + + logger.info(f"Generating model files for {finetuned_model_name}...") + + # Generate script + script_path = generate_finetuned_model_script( + base_checkpoint=args.model_checkpoint if args.model_checkpoint else None, + lora_adapter_path=str(output_dir_path / "lora_adapter"), + model_name=finetuned_model_name, + channels=args.channels, + input_voxel_size=tuple(args.input_voxel_size), + output_voxel_size=tuple(args.output_voxel_size), + lora_r=args.lora_r, + lora_alpha=args.lora_alpha, + num_epochs=args.num_epochs, + learning_rate=args.learning_rate, + output_path=models_dir / f"{finetuned_model_name}.py", + base_script_path=args.model_script if hasattr(args, 'model_script') and args.model_script else None + ) + logger.info(f"Generated script: {script_path}") + + # Extract data path from corrections + corrections_path = Path(args.corrections) + zarr_dirs = list(corrections_path.glob("*.zarr")) + data_path = None + if zarr_dirs: + zattrs_file = zarr_dirs[0] / ".zattrs" + if zattrs_file.exists(): + with open(zattrs_file) as f: + metadata = json.load(f) + data_path = metadata.get("dataset_path") + + if not data_path: + logger.warning("Could not extract data_path from corrections, using serve_data_path") + data_path = args.serve_data_path if args.auto_serve else "/path/to/data.zarr" + + yaml_path = generate_finetuned_model_yaml( + script_path=script_path, + model_name=finetuned_model_name, + resolution=args.output_voxel_size[0], + output_path=models_dir / f"{finetuned_model_name}.yaml", + data_path=data_path + ) + logger.info(f"Generated YAML: {yaml_path}") + + return finetuned_model_name, script_path, yaml_path + + +def _build_target_transform(args, model_config): + """Build a TargetTransform based on CLI args.""" + from cellmap_flow.finetune.target_transforms import ( + BinaryTargetTransform, + BroadcastBinaryTargetTransform, + AffinityTargetTransform, + ) + + output_type = args.output_type + num_channels = model_config.config.output_channels + + if output_type == "binary": + if num_channels > 1 and args.select_channel is None: + logger.warning( + f"Model has {num_channels} output channels but --output-type is 'binary' " + f"and --select-channel is not set. Consider using --select-channel or " + f"--output-type binary_broadcast." + ) + return BinaryTargetTransform() + + elif output_type == "binary_broadcast": + logger.info(f"Broadcasting binary target to {num_channels} channels") + return BroadcastBinaryTargetTransform(num_channels) + + elif output_type == "affinities": + offsets = None + + # Try CLI arg first + if args.offsets: + offsets = json.loads(args.offsets) + + # Try reading from model script + if offsets is None and args.model_script: + offsets = _read_offsets_from_script(args.model_script) + + if offsets is None: + raise ValueError( + "Affinity output type requires offsets. Provide --offsets as a JSON list " + "(e.g. '[[1,0,0],[0,1,0],[0,0,1]]') or define an 'offsets' variable in " + "the model script." + ) + + if len(offsets) > num_channels: + raise ValueError( + f"Number of offsets ({len(offsets)}) exceeds model output channels " + f"({num_channels})." + ) + + if len(offsets) < num_channels: + logger.info( + f"Model has {num_channels} output channels but only {len(offsets)} affinity offsets. " + f"Remaining {num_channels - len(offsets)} channels (e.g. LSDs) will be masked out." + ) + + logger.info(f"Using affinity target transform with {len(offsets)} offsets: {offsets}") + return AffinityTargetTransform(offsets, num_channels=num_channels) + + else: + raise ValueError(f"Unknown output type: {output_type}") + + +def _read_offsets_from_script(script_path): + """Try to read an 'offsets' variable from a model script via AST parsing.""" + import ast + + try: + with open(script_path, "r") as f: + tree = ast.parse(f.read()) + + for node in ast.walk(tree): + if isinstance(node, ast.Assign): + for target in node.targets: + if isinstance(target, ast.Name) and target.id == "offsets": + return ast.literal_eval(node.value) + except Exception as e: + logger.debug(f"Could not read offsets from {script_path}: {e}") + + return None + + +def main(): + parser = argparse.ArgumentParser( + description="Finetune CellMap-Flow models with LoRA using user corrections" + ) + + # Model arguments + parser.add_argument( + "--model-type", + type=str, + default="fly", + choices=["fly", "dacapo"], + help="Model type (fly or dacapo)" + ) + parser.add_argument( + "--model-checkpoint", + type=str, + required=False, + default=None, + help="Path to model checkpoint (optional - can train from scratch)" + ) + parser.add_argument( + "--model-script", + type=str, + required=False, + default=None, + help="Path to model script (alternative to checkpoint)" + ) + parser.add_argument( + "--model-name", + type=str, + default=None, + help="Model name (for filtering corrections)" + ) + parser.add_argument( + "--channels", + type=str, + nargs="+", + default=["mito"], + help="Model output channels" + ) + parser.add_argument( + "--input-voxel-size", + type=int, + nargs=3, + default=[16, 16, 16], + help="Input voxel size (Z Y X)" + ) + parser.add_argument( + "--output-voxel-size", + type=int, + nargs=3, + default=[16, 16, 16], + help="Output voxel size (Z Y X)" + ) + + # LoRA arguments + parser.add_argument( + "--lora-r", + type=int, + default=8, + help="LoRA rank (default: 8)" + ) + parser.add_argument( + "--lora-alpha", + type=int, + default=16, + help="LoRA alpha scaling (default: 16)" + ) + parser.add_argument( + "--lora-dropout", + type=float, + default=0.1, + help="LoRA dropout (default: 0.1)" + ) + + # Data arguments + parser.add_argument( + "--corrections", + type=str, + required=True, + help="Path to corrections.zarr directory" + ) + parser.add_argument( + "--patch-shape", + type=int, + nargs=3, + default=None, + help="Patch shape for training (Z Y X). Default: None (use full corrections)" + ) + parser.add_argument( + "--no-augment", + action="store_true", + help="Disable data augmentation" + ) + + # Training arguments + parser.add_argument( + "--output-dir", + type=str, + required=True, + help="Output directory for checkpoints and adapter" + ) + parser.add_argument( + "--batch-size", + type=int, + default=2, + help="Batch size (default: 2)" + ) + parser.add_argument( + "--num-epochs", + type=int, + default=10, + help="Number of training epochs (default: 10)" + ) + parser.add_argument( + "--learning-rate", + type=float, + default=1e-4, + help="Learning rate (default: 1e-4)" + ) + parser.add_argument( + "--gradient-accumulation-steps", + type=int, + default=1, + help="Gradient accumulation steps (default: 1)" + ) + parser.add_argument( + "--loss-type", + type=str, + default="combined", + choices=["dice", "bce", "combined", "mse", "margin"], + help="Loss function (default: combined)" + ) + parser.add_argument( + "--label-smoothing", + type=float, + default=0.0, + help="Label smoothing factor (e.g., 0.1 maps targets from 0/1 to 0.05/0.95). " + "Helps preserve gradual distance-like outputs. (default: 0.0)" + ) + parser.add_argument( + "--distillation-lambda", + type=float, + default=0.0, + help="Teacher distillation weight. Keeps model close to base on unlabeled voxels. " + "0.0=disabled, try 0.5-1.0 for sparse scribbles. (default: 0.0)" + ) + parser.add_argument( + "--distillation-all-voxels", + action="store_true", + help="Apply distillation loss on all voxels instead of only unlabeled voxels. (default: unlabeled only)" + ) + parser.add_argument( + "--margin", + type=float, + default=0.3, + help="Margin threshold for margin loss. " + "Foreground must exceed 1-margin, background must stay below margin. (default: 0.3)" + ) + parser.add_argument( + "--balance-classes", + action="store_true", + help="Balance fg/bg loss contribution so each class is weighted equally, " + "regardless of scribble voxel counts. Helps prevent foreground overprediction. (default: off)" + ) + parser.add_argument( + "--no-mixed-precision", + action="store_true", + help="Disable mixed precision (FP16) training" + ) + parser.add_argument( + "--num-workers", + type=int, + default=4, + help="DataLoader num_workers (default: 4)" + ) + + # Resuming + parser.add_argument( + "--resume", + type=str, + default=None, + help="Path to checkpoint to resume from" + ) + + # Auto-serve arguments + parser.add_argument( + "--auto-serve", + action="store_true", + help="Automatically start inference server after training completes" + ) + parser.add_argument( + "--serve-data-path", + type=str, + default=None, + help="Dataset path for inference server (required if --auto-serve is used)" + ) + parser.add_argument( + "--serve-port", + type=int, + default=0, + help="Port for inference server (0 for auto-assignment)" + ) + parser.add_argument( + "--mask-unannotated", + action="store_true", + help="Enable masked loss for sparse annotations (0=ignore, 1=bg, 2+=fg)" + ) + + # Output type and target transform arguments + parser.add_argument( + "--output-type", + type=str, + default="binary", + choices=["binary", "binary_broadcast", "affinities"], + help="How to generate training targets from annotations. " + "'binary': single-channel fg/bg (use with --select-channel for multi-channel models). " + "'binary_broadcast': broadcast binary target to all output channels. " + "'affinities': compute affinity targets from instance labels (requires offsets). " + "(default: binary)" + ) + parser.add_argument( + "--select-channel", + type=int, + default=None, + help="Select a single channel from multi-channel model output for binary training. " + "Only used with --output-type binary. (default: None, use all channels)" + ) + parser.add_argument( + "--offsets", + type=str, + default=None, + help="JSON list of [dz,dy,dx] offsets for affinity target generation. " + "Example: '[[1,0,0],[0,1,0],[0,0,1]]'. " + "If not provided with --output-type affinities, will try to read 'offsets' " + "from the model script." + ) + + args = parser.parse_args() + + # Print configuration + logger.info("=" * 60) + logger.info("LoRA Finetuning Configuration") + logger.info("=" * 60) + logger.info(f"Model type: {args.model_type}") + logger.info(f"Model checkpoint: {args.model_checkpoint}") + logger.info(f"Corrections: {args.corrections}") + logger.info(f"Output directory: {args.output_dir}") + logger.info(f"LoRA rank: {args.lora_r}") + logger.info(f"Batch size: {args.batch_size}") + logger.info(f"Epochs: {args.num_epochs}") + logger.info(f"Learning rate: {args.learning_rate}") + logger.info("") + + # === Load model (once) === + logger.info("Loading model...") + + if args.model_script: + from cellmap_flow.models.models_config import ScriptModelConfig + logger.info(f"Using script-based model: {args.model_script}") + model_config = ScriptModelConfig( + script_path=args.model_script, + name=args.model_name or "script_model" + ) + elif args.model_type == "fly": + if not args.model_checkpoint: + raise ValueError( + "For fly models, either --model-checkpoint or --model-script must be provided" + ) + model_config = FlyModelConfig( + checkpoint_path=args.model_checkpoint, + channels=args.channels, + input_voxel_size=tuple(args.input_voxel_size), + output_voxel_size=tuple(args.output_voxel_size), + name=args.model_name, + ) + elif args.model_type == "dacapo": + if not args.model_checkpoint: + raise ValueError("For dacapo models, --model-checkpoint is required") + checkpoint_path = Path(args.model_checkpoint) + iteration = int(checkpoint_path.stem.split('_')[-1]) + run_name = checkpoint_path.parent.name + + model_config = DaCapoModelConfig( + run_name=run_name, + iteration=iteration, + ) + else: + raise ValueError(f"Unknown model type: {args.model_type}") + + base_model = model_config.config.model + logger.info(f"Model loaded: {type(base_model).__name__}") + + # === Wrap with LoRA (once - same object is reused across restarts) === + logger.info(f"Wrapping model with LoRA (r={args.lora_r})...") + lora_model = wrap_model_with_lora( + base_model, + lora_r=args.lora_r, + lora_alpha=args.lora_alpha, + lora_dropout=args.lora_dropout, + ) + + # === Training loop (supports restart via signal file) === + server_started = False + restart_controller = RestartController() + iteration = 0 + + while True: + iteration += 1 + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + + if iteration > 1: + logger.info("") + logger.info("=" * 60) + logger.info(f"Training Iteration {iteration}") + logger.info("=" * 60) + + # Create dataloader (re-created each iteration to pick up new annotations) + logger.info(f"Loading corrections from {args.corrections}...") + dataloader = create_dataloader( + args.corrections, + batch_size=args.batch_size, + patch_shape=tuple(args.patch_shape) if args.patch_shape is not None else None, + augment=not args.no_augment, + num_workers=args.num_workers, + shuffle=True, + model_name=args.model_name, + ) + logger.info(f"DataLoader created: {len(dataloader.dataset)} corrections") + + # Build target transform (re-built each iteration to pick up restart params) + select_channel = args.select_channel + target_transform = _build_target_transform(args, model_config) + logger.info(f"output_type={args.output_type}, select_channel={select_channel}") + + # Create trainer (re-created each iteration for fresh optimizer/scheduler) + logger.info("Creating trainer...") + trainer = LoRAFinetuner( + lora_model, + dataloader, + output_dir=args.output_dir, + learning_rate=args.learning_rate, + num_epochs=args.num_epochs, + gradient_accumulation_steps=args.gradient_accumulation_steps, + use_mixed_precision=not args.no_mixed_precision, + loss_type=args.loss_type, + select_channel=select_channel, + mask_unannotated=args.mask_unannotated, + label_smoothing=args.label_smoothing, + distillation_lambda=args.distillation_lambda, + distillation_all_voxels=args.distillation_all_voxels, + margin=args.margin, + balance_classes=args.balance_classes, + target_transform=target_transform, + ) + + # Resume from checkpoint if specified (first iteration only) + if args.resume and iteration == 1: + logger.info(f"Resuming from checkpoint: {args.resume}") + trainer.load_checkpoint(args.resume) + + # Train + try: + stats = trainer.train() + + # Save final adapter + logger.info("\nSaving LoRA adapter...") + trainer.save_adapter() + + logger.info("\n" + "=" * 60) + logger.info("Finetuning Complete!") + logger.info(f"Best loss: {stats['best_loss']:.6f}") + logger.info(f"Adapter saved to: {args.output_dir}/lora_adapter") + logger.info("=" * 60) + + # Generate model files + finetuned_model_name, script_path, yaml_path = _generate_model_files( + args, model_config, timestamp + ) + + # Print completion marker with timestamp (for job manager to detect) + print(f"TRAINING_ITERATION_COMPLETE: {finetuned_model_name}", flush=True) + + # Auto-serve if requested + if args.auto_serve: + if not server_started: + # First time: start inference server in background thread + try: + _start_inference_server_background( + args, model_config, lora_model, restart_controller=restart_controller + ) + server_started = True + except Exception as e: + logger.error(f"Failed to start inference server: {e}", exc_info=True) + print(f"INFERENCE_SERVER_FAILED: {e}", flush=True) + return 0 + else: + # Server already running - just set model back to eval mode + # The server shares the same model object, so it automatically + # serves with the updated weights + lora_model.eval() + logger.info("Model updated and set to eval mode. Server continuing with new weights.") + + # Watch for restart signal + signal_file = Path(args.output_dir) / "restart_signal.json" + restart_data = _wait_for_restart_signal( + signal_file=signal_file, + check_interval=1.0, + restart_controller=restart_controller, + ) + + if restart_data is None: + logger.error("Malformed restart signal, exiting") + return 1 + + # Apply updated parameters + _apply_restart_params(args, restart_data) + + # Prepare for retraining + lora_model.train() + torch.cuda.empty_cache() + gc.collect() + logger.info("Restarting training with updated parameters...") + print("RESTARTING_TRAINING", flush=True) + continue # Loop back to retrain + + # No auto-serve: just exit after training + return 0 + + except KeyboardInterrupt: + logger.info("\nTraining interrupted by user") + logger.info("Saving current state...") + trainer.save_checkpoint(is_best=False) + return 1 + + except Exception as e: + logger.error(f"Training failed: {e}", exc_info=True) + return 1 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/cellmap_flow/finetune/finetune_job_manager.py b/cellmap_flow/finetune/finetune_job_manager.py new file mode 100644 index 0000000..e6ae5c9 --- /dev/null +++ b/cellmap_flow/finetune/finetune_job_manager.py @@ -0,0 +1,1228 @@ +""" +Job manager for orchestrating finetuning jobs on LSF cluster. + +This module provides: +- FinetuneJob: Track metadata and status of a single finetuning job +- FinetuneJobManager: Orchestrate job lifecycle from submission to completion +""" + +import json +import logging +import re +import threading +import time +import uuid +import requests +from dataclasses import dataclass, asdict +from datetime import datetime +from enum import Enum +from pathlib import Path +from typing import Dict, List, Optional, Any + +from cellmap_flow.utils.bsub_utils import ( + submit_bsub_job, + run_locally, + is_bsub_available, + LSFJob, + JobStatus as LSFJobStatus +) + +logger = logging.getLogger(__name__) + + +class JobStatus(Enum): + """Status of a finetuning job.""" + PENDING = "PENDING" + RUNNING = "RUNNING" + COMPLETED = "COMPLETED" + FAILED = "FAILED" + CANCELLED = "CANCELLED" + + +@dataclass +class FinetuneJob: + """Track a finetuning job with metadata, status, and training progress. + + Manages lifecycle from submission through completion, including inference + server state and restart chain linkage. + """ + job_id: str + lsf_job: Optional[LSFJob] + model_name: str + output_dir: Path + params: Dict[str, Any] + status: JobStatus + created_at: datetime + log_file: Path + finetuned_model_name: Optional[str] = None + model_script_path: Optional[Path] = None + model_yaml_path: Optional[Path] = None + current_epoch: int = 0 + total_epochs: int = 10 + latest_loss: Optional[float] = None + inference_server_url: Optional[str] = None + inference_server_ready: bool = False + previous_job_id: Optional[str] = None + next_job_id: Optional[str] = None + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary for JSON serialization.""" + # Get LSF job ID or local PID + lsf_job_id = None + if self.lsf_job: + if hasattr(self.lsf_job, 'job_id'): + lsf_job_id = self.lsf_job.job_id + elif hasattr(self.lsf_job, 'process'): + lsf_job_id = f"PID:{self.lsf_job.process.pid}" + + return { + "job_id": self.job_id, + "lsf_job_id": lsf_job_id, + "model_name": self.model_name, + "output_dir": str(self.output_dir), + "params": self.params, + "status": self.status.value, + "created_at": self.created_at.isoformat(), + "log_file": str(self.log_file), + "finetuned_model_name": self.finetuned_model_name, + "model_script_path": str(self.model_script_path) if self.model_script_path else None, + "model_yaml_path": str(self.model_yaml_path) if self.model_yaml_path else None, + "current_epoch": self.current_epoch, + "total_epochs": self.total_epochs, + "latest_loss": self.latest_loss, + "inference_server_url": self.inference_server_url, + "inference_server_ready": self.inference_server_ready, + "previous_job_id": self.previous_job_id, + "next_job_id": self.next_job_id, + } + + +class FinetuneJobManager: + """ + Orchestrate finetuning jobs from submission to completion. + + Manages the full lifecycle: + 1. Validation and job submission to LSF + 2. Background monitoring of training progress + 3. Post-training model registration + 4. Job cancellation and cleanup + """ + + def __init__(self): + """Initialize the job manager.""" + self.jobs: Dict[str, FinetuneJob] = {} + self.logger = logging.getLogger(__name__) + self._monitor_threads: Dict[str, threading.Thread] = {} + + def _get_model_metadata(self, model_config, attr_name: str, default=None): + """ + Get metadata from model config, checking both direct attributes and loaded config. + + Args: + model_config: The model configuration object + attr_name: Name of the attribute to retrieve + default: Default value if attribute not found + + Returns: + The attribute value if found, otherwise the default value + """ + # First try direct attribute access + if hasattr(model_config, attr_name): + value = getattr(model_config, attr_name, None) + if value is not None: + return value + + # Then try loading config and checking there + try: + config = model_config.config + if hasattr(config, attr_name): + value = getattr(config, attr_name, None) + if value is not None: + return value + except Exception as e: + self.logger.debug(f"Could not load config to check for {attr_name}: {e}") + + return default + + def _extract_data_path_from_corrections(self, corrections_path: Path) -> str: + """Extract dataset path from corrections metadata.""" + # Look for first .zarr directory + zarr_dirs = list(corrections_path.glob("*.zarr")) + if not zarr_dirs: + raise ValueError("No .zarr directories found in corrections") + + # Read .zattrs + zattrs_file = zarr_dirs[0] / ".zattrs" + if not zattrs_file.exists(): + raise ValueError("No .zattrs metadata found in corrections") + + with open(zattrs_file) as f: + metadata = json.load(f) + + if "dataset_path" not in metadata: + raise ValueError("No 'dataset_path' found in corrections metadata") + + return metadata["dataset_path"] + + def submit_finetuning_job( + self, + model_config, + corrections_path: Path, + lora_r: int = 8, + num_epochs: int = 10, + batch_size: int = 2, + learning_rate: float = 1e-4, + output_base: Optional[Path] = None, + queue: str = "gpu_h100", + charge_group: str = "cellmap", + checkpoint_path_override: Optional[Path] = None, + auto_serve: bool = True, + mask_unannotated: bool = False, + loss_type: str = "combined", + label_smoothing: float = 0.0, + distillation_lambda: float = 0.0, + distillation_scope: str = "unlabeled", + margin: float = 0.3, + balance_classes: bool = False, + output_type: str = "binary", + select_channel: Optional[int] = None, + offsets: Optional[str] = None, + ) -> FinetuneJob: + """ + Submit finetuning job to LSF cluster. + + Args: + model_config: Model configuration object (FlyModelConfig, etc.) + corrections_path: Path to corrections.zarr directory + lora_r: LoRA rank (default: 8) + num_epochs: Number of training epochs (default: 10) + batch_size: Training batch size (default: 2) + learning_rate: Learning rate (default: 1e-4) + output_base: Base directory for outputs (default: output/finetuning) + queue: LSF queue name (default: gpu_h100) + charge_group: LSF charge group (default: cellmap) + checkpoint_path_override: Optional path to override checkpoint detection (default: None) + auto_serve: Automatically start inference server after training (default: True) + + Returns: + FinetuneJob object tracking the submitted job + + Raises: + ValueError: If validation fails + RuntimeError: If job submission fails + """ + # === Validation === + + # 1. Check model config + if not model_config: + raise ValueError("Model config is required") + + # 2. Get checkpoint path if available (optional) + # For script models: we'll pass the script path instead + # For fly/dacapo models: we need the checkpoint path + checkpoint_path = None + + # Check for checkpoint override first + if checkpoint_path_override: + checkpoint_path = Path(checkpoint_path_override) + self.logger.info(f"Using checkpoint path override: {checkpoint_path}") + # For FlyModelConfig, get checkpoint_path attribute + elif hasattr(model_config, 'checkpoint_path') and model_config.checkpoint_path: + checkpoint_path = Path(model_config.checkpoint_path) + self.logger.info(f"Found checkpoint_path: {checkpoint_path}") + + # Validate checkpoint exists if specified + if checkpoint_path and not checkpoint_path.exists(): + raise ValueError( + f"Model checkpoint not found: {checkpoint_path}\n" + f"Please verify the path exists and is accessible." + ) + + # 3. Check corrections path exists + if not corrections_path.exists(): + raise ValueError(f"Corrections path does not exist: {corrections_path}") + + # 4. Count corrections (warn if few) + correction_dirs = list(corrections_path.glob("*/")) + num_corrections = len([d for d in correction_dirs if (d / ".zattrs").exists()]) + + if num_corrections == 0: + raise ValueError(f"No corrections found in {corrections_path}") + + if num_corrections < 5: + self.logger.warning( + f"Only {num_corrections} corrections found. " + "Recommend at least 5-10 for meaningful finetuning." + ) + + self.logger.info(f"Found {num_corrections} corrections for training") + + # === Setup output directory === + + if output_base is None: + output_base = Path("output/finetuning") + else: + output_base = Path(output_base) + + # Create timestamped run directory inside finetuning subdirectory + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + model_basename = model_config.name.replace("/", "_").replace(" ", "_") + run_dir_name = f"{model_basename}_{timestamp}" + output_dir = output_base / "finetuning" / "runs" / run_dir_name + output_dir.mkdir(parents=True, exist_ok=True) + + log_file = output_dir / "training_log.txt" + + self.logger.info(f"Output directory: {output_dir}") + + # === Build training command === + + # Get model metadata - try multiple sources + model_type = self._get_model_metadata(model_config, "model_type", "fly") + if model_type == "fly" and "dacapo" in model_config.name.lower(): + model_type = "dacapo" + + # Get channels - try multiple attribute names + channels = None + for attr_name in ["channels", "classes", "class_names"]: + channels = self._get_model_metadata(model_config, attr_name, None) + if channels: + break + if channels is None: + channels = ["mito"] # Default fallback + if isinstance(channels, str): + channels = [channels] + + # Get voxel sizes + input_voxel_size = self._get_model_metadata(model_config, "input_voxel_size", [16, 16, 16]) + output_voxel_size = self._get_model_metadata(model_config, "output_voxel_size", [16, 16, 16]) + + # Convert to list if needed (in case they're Coordinate objects) + if not isinstance(input_voxel_size, list): + input_voxel_size = list(input_voxel_size) + if not isinstance(output_voxel_size, list): + output_voxel_size = list(output_voxel_size) + + # Extract data path for inference server if auto-serve is enabled + serve_data_path = None + if auto_serve: + try: + serve_data_path = self._extract_data_path_from_corrections(corrections_path) + self.logger.info(f"Extracted dataset path for inference: {serve_data_path}") + except Exception as e: + self.logger.warning(f"Could not extract dataset path from corrections: {e}") + self.logger.warning("Auto-serve will be disabled") + auto_serve = False + + # Build CLI command + cli_command = f"python -m cellmap_flow.finetune.finetune_cli " + cli_command += f"--model-type {model_type} " + + # Add checkpoint or script path depending on what's available + if checkpoint_path: + cli_command += f"--model-checkpoint {checkpoint_path} " + elif hasattr(model_config, 'script_path'): + cli_command += f"--model-script {model_config.script_path} " + + cli_command += ( + f"--corrections {corrections_path} " + f"--output-dir {output_dir} " + f"--model-name {model_config.name} " + f"--channels {' '.join(channels)} " + f"--input-voxel-size {' '.join(map(str, input_voxel_size))} " + f"--output-voxel-size {' '.join(map(str, output_voxel_size))} " + f"--lora-r {lora_r} " + f"--lora-alpha {lora_r * 2} " + f"--num-epochs {num_epochs} " + f"--batch-size {batch_size} " + f"--learning-rate {learning_rate} " + f"--loss-type {loss_type} " + ) + + # Add label smoothing if specified + if label_smoothing > 0: + cli_command += f"--label-smoothing {label_smoothing} " + + # Add distillation lambda if specified + if distillation_lambda > 0: + cli_command += f"--distillation-lambda {distillation_lambda} " + if distillation_scope == "all": + cli_command += "--distillation-all-voxels " + + # Add margin if using margin loss + if loss_type == "margin": + cli_command += f"--margin {margin} " + + # Add auto-serve flags if enabled + if auto_serve and serve_data_path: + cli_command += f"--auto-serve --serve-data-path {serve_data_path} " + + # Add mask_unannotated flag for sparse annotations + if mask_unannotated: + cli_command += "--mask-unannotated " + + # Add class balancing flag + if balance_classes: + cli_command += "--balance-classes " + + # Add output type and related args + if output_type != "binary": + cli_command += f"--output-type {output_type} " + if select_channel is not None: + cli_command += f"--select-channel {select_channel} " + if offsets is not None: + cli_command += f"--offsets '{offsets}' " + + cli_command = f"stdbuf -oL {cli_command} 2>&1 | tee {log_file}" + + self.logger.info(f"Training command: {cli_command}") + + # === Save job metadata === + + metadata = { + "job_id": str(uuid.uuid4()), + "model_name": model_config.name, + "model_type": model_type, + "model_checkpoint": str(checkpoint_path) if checkpoint_path else None, + "model_script": str(model_config.script_path) if hasattr(model_config, 'script_path') else None, + "corrections_path": str(corrections_path), + "num_corrections": num_corrections, + "output_dir": str(output_dir), + "params": { + "model_checkpoint": str(checkpoint_path) if checkpoint_path else None, + "lora_r": lora_r, + "lora_alpha": lora_r * 2, + "num_epochs": num_epochs, + "batch_size": batch_size, + "learning_rate": learning_rate, + "loss_type": loss_type, + "label_smoothing": label_smoothing, + "distillation_lambda": distillation_lambda, + "distillation_scope": distillation_scope, + "margin": margin, + "balance_classes": balance_classes, + "channels": channels, + "input_voxel_size": input_voxel_size, + "output_voxel_size": output_voxel_size, + }, + "queue": queue, + "charge_group": charge_group, + "created_at": datetime.now().isoformat(), + "command": cli_command, + } + + metadata_file = output_dir / "metadata.json" + with open(metadata_file, "w") as f: + json.dump(metadata, f, indent=2) + + self.logger.info(f"Saved metadata to {metadata_file}") + + # === Submit job (LSF or local) === + + job_name = f"finetune_{model_basename}_{timestamp}" + + # Check if bsub is available + if is_bsub_available(): + self.logger.info("Submitting to LSF cluster via bsub") + try: + lsf_job = submit_bsub_job( + command=cli_command, + queue=queue, + charge_group=charge_group, + job_name=job_name, + num_gpus=1, + num_cpus=4 + ) + self.logger.info(f"Submitted LSF job {lsf_job.job_id} for finetuning") + except Exception as e: + self.logger.error(f"Failed to submit job to LSF: {e}") + raise RuntimeError(f"Job submission to LSF failed: {e}") + else: + # Fallback to local execution + self.logger.info("bsub not available - running finetuning locally") + try: + lsf_job = run_locally( + command=cli_command, + name=job_name + ) + self.logger.info(f"Started local finetuning job (PID: {lsf_job.process.pid})") + except Exception as e: + self.logger.error(f"Failed to start local job: {e}") + raise RuntimeError(f"Local job execution failed: {e}") + + # === Create FinetuneJob tracking object === + + job_id = metadata["job_id"] + + finetune_job = FinetuneJob( + job_id=job_id, + lsf_job=lsf_job, + model_name=model_config.name, + output_dir=output_dir, + params=metadata["params"], + status=JobStatus.PENDING, + created_at=datetime.now(), + log_file=log_file, + total_epochs=num_epochs + ) + + self.jobs[job_id] = finetune_job + + # === Start monitoring thread === + + monitor_thread = threading.Thread( + target=self.monitor_job, + args=(finetune_job,), + daemon=True + ) + monitor_thread.start() + self._monitor_threads[job_id] = monitor_thread + + self.logger.info(f"Started monitoring thread for job {job_id}") + + return finetune_job + + def monitor_job(self, finetune_job: FinetuneJob): + """ + Background thread for job monitoring. + + Polls LSF status and tails log file to track training progress. + Triggers completion when job finishes. + + Args: + finetune_job: The FinetuneJob to monitor + """ + job_id = finetune_job.job_id + self.logger.info(f"Monitoring job {job_id}...") + + last_log_position = 0 + check_interval = 3 # seconds + + try: + while True: + # === Check LSF job status === + + if finetune_job.lsf_job: + lsf_status = finetune_job.lsf_job.get_status() + + # Map LSF status to FinetuneJob status + if lsf_status == LSFJobStatus.RUNNING: + if finetune_job.status == JobStatus.PENDING: + self.logger.info(f"Job {job_id} started running") + finetune_job.status = JobStatus.RUNNING + elif lsf_status == LSFJobStatus.PENDING: + finetune_job.status = JobStatus.PENDING + elif lsf_status == LSFJobStatus.COMPLETED: + self.logger.info(f"Job {job_id} completed according to LSF") + finetune_job.status = JobStatus.COMPLETED + break + elif lsf_status == LSFJobStatus.FAILED: + self.logger.error(f"Job {job_id} failed according to LSF") + finetune_job.status = JobStatus.FAILED + break + elif lsf_status == LSFJobStatus.KILLED: + self.logger.warning(f"Job {job_id} was killed") + finetune_job.status = JobStatus.CANCELLED + break + + # === Tail log file for progress updates === + + if finetune_job.log_file.exists(): + try: + # Check if file was truncated (e.g., during restart archival) + file_size = finetune_job.log_file.stat().st_size + if file_size < last_log_position: + self.logger.info(f"Log file truncated (size {file_size} < position {last_log_position}), resetting") + last_log_position = 0 + + with open(finetune_job.log_file, "r") as f: + # Seek to last read position + f.seek(last_log_position) + new_content = f.read() + last_log_position = f.tell() + + if new_content: + # Parse for epoch and loss information + self._parse_training_progress(finetune_job, new_content) + # Parse for inference server ready marker + self._parse_inference_server_ready(finetune_job, new_content) + + # Always check for restart/iteration markers (reads full log). + # This must run every cycle, not just when there's new content, + # because the marker may have been at the end of the previous + # chunk and we need to detect it even if no new output follows. + self._parse_training_restart(finetune_job, new_content if new_content else "") + except Exception as e: + self.logger.debug(f"Error reading log file: {e}") + + # Sleep before next check + time.sleep(check_interval) + + except Exception as e: + self.logger.error(f"Error monitoring job {job_id}: {e}") + finetune_job.status = JobStatus.FAILED + + finally: + # === Post-completion actions === + + if finetune_job.status == JobStatus.COMPLETED: + try: + self.complete_job(finetune_job) + except Exception as e: + self.logger.error(f"Error in post-completion for job {job_id}: {e}") + finetune_job.status = JobStatus.FAILED + + self.logger.info(f"Stopped monitoring job {job_id}. Final status: {finetune_job.status.value}") + + def _parse_training_progress(self, finetune_job: FinetuneJob, log_content: str): + """ + Parse log content for training progress (epoch, loss). + + Args: + finetune_job: Job to update + log_content: New log content to parse + """ + # Look for patterns like "Epoch 5/10" and "Loss: 0.1234" + + # Match: Epoch X/Y + epoch_pattern = r"Epoch\s+(\d+)/(\d+)" + epoch_matches = re.findall(epoch_pattern, log_content, re.IGNORECASE) + if epoch_matches: + last_match = epoch_matches[-1] + finetune_job.current_epoch = int(last_match[0]) + finetune_job.total_epochs = int(last_match[1]) + + # Match: Loss: X.XXXX (various formats) + loss_patterns = [ + r"Loss:\s+([\d.]+)", + r"loss:\s+([\d.]+)", + r"avg_loss:\s+([\d.]+)", + ] + + for pattern in loss_patterns: + loss_matches = re.findall(pattern, log_content, re.IGNORECASE) + if loss_matches: + try: + finetune_job.latest_loss = float(loss_matches[-1]) + break + except ValueError: + pass + + def _add_finetuned_neuroglancer_layer(self, finetune_job: FinetuneJob, model_name: str): + """ + Add (or replace) the finetuned model's neuroglancer layer. + + Mirrors run_model() from cellmap_flow/models/run.py: + 1. Create/update Job object in g.jobs + 2. Add neuroglancer ImageLayer with pre/post processing args + + Args: + finetune_job: Job with inference_server_url set + model_name: Layer name (e.g. "mito_finetuned_20240101_120000") + """ + from cellmap_flow.globals import g + from cellmap_flow.utils.web_utils import get_norms_post_args, ARGS_KEY + import neuroglancer + + server_url = finetune_job.inference_server_url + + # Create a Job object for the running server + inference_job = LSFJob( + job_id=finetune_job.lsf_job.job_id if finetune_job.lsf_job else "local", + model_name=model_name + ) + inference_job.host = server_url + inference_job.status = LSFJobStatus.RUNNING + + # Remove any old finetuned jobs for this base model + g.jobs = [ + j for j in g.jobs + if not (hasattr(j, 'model_name') and j.model_name + and j.model_name.startswith(f"{finetune_job.model_name}_finetuned")) + ] + + # Add to g.jobs + g.jobs.append(inference_job) + self.logger.info(f"Added finetuned job to g.jobs: {model_name}") + + # Get pre/post processing args (same hash as other models) + st_data = get_norms_post_args(g.input_norms, g.postprocess) + + if g.viewer is None: + self.logger.error("g.viewer is None - neuroglancer not initialized yet") + return + + source_url = f"zarr://{server_url}/{model_name}{ARGS_KEY}{st_data}{ARGS_KEY}" + self.logger.info(f"Adding neuroglancer layer: {model_name}") + self.logger.info(f" source: {source_url}") + + with g.viewer.txn() as s: + # Remove old finetuned layer if it exists (exact name match) + old_layer_name = finetune_job.finetuned_model_name + if old_layer_name and old_layer_name in s.layers: + self.logger.info(f"Removing old finetuned layer: {old_layer_name}") + del s.layers[old_layer_name] + + # Also remove by current name in case of re-add + if model_name in s.layers: + del s.layers[model_name] + + # Add new layer - exact same format as run_model() + s.layers[model_name] = neuroglancer.ImageLayer( + source=source_url, + shader=f"""#uicontrol invlerp normalized(range=[0, 255], window=[0, 255]); + #uicontrol vec3 color color(default="red"); + void main(){{emitRGB(color * normalized());}}""", + ) + + # Update the stored name + finetune_job.finetuned_model_name = model_name + self.logger.info(f"Successfully added neuroglancer layer: {model_name}") + + def _parse_inference_server_ready(self, finetune_job: FinetuneJob, log_content: str): + """ + Parse log for CELLMAP_FLOW_SERVER_IP marker and add finetuned model + to neuroglancer exactly like a normal inference model. + + Args: + finetune_job: Job to update + log_content: New log content to parse + """ + if finetune_job.inference_server_ready: + return + + # Look for the standard server IP marker (same one start_hosts() uses) + from cellmap_flow.utils.web_utils import IP_PATTERN + ip_start = IP_PATTERN[0] + ip_end = IP_PATTERN[1] + + pattern = re.escape(ip_start) + r"(.+?)" + re.escape(ip_end) + matches = re.findall(pattern, log_content) + if not matches: + return + + server_url = matches[-1] + finetune_job.inference_server_url = server_url + finetune_job.inference_server_ready = True + self.logger.info(f"Finetuned inference server detected at {server_url}") + + try: + # Read the FULL log file to find TRAINING_ITERATION_COMPLETE marker. + # This marker is printed BEFORE the server starts, so it's typically + # in an earlier log chunk than the server IP marker. + iter_pattern = r"TRAINING_ITERATION_COMPLETE:\s+(\S+)" + full_log = finetune_job.log_file.read_text() + iter_matches = re.findall(iter_pattern, full_log) + if iter_matches: + model_name = iter_matches[-1] + else: + model_name = f"{finetune_job.model_name}_finetuned" + + self._add_finetuned_neuroglancer_layer(finetune_job, model_name) + + except Exception as e: + self.logger.error(f"Failed to add finetuned model to neuroglancer: {e}", exc_info=True) + + def _parse_training_restart(self, finetune_job: FinetuneJob, log_content: str): + """ + Parse log for RESTARTING_TRAINING and TRAINING_ITERATION_COMPLETE markers + to handle iterative training restarts. + + On RESTARTING_TRAINING: reset training progress counters. + On TRAINING_ITERATION_COMPLETE: update the neuroglancer layer name with new timestamp. + + Args: + finetune_job: Job to update + log_content: New log content to parse + """ + # Check for restart marker - reset progress + if "RESTARTING_TRAINING" in log_content: + self.logger.info(f"Training restart detected for job {finetune_job.job_id}") + finetune_job.current_epoch = 0 + finetune_job.latest_loss = None + finetune_job.status = JobStatus.RUNNING + finetune_job.inference_server_ready = False + + # Check for iteration complete marker - update neuroglancer layer. + # Read full log in case the marker was in a previous chunk. + iter_pattern = r"TRAINING_ITERATION_COMPLETE:\s+(\S+)" + try: + full_log = finetune_job.log_file.read_text() + except Exception: + full_log = log_content + iter_matches = re.findall(iter_pattern, full_log) + if iter_matches: + # For in-process restarts, the inference server usually stays on the same + # URL and does not emit a fresh CELLMAP_FLOW_SERVER_IP marker. Mark the + # server as ready once we see a completed training iteration if URL exists. + if finetune_job.inference_server_url: + finetune_job.inference_server_ready = True + + new_model_name = iter_matches[-1] + if new_model_name != finetune_job.finetuned_model_name: + self.logger.info(f"New training iteration complete: {new_model_name}") + try: + self._add_finetuned_neuroglancer_layer(finetune_job, new_model_name) + except Exception as e: + self.logger.error(f"Failed to update neuroglancer layer: {e}", exc_info=True) + # Still update the stored name so the frontend reflects the new model + # and we don't retry the failed neuroglancer update every cycle + finetune_job.finetuned_model_name = new_model_name + + def complete_job(self, finetune_job: FinetuneJob): + """ + Post-training actions after job completes successfully. + + 1. Verify adapter files exist + 2. Generate model script and YAML + 3. Register in g.models_config + 4. Update job status and metadata + + Args: + finetune_job: The completed job + + Raises: + RuntimeError: If adapter files missing or registration fails + """ + job_id = finetune_job.job_id + self.logger.info(f"Running post-completion for job {job_id}...") + + # === Verify adapter files exist === + + adapter_path = finetune_job.output_dir / "lora_adapter" + + # Check for adapter model (supports both .bin and .safetensors formats) + adapter_model_bin = adapter_path / "adapter_model.bin" + adapter_model_safetensors = adapter_path / "adapter_model.safetensors" + + if not (adapter_model_bin.exists() or adapter_model_safetensors.exists()): + raise RuntimeError( + f"Training completed but adapter model not found. " + f"Checked: {adapter_model_bin} and {adapter_model_safetensors}" + ) + + adapter_config_file = adapter_path / "adapter_config.json" + if not adapter_config_file.exists(): + raise RuntimeError( + f"Training completed but adapter config not found: {adapter_config_file}" + ) + + self.logger.info(f"Verified LoRA adapter files exist in {adapter_path}") + + # === Generate finetuned model name === + + timestamp = finetune_job.created_at.strftime("%Y%m%d_%H%M%S") + model_basename = finetune_job.model_name.replace("/", "_").replace(" ", "_") + finetuned_model_name = f"{model_basename}_finetuned_{timestamp}" + + finetune_job.finetuned_model_name = finetuned_model_name + + self.logger.info(f"Generated finetuned model name: {finetuned_model_name}") + + # === Generate model script and YAML === + + # Import here to avoid circular dependencies + from cellmap_flow.finetune.finetuned_model_templates import ( + generate_finetuned_model_script, + generate_finetuned_model_yaml + ) + + # Models output directory (at session level, not in finetuning subdirectory) + # output_dir structure: session_path/finetuning/runs/model_timestamp/ + # So parent.parent.parent gets us to session_path + models_dir = finetune_job.output_dir.parent.parent.parent / "models" + + try: + models_dir.mkdir(parents=True, exist_ok=True) + self.logger.info(f"Models directory ready: {models_dir}") + except Exception as e: + self.logger.error(f"Failed to create models directory {models_dir}: {e}") + raise RuntimeError(f"Failed to create models directory: {e}") + + # Check if files already exist (generated by CLI with auto-serve) + expected_script = models_dir / f"{finetuned_model_name}.py" + expected_yaml = models_dir / f"{finetuned_model_name}.yaml" + files_already_generated = expected_script.exists() and expected_yaml.exists() + + if files_already_generated: + self.logger.info(f"Model files already generated by CLI, skipping generation") + finetune_job.model_script_path = expected_script + finetune_job.model_yaml_path = expected_yaml + script_path = expected_script + yaml_path = expected_yaml + # Skip to registration + else: + self.logger.info(f"Generating model files...") + + # Get base model script path from metadata if available + metadata_file = finetune_job.output_dir / "metadata.json" + base_script_path = None + if metadata_file.exists(): + try: + with open(metadata_file, "r") as f: + metadata = json.load(f) + base_script_path = metadata.get("model_script", None) + self.logger.info(f"Found base model script in metadata: {base_script_path}") + except Exception as e: + self.logger.warning(f"Could not read base script from metadata: {e}") + + try: + # Generate .py script + self.logger.info(f"Generating finetuned model script for {finetuned_model_name}...") + self.logger.info(f" Base script path: {base_script_path}") + self.logger.info(f" LoRA adapter path: {adapter_path}") + self.logger.info(f" Output path: {models_dir / f'{finetuned_model_name}.py'}") + + script_path = generate_finetuned_model_script( + base_checkpoint=finetune_job.params.get("model_checkpoint", ""), + lora_adapter_path=str(adapter_path), + model_name=finetuned_model_name, + channels=finetune_job.params.get("channels", ["mito"]), + input_voxel_size=tuple(finetune_job.params.get("input_voxel_size", [16, 16, 16])), + output_voxel_size=tuple(finetune_job.params.get("output_voxel_size", [16, 16, 16])), + lora_r=finetune_job.params.get("lora_r", 8), + lora_alpha=finetune_job.params.get("lora_alpha", 16), + num_epochs=finetune_job.params.get("num_epochs", 10), + learning_rate=finetune_job.params.get("learning_rate", 1e-4), + output_path=models_dir / f"{finetuned_model_name}.py", + base_script_path=base_script_path + ) + + finetune_job.model_script_path = script_path + self.logger.info(f"Generated model script: {script_path}") + + # === Extract configuration from base model and corrections === + + data_path = None + json_data = None + base_scale = "s0" # Default scale (only safe default) + + # 1. Get dataset_path from corrections metadata (REQUIRED) + corrections_dir = Path(metadata.get("corrections_path", "")) + try: + data_path = self._extract_data_path_from_corrections(corrections_dir) + self.logger.info(f"Found dataset_path from corrections: {data_path}") + except (ValueError, Exception) as e: + self.logger.error(f"Could not extract dataset_path: {e}") + + # 2. Get normalization and preprocessing from base model YAML + if base_script_path: + self.logger.info("Extracting normalization from base model YAML...") + import yaml + base_yaml_path = Path(base_script_path).with_suffix('.yaml') + if base_yaml_path.exists(): + try: + with open(base_yaml_path, 'r') as f: + base_config = yaml.safe_load(f) + + # Get json_data (normalization and postprocessing) + if 'json_data' in base_config: + json_data = base_config['json_data'] + self.logger.info(f"✓ Found json_data from base model YAML") + else: + self.logger.warning(f"No json_data in base model YAML: {base_yaml_path}") + + # Get data_path from base model (fallback if not in corrections) + if not data_path and 'data_path' in base_config: + data_path = base_config['data_path'] + self.logger.info(f"✓ Using data_path from base model YAML: {data_path}") + + # Get scale + if 'models' in base_config and len(base_config['models']) > 0: + base_scale = base_config['models'][0].get('scale', 's0') + self.logger.info(f"✓ Found scale from base model: {base_scale}") + except Exception as e: + self.logger.error(f"Failed to read base model YAML {base_yaml_path}: {e}") + else: + self.logger.warning(f"Base model YAML not found: {base_yaml_path}") + + # 3. Validate we have required data (NO PLACEHOLDERS!) + if not data_path: + raise RuntimeError( + "Could not determine dataset_path for finetuned model. " + "Checked corrections metadata and base model YAML. " + "Cannot generate model config without actual dataset path." + ) + + if not json_data: + self.logger.warning( + "No json_data (normalization/postprocessing) found. " + "Finetuned model may not work correctly without proper normalization. " + "Consider adding json_data to base model YAML." + ) + + # Generate .yaml config + yaml_path = generate_finetuned_model_yaml( + script_path=script_path, + model_name=finetuned_model_name, + resolution=finetune_job.params.get("input_voxel_size", [16, 16, 16])[0], + output_path=models_dir / f"{finetuned_model_name}.yaml", + data_path=data_path, + queue=finetune_job.params.get("queue", "gpu_h100"), + json_data=json_data, + scale=base_scale + ) + + finetune_job.model_yaml_path = yaml_path + self.logger.info(f"Generated model YAML: {yaml_path}") + + except Exception as e: + import traceback + self.logger.error(f"Error generating model files: {e}") + self.logger.error(f"Traceback:\n{traceback.format_exc()}") + raise RuntimeError(f"Failed to generate model files: {e}") + + # === Update metadata file with completion info === + + metadata_file = finetune_job.output_dir / "metadata.json" + if metadata_file.exists(): + with open(metadata_file, "r") as f: + metadata = json.load(f) + + metadata["completed_at"] = datetime.now().isoformat() + metadata["status"] = "COMPLETED" + metadata["finetuned_model_name"] = finetuned_model_name + metadata["model_script_path"] = str(script_path) + metadata["model_yaml_path"] = str(yaml_path) + metadata["final_epoch"] = finetune_job.current_epoch + metadata["final_loss"] = finetune_job.latest_loss + + with open(metadata_file, "w") as f: + json.dump(metadata, f, indent=2) + + self.logger.info(f"Updated metadata file: {metadata_file}") + + self.logger.info(f"Job {job_id} completed successfully!") + + def cancel_job(self, job_id: str) -> bool: + """ + Cancel a running job. + + Args: + job_id: Job ID to cancel + + Returns: + True if successfully cancelled, False otherwise + """ + if job_id not in self.jobs: + self.logger.error(f"Job {job_id} not found") + return False + + finetune_job = self.jobs[job_id] + + if finetune_job.status in [JobStatus.COMPLETED, JobStatus.FAILED, JobStatus.CANCELLED]: + self.logger.warning(f"Job {job_id} already finished with status {finetune_job.status}") + return False + + self.logger.info(f"Cancelling job {job_id}...") + + if finetune_job.lsf_job: + try: + finetune_job.lsf_job.kill() + finetune_job.status = JobStatus.CANCELLED + self.logger.info(f"Successfully cancelled job {job_id}") + return True + except Exception as e: + self.logger.error(f"Error cancelling job {job_id}: {e}") + return False + else: + self.logger.error(f"No LSF job associated with {job_id}") + return False + + def get_job_status(self, job_id: str) -> Optional[Dict[str, Any]]: + """ + Get detailed status of a specific job. + + Args: + job_id: Job ID to query + + Returns: + Dictionary with job status details, or None if not found + """ + if job_id not in self.jobs: + return None + + finetune_job = self.jobs[job_id] + result = finetune_job.to_dict() + result["loss"] = result.pop("latest_loss", None) + result["progress_percent"] = ( + finetune_job.current_epoch / finetune_job.total_epochs * 100 + ) if finetune_job.total_epochs > 0 else 0 + return result + + def list_jobs(self) -> List[Dict[str, Any]]: + """ + Get list of all jobs with their status. + + Returns: + List of job status dictionaries + """ + return [self.get_job_status(job_id) for job_id in self.jobs.keys()] + + def get_job_logs(self, job_id: str) -> Optional[str]: + """ + Get full log content for a job. + + Args: + job_id: Job ID + + Returns: + Log file content as string, or None if not found + """ + if job_id not in self.jobs: + return None + + finetune_job = self.jobs[job_id] + + if not finetune_job.log_file.exists(): + return "Log file not yet created..." + + try: + with open(finetune_job.log_file, "r") as f: + return f.read() + except Exception as e: + self.logger.error(f"Error reading log file: {e}") + return f"Error reading log file: {e}" + + def get_job(self, job_id: str) -> Optional[FinetuneJob]: + """ + Get a FinetuneJob object by ID. + + Args: + job_id: Job ID to retrieve + + Returns: + FinetuneJob object, or None if not found + """ + return self.jobs.get(job_id) + + def _archive_job_logs(self, job: FinetuneJob): + """ + Archive logs before restart. + + Args: + job: The job whose logs to archive + """ + log_file = job.log_file + metadata_file = job.output_dir / "metadata.json" + + # Find next archive number + archive_num = 1 + while (job.output_dir / f"training_log_{archive_num}.txt").exists(): + archive_num += 1 + + # Archive log (copy only - do NOT truncate, as tee still has an open file descriptor) + if log_file.exists(): + import shutil + archive_log = job.output_dir / f"training_log_{archive_num}.txt" + shutil.copy(log_file, archive_log) + self.logger.info(f"Archived log to {archive_log}") + + # Archive metadata + if metadata_file.exists(): + import shutil + archive_meta = job.output_dir / f"metadata_{archive_num}.json" + shutil.copy(metadata_file, archive_meta) + self.logger.info(f"Archived metadata to {archive_meta}") + + def restart_finetuning_job( + self, + job_id: str, + updated_params: Optional[Dict[str, Any]] = None + ) -> FinetuneJob: + """ + Restart training on the same GPU via control endpoint. + + Primary path sends an HTTP restart request to the running + inference server in the same process as the training loop. + Falls back to file signal if control endpoint is unavailable. + + Args: + job_id: ID of job to restart + updated_params: Dict of updated training parameters + + Returns: + Same FinetuneJob object (updated in-place) + + Raises: + ValueError: If job not found or not in a restartable state + """ + restart_t0 = time.perf_counter() + + if job_id not in self.jobs: + raise ValueError(f"Job {job_id} not found") + + job = self.jobs[job_id] + + # Only allow restart if the job is running (serving after training) + if job.status not in [JobStatus.RUNNING, JobStatus.COMPLETED]: + raise ValueError( + f"Job {job_id} is in state {job.status.value} - " + f"can only restart jobs that are RUNNING (serving) or COMPLETED" + ) + + if not job.inference_server_ready: + raise ValueError( + f"Job {job_id} inference server not ready - " + f"training must complete and server must start before restarting" + ) + + # 1. Archive current logs + self.logger.info(f"Archiving logs for job {job_id}...") + archive_t0 = time.perf_counter() + self._archive_job_logs(job) + archive_elapsed = time.perf_counter() - archive_t0 + + signal_data = { + "restart": True, + "timestamp": datetime.now().isoformat(), + "params": updated_params or {} + } + + # 2. Send restart request to running inference server (primary path) + signal_write_mode = "http_control" + write_t0 = time.perf_counter() + http_error = None + if job.inference_server_url: + try: + control_url = job.inference_server_url.rstrip("/") + "/__control__/restart" + response = requests.post(control_url, json=signal_data, timeout=5) + response.raise_for_status() + data = response.json() + if not data.get("success", False): + raise RuntimeError(data.get("error", "Unknown restart control failure")) + self.logger.info(f"Sent restart request via HTTP control endpoint: {control_url}") + except Exception as e: + http_error = e + self.logger.warning(f"HTTP restart control failed for job {job_id}: {e}") + else: + http_error = RuntimeError("No inference_server_url for HTTP restart control") + + # 3. Fallback to signal file if HTTP control endpoint is unavailable + if http_error is not None: + signal_write_mode = "file_signal_fallback" + signal_file = job.output_dir / "restart_signal.json" + with open(signal_file, 'w') as f: + json.dump(signal_data, f, indent=2) + self.logger.info(f"Wrote fallback restart signal to {signal_file}") + write_elapsed = time.perf_counter() - write_t0 + + # 4. Reset training progress (keep inference server info) + job.current_epoch = 0 + job.latest_loss = None + job.status = JobStatus.RUNNING + job.inference_server_ready = False + + # 5. Update stored params + if updated_params: + job.params.update(updated_params) + + total_elapsed = time.perf_counter() - restart_t0 + self.logger.info( + f"Restart signal timings for job {job_id}: " + f"archive={archive_elapsed:.2f}s write={write_elapsed:.2f}s " + f"mode={signal_write_mode} total={total_elapsed:.2f}s" + ) + self.logger.info(f"Job {job_id} restart request sent, waiting for CLI to pick it up") + + return job diff --git a/cellmap_flow/finetune/finetuned_model_templates.py b/cellmap_flow/finetune/finetuned_model_templates.py new file mode 100644 index 0000000..80062f0 --- /dev/null +++ b/cellmap_flow/finetune/finetuned_model_templates.py @@ -0,0 +1,535 @@ +""" +Templates for generating finetuned model scripts and YAML configurations. + +This module provides functions to auto-generate the necessary files for serving +finetuned models, based on the patterns in my_yamls/jrc_c-elegans-bw-1_finetuned.py/yaml. +""" + +import ast +import logging +import re +from pathlib import Path +from typing import List, Tuple, Optional + +logger = logging.getLogger(__name__) + + +def extract_shapes_from_script(script_path: str) -> Tuple[Optional[Tuple], Optional[Tuple]]: + """ + Safely extract input_size and output_size from a Python script using AST parsing. + + This avoids executing the script (which may try to load models on GPU). + + Args: + script_path: Path to the Python script + + Returns: + Tuple of (input_size, output_size) or (None, None) if extraction fails + """ + try: + with open(script_path, 'r') as f: + source = f.read() + + # Parse the source code into an AST + tree = ast.parse(source) + + input_size = None + output_size = None + + # Walk through all assignment nodes + for node in ast.walk(tree): + if isinstance(node, ast.Assign): + # Check if this is an assignment to input_size or output_size + for target in node.targets: + if isinstance(target, ast.Name): + if target.id == 'input_size': + # Try to evaluate the value + try: + input_size = ast.literal_eval(node.value) + except: + pass + elif target.id == 'output_size': + try: + output_size = ast.literal_eval(node.value) + except: + pass + + logger.info(f"Extracted shapes from {script_path}: input_size={input_size}, output_size={output_size}") + return input_size, output_size + + except Exception as e: + logger.warning(f"AST extraction failed for {script_path}: {e}") + + # Fallback to regex parsing + try: + with open(script_path, 'r') as f: + content = f.read() + + # Match patterns like: input_size = (56, 56, 56) + input_match = re.search(r'input_size\s*=\s*\((\d+),\s*(\d+),\s*(\d+)\)', content) + output_match = re.search(r'output_size\s*=\s*\((\d+),\s*(\d+),\s*(\d+)\)', content) + + if input_match: + input_size = tuple(map(int, input_match.groups())) + if output_match: + output_size = tuple(map(int, output_match.groups())) + + if input_size or output_size: + logger.info(f"Regex extracted shapes from {script_path}: input_size={input_size}, output_size={output_size}") + return input_size, output_size + + except Exception as e2: + logger.warning(f"Regex extraction also failed for {script_path}: {e2}") + + return None, None + + +def generate_finetuned_model_script( + base_checkpoint: str, + lora_adapter_path: str, + model_name: str, + channels: List[str], + input_voxel_size: Tuple[int, int, int], + output_voxel_size: Tuple[int, int, int], + lora_r: int, + lora_alpha: int, + num_epochs: int, + learning_rate: float, + output_path: Path, + base_script_path: str = None +) -> Path: + """ + Generate .py script for loading and serving a finetuned model. + + Based on template: my_yamls/jrc_c-elegans-bw-1_finetuned.py + + Args: + base_checkpoint: Path to base model checkpoint (for checkpoint-based models) + lora_adapter_path: Path to LoRA adapter directory + model_name: Name of the finetuned model + channels: List of output channels (e.g., ["mito"]) + input_voxel_size: Input voxel size (z, y, x) in nm + output_voxel_size: Output voxel size (z, y, x) in nm + lora_r: LoRA rank used in training + lora_alpha: LoRA alpha used in training + num_epochs: Number of training epochs + learning_rate: Learning rate used + output_path: Where to write the .py file + base_script_path: Path to base model script (for script-based models) + + Returns: + Path to the generated script file + """ + # Calculate lora_dropout (typically 0.0 or 0.1) + lora_dropout = 0.0 # Default used in training + + # Format voxel sizes as tuples + input_voxel_str = f"({input_voxel_size[0]}, {input_voxel_size[1]}, {input_voxel_size[2]})" + output_voxel_str = f"({output_voxel_size[0]}, {output_voxel_size[1]}, {output_voxel_size[2]})" + + # Format channels list + channels_str = ", ".join([f'"{c}"' for c in channels]) + + # Determine if this is checkpoint-based or script-based + is_script_based = bool(base_script_path and not base_checkpoint) + + # Handle model source info + if is_script_based: + base_model_info = f"Script: {base_script_path}" + base_checkpoint_var = "" + base_script_var = base_script_path + else: + base_model_info = base_checkpoint if base_checkpoint else "N/A (trained from scratch)" + base_checkpoint_var = base_checkpoint if base_checkpoint else "" + base_script_var = "" + + # Get shapes from base script using safe AST parsing (doesn't execute the script) + if is_script_based and base_script_path: + extracted_input_size, extracted_output_size = extract_shapes_from_script(base_script_path) + base_input_size = extracted_input_size if extracted_input_size else (178, 178, 178) + base_output_size = extracted_output_size if extracted_output_size else (56, 56, 56) + + if not extracted_input_size or not extracted_output_size: + logger.warning( + f"Could not extract shapes from {base_script_path}. " + f"Using defaults: input_size={base_input_size}, output_size={base_output_size}" + ) + else: + base_input_size = (178, 178, 178) + base_output_size = (56, 56, 56) + + # Format shapes as strings + input_size_str = f"{base_input_size}" + output_size_str = f"{base_output_size}" + + # Generate different templates based on model type + if is_script_based: + # Template for script-based models + script_content = f'''""" +LoRA finetuned model: {model_name} + +This model is based on: +{base_model_info} + +Finetuned with LoRA on user corrections with parameters: +- LoRA rank (r): {lora_r} +- LoRA alpha: {lora_alpha} +- LoRA dropout: {lora_dropout} +- Training epochs: {num_epochs} +- Learning rate: {learning_rate} + +Auto-generated by CellMap-Flow finetuning workflow. +""" + +import torch +import torch.nn as nn +from pathlib import Path +import logging +import time + +import gunpowder as gp +import numpy as np +from funlib.geometry.coordinate import Coordinate +from cellmap_flow.utils.load_py import load_safe_config + +logger = logging.getLogger(__name__) + +# Model configuration +classes = [{channels_str}] +output_channels = len(classes) + +# Paths +BASE_SCRIPT = "{base_script_var}" +LORA_ADAPTER_PATH = "{lora_adapter_path}" + +# Voxel sizes and shapes +input_voxel_size = Coordinate{input_voxel_str} +output_voxel_size = Coordinate{output_voxel_str} + +# Model input/output shapes (from base model) +input_size = {input_size_str} +output_size = {output_size_str} + +# Gunpowder shapes +read_shape = gp.Coordinate(*input_size) * Coordinate(input_voxel_size) +write_shape = gp.Coordinate(*output_size) * Coordinate(output_voxel_size) + +# Block shape for processing +block_shape = np.array((*output_size, output_channels)) + +# Load base model ONCE at module level +logger.info(f"Loading base model from: {{BASE_SCRIPT}}") +_load_t0 = time.perf_counter() +_base_config = load_safe_config(BASE_SCRIPT, force_safe=False) +_base_model = _base_config.model +_base_elapsed = time.perf_counter() - _load_t0 +logger.info(f"Base model/script load time: {{_base_elapsed:.2f}}s") + +# Initialize device +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +logger.info(f"Using device: {{device}}") + +# Apply LoRA adapter to base model +from cellmap_flow.finetune.lora_wrapper import load_lora_adapter +logger.info(f"Loading LoRA adapter from: {{LORA_ADAPTER_PATH}}") +_lora_t0 = time.perf_counter() +model = load_lora_adapter( + _base_model, + LORA_ADAPTER_PATH, + is_trainable=False # Inference mode +) +_lora_elapsed = time.perf_counter() - _lora_t0 +model = model.to(device) +model.eval() +_total_elapsed = time.perf_counter() - _load_t0 + +logger.info("LoRA finetuned model loaded successfully") +logger.info( + f"Model load timings (s): base={{_base_elapsed:.2f}}, " + f"lora={{_lora_elapsed:.2f}}, total={{_total_elapsed:.2f}}" +) +logger.info(f"Model classes: {{classes}}") +logger.info(f"Input shape: {{input_size}}, Output shape: {{output_size}}") +logger.info(f"Voxel sizes - Input: {{input_voxel_size}}, Output: {{output_voxel_size}}") +''' + else: + # Template for checkpoint-based models (original template) + script_content = f'''""" +LoRA finetuned model: {model_name} + +This model is based on: +{base_model_info} + +Finetuned with LoRA on user corrections with parameters: +- LoRA rank (r): {lora_r} +- LoRA alpha: {lora_alpha} +- LoRA dropout: {lora_dropout} +- Training epochs: {num_epochs} +- Learning rate: {learning_rate} + +Auto-generated by CellMap-Flow finetuning workflow. +""" + +import torch +import torch.nn as nn +from pathlib import Path +import logging +import time + +import gunpowder as gp +import numpy as np +from funlib.geometry.coordinate import Coordinate + +logger = logging.getLogger(__name__) + +# Model configuration +classes = [{channels_str}] +output_channels = len(classes) + +# Paths +BASE_CHECKPOINT = "{base_checkpoint_var}" +LORA_ADAPTER_PATH = "{lora_adapter_path}" + +# Voxel sizes and shapes +input_voxel_size = Coordinate{input_voxel_str} +output_voxel_size = Coordinate{output_voxel_str} + +# Model input/output shapes (fly model defaults) +# Note: These may need adjustment based on your specific model architecture +input_size = (178, 178, 178) +output_size = (56, 56, 56) + +# Gunpowder shapes +read_shape = gp.Coordinate(*input_size) * Coordinate(input_voxel_size) +write_shape = gp.Coordinate(*output_size) * Coordinate(output_voxel_size) + +# Block shape for processing +block_shape = np.array((*output_size, output_channels)) + + +def load_base_model(checkpoint_path: str, num_channels: int, device) -> nn.Module: + """Load the base fly model from checkpoint.""" + from fly_organelles.model import StandardUnet + + logger.info(f"Loading base model from: {{checkpoint_path}}") + t0 = time.perf_counter() + + # Load the base model + model_backbone = StandardUnet(num_channels) + checkpoint = torch.load(checkpoint_path, weights_only=True, map_location="cpu") + model_backbone.load_state_dict(checkpoint["model_state_dict"]) + + # Wrap with sigmoid + model = torch.nn.Sequential(model_backbone, torch.nn.Sigmoid()) + elapsed = time.perf_counter() - t0 + logger.info(f"Base checkpoint load time: {{elapsed:.2f}}s") + + return model + + +def load_finetuned_model(device) -> nn.Module: + """Load the base model and apply LoRA adapter.""" + from cellmap_flow.finetune.lora_wrapper import load_lora_adapter + t0 = time.perf_counter() + + # Load base model + if BASE_CHECKPOINT: + base_t0 = time.perf_counter() + base_model = load_base_model(BASE_CHECKPOINT, len(classes), device) + base_elapsed = time.perf_counter() - base_t0 + else: + # Model was trained from scratch - create fresh model + logger.warning("No base checkpoint specified - model was trained from scratch") + base_t0 = time.perf_counter() + from fly_organelles.model import StandardUnet + model_backbone = StandardUnet(len(classes)) + base_model = torch.nn.Sequential(model_backbone, torch.nn.Sigmoid()) + base_model.to(device) + base_elapsed = time.perf_counter() - base_t0 + + # Load LoRA adapter + logger.info(f"Loading LoRA adapter from: {{LORA_ADAPTER_PATH}}") + lora_t0 = time.perf_counter() + model = load_lora_adapter( + base_model, + LORA_ADAPTER_PATH, + is_trainable=False # Inference mode + ) + lora_elapsed = time.perf_counter() - lora_t0 + total_elapsed = time.perf_counter() - t0 + logger.info( + f"Model load timings (s): base={{base_elapsed:.2f}}, " + f"lora={{lora_elapsed:.2f}}, total={{total_elapsed:.2f}}" + ) + + return model + + +# Initialize device and model +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +logger.info(f"Using device: {{device}}") + +model = load_finetuned_model(device) +model = model.to(device) +model.eval() + +logger.info("LoRA finetuned model loaded successfully") +logger.info(f"Model classes: {{classes}}") +logger.info(f"Input shape: {{input_size}}, Output shape: {{output_size}}") +logger.info(f"Voxel sizes - Input: {{input_voxel_size}}, Output: {{output_voxel_size}}") +''' + + # Write to file + output_path = Path(output_path) + output_path.parent.mkdir(parents=True, exist_ok=True) + + with open(output_path, "w") as f: + f.write(script_content) + + logger.info(f"Generated finetuned model script: {output_path}") + + return output_path + + +def generate_finetuned_model_yaml( + script_path: Path, + model_name: str, + resolution: int, + output_path: Path, + data_path: str, + queue: str = "gpu_h100", + charge_group: str = "cellmap", + json_data: dict = None, + scale: str = "s0" +) -> Path: + """ + Generate .yaml configuration for serving a finetuned model. + + Based on template: my_yamls/jrc_c-elegans-bw-1_finetuned.yaml + + Args: + script_path: Path to the generated .py script + model_name: Name of the finetuned model + resolution: Voxel resolution in nm + output_path: Where to write the .yaml file + data_path: Path to actual dataset used for training (REQUIRED - no placeholders) + queue: LSF queue name + charge_group: LSF charge group + json_data: Optional dict with input_norm and postprocess from base model + scale: Scale level (e.g., "s0", "s1") from base model + + Returns: + Path to the generated YAML file + """ + # Validate inputs - no placeholders allowed! + if not data_path or data_path == "/path/to/your/data.zarr": + raise ValueError( + "data_path is required and cannot be a placeholder. " + "Must provide actual dataset path from training corrections." + ) + + # Data path comment (always from corrections) + data_path_comment = "# Data path from training corrections\n#\n" + + # Format json_data - use provided or warn if missing + import yaml as yaml_lib + if json_data: + json_data_comment = "# Normalization and postprocessing from base model\n" + json_data_str = yaml_lib.dump({'json_data': json_data}, default_flow_style=False, sort_keys=False).strip() + else: + # Missing json_data is a warning case - provide generic defaults + # but log a warning (already done in job_manager) + json_data_comment = "# WARNING: No normalization found in base model!\n# Using generic defaults - model may not work correctly.\n# Update these values based on your data.\n" + json_data_str = '''json_data: + input_norm: + MinMaxNormalizer: + min_value: 0 + max_value: 65535 + invert: false + LambdaNormalizer: + expression: x*2-1 + postprocess: + DefaultPostprocessor: + clip_min: 0 + clip_max: 1.0''' + + # Convert script_path to absolute path + script_path_abs = Path(script_path).resolve() + + yaml_content = f'''# Finetuned model configuration: {model_name} +# Auto-generated by CellMap-Flow finetuning workflow +# +{data_path_comment} +data_path: "{data_path}" + +charge_group: "{charge_group}" +queue: "{queue}" + +{json_data_comment}{json_data_str} + +# Model configuration +models: + - type: "script" + scale: "{scale}" + resolution: {resolution} + script_path: "{script_path_abs}" + name: "{model_name}" +''' + + # Write to file + output_path = Path(output_path) + output_path.parent.mkdir(parents=True, exist_ok=True) + + with open(output_path, "w") as f: + f.write(yaml_content) + + logger.info(f"Generated finetuned model YAML: {output_path}") + + return output_path + + +def register_finetuned_model(yaml_path: Path): + """ + Load YAML config and register the finetuned model in g.models_config. + + This allows the model to appear in the dashboard immediately. + + Args: + yaml_path: Path to the generated YAML config + + Returns: + The newly created ScriptModelConfig object + """ + from cellmap_flow.utils.config_utils import build_model_from_entry + from cellmap_flow import globals as g + import yaml + + logger.info(f"Registering finetuned model from: {yaml_path}") + + # Load YAML + with open(yaml_path, "r") as f: + config = yaml.safe_load(f) + + # Extract model entry + if "models" not in config or len(config["models"]) == 0: + raise ValueError(f"No models found in YAML config: {yaml_path}") + + model_entry = config["models"][0] + + # Build ModelConfig object + try: + model_config = build_model_from_entry(model_entry) + + # Add to global models config + if not hasattr(g, "models_config"): + g.models_config = [] + + g.models_config.append(model_config) + + logger.info(f"Successfully registered finetuned model: {model_config.name}") + + return model_config + + except Exception as e: + logger.error(f"Failed to register finetuned model: {e}") + raise RuntimeError(f"Model registration failed: {e}") diff --git a/cellmap_flow/finetune/lora_trainer.py b/cellmap_flow/finetune/lora_trainer.py new file mode 100644 index 0000000..de17a28 --- /dev/null +++ b/cellmap_flow/finetune/lora_trainer.py @@ -0,0 +1,595 @@ +""" +LoRA finetuning trainer for CellMap-Flow models. + +This module provides a trainer class for finetuning models using user +corrections with mixed-precision training and gradient accumulation. +""" + +import logging +from pathlib import Path +from typing import Optional, Dict, Any +import time + +import torch +import torch.nn as nn +from torch.optim import AdamW +from torch.cuda.amp import autocast, GradScaler +from torch.utils.data import DataLoader + +logger = logging.getLogger(__name__) + + +class DiceLoss(nn.Module): + """ + Dice Loss for segmentation tasks. + + Dice loss is effective for imbalanced datasets where the target class + may be sparse (e.g., mitochondria in EM images). + + Formula: 1 - (2 * |X ∩ Y| + smooth) / (|X| + |Y| + smooth) + """ + + def __init__(self, smooth: float = 1.0): + """ + Args: + smooth: Smoothing factor to avoid division by zero (default: 1.0) + """ + super().__init__() + self.smooth = smooth + + def forward(self, pred: torch.Tensor, target: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor: + """ + Compute Dice loss. + + Args: + pred: Predictions (B, C, Z, Y, X) - raw logits or probabilities + target: Targets (B, C, Z, Y, X) - binary masks [0, 1] + mask: Optional mask (B, 1, Z, Y, X) - if provided, only compute loss on masked regions + + Returns: + Dice loss value (scalar) + """ + # Flatten spatial dimensions + pred = pred.reshape(pred.size(0), pred.size(1), -1) # (B, C, N) + target = target.reshape(target.size(0), target.size(1), -1) # (B, C, N) + + # Apply sigmoid if pred is logits (not already in [0, 1]) + if pred.min() < 0 or pred.max() > 1: + pred = torch.sigmoid(pred) + + # Apply mask if provided + if mask is not None: + mask = mask.reshape(mask.size(0), 1, -1) # (B, 1, N) + pred = pred * mask + target = target * mask + + # Compute intersection and union + intersection = (pred * target).sum(dim=2) # (B, C) + union = pred.sum(dim=2) + target.sum(dim=2) # (B, C) + + # Dice coefficient + dice = (2.0 * intersection + self.smooth) / (union + self.smooth) + + # Dice loss (1 - dice) + return 1.0 - dice.mean() + + +class CombinedLoss(nn.Module): + """ + Combined Dice + BCE loss for better convergence. + + Uses both Dice loss (for overlap) and BCE loss (for pixel-wise accuracy). + """ + + def __init__(self, dice_weight: float = 0.5, bce_weight: float = 0.5): + """ + Args: + dice_weight: Weight for Dice loss + bce_weight: Weight for BCE loss + """ + super().__init__() + self.dice_loss = DiceLoss() + self.bce_loss = nn.BCEWithLogitsLoss(reduction='none') + self.dice_weight = dice_weight + self.bce_weight = bce_weight + + def forward(self, pred: torch.Tensor, target: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor: + """ + Compute combined loss. + + Args: + pred: Predictions (B, C, Z, Y, X) - raw logits + target: Targets (B, C, Z, Y, X) - binary masks [0, 1] + mask: Optional mask (B, 1, Z, Y, X) - if provided, only compute loss on masked regions + + Returns: + Combined loss value (scalar) + """ + dice = self.dice_loss(pred, target, mask) + + # For BCE, manually apply mask if provided + bce = self.bce_loss(pred, target) + if mask is not None: + bce = bce * mask + bce = bce.sum() / mask.sum().clamp(min=1) # Average over masked regions + else: + bce = bce.mean() + + return self.dice_weight * dice + self.bce_weight * bce + + +class MarginLoss(nn.Module): + """ + Margin-based loss for sparse/scribble annotations. + + Only penalizes predictions on the wrong side of a margin threshold. + For post-sigmoid outputs in [0, 1]: + - Foreground (target=1): loss = relu(threshold - pred)^2, threshold = 1 - margin + - Background (target=0): loss = relu(pred - margin)^2 + - No loss when prediction is already correct with sufficient confidence. + """ + + def __init__(self, margin: float = 0.3, balance_classes: bool = False): + super().__init__() + self.margin = margin + self.balance_classes = balance_classes + + def forward(self, pred: torch.Tensor, target: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor: + threshold_high = 1.0 - self.margin # e.g., 0.7 + threshold_low = self.margin # e.g., 0.3 + + # Foreground loss: penalize if pred < threshold_high + fg_loss = torch.relu(threshold_high - pred) ** 2 + # Background loss: penalize if pred > threshold_low + bg_loss = torch.relu(pred - threshold_low) ** 2 + + if self.balance_classes and mask is not None: + # Average each class separately so fg/bg contribute equally + # regardless of how many scribble voxels each has + fg_mask = target * mask + bg_mask = (1.0 - target) * mask + fg_count = fg_mask.sum().clamp(min=1) + bg_count = bg_mask.sum().clamp(min=1) + fg_contrib = (fg_loss * fg_mask).sum() / fg_count + bg_contrib = (bg_loss * bg_mask).sum() / bg_count + return (fg_contrib + bg_contrib) / 2.0 + + # Blend by target: target=1 -> fg_loss, target=0 -> bg_loss + loss = target * fg_loss + (1.0 - target) * bg_loss + + if mask is not None: + loss = loss * mask + return loss.sum() / mask.sum().clamp(min=1) + return loss.mean() + + +class LoRAFinetuner: + """ + Trainer for finetuning models with LoRA adapters. + + Features: + - Mixed precision (FP16) training for memory efficiency + - Gradient accumulation to simulate larger batch sizes + - Checkpointing with best model tracking + - Progress logging + - Partial annotation support (mask unannotated regions) + + Args: + model: PEFT model with LoRA adapters + dataloader: DataLoader for training data + output_dir: Directory to save checkpoints and logs + learning_rate: Learning rate (default: 1e-4) + num_epochs: Number of training epochs (default: 10) + gradient_accumulation_steps: Steps to accumulate gradients (default: 1) + use_mixed_precision: Enable FP16 training (default: True) + loss_type: Loss function ("dice", "bce", or "combined") + device: Training device ("cuda" or "cpu", auto-detected if None) + select_channel: Optional channel index to select from multi-channel output (default: None) + mask_unannotated: If True (default), only compute loss on annotated regions (target > 0). + Targets are shifted down by 1 (e.g., 1->0, 2->1) after masking. + This allows partial annotations where 0=unannotated, 1=background, 2=foreground, etc. + Ignored if target_transform is provided. + target_transform: Optional TargetTransform instance that converts raw annotations + to (target, mask) pairs. Overrides mask_unannotated when provided. + See cellmap_flow.finetune.target_transforms. + + Examples: + >>> lora_model = wrap_model_with_lora(model) + >>> dataloader = create_dataloader("corrections.zarr") + >>> trainer = LoRAFinetuner( + ... lora_model, + ... dataloader, + ... output_dir="output/fly_organelles_v1.1" + ... ) + >>> trainer.train() + >>> trainer.save_adapter() + """ + + def __init__( + self, + model: nn.Module, + dataloader: DataLoader, + output_dir: str, + learning_rate: float = 1e-4, + num_epochs: int = 10, + gradient_accumulation_steps: int = 1, + use_mixed_precision: bool = True, + loss_type: str = "combined", + device: Optional[str] = None, + select_channel: Optional[int] = None, + mask_unannotated: bool = True, + label_smoothing: float = 0.0, + distillation_lambda: float = 0.0, + distillation_all_voxels: bool = False, + margin: float = 0.3, + balance_classes: bool = False, + target_transform=None, + ): + self.model = model + self.dataloader = dataloader + self.output_dir = Path(output_dir) + self.num_epochs = num_epochs + self.gradient_accumulation_steps = gradient_accumulation_steps + self.use_mixed_precision = use_mixed_precision + self.select_channel = select_channel + self.mask_unannotated = mask_unannotated + self.label_smoothing = label_smoothing + self.distillation_lambda = distillation_lambda + self.distillation_all_voxels = distillation_all_voxels + self.balance_classes = balance_classes + self.target_transform = target_transform + + # Create output directory + self.output_dir.mkdir(parents=True, exist_ok=True) + + # Device + if device is None: + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + else: + self.device = torch.device(device) + + logger.info(f"Using device: {self.device}") + + # Move model to device + self.model = self.model.to(self.device) + + # Optimizer (only LoRA parameters) + self.optimizer = AdamW( + filter(lambda p: p.requires_grad, self.model.parameters()), + lr=learning_rate, + ) + + # Loss function + self._use_bce = False + self._use_mse = False + if loss_type == "dice": + self.criterion = DiceLoss() + elif loss_type == "bce": + # Use reduction='none' so we can manually apply mask if needed + self.criterion = nn.BCEWithLogitsLoss(reduction='none') + self._use_bce = True + elif loss_type == "combined": + self.criterion = CombinedLoss() + elif loss_type == "mse": + self.criterion = nn.MSELoss(reduction='none') + self._use_mse = True + elif loss_type == "margin": + self.criterion = MarginLoss(margin=margin, balance_classes=balance_classes) + else: + raise ValueError(f"Unknown loss_type: {loss_type}") + + # Label smoothing is redundant with margin loss + if loss_type == "margin" and self.label_smoothing > 0: + logger.warning("Label smoothing is redundant with margin loss, setting to 0") + self.label_smoothing = 0.0 + + if self.balance_classes: + logger.info("Class balancing enabled: fg and bg scribble voxels weighted equally") + + logger.info(f"Using {loss_type} loss") + if self.label_smoothing > 0: + logger.info(f"Label smoothing: {self.label_smoothing} (targets: {self.label_smoothing/2:.3f} to {1-self.label_smoothing/2:.3f})") + if self.distillation_lambda > 0: + scope_str = "all voxels" if self.distillation_all_voxels else "unlabeled voxels only" + logger.info(f"Teacher distillation enabled: lambda={self.distillation_lambda} ({scope_str})") + + # Mixed precision scaler + self.scaler = GradScaler(enabled=use_mixed_precision) + + # Training state + self.current_epoch = 0 + self.global_step = 0 + self.best_loss = float('inf') + self.training_stats = [] + + def train(self) -> Dict[str, Any]: + """ + Run the training loop. + + Returns: + Training statistics dictionary with: + - final_loss: Final epoch loss + - best_loss: Best loss achieved + - total_epochs: Number of epochs trained + - total_steps: Total training steps + """ + # Create log file + log_file = self.output_dir / "training_log.txt" + + def log_message(msg): + """Log to console (tee handles writing to log file).""" + print(msg, flush=True) + + log_message("="*60) + log_message("Starting LoRA Finetuning") + log_message("="*60) + log_message(f"Epochs: {self.num_epochs}") + log_message(f"Batches per epoch: {len(self.dataloader)}") + log_message(f"Gradient accumulation: {self.gradient_accumulation_steps}") + log_message(f"Effective batch size: {self.dataloader.batch_size * self.gradient_accumulation_steps}") + log_message(f"Mixed precision: {self.use_mixed_precision}") + log_message(f"Mask unannotated regions: {self.mask_unannotated}") + log_message(f"Log file: {log_file}") + log_message("") + + self.model.train() + start_time = time.time() + + # Store log function for use in _train_epoch + self._log_message = log_message + + for epoch in range(self.num_epochs): + self.current_epoch = epoch + epoch_loss = self._train_epoch() + + # Log epoch results + self._log_message( + f"Epoch {epoch+1}/{self.num_epochs} - " + f"Loss: {epoch_loss:.6f} - " + f"Best: {self.best_loss:.6f}" + ) + + # Save checkpoint if best + if epoch_loss < self.best_loss: + self.best_loss = epoch_loss + self.save_checkpoint(is_best=True) + self._log_message(f" → Saved best checkpoint") + + # Save regular checkpoint every 5 epochs + if (epoch + 1) % 5 == 0: + self.save_checkpoint(is_best=False) + + self.training_stats.append({ + 'epoch': epoch + 1, + 'loss': epoch_loss, + 'best_loss': self.best_loss, + }) + + # Final checkpoint + self.save_checkpoint(is_best=False) + + total_time = time.time() - start_time + self._log_message("") + self._log_message("="*60) + self._log_message("Training Complete!") + self._log_message(f"Total time: {total_time/60:.2f} minutes") + self._log_message(f"Best loss: {self.best_loss:.6f}") + self._log_message(f"Final loss: {epoch_loss:.6f}") + self._log_message(f"Output directory: {self.output_dir}") + self._log_message("="*60) + + return { + 'final_loss': epoch_loss, + 'best_loss': self.best_loss, + 'total_epochs': self.num_epochs, + 'total_steps': self.global_step, + 'training_time': total_time, + } + + def _train_epoch(self) -> float: + """Train for one epoch and return average loss.""" + epoch_loss = 0.0 + epoch_supervised_loss = 0.0 + epoch_distill_loss = 0.0 + num_batches = len(self.dataloader) + + for batch_idx, (raw, target) in enumerate(self.dataloader): + # Move to device + raw = raw.to(self.device, non_blocking=True) + target = target.to(self.device, non_blocking=True) + + # Handle partial annotations: create mask and shift labels + mask = None + if self.target_transform is not None: + target, mask = self.target_transform(target) + elif self.mask_unannotated: + # Legacy behavior: binary single-channel + mask = (target > 0).float() # (B, C, Z, Y, X) + # Shift labels down by 1 (but keep 0 as 0) + # e.g., 0->0 (unannotated), 1->0 (background), 2->1 (foreground) + target = torch.clamp(target - 1, min=0) + + # Apply label smoothing: 0 -> s/2, 1 -> 1-s/2 + # This prevents the model from being pushed to extreme 0/1 outputs, + # preserving gradual distance-like predictions + if self.label_smoothing > 0: + target = target * (1 - self.label_smoothing) + self.label_smoothing / 2 + + # Teacher forward pass for distillation (before student pass) + # Uses the base model without LoRA adapters as the teacher + teacher_pred = None + if self.distillation_lambda > 0: + with torch.no_grad(): + self.model.disable_adapter_layers() + try: + with autocast(enabled=self.use_mixed_precision): + teacher_pred = self.model(raw) + if self.select_channel is not None: + teacher_pred = teacher_pred[:, self.select_channel:self.select_channel+1, :, :, :] + teacher_pred = teacher_pred.detach() + finally: + self.model.enable_adapter_layers() + + # Student forward pass with mixed precision + with autocast(enabled=self.use_mixed_precision): + pred = self.model(raw) + + # Select specific channel if requested (e.g., mito = channel 2 from 8-channel output) + if self.select_channel is not None: + pred = pred[:, self.select_channel:self.select_channel+1, :, :, :] + + # Compute supervised loss with optional mask + if (self._use_bce or self._use_mse) and mask is not None: + # For per-element losses (BCE, MSE), manually apply mask + per_element_loss = self.criterion(pred, target) + if self.balance_classes: + # Average fg and bg separately so each contributes equally + fg_mask = target * mask + bg_mask = (1.0 - target) * mask + fg_count = fg_mask.sum().clamp(min=1) + bg_count = bg_mask.sum().clamp(min=1) + fg_contrib = (per_element_loss * fg_mask).sum() / fg_count + bg_contrib = (per_element_loss * bg_mask).sum() / bg_count + supervised_loss = (fg_contrib + bg_contrib) / 2.0 + else: + supervised_loss = (per_element_loss * mask).sum() / mask.sum().clamp(min=1) + elif hasattr(self.criterion, 'forward') and 'mask' in self.criterion.forward.__code__.co_varnames: + # For custom losses that support masking (DiceLoss, CombinedLoss, MarginLoss) + supervised_loss = self.criterion(pred, target, mask) + else: + # No masking needed + supervised_loss = self.criterion(pred, target) + if self._use_bce or self._use_mse: + supervised_loss = supervised_loss.mean() + + loss = supervised_loss + + # Compute distillation loss + distillation_loss = torch.tensor(0.0, device=self.device) + if self.distillation_lambda > 0 and teacher_pred is not None: + distill_loss_map = (pred - teacher_pred) ** 2 # per-element MSE + if self.distillation_all_voxels or mask is None: + # Apply on all voxels + distillation_loss = distill_loss_map.mean() + else: + # Apply only on unlabeled voxels + unlabeled_mask = 1.0 - mask # 1 where unlabeled + distillation_loss = (distill_loss_map * unlabeled_mask).sum() / unlabeled_mask.sum().clamp(min=1) + loss = loss + self.distillation_lambda * distillation_loss + + # Scale loss for gradient accumulation + loss = loss / self.gradient_accumulation_steps + + # Backward pass + self.scaler.scale(loss).backward() + + # Update weights after accumulation + if (batch_idx + 1) % self.gradient_accumulation_steps == 0: + self.scaler.step(self.optimizer) + self.scaler.update() + self.optimizer.zero_grad() + self.global_step += 1 + + # Accumulate losses (unscaled) + epoch_loss += loss.item() * self.gradient_accumulation_steps + epoch_supervised_loss += supervised_loss.item() + epoch_distill_loss += distillation_loss.item() + + # Log progress every batch (since we have few batches) + avg_loss = epoch_loss / (batch_idx + 1) + if hasattr(self, '_log_message'): + if self.distillation_lambda > 0: + avg_sup = epoch_supervised_loss / (batch_idx + 1) + avg_distill = epoch_distill_loss / (batch_idx + 1) + self._log_message( + f" Batch {batch_idx+1}/{num_batches} - " + f"Loss: {avg_loss:.6f} (sup: {avg_sup:.6f}, distill: {avg_distill:.6f})" + ) + else: + self._log_message( + f" Batch {batch_idx+1}/{num_batches} - " + f"Loss: {avg_loss:.6f}" + ) + else: + # Fallback if _log_message not set + msg = f" Batch {batch_idx+1}/{num_batches} - Loss: {avg_loss:.6f}" + print(msg) + logger.info(msg) + + # Handle leftover accumulated gradients at end of epoch + # (in case num_batches is not divisible by gradient_accumulation_steps) + if num_batches % self.gradient_accumulation_steps != 0: + self.scaler.step(self.optimizer) + self.scaler.update() + self.optimizer.zero_grad() + self.global_step += 1 + + return epoch_loss / num_batches + + def save_checkpoint(self, is_best: bool = False): + """ + Save training checkpoint. + + Args: + is_best: If True, saves as "best_model.pth" + """ + checkpoint_name = "best_checkpoint.pth" if is_best else f"checkpoint_epoch_{self.current_epoch+1}.pth" + checkpoint_path = self.output_dir / checkpoint_name + + # Save only trainable (LoRA) parameters to avoid writing the full + # 800M+ param base model to disk every checkpoint. + trainable_keys = {n for n, p in self.model.named_parameters() if p.requires_grad} + trainable_state = {k: v for k, v in self.model.state_dict().items() if k in trainable_keys} + checkpoint = { + 'epoch': self.current_epoch, + 'global_step': self.global_step, + 'model_state_dict': trainable_state, + 'optimizer_state_dict': self.optimizer.state_dict(), + 'scaler_state_dict': self.scaler.state_dict(), + 'best_loss': self.best_loss, + 'training_stats': self.training_stats, + 'lora_only': True, + } + + torch.save(checkpoint, checkpoint_path) + logger.debug(f"Checkpoint saved: {checkpoint_path}") + + def save_adapter(self, adapter_path: Optional[str] = None): + """ + Save only the LoRA adapter (not the full model). + + Args: + adapter_path: Path to save adapter. If None, uses output_dir/lora_adapter + """ + from cellmap_flow.finetune.lora_wrapper import save_lora_adapter + + if adapter_path is None: + adapter_path = str(self.output_dir / "lora_adapter") + + save_lora_adapter(self.model, adapter_path) + logger.info(f"LoRA adapter saved to: {adapter_path}") + + def load_checkpoint(self, checkpoint_path: str): + """ + Load training checkpoint to resume training. + + Args: + checkpoint_path: Path to checkpoint file + """ + checkpoint = torch.load(checkpoint_path, map_location=self.device) + + if checkpoint.get('lora_only', False): + # Checkpoint contains only trainable (LoRA) params — merge into full state + self.model.load_state_dict(checkpoint['model_state_dict'], strict=False) + else: + self.model.load_state_dict(checkpoint['model_state_dict']) + self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) + self.scaler.load_state_dict(checkpoint['scaler_state_dict']) + + self.current_epoch = checkpoint['epoch'] + self.global_step = checkpoint['global_step'] + self.best_loss = checkpoint['best_loss'] + self.training_stats = checkpoint.get('training_stats', []) + + logger.info(f"Checkpoint loaded from: {checkpoint_path}") + logger.info(f"Resuming from epoch {self.current_epoch+1}") diff --git a/cellmap_flow/finetune/lora_wrapper.py b/cellmap_flow/finetune/lora_wrapper.py new file mode 100644 index 0000000..1e5ff8c --- /dev/null +++ b/cellmap_flow/finetune/lora_wrapper.py @@ -0,0 +1,394 @@ +""" +Generic LoRA wrapper for PyTorch models. + +This module provides automatic detection of adaptable layers and wraps +PyTorch models with LoRA (Low-Rank Adaptation) adapters using the +HuggingFace PEFT library. + +LoRA enables efficient finetuning by training only a small number of +additional parameters (typically 1-2% of the original model) while +keeping the base model frozen. +""" + +import logging +from typing import List, Optional, Union +import torch +import torch.nn as nn + +logger = logging.getLogger(__name__) + + +def detect_adaptable_layers( + model: nn.Module, + include_patterns: Optional[List[str]] = None, + exclude_patterns: Optional[List[str]] = None, +) -> List[str]: + """ + Automatically detect layers suitable for LoRA adaptation. + + Searches for Conv2d, Conv3d, and Linear layers, filtering by name patterns. + By default, excludes batch norm, layer norm, and final output layers. + + Args: + model: PyTorch model to inspect + include_patterns: List of regex patterns for layer names to include + If None, includes all Conv/Linear layers + exclude_patterns: List of substrings for layer names to exclude + Default: ['bn', 'norm', 'final', 'head'] + + Returns: + List of layer names suitable for LoRA adaptation + + Examples: + >>> model = my_unet_model() + >>> layers = detect_adaptable_layers(model) + >>> print(f"Found {len(layers)} adaptable layers") + Found 24 adaptable layers + + >>> # Only adapt encoder layers + >>> layers = detect_adaptable_layers( + ... model, + ... include_patterns=[r".*encoder.*"] + ... ) + """ + import re + + if exclude_patterns is None: + exclude_patterns = ['bn', 'norm', 'final', 'head', 'output'] + + adaptable = [] + + for name, module in model.named_modules(): + # Check if it's a convolutional or linear layer + if not isinstance(module, (nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.Linear)): + continue + + # Apply include patterns if specified + if include_patterns is not None: + if not any(re.match(pattern, name) for pattern in include_patterns): + continue + + # Apply exclude patterns + if any(exclude in name.lower() for exclude in exclude_patterns): + logger.debug(f"Excluding layer: {name} (matched exclude pattern)") + continue + + adaptable.append(name) + + logger.info(f"Detected {len(adaptable)} adaptable layers") + if len(adaptable) > 0: + logger.debug(f"Adaptable layers: {adaptable[:5]}..." if len(adaptable) > 5 else f"Adaptable layers: {adaptable}") + + return adaptable + + +class SequentialWrapper(nn.Module): + """ + Wrapper for Sequential models to make them compatible with PEFT. + + PEFT expects models to accept **kwargs, but Sequential only accepts + positional args. This wrapper provides that interface. + """ + def __init__(self, model: nn.Module): + super().__init__() + self.model = model + + def forward(self, x=None, input_ids=None, **kwargs): + # PEFT may pass input as 'input_ids' kwarg for transformers + # For vision models, we expect 'x' as positional or kwarg + if x is None and input_ids is not None: + x = input_ids + if x is None: + raise ValueError("Input tensor not provided") + # Ignore other kwargs and just pass x + return self.model(x) + + +def wrap_model_with_lora( + model: nn.Module, + target_modules: Optional[List[str]] = None, + lora_r: int = 8, + lora_alpha: int = 16, + lora_dropout: float = 0.1, + modules_to_save: Optional[List[str]] = None, + task_type: str = "FEATURE_EXTRACTION", +) -> nn.Module: + """ + Wrap a PyTorch model with LoRA adapters using HuggingFace PEFT. + + This creates a PEFT model with LoRA adapters on specified layers. + The base model is frozen, and only LoRA parameters are trainable. + + Args: + model: PyTorch model to wrap (e.g., UNet, CNN) + target_modules: List of layer names to adapt. If None, auto-detects. + lora_r: LoRA rank (number of low-rank dimensions) + Higher = more capacity, more parameters + Typical values: 4-32, default 8 + lora_alpha: LoRA alpha (scaling factor) + Controls strength of LoRA updates + Typical: 2*r, default 16 + lora_dropout: Dropout probability for LoRA layers (0.0-0.5, default 0.1) + modules_to_save: Additional modules to make trainable (e.g., final layer) + task_type: PEFT task type. Options: + - "FEATURE_EXTRACTION" (default, for general models) + - "SEQ_CLS" (sequence classification) + - "TOKEN_CLS" (token classification) + - "CAUSAL_LM" (causal language modeling) + + Returns: + PEFT model with LoRA adapters + + Raises: + ImportError: If peft library is not installed + ValueError: If no adaptable layers found + + Examples: + >>> # Auto-detect and wrap all Conv/Linear layers + >>> lora_model = wrap_model_with_lora(model, lora_r=8) + + >>> # Wrap specific layers with custom config + >>> lora_model = wrap_model_with_lora( + ... model, + ... target_modules=["encoder.conv1", "encoder.conv2"], + ... lora_r=16, + ... lora_alpha=32, + ... modules_to_save=["final_conv"] + ... ) + + >>> # Check trainable parameters + >>> print_lora_parameters(lora_model) + """ + try: + from peft import LoraConfig, get_peft_model, TaskType + except ImportError: + raise ImportError( + "peft library is required for LoRA finetuning. " + "Install with: pip install peft" + ) + + # Wrap Sequential models to make them compatible with PEFT + if isinstance(model, nn.Sequential): + logger.info("Wrapping Sequential model for PEFT compatibility") + model = SequentialWrapper(model) + + # Auto-detect target modules if not specified + if target_modules is None: + target_modules = detect_adaptable_layers(model) + if len(target_modules) == 0: + raise ValueError( + "No adaptable layers found in model. " + "Specify target_modules manually or check model architecture." + ) + logger.info(f"Auto-detected {len(target_modules)} target modules for LoRA") + + # Map task type string to PEFT TaskType enum + task_type_map = { + "FEATURE_EXTRACTION": TaskType.FEATURE_EXTRACTION, + "SEQ_CLS": TaskType.SEQ_CLS, + "TOKEN_CLS": TaskType.TOKEN_CLS, + "CAUSAL_LM": TaskType.CAUSAL_LM, + } + + if task_type not in task_type_map: + logger.warning( + f"Unknown task_type '{task_type}', using FEATURE_EXTRACTION. " + f"Valid options: {list(task_type_map.keys())}" + ) + task_type = "FEATURE_EXTRACTION" + + # Create LoRA config + lora_config = LoraConfig( + task_type=task_type_map[task_type], + r=lora_r, + lora_alpha=lora_alpha, + lora_dropout=lora_dropout, + target_modules=target_modules, + modules_to_save=modules_to_save, + bias="none", # Don't adapt bias terms + ) + + logger.info( + f"Creating LoRA model with r={lora_r}, alpha={lora_alpha}, " + f"dropout={lora_dropout}" + ) + + # Wrap model with PEFT + peft_model = get_peft_model(model, lora_config) + + logger.info("LoRA model created successfully") + print_lora_parameters(peft_model) + + return peft_model + + +def print_lora_parameters(model: nn.Module): + """ + Print statistics about trainable and total parameters in a LoRA model. + + Args: + model: PEFT model with LoRA adapters + + Examples: + >>> lora_model = wrap_model_with_lora(model) + >>> print_lora_parameters(lora_model) + Trainable params: 294,912 (1.2% of total) + Total params: 24,567,890 + """ + try: + from peft import PeftModel + if isinstance(model, PeftModel): + model.print_trainable_parameters() + return + except ImportError: + pass + + # Fallback if not a PEFT model + trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + total_params = sum(p.numel() for p in model.parameters()) + + if total_params > 0: + percentage = 100 * trainable_params / total_params + logger.info( + f"Trainable params: {trainable_params:,} ({percentage:.2f}% of total)" + ) + logger.info(f"Total params: {total_params:,}") + else: + logger.warning("Model has no parameters") + + +def load_lora_adapter( + model: nn.Module, + adapter_path: str, + is_trainable: bool = False, +) -> nn.Module: + """ + Load a pretrained LoRA adapter into a base model. + + Args: + model: Base PyTorch model (without LoRA) + adapter_path: Path to saved LoRA adapter directory + is_trainable: If True, adapter parameters are trainable (for continued training) + If False, adapter parameters are frozen (for inference) + + Returns: + PEFT model with loaded adapter + + Examples: + >>> # Load adapter for inference + >>> model = load_lora_adapter( + ... base_model, + ... "models/fly_organelles/v1.1.0/lora_adapter" + ... ) + + >>> # Load adapter for continued training + >>> model = load_lora_adapter( + ... base_model, + ... "models/fly_organelles/v1.1.0/lora_adapter", + ... is_trainable=True + ... ) + """ + try: + from peft import PeftModel + except ImportError: + raise ImportError( + "peft library is required. Install with: pip install peft" + ) + + logger.info(f"Loading LoRA adapter from: {adapter_path}") + + # Wrap Sequential models to make them compatible with PEFT + if isinstance(model, nn.Sequential): + logger.info("Wrapping Sequential model for PEFT compatibility") + model = SequentialWrapper(model) + + peft_model = PeftModel.from_pretrained( + model, + adapter_path, + is_trainable=is_trainable, + ) + + if is_trainable: + logger.info("Adapter loaded in trainable mode") + else: + logger.info("Adapter loaded in inference mode (frozen)") + + print_lora_parameters(peft_model) + + return peft_model + + +def save_lora_adapter( + model: nn.Module, + output_path: str, +): + """ + Save only the LoRA adapter parameters (not the full model). + + This saves only the trained LoRA weights (~5-20 MB) rather than + the entire model (~200-500 MB). + + Args: + model: PEFT model with LoRA adapters + output_path: Directory to save adapter + + Examples: + >>> save_lora_adapter( + ... lora_model, + ... "models/fly_organelles/v1.1.0/lora_adapter" + ... ) + """ + try: + from peft import PeftModel + except ImportError: + raise ImportError( + "peft library is required. Install with: pip install peft" + ) + + if not isinstance(model, PeftModel): + raise ValueError( + "Model must be a PeftModel. Use wrap_model_with_lora() first." + ) + + logger.info(f"Saving LoRA adapter to: {output_path}") + model.save_pretrained(output_path) + logger.info("Adapter saved successfully") + + +def merge_lora_into_base(model: nn.Module) -> nn.Module: + """ + Merge LoRA weights back into the base model. + + This creates a standalone model with LoRA weights merged in, + removing the need for PEFT at inference time. + + Warning: This increases model size back to the full model size. + Only use if you need a standalone model without PEFT dependency. + + Args: + model: PEFT model with LoRA adapters + + Returns: + Base model with merged weights + + Examples: + >>> merged_model = merge_lora_into_base(lora_model) + >>> torch.save(merged_model.state_dict(), "merged_model.pt") + """ + try: + from peft import PeftModel + except ImportError: + raise ImportError( + "peft library is required. Install with: pip install peft" + ) + + if not isinstance(model, PeftModel): + raise ValueError( + "Model must be a PeftModel to merge adapters" + ) + + logger.info("Merging LoRA adapters into base model") + merged_model = model.merge_and_unload() + logger.info("Adapters merged successfully") + + return merged_model diff --git a/cellmap_flow/finetune/target_transforms.py b/cellmap_flow/finetune/target_transforms.py new file mode 100644 index 0000000..4749821 --- /dev/null +++ b/cellmap_flow/finetune/target_transforms.py @@ -0,0 +1,134 @@ +""" +Target transforms for converting user annotations to training targets. + +Each transform takes a raw annotation tensor (B, 1, Z, Y, X) with values: + 0 = unannotated (ignored in loss) + 1 = background + 2 = first foreground object + 3 = second foreground object, etc. + +And produces: + target: (B, C, Z, Y, X) — training target matching model output channels + mask: (B, C, Z, Y, X) or (B, 1, Z, Y, X) — valid loss mask +""" + +from typing import List, Tuple + +import torch +from torch import Tensor + + +class TargetTransform: + """Base class for target transforms.""" + + def __call__(self, annotation: Tensor) -> Tuple[Tensor, Tensor]: + """Convert annotation to (target, mask) pair.""" + raise NotImplementedError + + +class BinaryTargetTransform(TargetTransform): + """Standard binary segmentation transform (current default behavior). + + Produces single-channel binary target: bg=0, fg=1. + Mask marks annotated regions. + """ + + def __call__(self, annotation: Tensor) -> Tuple[Tensor, Tensor]: + mask = (annotation > 0).float() + target = torch.clamp(annotation - 1, min=0) + target = (target > 0).float() + return target, mask + + +class BroadcastBinaryTargetTransform(TargetTransform): + """Binary target broadcast to N channels. + + All output channels receive the same fg/bg target. + Useful for treating multi-channel models (affinities, distances) + as simple binary segmentation. + """ + + def __init__(self, num_channels: int): + self.num_channels = num_channels + + def __call__(self, annotation: Tensor) -> Tuple[Tensor, Tensor]: + mask = (annotation > 0).float() + target = (torch.clamp(annotation - 1, min=0) > 0).float() + # expand is lazy (no memory copy), contiguous() ensures safe downstream use + target = target.expand(-1, self.num_channels, -1, -1, -1).contiguous() + mask = mask.expand(-1, self.num_channels, -1, -1, -1).contiguous() + return target, mask + + +class AffinityTargetTransform(TargetTransform): + """Compute affinity targets from instance labels. + + For each offset, affinity is: + 1 if both voxels belong to the same foreground object (same label > 1) + 0 if different objects, or either is background + + The loss mask requires both voxels in each pair to be annotated (label > 0), + producing a per-channel mask since each offset shifts differently. + + Args: + offsets: List of [dz, dy, dx] offset tuples defining neighbor relationships. + num_channels: Total number of model output channels. If greater than + len(offsets), extra channels (e.g. LSDs) are masked out + (mask=0) so they receive no gradient. If None, defaults + to len(offsets). + """ + + def __init__(self, offsets: List[List[int]], num_channels: int = None): + self.offsets = offsets + self.num_channels = num_channels if num_channels is not None else len(offsets) + + def __call__(self, annotation: Tensor) -> Tuple[Tensor, Tensor]: + B, _C, Z, Y, X = annotation.shape + # Allocate for all output channels; non-affinity channels stay zero (masked out) + target = torch.zeros(B, self.num_channels, Z, Y, X, device=annotation.device) + mask = torch.zeros(B, self.num_channels, Z, Y, X, device=annotation.device) + + labels = annotation[:, 0] # (B, Z, Y, X) + annotated = labels > 0 # bool + + for i, offset in enumerate(self.offsets): + dz, dy, dx = offset + src_slices, dst_slices = _offset_slices(Z, Y, X, dz, dy, dx) + + src_labels = labels[(slice(None), *src_slices)] + dst_labels = labels[(slice(None), *dst_slices)] + src_ann = annotated[(slice(None), *src_slices)] + dst_ann = annotated[(slice(None), *dst_slices)] + + # Affinity = 1 iff same foreground object + same_fg = (src_labels == dst_labels) & (src_labels > 1) + both_annotated = src_ann & dst_ann + + target[(slice(None), i, *src_slices)] = same_fg.float() + mask[(slice(None), i, *src_slices)] = both_annotated.float() + + return target, mask + + +def _offset_slices(Z, Y, X, dz, dy, dx): + """Compute source and destination slices for an offset. + + For a volume of shape (Z, Y, X) and offset (dz, dy, dx), + returns slices such that: + volume[src_slices] and volume[dst_slices] + are aligned views offset by (dz, dy, dx). + """ + + def _dim_slices(size, d): + if d > 0: + return slice(None, size - d), slice(d, None) + elif d < 0: + return slice(-d, None), slice(None, size + d) + else: + return slice(None), slice(None) + + sz, dz_s = _dim_slices(Z, dz) + sy, dy_s = _dim_slices(Y, dy) + sx, dx_s = _dim_slices(X, dx) + + return (sz, sy, sx), (dz_s, dy_s, dx_s) diff --git a/cellmap_flow/globals.py b/cellmap_flow/globals.py index 5e07486..c656aac 100644 --- a/cellmap_flow/globals.py +++ b/cellmap_flow/globals.py @@ -2,10 +2,12 @@ from cellmap_flow.post.postprocessors import DefaultPostprocessor, ThresholdPostprocessor import os +import queue import yaml +import logging import threading import numpy as np -import logging +from collections import deque from typing import Any, List, Optional logger = logging.getLogger(__name__) @@ -49,6 +51,18 @@ class Flow: shaders: dict shader_controls: dict + # Dashboard state (moved from cellmap_flow.dashboard.state) + log_buffer: deque + log_clients: list + NEUROGLANCER_URL: Optional[str] + INFERENCE_SERVER: Optional[Any] + CUSTOM_CODE_FOLDER: str + bbx_generator_state: dict + finetune_job_manager: Any + minio_state: dict + annotation_volumes: dict + output_sessions: dict + def __new__(cls): if cls._instance is None: cls._instance = super(Flow, cls).__new__(cls) @@ -91,8 +105,55 @@ def __new__(cls): cls._instance.shaders = {} # ShaderControls state: key = layer name, value = shaderControls dict cls._instance.shader_controls = {} + + # Dashboard state (moved from cellmap_flow.dashboard.state) + cls._instance.log_buffer = deque(maxlen=1000) + cls._instance.log_clients = [] + cls._instance.NEUROGLANCER_URL = None + cls._instance.INFERENCE_SERVER = None + cls._instance.CUSTOM_CODE_FOLDER = os.path.expanduser( + os.environ.get( + "CUSTOM_CODE_FOLDER", + "~/Desktop/cellmap/cellmap-flow/example/example_norm", + ) + ) + cls._instance.bbx_generator_state = { + "dataset_path": None, + "num_boxes": 0, + "bounding_boxes": [], + "viewer": None, + "viewer_process": None, + "viewer_url": None, + "viewer_state": None, + } + cls._instance.minio_state = { + "process": None, + "port": None, + "ip": None, + "bucket": "annotations", + "minio_root": None, + "output_base": None, + "last_sync": {}, + "chunk_sync_state": {}, + "sync_thread": None, + } + cls._instance.annotation_volumes = {} + cls._instance.output_sessions = {} + cls._instance._finetune_job_manager = None + return cls._instance + @property + def finetune_job_manager(self): + if self._finetune_job_manager is None: + from cellmap_flow.finetune.finetune_job_manager import FinetuneJobManager + self._finetune_job_manager = FinetuneJobManager() + return self._finetune_job_manager + + @finetune_job_manager.setter + def finetune_job_manager(self, value): + self._finetune_job_manager = value + def to_dict(self): return self.__dict__.items() @@ -187,3 +248,24 @@ def delete(cls): g = Flow() + + +# Custom handler to capture logs into Flow singleton +class LogHandler(logging.Handler): + def emit(self, record): + log_entry = self.format(record) + g.log_buffer.append(log_entry) + # Send to all connected clients + for client_queue in g.log_clients: + try: + client_queue.put_nowait(log_entry) + except queue.Full: + pass + + +def get_blockwise_tasks_dir(): + tasks_dir = g.blockwise_tasks_dir or os.path.expanduser( + "~/.cellmap_flow/blockwise_tasks" + ) + os.makedirs(tasks_dir, exist_ok=True) + return tasks_dir diff --git a/cellmap_flow/models/models_config.py b/cellmap_flow/models/models_config.py index fa04fb3..c609dfd 100644 --- a/cellmap_flow/models/models_config.py +++ b/cellmap_flow/models/models_config.py @@ -250,6 +250,12 @@ def load_eval_model(self, num_channels, checkpoint_path): if checkpoint_path.endswith(".ts"): model_backbone = torch.jit.load(checkpoint_path, map_location=device) + elif checkpoint_path.endswith("model.pt"): + # Load full model directly (for trusted fly_organelles models) + model = torch.load(checkpoint_path, weights_only=False, map_location=device) + model.to(device) + model.eval() + return model else: from fly_organelles.model import StandardUnet @@ -574,7 +580,7 @@ def __init__(self, folder_path, name, scale=None): @property def command(self) -> str: - return f"cellmap-model --folder-path {self.cellmap_model.folder_path} --name {self.name}" + return f"cellmap --folder-path {self.cellmap_model.folder_path} --name {self.name}" def _get_config(self) -> Config: config = Config() diff --git a/cellmap_flow/server.py b/cellmap_flow/server.py index 3b3cfce..4d8769d 100644 --- a/cellmap_flow/server.py +++ b/cellmap_flow/server.py @@ -3,7 +3,7 @@ from http import HTTPStatus import numpy as np import numcodecs -from flask import Flask, jsonify, redirect +from flask import Flask, jsonify, redirect, request from flask_cors import CORS from flasgger import Swagger from funlib.geometry import Roi @@ -33,7 +33,7 @@ class CellMapFlowServer: All routes are defined via Flask decorators for convenience. """ - def __init__(self, dataset_name: str, model_config: ModelConfig): + def __init__(self, dataset_name: str, model_config: ModelConfig, restart_callback=None): """ Initialize the server and set up routes via decorators. """ @@ -47,6 +47,7 @@ def __init__(self, dataset_name: str, model_config: ModelConfig): self.model_output_axes = model_config.chunk_output_axes self.inferencer = Inferencer(model_config) + self.restart_callback = restart_callback # Load or initialize your dataset self.idi_raw = ImageDataInterface( @@ -88,6 +89,20 @@ def __init__(self, dataset_name: str, model_config: ModelConfig): def home(): return redirect("/apidocs/") + @self.app.route("/__control__/restart", methods=["POST"]) + def control_restart(): + if self.restart_callback is None: + return jsonify({"success": False, "error": "Restart control not enabled"}), HTTPStatus.NOT_IMPLEMENTED + try: + payload = request.get_json(silent=True) or {} + accepted = self.restart_callback(payload) + if not accepted: + return jsonify({"success": False, "error": "Restart request rejected"}), HTTPStatus.CONFLICT + return jsonify({"success": True}), HTTPStatus.OK + except Exception as e: + logger.error(f"Failed to process restart control request: {e}", exc_info=True) + return jsonify({"success": False, "error": str(e)}), HTTPStatus.INTERNAL_SERVER_ERROR + @self.app.route("//.zattrs", methods=["GET"]) def top_level_attributes(dataset): self.refresh_dataset(dataset) diff --git a/cellmap_flow/utils/bsub_utils.py b/cellmap_flow/utils/bsub_utils.py index 796e3ee..ed0de42 100644 --- a/cellmap_flow/utils/bsub_utils.py +++ b/cellmap_flow/utils/bsub_utils.py @@ -121,13 +121,13 @@ def get_status(self) -> JobStatus: else: return JobStatus.FAILED - def wait_for_host(self, timeout: int = 60) -> Optional[str]: + def wait_for_host(self, timeout: int = 180) -> Optional[str]: """ Monitor process output for host information. - + Args: - timeout: Maximum time to wait in seconds - + timeout: Maximum time to wait in seconds (default 180s for model loading) + Returns: Host URL if found, None otherwise """ @@ -464,18 +464,19 @@ def run_locally(command: str, name: str) -> LocalJob: LocalJob object with process information """ logger.info(f"Running locally: {command}") - + try: process = subprocess.Popen( - command.split(), + command, + shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True ) - + local_job = LocalJob(process=process, model_name=name) return local_job - + except Exception as e: logger.error(f"Error starting local process: {e}") raise diff --git a/cellmap_flow/utils/ds.py b/cellmap_flow/utils/ds.py index f8de279..11b241a 100644 --- a/cellmap_flow/utils/ds.py +++ b/cellmap_flow/utils/ds.py @@ -153,7 +153,7 @@ def split_dataset_path(dataset_path, scale=None) -> tuple[str, str]: ".zarr" if dataset_path.rfind(".zarr") > dataset_path.rfind(".n5") else ".n5" ) - filename, dataset = dataset_path.split(splitter) + filename, dataset = dataset_path.rsplit(splitter, 1) if dataset.startswith("/"): dataset = dataset[1:] # include scale if present diff --git a/cellmap_flow/utils/load_py.py b/cellmap_flow/utils/load_py.py index deb8179..7b82b74 100644 --- a/cellmap_flow/utils/load_py.py +++ b/cellmap_flow/utils/load_py.py @@ -42,12 +42,9 @@ def analyze_script(filepath): # If function is a direct name (e.g., `eval()`) if isinstance(node.func, ast.Name) and node.func.id in DISALLOWED_FUNCTIONS: issues.append(f"Disallowed function call detected: {node.func.id}") - # If function is an attribute call (e.g., `os.system()`) - elif ( - isinstance(node.func, ast.Attribute) - and node.func.attr in DISALLOWED_FUNCTIONS - ): - issues.append(f"Disallowed function call detected: {node.func.attr}") + # Note: We intentionally do NOT flag method calls like `model.eval()` here + # Method calls on objects (e.g., model.eval()) are safe - only direct calls + # to dangerous builtin functions (e.g., eval()) are a security risk # Return whether the script is safe (no issues found) and the list of issues is_safe = len(issues) == 0 diff --git a/docs/finetuning.md b/docs/finetuning.md new file mode 100644 index 0000000..d03c9ce --- /dev/null +++ b/docs/finetuning.md @@ -0,0 +1,92 @@ +# Finetuning Guide + +This guide walks through the full finetuning workflow in CellMap-Flow: loading data, creating annotations, and training a finetuned model — all from the dashboard. + +## 1. Launch the Dashboard + +Start by loading your data and model with a YAML configuration file: + +```bash +cellmap_flow_yaml my_yamls/jrc_c-elegans-bw-1_affinities.yaml +``` + +This starts the dashboard with your dataset and model loaded into the Neuroglancer viewer. + +## 2. Create an Annotation Volume +![Annotation Crops tab](screenshots/finetune_annotation_crops.png) + +Navigate to the **Finetune** tab in the dashboard. + +Under **Annotation Crops**, you will see your model configuration (name, output size, voxel size, crop shape, channels). + +1. Set the **Output Path for Zarr Files** to a directory where annotation data will be saved. This must be accessible to the MinIO server that the dashboard starts. +2. Click **Create Annotation Volume**. + - This creates a sparse annotation zarr covering the full dataset extent, where each chunk maps to one training sample. + - A MinIO server will start automatically to serve the zarr for editing in Neuroglancer. + + +## 3. Set Up Annotation Tools in Neuroglancer +![Draw tab with bound keys](screenshots/finetune_draw_tab.png) + +Once the annotation volume is created and added to the viewer: + +1. **Select the annotation layer** by right-clicking on it in the layer list (it will be named something like `sparse_annotation_vol-XXXX`). +2. Go to the **Draw** tab for that layer. +3. **Bind keyboard shortcuts** to the drawing tools: + - Click the small box next to each tool name (e.g. `[A] Brush`, `[S] Flood Fill`, `[D] Seg Picker`). + - Press the letter you want to assign to that tool. + - Once bound, activate a tool by pressing **Shift + the assigned letter**. + + + +## 4. Annotate + +When you start drawing, Neuroglancer will ask if you want to write to the file — click **Yes**. + +### Annotation label rules + +- **Paint Value 1** = **background** (this voxel is not the object of interest) +- **Paint Value 2** = **foreground** (this voxel is the object of interest) +- For **affinities models** with multiple object IDs, use higher paint values (3, 4, ...) for distinct object instances. The finetuning pipeline will automatically convert these instance IDs into affinity targets using the offsets defined in the model script. +- **Paint Value 0** = **unannotated / ignored** — these voxels are excluded from the loss during training. + +You can change the paint value in the Draw tab by editing the **Paint Value** field, or click **Random** next to **New Random Value** to pick a new instance ID. + +Annotate as many chunks as you like across the dataset. Only chunks with non-zero annotations will be used for training. + +## 5. Training + +Switch to the **Training** tab in the Finetune section. + +![Training tab](screenshots/finetune_training_tab.png) + +### Training configuration options + +| Parameter | Description | +|---|---| +| **Checkpoint Path** | (Optional, Advanced) Override the base model checkpoint to finetune from. Leave empty to auto-detect from the model configuration or script. | +| **LoRA Rank** | Controls the number of trainable parameters. Higher rank = more capacity but more memory. Typical values: 4 (low), 8 (default), 16 (high). | +| **Number of Epochs** | How many passes over the training data. Typical range: 10–20. | +| **Batch Size** | Number of samples per training step. Higher = faster but uses more GPU memory. | +| **Learning Rate** | Step size for optimization. 1e-4 is a good starting point for LoRA. | +| **Loss Function** | The training objective. **MSE** for standard regression. **Margin** is recommended for sparse annotations (auto-selected when sparse volumes are detected). | +| **Distillation Weight** | Keeps the finetuned model close to the original model's predictions. 0.5 is a good default. Set to 0.0 to disable. | +| **Distillation Scope** | (Advanced) Where to apply distillation loss — **Unlabeled** (only on unannotated voxels) or **All** (everywhere). | +| **Balance fg/bg classes** | Weights foreground and background equally in the loss regardless of how much of each you've annotated. Prevents the model from overpredicting whichever class has more annotations. | +| **GPU Queue** | Which GPU queue to submit the training job to (e.g. H100, H200). | +| **Auto-load model after training** | When checked, the finetuned model will automatically start an inference server and be added to the Neuroglancer viewer once training completes. | + +### Start training + +Click **Start Finetuning** to submit the training job to the GPU cluster. You can monitor training progress via the live log stream in the Training tab. + +## 6. Iterative Refinement + +After reviewing the finetuned model's predictions in Neuroglancer: + +1. Add more annotations or correct existing ones in the annotation volume. +2. Go back to the **Training** tab. +3. Click **Restart Finetuning** — this retrains on the same GPU using your updated annotations without needing to resubmit a new job. +4. Updated parameters (epochs, learning rate, loss, etc.) can be changed before restarting. + +Repeat this annotate-train-review cycle until the model performs well on your data. diff --git a/docs/screenshots/finetune_annotation_crops.png b/docs/screenshots/finetune_annotation_crops.png new file mode 100644 index 0000000..00f84fd Binary files /dev/null and b/docs/screenshots/finetune_annotation_crops.png differ diff --git a/docs/screenshots/finetune_draw_tab.png b/docs/screenshots/finetune_draw_tab.png new file mode 100644 index 0000000..a55ed1b Binary files /dev/null and b/docs/screenshots/finetune_draw_tab.png differ diff --git a/docs/screenshots/finetune_training_tab.png b/docs/screenshots/finetune_training_tab.png new file mode 100644 index 0000000..ad21af3 Binary files /dev/null and b/docs/screenshots/finetune_training_tab.png differ diff --git a/example/model_spec_affinities.py b/example/model_spec_affinities.py index 46c4e7b..1f5ee71 100644 --- a/example/model_spec_affinities.py +++ b/example/model_spec_affinities.py @@ -1,15 +1,17 @@ -#%% +# %% # pip install fly-organelles from funlib.geometry.coordinate import Coordinate import torch import funlib.learn.torch import numpy as np + voxel_size = (16, 16, 16) read_shape = Coordinate((178, 178, 178)) * Coordinate(voxel_size) write_shape = Coordinate((56, 56, 56)) * Coordinate(voxel_size) output_voxel_size = Coordinate((16, 16, 16)) -#%% + +# %% class StandardUnet(torch.nn.Module): def __init__( self, @@ -43,19 +45,23 @@ def __init__( constant_upsample=True, ) - self.final_conv = torch.nn.Conv3d(num_fmaps, out_channels, (1, 1, 1), padding="valid") + self.final_conv = torch.nn.Conv3d( + num_fmaps, out_channels, (1, 1, 1), padding="valid" + ) def forward(self, raw): x = self.unet_backbone(raw) return self.final_conv(x) -#%% + + +# %% def load_eval_model(num_labels, checkpoint_path): model_backbone = StandardUnet(num_labels) if torch.cuda.is_available(): device = torch.device("cuda") else: device = torch.device("cpu") - print("device:", device) + print("device:", device) checkpoint = torch.load(checkpoint_path, weights_only=True, map_location=device) model_backbone.load_state_dict(checkpoint["model_state_dict"]) model = torch.nn.Sequential(model_backbone, torch.nn.Sigmoid()) @@ -64,9 +70,11 @@ def load_eval_model(num_labels, checkpoint_path): return model -classes = ["mito",]*9 +classes = [ + "mito", +] * 9 CHECKPOINT_PATH = "/groups/cellmap/cellmap/zouinkhim/c-elegen/v2/train/fly_run/all/affinities/new/run04_mito/model_checkpoint_65000" output_channels = len(classes) model = load_eval_model(output_channels, CHECKPOINT_PATH) -block_shape = np.array((56, 56, 56,output_channels)) +block_shape = np.array((56, 56, 56, output_channels)) # %% diff --git a/pyproject.toml b/pyproject.toml index 388ae47..a77a3cd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -70,6 +70,14 @@ postprocess = ["edt", "mwatershed @ git+https://github.com/pattonw/mwatershed", "funlib.math @ git+https://github.com/funkelab/funlib.math.git",] +finetune = [ + "peft>=0.7.0", # HuggingFace Parameter-Efficient Fine-Tuning + "transformers>=4.35.0", # Required by peft + "accelerate>=0.20.0", # Training utilities + "minio-client", # For annotations + "minio-server" # For annotations +] + [build-system] requires = ["setuptools", "wheel"] build-backend = "setuptools.build_meta" @@ -92,4 +100,5 @@ cellmap_flow_server = "cellmap_flow.cli.server_cli:cli" cellmap_flow_yaml = "cellmap_flow.cli.yaml_cli:main" cellmap_flow_blockwise = "cellmap_flow.blockwise.cli:cli" cellmap_flow_blockwise_multiple = "cellmap_flow.blockwise.multiple_cli:cli" -cellmap_flow_app = "cellmap_flow.dashboard.app:create_and_run_app" \ No newline at end of file +cellmap_flow_app = "cellmap_flow.dashboard.app:create_and_run_app" +cellmap_flow_viewer = "cellmap_flow.cli.viewer_cli:main" \ No newline at end of file diff --git a/tests/finetune/__init__.py b/tests/finetune/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/finetune/test_target_transforms.py b/tests/finetune/test_target_transforms.py new file mode 100644 index 0000000..04e48f8 --- /dev/null +++ b/tests/finetune/test_target_transforms.py @@ -0,0 +1,200 @@ +"""Tests for target transforms.""" + +import torch +from cellmap_flow.finetune.target_transforms import ( + BinaryTargetTransform, + BroadcastBinaryTargetTransform, + AffinityTargetTransform, + _offset_slices, +) + + +def test_binary_transform_basic(): + """Test that BinaryTargetTransform produces correct targets and masks.""" + # annotation: 0=unannotated, 1=bg, 2=fg + annotation = torch.tensor([[[[[0, 1, 2, 0, 1]]]]]).float() # (1, 1, 1, 1, 5) + transform = BinaryTargetTransform() + target, mask = transform(annotation) + + # mask: 1 where annotated (>0) + assert mask.tolist() == [[[[[0, 1, 1, 0, 1]]]]] + # target: 0 for bg (was 1), 1 for fg (was 2), 0 for unannotated + assert target.tolist() == [[[[[0, 0, 1, 0, 0]]]]] + + +def test_binary_transform_multi_object(): + """Labels 2 and 3 both become foreground (1).""" + annotation = torch.tensor([[[[[1, 2, 3]]]]]).float() + transform = BinaryTargetTransform() + target, mask = transform(annotation) + + assert target.tolist() == [[[[[0, 1, 1]]]]] + assert mask.tolist() == [[[[[1, 1, 1]]]]] + + +def test_broadcast_transform(): + """Test broadcasting to multiple channels.""" + annotation = torch.tensor([[[[[0, 1, 2]]]]]).float() # (1, 1, 1, 1, 3) + transform = BroadcastBinaryTargetTransform(num_channels=3) + target, mask = transform(annotation) + + assert target.shape == (1, 3, 1, 1, 3) + assert mask.shape == (1, 3, 1, 1, 3) + # All channels should be identical + for c in range(3): + assert target[0, c].tolist() == [[[0, 0, 1]]] + assert mask[0, c].tolist() == [[[0, 1, 1]]] + + +def test_affinity_transform_same_object(): + """Two adjacent voxels of the same object should have affinity=1.""" + # 1D-like: [bg, obj2, obj2, bg] along X + annotation = torch.zeros(1, 1, 1, 1, 4) + annotation[0, 0, 0, 0, :] = torch.tensor([1, 2, 2, 1]).float() + + offsets = [[0, 0, 1]] # X offset + transform = AffinityTargetTransform(offsets) + target, mask = transform(annotation) + + # target shape: (1, 1, 1, 1, 4) + assert target.shape == (1, 1, 1, 1, 4) + + # Pairs (along X, offset +1): + # (0,1): bg-obj2 -> 0, both annotated -> mask=1 + # (1,2): obj2-obj2 -> 1, both annotated -> mask=1 + # (2,3): obj2-bg -> 0, both annotated -> mask=1 + # Position 3 has no pair (boundary) -> target=0, mask=0 + assert target[0, 0, 0, 0, :3].tolist() == [0, 1, 0] + assert mask[0, 0, 0, 0, :3].tolist() == [1, 1, 1] + assert mask[0, 0, 0, 0, 3].item() == 0 # no pair for last voxel + + +def test_affinity_transform_different_objects(): + """Adjacent voxels of different objects should have affinity=0.""" + annotation = torch.zeros(1, 1, 1, 1, 3) + annotation[0, 0, 0, 0, :] = torch.tensor([2, 3, 2]).float() + + offsets = [[0, 0, 1]] + transform = AffinityTargetTransform(offsets) + target, mask = transform(annotation) + + # (0,1): obj2-obj3 -> 0 + # (1,2): obj3-obj2 -> 0 + assert target[0, 0, 0, 0, :2].tolist() == [0, 0] + assert mask[0, 0, 0, 0, :2].tolist() == [1, 1] + + +def test_affinity_transform_unannotated_masking(): + """Unannotated voxels should produce mask=0.""" + annotation = torch.zeros(1, 1, 1, 1, 4) + annotation[0, 0, 0, 0, :] = torch.tensor([2, 0, 2, 1]).float() + + offsets = [[0, 0, 1]] + transform = AffinityTargetTransform(offsets) + target, mask = transform(annotation) + + # (0,1): obj2-unannotated -> mask=0 + # (1,2): unannotated-obj2 -> mask=0 + # (2,3): obj2-bg -> mask=1, target=0 + assert mask[0, 0, 0, 0, 0].item() == 0 + assert mask[0, 0, 0, 0, 1].item() == 0 + assert mask[0, 0, 0, 0, 2].item() == 1 + assert target[0, 0, 0, 0, 2].item() == 0 + + +def test_affinity_transform_multiple_offsets(): + """Test with Z, Y, X offsets.""" + annotation = torch.zeros(1, 1, 3, 3, 3) + # Fill with same object + annotation[:] = 2 + # Set corners to background + annotation[0, 0, 0, 0, 0] = 1 + + offsets = [[1, 0, 0], [0, 1, 0], [0, 0, 1]] + transform = AffinityTargetTransform(offsets) + target, mask = transform(annotation) + + assert target.shape == (1, 3, 3, 3, 3) + assert mask.shape == (1, 3, 3, 3, 3) + + # All annotated (>0), so mask should be 1 everywhere there's a valid pair + # Z offset channel: mask=1 for z=0,1 (pairs with z+1 exist), mask=0 for z=2 + assert mask[0, 0, 2, :, :].sum().item() == 0 # no z+1 for z=2 + assert mask[0, 0, 0, :, :].sum().item() == 9 # all y,x pairs valid + assert mask[0, 0, 1, :, :].sum().item() == 9 + + # Corner (0,0,0) is bg, (1,0,0) is fg -> Z-offset affinity at (0,0,0) = 0 + assert target[0, 0, 0, 0, 0].item() == 0 + # (1,0,0) and (2,0,0) both fg -> Z-offset affinity at (1,0,0) = 1 + assert target[0, 0, 1, 0, 0].item() == 1 + + +def test_affinity_transform_negative_offset(): + """Test that negative offsets work correctly.""" + annotation = torch.zeros(1, 1, 1, 1, 4) + annotation[0, 0, 0, 0, :] = torch.tensor([1, 2, 2, 1]).float() + + offsets = [[0, 0, -1]] # Negative X offset + transform = AffinityTargetTransform(offsets) + target, mask = transform(annotation) + + # With offset -1, source starts at index 1, dest starts at index 0 + # Pair (1,0): obj2-bg -> 0, both annotated -> mask=1 + # Pair (2,1): obj2-obj2 -> 1, both annotated -> mask=1 + # Pair (3,2): bg-obj2 -> 0, both annotated -> mask=1 + assert target[0, 0, 0, 0, 1].item() == 0 + assert target[0, 0, 0, 0, 2].item() == 1 + assert target[0, 0, 0, 0, 3].item() == 0 + assert mask[0, 0, 0, 0, 0].item() == 0 # no pair for index 0 + + +def test_offset_slices(): + """Test _offset_slices helper.""" + # Positive offset + src, dst = _offset_slices(10, 10, 10, 1, 0, 0) + assert src == (slice(None, 9), slice(None), slice(None)) + assert dst == (slice(1, None), slice(None), slice(None)) + + # Negative offset + src, dst = _offset_slices(10, 10, 10, 0, 0, -2) + assert src == (slice(None), slice(None), slice(2, None)) + assert dst == (slice(None), slice(None), slice(None, 8)) + + # Zero offset + src, dst = _offset_slices(10, 10, 10, 0, 0, 0) + assert src == (slice(None), slice(None), slice(None)) + assert dst == (slice(None), slice(None), slice(None)) + + +def test_affinity_transform_extra_channels_masked(): + """Extra channels (e.g. LSDs) should have mask=0.""" + annotation = torch.zeros(1, 1, 1, 1, 4) + annotation[0, 0, 0, 0, :] = torch.tensor([1, 2, 2, 1]).float() + + offsets = [[0, 0, 1]] # 1 affinity channel + transform = AffinityTargetTransform(offsets, num_channels=4) # 1 aff + 3 extra + target, mask = transform(annotation) + + assert target.shape == (1, 4, 1, 1, 4) + assert mask.shape == (1, 4, 1, 1, 4) + + # Channel 0 (affinity) should have valid mask + assert mask[0, 0, 0, 0, :3].sum().item() == 3 + # Channels 1-3 (extra, e.g. LSDs) should be fully masked out + assert mask[0, 1, :, :, :].sum().item() == 0 + assert mask[0, 2, :, :, :].sum().item() == 0 + assert mask[0, 3, :, :, :].sum().item() == 0 + + +if __name__ == "__main__": + test_binary_transform_basic() + test_binary_transform_multi_object() + test_broadcast_transform() + test_affinity_transform_same_object() + test_affinity_transform_different_objects() + test_affinity_transform_unannotated_masking() + test_affinity_transform_multiple_offsets() + test_affinity_transform_negative_offset() + test_offset_slices() + test_affinity_transform_extra_channels_masked() + print("All tests passed!")