forked from baowaly/SynthEHR
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
150 lines (137 loc) · 9.24 KB
/
train.py
File metadata and controls
150 lines (137 loc) · 9.24 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
## _tkinter.TclError: no display name and no $DISPLAY environment variable
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import sys, time, argparse, os, re
import tensorflow as tf
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score
from tensorflow.contrib.layers import l2_regularizer
from tensorflow.contrib.layers import batch_norm
import tensorflow.contrib.slim as slim
import tqdm
from scipy.stats.stats import pearsonr
_VALIDATION_RATIO = 0.1
from model import MEDGAN, MEDWGAN, MEDBGAN
def parse_arguments(parser):
parser.add_argument('--model', type=str, default='medGAN', help='Specify the model name (medGAN, medWGAN, etc.). A dedicated folder will be created to save all models and outputs for this model (default value: medGAN)')
parser.add_argument('--model_name', type=str, default='model001',
help='Specify the model name. A dedicated folder will be created to save all models and outputs for this model (default value: medGAN)')
parser.add_argument('--embed_size', type=int, default=128, help='The dimension size of the embedding, which will be generated by the generator. (default value: 128)')
parser.add_argument('--noise_size', type=int, default=128, help='The dimension size of the random noise, on which the generator is conditioned. (default value: 128)')
parser.add_argument('--generator_size', type=tuple, default=(128, 128), help='The dimension size of the generator. Note that another layer of size "--embed_size" is always added. (default value: (128, 128))')
parser.add_argument('--discriminator_size', type=tuple, default=(256, 128, 1), help='The dimension size of the discriminator. (default value: (256, 128, 1))')
parser.add_argument('--compressor_size', type=tuple, default=(), help='The dimension size of the encoder of the autoencoder. Note that another layer of size "--embed_size" is always added. Therefore this can be a blank tuple. (default value: ())')
parser.add_argument('--decompressor_size', type=tuple, default=(), help='The dimension size of the decoder of the autoencoder. Note that another layer, whose size is equal to the dimension of the <patient_matrix>, is always added. Therefore this can be a blank tuple. (default value: ())')
parser.add_argument('--data_type', type=str, default='binary', choices=['binary', 'count'], help='The input data type. The <patient matrix> could either contain binary values or count values. (default value: "binary")')
parser.add_argument('--batchnorm_decay', type=float, default=0.99, help='Decay value for the moving average used in Batch Normalization. (default value: 0.99)')
parser.add_argument('--L2', type=float, default=0.001, help='L2 regularization coefficient for all weights. (default value: 0.001)')
parser.add_argument('--gp_scale', type=float, default=10.0, help='Gradient penalty scale used in WGAN (default value: 10.0)')
parser.add_argument('--data_file', type=str, default='data/inpatient_final_data.npy', help='The path to the numpy matrix containing aggregated patient records.')
parser.add_argument('--out_name', type=str, default='generated.npy', help='The file name of the generating data.')
parser.add_argument('--init_from', type=str, default=None, help='Continue training from saved model in the "models" sub-folder in this folder. If None, train from scratch. (default value: None)')
parser.add_argument('--n_pretrain_epoch', type=int, default=100, help='The number of epochs to pre-train the autoencoder. (default value: 100)')
parser.add_argument('--n_epoch', type=int, default=1000, help='The number of epochs to train medGAN. (default value: 1000)')
parser.add_argument('--n_discriminator_update', type=int, default=2, help='The number of times to update the discriminator per epoch. (default value: 2)')
parser.add_argument('--n_generator_update', type=int, default=1, help='The number of times to update the generator per epoch. (default value: 1)')
parser.add_argument('--pretrain_batch_size', type=int, default=100, help='The size of a single mini-batch for pre-training the autoencoder. (default value: 100)')
parser.add_argument('--batch_size', type=int, default=1000, help='The size of a single mini-batch for training medGAN. (default value: 1000)')
parser.add_argument('--save_max_keep', type=int, default=0, help='The number of models to keep. Setting this to 0 will save models for every epoch. (default value: 0)')
parser.add_argument('--generate_data', type=bool, default=False, help='If True the model generates data, if False the model is trained (default value: False)')
parser.add_argument('--n_synthetic_samples', type=int, default=1000, help='The number of samples of synthetic data')
parser.add_argument('--model_file', type=str, metavar='<model_file>', default='', help='The path to the model file, in case you want to continue training. (default value: '')')
args = parser.parse_args()
return args
## for jupyter notebook
# args = parser.parse_known_args()
# return args[0]
parser = argparse.ArgumentParser()
args = parse_arguments(parser)
# Pre-load training data
train_data = np.load(args.data_file)
if not args.generate_data:
# Training GAN
tf.reset_default_graph()
with tf.Session() as sess:
# Parse the GAN type from args.model
if re.search('medWGAN', args.model):
print('****Using medWGAN****')
mg = MEDWGAN(sess,
model_name=args.model_name,
dataType=args.data_type,
inputDim=train_data.shape[1],
compressDims=args.compressor_size,
decompressDims=args.decompressor_size,
bnDecay=args.batchnorm_decay,
l2scale=args.L2,
gp_scale=args.gp_scale)
elif re.search('medBGAN', args.model):
print('****Using medBGAN****')
mg = MEDBGAN(sess,
model_name=args.model_name,
dataType=args.data_type,
inputDim=train_data.shape[1],
compressDims=args.compressor_size,
decompressDims=args.decompressor_size,
bnDecay=args.batchnorm_decay,
l2scale=args.L2)
else:
print('****Using medGAN****')
mg = MEDGAN(sess,
model_name=args.model_name,
dataType=args.data_type,
inputDim=train_data.shape[1],
compressDims=args.compressor_size,
decompressDims=args.decompressor_size,
bnDecay=args.batchnorm_decay,
l2scale=args.L2)
mg.build_model()
results = mg.train(data_path=args.data_file,
pretrainEpochs=args.n_pretrain_epoch,
nEpochs=args.n_epoch,
discriminatorTrainPeriod=args.n_discriminator_update,
generatorTrainPeriod=args.n_generator_update,
pretrainBatchSize=args.pretrain_batch_size,
batchSize=args.batch_size,
saveMaxKeep=args.save_max_keep)
else:
# Generate synthetic data
tf.reset_default_graph()
with tf.Session() as sess:
# Parse the GAN type from args.model
if re.search('medWGAN', args.model):
mg = MEDWGAN(sess,
model_name=args.model,
dataType=args.data_type,
inputDim=train_data.shape[1],
compressDims=args.compressor_size,
decompressDims=args.decompressor_size,
bnDecay=args.batchnorm_decay,
l2scale=args.L2,
gp_scale=args.gp_scale)
elif re.search('medBGAN', args.model):
mg = MEDBGAN(sess,
model_name=args.model,
dataType=args.data_type,
inputDim=train_data.shape[1],
compressDims=args.compressor_size,
decompressDims=args.decompressor_size,
bnDecay=args.batchnorm_decay,
l2scale=args.L2)
else:
mg = MEDGAN(sess,
model_name=args.model,
dataType=args.data_type,
inputDim=train_data.shape[1],
compressDims=args.compressor_size,
decompressDims=args.decompressor_size,
bnDecay=args.batchnorm_decay,
l2scale=args.L2)
mg.build_model()
mg.generateData(nSamples=args.n_synthetic_samples,
gen_from=args.model,
gen_from_ckpt=args.model_file,
out_name=str(args.n_synthetic_samples) + 'generated.npy',
batchSize=args.batch_size)