From 986ade70f811d75c0c8a52e9b9d73d66c50a5826 Mon Sep 17 00:00:00 2001 From: dm4 Date: Sat, 1 Mar 2025 17:05:30 +0800 Subject: [PATCH] tts: handle outetts-0.3 --- examples/tts/tts.cpp | 100 +++++++++++++++++++++++++++++++------------ 1 file changed, 73 insertions(+), 27 deletions(-) diff --git a/examples/tts/tts.cpp b/examples/tts/tts.cpp index 998e093ac9d7f..bfdce785e7837 100644 --- a/examples/tts/tts.cpp +++ b/examples/tts/tts.cpp @@ -19,6 +19,11 @@ using json = nlohmann::ordered_json; +enum outetts_version { + OUTETTS_V0_2, + OUTETTS_V0_3, +}; + // // Terminal utils // @@ -374,7 +379,7 @@ static std::string replace_numbers_with_words(const std::string & input_text) { } // Based on: https://github.com/edwko/OuteTTS/blob/a613e79c489d8256dd657ea9168d78de75895d82/outetts/version/v1/prompt_processor.py#L39 -static std::string process_text(const std::string & text) { +static std::string process_text(const std::string & text, const outetts_version tts_version = OUTETTS_V0_2) { // For now I skipped text romanization as I am unsure how to handle // uroman and MeCab implementations in C++ @@ -404,7 +409,8 @@ static std::string process_text(const std::string & text) { if (c == ' ') { prompt_clean += "<|text_sep|>"; */ - processed_text = std::regex_replace(processed_text, std::regex(R"(\s)"), "<|text_sep|>"); + std::string separator = (tts_version == OUTETTS_V0_3) ? "<|space|>" : "<|text_sep|>"; + processed_text = std::regex_replace(processed_text, std::regex(R"(\s)"), separator); return processed_text; } @@ -428,8 +434,8 @@ static void prompt_init(llama_tokens & prompt, const llama_vocab * vocab) { prompt_add(prompt, vocab, "<|im_start|>\n", true, true); } -static std::vector prepare_guide_tokens(const llama_vocab * vocab, const std::string & str) { - const std::string& delimiter = "<|text_sep|>"; +static std::vector prepare_guide_tokens(const llama_vocab * vocab, const std::string & str, const outetts_version tts_version = OUTETTS_V0_2) { + const std::string& delimiter = (tts_version == OUTETTS_V0_3 ? "<|space|>" : "<|text_sep|>"); std::vector result; size_t start = 0; @@ -466,31 +472,62 @@ static json speaker_from_file(const std::string & speaker_file) { return speaker; } -static std::string audio_text_from_speaker(json speaker) { +static outetts_version get_tts_version(llama_model *model, json speaker = json::object()) { + if (speaker.contains("version")) { + std::string version = speaker["version"].get(); + if (version == "0.2") { + return OUTETTS_V0_2; + } else if (version == "0.3") { + return OUTETTS_V0_3; + } else { + LOG_ERR("%s: Unsupported speaker version '%s'\n", __func__, version.c_str()); + } + } + + // Also could get version from model itself + const char *chat_template = llama_model_chat_template(model, nullptr); + if (chat_template && std::string(chat_template) == "outetts-0.3") { + return OUTETTS_V0_3; + } + + // Use 0.2 as the default version + return OUTETTS_V0_2; +} + +static std::string audio_text_from_speaker(json speaker, const outetts_version tts_version = OUTETTS_V0_2) { std::string audio_text = "<|text_start|>"; - for (const auto &word : speaker["words"]) { - audio_text += word["word"].get() + "<|text_sep|>"; + + if (tts_version == OUTETTS_V0_2 || tts_version == OUTETTS_V0_3) { + std::string separator = (tts_version == OUTETTS_V0_3) ? "<|space|>" : "<|text_sep|>"; + for (const auto &word : speaker["words"]) { + audio_text += word["word"].get() + separator; + } } return audio_text; } -static std::string audio_data_from_speaker(json speaker) { +static std::string audio_data_from_speaker(json speaker, const outetts_version tts_version = OUTETTS_V0_2) { std::string audio_data = "<|audio_start|>\n"; - for (const auto &word : speaker["words"]) { - std::string word_text = word["word"].get(); - double duration = word["duration"].get(); - std::vector codes = word["codes"].get>(); - - // Create the audio output entry - std::ostringstream word_entry; - word_entry << word_text << "<|t_" << std::fixed << std::setprecision(2) - << duration << "|><|code_start|>"; - for (const auto &Code : codes) { - word_entry << "<|" << Code << "|>"; + + if (tts_version == OUTETTS_V0_2 || tts_version == OUTETTS_V0_3) { + std::string code_start = (tts_version == OUTETTS_V0_3) ? "" : "<|code_start|>"; + std::string code_end = (tts_version == OUTETTS_V0_3) ? "<|space|>" : "<|code_end|>"; + for (const auto &word : speaker["words"]) { + std::string word_text = word["word"].get(); + double duration = word["duration"].get(); + std::vector codes = word["codes"].get>(); + + // Create the audio output entry + std::ostringstream word_entry; + word_entry << word_text << "<|t_" << std::fixed << std::setprecision(2) + << duration << "|>" + code_start; + for (const auto &Code : codes) { + word_entry << "<|" << Code << "|>"; + } + word_entry << code_end << "\n"; + audio_data += word_entry.str(); } - word_entry << "<|code_end|>\n"; - audio_data += word_entry.str(); } return audio_data; @@ -601,6 +638,14 @@ it<|t_0.09|><|code_start|><|848|><|1366|><|395|><|1601|><|1513|><|593|><|1302|>< looks<|t_0.27|><|code_start|><|1281|><|1266|><|1755|><|572|><|248|><|1751|><|1257|><|695|><|1380|><|457|><|659|><|585|><|1315|><|1105|><|1776|><|736|><|24|><|736|><|654|><|1027|><|code_end|> lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|1481|><|1721|><|1123|><|438|><|1246|><|1251|><|795|><|659|><|1381|><|1658|><|217|><|1772|><|562|><|952|><|107|><|1129|><|1112|><|467|><|550|><|1079|><|840|><|1615|><|1469|><|1380|><|168|><|917|><|836|><|1827|><|437|><|583|><|67|><|595|><|1087|><|1646|><|1493|><|1677|><|code_end|>)"; + // audio data for 0.3 version + outetts_version tts_version = get_tts_version(model_ttc); + if (tts_version == OUTETTS_V0_3) { + audio_text = std::regex_replace(audio_text, std::regex(R"(<\|text_sep\|>)"), "<|space|>"); + audio_data = std::regex_replace(audio_data, std::regex(R"(<\|code_start\|>)"), ""); + audio_data = std::regex_replace(audio_data, std::regex(R"(<\|code_end\|>)"), "<|space|>"); + } + // load speaker if given if (!params.vocoder.speaker_file.empty()) { LOG_INF("%s: loading speaker ..\n", __func__); @@ -609,8 +654,8 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14 LOG_ERR("%s: Failed to load speaker file '%s'\n", __func__, params.vocoder.speaker_file.c_str()); return 1; } - audio_text = audio_text_from_speaker(speaker); - audio_data = audio_data_from_speaker(speaker); + audio_text = audio_text_from_speaker(speaker, tts_version); + audio_data = audio_data_from_speaker(speaker, tts_version); } // process prompt and generate voice codes @@ -625,9 +670,9 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14 // convert the input text into the necessary format expected by OuteTTS { - std::string prompt_clean = process_text(params.prompt); + std::string prompt_clean = process_text(params.prompt, tts_version); if (params.vocoder.use_guide_tokens) { - guide_tokens = prepare_guide_tokens(vocab, prompt_clean); + guide_tokens = prepare_guide_tokens(vocab, prompt_clean, tts_version); } LOG_INF("%s: prompt: '%s'\n", __func__, prompt_clean.c_str()); @@ -641,15 +686,16 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14 prompt_add(prompt_inp, vocab, audio_data, false, true); } else { // disabled to save time on tokenizing each time -#if 0 +#if 1 const std::string voice_data = audio_data; auto tmp = common_tokenize(vocab, voice_data, false, true); printf("\n\n"); - for (int i = 0; i < tmp.size(); ++i) { + for (size_t i = 0; i < tmp.size(); ++i) { printf("%d, ", tmp[i]); } printf("\n\n"); + prompt_add(prompt_inp, tmp); #else prompt_add(prompt_inp, llama_tokens { 151667, 198, 1782, 155780, 151669, 151929, 152412, 152308, 152585,