-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain_eval.py
More file actions
executable file
·31 lines (24 loc) · 1.07 KB
/
train_eval.py
File metadata and controls
executable file
·31 lines (24 loc) · 1.07 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import numpy as np
import tensorflow as tf
from vae_celeb import *
from config import *
from helper import *
from utils import create_dataset
tf.logging.set_verbosity(tf.logging.INFO)
def train_input_fn_from_tfr():
return lambda: create_dataset(path='./db/train.tfrecords',
buffer_size=buffer_size,
batch_size=batch_size,
num_epochs=num_epochs)
def eval_input_fn_from_tfr():
return lambda: create_dataset(path='./db/val.tfrecords',
buffer_size=buffer_size,
batch_size=64,
num_epochs=1)
if __name__ == '__main__':
with tf.Session() as sess:
estimator = tf.estimator.Estimator(model_fn=model_fn, model_dir='./logs')
train_spec = tf.estimator.TrainSpec(input_fn=train_input_fn_from_tfr())
eval_spec = tf.estimator.EvalSpec(input_fn=eval_input_fn_from_tfr())
tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
sess.close()