Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 3 additions & 6 deletions wetts/vits/export_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from pathlib import Path

import torch

from model.models import SynthesizerTrn
from utils import task

Expand Down Expand Up @@ -99,7 +98,7 @@ def main():
args=dummy_input,
f=add_prefix(args.onnx_model, 'encoder_'),
input_names=["input", "input_lengths", "scales", "sid"],
output_names=["z", "g"],
output_names=["z"],
dynamic_axes={
"input": {
0: "batch",
Expand All @@ -115,22 +114,20 @@ def main():
0: "batch"
},
"z": {0: "batch", 2: "L"},
"g": {0: "batch"},
},
opset_version=13,
verbose=False,
)
net_g.forward = net_g.export_decoder_forward
dummy_input = (z, g)
dummy_input = (z, sid)
torch.onnx.export(
model=net_g,
args=dummy_input,
f=add_prefix(args.onnx_model, 'decoder_'),
input_names=["z", "g"],
input_names=["z", "sid"],
output_names=["output"],
dynamic_axes={
"z": {0: "batch", 2: "L"},
"g": {0: "batch"},
"output": {
0: "batch",
1: "audio",
Expand Down
5 changes: 3 additions & 2 deletions wetts/vits/inference_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,12 +136,13 @@ def main():

# Copy from: runtime/cpu_triton_stream/model_repo/stream_tts/1/model.py
def tts(ort_inputs):
z, g = encoder_ort_sess.run(None, ort_inputs)
sid = ort_inputs['sid']
z = encoder_ort_sess.run(None, ort_inputs)[0]
z_chunks = get_chunks(z, args.chunk_size, args.pad_size)
num_chunks = len(z_chunks)
audios = []
for i, chunk in enumerate(z_chunks):
decoder_inputs = {"z": chunk, "g": g}
decoder_inputs = {"z": chunk, "sid": sid}
audio_chunk = decoder_ort_sess.run(None, decoder_inputs)[0]
audio_clip = depadding(audio_chunk.reshape(1, -1), num_chunks,
i, args.chunk_size, args.pad_size, 256)
Expand Down
14 changes: 6 additions & 8 deletions wetts/vits/model/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,12 @@
import time

import torch
from torch import nn

from model.decoders import Generator, VocosGenerator
from model.duration_predictors import (
StochasticDurationPredictor,
DurationPredictor
)
from model.encoders import TextEncoder, PosteriorEncoder
from model.duration_predictors import (DurationPredictor,
StochasticDurationPredictor)
from model.encoders import PosteriorEncoder, TextEncoder
from model.flows import AVAILABLE_FLOW_TYPES, ResidualCouplingTransformersBlock
from torch import nn
from utils import commons, monotonic_align


Expand Down Expand Up @@ -358,7 +355,8 @@ def export_encoder_forward(self, x, x_lengths, scales, sid):
)
return z, g

def export_decoder_forward(self, z, g):
def export_decoder_forward(self, z, sid):
g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
return self.dec(z, g=g)

# currently vits-2 is not capable of voice conversion
Expand Down