-
Notifications
You must be signed in to change notification settings - Fork 14
Expand file tree
/
Copy pathtest.py
More file actions
55 lines (43 loc) · 1.74 KB
/
test.py
File metadata and controls
55 lines (43 loc) · 1.74 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
from tqdm import tqdm
import os
from dataset.bi_sequence_dataset import get_dataloader
from config.config import Config
from config.file_utils import ensure_dir
from trainer.trainer import TrainerED
import torch
import numpy as np
import h5py
import json
from config.macro import *
def main():
cfg = Config('test')
tr_agent = TrainerED(cfg)
# load from checkpoint
tr_agent.load_ckpt(cfg.ckpt)
tr_agent.net.eval()
# create dataloader
test_loader = get_dataloader("test", cfg)
print(f"Total number of test batch: {len(test_loader)}")
# evaluate
for i, data in enumerate(test_loader):
cad_data = data['cad']
gt_vec = torch.cat([cad_data['command'].unsqueeze(-1), cad_data['args']], dim=-1).squeeze(1).detach().cpu().numpy()
cad_commands_ = gt_vec[:, :, 0]
batch_size = cad_data['command'].shape[0]
with torch.no_grad():
outputs, _ = tr_agent.forward(data)
batch_outputs = tr_agent.logits2vec(outputs)
pbar = tqdm(total=batch_size, desc='BATCH[{}]'.format(i))
for j in range(batch_size):
out = batch_outputs[j]
seq_len = cad_commands_[j].tolist().index(CAD_EOS_IDX)
data_id = data['id'][j].split('/')[-1]
save_dir = os.path.join(cfg.exp_dir, 'test_results')
ensure_dir(save_dir)
save_path = os.path.join(save_dir, '{}_vec.h5'.format(data_id))
with h5py.File(save_path, 'w') as f:
f.create_dataset('out_vec', data=out[:seq_len], dtype=np.int32)
f.create_dataset('gt_vec', data=gt_vec[j][:seq_len], dtype=np.int32)
pbar.update(1)
if __name__ == '__main__':
main()