Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions runtime/core/bin/tts_main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include <memory>
#include <vector>

#include "gflags/gflags.h"
#include "glog/logging.h"
#include "processor/wetext_processor.h"
Expand Down Expand Up @@ -47,7 +50,8 @@ DEFINE_string(g2p_prosody_model, "", "g2p prosody model file");
// VITS
DEFINE_string(speaker2id, "", "speaker to id");
DEFINE_string(phone2id, "", "phone to id");
DEFINE_string(vits_model, "", "e2e tts model file");
DEFINE_string(vits_encoder_model, "", "vits encoder model path");
DEFINE_string(vits_decoder_model, "", "vits decoder model path");
DEFINE_int32(sampling_rate, 22050, "sampling rate of pcm");

DEFINE_string(sname, "", "speaker name");
Expand All @@ -73,8 +77,8 @@ int main(int argc, char* argv[]) {
FLAGS_g2p_prosody_model, FLAGS_vocab, FLAGS_char2pinyin, FLAGS_pinyin2id,
FLAGS_pinyin2phones, g2p_en);
auto model = std::make_shared<wetts::TtsModel>(
FLAGS_vits_model, FLAGS_speaker2id, FLAGS_phone2id, FLAGS_sampling_rate,
tn, g2p_prosody);
FLAGS_vits_encoder_model, FLAGS_vits_decoder_model, FLAGS_speaker2id,
FLAGS_phone2id, FLAGS_sampling_rate, tn, g2p_prosody);

std::vector<float> audio;
int sid = model->GetSid(FLAGS_sname);
Expand Down
38 changes: 25 additions & 13 deletions runtime/core/model/tts_model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,49 +28,58 @@

namespace wetts {

TtsModel::TtsModel(const std::string& model_path, const std::string& speaker2id,
const std::string& phone2id, const int sampling_rate,
TtsModel::TtsModel(const std::string& encoder_model_path,
const std::string& decoder_model_path,
const std::string& speaker2id, const std::string& phone2id,
const int sampling_rate,
std::shared_ptr<wetext::Processor> tn,
std::shared_ptr<G2pProsody> g2p_prosody)
: model_(model_path),
: encoder_(encoder_model_path),
decoder_(decoder_model_path),
tn_(std::move(tn)),
g2p_prosody_(std::move(g2p_prosody)) {
sampling_rate_ = sampling_rate;
ReadTableFile(phone2id, &phone2id_);
ReadTableFile(speaker2id, &speaker2id_);
}

void TtsModel::Forward(const std::vector<int64_t>& phonemes, const int sid,
std::vector<float>* audio) {
std::vector<Ort::Value> TtsModel::ForwardEncoder(
const std::vector<int64_t>& phonemes, const int sid) {
int num_phones = phonemes.size();
const int64_t inputs_shape[] = {1, num_phones};
auto inputs_ort = Ort::Value::CreateTensor<int64_t>(
model_.memory_info(), const_cast<int64_t*>(phonemes.data()), num_phones,
encoder_.memory_info(), const_cast<int64_t*>(phonemes.data()), num_phones,
inputs_shape, 2);

std::vector<int64_t> inputs_len = {num_phones};
const int64_t inputs_len_shape[] = {1};
auto inputs_len_ort =
Ort::Value::CreateTensor<int64_t>(model_.memory_info(), inputs_len.data(),
inputs_len.size(), inputs_len_shape, 1);
auto inputs_len_ort = Ort::Value::CreateTensor<int64_t>(
encoder_.memory_info(), inputs_len.data(), inputs_len.size(),
inputs_len_shape, 1);

std::vector<float> scales = {0.667, 1.0, 0.8};
const int64_t scales_shape[] = {1, 3};
auto scales_ort = Ort::Value::CreateTensor<float>(
model_.memory_info(), scales.data(), scales.size(), scales_shape, 2);
encoder_.memory_info(), scales.data(), scales.size(), scales_shape, 2);

std::vector<int64_t> spk_id = {sid};
const int64_t spk_id_shape[] = {1};
auto sid_ort = Ort::Value::CreateTensor<int64_t>(
model_.memory_info(), spk_id.data(), spk_id.size(), spk_id_shape, 1);
encoder_.memory_info(), spk_id.data(), spk_id.size(), spk_id_shape, 1);

std::vector<Ort::Value> ort_inputs;
ort_inputs.push_back(std::move(inputs_ort));
ort_inputs.push_back(std::move(inputs_len_ort));
ort_inputs.push_back(std::move(scales_ort));
ort_inputs.push_back(std::move(sid_ort));

auto outputs_ort = model_.Run(ort_inputs);
auto outputs_ort = encoder_.Run(ort_inputs);
return outputs_ort;
}

void TtsModel::ForwardDecoder(const std::vector<Ort::Value>& inputs,
std::vector<float>* audio) {
auto outputs_ort = decoder_.Run(inputs);
int len = outputs_ort[0].GetTensorTypeAndShapeInfo().GetShape()[2];
const float* outputs = outputs_ort[0].GetTensorData<float>();
audio->assign(outputs, outputs + len);
Expand All @@ -97,7 +106,9 @@ void TtsModel::Synthesis(const std::string& text, const int sid,
inputs.emplace_back(phone2id_[phone]);
}
LOG(INFO) << "phone sequence " << ss.str();
Forward(inputs, sid, audio);

auto outputs = ForwardEncoder(inputs, sid);
ForwardDecoder(outputs, audio);
for (size_t i = 0; i < audio->size(); i++) {
(*audio)[i] *= 32767.0;
}
Expand All @@ -112,4 +123,5 @@ int TtsModel::GetSid(const std::string& name) {
}
return speaker2id_[name];
}

} // namespace wetts
12 changes: 8 additions & 4 deletions runtime/core/model/tts_model.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,20 +29,24 @@ namespace wetts {

class TtsModel {
public:
explicit TtsModel(const std::string& model_path,
explicit TtsModel(const std::string& encoder_model_path,
const std::string& decodder_model_path,
const std::string& speaker2id, const std::string& phone2id,
const int sampling_rate,
std::shared_ptr<wetext::Processor> processor,
std::shared_ptr<G2pProsody> g2p_prosody);
void Forward(const std::vector<int64_t>& phonemes, const int sid,
std::vector<float>* audio);
std::vector<Ort::Value> ForwardEncoder(const std::vector<int64_t>& phonemes,
const int sid);
void ForwardDecoder(const std::vector<Ort::Value>& inputs,
std::vector<float>* audio);
void Synthesis(const std::string& text, const int sid,
std::vector<float>* audio);
int GetSid(const std::string& name);
int sampling_rate() const { return sampling_rate_; }

private:
OnnxModel model_;
// VITS encoder & decoder
OnnxModel encoder_, decoder_;
int sampling_rate_;
std::unordered_map<std::string, int> phone2id_;
std::unordered_map<std::string, int> speaker2id_;
Expand Down