Skip to content

Commit

Permalink
tts: handle outetts-0.3
Browse files Browse the repository at this point in the history
  • Loading branch information
dm4 committed Mar 1, 2025
1 parent 0b9db0a commit 986ade7
Showing 1 changed file with 73 additions and 27 deletions.
100 changes: 73 additions & 27 deletions examples/tts/tts.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@

using json = nlohmann::ordered_json;

enum outetts_version {
OUTETTS_V0_2,
OUTETTS_V0_3,
};

//
// Terminal utils
//
Expand Down Expand Up @@ -374,7 +379,7 @@ static std::string replace_numbers_with_words(const std::string & input_text) {
}

// Based on: https://github.com/edwko/OuteTTS/blob/a613e79c489d8256dd657ea9168d78de75895d82/outetts/version/v1/prompt_processor.py#L39
static std::string process_text(const std::string & text) {
static std::string process_text(const std::string & text, const outetts_version tts_version = OUTETTS_V0_2) {

// For now I skipped text romanization as I am unsure how to handle
// uroman and MeCab implementations in C++
Expand Down Expand Up @@ -404,7 +409,8 @@ static std::string process_text(const std::string & text) {
if (c == ' ') {
prompt_clean += "<|text_sep|>";
*/
processed_text = std::regex_replace(processed_text, std::regex(R"(\s)"), "<|text_sep|>");
std::string separator = (tts_version == OUTETTS_V0_3) ? "<|space|>" : "<|text_sep|>";
processed_text = std::regex_replace(processed_text, std::regex(R"(\s)"), separator);

return processed_text;
}
Expand All @@ -428,8 +434,8 @@ static void prompt_init(llama_tokens & prompt, const llama_vocab * vocab) {
prompt_add(prompt, vocab, "<|im_start|>\n", true, true);
}

static std::vector<llama_token> prepare_guide_tokens(const llama_vocab * vocab, const std::string & str) {
const std::string& delimiter = "<|text_sep|>";
static std::vector<llama_token> prepare_guide_tokens(const llama_vocab * vocab, const std::string & str, const outetts_version tts_version = OUTETTS_V0_2) {
const std::string& delimiter = (tts_version == OUTETTS_V0_3 ? "<|space|>" : "<|text_sep|>");

std::vector<llama_token> result;
size_t start = 0;
Expand Down Expand Up @@ -466,31 +472,62 @@ static json speaker_from_file(const std::string & speaker_file) {
return speaker;
}

static std::string audio_text_from_speaker(json speaker) {
static outetts_version get_tts_version(llama_model *model, json speaker = json::object()) {
if (speaker.contains("version")) {
std::string version = speaker["version"].get<std::string>();
if (version == "0.2") {
return OUTETTS_V0_2;
} else if (version == "0.3") {
return OUTETTS_V0_3;
} else {
LOG_ERR("%s: Unsupported speaker version '%s'\n", __func__, version.c_str());
}
}

// Also could get version from model itself
const char *chat_template = llama_model_chat_template(model, nullptr);
if (chat_template && std::string(chat_template) == "outetts-0.3") {
return OUTETTS_V0_3;
}

// Use 0.2 as the default version
return OUTETTS_V0_2;
}

static std::string audio_text_from_speaker(json speaker, const outetts_version tts_version = OUTETTS_V0_2) {
std::string audio_text = "<|text_start|>";
for (const auto &word : speaker["words"]) {
audio_text += word["word"].get<std::string>() + "<|text_sep|>";

if (tts_version == OUTETTS_V0_2 || tts_version == OUTETTS_V0_3) {
std::string separator = (tts_version == OUTETTS_V0_3) ? "<|space|>" : "<|text_sep|>";
for (const auto &word : speaker["words"]) {
audio_text += word["word"].get<std::string>() + separator;
}
}

return audio_text;
}

static std::string audio_data_from_speaker(json speaker) {
static std::string audio_data_from_speaker(json speaker, const outetts_version tts_version = OUTETTS_V0_2) {
std::string audio_data = "<|audio_start|>\n";
for (const auto &word : speaker["words"]) {
std::string word_text = word["word"].get<std::string>();
double duration = word["duration"].get<double>();
std::vector<int> codes = word["codes"].get<std::vector<int>>();

// Create the audio output entry
std::ostringstream word_entry;
word_entry << word_text << "<|t_" << std::fixed << std::setprecision(2)
<< duration << "|><|code_start|>";
for (const auto &Code : codes) {
word_entry << "<|" << Code << "|>";

if (tts_version == OUTETTS_V0_2 || tts_version == OUTETTS_V0_3) {
std::string code_start = (tts_version == OUTETTS_V0_3) ? "" : "<|code_start|>";
std::string code_end = (tts_version == OUTETTS_V0_3) ? "<|space|>" : "<|code_end|>";
for (const auto &word : speaker["words"]) {
std::string word_text = word["word"].get<std::string>();
double duration = word["duration"].get<double>();
std::vector<int> codes = word["codes"].get<std::vector<int>>();

// Create the audio output entry
std::ostringstream word_entry;
word_entry << word_text << "<|t_" << std::fixed << std::setprecision(2)
<< duration << "|>" + code_start;
for (const auto &Code : codes) {
word_entry << "<|" << Code << "|>";
}
word_entry << code_end << "\n";
audio_data += word_entry.str();
}
word_entry << "<|code_end|>\n";
audio_data += word_entry.str();
}

return audio_data;
Expand Down Expand Up @@ -601,6 +638,14 @@ it<|t_0.09|><|code_start|><|848|><|1366|><|395|><|1601|><|1513|><|593|><|1302|><
looks<|t_0.27|><|code_start|><|1281|><|1266|><|1755|><|572|><|248|><|1751|><|1257|><|695|><|1380|><|457|><|659|><|585|><|1315|><|1105|><|1776|><|736|><|24|><|736|><|654|><|1027|><|code_end|>
lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|1481|><|1721|><|1123|><|438|><|1246|><|1251|><|795|><|659|><|1381|><|1658|><|217|><|1772|><|562|><|952|><|107|><|1129|><|1112|><|467|><|550|><|1079|><|840|><|1615|><|1469|><|1380|><|168|><|917|><|836|><|1827|><|437|><|583|><|67|><|595|><|1087|><|1646|><|1493|><|1677|><|code_end|>)";

// audio data for 0.3 version
outetts_version tts_version = get_tts_version(model_ttc);
if (tts_version == OUTETTS_V0_3) {
audio_text = std::regex_replace(audio_text, std::regex(R"(<\|text_sep\|>)"), "<|space|>");
audio_data = std::regex_replace(audio_data, std::regex(R"(<\|code_start\|>)"), "");
audio_data = std::regex_replace(audio_data, std::regex(R"(<\|code_end\|>)"), "<|space|>");
}

// load speaker if given
if (!params.vocoder.speaker_file.empty()) {
LOG_INF("%s: loading speaker ..\n", __func__);
Expand All @@ -609,8 +654,8 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14
LOG_ERR("%s: Failed to load speaker file '%s'\n", __func__, params.vocoder.speaker_file.c_str());
return 1;
}
audio_text = audio_text_from_speaker(speaker);
audio_data = audio_data_from_speaker(speaker);
audio_text = audio_text_from_speaker(speaker, tts_version);
audio_data = audio_data_from_speaker(speaker, tts_version);
}

// process prompt and generate voice codes
Expand All @@ -625,9 +670,9 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14

// convert the input text into the necessary format expected by OuteTTS
{
std::string prompt_clean = process_text(params.prompt);
std::string prompt_clean = process_text(params.prompt, tts_version);
if (params.vocoder.use_guide_tokens) {
guide_tokens = prepare_guide_tokens(vocab, prompt_clean);
guide_tokens = prepare_guide_tokens(vocab, prompt_clean, tts_version);
}

LOG_INF("%s: prompt: '%s'\n", __func__, prompt_clean.c_str());
Expand All @@ -641,15 +686,16 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14
prompt_add(prompt_inp, vocab, audio_data, false, true);
} else {
// disabled to save time on tokenizing each time
#if 0
#if 1
const std::string voice_data = audio_data;

auto tmp = common_tokenize(vocab, voice_data, false, true);
printf("\n\n");
for (int i = 0; i < tmp.size(); ++i) {
for (size_t i = 0; i < tmp.size(); ++i) {
printf("%d, ", tmp[i]);
}
printf("\n\n");
prompt_add(prompt_inp, tmp);
#else
prompt_add(prompt_inp, llama_tokens {
151667, 198, 1782, 155780, 151669, 151929, 152412, 152308, 152585,
Expand Down

0 comments on commit 986ade7

Please sign in to comment.