diff --git a/evaluate_model.py b/evaluate_model.py index e08a5d7..71927c3 100644 --- a/evaluate_model.py +++ b/evaluate_model.py @@ -84,7 +84,8 @@ "evaluate_single_checkpoint", "", "If set, defines the checkpoint file prefix to evaluate " - 'and then exit, e.g. "model.ckpt-97231".', + 'and then exit, e.g. "model.ckpt-97231". Use "latest" to evaluate ' + "the latest checkpoint.", ) FLAGS = flags.FLAGS @@ -146,7 +147,14 @@ def evaluate(hps, result_dir, tuner=None, trial_name=None): ) if FLAGS.evaluate_single_checkpoint: - meta_file = FLAGS.evaluate_single_checkpoint + ".meta" + if FLAGS.evaluate_single_checkpoint == "latest": + latest_ckpt = tf.train.latest_checkpoint(result_dir) + if not latest_ckpt: + logging.warning("No checkpoints in %s", result_dir) + return + meta_file = os.path.basename(latest_ckpt) + ".meta" + else: + meta_file = FLAGS.evaluate_single_checkpoint + ".meta" else: potential_files = tf.gfile.ListDirectory(result_dir)