-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathsave_sasv_score.py
More file actions
executable file
·126 lines (107 loc) · 3.93 KB
/
save_sasv_score.py
File metadata and controls
executable file
·126 lines (107 loc) · 3.93 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import argparse
import json
import os
import pickle as pk
from pathlib import Path
import numpy as np
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
from aasist.data_utils import Dataset_ASVspoof2019_devNeval, Dataset_ASVspoof2019_train
from aasist.models.AASIST import Model as AASISTModel
# from ECAPATDNN.model import ECAPA_TDNN
from ResNetModels.ResNetSE34V2 import MainModel
from models.ResNetSE34V2_AASIST_OC import Model
from dataloaders.Datalaoder_sasv import get_evalset
from loss.loss import OCSoftmax
from utils import load_parameters
# list of dataset partitions
# SET_PARTITION = ["dev"]
SET_PARTITION = ["eval"]
database = "/path/to/your/LA/"
# directories of each dataset partition
SET_DIR = {
"trn": database + "ASVspoof2019_LA_train/",
"dev": database + "ASVspoof2019_LA_dev/",
"eval": database + "ASVspoof2019_LA_eval/",
}
sasv_eval_trial = "./protocols/ASVspoof2019.LA.asv.eval.gi.trl.txt"
sasv_eval_trn_trial = "./protocols/ASVspoof2019.LA.asv.eval.gi.trn.txt"
utt2spk = {}
with open(sasv_eval_trial, "r") as f:
sasv_eval_trial = f.readlines()
with open(sasv_eval_trn_trial, "r") as f:
sasv_eval_trn_trial = f.readlines()
for line in sasv_eval_trn_trial:
tmp = line.strip().split(" ")
spk = tmp[0]
utts = tmp[1]
utt2spk[spk] = utts
eval_ds = get_evalset(sasv_eval_trial, utt2spk, 'eval', database)
def save_embeddings(
set_name, model, loss, device, config_name
):
loader = DataLoader(
eval_ds, batch_size=10, shuffle=False, drop_last=False, pin_memory=True
)
preds, keys, spkrs, utter_id = [], [], [], []
for wave_asv_enr, wave_asv_tst, key, spkmd, utt_id, labels in tqdm(loader):
wave_asv_enr = wave_asv_enr.to(device)
wave_asv_tst = wave_asv_tst.to(device)
labels = labels.to(device)
with torch.no_grad():
feats, lfcc_outputs = model.validate(wave_asv_enr, wave_asv_tst)
_, pred = loss(feats, labels)
preds.append(pred)
keys.extend(list(key))
spkrs.extend(list(spkmd))
utter_id.extend(list(utt_id))
preds = torch.cat(preds, dim=0).detach().cpu().numpy()
os.makedirs(config_name, exist_ok=True)
with open(config_name + "/" + "score_SASV_" + set_name + '.txt', '+a') as fh:
for s, u, k, cm in zip(spkrs, utter_id, keys, preds):
fh.write('{} {} {} {}\n'.format(s, u, k, cm))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"-aasist_config", type=str, default="./aasist/config/AASIST.conf"
)
parser.add_argument(
"-aasist_weight", type=str, default="./aasist/models/weights/AASIST.pth"
)
parser.add_argument(
"--model", type=str, default="./pre_trained_models/ResNetSE34V2_AASIST_OC.ckpt"
)
args = parser.parse_args()
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device: {}".format(device))
PATH = args.model
# address = PATH.split('/')
# config_name = address[9]
config_name = 'EXP-score'
with open(args.aasist_config, "r") as f_json:
config = json.loads(f_json.read())
loss = OCSoftmax(feat_dim=256, r_real=0.8, r_fake=0.2, alpha=10.0)
model = Model()
checkpoint = torch.load(PATH)
dicts = {}
loss_dict = {}
for k in checkpoint['state_dict']:
if k[:6] == 'model.':
k_ = k[6:]
dicts[k_] = checkpoint['state_dict'][k]
loss_dict['center'] = checkpoint['state_dict']['loss.center']
model.load_state_dict(dicts)
model.to(device)
model.eval()
loss.load_state_dict(loss_dict)
loss.to(device)
loss.eval()
for set_name in SET_PARTITION:
save_embeddings(
set_name,
model,
loss,
device,
config_name,
)