-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathapply_attack_semiinformed.py
More file actions
55 lines (40 loc) · 1.67 KB
/
apply_attack_semiinformed.py
File metadata and controls
55 lines (40 loc) · 1.67 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
DESC="""
This thing takes a reverser checkpoint, a folder full of wavs, and tries to reverse their drift.
Will plunge me into depression. This time for real.
"""
import os
from argparse import ArgumentParser
import torch
import numpy as np
from atk_tools import SemiInformedReverser as Reverser, SupervisedDataset
###########
parser = ArgumentParser(description=DESC)
parser.add_argument('--reverser_path', required=True, type=str, help='Path to a state dict of a Reverser.')
parser.add_argument('--in_fold', required=True, type=str, help='Where to find wavs to reverse.')
parser.add_argument('--out_fold', required=True, type=str, help='Where to output the files in numpy xvector format.')
parser.add_argument('--device', required=False, type=str, default='cuda')
args = parser.parse_args()
device = args.device
if not os.path.exists(args.out_fold):
os.makedirs(args.out_fold)
# Parse a list of ids from the files.
ids = [os.path.splitext(fn)[0].replace('_gen', '') for fn in os.listdir(args.in_fold) if '_gen.wav' in fn]
ds = SupervisedDataset(args.in_fold, ids, return_ids=True)
dl = torch.utils.data.DataLoader(
ds,
batch_size = 1,
shuffle = False
)
reverser = Reverser().eval().to(device)
reverser.load_state_dict(torch.load(args.reverser_path, map_location='cuda'))
tot = len(dl)
with torch.no_grad():
for i, (wav, y, uttids) in enumerate(dl):
wav = wav.to(device)
uttid = uttids[0] # batch size is hardcoded to 1
reco_xv = reverser(wav, embeddings=True)
fn = f'{uttid}_gen.xvector'
fp = os.path.join(args.out_fold, fn)
print(f'[{i+1}/{tot}] Saving to {fp}')
np.save(fp, reco_xv.cpu().numpy())
print('Done.')