From 0abda77b39d56bd397b9f7f56792f8e0aa20d1ce Mon Sep 17 00:00:00 2001 From: WeberJulian Date: Wed, 15 Nov 2023 14:26:57 +0100 Subject: [PATCH 1/7] Add sentence spliting --- TTS/tts/layers/xtts/tokenizer.py | 74 +++++++++-- TTS/tts/models/xtts.py | 214 +++++++++++++++++-------------- 2 files changed, 179 insertions(+), 109 deletions(-) diff --git a/TTS/tts/layers/xtts/tokenizer.py b/TTS/tts/layers/xtts/tokenizer.py index 7726d829ac..f224534598 100644 --- a/TTS/tts/layers/xtts/tokenizer.py +++ b/TTS/tts/layers/xtts/tokenizer.py @@ -1,10 +1,10 @@ -import json import os import re -from functools import cached_property - -import pypinyin import torch +import pypinyin +import textwrap + +from functools import cached_property from hangul_romanize import Transliter from hangul_romanize.rule import academic from num2words import num2words @@ -12,6 +12,57 @@ from TTS.tts.layers.xtts.zh_num2words import TextNorm as zh_num2words +from spacy.lang.en import English +from spacy.lang.zh import Chinese +from spacy.lang.ja import Japanese +from spacy.lang.ar import Arabic + + +def get_spacy_lang(lang): + if lang == "zh": + return Chinese() + elif lang == "ja": + return Japanese() + elif lang == "ar": + return Arabic() + else: + return English() + +def split_sentence(text, lang, text_split_length=250): + """Preprocess the input text""" + text_splits = [] + if text_split_length is not None and len(text) >= text_split_length: + text_splits.append("") + nlp = get_spacy_lang(lang) + nlp.add_pipe("sentencizer") + doc = nlp(text) + for sentence in doc.sents: + if len(text_splits[-1]) + len(str(sentence)) <= text_split_length: + # if the last sentence + the current sentence is less than the text_split_length + # then add the current sentence to the last sentence + text_splits[-1] += " " + str(sentence) + text_splits[-1] = text_splits[-1].lstrip() + elif len(str(sentence)) > text_split_length: + # if the current sentence is greater than the text_split_length + for line in textwrap.wrap( + str(sentence), + width=text_split_length, + drop_whitespace=True, + break_on_hyphens=False, + tabsize=1, + ): + text_splits.append(str(line)) + else: + text_splits.append(str(sentence)) + + if len(text_splits) > 1: + if text_splits[0] == "": + del text_splits[0] + else: + text_splits = [text.lstrip()] + + return text_splits + _whitespace_re = re.compile(r"\s+") # List of (regular expression, replacement) pairs for abbreviations: @@ -464,7 +515,7 @@ def _expand_number(m, lang="en"): def expand_numbers_multilingual(text, lang="en"): - if lang == "zh" or lang == "zh-cn": + if lang == "zh": text = zh_num2words()(text) else: if lang in ["en", "ru"]: @@ -525,7 +576,7 @@ def japanese_cleaners(text, katsu): return text -def korean_cleaners(text): +def korean_transliterate(text): r = Transliter(academic) return r.translit(text) @@ -546,7 +597,7 @@ def __init__(self, vocab_file=None): "it": 213, "pt": 203, "pl": 224, - "zh-cn": 82, + "zh": 82, "ar": 166, "cs": 186, "ru": 182, @@ -571,19 +622,20 @@ def check_input_length(self, txt, lang): ) def preprocess_text(self, txt, lang): - if lang in {"ar", "cs", "de", "en", "es", "fr", "hu", "it", "nl", "pl", "pt", "ru", "tr", "zh-cn", "zh-cn"}: + if lang in {"ar", "cs", "de", "en", "es", "fr", "hu", "it", "nl", "pl", "pt", "ru", "tr", "zh-cn", "ko"}: txt = multilingual_cleaners(txt, lang) - if lang in {"zh", "zh-cn"}: + if lang == "zh": txt = chinese_transliterate(txt) + if lang == "ko": + txt = korean_transliterate(txt) elif lang == "ja": txt = japanese_cleaners(txt, self.katsu) - elif lang == "ko": - txt = korean_cleaners(txt) else: raise NotImplementedError(f"Language '{lang}' is not supported.") return txt def encode(self, txt, lang): + lang = lang.split("-")[0] # remove the region self.check_input_length(txt, lang) txt = self.preprocess_text(txt, lang) txt = f"[{lang}]{txt}" diff --git a/TTS/tts/models/xtts.py b/TTS/tts/models/xtts.py index f37f08449d..f305b76075 100644 --- a/TTS/tts/models/xtts.py +++ b/TTS/tts/models/xtts.py @@ -10,7 +10,7 @@ from TTS.tts.layers.xtts.gpt import GPT from TTS.tts.layers.xtts.hifigan_decoder import HifiDecoder from TTS.tts.layers.xtts.stream_generator import init_stream_support -from TTS.tts.layers.xtts.tokenizer import VoiceBpeTokenizer +from TTS.tts.layers.xtts.tokenizer import VoiceBpeTokenizer, split_sentence from TTS.tts.models.base_tts import BaseTTS from TTS.utils.io import load_fsspec @@ -510,63 +510,69 @@ def inference( do_sample=True, num_beams=1, speed=1.0, + enable_text_splitting=False, **hf_generate_kwargs, ): + language = language.split("-")[0] # remove the country code length_scale = 1.0 / max(speed, 0.05) - text = text.strip().lower() - text_tokens = torch.IntTensor(self.tokenizer.encode(text, lang=language)).unsqueeze(0).to(self.device) - - # print(" > Input text: ", text) - # print(" > Input text preprocessed: ",self.tokenizer.preprocess_text(text, language)) - # print(" > Input tokens: ", text_tokens) - # print(" > Decoded text: ", self.tokenizer.decode(text_tokens[0].cpu().numpy())) - assert ( - text_tokens.shape[-1] < self.args.gpt_max_text_tokens - ), " ❗ XTTS can only generate text with a maximum of 400 tokens." - - with torch.no_grad(): - gpt_codes = self.gpt.generate( - cond_latents=gpt_cond_latent, - text_inputs=text_tokens, - input_tokens=None, - do_sample=do_sample, - top_p=top_p, - top_k=top_k, - temperature=temperature, - num_return_sequences=self.gpt_batch_size, - num_beams=num_beams, - length_penalty=length_penalty, - repetition_penalty=repetition_penalty, - output_attentions=False, - **hf_generate_kwargs, - ) - expected_output_len = torch.tensor( - [gpt_codes.shape[-1] * self.gpt.code_stride_len], device=text_tokens.device - ) + if enable_text_splitting: + text = split_sentence(text, language, self.tokenizer.char_limits[language]) + else: + text = [text] + + wavs = [] + gpt_latents_list = [] + for sent in text: + sent = sent.strip().lower() + text_tokens = torch.IntTensor(self.tokenizer.encode(sent, lang=language)).unsqueeze(0).to(self.device) + + assert ( + text_tokens.shape[-1] < self.args.gpt_max_text_tokens + ), " ❗ XTTS can only generate text with a maximum of 400 tokens." + + with torch.no_grad(): + gpt_codes = self.gpt.generate( + cond_latents=gpt_cond_latent, + text_inputs=text_tokens, + input_tokens=None, + do_sample=do_sample, + top_p=top_p, + top_k=top_k, + temperature=temperature, + num_return_sequences=self.gpt_batch_size, + num_beams=num_beams, + length_penalty=length_penalty, + repetition_penalty=repetition_penalty, + output_attentions=False, + **hf_generate_kwargs, + ) + expected_output_len = torch.tensor( + [gpt_codes.shape[-1] * self.gpt.code_stride_len], device=text_tokens.device + ) - text_len = torch.tensor([text_tokens.shape[-1]], device=self.device) - gpt_latents = self.gpt( - text_tokens, - text_len, - gpt_codes, - expected_output_len, - cond_latents=gpt_cond_latent, - return_attentions=False, - return_latent=True, - ) + text_len = torch.tensor([text_tokens.shape[-1]], device=self.device) + gpt_latents = self.gpt( + text_tokens, + text_len, + gpt_codes, + expected_output_len, + cond_latents=gpt_cond_latent, + return_attentions=False, + return_latent=True, + ) - if length_scale != 1.0: - gpt_latents = F.interpolate( - gpt_latents.transpose(1, 2), - scale_factor=length_scale, - mode="linear" - ).transpose(1, 2) + if length_scale != 1.0: + gpt_latents = F.interpolate( + gpt_latents.transpose(1, 2), + scale_factor=length_scale, + mode="linear" + ).transpose(1, 2) - wav = self.hifigan_decoder(gpt_latents, g=speaker_embedding) + wavs.append(self.hifigan_decoder(gpt_latents, g=speaker_embedding).cpu().squeeze()) return { - "wav": wav.cpu().numpy().squeeze(), - "gpt_latents": gpt_latents, + "wav": torch.cat(wavs, dim=0).numpy(), + "gpt_latents": torch.cat(gpt_latents_list, dim=0).numpy(), "speaker_embedding": speaker_embedding, } @@ -613,59 +619,71 @@ def inference_stream( top_p=0.85, do_sample=True, speed=1.0, + enable_text_splitting=False, **hf_generate_kwargs, ): + language = language.split("-")[0] # remove the country code length_scale = 1.0 / max(speed, 0.05) - text = text.strip().lower() - text_tokens = torch.IntTensor(self.tokenizer.encode(text, lang=language)).unsqueeze(0).to(self.device) + if enable_text_splitting: + text = split_sentence(text, language, self.tokenizer.char_limits[language]) + else: + text = [text] - fake_inputs = self.gpt.compute_embeddings( - gpt_cond_latent.to(self.device), - text_tokens, - ) - gpt_generator = self.gpt.get_generator( - fake_inputs=fake_inputs, - top_k=top_k, - top_p=top_p, - temperature=temperature, - do_sample=do_sample, - num_beams=1, - num_return_sequences=1, - length_penalty=float(length_penalty), - repetition_penalty=float(repetition_penalty), - output_attentions=False, - output_hidden_states=True, - **hf_generate_kwargs, - ) + for sent in text: + sent = sent.strip().lower() + text_tokens = torch.IntTensor(self.tokenizer.encode(sent, lang=language)).unsqueeze(0).to(self.device) - last_tokens = [] - all_latents = [] - wav_gen_prev = None - wav_overlap = None - is_end = False - - while not is_end: - try: - x, latent = next(gpt_generator) - last_tokens += [x] - all_latents += [latent] - except StopIteration: - is_end = True - - if is_end or (stream_chunk_size > 0 and len(last_tokens) >= stream_chunk_size): - gpt_latents = torch.cat(all_latents, dim=0)[None, :] - if length_scale != 1.0: - gpt_latents = F.interpolate( - gpt_latents.transpose(1, 2), - scale_factor=length_scale, - mode="linear" - ).transpose(1, 2) - wav_gen = self.hifigan_decoder(gpt_latents, g=speaker_embedding.to(self.device)) - wav_chunk, wav_gen_prev, wav_overlap = self.handle_chunks( - wav_gen.squeeze(), wav_gen_prev, wav_overlap, overlap_wav_len - ) - last_tokens = [] - yield wav_chunk + assert ( + text_tokens.shape[-1] < self.args.gpt_max_text_tokens + ), " ❗ XTTS can only generate text with a maximum of 400 tokens." + + fake_inputs = self.gpt.compute_embeddings( + gpt_cond_latent.to(self.device), + text_tokens, + ) + gpt_generator = self.gpt.get_generator( + fake_inputs=fake_inputs, + top_k=top_k, + top_p=top_p, + temperature=temperature, + do_sample=do_sample, + num_beams=1, + num_return_sequences=1, + length_penalty=float(length_penalty), + repetition_penalty=float(repetition_penalty), + output_attentions=False, + output_hidden_states=True, + **hf_generate_kwargs, + ) + + last_tokens = [] + all_latents = [] + wav_gen_prev = None + wav_overlap = None + is_end = False + + while not is_end: + try: + x, latent = next(gpt_generator) + last_tokens += [x] + all_latents += [latent] + except StopIteration: + is_end = True + + if is_end or (stream_chunk_size > 0 and len(last_tokens) >= stream_chunk_size): + gpt_latents = torch.cat(all_latents, dim=0)[None, :] + if length_scale != 1.0: + gpt_latents = F.interpolate( + gpt_latents.transpose(1, 2), + scale_factor=length_scale, + mode="linear" + ).transpose(1, 2) + wav_gen = self.hifigan_decoder(gpt_latents, g=speaker_embedding.to(self.device)) + wav_chunk, wav_gen_prev, wav_overlap = self.handle_chunks( + wav_gen.squeeze(), wav_gen_prev, wav_overlap, overlap_wav_len + ) + last_tokens = [] + yield wav_chunk def forward(self): raise NotImplementedError( From b5abb06dd9cbec1d30b26990bc71abf85a97f799 Mon Sep 17 00:00:00 2001 From: WeberJulian Date: Wed, 15 Nov 2023 14:28:03 +0100 Subject: [PATCH 2/7] update requirements --- requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements.txt b/requirements.txt index 53e8af590c..b418e4fe29 100644 --- a/requirements.txt +++ b/requirements.txt @@ -54,3 +54,4 @@ encodec==0.1.* # deps for XTTS unidecode==1.3.* num2words +spacy[ja] \ No newline at end of file From 2d9525a6696ecc0f9ced360a645f12dba8a2b955 Mon Sep 17 00:00:00 2001 From: WeberJulian Date: Wed, 15 Nov 2023 14:30:48 +0100 Subject: [PATCH 3/7] update default args v2 --- TTS/tts/models/xtts.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/TTS/tts/models/xtts.py b/TTS/tts/models/xtts.py index f305b76075..3fd04351d9 100644 --- a/TTS/tts/models/xtts.py +++ b/TTS/tts/models/xtts.py @@ -420,9 +420,9 @@ def full_inference( ref_audio_path, language, # GPT inference - temperature=0.65, - length_penalty=1, - repetition_penalty=2.0, + temperature=0.75, + length_penalty=1.0, + repetition_penalty=10.0, top_k=50, top_p=0.85, do_sample=True, @@ -502,9 +502,9 @@ def inference( gpt_cond_latent, speaker_embedding, # GPT inference - temperature=0.65, - length_penalty=1, - repetition_penalty=2.0, + temperature=0.75, + length_penalty=1.0, + repetition_penalty=10.0, top_k=50, top_p=0.85, do_sample=True, @@ -612,9 +612,9 @@ def inference_stream( stream_chunk_size=20, overlap_wav_len=1024, # GPT inference - temperature=0.65, - length_penalty=1, - repetition_penalty=2.0, + temperature=0.75, + length_penalty=1.0, + repetition_penalty=10.0, top_k=50, top_p=0.85, do_sample=True, From d64cafccd2a107decc32ca659369b0dc521c80e2 Mon Sep 17 00:00:00 2001 From: WeberJulian Date: Wed, 15 Nov 2023 14:35:30 +0100 Subject: [PATCH 4/7] Add spanish --- TTS/tts/layers/xtts/tokenizer.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/TTS/tts/layers/xtts/tokenizer.py b/TTS/tts/layers/xtts/tokenizer.py index f224534598..56eb78aed4 100644 --- a/TTS/tts/layers/xtts/tokenizer.py +++ b/TTS/tts/layers/xtts/tokenizer.py @@ -16,6 +16,7 @@ from spacy.lang.zh import Chinese from spacy.lang.ja import Japanese from spacy.lang.ar import Arabic +from spacy.lang.es import Spanish def get_spacy_lang(lang): @@ -25,7 +26,10 @@ def get_spacy_lang(lang): return Japanese() elif lang == "ar": return Arabic() + elif lang == "es": + return Spanish() else: + # For most languages, Enlish does the job return English() def split_sentence(text, lang, text_split_length=250): From 83988c922a24c96aa4afbfcb9e87620603f07d51 Mon Sep 17 00:00:00 2001 From: WeberJulian Date: Wed, 15 Nov 2023 14:37:53 +0100 Subject: [PATCH 5/7] Fix return gpt_latents --- TTS/tts/models/xtts.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/TTS/tts/models/xtts.py b/TTS/tts/models/xtts.py index 3fd04351d9..5ccb26c314 100644 --- a/TTS/tts/models/xtts.py +++ b/TTS/tts/models/xtts.py @@ -568,11 +568,12 @@ def inference( mode="linear" ).transpose(1, 2) + gpt_latents_list.append(gpt_latents.cpu()) wavs.append(self.hifigan_decoder(gpt_latents, g=speaker_embedding).cpu().squeeze()) return { "wav": torch.cat(wavs, dim=0).numpy(), - "gpt_latents": torch.cat(gpt_latents_list, dim=0).numpy(), + "gpt_latents": torch.cat(gpt_latents_list, dim=1).numpy(), "speaker_embedding": speaker_embedding, } From 1bb689aa9ff5476ef47739d5f82617f0fc5d613b Mon Sep 17 00:00:00 2001 From: WeberJulian Date: Wed, 15 Nov 2023 15:22:10 +0100 Subject: [PATCH 6/7] Update requirements --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index b418e4fe29..3ad74450e5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -54,4 +54,4 @@ encodec==0.1.* # deps for XTTS unidecode==1.3.* num2words -spacy[ja] \ No newline at end of file +spacy[ja]>=3.* \ No newline at end of file From 491f16ec3775182a458468c9745bbc101a4c5c95 Mon Sep 17 00:00:00 2001 From: WeberJulian Date: Wed, 15 Nov 2023 15:23:55 +0100 Subject: [PATCH 7/7] Fix requirements --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 3ad74450e5..836de40ab6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -54,4 +54,4 @@ encodec==0.1.* # deps for XTTS unidecode==1.3.* num2words -spacy[ja]>=3.* \ No newline at end of file +spacy[ja]>=3 \ No newline at end of file