-
Notifications
You must be signed in to change notification settings - Fork 4
Expand file tree
/
Copy pathevaluate.py
More file actions
71 lines (58 loc) · 2.09 KB
/
evaluate.py
File metadata and controls
71 lines (58 loc) · 2.09 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import argparse
from pathlib import Path
from omegaconf import OmegaConf
import wandb
from src.models import model_registry
from src.datasets import get_dataloaders
from src.engine.loops import run_full_test_loop
from os.path import join
def test(args, cfg):
train_loader, val_loader = get_dataloaders(**cfg.data)
model = model_registry.get_model(**cfg.model)
model.to(cfg.device)
model.eval()
metrics = run_full_test_loop(
model,
val_loader,
output_dir=Path(cfg.log_dir),
device=cfg.device,
use_amp=cfg.use_amp,
**cfg.get('test_cfg', {}),
include_full_ddf=args.include_full_ddf_metrics,
save_predictions=args.save_predictions,
save_images_with_predictions=args.save_predictions,
images_key_for_save='raw_images',
)
print(metrics)
if args.log_wandb:
wandb.init(
project=cfg.logger_kw.wandb_project,
job_type='test',
config=OmegaConf.to_object(cfg),
)
wandb.log({
f'{k}/val': v for k, v in metrics.items()
})
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--log_dir')
parser.add_argument('--train_dir')
parser.add_argument('--config', '-c')
parser.add_argument('--checkpoint', help='Checkpoint to load')
parser.add_argument('--log_wandb', action='store_true')
parser.add_argument('--include_full_ddf_metrics', action='store_true')
parser.add_argument('--save_predictions', action='store_true')
args = parser.parse_args()
if args.train_dir:
if args.log_dir is None:
args.log_dir = join(args.train_dir, 'test')
if args.checkpoint is None:
args.checkpoint = join(args.train_dir, 'checkpoint', 'best.pt')
if args.config is None:
args.config = join(args.train_dir, 'config_resolved.yaml')
config = OmegaConf.load(args.config)
if args.log_dir:
config.log_dir = args.log_dir
if args.checkpoint:
config.model.checkpoint = args.checkpoint
test(args, config)