-
Notifications
You must be signed in to change notification settings - Fork 4
Expand file tree
/
Copy pathtrain.py
More file actions
110 lines (97 loc) · 4.61 KB
/
train.py
File metadata and controls
110 lines (97 loc) · 4.61 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import argparse
import logging
import os
import random
import numpy as np
import torch
import torch.backends.cudnn as cudnn
from networks.vision_transformer import SwinUnet as ViT_seg
from trainer_unet import trainer_synapse, trainer_assd
from unet.unet_stvit import UNet_STA
parser = argparse.ArgumentParser()
parser.add_argument('--root_path', type=str,
default='data/Synapse', help='root dir for data')
parser.add_argument('--dataset', type=str,
default='Synapse', help='experiment_name')
parser.add_argument('--list_dir', type=str,
default='./lists/lists_Synapse', help='list dir')
parser.add_argument('--num_classes', type=int,
default=9, help='output channel of network')
parser.add_argument('--output_dir', type=str, help='output dir')
parser.add_argument('--max_iterations', type=int,
default=30000, help='maximum epoch number to train')
parser.add_argument('--max_epochs', type=int,
default=150, help='maximum epoch number to train')
parser.add_argument('--batch_size', type=int,
default=32, help='batch_size per gpu')
parser.add_argument('--n_gpu', type=int, default=1, help='total gpu')
parser.add_argument('--deterministic', type=int, default=1,
help='whether use deterministic training')
parser.add_argument('--base_lr', type=float, default=0.01,
help='segmentation network learning rate')
parser.add_argument('--img_size', type=int,
default=224, help='input patch size of network input')
parser.add_argument('--seed', type=int,
default=1234, help='random seed')
# parser.add_argument('--cfg', type=str, required=True, metavar="FILE", help='path to config file', )
parser.add_argument(
"--opts",
help="Modify config options by adding 'KEY VALUE' pairs. ",
default=None,
nargs='+',
)
parser.add_argument('--zip', action='store_true', help='use zipped dataset instead of folder dataset')
parser.add_argument('--cache-mode', type=str, default='part', choices=['no', 'full', 'part'],
help='no: no cache, '
'full: cache all data, '
'part: sharding the dataset into nonoverlapping pieces and only cache one piece')
parser.add_argument('--resume', help='resume from checkpoint')
parser.add_argument('--accumulation-steps', type=int, help="gradient accumulation steps")
parser.add_argument('--use-checkpoint', action='store_true',
help="whether to use gradient checkpointing to save memory")
parser.add_argument('--amp-opt-level', type=str, default='O1', choices=['O0', 'O1', 'O2'],
help='mixed precision opt level, if O0, no amp is used')
parser.add_argument('--tag', help='tag of experiment')
parser.add_argument('--eval', action='store_true', help='Perform evaluation only')
parser.add_argument('--throughput', action='store_true', help='Test throughput only')
parser.add_argument('--gpu_id',default=1,type=int)
parser.add_argument('--lambda_x',default=0.015,type=float)
parser.add_argument('--dino_weight',default=0.3,type=float)
parser.add_argument('--alpha',default=20.,type=float)
parser.add_argument('--sigma',default=5.,type=float)
args = parser.parse_args()
if args.dataset == "Synapse":
args.root_path = os.path.join(args.root_path, "train_npz")
if __name__ == "__main__":
torch.cuda.set_device(args.gpu_id)
if not args.deterministic:
cudnn.benchmark = True
cudnn.deterministic = False
else:
cudnn.benchmark = False
cudnn.deterministic = True
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)
dataset_name = args.dataset
dataset_config = {
'Synapse': {
'root_path': args.root_path,
'list_dir': './lists/lists_Synapse',
'num_classes': 9,
},
}
if args.batch_size != 24 and args.batch_size % 6 == 0:
args.base_lr *= args.batch_size / 24
""" Uncomment while training the main model"""
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
## Training our model
net = UNet_STA(n_in=1, n_class=args.num_classes).cuda()
# weight = torch.load('/media/wzhu59/easystore/SelfReg-UNet/src/train_synase/ASSD pretrained weights/epoch_150.pth')
# net.load_state_dict(weight,strict=False)
# print('weights loaded')
trainer = {'Synapse': trainer_synapse}
# trainer = {'ASSD': trainer_assd}
trainer[dataset_name](args, net, args.output_dir)