Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 23 additions & 4 deletions runtime/core/frontend/g2p_prosody.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include <fstream>
#include <memory>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>

Expand All @@ -31,6 +32,16 @@

namespace wetts {

// 标点到韵律值的映射规则
// 逗号、冒号映射为 #3,顿号映射为 #2
static const std::unordered_map<std::string, std::string> kPunctProsodyMap = {
{",", "#3"}, // 英文逗号
{",", "#3"}, // 中文逗号
{":", "#3"}, // 英文冒号
{":", "#3"}, // 中文冒号
{"、", "#2"}, // 中文顿号
};

G2pProsody::G2pProsody(const std::string& g2p_prosody_model,
const std::string& g2p_prosody_vocab,
const std::string& lexicon_file,
Expand Down Expand Up @@ -200,15 +211,23 @@ void G2pProsody::Compute(const std::string& str,
phonemes->insert(phonemes->end(), pinyin.begin(), pinyin.end());
phonemes->emplace_back(prosody[0]);
} else {
// Not English, Not in Lexicon, ignore now
// TODO(Binbin Zhang): Deal punct
LOG(WARNING) << "Ignore word " << word;
// Deal with OOV & punctuation
auto it = kPunctProsodyMap.find(word);
if (it != kPunctProsodyMap.end()) {
if (!phonemes->empty()) {
phonemes->back() = it->second;
}
} else {
LOG(WARNING) << "Ignore word " << word;
}
}
VLOG(2) << "Word, g2p & prosody: " << word << " "
<< pinyins[idx] << " " << prosodys[idx];
}
// Last token should be "#4"
phonemes->back() = "#4";
if (phonemes->size() > 0) {
phonemes->back() = "#4";
}
}


Expand Down
59 changes: 40 additions & 19 deletions runtime/core/model/tts.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,27 +44,33 @@ TTS::TTS(const std::string& encoder_model_path,
ReadTableFile(speaker2id, &speaker2id_);
}

void TTS::Text2PhoneIds(const std::string& text,
bool TTS::Text2PhoneIds(const std::string& text,
std::vector<int64_t>* phone_ids) {
phone_ids->clear();
// 1. TN
std::string norm_text = tn_->Normalize(text);
LOG(INFO) << text << " --TN--> " << norm_text;
// 2. G2P: char => pinyin => phones => ids
std::vector<std::string> phonemes;
g2p_prosody_->Compute(norm_text, &phonemes);
// 3. Convert to phone id
std::stringstream ss;
phone_ids->emplace_back(phone2id_["sil"]);
ss << "sil";
for (const auto& phone : phonemes) {
if (phone2id_.count(phone) == 0) {
LOG(ERROR) << "Can't find `" << phone << "` in phone2id.";
continue;
if (phonemes.size() > 0) {
std::stringstream ss;
phone_ids->emplace_back(phone2id_["sil"]);
ss << "sil";
for (const auto& phone : phonemes) {
if (phone2id_.count(phone) == 0) {
LOG(ERROR) << "Can't find `" << phone << "` in phone2id.";
continue;
}
ss << " " << phone;
phone_ids->emplace_back(phone2id_[phone]);
}
ss << " " << phone;
phone_ids->emplace_back(phone2id_[phone]);
LOG(INFO) << "phone sequence " << ss.str();
return true;
} else {
return false;
}
LOG(INFO) << "phone sequence " << ss.str();
}

void TTS::Synthesis(const std::string& text, const int sid,
Expand All @@ -73,10 +79,12 @@ void TTS::Synthesis(const std::string& text, const int sid,
SentenceSegement(text, &text_arrs);
for (const auto& text : text_arrs) {
std::vector<int64_t> phonemes;
Text2PhoneIds(text, &phonemes);
std::vector<float> sub_audio;
vits_.Forward(phonemes, sid, &sub_audio);
audio->insert(audio->end(), sub_audio.begin(), sub_audio.end());
bool ok = Text2PhoneIds(text, &phonemes);
if (ok) {
std::vector<float> sub_audio;
vits_.Forward(phonemes, sid, &sub_audio);
audio->insert(audio->end(), sub_audio.begin(), sub_audio.end());
}
}
}

Expand All @@ -93,10 +101,23 @@ bool TTS::StreamSynthesis(std::vector<float>* audio) {
return true; // all text done
}
if (new_text_) {
std::vector<int64_t> phonemes;
Text2PhoneIds(text_arrs_[cur_text_idx_], &phonemes);
vits_.SetInput(phonemes, sid_);
new_text_ = false;
// 连续跳过转换失败的文本片段,直到找到有效的片段
while (cur_text_idx_ < text_arrs_.size()) {
std::vector<int64_t> phonemes;
bool ok = Text2PhoneIds(text_arrs_[cur_text_idx_], &phonemes);
if (ok && !phonemes.empty()) {
// 找到有效的片段,设置输入并退出循环
vits_.SetInput(phonemes, sid_);
new_text_ = false;
break;
}
// 转换失败或结果为空,跳过当前文本片段,继续下一个
cur_text_idx_++;
}
// 如果所有片段都处理完了,返回 true
if (cur_text_idx_ >= text_arrs_.size()) {
return true;
}
}
bool done = vits_.StreamDecode(audio);
if (done) {
Expand Down
2 changes: 1 addition & 1 deletion runtime/core/model/tts.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class TTS {
// true: synthesis done, stop call
// false: synthesis work in progress, you should continue to call me
bool StreamSynthesis(std::vector<float>* audio);
void Text2PhoneIds(const std::string& text, std::vector<int64_t>* phone_ids);
bool Text2PhoneIds(const std::string& text, std::vector<int64_t>* phone_ids);
int GetSid(const std::string& name);
int sampling_rate() const { return sampling_rate_; }

Expand Down