Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions deepem/train/option.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,9 @@ def initialize(self):
self.parser.add_argument('--optim', default='Adam')
self.parser.add_argument('--lr', type=float, default=0.001)

# EMA (Exponential Moving Average)
self.parser.add_argument('--ema_decay', type=float, default=0.0)

# Optimizer: Adam
self.parser.add_argument('--betas', type=float, default=[0.9,0.999], nargs='+')
self.parser.add_argument('--eps', type=float, default=1e-08)
Expand Down
56 changes: 46 additions & 10 deletions deepem/train/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim.swa_utils import AveragedModel, get_ema_multi_avg_fn

import samwise

Expand Down Expand Up @@ -68,17 +69,31 @@ def train(opt):
trainable = filter(lambda p: p.requires_grad, model.parameters())
optimizer = load_optimizer(opt, trainable)

# EMA model
ema_model = None
if opt.ema_decay > 0:
base = model.module if opt.parallel == "DDP" else model
ema_model = AveragedModel(
base,
multi_avg_fn=get_ema_multi_avg_fn(opt.ema_decay),
use_buffers=True,
)
if opt.chkpt_num != 0:
load_ema_state(ema_model, opt)

# Data loaders
train_loader, val_loader = load_data(opt, local_rank)

# Initial checkpoint
if opt.parallel == "DDP":
if dist.get_rank() == 0:
model = revert_sync_batchnorm(model)
save_chkpt(model.module, opt.model_dir, opt.chkpt_num, optimizer)
save_chkpt(model.module, opt.model_dir, opt.chkpt_num,
optimizer, ema_model=ema_model)
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
else:
save_chkpt(model, opt.model_dir, opt.chkpt_num, optimizer)
save_chkpt(model, opt.model_dir, opt.chkpt_num,
optimizer, ema_model=ema_model)

# Mixed-precision training
scaler = None
Expand Down Expand Up @@ -135,6 +150,11 @@ def train(opt):
total_loss.backward()
optimizer.step()

# Update EMA
if ema_model is not None:
ema_base = model.module if opt.parallel == "DDP" else model
ema_model.update_parameters(ema_base)

# Elapsed time
end.record()
end.synchronize() # waits only for work up to `end` on this stream
Expand Down Expand Up @@ -163,33 +183,49 @@ def train(opt):
if (i+1) % opt.eval_intv == 0:
if opt.parallel == "DDP":
if dist.get_rank() == 0:
model = revert_sync_batchnorm(model)
eval_loop(i+1, model, val_loader, opt, logger, wandb_logger)
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
if ema_model is not None:
eval_loop(i+1, ema_model.module,
val_loader, opt, logger,
wandb_logger)
else:
model = revert_sync_batchnorm(model)
eval_loop(i+1, model, val_loader,
opt, logger, wandb_logger)
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
else:
eval_loop(i+1, model, val_loader, opt, logger, wandb_logger)
eval_model = (ema_model.module
if ema_model is not None
else model)
eval_loop(i+1, eval_model, val_loader,
opt, logger, wandb_logger)

# Frontier checkpoint (overwrites previous frontier)
if opt.chkpt_sync_intv is not None and (i+1) % opt.chkpt_sync_intv == 0:
if opt.parallel == "DDP":
if dist.get_rank() == 0:
model = revert_sync_batchnorm(model)
save_frontier_chkpt(model.module, opt.model_dir, i+1, optimizer)
save_frontier_chkpt(model.module, opt.model_dir,
i+1, optimizer,
ema_model=ema_model)
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
else:
save_frontier_chkpt(model, opt.model_dir, i+1, optimizer)
save_frontier_chkpt(model, opt.model_dir, i+1,
optimizer,
ema_model=ema_model)

# Model checkpoint
if (i+1) % opt.chkpt_intv == 0:
if opt.parallel == "DDP":
if dist.get_rank() == 0:
model = revert_sync_batchnorm(model)
save_chkpt(model.module, opt.model_dir, i+1, optimizer)
save_chkpt(model.module, opt.model_dir, i+1,
optimizer, ema_model=ema_model)
if opt.export_onnx:
export_onnx(opt, i+1)
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
else:
save_chkpt(model, opt.model_dir, i+1, optimizer)
save_chkpt(model, opt.model_dir, i+1,
optimizer, ema_model=ema_model)
if opt.export_onnx:
export_onnx(opt, i+1)

Expand Down
33 changes: 31 additions & 2 deletions deepem/train/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,22 +155,28 @@ def chkpt_num_from_filename(f):
return latest_regular, False


def save_chkpt(model, fpath, chkpt_num, optimizer):
def save_chkpt(model, fpath, chkpt_num, optimizer, ema_model=None):
print(f"SAVE CHECKPOINT: {chkpt_num} iters.")
fname = os.path.join(fpath, f"model{chkpt_num}.chkpt")
state = {'iter': chkpt_num,
'state_dict': model.state_dict(),
'optimizer': optimizer.state_dict()}
if ema_model is not None:
state['ema_state_dict'] = ema_model.module.state_dict()
state['ema_n_averaged'] = ema_model.n_averaged.item()
torch.save(state, fname)


def save_frontier_chkpt(model, fpath, chkpt_num, optimizer):
def save_frontier_chkpt(model, fpath, chkpt_num, optimizer, ema_model=None):
"""Save a frontier checkpoint that overwrites the previous frontier checkpoint."""
print(f"SAVE FRONTIER CHECKPOINT: {chkpt_num} iters.")
fname = os.path.join(fpath, "model_frontier.chkpt")
state = {'iter': chkpt_num,
'state_dict': model.state_dict(),
'optimizer': optimizer.state_dict()}
if ema_model is not None:
state['ema_state_dict'] = ema_model.module.state_dict()
state['ema_n_averaged'] = ema_model.n_averaged.item()
torch.save(state, fname)


Expand Down Expand Up @@ -209,6 +215,29 @@ def load_optimizer_state(optimizer, fpath, chkpt_num, is_frontier=False):
state[k] = v.cuda()


def load_ema_state(ema_model, opt):
"""Load EMA state from checkpoint if available."""
is_frontier = getattr(opt, 'loaded_from_frontier', False)
if is_frontier:
fname = os.path.join(opt.model_dir, "model_frontier.chkpt")
else:
fname = os.path.join(opt.model_dir,
f"model{opt.chkpt_num}.chkpt")
if not os.path.exists(fname):
return

chkpt = torch.load(fname)
if 'ema_state_dict' in chkpt:
print(f"LOAD EMA STATE: {opt.chkpt_num} iters.")
ema_model.module.model.load_state_dict(
chkpt['ema_state_dict'])
if 'ema_n_averaged' in chkpt:
ema_model.n_averaged.fill_(
chkpt['ema_n_averaged'])
else:
print("No EMA state in checkpoint, starting fresh.")


def load_data(opt, local_rank):
data_ids = list(set().union(opt.train_ids, opt.val_ids))
if opt.zettaset_specs:
Expand Down
26 changes: 24 additions & 2 deletions deepem/utils/onnx_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,6 @@ def export_onnx(
onnx_model = load_model(onnx_opt)
onnx_model, count = batchnorm3d_to_instancenorm3d(onnx_model)
print(f"Replaced {count} BatchNorm3d layer to InstanceNorm3d layer.")
fname = os.path.join(onnx_opt.model_dir, f"model{chkpt_num}.onnx")

args = dummy_input(onnx_opt.in_spec, device=onnx_opt.device)

Expand All @@ -88,14 +87,37 @@ def export_onnx(
# Generate input names based on spec keys
input_names = sorted(onnx_opt.in_spec.keys())

# Export regular model
fname = os.path.join(onnx_opt.model_dir, f"model{chkpt_num}.onnx")
torch.onnx.export(
onnx_model,
export_args,
fname,
verbose=False,
export_params=True,
opset_version=onnx_opt.opset_version,
input_names=input_names, # dynamic input names based on spec
input_names=input_names,
output_names=["output"]
)
print(f"Relative ONNX filepath: {fname}")

# Export EMA model if available
if getattr(opt, 'ema_decay', 0) > 0:
chkpt_fname = os.path.join(
onnx_opt.model_dir, f"model{chkpt_num}.chkpt")
chkpt = torch.load(chkpt_fname)
if 'ema_state_dict' in chkpt:
onnx_model.model.load_state_dict(chkpt['ema_state_dict'])
fname_ema = os.path.join(
onnx_opt.model_dir, f"model{chkpt_num}_ema.onnx")
torch.onnx.export(
onnx_model,
export_args,
fname_ema,
verbose=False,
export_params=True,
opset_version=onnx_opt.opset_version,
input_names=input_names,
output_names=["output"]
)
print(f"Relative EMA ONNX filepath: {fname_ema}")