diff --git a/deepem/train/option.py b/deepem/train/option.py index ed081bb..c7c9748 100644 --- a/deepem/train/option.py +++ b/deepem/train/option.py @@ -118,6 +118,9 @@ def initialize(self): self.parser.add_argument('--optim', default='Adam') self.parser.add_argument('--lr', type=float, default=0.001) + # EMA (Exponential Moving Average) + self.parser.add_argument('--ema_decay', type=float, default=0.0) + # Optimizer: Adam self.parser.add_argument('--betas', type=float, default=[0.9,0.999], nargs='+') self.parser.add_argument('--eps', type=float, default=1e-08) diff --git a/deepem/train/run.py b/deepem/train/run.py index 03fa17c..b260380 100644 --- a/deepem/train/run.py +++ b/deepem/train/run.py @@ -4,6 +4,7 @@ import torch import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel as DDP +from torch.optim.swa_utils import AveragedModel, get_ema_multi_avg_fn import samwise @@ -68,6 +69,18 @@ def train(opt): trainable = filter(lambda p: p.requires_grad, model.parameters()) optimizer = load_optimizer(opt, trainable) + # EMA model + ema_model = None + if opt.ema_decay > 0: + base = model.module if opt.parallel == "DDP" else model + ema_model = AveragedModel( + base, + multi_avg_fn=get_ema_multi_avg_fn(opt.ema_decay), + use_buffers=True, + ) + if opt.chkpt_num != 0: + load_ema_state(ema_model, opt) + # Data loaders train_loader, val_loader = load_data(opt, local_rank) @@ -75,10 +88,12 @@ def train(opt): if opt.parallel == "DDP": if dist.get_rank() == 0: model = revert_sync_batchnorm(model) - save_chkpt(model.module, opt.model_dir, opt.chkpt_num, optimizer) + save_chkpt(model.module, opt.model_dir, opt.chkpt_num, + optimizer, ema_model=ema_model) model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) else: - save_chkpt(model, opt.model_dir, opt.chkpt_num, optimizer) + save_chkpt(model, opt.model_dir, opt.chkpt_num, + optimizer, ema_model=ema_model) # Mixed-precision training scaler = None @@ -135,6 +150,11 @@ def train(opt): total_loss.backward() optimizer.step() + # Update EMA + if ema_model is not None: + ema_base = model.module if opt.parallel == "DDP" else model + ema_model.update_parameters(ema_base) + # Elapsed time end.record() end.synchronize() # waits only for work up to `end` on this stream @@ -163,33 +183,49 @@ def train(opt): if (i+1) % opt.eval_intv == 0: if opt.parallel == "DDP": if dist.get_rank() == 0: - model = revert_sync_batchnorm(model) - eval_loop(i+1, model, val_loader, opt, logger, wandb_logger) - model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) + if ema_model is not None: + eval_loop(i+1, ema_model.module, + val_loader, opt, logger, + wandb_logger) + else: + model = revert_sync_batchnorm(model) + eval_loop(i+1, model, val_loader, + opt, logger, wandb_logger) + model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) else: - eval_loop(i+1, model, val_loader, opt, logger, wandb_logger) + eval_model = (ema_model.module + if ema_model is not None + else model) + eval_loop(i+1, eval_model, val_loader, + opt, logger, wandb_logger) # Frontier checkpoint (overwrites previous frontier) if opt.chkpt_sync_intv is not None and (i+1) % opt.chkpt_sync_intv == 0: if opt.parallel == "DDP": if dist.get_rank() == 0: model = revert_sync_batchnorm(model) - save_frontier_chkpt(model.module, opt.model_dir, i+1, optimizer) + save_frontier_chkpt(model.module, opt.model_dir, + i+1, optimizer, + ema_model=ema_model) model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) else: - save_frontier_chkpt(model, opt.model_dir, i+1, optimizer) + save_frontier_chkpt(model, opt.model_dir, i+1, + optimizer, + ema_model=ema_model) # Model checkpoint if (i+1) % opt.chkpt_intv == 0: if opt.parallel == "DDP": if dist.get_rank() == 0: model = revert_sync_batchnorm(model) - save_chkpt(model.module, opt.model_dir, i+1, optimizer) + save_chkpt(model.module, opt.model_dir, i+1, + optimizer, ema_model=ema_model) if opt.export_onnx: export_onnx(opt, i+1) model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) else: - save_chkpt(model, opt.model_dir, i+1, optimizer) + save_chkpt(model, opt.model_dir, i+1, + optimizer, ema_model=ema_model) if opt.export_onnx: export_onnx(opt, i+1) diff --git a/deepem/train/utils.py b/deepem/train/utils.py index 8f2fcbd..65f881d 100644 --- a/deepem/train/utils.py +++ b/deepem/train/utils.py @@ -155,22 +155,28 @@ def chkpt_num_from_filename(f): return latest_regular, False -def save_chkpt(model, fpath, chkpt_num, optimizer): +def save_chkpt(model, fpath, chkpt_num, optimizer, ema_model=None): print(f"SAVE CHECKPOINT: {chkpt_num} iters.") fname = os.path.join(fpath, f"model{chkpt_num}.chkpt") state = {'iter': chkpt_num, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict()} + if ema_model is not None: + state['ema_state_dict'] = ema_model.module.state_dict() + state['ema_n_averaged'] = ema_model.n_averaged.item() torch.save(state, fname) -def save_frontier_chkpt(model, fpath, chkpt_num, optimizer): +def save_frontier_chkpt(model, fpath, chkpt_num, optimizer, ema_model=None): """Save a frontier checkpoint that overwrites the previous frontier checkpoint.""" print(f"SAVE FRONTIER CHECKPOINT: {chkpt_num} iters.") fname = os.path.join(fpath, "model_frontier.chkpt") state = {'iter': chkpt_num, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict()} + if ema_model is not None: + state['ema_state_dict'] = ema_model.module.state_dict() + state['ema_n_averaged'] = ema_model.n_averaged.item() torch.save(state, fname) @@ -209,6 +215,29 @@ def load_optimizer_state(optimizer, fpath, chkpt_num, is_frontier=False): state[k] = v.cuda() +def load_ema_state(ema_model, opt): + """Load EMA state from checkpoint if available.""" + is_frontier = getattr(opt, 'loaded_from_frontier', False) + if is_frontier: + fname = os.path.join(opt.model_dir, "model_frontier.chkpt") + else: + fname = os.path.join(opt.model_dir, + f"model{opt.chkpt_num}.chkpt") + if not os.path.exists(fname): + return + + chkpt = torch.load(fname) + if 'ema_state_dict' in chkpt: + print(f"LOAD EMA STATE: {opt.chkpt_num} iters.") + ema_model.module.model.load_state_dict( + chkpt['ema_state_dict']) + if 'ema_n_averaged' in chkpt: + ema_model.n_averaged.fill_( + chkpt['ema_n_averaged']) + else: + print("No EMA state in checkpoint, starting fresh.") + + def load_data(opt, local_rank): data_ids = list(set().union(opt.train_ids, opt.val_ids)) if opt.zettaset_specs: diff --git a/deepem/utils/onnx_utils.py b/deepem/utils/onnx_utils.py index 979a53b..66202cf 100644 --- a/deepem/utils/onnx_utils.py +++ b/deepem/utils/onnx_utils.py @@ -78,7 +78,6 @@ def export_onnx( onnx_model = load_model(onnx_opt) onnx_model, count = batchnorm3d_to_instancenorm3d(onnx_model) print(f"Replaced {count} BatchNorm3d layer to InstanceNorm3d layer.") - fname = os.path.join(onnx_opt.model_dir, f"model{chkpt_num}.onnx") args = dummy_input(onnx_opt.in_spec, device=onnx_opt.device) @@ -88,6 +87,8 @@ def export_onnx( # Generate input names based on spec keys input_names = sorted(onnx_opt.in_spec.keys()) + # Export regular model + fname = os.path.join(onnx_opt.model_dir, f"model{chkpt_num}.onnx") torch.onnx.export( onnx_model, export_args, @@ -95,7 +96,28 @@ def export_onnx( verbose=False, export_params=True, opset_version=onnx_opt.opset_version, - input_names=input_names, # dynamic input names based on spec + input_names=input_names, output_names=["output"] ) print(f"Relative ONNX filepath: {fname}") + + # Export EMA model if available + if getattr(opt, 'ema_decay', 0) > 0: + chkpt_fname = os.path.join( + onnx_opt.model_dir, f"model{chkpt_num}.chkpt") + chkpt = torch.load(chkpt_fname) + if 'ema_state_dict' in chkpt: + onnx_model.model.load_state_dict(chkpt['ema_state_dict']) + fname_ema = os.path.join( + onnx_opt.model_dir, f"model{chkpt_num}_ema.onnx") + torch.onnx.export( + onnx_model, + export_args, + fname_ema, + verbose=False, + export_params=True, + opset_version=onnx_opt.opset_version, + input_names=input_names, + output_names=["output"] + ) + print(f"Relative EMA ONNX filepath: {fname_ema}")