-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
102 lines (75 loc) · 3.59 KB
/
train.py
File metadata and controls
102 lines (75 loc) · 3.59 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import numpy as np
import tensorflow as tf
import models
import utils
import argparse
np.random.seed(0)
if __name__ == "__main__":
available_models = [x for x in dir(models) if 'VAE' in x]
argparser = argparse.ArgumentParser()
argparser.add_argument(
'model',
choices=available_models,
help='Model class.')
argparser.add_argument(
'--code_size', type=int, default=200,
help='Dimension of latent code')
argparser.add_argument(
'--prior_proba', type=float, default=0.5,
help='Prior probability on code')
argparser.add_argument(
'--learning_rate', type=float, default=1e-4,
help='Learning rate')
argparser.add_argument(
'--lam', type=float, default=1e-3,
help='Regularisation coefficient')
argparser.add_argument(
'--tau', type=float, default=1.0,
help='Relaxation temperature')
argparser.add_argument(
'--relaxation_distribution', type=str, default='Uniform',
choices=models.GeneralizedRelaxedDVAE.DISTRIBUTION_FACTORIES.keys(),
help='Underlying distribution for Generalized Sigmoid relaxation')
argparser.add_argument(
'--noise_distribution', type=str, default='Normal',
choices=models.NoiseRelaxedDVAE.NOISE_FACTORIES.keys(),
help='Noise distribution for noise relaxation')
argparser.add_argument(
'--batch_size', type=int, default=50,
help='Batch size')
argparser.add_argument(
'--eval_batch_size', type=int, default=50,
help='Batch size for evaluation')
argparser.add_argument(
'--epochs', type=int, default=10000,
help='Number of epochs')
argparser.add_argument(
'--multisamples', type=int, nargs='+', default=[1, 100, 1000, 10000],
help='List of sample sizes for multisample ELBOs')
argparser.add_argument(
'--evaluate_every', type=int, nargs='+', default=[1, 3, 10, 100],
help='Evaluate model on ELBO every X epochs '
'(for number for each multisample ELBO)')
argparser.add_argument(
'--experiment_path', type=str, default='experiments/tmp/',
help='Path to save experiment\'s data')
argparser.add_argument(
'--subset_validation', type=int, default=1000*1000*1000,
help='Number of validation samples to compute marginal '
'log-likelihood on.')
args = argparser.parse_args()
if args.model not in available_models:
raise ValueError("Unknown model name: {}".format(args.model))
evaluate_every = dict(zip(args.multisamples, args.evaluate_every))
model_class = getattr(models, args.model)
dataset = utils.get_mnist_dataset()
with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
train_mean = dataset.train.images.mean(axis=0)
output_bias = -np.log(1. / np.clip(train_mean, 0.001, 0.999) - 1.)
dvae = model_class(code_size=args.code_size, input_size=28*28, prior_p=args.prior_proba,
lam=args.lam, tau=args.tau, relaxation_distribution=args.relaxation_distribution,
output_bias=output_bias, batch_size=args.batch_size, multisample_ks=args.multisamples,
noise_distribution=args.noise_distribution)
utils.train(dvae, dataset.train.images, dataset.validation.images, learning_rate=args.learning_rate,
epochs_total=args.epochs, eval_batch_size=args.eval_batch_size, evaluate_every=evaluate_every,
experiment_path=args.experiment_path, sess=sess, subset_validation=args.subset_validation)