diff --git a/tensorrec/input_utils.py b/tensorrec/input_utils.py index 405a648..021a836 100644 --- a/tensorrec/input_utils.py +++ b/tensorrec/input_utils.py @@ -123,5 +123,5 @@ def parse_tensorrec_tfrecord(example_proto): return (parsed_features['row_index'], parsed_features['col_index'], parsed_features['values'], parsed_features['d0'], parsed_features['d1']) - dataset = tf.data.TFRecordDataset(tfrecord_path).map(parse_tensorrec_tfrecord) + dataset = tf.data.TFRecordDataset(tfrecord_path).map(parse_tensorrec_tfrecord,num_parallel_calls=tf.data.experimental.AUTOTUNE) return dataset