-
Notifications
You must be signed in to change notification settings - Fork 5
Expand file tree
/
Copy pathdcrnn_train.py
More file actions
106 lines (93 loc) · 5.16 KB
/
dcrnn_train.py
File metadata and controls
106 lines (93 loc) · 5.16 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import json
import pandas as pd
import tensorflow as tf
from lib import log_helper
from lib.dcrnn_utils import load_graph_data
from model.dcrnn_supervisor import DCRNNSupervisor
# flags
flags = tf.app.flags
FLAGS = flags.FLAGS
flags.DEFINE_integer('batch_size', -1, 'Batch size')
flags.DEFINE_integer('cl_decay_steps', -1,
'Parameter to control the decay speed of probability of feeding groundth instead of model output.')
flags.DEFINE_string('config_filename', None, 'Configuration filename for restoring the model.')
flags.DEFINE_integer('epochs', -1, 'Maximum number of epochs to train.')
flags.DEFINE_string('filter_type', None, 'laplacian/random_walk/dual_random_walk.')
flags.DEFINE_string('graph_pkl_filename', 'data/sensor_graph/adj_mat.pkl',
'Pickle file containing: sensor_ids, sensor_id_to_ind_map, dist_matrix')
flags.DEFINE_integer('horizon', -1, 'Maximum number of timestamps to prediction.')
flags.DEFINE_float('l1_decay', -1.0, 'L1 Regularization')
flags.DEFINE_float('lr_decay', -1.0, 'Learning rate decay.')
flags.DEFINE_integer('lr_decay_epoch', -1, 'The epoch that starting decaying the parameter.')
flags.DEFINE_integer('lr_decay_interval', -1, 'Interval beteween each deacy.')
flags.DEFINE_float('learning_rate', -1, 'Learning rate. -1: select by hyperopt tuning.')
flags.DEFINE_string('log_dir', None, 'Log directory for restoring the model from a checkpoint.')
flags.DEFINE_string('loss_func', None, 'MSE/MAPE/RMSE_MAPE: loss function.')
flags.DEFINE_float('min_learning_rate', -1, 'Minimum learning rate')
flags.DEFINE_integer('nb_weeks', 43, 'How many week\'s data should be used for train/test.')
flags.DEFINE_integer('patience', -1,
'Maximum number of epochs allowed for non-improving validation error before early stopping.')
flags.DEFINE_integer('seq_len', -1, 'Sequence length.')
flags.DEFINE_integer('test_every_n_epochs', -1, 'Run model on the testing dataset every n epochs.')
flags.DEFINE_string('traffic_df_filename', 'data/weather_ts.csv',
'Path to hdf5 pandas.DataFrame.')
flags.DEFINE_bool('use_cpu_only', False, 'Set to true to only use cpu.')
flags.DEFINE_bool('use_curriculum_learning', None, 'Set to true to use Curriculum learning in decoding stage.')
flags.DEFINE_integer('verbose', -1, '1: to log individual sensor information.')
def main():
# Reads graph data.
with open(FLAGS.config_filename) as f:
supervisor_config = json.load(f)
logger = log_helper.get_logger(supervisor_config.get('base_dir'), 'info.log')
logger.info('Loading graph from: ' + FLAGS.graph_pkl_filename)
sensor_ids, sensor_id_to_ind, adj_mx = load_graph_data(FLAGS.graph_pkl_filename)
adj_mx[adj_mx < 0.1] = 0
logger.info('Loading traffic data from: ' + FLAGS.traffic_df_filename)
traffic_df_filename = FLAGS.traffic_df_filename
traffic_reading_df = pd.read_csv(traffic_df_filename)
#modify by AG
sensors_ids = ['time_stamp',
'aqi_W San Gabriel Vly',
'aqi_E San Fernando Vly',
'aqi_SW Coastal LA',
'aqi_San Gabriel Mts',
'aqi_SW San Bernardino',
'aqi_Southeast LA CO',
'aqi_South Coastal LA',
'aqi_Central LA CO',
'aqi_NW Coastal LA',
'aqi_Santa Clarita Vly',
'aqi_W San Fernando Vly',
'aqi_E San Gabriel V-2']
#test
test = pd.read_hdf('data/df_highway_2012_4mon_sample.h5')
test = test.ix[:, ]
traffic_reading_df = traffic_reading_df.ix[:, sensor_ids]
supervisor_config['use_cpu_only'] = FLAGS.use_cpu_only
if FLAGS.log_dir:
supervisor_config['log_dir'] = FLAGS.log_dir
if FLAGS.use_curriculum_learning is not None:
supervisor_config['use_curriculum_learning'] = FLAGS.use_curriculum_learning
if FLAGS.loss_func:
supervisor_config['loss_func'] = FLAGS.loss_func
if FLAGS.filter_type:
supervisor_config['filter_type'] = FLAGS.filter_type
# Overwrites space with specified parameters.
for name in ['batch_size', 'cl_decay_steps', 'epochs', 'horizon', 'learning_rate', 'l1_decay',
'lr_decay', 'lr_decay_epoch', 'lr_decay_interval', 'learning_rate', 'min_learning_rate',
'patience', 'seq_len', 'test_every_n_epochs', 'verbose']:
if getattr(FLAGS, name) >= 0:
supervisor_config[name] = getattr(FLAGS, name)
tf_config = tf.ConfigProto()
if FLAGS.use_cpu_only:
tf_config = tf.ConfigProto(device_count={'GPU': 0})
tf_config.gpu_options.allow_growth = True
with tf.Session(config=tf_config) as sess:
supervisor = DCRNNSupervisor(traffic_reading_df=traffic_reading_df, adj_mx=adj_mx,
config=supervisor_config)
supervisor.train(sess=sess)
if __name__ == '__main__':
main()