-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathexport_wsj02mix.py
More file actions
211 lines (183 loc) · 7.12 KB
/
export_wsj02mix.py
File metadata and controls
211 lines (183 loc) · 7.12 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
import torch
import torch.nn as nn
from speechbrain.inference.separation import SepformerSeparation
import os
from huggingface_hub import snapshot_download
import pathlib
import shutil
import urllib.request
import huggingface_hub
import speechbrain.utils.fetching as sb_fetch
import speechbrain.inference.interfaces as sb_int
import speechbrain.utils.parameter_transfer as sb_pt
def _fetch_copy(
filename,
source,
savedir="./pretrained_model_checkpoints",
overwrite=False,
save_filename=None,
use_auth_token=False,
revision=None,
huggingface_cache_dir=None,
):
if save_filename is None:
save_filename = filename
savedir = pathlib.Path(savedir)
savedir.mkdir(parents=True, exist_ok=True)
fetch_from = None
if isinstance(source, sb_fetch.FetchSource):
fetch_from, source = source
sourcefile = f"{source}/{filename}"
destination = savedir / save_filename
if destination.exists() and not overwrite:
return destination
if pathlib.Path(source).is_dir() and fetch_from not in [
sb_fetch.FetchFrom.HUGGING_FACE,
sb_fetch.FetchFrom.URI,
]:
sourcepath = pathlib.Path(sourcefile).absolute()
sb_fetch._missing_ok_unlink(destination)
shutil.copyfile(sourcepath, destination)
return destination
if (
str(source).startswith("http:") or str(source).startswith("https:")
) or fetch_from is sb_fetch.FetchFrom.URI:
urllib.request.urlretrieve(sourcefile, destination)
else:
try:
fetched_file = huggingface_hub.hf_hub_download(
repo_id=source,
filename=filename,
use_auth_token=use_auth_token,
revision=revision,
cache_dir=huggingface_cache_dir,
)
except Exception as e:
raise ValueError("File not found on HF hub") from e
sourcepath = pathlib.Path(fetched_file).absolute()
sb_fetch._missing_ok_unlink(destination)
shutil.copyfile(sourcepath, destination)
return destination
sb_fetch.fetch = _fetch_copy
sb_int.fetch = _fetch_copy
sb_pt.fetch = _fetch_copy
# ==========================================
# 1. 更加健壮的模型包装器
# [...](asc_slot://start-slot-1)==========================================
class SepFormerWrapper(nn.Module):
def __init__(self, sb_model):
super(SepFormerWrapper, self).__init__()
self.encoder = sb_model.mods.encoder
self.masknet = sb_model.mods.masknet
self.decoder = sb_model.mods.decoder
self.num_spks = sb_model.hparams.num_spks
def forward(self, mix):
"""
Args:
mix: (Batch, Time)
Returns:
est_source: (Batch, Time, Spks)
"""
# --- 关键修改 1: 动态获取 Batch Size ---
# 使用 .size(0) 而不是 .shape[0],这样 ONNX 会将其视为动态变量而不是常量
batch_size = mix.size(0)
# 1. Encoder: (Batch, Time) -> (Batch, Feats, Time)
mix_w = self.encoder(mix)
# 2. MaskNet
est_mask = self.masknet(mix_w)
# 3. [...](asc_slot://start-slot-3)Apply Mask
# (Batch, Feats, Time) -> (Batch, 1, Feats, Time)
mix_w_expanded = mix_w.unsqueeze(1)
# (Batch, 1, F, T) * (Batch, Spks, F, T) -> (Batch, Spks, F, T)
sep_h = mix_w_expanded * est_mask
# 准备进入 Decoder
# 获取特征维度 (F) 和时间维度 (T)
# 注意:这里 F 和 T 可以用 shape 获取,因为特征维通常是固定的,时间维是动态轴会自动处理
feats = sep_h.shape[2]
# time = sep_h.shape[3]
# 4. Reshape for Decoder
# 将 (Batch, Spks, ...) 合并为 (Batch * Spks, ...)
# 使用 reshape 而不是 view,更加稳健
sep_h_flat = sep_h.reshape(batch_size * self.num_spks, feats, -1)
# 5. Decoder
est_source_flat = self.decoder(sep_h_flat)
# --- 关键修改 2: 处理 Decoder 可能的 3D 输出 (B, 1, T) ---
if est_source_flat.dim() == 3 and est_source_flat.size(1) == 1:
est_source_flat = est_source_flat.squeeze(1)
elif est_source_flat.dim() == 3 and est_source_flat.size(2) == 1:
est_source_flat = est_source_flat.squeeze(2)
# 此时 est_source_flat 应该是 (Batch * Spks, Time)
# 6. 恢复形状
# (Batch * Spks, Time) -> (Batch, Spks, Time)
# 使用动态的 batch_size 变量
est_source = est_source_flat.view(batch_size, self.num_spks, -1)
# 7. 转置为 (Batch, Time, Spks)
est_source = est_source.transpose(1, 2)
return est_source
def export_onnx():
model_source = "speechbrain/sepformer-wsj02mix"
onnx_path = "sepformer_wsj02mix.onnx"
print(f"--- 正在加载模型: {model_source} ---")
savedir = "pretrained_models/sepformer-wsj02mix"
try:
sb_model = SepformerSeparation.from_hparams(
source=model_source,
savedir=savedir,
)
except OSError as e:
if getattr(e, "winerror", None) == 1314:
os.makedirs(savedir, exist_ok=True)
snapshot_download(
repo_id=model_source,
local_dir=savedir,
local_dir_use_symlinks=False,
)
sb_model = SepformerSeparation.from_hparams(
source=savedir,
savedir=savedir,
)
else:
raise
model_wrapper = SepFormerWrapper(sb_model)
model_wrapper.eval()
# 虚拟输入
dummy_input = torch.randn(1, 16000)
print(f"--- 正在导出到 {onnx_path} ---")
try:
torch.onnx.export(
model_wrapper,
dummy_input,
onnx_path,
export_params=True,
opset_version=17,
do_constant_folding=True,
input_names=['mix'],
output_names=['est_sources'],
dynamic_axes={
'mix': {0: 'batch_size', 1: 'time'},
'est_sources': {0: 'batch_size', 1: 'time'}
}
)
print(f"✅ 导出成功!文件: {onnx_path}")
except Exception as e:
print(f"❌ 导出失败: {e}")
return
# [...](asc_slot://start-slot-9)--- 验证步骤 ---
try:
import onnxruntime as ort
import numpy as np
print("\n--- 最终验证 ---")
ort_session = ort.InferenceSession(onnx_path)
# 测试 Batch=1
test_input = np.random.randn(1, 24000).astype(np.float32)
ort_outs = ort_session.run(None, {'mix': test_input})
out_shape = ort_outs[0].shape
print(f"输入 (1, 24000) -> 输出 {out_shape}")
if len(out_shape) == 3 and out_shape[0] == 1 and out_shape[2] == 2:
print("✅ 验证完美通过!格式为 (Batch, Time, Spks)")
else:
print("❌ 格式依然不对,请检查上面的输出形状。")
except Exception as e:
print(f"验证出错: {e}")
if __name__ == "__main__":
export_onnx()