-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathconfig_template.py
More file actions
38 lines (32 loc) · 2.57 KB
/
config_template.py
File metadata and controls
38 lines (32 loc) · 2.57 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import tensorflow as tf
flags = tf.flags
# data
flags.DEFINE_string('data_dir', '/home/ubuntu/data_store/training_data/10', 'data directory. Should contain train.txt/valid.txt/test.txt with input data')
flags.DEFINE_string('train_dir', 'cv', 'training directory (models and summaries are saved there periodically)')
flags.DEFINE_string('load_model', None, '(optional) filename of the model to load. Useful for re-starting training from a checkpoint')
flags.DEFINE_string('master_file', 'main/main_10.csv', 'path to master .csv containing locations of all training data')
# model params
flags.DEFINE_string ('model_choice', 'lstm', 'model choice')
flags.DEFINE_string ('embedding_path', None, 'pretrained emebdding path')
flags.DEFINE_integer('rnn_size', 650, 'size of LSTM internal state')
flags.DEFINE_integer('highway_layers', 2, 'number of highway layers')
flags.DEFINE_integer('word_embed_size', 3, 'dimensionality of word embeddings')
flags.DEFINE_string ('kernels', '[1,2,3,4,5,6,7]', 'CNN kernel widths')
flags.DEFINE_string ('kernel_features', '[50,100,150,200,200,200,200]', 'number of features in the CNN kernel')
flags.DEFINE_integer('rnn_layers', 2, 'number of layers in the LSTM')
flags.DEFINE_float ('dropout', 0.5, 'dropout. 0 = no dropout')
# optimization
flags.DEFINE_float ('learning_rate_decay', 0.5, 'learning rate decay')
flags.DEFINE_float ('learning_rate', 1.0, 'starting learning rate')
flags.DEFINE_float ('decay_when', 1.0, 'decay if validation perplexity does not improve by more than this much')
flags.DEFINE_float ('param_init', 0.05, 'initialize parameters at')
flags.DEFINE_integer('batch_size', 5, 'number of sequences to train on in parallel')
flags.DEFINE_integer('max_epochs', 25, 'number of full passes through the training data')
flags.DEFINE_float ('max_grad_norm', 5.0, 'normalize gradients at')
flags.DEFINE_integer('max_doc_length', 3000, 'max_doc_length')
flags.DEFINE_integer('max_line_length', 130, 'maximum sentence length')
# bookkeeping
flags.DEFINE_integer('seed', 3435, 'random number generator seed')
flags.DEFINE_integer('print_every', 5, 'how often to print current loss')
flags.DEFINE_string ('EOS', '+', '<EOS> symbol. should be a single unused character (like +) for PTB and blank for others')
FLAGS = flags.FLAGS