forked from uni-medical/STU-Net
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathrun_finetuning.py
More file actions
254 lines (221 loc) · 13.7 KB
/
run_finetuning.py
File metadata and controls
254 lines (221 loc) · 13.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
import argparse
import torch
from batchgenerators.utilities.file_and_folder_operations import *
from nnunet.run.default_configuration import get_default_configuration
from nnunet.paths import default_plans_identifier
from nnunet.training.cascade_stuff.predict_next_stage import predict_next_stage
from nnunet.training.network_training.nnUNetTrainer import nnUNetTrainer
from nnunet.training.network_training.nnUNetTrainerCascadeFullRes import nnUNetTrainerCascadeFullRes
from nnunet.training.network_training.nnUNetTrainerV2_CascadeFullRes import nnUNetTrainerV2CascadeFullRes
from nnunet.utilities.task_name_id_conversion import convert_id_to_task_name
def load_pretrained_weights(network, fname, verbose=True):
"""
THIS DOES NOT TRANSFER SEGMENTATION HEADS!
"""
saved_model = torch.load(fname)
pretrained_dict = saved_model['state_dict']
new_state_dict = {}
# if state dict comes from nn.DataParallel but we use non-parallel model here then the state dict keys do not
# match. Use heuristic to make it match
print('below are keys in pretrained model')
for k, value in pretrained_dict.items():
print(k)
key = k
# remove module. prefix from DDP models
if key.startswith('module.'):
key = key[7:]
new_state_dict[key] = value
pretrained_dict = new_state_dict
model_dict = network.state_dict()
ok = True
# for key, _ in model_dict.items():
# if ('conv_blocks' in key):
# if (key in pretrained_dict) and (model_dict[key].shape == pretrained_dict[key].shape):
# continue
# else:
# ok = False
# break
# filter unnecessary keys
# 对输入模态进行调整
# pretrained_dict['conv_blocks_context.0.0.conv1.weight'].shape is [32,num_inputs,3,3,3]
# pretrained_dict['conv_blocks_context.0.0.conv3.weight'].shape is [32,num_inputs,1,1,1]
num_inputs = model_dict['conv_blocks_context.0.0.conv1.weight'].shape[1]
print('number of input modality: ', num_inputs)
if num_inputs > 1:
pretrained_dict['conv_blocks_context.0.0.conv1.weight'] = pretrained_dict['conv_blocks_context.0.0.conv1.weight'].repeat(1, num_inputs, 1,1,1)
pretrained_dict['conv_blocks_context.0.0.conv3.weight'] = pretrained_dict['conv_blocks_context.0.0.conv3.weight'].repeat(1, num_inputs, 1,1,1)
if ok:
# filtered_dict = {k: v for k, v in pretrained_dict.items() if
# (k in model_dict) and (model_dict[k].shape != pretrained_dict[k].shape)}
# loaded_dict = {k: v for k, v in pretrained_dict.items() if
# (k in model_dict) and (model_dict[k].shape == pretrained_dict[k].shape)}
# 不加载 seg_head
filtered_dict = {k: v for k, v in pretrained_dict.items() if
(k in model_dict) and ((model_dict[k].shape != pretrained_dict[k].shape) or k.startswith('seg_outputs'))}
loaded_dict = {k: v for k, v in pretrained_dict.items() if
(k in model_dict) and ((model_dict[k].shape == pretrained_dict[k].shape) and not k.startswith('seg_outputs'))}
# 2. overwrite entries in the existing state dict
model_dict.update(loaded_dict)
print("################### Loading pretrained weights from file ", fname, '###################')
if verbose:
print("Below is the list of overlapping blocks in pretrained model:")
for key, _ in loaded_dict.items():
print(key)
print("Below is the list of not loaded blocks in pretrained model:")
for key, _ in filtered_dict.items():
print(key)
print("################### Done ###################")
network.load_state_dict(model_dict)
else:
raise RuntimeError("Pretrained weights are not compatible with the current network architecture")
def main():
parser = argparse.ArgumentParser()
parser.add_argument("network")
parser.add_argument("network_trainer")
parser.add_argument("task", help="can be task name or task id")
parser.add_argument("fold", help='0, 1, ..., 5 or \'all\'')
parser.add_argument("-val", "--validation_only", help="use this if you want to only run the validation",
action="store_true")
parser.add_argument("-c", "--continue_training", help="use this if you want to continue a training",
action="store_true")
parser.add_argument("-p", help="plans identifier. Only change this if you created a custom experiment planner",
default=default_plans_identifier, required=False)
parser.add_argument("--use_compressed_data", default=False, action="store_true",
help="If you set use_compressed_data, the training cases will not be decompressed. Reading compressed data "
"is much more CPU and RAM intensive and should only be used if you know what you are "
"doing", required=False)
parser.add_argument("--deterministic",
help="Makes training deterministic, but reduces training speed substantially. I (Fabian) think "
"this is not necessary. Deterministic training will make you overfit to some random seed. "
"Don't use that.",
required=False, default=False, action="store_true")
parser.add_argument("--npz", required=False, default=False, action="store_true", help="if set then nnUNet will "
"export npz files of "
"predicted segmentations "
"in the validation as well. "
"This is needed to run the "
"ensembling step so unless "
"you are developing nnUNet "
"you should enable this")
parser.add_argument("--find_lr", required=False, default=False, action="store_true",
help="not used here, just for fun")
parser.add_argument("--valbest", required=False, default=False, action="store_true",
help="hands off. This is not intended to be used")
parser.add_argument("--fp32", required=False, default=False, action="store_true",
help="disable mixed precision training and run old school fp32")
parser.add_argument("--val_folder", required=False, default="validation_raw",
help="name of the validation folder. No need to use this for most people")
parser.add_argument("--disable_saving", required=False, action='store_true',
help="If set nnU-Net will not save any parameter files (except a temporary checkpoint that "
"will be removed at the end of the training). Useful for development when you are "
"only interested in the results and want to save some disk space")
parser.add_argument("--disable_postprocessing_on_folds", required=False, action='store_true',
help="Running postprocessing on each fold only makes sense when developing with nnU-Net and "
"closely observing the model performance on specific configurations. You do not need it "
"when applying nnU-Net because the postprocessing for this will be determined only once "
"all five folds have been trained and nnUNet_find_best_configuration is called. Usually "
"running postprocessing on each fold is computationally cheap, but some users have "
"reported issues with very large images. If your images are large (>600x600x600 voxels) "
"you should consider setting this flag.")
# parser.add_argument("--interp_order", required=False, default=3, type=int,
# help="order of interpolation for segmentations. Testing purpose only. Hands off")
# parser.add_argument("--interp_order_z", required=False, default=0, type=int,
# help="order of interpolation along z if z is resampled separately. Testing purpose only. "
# "Hands off")
# parser.add_argument("--force_separate_z", required=False, default="None", type=str,
# help="force_separate_z resampling. Can be None, True or False. Testing purpose only. Hands off")
parser.add_argument('--val_disable_overwrite', action='store_false', default=True,
help='Validation does not overwrite existing segmentations')
parser.add_argument('--disable_next_stage_pred', action='store_true', default=False,
help='do not predict next stage')
parser.add_argument('-pretrained_weights', type=str, required=False, default=None,
help='path to nnU-Net checkpoint file to be used as pretrained model (use .model '
'file, for example model_final_checkpoint.model). Will only be used when actually training. '
'Optional. Beta. Use with caution.')
args = parser.parse_args()
task = args.task
fold = args.fold
network = args.network
network_trainer = args.network_trainer
validation_only = args.validation_only
plans_identifier = args.p
find_lr = args.find_lr
disable_postprocessing_on_folds = args.disable_postprocessing_on_folds
use_compressed_data = args.use_compressed_data
decompress_data = not use_compressed_data
deterministic = args.deterministic
valbest = args.valbest
fp32 = args.fp32
run_mixed_precision = not fp32
val_folder = args.val_folder
# interp_order = args.interp_order
# interp_order_z = args.interp_order_z
# force_separate_z = args.force_separate_z
if not task.startswith("Task"):
task_id = int(task)
task = convert_id_to_task_name(task_id)
if fold == 'all':
pass
else:
fold = int(fold)
# if force_separate_z == "None":
# force_separate_z = None
# elif force_separate_z == "False":
# force_separate_z = False
# elif force_separate_z == "True":
# force_separate_z = True
# else:
# raise ValueError("force_separate_z must be None, True or False. Given: %s" % force_separate_z)
plans_file, output_folder_name, dataset_directory, batch_dice, stage, \
trainer_class = get_default_configuration(network, task, network_trainer, plans_identifier)
if trainer_class is None:
raise RuntimeError("Could not find trainer class in nnunet.training.network_training")
if network == "3d_cascade_fullres":
assert issubclass(trainer_class, (nnUNetTrainerCascadeFullRes, nnUNetTrainerV2CascadeFullRes)), \
"If running 3d_cascade_fullres then your " \
"trainer class must be derived from " \
"nnUNetTrainerCascadeFullRes"
else:
assert issubclass(trainer_class,
nnUNetTrainer), "network_trainer was found but is not derived from nnUNetTrainer"
trainer = trainer_class(plans_file, fold, output_folder=output_folder_name, dataset_directory=dataset_directory,
batch_dice=batch_dice, stage=stage, unpack_data=decompress_data,
deterministic=deterministic,
fp16=run_mixed_precision)
if args.disable_saving:
trainer.save_final_checkpoint = False # whether or not to save the final checkpoint
trainer.save_best_checkpoint = False # whether or not to save the best checkpoint according to
# self.best_val_eval_criterion_MA
trainer.save_intermediate_checkpoints = True # whether or not to save checkpoint_latest. We need that in case
# the training chashes
trainer.save_latest_only = True # if false it will not store/overwrite _latest but separate files each
trainer.initialize(not validation_only)
if find_lr:
trainer.find_lr()
else:
if not validation_only:
if args.continue_training:
# -c was set, continue a previous training and ignore pretrained weights
trainer.load_latest_checkpoint()
elif (not args.continue_training) and (args.pretrained_weights is not None):
# we start a new training. If pretrained_weights are set, use them
load_pretrained_weights(trainer.network, args.pretrained_weights)
else:
# new training without pretraine weights, do nothing
pass
trainer.run_training()
else:
if valbest:
trainer.load_best_checkpoint(train=False)
else:
trainer.load_final_checkpoint(train=False)
trainer.network.eval()
# predict validation
trainer.validate(save_softmax=args.npz, validation_folder_name=val_folder,
run_postprocessing_on_folds=not disable_postprocessing_on_folds,
overwrite=args.val_disable_overwrite)
if network == '3d_lowres' and not args.disable_next_stage_pred:
print("predicting segmentations for the next stage of the cascade")
predict_next_stage(trainer, join(dataset_directory, trainer.plans['data_identifier'] + "_stage%d" % 1))
if __name__ == "__main__":
main()