-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmodels.py
More file actions
83 lines (67 loc) · 3.04 KB
/
models.py
File metadata and controls
83 lines (67 loc) · 3.04 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
"""
Authors: Massimiliano Todisco, Michele Panariello and chatGPT
Email: https://mailhide.io/e/Qk2FFM4a
Date: August 2024
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import speechbrain as sb
from speakerlab.models.campplus.DTDNN import CAMPPlus
from speakerlab.process.processor import FBank
import numpy as np
class Malacopula(nn.Module):
def __init__(self, num_layers=5, in_channels=1, out_channels=1, kernel_size=1025, padding='same', bias=False):
super().__init__()
self.kernel_size = kernel_size
self.convs = nn.ModuleList([
nn.utils.parametrizations.weight_norm(nn.Conv1d(in_channels, out_channels, kernel_size, padding=padding, bias=bias))
for _ in range(num_layers)
])
self.bartlett_window = self.create_bartlett_window()
self.apply_bartlett_window()
def create_bartlett_window(self):
bartlett_window = torch.bartlett_window(self.kernel_size)
return bartlett_window.unsqueeze(0).unsqueeze(0)
def apply_bartlett_window(self):
for conv in self.convs:
with torch.no_grad():
bartlett_window = self.bartlett_window.to(conv.weight.device)
conv.weight *= bartlett_window
def save_filter_coefficients(self, directory_path):
for i, conv in enumerate(self.convs, start=1):
bartlett_window = self.bartlett_window.to(conv.weight.device)
filter_weights = (conv.weight.data * bartlett_window).cpu().numpy()
filter_weights = np.squeeze(filter_weights)
filepath = f"{directory_path}/filter_{i}.txt"
np.savetxt(filepath, filter_weights, fmt='%.6f', delimiter=' ')
def forward(self, x):
outputs = []
self.apply_bartlett_window()
for i, conv in enumerate(self.convs, start=1):
powered_x = torch.pow(x, i)
output = conv(powered_x)
outputs.append(output)
summed_output = torch.sum(torch.stack(outputs, dim=0), dim=0)
max_abs_value = torch.max(torch.abs(summed_output))
norm_output = summed_output / max_abs_value
return norm_output
class CosineDistanceLoss(nn.Module):
def __init__(self):
super(CosineDistanceLoss, self).__init__()
def forward(self, input, target):
assert input.shape == target.shape, "Input and target must have the same shape"
cosine_similarity = F.cosine_similarity(input, target, dim=-1)
cosine_distance = 1 - cosine_similarity
loss = cosine_distance.mean()
return loss
def load_models(device):
model_ecapa = sb.inference.speaker.EncoderClassifier.from_hparams(source="speechbrain/spkrec-ecapa-voxceleb", run_opts={"device":device})
model_ecapa.eval()
model_path = 'pretrained_models/campplus_voxceleb.bin'
d = torch.load(model_path)
model_campp = CAMPPlus().to(device)
feature_extractor_campp = FBank(80, sample_rate=16000, mean_nor=True)
model_campp.load_state_dict(d)
model_campp.eval()
return model_ecapa, model_campp, feature_extractor_campp