From 5723ce85c849c0fbbc7a2c862d5e98e6ef264383 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Thu, 14 Mar 2024 12:06:47 +0800 Subject: [PATCH 1/8] Copy files --- egs/baker_zh/TTS/README.md | 0 egs/baker_zh/TTS/prepare.sh | 0 egs/baker_zh/TTS/shared | 1 + egs/baker_zh/TTS/vits/duration_predictor.py | 1 + egs/baker_zh/TTS/vits/flow.py | 1 + egs/baker_zh/TTS/vits/generator.py | 1 + egs/baker_zh/TTS/vits/hifigan.py | 1 + egs/baker_zh/TTS/vits/loss.py | 1 + egs/baker_zh/TTS/vits/monotonic_align | 1 + egs/baker_zh/TTS/vits/posterior_encoder.py | 1 + egs/baker_zh/TTS/vits/residual_coupling.py | 1 + egs/baker_zh/TTS/vits/text_encoder.py | 1 + egs/baker_zh/TTS/vits/train.py | 1 + egs/baker_zh/TTS/vits/transform.py | 1 + egs/baker_zh/TTS/vits/tts_datamodule.py | 329 ++++++++++++++++++++ egs/baker_zh/TTS/vits/utils.py | 1 + egs/baker_zh/TTS/vits/vits.py | 1 + egs/baker_zh/TTS/vits/wavenet.py | 1 + 18 files changed, 344 insertions(+) create mode 100644 egs/baker_zh/TTS/README.md create mode 100755 egs/baker_zh/TTS/prepare.sh create mode 120000 egs/baker_zh/TTS/shared create mode 120000 egs/baker_zh/TTS/vits/duration_predictor.py create mode 120000 egs/baker_zh/TTS/vits/flow.py create mode 120000 egs/baker_zh/TTS/vits/generator.py create mode 120000 egs/baker_zh/TTS/vits/hifigan.py create mode 120000 egs/baker_zh/TTS/vits/loss.py create mode 120000 egs/baker_zh/TTS/vits/monotonic_align create mode 120000 egs/baker_zh/TTS/vits/posterior_encoder.py create mode 120000 egs/baker_zh/TTS/vits/residual_coupling.py create mode 120000 egs/baker_zh/TTS/vits/text_encoder.py create mode 120000 egs/baker_zh/TTS/vits/train.py create mode 120000 egs/baker_zh/TTS/vits/transform.py create mode 100644 egs/baker_zh/TTS/vits/tts_datamodule.py create mode 120000 egs/baker_zh/TTS/vits/utils.py create mode 120000 egs/baker_zh/TTS/vits/vits.py create mode 120000 egs/baker_zh/TTS/vits/wavenet.py diff --git a/egs/baker_zh/TTS/README.md b/egs/baker_zh/TTS/README.md new file mode 100644 index 0000000000..e69de29bb2 diff --git a/egs/baker_zh/TTS/prepare.sh b/egs/baker_zh/TTS/prepare.sh new file mode 100755 index 0000000000..e69de29bb2 diff --git a/egs/baker_zh/TTS/shared b/egs/baker_zh/TTS/shared new file mode 120000 index 0000000000..4cbd91a7e9 --- /dev/null +++ b/egs/baker_zh/TTS/shared @@ -0,0 +1 @@ +../../../icefall/shared \ No newline at end of file diff --git a/egs/baker_zh/TTS/vits/duration_predictor.py b/egs/baker_zh/TTS/vits/duration_predictor.py new file mode 120000 index 0000000000..9972b476f9 --- /dev/null +++ b/egs/baker_zh/TTS/vits/duration_predictor.py @@ -0,0 +1 @@ +../../../ljspeech/TTS/vits/duration_predictor.py \ No newline at end of file diff --git a/egs/baker_zh/TTS/vits/flow.py b/egs/baker_zh/TTS/vits/flow.py new file mode 120000 index 0000000000..e65d91ea75 --- /dev/null +++ b/egs/baker_zh/TTS/vits/flow.py @@ -0,0 +1 @@ +../../../ljspeech/TTS/vits/flow.py \ No newline at end of file diff --git a/egs/baker_zh/TTS/vits/generator.py b/egs/baker_zh/TTS/vits/generator.py new file mode 120000 index 0000000000..611679bfa8 --- /dev/null +++ b/egs/baker_zh/TTS/vits/generator.py @@ -0,0 +1 @@ +../../../ljspeech/TTS/vits/generator.py \ No newline at end of file diff --git a/egs/baker_zh/TTS/vits/hifigan.py b/egs/baker_zh/TTS/vits/hifigan.py new file mode 120000 index 0000000000..5ac025de72 --- /dev/null +++ b/egs/baker_zh/TTS/vits/hifigan.py @@ -0,0 +1 @@ +../../../ljspeech/TTS/vits/hifigan.py \ No newline at end of file diff --git a/egs/baker_zh/TTS/vits/loss.py b/egs/baker_zh/TTS/vits/loss.py new file mode 120000 index 0000000000..672e5ff68d --- /dev/null +++ b/egs/baker_zh/TTS/vits/loss.py @@ -0,0 +1 @@ +../../../ljspeech/TTS/vits/loss.py \ No newline at end of file diff --git a/egs/baker_zh/TTS/vits/monotonic_align b/egs/baker_zh/TTS/vits/monotonic_align new file mode 120000 index 0000000000..71934e7cca --- /dev/null +++ b/egs/baker_zh/TTS/vits/monotonic_align @@ -0,0 +1 @@ +../../../ljspeech/TTS/vits/monotonic_align \ No newline at end of file diff --git a/egs/baker_zh/TTS/vits/posterior_encoder.py b/egs/baker_zh/TTS/vits/posterior_encoder.py new file mode 120000 index 0000000000..41d64a3a66 --- /dev/null +++ b/egs/baker_zh/TTS/vits/posterior_encoder.py @@ -0,0 +1 @@ +../../../ljspeech/TTS/vits/posterior_encoder.py \ No newline at end of file diff --git a/egs/baker_zh/TTS/vits/residual_coupling.py b/egs/baker_zh/TTS/vits/residual_coupling.py new file mode 120000 index 0000000000..f979adbf00 --- /dev/null +++ b/egs/baker_zh/TTS/vits/residual_coupling.py @@ -0,0 +1 @@ +../../../ljspeech/TTS/vits/residual_coupling.py \ No newline at end of file diff --git a/egs/baker_zh/TTS/vits/text_encoder.py b/egs/baker_zh/TTS/vits/text_encoder.py new file mode 120000 index 0000000000..0efba277e1 --- /dev/null +++ b/egs/baker_zh/TTS/vits/text_encoder.py @@ -0,0 +1 @@ +../../../ljspeech/TTS/vits/text_encoder.py \ No newline at end of file diff --git a/egs/baker_zh/TTS/vits/train.py b/egs/baker_zh/TTS/vits/train.py new file mode 120000 index 0000000000..ea0fad02a8 --- /dev/null +++ b/egs/baker_zh/TTS/vits/train.py @@ -0,0 +1 @@ +../../../ljspeech/TTS/vits/train.py \ No newline at end of file diff --git a/egs/baker_zh/TTS/vits/transform.py b/egs/baker_zh/TTS/vits/transform.py new file mode 120000 index 0000000000..962647408b --- /dev/null +++ b/egs/baker_zh/TTS/vits/transform.py @@ -0,0 +1 @@ +../../../ljspeech/TTS/vits/transform.py \ No newline at end of file diff --git a/egs/baker_zh/TTS/vits/tts_datamodule.py b/egs/baker_zh/TTS/vits/tts_datamodule.py new file mode 100644 index 0000000000..e1a9c7b3ca --- /dev/null +++ b/egs/baker_zh/TTS/vits/tts_datamodule.py @@ -0,0 +1,329 @@ +# Copyright 2021 Piotr Żelasko +# Copyright 2022-2023 Xiaomi Corporation (Authors: Mingshuang Luo, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import logging +from functools import lru_cache +from pathlib import Path +from typing import Any, Dict, Optional + +import torch +from lhotse import CutSet, Spectrogram, SpectrogramConfig, load_manifest_lazy +from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures + CutConcatenate, + CutMix, + DynamicBucketingSampler, + PrecomputedFeatures, + SimpleCutSampler, + SpecAugment, + SpeechSynthesisDataset, +) +from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples + AudioSamples, + OnTheFlyFeatures, +) +from lhotse.utils import fix_random_seed +from torch.utils.data import DataLoader + +from icefall.utils import str2bool + + +class _SeedWorkers: + def __init__(self, seed: int): + self.seed = seed + + def __call__(self, worker_id: int): + fix_random_seed(self.seed + worker_id) + + +class LJSpeechTtsDataModule: + """ + DataModule for tts experiments. + It assumes there is always one train and valid dataloader, + but there can be multiple test dataloaders (e.g. LibriSpeech test-clean + and test-other). + + It contains all the common data pipeline modules used in ASR + experiments, e.g.: + - dynamic batch size, + - bucketing samplers, + - cut concatenation, + - on-the-fly feature extraction + + This class should be derived for specific corpora used in ASR tasks. + """ + + def __init__(self, args: argparse.Namespace): + self.args = args + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="TTS data related options", + description="These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc.", + ) + + group.add_argument( + "--manifest-dir", + type=Path, + default=Path("data/spectrogram"), + help="Path to directory with train/valid/test cuts.", + ) + group.add_argument( + "--max-duration", + type=int, + default=200.0, + help="Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM.", + ) + group.add_argument( + "--bucketing-sampler", + type=str2bool, + default=True, + help="When enabled, the batches will come from buckets of " + "similar duration (saves padding frames).", + ) + group.add_argument( + "--num-buckets", + type=int, + default=30, + help="The number of buckets for the DynamicBucketingSampler" + "(you might want to increase it for larger datasets).", + ) + + group.add_argument( + "--on-the-fly-feats", + type=str2bool, + default=False, + help="When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available.", + ) + group.add_argument( + "--shuffle", + type=str2bool, + default=True, + help="When enabled (=default), the examples will be " + "shuffled for each epoch.", + ) + group.add_argument( + "--drop-last", + type=str2bool, + default=True, + help="Whether to drop last batch. Used by sampler.", + ) + group.add_argument( + "--return-cuts", + type=str2bool, + default=False, + help="When enabled, each batch will have the " + "field: batch['cut'] with the cuts that " + "were used to construct it.", + ) + group.add_argument( + "--num-workers", + type=int, + default=2, + help="The number of training dataloader workers that " + "collect the batches.", + ) + + group.add_argument( + "--input-strategy", + type=str, + default="PrecomputedFeatures", + help="AudioSamples or PrecomputedFeatures", + ) + + def train_dataloaders( + self, + cuts_train: CutSet, + sampler_state_dict: Optional[Dict[str, Any]] = None, + ) -> DataLoader: + """ + Args: + cuts_train: + CutSet for training. + sampler_state_dict: + The state dict for the training sampler. + """ + logging.info("About to create train dataset") + train = SpeechSynthesisDataset( + return_text=False, + return_tokens=True, + feature_input_strategy=eval(self.args.input_strategy)(), + return_cuts=self.args.return_cuts, + ) + + if self.args.on_the_fly_feats: + sampling_rate = 22050 + config = SpectrogramConfig( + sampling_rate=sampling_rate, + frame_length=1024 / sampling_rate, # (in second), + frame_shift=256 / sampling_rate, # (in second) + use_fft_mag=True, + ) + train = SpeechSynthesisDataset( + return_text=False, + return_tokens=True, + feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)), + return_cuts=self.args.return_cuts, + ) + + if self.args.bucketing_sampler: + logging.info("Using DynamicBucketingSampler.") + train_sampler = DynamicBucketingSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + num_buckets=self.args.num_buckets, + buffer_size=self.args.num_buckets * 2000, + shuffle_buffer_size=self.args.num_buckets * 5000, + drop_last=self.args.drop_last, + ) + else: + logging.info("Using SimpleCutSampler.") + train_sampler = SimpleCutSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + ) + logging.info("About to create train dataloader") + + if sampler_state_dict is not None: + logging.info("Loading sampler state dict") + train_sampler.load_state_dict(sampler_state_dict) + + # 'seed' is derived from the current random state, which will have + # previously been set in the main process. + seed = torch.randint(0, 100000, ()).item() + worker_init_fn = _SeedWorkers(seed) + + train_dl = DataLoader( + train, + sampler=train_sampler, + batch_size=None, + num_workers=self.args.num_workers, + persistent_workers=False, + worker_init_fn=worker_init_fn, + ) + + return train_dl + + def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: + logging.info("About to create dev dataset") + if self.args.on_the_fly_feats: + sampling_rate = 22050 + config = SpectrogramConfig( + sampling_rate=sampling_rate, + frame_length=1024 / sampling_rate, # (in second), + frame_shift=256 / sampling_rate, # (in second) + use_fft_mag=True, + ) + validate = SpeechSynthesisDataset( + return_text=False, + return_tokens=True, + feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)), + return_cuts=self.args.return_cuts, + ) + else: + validate = SpeechSynthesisDataset( + return_text=False, + return_tokens=True, + feature_input_strategy=eval(self.args.input_strategy)(), + return_cuts=self.args.return_cuts, + ) + valid_sampler = DynamicBucketingSampler( + cuts_valid, + max_duration=self.args.max_duration, + num_buckets=self.args.num_buckets, + shuffle=False, + ) + logging.info("About to create valid dataloader") + valid_dl = DataLoader( + validate, + sampler=valid_sampler, + batch_size=None, + num_workers=2, + persistent_workers=False, + ) + + return valid_dl + + def test_dataloaders(self, cuts: CutSet) -> DataLoader: + logging.info("About to create test dataset") + if self.args.on_the_fly_feats: + sampling_rate = 22050 + config = SpectrogramConfig( + sampling_rate=sampling_rate, + frame_length=1024 / sampling_rate, # (in second), + frame_shift=256 / sampling_rate, # (in second) + use_fft_mag=True, + ) + test = SpeechSynthesisDataset( + return_text=False, + return_tokens=True, + feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)), + return_cuts=self.args.return_cuts, + ) + else: + test = SpeechSynthesisDataset( + return_text=False, + return_tokens=True, + feature_input_strategy=eval(self.args.input_strategy)(), + return_cuts=self.args.return_cuts, + ) + test_sampler = DynamicBucketingSampler( + cuts, + max_duration=self.args.max_duration, + num_buckets=self.args.num_buckets, + shuffle=False, + ) + logging.info("About to create test dataloader") + test_dl = DataLoader( + test, + batch_size=None, + sampler=test_sampler, + num_workers=self.args.num_workers, + ) + return test_dl + + @lru_cache() + def train_cuts(self) -> CutSet: + logging.info("About to get train cuts") + return load_manifest_lazy( + self.args.manifest_dir / "ljspeech_cuts_train.jsonl.gz" + ) + + @lru_cache() + def valid_cuts(self) -> CutSet: + logging.info("About to get validation cuts") + return load_manifest_lazy( + self.args.manifest_dir / "ljspeech_cuts_valid.jsonl.gz" + ) + + @lru_cache() + def test_cuts(self) -> CutSet: + logging.info("About to get test cuts") + return load_manifest_lazy( + self.args.manifest_dir / "ljspeech_cuts_test.jsonl.gz" + ) diff --git a/egs/baker_zh/TTS/vits/utils.py b/egs/baker_zh/TTS/vits/utils.py new file mode 120000 index 0000000000..085e764b43 --- /dev/null +++ b/egs/baker_zh/TTS/vits/utils.py @@ -0,0 +1 @@ +../../../ljspeech/TTS/vits/utils.py \ No newline at end of file diff --git a/egs/baker_zh/TTS/vits/vits.py b/egs/baker_zh/TTS/vits/vits.py new file mode 120000 index 0000000000..1f58cf6fea --- /dev/null +++ b/egs/baker_zh/TTS/vits/vits.py @@ -0,0 +1 @@ +../../../ljspeech/TTS/vits/vits.py \ No newline at end of file diff --git a/egs/baker_zh/TTS/vits/wavenet.py b/egs/baker_zh/TTS/vits/wavenet.py new file mode 120000 index 0000000000..28f0a78eeb --- /dev/null +++ b/egs/baker_zh/TTS/vits/wavenet.py @@ -0,0 +1 @@ +../../../ljspeech/TTS/vits/wavenet.py \ No newline at end of file From 8b867affee88b79d0f441dc54198628c8a6d9ba0 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Mon, 18 Mar 2024 10:04:07 +0800 Subject: [PATCH 2/8] first working version --- .gitignore | 3 + docs/source/recipes/TTS/ljspeech/vits.rst | 2 +- egs/baker_zh/TTS/local/README.md | 7 + egs/baker_zh/TTS/local/__init__.py | 0 .../TTS/local/compute_spectrogram_baker.py | 106 ++ egs/baker_zh/TTS/local/pinyin_dict.py | 421 ++++++++ egs/baker_zh/TTS/local/prepare_token_file.py | 53 + .../TTS/local/prepare_tokens_baker_zh.py | 59 ++ egs/baker_zh/TTS/local/pypinyin-local.dict | 328 +++++++ egs/baker_zh/TTS/local/symbols.py | 73 ++ egs/baker_zh/TTS/local/tokenizer.py | 137 +++ egs/baker_zh/TTS/local/validate_manifest.py | 1 + egs/baker_zh/TTS/prepare.sh | 124 +++ egs/baker_zh/TTS/vits/export-onnx.py | 414 ++++++++ egs/baker_zh/TTS/vits/generate_lexicon.py | 39 + egs/baker_zh/TTS/vits/pinyin_dict.py | 1 + egs/baker_zh/TTS/vits/pypinyin-local.dict | 1 + egs/baker_zh/TTS/vits/test_onnx.py | 142 +++ egs/baker_zh/TTS/vits/tokenizer.py | 1 + egs/baker_zh/TTS/vits/train.py | 928 +++++++++++++++++- egs/baker_zh/TTS/vits/tts_datamodule.py | 17 +- .../TTS/vits/monotonic_align/setup.py | 5 +- egs/ljspeech/TTS/vits/tokenizer.py | 6 +- egs/ljspeech/TTS/vits/tts_datamodule.py | 2 +- 24 files changed, 2855 insertions(+), 15 deletions(-) create mode 100644 egs/baker_zh/TTS/local/README.md create mode 100644 egs/baker_zh/TTS/local/__init__.py create mode 100755 egs/baker_zh/TTS/local/compute_spectrogram_baker.py create mode 100644 egs/baker_zh/TTS/local/pinyin_dict.py create mode 100755 egs/baker_zh/TTS/local/prepare_token_file.py create mode 100755 egs/baker_zh/TTS/local/prepare_tokens_baker_zh.py create mode 100644 egs/baker_zh/TTS/local/pypinyin-local.dict create mode 100644 egs/baker_zh/TTS/local/symbols.py create mode 100644 egs/baker_zh/TTS/local/tokenizer.py create mode 120000 egs/baker_zh/TTS/local/validate_manifest.py create mode 100755 egs/baker_zh/TTS/vits/export-onnx.py create mode 100755 egs/baker_zh/TTS/vits/generate_lexicon.py create mode 120000 egs/baker_zh/TTS/vits/pinyin_dict.py create mode 120000 egs/baker_zh/TTS/vits/pypinyin-local.dict create mode 100755 egs/baker_zh/TTS/vits/test_onnx.py create mode 120000 egs/baker_zh/TTS/vits/tokenizer.py mode change 120000 => 100755 egs/baker_zh/TTS/vits/train.py diff --git a/.gitignore b/.gitignore index fa18ca83c3..620427501b 100644 --- a/.gitignore +++ b/.gitignore @@ -36,3 +36,6 @@ node_modules .DS_Store *.fst *.arpa +core.c +*.so +build diff --git a/docs/source/recipes/TTS/ljspeech/vits.rst b/docs/source/recipes/TTS/ljspeech/vits.rst index 9499a3aea2..37c8bff1e6 100644 --- a/docs/source/recipes/TTS/ljspeech/vits.rst +++ b/docs/source/recipes/TTS/ljspeech/vits.rst @@ -19,7 +19,7 @@ Install extra dependencies .. code-block:: bash pip install piper_phonemize -f https://k2-fsa.github.io/icefall/piper_phonemize.html - pip install numba espnet_tts_frontend + pip install numba espnet_tts_frontend cython Data preparation ---------------- diff --git a/egs/baker_zh/TTS/local/README.md b/egs/baker_zh/TTS/local/README.md new file mode 100644 index 0000000000..dac1388537 --- /dev/null +++ b/egs/baker_zh/TTS/local/README.md @@ -0,0 +1,7 @@ +# Introduction + +[./symbols.py](./symbols.py) is copied from +https://github.com/UEhQZXI/vits_chinese/blob/master/text/symbols.py + +[./pypinyin-local.dict](./pypinyin-local.dict) is copied from +https://github.com/UEhQZXI/vits_chinese/blob/master/misc/pypinyin-local.dict diff --git a/egs/baker_zh/TTS/local/__init__.py b/egs/baker_zh/TTS/local/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/egs/baker_zh/TTS/local/compute_spectrogram_baker.py b/egs/baker_zh/TTS/local/compute_spectrogram_baker.py new file mode 100755 index 0000000000..1a15c7c0d4 --- /dev/null +++ b/egs/baker_zh/TTS/local/compute_spectrogram_baker.py @@ -0,0 +1,106 @@ +#!/usr/bin/env python3 +# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This file computes fbank features of the baker_zh dataset. +It looks for manifests in the directory data/manifests. + +The generated spectrogram features are saved in data/spectrogram. +""" + +import logging +import os +from pathlib import Path + +import torch +from lhotse import ( + CutSet, + LilcomChunkyWriter, + Spectrogram, + SpectrogramConfig, + load_manifest, +) +from lhotse.audio import RecordingSet +from lhotse.supervision import SupervisionSet + +from icefall.utils import get_executor + +# Torch's multithreaded behavior needs to be disabled or +# it wastes a lot of CPU and slow things down. +# Do this outside of main() in case it needs to take effect +# even when we are not invoking the main (e.g. when spawning subprocesses). +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + + +def compute_spectrogram_baker_zh(): + src_dir = Path("data/manifests") + output_dir = Path("data/spectrogram") + num_jobs = min(4, os.cpu_count()) + + sampling_rate = 48000 + frame_length = 1024 / sampling_rate # (in second) + frame_shift = 256 / sampling_rate # (in second) + use_fft_mag = True + + prefix = "baker_zh" + suffix = "jsonl.gz" + partition = "all" + + recordings = load_manifest( + src_dir / f"{prefix}_recordings_{partition}.{suffix}", RecordingSet + ) + supervisions = load_manifest( + src_dir / f"{prefix}_supervisions_{partition}.{suffix}", SupervisionSet + ) + + config = SpectrogramConfig( + sampling_rate=sampling_rate, + frame_length=frame_length, + frame_shift=frame_shift, + use_fft_mag=use_fft_mag, + ) + extractor = Spectrogram(config) + + with get_executor() as ex: # Initialize the executor only once. + cuts_filename = f"{prefix}_cuts_{partition}.{suffix}" + if (output_dir / cuts_filename).is_file(): + logging.info(f"{cuts_filename} already exists - skipping.") + return + logging.info(f"Processing {partition}") + cut_set = CutSet.from_manifests( + recordings=recordings, supervisions=supervisions + ) + + cut_set = cut_set.compute_and_store_features( + extractor=extractor, + storage_path=f"{output_dir}/{prefix}_feats_{partition}", + # when an executor is specified, make more partitions + num_jobs=num_jobs if ex is None else 80, + executor=ex, + storage_type=LilcomChunkyWriter, + ) + cut_set.to_file(output_dir / cuts_filename) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + compute_spectrogram_baker_zh() diff --git a/egs/baker_zh/TTS/local/pinyin_dict.py b/egs/baker_zh/TTS/local/pinyin_dict.py new file mode 100644 index 0000000000..950fb39fc0 --- /dev/null +++ b/egs/baker_zh/TTS/local/pinyin_dict.py @@ -0,0 +1,421 @@ +# This dict is copied from +# https://github.com/UEhQZXI/vits_chinese/blob/master/vits_strings.py +pinyin_dict = { + "a": ("^", "a"), + "ai": ("^", "ai"), + "an": ("^", "an"), + "ang": ("^", "ang"), + "ao": ("^", "ao"), + "ba": ("b", "a"), + "bai": ("b", "ai"), + "ban": ("b", "an"), + "bang": ("b", "ang"), + "bao": ("b", "ao"), + "be": ("b", "e"), + "bei": ("b", "ei"), + "ben": ("b", "en"), + "beng": ("b", "eng"), + "bi": ("b", "i"), + "bian": ("b", "ian"), + "biao": ("b", "iao"), + "bie": ("b", "ie"), + "bin": ("b", "in"), + "bing": ("b", "ing"), + "bo": ("b", "o"), + "bu": ("b", "u"), + "ca": ("c", "a"), + "cai": ("c", "ai"), + "can": ("c", "an"), + "cang": ("c", "ang"), + "cao": ("c", "ao"), + "ce": ("c", "e"), + "cen": ("c", "en"), + "ceng": ("c", "eng"), + "cha": ("ch", "a"), + "chai": ("ch", "ai"), + "chan": ("ch", "an"), + "chang": ("ch", "ang"), + "chao": ("ch", "ao"), + "che": ("ch", "e"), + "chen": ("ch", "en"), + "cheng": ("ch", "eng"), + "chi": ("ch", "iii"), + "chong": ("ch", "ong"), + "chou": ("ch", "ou"), + "chu": ("ch", "u"), + "chua": ("ch", "ua"), + "chuai": ("ch", "uai"), + "chuan": ("ch", "uan"), + "chuang": ("ch", "uang"), + "chui": ("ch", "uei"), + "chun": ("ch", "uen"), + "chuo": ("ch", "uo"), + "ci": ("c", "ii"), + "cong": ("c", "ong"), + "cou": ("c", "ou"), + "cu": ("c", "u"), + "cuan": ("c", "uan"), + "cui": ("c", "uei"), + "cun": ("c", "uen"), + "cuo": ("c", "uo"), + "da": ("d", "a"), + "dai": ("d", "ai"), + "dan": ("d", "an"), + "dang": ("d", "ang"), + "dao": ("d", "ao"), + "de": ("d", "e"), + "dei": ("d", "ei"), + "den": ("d", "en"), + "deng": ("d", "eng"), + "di": ("d", "i"), + "dia": ("d", "ia"), + "dian": ("d", "ian"), + "diao": ("d", "iao"), + "die": ("d", "ie"), + "ding": ("d", "ing"), + "diu": ("d", "iou"), + "dong": ("d", "ong"), + "dou": ("d", "ou"), + "du": ("d", "u"), + "duan": ("d", "uan"), + "dui": ("d", "uei"), + "dun": ("d", "uen"), + "duo": ("d", "uo"), + "e": ("^", "e"), + "ei": ("^", "ei"), + "en": ("^", "en"), + "ng": ("^", "en"), + "eng": ("^", "eng"), + "er": ("^", "er"), + "fa": ("f", "a"), + "fan": ("f", "an"), + "fang": ("f", "ang"), + "fei": ("f", "ei"), + "fen": ("f", "en"), + "feng": ("f", "eng"), + "fo": ("f", "o"), + "fou": ("f", "ou"), + "fu": ("f", "u"), + "ga": ("g", "a"), + "gai": ("g", "ai"), + "gan": ("g", "an"), + "gang": ("g", "ang"), + "gao": ("g", "ao"), + "ge": ("g", "e"), + "gei": ("g", "ei"), + "gen": ("g", "en"), + "geng": ("g", "eng"), + "gong": ("g", "ong"), + "gou": ("g", "ou"), + "gu": ("g", "u"), + "gua": ("g", "ua"), + "guai": ("g", "uai"), + "guan": ("g", "uan"), + "guang": ("g", "uang"), + "gui": ("g", "uei"), + "gun": ("g", "uen"), + "guo": ("g", "uo"), + "ha": ("h", "a"), + "hai": ("h", "ai"), + "han": ("h", "an"), + "hang": ("h", "ang"), + "hao": ("h", "ao"), + "he": ("h", "e"), + "hei": ("h", "ei"), + "hen": ("h", "en"), + "heng": ("h", "eng"), + "hong": ("h", "ong"), + "hou": ("h", "ou"), + "hu": ("h", "u"), + "hua": ("h", "ua"), + "huai": ("h", "uai"), + "huan": ("h", "uan"), + "huang": ("h", "uang"), + "hui": ("h", "uei"), + "hun": ("h", "uen"), + "huo": ("h", "uo"), + "ji": ("j", "i"), + "jia": ("j", "ia"), + "jian": ("j", "ian"), + "jiang": ("j", "iang"), + "jiao": ("j", "iao"), + "jie": ("j", "ie"), + "jin": ("j", "in"), + "jing": ("j", "ing"), + "jiong": ("j", "iong"), + "jiu": ("j", "iou"), + "ju": ("j", "v"), + "juan": ("j", "van"), + "jue": ("j", "ve"), + "jun": ("j", "vn"), + "ka": ("k", "a"), + "kai": ("k", "ai"), + "kan": ("k", "an"), + "kang": ("k", "ang"), + "kao": ("k", "ao"), + "ke": ("k", "e"), + "kei": ("k", "ei"), + "ken": ("k", "en"), + "keng": ("k", "eng"), + "kong": ("k", "ong"), + "kou": ("k", "ou"), + "ku": ("k", "u"), + "kua": ("k", "ua"), + "kuai": ("k", "uai"), + "kuan": ("k", "uan"), + "kuang": ("k", "uang"), + "kui": ("k", "uei"), + "kun": ("k", "uen"), + "kuo": ("k", "uo"), + "la": ("l", "a"), + "lai": ("l", "ai"), + "lan": ("l", "an"), + "lang": ("l", "ang"), + "lao": ("l", "ao"), + "le": ("l", "e"), + "lei": ("l", "ei"), + "leng": ("l", "eng"), + "li": ("l", "i"), + "lia": ("l", "ia"), + "lian": ("l", "ian"), + "liang": ("l", "iang"), + "liao": ("l", "iao"), + "lie": ("l", "ie"), + "lin": ("l", "in"), + "ling": ("l", "ing"), + "liu": ("l", "iou"), + "lo": ("l", "o"), + "long": ("l", "ong"), + "lou": ("l", "ou"), + "lu": ("l", "u"), + "lv": ("l", "v"), + "luan": ("l", "uan"), + "lve": ("l", "ve"), + "lue": ("l", "ve"), + "lun": ("l", "uen"), + "luo": ("l", "uo"), + "ma": ("m", "a"), + "mai": ("m", "ai"), + "man": ("m", "an"), + "mang": ("m", "ang"), + "mao": ("m", "ao"), + "me": ("m", "e"), + "mei": ("m", "ei"), + "men": ("m", "en"), + "meng": ("m", "eng"), + "mi": ("m", "i"), + "mian": ("m", "ian"), + "miao": ("m", "iao"), + "mie": ("m", "ie"), + "min": ("m", "in"), + "ming": ("m", "ing"), + "miu": ("m", "iou"), + "mo": ("m", "o"), + "mou": ("m", "ou"), + "mu": ("m", "u"), + "na": ("n", "a"), + "nai": ("n", "ai"), + "nan": ("n", "an"), + "nang": ("n", "ang"), + "nao": ("n", "ao"), + "ne": ("n", "e"), + "nei": ("n", "ei"), + "nen": ("n", "en"), + "neng": ("n", "eng"), + "ni": ("n", "i"), + "nia": ("n", "ia"), + "nian": ("n", "ian"), + "niang": ("n", "iang"), + "niao": ("n", "iao"), + "nie": ("n", "ie"), + "nin": ("n", "in"), + "ning": ("n", "ing"), + "niu": ("n", "iou"), + "nong": ("n", "ong"), + "nou": ("n", "ou"), + "nu": ("n", "u"), + "nv": ("n", "v"), + "nuan": ("n", "uan"), + "nve": ("n", "ve"), + "nue": ("n", "ve"), + "nuo": ("n", "uo"), + "o": ("^", "o"), + "ou": ("^", "ou"), + "pa": ("p", "a"), + "pai": ("p", "ai"), + "pan": ("p", "an"), + "pang": ("p", "ang"), + "pao": ("p", "ao"), + "pe": ("p", "e"), + "pei": ("p", "ei"), + "pen": ("p", "en"), + "peng": ("p", "eng"), + "pi": ("p", "i"), + "pian": ("p", "ian"), + "piao": ("p", "iao"), + "pie": ("p", "ie"), + "pin": ("p", "in"), + "ping": ("p", "ing"), + "po": ("p", "o"), + "pou": ("p", "ou"), + "pu": ("p", "u"), + "qi": ("q", "i"), + "qia": ("q", "ia"), + "qian": ("q", "ian"), + "qiang": ("q", "iang"), + "qiao": ("q", "iao"), + "qie": ("q", "ie"), + "qin": ("q", "in"), + "qing": ("q", "ing"), + "qiong": ("q", "iong"), + "qiu": ("q", "iou"), + "qu": ("q", "v"), + "quan": ("q", "van"), + "que": ("q", "ve"), + "qun": ("q", "vn"), + "ran": ("r", "an"), + "rang": ("r", "ang"), + "rao": ("r", "ao"), + "re": ("r", "e"), + "ren": ("r", "en"), + "reng": ("r", "eng"), + "ri": ("r", "iii"), + "rong": ("r", "ong"), + "rou": ("r", "ou"), + "ru": ("r", "u"), + "rua": ("r", "ua"), + "ruan": ("r", "uan"), + "rui": ("r", "uei"), + "run": ("r", "uen"), + "ruo": ("r", "uo"), + "sa": ("s", "a"), + "sai": ("s", "ai"), + "san": ("s", "an"), + "sang": ("s", "ang"), + "sao": ("s", "ao"), + "se": ("s", "e"), + "sen": ("s", "en"), + "seng": ("s", "eng"), + "sha": ("sh", "a"), + "shai": ("sh", "ai"), + "shan": ("sh", "an"), + "shang": ("sh", "ang"), + "shao": ("sh", "ao"), + "she": ("sh", "e"), + "shei": ("sh", "ei"), + "shen": ("sh", "en"), + "sheng": ("sh", "eng"), + "shi": ("sh", "iii"), + "shou": ("sh", "ou"), + "shu": ("sh", "u"), + "shua": ("sh", "ua"), + "shuai": ("sh", "uai"), + "shuan": ("sh", "uan"), + "shuang": ("sh", "uang"), + "shui": ("sh", "uei"), + "shun": ("sh", "uen"), + "shuo": ("sh", "uo"), + "si": ("s", "ii"), + "song": ("s", "ong"), + "sou": ("s", "ou"), + "su": ("s", "u"), + "suan": ("s", "uan"), + "sui": ("s", "uei"), + "sun": ("s", "uen"), + "suo": ("s", "uo"), + "ta": ("t", "a"), + "tai": ("t", "ai"), + "tan": ("t", "an"), + "tang": ("t", "ang"), + "tao": ("t", "ao"), + "te": ("t", "e"), + "tei": ("t", "ei"), + "teng": ("t", "eng"), + "ti": ("t", "i"), + "tian": ("t", "ian"), + "tiao": ("t", "iao"), + "tie": ("t", "ie"), + "ting": ("t", "ing"), + "tong": ("t", "ong"), + "tou": ("t", "ou"), + "tu": ("t", "u"), + "tuan": ("t", "uan"), + "tui": ("t", "uei"), + "tun": ("t", "uen"), + "tuo": ("t", "uo"), + "wa": ("^", "ua"), + "wai": ("^", "uai"), + "wan": ("^", "uan"), + "wang": ("^", "uang"), + "wei": ("^", "uei"), + "wen": ("^", "uen"), + "weng": ("^", "ueng"), + "wo": ("^", "uo"), + "wu": ("^", "u"), + "xi": ("x", "i"), + "xia": ("x", "ia"), + "xian": ("x", "ian"), + "xiang": ("x", "iang"), + "xiao": ("x", "iao"), + "xie": ("x", "ie"), + "xin": ("x", "in"), + "xing": ("x", "ing"), + "xiong": ("x", "iong"), + "xiu": ("x", "iou"), + "xu": ("x", "v"), + "xuan": ("x", "van"), + "xue": ("x", "ve"), + "xun": ("x", "vn"), + "ya": ("^", "ia"), + "yan": ("^", "ian"), + "yang": ("^", "iang"), + "yao": ("^", "iao"), + "ye": ("^", "ie"), + "yi": ("^", "i"), + "yin": ("^", "in"), + "ying": ("^", "ing"), + "yo": ("^", "iou"), + "yong": ("^", "iong"), + "you": ("^", "iou"), + "yu": ("^", "v"), + "yuan": ("^", "van"), + "yue": ("^", "ve"), + "yun": ("^", "vn"), + "za": ("z", "a"), + "zai": ("z", "ai"), + "zan": ("z", "an"), + "zang": ("z", "ang"), + "zao": ("z", "ao"), + "ze": ("z", "e"), + "zei": ("z", "ei"), + "zen": ("z", "en"), + "zeng": ("z", "eng"), + "zha": ("zh", "a"), + "zhai": ("zh", "ai"), + "zhan": ("zh", "an"), + "zhang": ("zh", "ang"), + "zhao": ("zh", "ao"), + "zhe": ("zh", "e"), + "zhei": ("zh", "ei"), + "zhen": ("zh", "en"), + "zheng": ("zh", "eng"), + "zhi": ("zh", "iii"), + "zhong": ("zh", "ong"), + "zhou": ("zh", "ou"), + "zhu": ("zh", "u"), + "zhua": ("zh", "ua"), + "zhuai": ("zh", "uai"), + "zhuan": ("zh", "uan"), + "zhuang": ("zh", "uang"), + "zhui": ("zh", "uei"), + "zhun": ("zh", "uen"), + "zhuo": ("zh", "uo"), + "zi": ("z", "ii"), + "zong": ("z", "ong"), + "zou": ("z", "ou"), + "zu": ("z", "u"), + "zuan": ("z", "uan"), + "zui": ("z", "uei"), + "zun": ("z", "uen"), + "zuo": ("z", "uo"), +} diff --git a/egs/baker_zh/TTS/local/prepare_token_file.py b/egs/baker_zh/TTS/local/prepare_token_file.py new file mode 100755 index 0000000000..d90910ab02 --- /dev/null +++ b/egs/baker_zh/TTS/local/prepare_token_file.py @@ -0,0 +1,53 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This file generates the file that maps tokens to IDs. +""" + +import argparse +import logging +from pathlib import Path +from typing import Dict +from symbols import symbols + + +def get_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--tokens", + type=Path, + default=Path("data/tokens.txt"), + help="Path to the dict that maps the text tokens to IDs", + ) + + return parser.parse_args() + + +def main(): + args = get_args() + tokens = Path(args.tokens) + + with open(tokens, "w", encoding="utf-8") as f: + for token_id, token in enumerate(symbols): + f.write(f"{token} {token_id}\n") + + +if __name__ == "__main__": + main() diff --git a/egs/baker_zh/TTS/local/prepare_tokens_baker_zh.py b/egs/baker_zh/TTS/local/prepare_tokens_baker_zh.py new file mode 100755 index 0000000000..0b27fd1e9e --- /dev/null +++ b/egs/baker_zh/TTS/local/prepare_tokens_baker_zh.py @@ -0,0 +1,59 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This file reads the texts in given manifest and save the new cuts with tokens. +""" + +import logging +from pathlib import Path + +from lhotse import CutSet, load_manifest + +from tokenizer import Tokenizer + + +def prepare_tokens_baker_zh(): + output_dir = Path("data/spectrogram") + prefix = "baker_zh" + suffix = "jsonl.gz" + partition = "all" + + cut_set = load_manifest(output_dir / f"{prefix}_cuts_{partition}.{suffix}") + + tokenizer = Tokenizer() + + new_cuts = [] + i = 0 + for cut in cut_set: + # Each cut only contains one supervision + assert len(cut.supervisions) == 1, (len(cut.supervisions), cut) + text = cut.supervisions[0].normalized_text + cut.tokens = tokenizer.text_to_tokens(text) + + new_cuts.append(cut) + + new_cut_set = CutSet.from_cuts(new_cuts) + new_cut_set.to_file(output_dir / f"{prefix}_cuts_with_tokens_{partition}.{suffix}") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + logging.basicConfig(format=formatter, level=logging.INFO) + + prepare_tokens_baker_zh() diff --git a/egs/baker_zh/TTS/local/pypinyin-local.dict b/egs/baker_zh/TTS/local/pypinyin-local.dict new file mode 100644 index 0000000000..5e386014c8 --- /dev/null +++ b/egs/baker_zh/TTS/local/pypinyin-local.dict @@ -0,0 +1,328 @@ +姐姐 jie3 jie +宝宝 bao3 bao +哥哥 ge1 ge +妹妹 mei4 mei +弟弟 di4 di +妈妈 ma1 ma +开心哦 kai1 xin1 o +爸爸 ba4 ba +秘密哟 mi4 mi4 yo +哦 o +一年 yi4 nian2 +一夜 yi2 ye4 +一切 yi2 qie4 +一座 yi2 zuo4 +一下 yi2 xia4 +上一山 shang4 yi2 shan1 +下一山 xia4 yi2 shan1 +休息 xiu1 xi2 +东西 dong1 xi +上一届 shang4 yi2 jie4 +便宜 pian2 yi4 +加长 jia1 chang2 +单田芳 shan4 tian2 fang1 +帧 zhen1 +长时间 chang2 shi2 jian1 +长时 chang2 shi2 +识别 shi2 bie2 +生命中 sheng1 ming4 zhong1 +踏实 ta1 shi +嗯 en4 +溜达 liu1 da +少儿 shao4 er2 +爷爷 ye2 ye +不是 bu2 shi4 +一圈 yi1 quan1 +厜读一声 zui1 du2 yi4 sheng1 +一种 yi4 zhong3 +一簇簇 yi2 cu4 cu4 +一个 yi2 ge4 +一样 yi2 yang4 +一跩一跩 yi4 zhuai3 yi4 zhuai3 +一会儿 yi2 hui4 er +一幢 yi2 zhuang4 +挨了 ai2 le +熬菜 ao1 cai4 +扒鸡 pa2 ji1 +背枪 bei1 qiang1 +绷瓷儿 beng4 ci2 er2 +绷劲儿 beng3 jin4 er +绷着脸 beng3 zhe lian3 +藏医 zang4 yi1 +噌吰 cheng1 hong2 +差点儿 cha4 dian3 er +差失 cha1 shi1 +差误 cha1 wu4 +孱头 can4 tou +乘间 cheng2 jian4 +锄镰棘矜 chu2 lian2 ji2 qin2 +川藏 chuan1 zang4 +穿著 chuan1 zhuo2 +答讪 da1 shan4 +答言 da1 yan2 +大伯子 da4 bai3 zi +大夫 dai4 fu +弹冠 tan2 guan1 +当间 dang1 jian4 +当然咯 dang1 ran2 lo +点种 dian3 zhong3 +垛好 duo4 hao3 +发疟子 fa1 yao4 zi +饭熟了 fan4 shou2 le +附著 fu4 zhuo2 +复沓 fu4 ta4 +供稿 gong1 gao3 +供养 gong1 yang3 +骨朵 gu1 duo +骨碌 gu1 lu +果脯 guo3 fu3 +哈什玛 ha4 shi2 ma3 +海蜇 hai3 zhe2 +呵欠 he1 qian +河水汤汤 he2 shui3 shang1 shang1 +鹄立 hu2 li4 +鹄望 hu2 wang4 +混人 hun2 ren2 +混水 hun2 shui3 +鸡血 ji1 xie3 +缉鞋口 qi1 xie2 kou3 +亟来闻讯 qi4 lai2 wen2 xun4 +计量 ji4 liang2 +济水 ji3 shui3 +间杂 jian4 za2 +脚跐两只船 jiao3 ci3 liang3 zhi1 chuan2 +脚儿 jue2 er2 +口角 kou3 jiao3 +勒石 le4 shi2 +累进 lei3 jin4 +累累如丧家之犬 lei2 lei2 ru2 sang4 jia1 zhi1 quan3 +累年 lei3 nian2 +脸涨通红 lian3 zhang4 tong1 hong2 +踉锵 liang4 qiang1 +燎眉毛 liao3 mei2 mao2 +燎头发 liao3 tou2 fa4 +溜达 liu1 da +溜缝儿 liu4 feng4 er +馏口饭 liu4 kou3 fan4 +遛马 liu4 ma3 +遛鸟 liu4 niao3 +遛弯儿 liu4 wan1 er +楼枪机 lou1 qiang1 ji1 +搂钱 lou1 qian2 +鹿脯 lu4 fu3 +露头 lou4 tou2 +落魄 luo4 po4 +捋胡子 lv3 hu2 zi +绿地 lv4 di4 +麦垛 mai4 duo4 +没劲儿 mei2 jin4 er +闷棍 men4 gun4 +闷葫芦 men4 hu2 lu +闷头干 men1 tou2 gan4 +蒙古 meng3 gu3 +靡日不思 mi3 ri4 bu4 si1 +缪姓 miao4 xing4 +抹墙 mo4 qiang2 +抹下脸 ma1 xia4 lian3 +泥子 ni4 zi +拗不过 niu4 bu guo4 +排车 pai3 che1 +盘诘 pan2 jie2 +膀肿 pang1 zhong3 +炮干 bao1 gan1 +炮格 pao2 ge2 +碰钉子 peng4 ding1 zi +缥色 piao3 se4 +瀑河 bao4 he2 +蹊径 xi1 jing4 +前后相属 qian2 hou4 xiang1 zhu3 +翘尾巴 qiao4 wei3 ba +趄坡儿 qie4 po1 er +秦桧 qin2 hui4 +圈马 juan1 ma3 +雀盲眼 qiao3 mang2 yan3 +雀子 qiao1 zi +三年五载 san1 nian2 wu3 zai3 +加载 jia1 zai3 +山大王 shan1 dai4 wang +苫屋草 shan4 wu1 cao3 +数数 shu3 shu4 +说客 shui4 ke4 +思量 si1 liang2 +伺侯 ci4 hou +踏实 ta1 shi +提溜 di1 liu +调拨 diao4 bo1 +帖子 tie3 zi +铜钿 tong2 tian2 +头昏脑涨 tou2 hun1 nao3 zhang4 +褪色 tui4 se4 +褪着手 tun4 zhe shou3 +圩子 wei2 zi +尾巴 wei3 ba +系好船只 xi4 hao3 chuan2 zhi1 +系好马匹 xi4 hao3 ma3 pi3 +杏脯 xing4 fu3 +姓单 xing4 shan4 +姓葛 xing4 ge3 +姓哈 xing4 ha3 +姓解 xing4 xie4 +姓秘 xing4 bi4 +姓宁 xing4 ning4 +旋风 xuan4 feng1 +旋根车轴 xuan4 gen1 che1 zhou2 +荨麻 qian2 ma2 +一幢楼房 yi1 zhuang4 lou2 fang2 +遗之千金 wei4 zhi1 qian1 jin1 +殷殷 yin3 yin3 +应招 ying4 zhao1 +用称约 yong4 cheng4 yao1 +约斤肉 yao1 jin1 rou4 +晕机 yun4 ji1 +熨贴 yu4 tie1 +咋办 za3 ban4 +咋呼 zha1 hu +仔兽 zi3 shou4 +扎彩 za1 cai3 +扎实 zha1 shi +扎腰带 za1 yao1 dai4 +轧朋友 ga2 peng2 you3 +爪子 zhua3 zi +折腾 zhe1 teng +着实 zhuo2 shi2 +着我旧时裳 zhuo2 wo3 jiu4 shi2 chang2 +枝蔓 zhi1 man4 +中鹄 zhong1 hu2 +中选 zhong4 xuan3 +猪圈 zhu1 juan4 +拽住不放 zhuai4 zhu4 bu4 fang4 +转悠 zhuan4 you +庄稼熟了 zhuang1 jia shou2 le +酌量 zhuo2 liang2 +罪行累累 zui4 xing2 lei3 lei3 +一手 yi4 shou3 +一去不复返 yi2 qu4 bu2 fu4 fan3 +一颗 yi4 ke1 +一件 yi2 jian4 +一斤 yi4 jin1 +一点 yi4 dian3 +一朵 yi4 duo3 +一声 yi4 sheng1 +一身 yi4 shen1 +不要 bu2 yao4 +一人 yi4 ren2 +一个 yi2 ge4 +一把 yi4 ba3 +一门 yi4 men2 +一門 yi4 men2 +一艘 yi4 sou1 +一片 yi2 pian4 +一篇 yi2 pian1 +一份 yi2 fen4 +好嗲 hao3 dia3 +随地 sui2 di4 +扁担长 bian3 dan4 chang3 +一堆 yi4 dui1 +不义 bu2 yi4 +放一放 fang4 yi2 fang4 +一米 yi4 mi3 +一顿 yi2 dun4 +一层楼 yi4 ceng2 lou2 +一条 yi4 tiao2 +一件 yi2 jian4 +一棵 yi4 ke1 +一小股 yi4 xiao3 gu3 +一拐一拐 yi4 guai3 yi4 guai3 +一根 yi4 gen1 +沆瀣一气 hang4 xie4 yi2 qi4 +一丝 yi4 si1 +一毫 yi4 hao2 +一樣 yi2 yang4 +处处 chu4 chu4 +一餐 yi4 can +永不 yong3 bu2 +一看 yi2 kan4 +一架 yi2 jia4 +送还 song4 huan2 +一见 yi2 jian4 +一座 yi2 zuo4 +一块 yi2 kuai4 +一天 yi4 tian1 +一只 yi4 zhi1 +一支 yi4 zhi1 +一字 yi2 zi4 +一句 yi2 ju4 +一张 yi4 zhang1 +一條 yi4 tiao2 +一场 yi4 chang3 +一粒 yi2 li4 +小俩口 xiao3 liang3 kou3 +一首 yi4 shou3 +一对 yi2 dui4 +一手 yi4 shou3 +又一村 you4 yi4 cun1 +一概而论 yi2 gai4 er2 lun4 +一峰峰 yi4 feng1 feng1 +不但 bu2 dan4 +一笑 yi2 xiao4 +挠痒痒 nao2 yang3 yang +不对 bu2 dui4 +拧开 ning3 kai1 +爱不释手 ai4 bu2 shi4 shou3 +一念 yi2 nian4 +夺得 duo2 de2 +一袭 yi4 xi2 +一定 yi2 ding4 +不慎 bu2 shen4 +剽窃 piao2 qie4 +一时 yi4 shi2 +撇开 pie3 kai1 +一祭 yi2 ji4 +发卡 fa4 qia3 +少不了 shao3 bu4 liao3 +千虑一失 qian1 lv4 yi4 shi1 +呛得 qiang4 de2 +切菜 qie1 cai4 +茄盒 qie2 he2 +不去 bu2 qu4 +一大圈 yi2 da4 quan1 +不再 bu2 zai4 +一群 yi4 qun2 +不必 bu2 bi4 +一些 yi4 xie1 +一路 yi2 lu4 +一股 yi4 gu3 +一到 yi2 dao4 +一拨 yi4 bo1 +一排 yi4 pai2 +一空 yi4 kong1 +吮吸着 shun3 xi1 zhe +不适合 bu2 shi4 he2 +一串串 yi2 chuan4 chuan4 +一提起 yi4 ti2 qi3 +一尘不染 yi4 chen2 bu4 ran3 +一生 yi4 sheng1 +一派 yi2 pai4 +不断 bu2 duan4 +一次 yi2 ci4 +不进步 bu2 jin4 bu4 +娃娃 wa2 wa +万户侯 wan4 hu4 hou2 +一方 yi4 fang1 +一番话 yi4 fan1 hua4 +一遍 yi2 bian4 +不计较 bu2 ji4 jiao4 +诇 xiong4 +一边 yi4 bian1 +一束 yi2 shu4 +一听到 yi4 ting1 dao4 +炸鸡 zha2 ji1 +乍暧还寒 zha4 ai4 huan2 han2 +我说诶 wo3 shuo1 ei1 +棒诶 bang4 ei1 +寒碜 han2 chen4 +应采儿 ying4 cai3 er2 +晕车 yun1 che1 +必应 bi4 ying4 +应援 ying4 yuan2 +应力 ying4 li4 \ No newline at end of file diff --git a/egs/baker_zh/TTS/local/symbols.py b/egs/baker_zh/TTS/local/symbols.py new file mode 100644 index 0000000000..1e68788704 --- /dev/null +++ b/egs/baker_zh/TTS/local/symbols.py @@ -0,0 +1,73 @@ +# This file is copied from +# https://github.com/UEhQZXI/vits_chinese/blob/master/text/symbols.py +_pause = ["sil", "eos", "sp", "#0", "#1", "#2", "#3"] + +_initials = [ + "^", + "b", + "c", + "ch", + "d", + "f", + "g", + "h", + "j", + "k", + "l", + "m", + "n", + "p", + "q", + "r", + "s", + "sh", + "t", + "x", + "z", + "zh", +] + +_tones = ["1", "2", "3", "4", "5"] + +_finals = [ + "a", + "ai", + "an", + "ang", + "ao", + "e", + "ei", + "en", + "eng", + "er", + "i", + "ia", + "ian", + "iang", + "iao", + "ie", + "ii", + "iii", + "in", + "ing", + "iong", + "iou", + "o", + "ong", + "ou", + "u", + "ua", + "uai", + "uan", + "uang", + "uei", + "uen", + "ueng", + "uo", + "v", + "van", + "ve", + "vn", +] + +symbols = _pause + _initials + [i + j for i in _finals for j in _tones] diff --git a/egs/baker_zh/TTS/local/tokenizer.py b/egs/baker_zh/TTS/local/tokenizer.py new file mode 100644 index 0000000000..cbf6c9c773 --- /dev/null +++ b/egs/baker_zh/TTS/local/tokenizer.py @@ -0,0 +1,137 @@ +# This file is modified from +# https://github.com/UEhQZXI/vits_chinese/blob/master/vits_strings.py + +import logging +from pathlib import Path +from typing import List + +# Note pinyin_dict is from ./pinyin_dict.py +from pinyin_dict import pinyin_dict +from pypinyin import Style +from pypinyin.contrib.neutral_tone import NeutralToneWith5Mixin +from pypinyin.converter import DefaultConverter +from pypinyin.core import Pinyin, load_phrases_dict + + +class _MyConverter(NeutralToneWith5Mixin, DefaultConverter): + pass + + +class Tokenizer: + def __init__(self, tokens: str = ""): + self._load_pinyin_dict() + self._pinyin_parser = Pinyin(_MyConverter()) + + if tokens != "": + self._load_tokens(tokens) + + def texts_to_token_ids(self, texts: List[str], **kwargs) -> List[List[int]]: + """ + Args: + texts: + A list of sentences. + kwargs: + Not used. It is for compatibility with other TTS recipes in icefall. + """ + tokens = [] + + for text in texts: + tokens.append(self.text_to_tokens(text)) + + return self.tokens_to_token_ids(tokens) + + def tokens_to_token_ids(self, tokens: List[List[str]]) -> List[List[int]]: + ans = [] + + for token_list in tokens: + token_ids = [] + for t in token_list: + if t not in self.token2id: + logging.warning(f"Skip OOV {t}") + continue + token_ids.append(self.token2id[t]) + ans.append(token_ids) + + return ans + + def text_to_tokens(self, text: str) -> List[str]: + # Convert "," to ["sp", "sil"] + # Convert "。" to ["sil"] + # append ["eos"] at the end of a sentence + phonemes = ["sil"] + pinyins = self._pinyin_parser.pinyin( + text, + style=Style.TONE3, + errors=lambda x: [[w] for w in x], + ) + + new_pinyin = [] + for p in pinyins: + p = p[0] + if p == ",": + new_pinyin.extend(["sp", "sil"]) + elif p == "。": + new_pinyin.append("sil") + else: + new_pinyin.append(p) + sub_phonemes = self._get_phoneme4pinyin(new_pinyin) + sub_phonemes.append("eos") + phonemes.extend(sub_phonemes) + return phonemes + + def _get_phoneme4pinyin(self, pinyins): + result = [] + for pinyin in pinyins: + if pinyin in ("sil", "sp"): + result.append(pinyin) + elif pinyin[:-1] in pinyin_dict: + tone = pinyin[-1] + a = pinyin[:-1] + a1, a2 = pinyin_dict[a] + # every word is appended with a #0 + result += [a1, a2 + tone, "#0"] + + return result + + def _load_pinyin_dict(self): + this_dir = Path(__file__).parent.resolve() + my_dict = {} + with open(f"{this_dir}/pypinyin-local.dict", "r", encoding="utf-8") as f: + content = f.readlines() + for line in content: + cuts = line.strip().split() + hanzi = cuts[0] + pinyin = cuts[1:] + my_dict[hanzi] = [[p] for p in pinyin] + + load_phrases_dict(my_dict) + + def _load_tokens(self, filename): + token2id: Dict[str, int] = {} + + with open(filename, "r", encoding="utf-8") as f: + for line in f.readlines(): + info = line.rstrip().split() + if len(info) == 1: + # case of space + token = " " + idx = int(info[0]) + else: + token, idx = info[0], int(info[1]) + + assert token not in token2id, token + + token2id[token] = idx + + self.token2id = token2id + self.vocab_size = len(self.token2id) + self.pad_id = self.token2id["#0"] + + +def main(): + tokenizer = Tokenizer() + tokenizer._sentence_to_ids("你好,好的。") + + +if __name__ == "__main__": + main() diff --git a/egs/baker_zh/TTS/local/validate_manifest.py b/egs/baker_zh/TTS/local/validate_manifest.py new file mode 120000 index 0000000000..b4d52ebca0 --- /dev/null +++ b/egs/baker_zh/TTS/local/validate_manifest.py @@ -0,0 +1 @@ +../../../ljspeech/TTS/local/validate_manifest.py \ No newline at end of file diff --git a/egs/baker_zh/TTS/prepare.sh b/egs/baker_zh/TTS/prepare.sh index e69de29bb2..6fa87fe438 100755 --- a/egs/baker_zh/TTS/prepare.sh +++ b/egs/baker_zh/TTS/prepare.sh @@ -0,0 +1,124 @@ +#!/usr/bin/env bash + +# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 +export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python + +set -eou pipefail + +stage=-1 +stop_stage=100 + +dl_dir=$PWD/download + +. shared/parse_options.sh || exit 1 + +# All files generated by this script are saved in "data". +# You can safely remove "data" and rerun this script to regenerate it. +mkdir -p data + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +log "dl_dir: $dl_dir" + +if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then + log "Stage 0: build monotonic_align lib" + if [ ! -d vits/monotonic_align/build ]; then + cd vits/monotonic_align + python3 setup.py build_ext --inplace + cd ../../ + else + log "monotonic_align lib already built" + fi +fi + +if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then + log "Stage 1: Download data" + + # The directory $dl_dir/BZNSYP will contain 3 sub directories: + # - PhoneLabeling + # - ProsodyLabeling + # - Wave + + # If you have pre-downloaded it to /path/to/BZNSYP, you can create a symlink + # + # ln -sfv /path/to/BZNSYP $dl_dir/ + # touch $dl_dir/BZNSYP/.completed + # + if [ ! -d $dl_dir/BZNSYP ]; then + lhotse download baker-zh $dl_dir + fi +fi + +if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then + log "Stage 2: Prepare baker-zh manifest" + # We assume that you have downloaded the baker corpus + # to $dl_dir/BZNSYP + mkdir -p data/manifests + if [ ! -e data/manifests/.baker.done ]; then + lhotse prepare baker-zh $dl_dir/BZNSYP data/manifests + touch data/manifests/.baker.done + fi +fi + +if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then + log "Stage 3: Compute spectrogram for baker (may take 3 minutes)" + mkdir -p data/spectrogram + if [ ! -e data/spectrogram/.baker.done ]; then + ./local/compute_spectrogram_baker.py + touch data/spectrogram/.baker.done + fi + + if [ ! -e data/spectrogram/.baker-validated.done ]; then + log "Validating data/spectrogram for baker" + python3 ./local/validate_manifest.py \ + data/spectrogram/baker_zh_cuts_all.jsonl.gz + touch data/spectrogram/.baker-validated.done + fi +fi + +if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then + log "Stage 4: Prepare tokens for baker-zh (may take 20 seconds)" + if [ ! -e data/spectrogram/.baker_zh_with_token.done ]; then + + ./local/prepare_tokens_baker_zh.py + + mv -v data/spectrogram/baker_zh_cuts_with_tokens_all.jsonl.gz \ + data/spectrogram/baker_zh_cuts_all.jsonl.gz + + touch data/spectrogram/.baker_zh_with_token.done + fi +fi + +if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then + log "Stage 5: Split the baker-zh cuts into train, valid and test sets (may take 25 seconds)" + if [ ! -e data/spectrogram/.baker_zh_split.done ]; then + lhotse subset --last 600 \ + data/spectrogram/baker_zh_cuts_all.jsonl.gz \ + data/spectrogram/baker_zh_cuts_validtest.jsonl.gz + lhotse subset --first 100 \ + data/spectrogram/baker_zh_cuts_validtest.jsonl.gz \ + data/spectrogram/baker_zh_cuts_valid.jsonl.gz + lhotse subset --last 500 \ + data/spectrogram/baker_zh_cuts_validtest.jsonl.gz \ + data/spectrogram/baker_zh_cuts_test.jsonl.gz + + rm data/spectrogram/baker_zh_cuts_validtest.jsonl.gz + + n=$(( $(gunzip -c data/spectrogram/baker_zh_cuts_all.jsonl.gz | wc -l) - 600 )) + lhotse subset --first $n \ + data/spectrogram/baker_zh_cuts_all.jsonl.gz \ + data/spectrogram/baker_zh_cuts_train.jsonl.gz + touch data/spectrogram/.baker_zh_split.done + fi +fi + +if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then + log "Stage 6: Generate token file" + if [ ! -e data/tokens.txt ]; then + ./local/prepare_token_file.py --tokens data/tokens.txt + fi +fi diff --git a/egs/baker_zh/TTS/vits/export-onnx.py b/egs/baker_zh/TTS/vits/export-onnx.py new file mode 100755 index 0000000000..11c8a9791f --- /dev/null +++ b/egs/baker_zh/TTS/vits/export-onnx.py @@ -0,0 +1,414 @@ +#!/usr/bin/env python3 +# +# Copyright 2023 Xiaomi Corporation (Author: Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This script exports a VITS model from PyTorch to ONNX. + +Export the model to ONNX: +./vits/export-onnx.py \ + --epoch 1000 \ + --exp-dir vits/exp \ + --tokens data/tokens.txt + +It will generate one file inside vits/exp: + - vits-epoch-1000.onnx + +See ./test_onnx.py for how to use the exported ONNX models. +""" + +import argparse +import logging +from pathlib import Path +from typing import Dict, Tuple + +import onnx +import torch +import torch.nn as nn +from tokenizer import Tokenizer +from train import get_model, get_params + +from icefall.checkpoint import load_checkpoint + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=1000, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="vits/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--tokens", + type=str, + default="data/tokens.txt", + help="""Path to vocabulary.""", + ) + + parser.add_argument( + "--model-type", + type=str, + default="high", + choices=["low", "medium", "high"], + help="""If not empty, valid values are: low, medium, high. + It controls the model size. low -> runs faster. + """, + ) + + return parser + + +def add_meta_data(filename: str, meta_data: Dict[str, str]): + """Add meta data to an ONNX model. It is changed in-place. + + Args: + filename: + Filename of the ONNX model to be changed. + meta_data: + Key-value pairs. + """ + model = onnx.load(filename) + for key, value in meta_data.items(): + meta = model.metadata_props.add() + meta.key = key + meta.value = str(value) + + onnx.save(model, filename) + + +class OnnxModel(nn.Module): + """A wrapper for VITS generator.""" + + def __init__(self, model: nn.Module): + """ + Args: + model: + A VITS generator. + frame_shift: + The frame shift in samples. + """ + super().__init__() + self.model = model + + def forward( + self, + tokens: torch.Tensor, + tokens_lens: torch.Tensor, + noise_scale: float = 0.667, + alpha: float = 1.0, + noise_scale_dur: float = 0.8, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Please see the help information of VITS.inference_batch + + Args: + tokens: + Input text token indexes (1, T_text) + tokens_lens: + Number of tokens of shape (1,) + noise_scale (float): + Noise scale parameter for flow. + noise_scale_dur (float): + Noise scale parameter for duration predictor. + alpha (float): + Alpha parameter to control the speed of generated speech. + + Returns: + Return a tuple containing: + - audio, generated wavform tensor, (B, T_wav) + """ + audio, _, _ = self.model.generator.inference( + text=tokens, + text_lengths=tokens_lens, + noise_scale=noise_scale, + noise_scale_dur=noise_scale_dur, + alpha=alpha, + ) + return audio + + +def export_model_onnx( + model: nn.Module, + model_filename: str, + vocab_size: int, + opset_version: int = 11, +) -> None: + """Export the given generator model to ONNX format. + The exported model has one input: + + - tokens, a tensor of shape (1, T_text); dtype is torch.int64 + + and it has one output: + + - audio, a tensor of shape (1, T'); dtype is torch.float32 + + Args: + model: + The VITS generator. + model_filename: + The filename to save the exported ONNX model. + vocab_size: + Number of tokens used in training. + opset_version: + The opset version to use. + """ + tokens = torch.randint(low=0, high=vocab_size, size=(1, 13), dtype=torch.int64) + tokens_lens = torch.tensor([tokens.shape[1]], dtype=torch.int64) + noise_scale = torch.tensor([1], dtype=torch.float32) + noise_scale_dur = torch.tensor([1], dtype=torch.float32) + alpha = torch.tensor([1], dtype=torch.float32) + + torch.onnx.export( + model, + (tokens, tokens_lens, noise_scale, alpha, noise_scale_dur), + model_filename, + verbose=False, + opset_version=opset_version, + input_names=[ + "tokens", + "tokens_lens", + "noise_scale", + "alpha", + "noise_scale_dur", + ], + output_names=["audio"], + dynamic_axes={ + "tokens": {0: "N", 1: "T"}, + "tokens_lens": {0: "N"}, + "audio": {0: "N", 1: "T"}, + }, + ) + + if model.model.spks is None: + num_speakers = 1 + else: + num_speakers = model.model.spks + + meta_data = { + "model_type": "vits", + "version": "1", + "model_author": "k2-fsa", + "comment": "icefall", # must be icefall for models from icefall + "language": "Chinese", + "n_speakers": num_speakers, + "sample_rate": model.model.sampling_rate, # Must match the real sample rate + } + logging.info(f"meta_data: {meta_data}") + + add_meta_data(filename=model_filename, meta_data=meta_data) + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + tokenizer = Tokenizer(params.tokens) + params.blank_id = tokenizer.pad_id + params.vocab_size = tokenizer.vocab_size + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + + model.to("cpu") + model.eval() + + model = OnnxModel(model=model) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"generator parameters: {num_param}, or {num_param/1000/1000} M") + + suffix = f"epoch-{params.epoch}" + + opset_version = 13 + + logging.info("Exporting encoder") + model_filename = params.exp_dir / f"vits-{suffix}.onnx" + export_model_onnx( + model, + model_filename, + params.vocab_size, + opset_version=opset_version, + ) + logging.info(f"Exported generator to {model_filename}") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + logging.basicConfig(format=formatter, level=logging.INFO) + main() + +""" +Supported languages. + +LJSpeech is using "en-us" from the second column. + +Pty Language Age/Gender VoiceName File Other Languages + 5 af --/M Afrikaans gmw/af + 5 am --/M Amharic sem/am + 5 an --/M Aragonese roa/an + 5 ar --/M Arabic sem/ar + 5 as --/M Assamese inc/as + 5 az --/M Azerbaijani trk/az + 5 ba --/M Bashkir trk/ba + 5 be --/M Belarusian zle/be + 5 bg --/M Bulgarian zls/bg + 5 bn --/M Bengali inc/bn + 5 bpy --/M Bishnupriya_Manipuri inc/bpy + 5 bs --/M Bosnian zls/bs + 5 ca --/M Catalan roa/ca + 5 chr-US-Qaaa-x-west --/M Cherokee_ iro/chr + 5 cmn --/M Chinese_(Mandarin,_latin_as_English) sit/cmn (zh-cmn 5)(zh 5) + 5 cmn-latn-pinyin --/M Chinese_(Mandarin,_latin_as_Pinyin) sit/cmn-Latn-pinyin (zh-cmn 5)(zh 5) + 5 cs --/M Czech zlw/cs + 5 cv --/M Chuvash trk/cv + 5 cy --/M Welsh cel/cy + 5 da --/M Danish gmq/da + 5 de --/M German gmw/de + 5 el --/M Greek grk/el + 5 en-029 --/M English_(Caribbean) gmw/en-029 (en 10) + 2 en-gb --/M English_(Great_Britain) gmw/en (en 2) + 5 en-gb-scotland --/M English_(Scotland) gmw/en-GB-scotland (en 4) + 5 en-gb-x-gbclan --/M English_(Lancaster) gmw/en-GB-x-gbclan (en-gb 3)(en 5) + 5 en-gb-x-gbcwmd --/M English_(West_Midlands) gmw/en-GB-x-gbcwmd (en-gb 9)(en 9) + 5 en-gb-x-rp --/M English_(Received_Pronunciation) gmw/en-GB-x-rp (en-gb 4)(en 5) + 2 en-us --/M English_(America) gmw/en-US (en 3) + 5 en-us-nyc --/M English_(America,_New_York_City) gmw/en-US-nyc + 5 eo --/M Esperanto art/eo + 5 es --/M Spanish_(Spain) roa/es + 5 es-419 --/M Spanish_(Latin_America) roa/es-419 (es-mx 6) + 5 et --/M Estonian urj/et + 5 eu --/M Basque eu + 5 fa --/M Persian ira/fa + 5 fa-latn --/M Persian_(Pinglish) ira/fa-Latn + 5 fi --/M Finnish urj/fi + 5 fr-be --/M French_(Belgium) roa/fr-BE (fr 8) + 5 fr-ch --/M French_(Switzerland) roa/fr-CH (fr 8) + 5 fr-fr --/M French_(France) roa/fr (fr 5) + 5 ga --/M Gaelic_(Irish) cel/ga + 5 gd --/M Gaelic_(Scottish) cel/gd + 5 gn --/M Guarani sai/gn + 5 grc --/M Greek_(Ancient) grk/grc + 5 gu --/M Gujarati inc/gu + 5 hak --/M Hakka_Chinese sit/hak + 5 haw --/M Hawaiian map/haw + 5 he --/M Hebrew sem/he + 5 hi --/M Hindi inc/hi + 5 hr --/M Croatian zls/hr (hbs 5) + 5 ht --/M Haitian_Creole roa/ht + 5 hu --/M Hungarian urj/hu + 5 hy --/M Armenian_(East_Armenia) ine/hy (hy-arevela 5) + 5 hyw --/M Armenian_(West_Armenia) ine/hyw (hy-arevmda 5)(hy 8) + 5 ia --/M Interlingua art/ia + 5 id --/M Indonesian poz/id + 5 io --/M Ido art/io + 5 is --/M Icelandic gmq/is + 5 it --/M Italian roa/it + 5 ja --/M Japanese jpx/ja + 5 jbo --/M Lojban art/jbo + 5 ka --/M Georgian ccs/ka + 5 kk --/M Kazakh trk/kk + 5 kl --/M Greenlandic esx/kl + 5 kn --/M Kannada dra/kn + 5 ko --/M Korean ko + 5 kok --/M Konkani inc/kok + 5 ku --/M Kurdish ira/ku + 5 ky --/M Kyrgyz trk/ky + 5 la --/M Latin itc/la + 5 lb --/M Luxembourgish gmw/lb + 5 lfn --/M Lingua_Franca_Nova art/lfn + 5 lt --/M Lithuanian bat/lt + 5 ltg --/M Latgalian bat/ltg + 5 lv --/M Latvian bat/lv + 5 mi --/M Māori poz/mi + 5 mk --/M Macedonian zls/mk + 5 ml --/M Malayalam dra/ml + 5 mr --/M Marathi inc/mr + 5 ms --/M Malay poz/ms + 5 mt --/M Maltese sem/mt + 5 mto --/M Totontepec_Mixe miz/mto + 5 my --/M Myanmar_(Burmese) sit/my + 5 nb --/M Norwegian_Bokmål gmq/nb (no 5) + 5 nci --/M Nahuatl_(Classical) azc/nci + 5 ne --/M Nepali inc/ne + 5 nl --/M Dutch gmw/nl + 5 nog --/M Nogai trk/nog + 5 om --/M Oromo cus/om + 5 or --/M Oriya inc/or + 5 pa --/M Punjabi inc/pa + 5 pap --/M Papiamento roa/pap + 5 piqd --/M Klingon art/piqd + 5 pl --/M Polish zlw/pl + 5 pt --/M Portuguese_(Portugal) roa/pt (pt-pt 5) + 5 pt-br --/M Portuguese_(Brazil) roa/pt-BR (pt 6) + 5 py --/M Pyash art/py + 5 qdb --/M Lang_Belta art/qdb + 5 qu --/M Quechua qu + 5 quc --/M K'iche' myn/quc + 5 qya --/M Quenya art/qya + 5 ro --/M Romanian roa/ro + 5 ru --/M Russian zle/ru + 5 ru-cl --/M Russian_(Classic) zle/ru-cl + 2 ru-lv --/M Russian_(Latvia) zle/ru-LV + 5 sd --/M Sindhi inc/sd + 5 shn --/M Shan_(Tai_Yai) tai/shn + 5 si --/M Sinhala inc/si + 5 sjn --/M Sindarin art/sjn + 5 sk --/M Slovak zlw/sk + 5 sl --/M Slovenian zls/sl + 5 smj --/M Lule_Saami urj/smj + 5 sq --/M Albanian ine/sq + 5 sr --/M Serbian zls/sr + 5 sv --/M Swedish gmq/sv + 5 sw --/M Swahili bnt/sw + 5 ta --/M Tamil dra/ta + 5 te --/M Telugu dra/te + 5 th --/M Thai tai/th + 5 tk --/M Turkmen trk/tk + 5 tn --/M Setswana bnt/tn + 5 tr --/M Turkish trk/tr + 5 tt --/M Tatar trk/tt + 5 ug --/M Uyghur trk/ug + 5 uk --/M Ukrainian zle/uk + 5 ur --/M Urdu inc/ur + 5 uz --/M Uzbek trk/uz + 5 vi --/M Vietnamese_(Northern) aav/vi + 5 vi-vn-x-central --/M Vietnamese_(Central) aav/vi-VN-x-central + 5 vi-vn-x-south --/M Vietnamese_(Southern) aav/vi-VN-x-south + 5 yue --/M Chinese_(Cantonese) sit/yue (zh-yue 5)(zh 8) + 5 yue --/M Chinese_(Cantonese,_latin_as_Jyutping) sit/yue-Latn-jyutping (zh-yue 5)(zh 8) +""" diff --git a/egs/baker_zh/TTS/vits/generate_lexicon.py b/egs/baker_zh/TTS/vits/generate_lexicon.py new file mode 100755 index 0000000000..6d040ef539 --- /dev/null +++ b/egs/baker_zh/TTS/vits/generate_lexicon.py @@ -0,0 +1,39 @@ +#!/usr/bin/env python3 + +from pypinyin import phrases_dict, pinyin_dict +from tokenizer import Tokenizer + + +def main(): + filename = "lexicon.txt" + tokens = "./data/tokens.txt" + tokenizer = Tokenizer(tokens) + + word_dict = pinyin_dict.pinyin_dict + phrases = phrases_dict.phrases_dict + + i = 0 + with open(filename, "w", encoding="utf-8") as f: + for key in word_dict: + if not (0x4E00 <= key <= 0x9FFF): + continue + + w = chr(key) + + # 1 to remove the initial sil + # :-1 to remove the final eos + tokens = tokenizer.text_to_tokens(w)[1:-1] + + tokens = " ".join(tokens) + f.write(f"{w} {tokens}\n") + + for key in phrases: + # 1 to remove the initial sil + # :-1 to remove the final eos + tokens = tokenizer.text_to_tokens(key)[1:-1] + tokens = " ".join(tokens) + f.write(f"{key} {tokens}\n") + + +if __name__ == "__main__": + main() diff --git a/egs/baker_zh/TTS/vits/pinyin_dict.py b/egs/baker_zh/TTS/vits/pinyin_dict.py new file mode 120000 index 0000000000..b8683bd2dc --- /dev/null +++ b/egs/baker_zh/TTS/vits/pinyin_dict.py @@ -0,0 +1 @@ +../local/pinyin_dict.py \ No newline at end of file diff --git a/egs/baker_zh/TTS/vits/pypinyin-local.dict b/egs/baker_zh/TTS/vits/pypinyin-local.dict new file mode 120000 index 0000000000..5bc9b77282 --- /dev/null +++ b/egs/baker_zh/TTS/vits/pypinyin-local.dict @@ -0,0 +1 @@ +../local/pypinyin-local.dict \ No newline at end of file diff --git a/egs/baker_zh/TTS/vits/test_onnx.py b/egs/baker_zh/TTS/vits/test_onnx.py new file mode 100755 index 0000000000..66c94270ce --- /dev/null +++ b/egs/baker_zh/TTS/vits/test_onnx.py @@ -0,0 +1,142 @@ +#!/usr/bin/env python3 +# +# Copyright 2023 Xiaomi Corporation (Author: Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This script is used to test the exported onnx model by vits/export-onnx.py + +Use the onnx model to generate a wav: +./vits/test_onnx.py \ + --model-filename vits/exp/vits-epoch-1000.onnx \ + --tokens data/tokens.txt +""" + + +import argparse +import logging + +import onnxruntime as ort +import torch +import torchaudio +from tokenizer import Tokenizer + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--model-filename", + type=str, + required=True, + help="Path to the onnx model.", + ) + + parser.add_argument( + "--tokens", + type=str, + default="data/tokens.txt", + help="""Path to vocabulary.""", + ) + + parser.add_argument( + "--text", + type=str, + default="Ask not what your country can do for you; ask what you can do for your country.", + help="Text to generate speech for", + ) + + parser.add_argument( + "--output-filename", + type=str, + default="test_onnx.wav", + help="Filename to save the generated wave file.", + ) + + return parser + + +class OnnxModel: + def __init__(self, model_filename: str): + session_opts = ort.SessionOptions() + session_opts.inter_op_num_threads = 1 + session_opts.intra_op_num_threads = 1 + + self.session_opts = session_opts + + self.model = ort.InferenceSession( + model_filename, + sess_options=self.session_opts, + providers=["CPUExecutionProvider"], + ) + logging.info(f"{self.model.get_modelmeta().custom_metadata_map}") + + metadata = self.model.get_modelmeta().custom_metadata_map + self.sample_rate = int(metadata["sample_rate"]) + + def __call__(self, tokens: torch.Tensor, tokens_lens: torch.Tensor) -> torch.Tensor: + """ + Args: + tokens: + A 1-D tensor of shape (1, T) + Returns: + A tensor of shape (1, T') + """ + noise_scale = torch.tensor([0.667], dtype=torch.float32) + noise_scale_dur = torch.tensor([0.8], dtype=torch.float32) + alpha = torch.tensor([1.0], dtype=torch.float32) + + out = self.model.run( + [ + self.model.get_outputs()[0].name, + ], + { + self.model.get_inputs()[0].name: tokens.numpy(), + self.model.get_inputs()[1].name: tokens_lens.numpy(), + self.model.get_inputs()[2].name: noise_scale.numpy(), + self.model.get_inputs()[3].name: alpha.numpy(), + self.model.get_inputs()[4].name: noise_scale_dur.numpy(), + }, + )[0] + return torch.from_numpy(out) + + +def main(): + args = get_parser().parse_args() + logging.info(vars(args)) + + tokenizer = Tokenizer(args.tokens) + + logging.info("About to create onnx model") + model = OnnxModel(args.model_filename) + + text = args.text + tokens = tokenizer.texts_to_token_ids([text]) + tokens = torch.tensor(tokens) # (1, T) + tokens_lens = torch.tensor([tokens.shape[1]], dtype=torch.int64) # (1, T) + audio = model(tokens, tokens_lens) # (1, T') + + output_filename = args.output_filename + torchaudio.save(output_filename, audio, sample_rate=model.sample_rate) + logging.info(f"Saved to {output_filename}") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/baker_zh/TTS/vits/tokenizer.py b/egs/baker_zh/TTS/vits/tokenizer.py new file mode 120000 index 0000000000..0368e07d34 --- /dev/null +++ b/egs/baker_zh/TTS/vits/tokenizer.py @@ -0,0 +1 @@ +../local/tokenizer.py \ No newline at end of file diff --git a/egs/baker_zh/TTS/vits/train.py b/egs/baker_zh/TTS/vits/train.py deleted file mode 120000 index ea0fad02a8..0000000000 --- a/egs/baker_zh/TTS/vits/train.py +++ /dev/null @@ -1 +0,0 @@ -../../../ljspeech/TTS/vits/train.py \ No newline at end of file diff --git a/egs/baker_zh/TTS/vits/train.py b/egs/baker_zh/TTS/vits/train.py new file mode 100755 index 0000000000..694129a89d --- /dev/null +++ b/egs/baker_zh/TTS/vits/train.py @@ -0,0 +1,927 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import logging +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, Optional, Tuple, Union + +import k2 +import numpy as np +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from lhotse.cut import Cut +from lhotse.utils import fix_random_seed +from tokenizer import Tokenizer +from torch.cuda.amp import GradScaler, autocast +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.optim import Optimizer +from torch.utils.tensorboard import SummaryWriter +from tts_datamodule import BakerZhSpeechTtsDataModule +from utils import MetricsTracker, plot_feature, save_checkpoint +from vits import VITS + +from icefall import diagnostics +from icefall.checkpoint import load_checkpoint +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.hooks import register_inf_check_hooks +from icefall.utils import AttributeDict, setup_logger, str2bool + +LRSchedulerType = torch.optim.lr_scheduler._LRScheduler + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=1000, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="vits/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--tokens", + type=str, + default="data/tokens.txt", + help="""Path to vocabulary.""", + ) + + parser.add_argument( + "--lr", type=float, default=2.0e-4, help="The base learning rate." + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=20, + help="""Save checkpoint after processing this number of epochs" + periodically. We save checkpoint to exp-dir/ whenever + params.cur_epoch % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/epoch-{params.cur_epoch}.pt'. + Since it will take around 1000 epochs, we suggest using a large + save_every_n to save disk space. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + parser.add_argument( + "--model-type", + type=str, + default="high", + choices=["low", "medium", "high"], + help="""If not empty, valid values are: low, medium, high. + It controls the model size. low -> runs faster. + """, + ) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + """ + params = AttributeDict( + { + # training params + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": -1, # 0 + "log_interval": 50, + "valid_interval": 200, + "env_info": get_env_info(), + "sampling_rate": 48000, + "frame_shift": 256, + "frame_length": 1024, + "feature_dim": 513, # 1024 // 2 + 1, 1024 is fft_length + "n_mels": 80, + "lambda_adv": 1.0, # loss scaling coefficient for adversarial loss + "lambda_mel": 45.0, # loss scaling coefficient for Mel loss + "lambda_feat_match": 2.0, # loss scaling coefficient for feat match loss + "lambda_dur": 1.0, # loss scaling coefficient for duration loss + "lambda_kl": 1.0, # loss scaling coefficient for KL divergence loss + } + ) + + return params + + +def load_checkpoint_if_available( + params: AttributeDict, model: nn.Module +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint(filename, model=model) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + return saved_params + + +def get_model(params: AttributeDict) -> nn.Module: + mel_loss_params = { + "n_mels": params.n_mels, + "frame_length": params.frame_length, + "frame_shift": params.frame_shift, + } + model = VITS( + vocab_size=params.vocab_size, + feature_dim=params.feature_dim, + sampling_rate=params.sampling_rate, + model_type=params.model_type, + mel_loss_params=mel_loss_params, + lambda_adv=params.lambda_adv, + lambda_mel=params.lambda_mel, + lambda_feat_match=params.lambda_feat_match, + lambda_dur=params.lambda_dur, + lambda_kl=params.lambda_kl, + ) + return model + + +def prepare_input(batch: dict, tokenizer: Tokenizer, device: torch.device): + """Parse batch data""" + audio = batch["audio"].to(device) + features = batch["features"].to(device) + audio_lens = batch["audio_lens"].to(device) + features_lens = batch["features_lens"].to(device) + tokens = batch["tokens"] + + tokens = tokenizer.tokens_to_token_ids(tokens) + tokens = k2.RaggedTensor(tokens) + row_splits = tokens.shape.row_splits(1) + tokens_lens = row_splits[1:] - row_splits[:-1] + tokens = tokens.to(device) + tokens_lens = tokens_lens.to(device) + # a tensor of shape (B, T) + tokens = tokens.pad(mode="constant", padding_value=tokenizer.pad_id) + + return audio, audio_lens, features, features_lens, tokens, tokens_lens + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + tokenizer: Tokenizer, + optimizer_g: Optimizer, + optimizer_d: Optimizer, + scheduler_g: LRSchedulerType, + scheduler_d: LRSchedulerType, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + scaler: GradScaler, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + tokenizer: + Used to convert text to phonemes. + optimizer_g: + The optimizer for generator. + optimizer_d: + The optimizer for discriminator. + scheduler_g: + The learning rate scheduler for generator, we call step() every epoch. + scheduler_d: + The learning rate scheduler for discriminator, we call step() every epoch. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + + # used to track the stats over iterations in one epoch + tot_loss = MetricsTracker() + + saved_bad_model = False + + def save_bad_model(suffix: str = ""): + save_checkpoint( + filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", + model=model, + params=params, + optimizer_g=optimizer_g, + optimizer_d=optimizer_d, + scheduler_g=scheduler_g, + scheduler_d=scheduler_d, + sampler=train_dl.sampler, + scaler=scaler, + rank=0, + ) + + for batch_idx, batch in enumerate(train_dl): + params.batch_idx_train += 1 + + batch_size = len(batch["tokens"]) + audio, audio_lens, features, features_lens, tokens, tokens_lens = prepare_input( + batch, tokenizer, device + ) + + loss_info = MetricsTracker() + loss_info["samples"] = batch_size + + try: + with autocast(enabled=params.use_fp16): + # forward discriminator + loss_d, stats_d = model( + text=tokens, + text_lengths=tokens_lens, + feats=features, + feats_lengths=features_lens, + speech=audio, + speech_lengths=audio_lens, + forward_generator=False, + ) + for k, v in stats_d.items(): + loss_info[k] = v * batch_size + # update discriminator + optimizer_d.zero_grad() + scaler.scale(loss_d).backward() + scaler.step(optimizer_d) + + with autocast(enabled=params.use_fp16): + # forward generator + loss_g, stats_g = model( + text=tokens, + text_lengths=tokens_lens, + feats=features, + feats_lengths=features_lens, + speech=audio, + speech_lengths=audio_lens, + forward_generator=True, + return_sample=params.batch_idx_train % params.log_interval == 0, + ) + for k, v in stats_g.items(): + if "returned_sample" not in k: + loss_info[k] = v * batch_size + # update generator + optimizer_g.zero_grad() + scaler.scale(loss_g).backward() + scaler.step(optimizer_g) + scaler.update() + + # summary stats + tot_loss = tot_loss + loss_info + except: # noqa + save_bad_model() + raise + + if params.print_diagnostics and batch_idx == 5: + return + + if params.batch_idx_train % 100 == 0 and params.use_fp16: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + + if cur_grad_scale < 8.0 or ( + cur_grad_scale < 32.0 and params.batch_idx_train % 400 == 0 + ): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + if not saved_bad_model: + save_bad_model(suffix="-first-warning") + saved_bad_model = True + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + save_bad_model() + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + + if params.batch_idx_train % params.log_interval == 0: + cur_lr_g = max(scheduler_g.get_last_lr()) + cur_lr_d = max(scheduler_d.get_last_lr()) + cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + + logging.info( + f"Epoch {params.cur_epoch}, batch {batch_idx}, " + f"global_batch_idx: {params.batch_idx_train}, batch size: {batch_size}, " + f"loss[{loss_info}], tot_loss[{tot_loss}], " + f"cur_lr_g: {cur_lr_g:.2e}, cur_lr_d: {cur_lr_d:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate_g", cur_lr_g, params.batch_idx_train + ) + tb_writer.add_scalar( + "train/learning_rate_d", cur_lr_d, params.batch_idx_train + ) + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", cur_grad_scale, params.batch_idx_train + ) + if "returned_sample" in stats_g: + speech_hat_, speech_, mel_hat_, mel_ = stats_g["returned_sample"] + tb_writer.add_audio( + "train/speech_hat_", + speech_hat_, + params.batch_idx_train, + params.sampling_rate, + ) + tb_writer.add_audio( + "train/speech_", + speech_, + params.batch_idx_train, + params.sampling_rate, + ) + tb_writer.add_image( + "train/mel_hat_", + plot_feature(mel_hat_), + params.batch_idx_train, + dataformats="HWC", + ) + tb_writer.add_image( + "train/mel_", + plot_feature(mel_), + params.batch_idx_train, + dataformats="HWC", + ) + + if ( + params.batch_idx_train % params.valid_interval == 0 + and not params.print_diagnostics + ): + logging.info("Computing validation loss") + valid_info, (speech_hat, speech) = compute_validation_loss( + params=params, + model=model, + tokenizer=tokenizer, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + tb_writer.add_audio( + "train/valdi_speech_hat", + speech_hat, + params.batch_idx_train, + params.sampling_rate, + ) + tb_writer.add_audio( + "train/valdi_speech", + speech, + params.batch_idx_train, + params.sampling_rate, + ) + + loss_value = tot_loss["generator_loss"] / tot_loss["samples"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + tokenizer: Tokenizer, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, + rank: int = 0, +) -> Tuple[MetricsTracker, Tuple[np.ndarray, np.ndarray]]: + """Run the validation process.""" + model.eval() + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + + # used to summary the stats over iterations + tot_loss = MetricsTracker() + returned_sample = None + + with torch.no_grad(): + for batch_idx, batch in enumerate(valid_dl): + batch_size = len(batch["tokens"]) + ( + audio, + audio_lens, + features, + features_lens, + tokens, + tokens_lens, + ) = prepare_input(batch, tokenizer, device) + + loss_info = MetricsTracker() + loss_info["samples"] = batch_size + + # forward discriminator + loss_d, stats_d = model( + text=tokens, + text_lengths=tokens_lens, + feats=features, + feats_lengths=features_lens, + speech=audio, + speech_lengths=audio_lens, + forward_generator=False, + ) + assert loss_d.requires_grad is False + for k, v in stats_d.items(): + loss_info[k] = v * batch_size + + # forward generator + loss_g, stats_g = model( + text=tokens, + text_lengths=tokens_lens, + feats=features, + feats_lengths=features_lens, + speech=audio, + speech_lengths=audio_lens, + forward_generator=True, + ) + assert loss_g.requires_grad is False + for k, v in stats_g.items(): + loss_info[k] = v * batch_size + + # summary stats + tot_loss = tot_loss + loss_info + + # infer for first batch: + if batch_idx == 0 and rank == 0: + inner_model = model.module if isinstance(model, DDP) else model + audio_pred, _, duration = inner_model.inference( + text=tokens[0, : tokens_lens[0].item()] + ) + audio_pred = audio_pred.data.cpu().numpy() + audio_len_pred = ( + (duration.sum(0) * params.frame_shift).to(dtype=torch.int64).item() + ) + assert audio_len_pred == len(audio_pred), ( + audio_len_pred, + len(audio_pred), + ) + audio_gt = audio[0, : audio_lens[0].item()].data.cpu().numpy() + returned_sample = (audio_pred, audio_gt) + + if world_size > 1: + tot_loss.reduce(device) + + loss_value = tot_loss["generator_loss"] / tot_loss["samples"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss, returned_sample + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + tokenizer: Tokenizer, + optimizer_g: torch.optim.Optimizer, + optimizer_d: torch.optim.Optimizer, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + audio, audio_lens, features, features_lens, tokens, tokens_lens = prepare_input( + batch, tokenizer, device + ) + try: + # for discriminator + with autocast(enabled=params.use_fp16): + loss_d, stats_d = model( + text=tokens, + text_lengths=tokens_lens, + feats=features, + feats_lengths=features_lens, + speech=audio, + speech_lengths=audio_lens, + forward_generator=False, + ) + optimizer_d.zero_grad() + loss_d.backward() + # for generator + with autocast(enabled=params.use_fp16): + loss_g, stats_g = model( + text=tokens, + text_lengths=tokens_lens, + feats=features, + feats_lengths=features_lens, + speech=audio, + speech_lengths=audio_lens, + forward_generator=True, + ) + optimizer_g.zero_grad() + loss_g.backward() + except Exception as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + raise + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + tokenizer = Tokenizer(params.tokens) + params.blank_id = tokenizer.pad_id + params.vocab_size = tokenizer.vocab_size + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + generator = model.generator + discriminator = model.discriminator + + num_param_g = sum([p.numel() for p in generator.parameters()]) + logging.info(f"Number of parameters in generator: {num_param_g}") + num_param_d = sum([p.numel() for p in discriminator.parameters()]) + logging.info(f"Number of parameters in discriminator: {num_param_d}") + logging.info(f"Total number of parameters: {num_param_g + num_param_d}") + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available(params=params, model=model) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + optimizer_g = torch.optim.AdamW( + generator.parameters(), lr=params.lr, betas=(0.8, 0.99), eps=1e-9 + ) + optimizer_d = torch.optim.AdamW( + discriminator.parameters(), lr=params.lr, betas=(0.8, 0.99), eps=1e-9 + ) + + scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optimizer_g, gamma=0.999875) + scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optimizer_d, gamma=0.999875) + + if checkpoints is not None: + # load state_dict for optimizers + if "optimizer_g" in checkpoints: + logging.info("Loading optimizer_g state dict") + optimizer_g.load_state_dict(checkpoints["optimizer_g"]) + if "optimizer_d" in checkpoints: + logging.info("Loading optimizer_d state dict") + optimizer_d.load_state_dict(checkpoints["optimizer_d"]) + + # load state_dict for schedulers + if "scheduler_g" in checkpoints: + logging.info("Loading scheduler_g state dict") + scheduler_g.load_state_dict(checkpoints["scheduler_g"]) + if "scheduler_d" in checkpoints: + logging.info("Loading scheduler_d state dict") + scheduler_d.load_state_dict(checkpoints["scheduler_d"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 512 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + baker_zh = BakerZhSpeechTtsDataModule(args) + + train_cuts = baker_zh.train_cuts() + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 20 seconds + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + if c.duration < 1.0 or c.duration > 20.0: + # logging.warning( + # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + # ) + return False + return True + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + train_dl = baker_zh.train_dataloaders(train_cuts) + + valid_cuts = baker_zh.valid_cuts() + valid_dl = baker_zh.valid_dataloaders(valid_cuts) + + if not params.print_diagnostics: + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + tokenizer=tokenizer, + optimizer_g=optimizer_g, + optimizer_d=optimizer_d, + params=params, + ) + + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + logging.info(f"Start epoch {epoch}") + + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + params.cur_epoch = epoch + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + train_one_epoch( + params=params, + model=model, + tokenizer=tokenizer, + optimizer_g=optimizer_g, + optimizer_d=optimizer_d, + scheduler_g=scheduler_g, + scheduler_d=scheduler_d, + train_dl=train_dl, + valid_dl=valid_dl, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + if epoch % params.save_every_n == 0 or epoch == params.num_epochs: + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint( + filename=filename, + params=params, + model=model, + optimizer_g=optimizer_g, + optimizer_d=optimizer_d, + scheduler_g=scheduler_g, + scheduler_d=scheduler_d, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + if rank == 0: + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + # step per epoch + scheduler_g.step() + scheduler_d.step() + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def main(): + parser = get_parser() + BakerZhSpeechTtsDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/baker_zh/TTS/vits/tts_datamodule.py b/egs/baker_zh/TTS/vits/tts_datamodule.py index e1a9c7b3ca..96c5422771 100644 --- a/egs/baker_zh/TTS/vits/tts_datamodule.py +++ b/egs/baker_zh/TTS/vits/tts_datamodule.py @@ -52,7 +52,7 @@ def __call__(self, worker_id: int): fix_random_seed(self.seed + worker_id) -class LJSpeechTtsDataModule: +class BakerZhSpeechTtsDataModule: """ DataModule for tts experiments. It assumes there is always one train and valid dataloader, @@ -66,11 +66,12 @@ class LJSpeechTtsDataModule: - cut concatenation, - on-the-fly feature extraction - This class should be derived for specific corpora used in ASR tasks. + This class should be derived for specific corpora used in TTS tasks. """ def __init__(self, args: argparse.Namespace): self.args = args + self.sampling_rate = 48000 @classmethod def add_arguments(cls, parser: argparse.ArgumentParser): @@ -175,7 +176,7 @@ def train_dataloaders( ) if self.args.on_the_fly_feats: - sampling_rate = 22050 + sampling_rate = self.sampling_rate config = SpectrogramConfig( sampling_rate=sampling_rate, frame_length=1024 / sampling_rate, # (in second), @@ -232,7 +233,7 @@ def train_dataloaders( def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: logging.info("About to create dev dataset") if self.args.on_the_fly_feats: - sampling_rate = 22050 + sampling_rate = self.sampling_rate config = SpectrogramConfig( sampling_rate=sampling_rate, frame_length=1024 / sampling_rate, # (in second), @@ -272,7 +273,7 @@ def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: def test_dataloaders(self, cuts: CutSet) -> DataLoader: logging.info("About to create test dataset") if self.args.on_the_fly_feats: - sampling_rate = 22050 + sampling_rate = self.sampling_rate config = SpectrogramConfig( sampling_rate=sampling_rate, frame_length=1024 / sampling_rate, # (in second), @@ -311,19 +312,19 @@ def test_dataloaders(self, cuts: CutSet) -> DataLoader: def train_cuts(self) -> CutSet: logging.info("About to get train cuts") return load_manifest_lazy( - self.args.manifest_dir / "ljspeech_cuts_train.jsonl.gz" + self.args.manifest_dir / "baker_zh_cuts_train.jsonl.gz" ) @lru_cache() def valid_cuts(self) -> CutSet: logging.info("About to get validation cuts") return load_manifest_lazy( - self.args.manifest_dir / "ljspeech_cuts_valid.jsonl.gz" + self.args.manifest_dir / "baker_zh_cuts_valid.jsonl.gz" ) @lru_cache() def test_cuts(self) -> CutSet: logging.info("About to get test cuts") return load_manifest_lazy( - self.args.manifest_dir / "ljspeech_cuts_test.jsonl.gz" + self.args.manifest_dir / "baker_zh_cuts_test.jsonl.gz" ) diff --git a/egs/ljspeech/TTS/vits/monotonic_align/setup.py b/egs/ljspeech/TTS/vits/monotonic_align/setup.py index 33d75e1765..dc9ddaf489 100644 --- a/egs/ljspeech/TTS/vits/monotonic_align/setup.py +++ b/egs/ljspeech/TTS/vits/monotonic_align/setup.py @@ -1,7 +1,10 @@ # https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/monotonic_align/setup.py """Setup cython code.""" -from Cython.Build import cythonize +try: + from Cython.Build import cythonize +except ModuleNotFoundError as ex: + raise RuntimeError(f'{ex}\nPlease run:\n pip install cython') from setuptools import Extension, setup from setuptools.command.build_ext import build_ext as _build_ext diff --git a/egs/ljspeech/TTS/vits/tokenizer.py b/egs/ljspeech/TTS/vits/tokenizer.py index 3c9046adde..f314cc3624 100644 --- a/egs/ljspeech/TTS/vits/tokenizer.py +++ b/egs/ljspeech/TTS/vits/tokenizer.py @@ -44,11 +44,11 @@ def __init__(self, tokens: str): if len(info) == 1: # case of space token = " " - id = int(info[0]) + idx = int(info[0]) else: - token, id = info[0], int(info[1]) + token, idx = info[0], int(info[1]) assert token not in self.token2id, token - self.token2id[token] = id + self.token2id[token] = idx # Refer to https://github.com/rhasspy/piper/blob/master/TRAINING.md self.pad_id = self.token2id["_"] # padding diff --git a/egs/ljspeech/TTS/vits/tts_datamodule.py b/egs/ljspeech/TTS/vits/tts_datamodule.py index e1a9c7b3ca..005e1da494 100644 --- a/egs/ljspeech/TTS/vits/tts_datamodule.py +++ b/egs/ljspeech/TTS/vits/tts_datamodule.py @@ -66,7 +66,7 @@ class LJSpeechTtsDataModule: - cut concatenation, - on-the-fly feature extraction - This class should be derived for specific corpora used in ASR tasks. + This class should be derived for specific corpora used in TTS tasks. """ def __init__(self, args: argparse.Namespace): From d4fda2b3545a2c00dfa47e7b1af8926bb91883b9 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Mon, 18 Mar 2024 10:07:50 +0800 Subject: [PATCH 3/8] copy files for aishell3 --- egs/aishell3/TTS/vits/duration_predictor.py | 1 + egs/aishell3/TTS/vits/flow.py | 1 + egs/aishell3/TTS/vits/generator.py | 1 + egs/aishell3/TTS/vits/hifigan.py | 1 + egs/aishell3/TTS/vits/loss.py | 1 + egs/aishell3/TTS/vits/monotonic_align | 1 + egs/aishell3/TTS/vits/posterior_encoder.py | 1 + egs/aishell3/TTS/vits/residual_coupling.py | 1 + egs/aishell3/TTS/vits/text_encoder.py | 1 + egs/aishell3/TTS/vits/tokenizer.py | 1 + egs/aishell3/TTS/vits/transform.py | 1 + egs/aishell3/TTS/vits/utils.py | 1 + egs/aishell3/TTS/vits/vits.py | 1 + egs/aishell3/TTS/vits/wavenet.py | 1 + 14 files changed, 14 insertions(+) create mode 120000 egs/aishell3/TTS/vits/duration_predictor.py create mode 120000 egs/aishell3/TTS/vits/flow.py create mode 120000 egs/aishell3/TTS/vits/generator.py create mode 120000 egs/aishell3/TTS/vits/hifigan.py create mode 120000 egs/aishell3/TTS/vits/loss.py create mode 120000 egs/aishell3/TTS/vits/monotonic_align create mode 120000 egs/aishell3/TTS/vits/posterior_encoder.py create mode 120000 egs/aishell3/TTS/vits/residual_coupling.py create mode 120000 egs/aishell3/TTS/vits/text_encoder.py create mode 120000 egs/aishell3/TTS/vits/tokenizer.py create mode 120000 egs/aishell3/TTS/vits/transform.py create mode 120000 egs/aishell3/TTS/vits/utils.py create mode 120000 egs/aishell3/TTS/vits/vits.py create mode 120000 egs/aishell3/TTS/vits/wavenet.py diff --git a/egs/aishell3/TTS/vits/duration_predictor.py b/egs/aishell3/TTS/vits/duration_predictor.py new file mode 120000 index 0000000000..9972b476f9 --- /dev/null +++ b/egs/aishell3/TTS/vits/duration_predictor.py @@ -0,0 +1 @@ +../../../ljspeech/TTS/vits/duration_predictor.py \ No newline at end of file diff --git a/egs/aishell3/TTS/vits/flow.py b/egs/aishell3/TTS/vits/flow.py new file mode 120000 index 0000000000..e65d91ea75 --- /dev/null +++ b/egs/aishell3/TTS/vits/flow.py @@ -0,0 +1 @@ +../../../ljspeech/TTS/vits/flow.py \ No newline at end of file diff --git a/egs/aishell3/TTS/vits/generator.py b/egs/aishell3/TTS/vits/generator.py new file mode 120000 index 0000000000..611679bfa8 --- /dev/null +++ b/egs/aishell3/TTS/vits/generator.py @@ -0,0 +1 @@ +../../../ljspeech/TTS/vits/generator.py \ No newline at end of file diff --git a/egs/aishell3/TTS/vits/hifigan.py b/egs/aishell3/TTS/vits/hifigan.py new file mode 120000 index 0000000000..5ac025de72 --- /dev/null +++ b/egs/aishell3/TTS/vits/hifigan.py @@ -0,0 +1 @@ +../../../ljspeech/TTS/vits/hifigan.py \ No newline at end of file diff --git a/egs/aishell3/TTS/vits/loss.py b/egs/aishell3/TTS/vits/loss.py new file mode 120000 index 0000000000..672e5ff68d --- /dev/null +++ b/egs/aishell3/TTS/vits/loss.py @@ -0,0 +1 @@ +../../../ljspeech/TTS/vits/loss.py \ No newline at end of file diff --git a/egs/aishell3/TTS/vits/monotonic_align b/egs/aishell3/TTS/vits/monotonic_align new file mode 120000 index 0000000000..2c4923075e --- /dev/null +++ b/egs/aishell3/TTS/vits/monotonic_align @@ -0,0 +1 @@ +../../../ljspeech/TTS/vits/monotonic_align/ \ No newline at end of file diff --git a/egs/aishell3/TTS/vits/posterior_encoder.py b/egs/aishell3/TTS/vits/posterior_encoder.py new file mode 120000 index 0000000000..41d64a3a66 --- /dev/null +++ b/egs/aishell3/TTS/vits/posterior_encoder.py @@ -0,0 +1 @@ +../../../ljspeech/TTS/vits/posterior_encoder.py \ No newline at end of file diff --git a/egs/aishell3/TTS/vits/residual_coupling.py b/egs/aishell3/TTS/vits/residual_coupling.py new file mode 120000 index 0000000000..f979adbf00 --- /dev/null +++ b/egs/aishell3/TTS/vits/residual_coupling.py @@ -0,0 +1 @@ +../../../ljspeech/TTS/vits/residual_coupling.py \ No newline at end of file diff --git a/egs/aishell3/TTS/vits/text_encoder.py b/egs/aishell3/TTS/vits/text_encoder.py new file mode 120000 index 0000000000..0efba277e1 --- /dev/null +++ b/egs/aishell3/TTS/vits/text_encoder.py @@ -0,0 +1 @@ +../../../ljspeech/TTS/vits/text_encoder.py \ No newline at end of file diff --git a/egs/aishell3/TTS/vits/tokenizer.py b/egs/aishell3/TTS/vits/tokenizer.py new file mode 120000 index 0000000000..057b0dc4b1 --- /dev/null +++ b/egs/aishell3/TTS/vits/tokenizer.py @@ -0,0 +1 @@ +../../../ljspeech/TTS/vits/tokenizer.py \ No newline at end of file diff --git a/egs/aishell3/TTS/vits/transform.py b/egs/aishell3/TTS/vits/transform.py new file mode 120000 index 0000000000..962647408b --- /dev/null +++ b/egs/aishell3/TTS/vits/transform.py @@ -0,0 +1 @@ +../../../ljspeech/TTS/vits/transform.py \ No newline at end of file diff --git a/egs/aishell3/TTS/vits/utils.py b/egs/aishell3/TTS/vits/utils.py new file mode 120000 index 0000000000..085e764b43 --- /dev/null +++ b/egs/aishell3/TTS/vits/utils.py @@ -0,0 +1 @@ +../../../ljspeech/TTS/vits/utils.py \ No newline at end of file diff --git a/egs/aishell3/TTS/vits/vits.py b/egs/aishell3/TTS/vits/vits.py new file mode 120000 index 0000000000..1f58cf6fea --- /dev/null +++ b/egs/aishell3/TTS/vits/vits.py @@ -0,0 +1 @@ +../../../ljspeech/TTS/vits/vits.py \ No newline at end of file diff --git a/egs/aishell3/TTS/vits/wavenet.py b/egs/aishell3/TTS/vits/wavenet.py new file mode 120000 index 0000000000..28f0a78eeb --- /dev/null +++ b/egs/aishell3/TTS/vits/wavenet.py @@ -0,0 +1 @@ +../../../ljspeech/TTS/vits/wavenet.py \ No newline at end of file From 5e8cd61e48bb184c3905c0057aac0a373e0ac791 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Sat, 6 Apr 2024 21:49:32 +0800 Subject: [PATCH 4/8] add aishell3 --- .../TTS/local/compute_spectrogram_aishell3.py | 110 ++ egs/aishell3/TTS/local/pinyin_dict.py | 421 +++++++ egs/aishell3/TTS/local/prepare_token_file.py | 53 + .../TTS/local/prepare_tokens_aishell3.py | 62 + egs/aishell3/TTS/local/pypinyin-local.dict | 328 ++++++ egs/aishell3/TTS/local/tokenizer.py | 137 +++ egs/aishell3/TTS/local/validate_manifest.py | 1 + egs/aishell3/TTS/prepare.sh | 137 +++ egs/aishell3/TTS/shared | 1 + egs/aishell3/TTS/vits/export-onnx.py | 433 +++++++ egs/aishell3/TTS/vits/pinyin_dict.py | 1 + egs/aishell3/TTS/vits/pypinyin-local.dict | 1 + egs/aishell3/TTS/vits/tokenizer.py | 2 +- egs/aishell3/TTS/vits/train.py | 1003 +++++++++++++++++ egs/aishell3/TTS/vits/tts_datamodule.py | 349 ++++++ 15 files changed, 3038 insertions(+), 1 deletion(-) create mode 100755 egs/aishell3/TTS/local/compute_spectrogram_aishell3.py create mode 100644 egs/aishell3/TTS/local/pinyin_dict.py create mode 100755 egs/aishell3/TTS/local/prepare_token_file.py create mode 100755 egs/aishell3/TTS/local/prepare_tokens_aishell3.py create mode 100644 egs/aishell3/TTS/local/pypinyin-local.dict create mode 100644 egs/aishell3/TTS/local/tokenizer.py create mode 120000 egs/aishell3/TTS/local/validate_manifest.py create mode 100755 egs/aishell3/TTS/prepare.sh create mode 120000 egs/aishell3/TTS/shared create mode 100755 egs/aishell3/TTS/vits/export-onnx.py create mode 120000 egs/aishell3/TTS/vits/pinyin_dict.py create mode 120000 egs/aishell3/TTS/vits/pypinyin-local.dict create mode 100755 egs/aishell3/TTS/vits/train.py create mode 100644 egs/aishell3/TTS/vits/tts_datamodule.py diff --git a/egs/aishell3/TTS/local/compute_spectrogram_aishell3.py b/egs/aishell3/TTS/local/compute_spectrogram_aishell3.py new file mode 100755 index 0000000000..1c7fccad63 --- /dev/null +++ b/egs/aishell3/TTS/local/compute_spectrogram_aishell3.py @@ -0,0 +1,110 @@ +#!/usr/bin/env python3 +# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This file computes fbank features of the aishell3 dataset. +It looks for manifests in the directory data/manifests. + +The generated spectrogram features are saved in data/spectrogram. +""" + +import logging +import os +from pathlib import Path + +import torch +from lhotse import ( + CutSet, + LilcomChunkyWriter, + Spectrogram, + SpectrogramConfig, + load_manifest, +) +from lhotse.audio import RecordingSet +from lhotse.supervision import SupervisionSet + +from icefall.utils import get_executor + +# Torch's multithreaded behavior needs to be disabled or +# it wastes a lot of CPU and slow things down. +# Do this outside of main() in case it needs to take effect +# even when we are not invoking the main (e.g. when spawning subprocesses). +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + + +def compute_spectrogram_aishell3(): + src_dir = Path("data/manifests") + output_dir = Path("data/spectrogram") + num_jobs = min(4, os.cpu_count()) + + sampling_rate = 8000 + frame_length = 1024 / sampling_rate # (in second) + frame_shift = 256 / sampling_rate # (in second) + use_fft_mag = True + + prefix = "aishell3" + suffix = "jsonl.gz" + partitions = ("test", "train") + + config = SpectrogramConfig( + sampling_rate=sampling_rate, + frame_length=frame_length, + frame_shift=frame_shift, + use_fft_mag=use_fft_mag, + ) + extractor = Spectrogram(config) + + for partition in partitions: + recordings = load_manifest( + src_dir / f"{prefix}_recordings_{partition}.{suffix}", RecordingSet + ) + supervisions = load_manifest( + src_dir / f"{prefix}_supervisions_{partition}.{suffix}", SupervisionSet + ) + + # resample from 44100 to 8000 + recordings = recordings.resample(sampling_rate) + + with get_executor() as ex: # Initialize the executor only once. + cuts_filename = f"{prefix}_cuts_{partition}.{suffix}" + if (output_dir / cuts_filename).is_file(): + logging.info(f"{cuts_filename} already exists - skipping.") + return + logging.info(f"Processing {partition}") + cut_set = CutSet.from_manifests( + recordings=recordings, supervisions=supervisions + ) + + cut_set = cut_set.compute_and_store_features( + extractor=extractor, + storage_path=f"{output_dir}/{prefix}_feats_{partition}", + # when an executor is specified, make more partitions + num_jobs=num_jobs if ex is None else 80, + executor=ex, + storage_type=LilcomChunkyWriter, + ) + cut_set.to_file(output_dir / cuts_filename) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + compute_spectrogram_aishell3() diff --git a/egs/aishell3/TTS/local/pinyin_dict.py b/egs/aishell3/TTS/local/pinyin_dict.py new file mode 100644 index 0000000000..950fb39fc0 --- /dev/null +++ b/egs/aishell3/TTS/local/pinyin_dict.py @@ -0,0 +1,421 @@ +# This dict is copied from +# https://github.com/UEhQZXI/vits_chinese/blob/master/vits_strings.py +pinyin_dict = { + "a": ("^", "a"), + "ai": ("^", "ai"), + "an": ("^", "an"), + "ang": ("^", "ang"), + "ao": ("^", "ao"), + "ba": ("b", "a"), + "bai": ("b", "ai"), + "ban": ("b", "an"), + "bang": ("b", "ang"), + "bao": ("b", "ao"), + "be": ("b", "e"), + "bei": ("b", "ei"), + "ben": ("b", "en"), + "beng": ("b", "eng"), + "bi": ("b", "i"), + "bian": ("b", "ian"), + "biao": ("b", "iao"), + "bie": ("b", "ie"), + "bin": ("b", "in"), + "bing": ("b", "ing"), + "bo": ("b", "o"), + "bu": ("b", "u"), + "ca": ("c", "a"), + "cai": ("c", "ai"), + "can": ("c", "an"), + "cang": ("c", "ang"), + "cao": ("c", "ao"), + "ce": ("c", "e"), + "cen": ("c", "en"), + "ceng": ("c", "eng"), + "cha": ("ch", "a"), + "chai": ("ch", "ai"), + "chan": ("ch", "an"), + "chang": ("ch", "ang"), + "chao": ("ch", "ao"), + "che": ("ch", "e"), + "chen": ("ch", "en"), + "cheng": ("ch", "eng"), + "chi": ("ch", "iii"), + "chong": ("ch", "ong"), + "chou": ("ch", "ou"), + "chu": ("ch", "u"), + "chua": ("ch", "ua"), + "chuai": ("ch", "uai"), + "chuan": ("ch", "uan"), + "chuang": ("ch", "uang"), + "chui": ("ch", "uei"), + "chun": ("ch", "uen"), + "chuo": ("ch", "uo"), + "ci": ("c", "ii"), + "cong": ("c", "ong"), + "cou": ("c", "ou"), + "cu": ("c", "u"), + "cuan": ("c", "uan"), + "cui": ("c", "uei"), + "cun": ("c", "uen"), + "cuo": ("c", "uo"), + "da": ("d", "a"), + "dai": ("d", "ai"), + "dan": ("d", "an"), + "dang": ("d", "ang"), + "dao": ("d", "ao"), + "de": ("d", "e"), + "dei": ("d", "ei"), + "den": ("d", "en"), + "deng": ("d", "eng"), + "di": ("d", "i"), + "dia": ("d", "ia"), + "dian": ("d", "ian"), + "diao": ("d", "iao"), + "die": ("d", "ie"), + "ding": ("d", "ing"), + "diu": ("d", "iou"), + "dong": ("d", "ong"), + "dou": ("d", "ou"), + "du": ("d", "u"), + "duan": ("d", "uan"), + "dui": ("d", "uei"), + "dun": ("d", "uen"), + "duo": ("d", "uo"), + "e": ("^", "e"), + "ei": ("^", "ei"), + "en": ("^", "en"), + "ng": ("^", "en"), + "eng": ("^", "eng"), + "er": ("^", "er"), + "fa": ("f", "a"), + "fan": ("f", "an"), + "fang": ("f", "ang"), + "fei": ("f", "ei"), + "fen": ("f", "en"), + "feng": ("f", "eng"), + "fo": ("f", "o"), + "fou": ("f", "ou"), + "fu": ("f", "u"), + "ga": ("g", "a"), + "gai": ("g", "ai"), + "gan": ("g", "an"), + "gang": ("g", "ang"), + "gao": ("g", "ao"), + "ge": ("g", "e"), + "gei": ("g", "ei"), + "gen": ("g", "en"), + "geng": ("g", "eng"), + "gong": ("g", "ong"), + "gou": ("g", "ou"), + "gu": ("g", "u"), + "gua": ("g", "ua"), + "guai": ("g", "uai"), + "guan": ("g", "uan"), + "guang": ("g", "uang"), + "gui": ("g", "uei"), + "gun": ("g", "uen"), + "guo": ("g", "uo"), + "ha": ("h", "a"), + "hai": ("h", "ai"), + "han": ("h", "an"), + "hang": ("h", "ang"), + "hao": ("h", "ao"), + "he": ("h", "e"), + "hei": ("h", "ei"), + "hen": ("h", "en"), + "heng": ("h", "eng"), + "hong": ("h", "ong"), + "hou": ("h", "ou"), + "hu": ("h", "u"), + "hua": ("h", "ua"), + "huai": ("h", "uai"), + "huan": ("h", "uan"), + "huang": ("h", "uang"), + "hui": ("h", "uei"), + "hun": ("h", "uen"), + "huo": ("h", "uo"), + "ji": ("j", "i"), + "jia": ("j", "ia"), + "jian": ("j", "ian"), + "jiang": ("j", "iang"), + "jiao": ("j", "iao"), + "jie": ("j", "ie"), + "jin": ("j", "in"), + "jing": ("j", "ing"), + "jiong": ("j", "iong"), + "jiu": ("j", "iou"), + "ju": ("j", "v"), + "juan": ("j", "van"), + "jue": ("j", "ve"), + "jun": ("j", "vn"), + "ka": ("k", "a"), + "kai": ("k", "ai"), + "kan": ("k", "an"), + "kang": ("k", "ang"), + "kao": ("k", "ao"), + "ke": ("k", "e"), + "kei": ("k", "ei"), + "ken": ("k", "en"), + "keng": ("k", "eng"), + "kong": ("k", "ong"), + "kou": ("k", "ou"), + "ku": ("k", "u"), + "kua": ("k", "ua"), + "kuai": ("k", "uai"), + "kuan": ("k", "uan"), + "kuang": ("k", "uang"), + "kui": ("k", "uei"), + "kun": ("k", "uen"), + "kuo": ("k", "uo"), + "la": ("l", "a"), + "lai": ("l", "ai"), + "lan": ("l", "an"), + "lang": ("l", "ang"), + "lao": ("l", "ao"), + "le": ("l", "e"), + "lei": ("l", "ei"), + "leng": ("l", "eng"), + "li": ("l", "i"), + "lia": ("l", "ia"), + "lian": ("l", "ian"), + "liang": ("l", "iang"), + "liao": ("l", "iao"), + "lie": ("l", "ie"), + "lin": ("l", "in"), + "ling": ("l", "ing"), + "liu": ("l", "iou"), + "lo": ("l", "o"), + "long": ("l", "ong"), + "lou": ("l", "ou"), + "lu": ("l", "u"), + "lv": ("l", "v"), + "luan": ("l", "uan"), + "lve": ("l", "ve"), + "lue": ("l", "ve"), + "lun": ("l", "uen"), + "luo": ("l", "uo"), + "ma": ("m", "a"), + "mai": ("m", "ai"), + "man": ("m", "an"), + "mang": ("m", "ang"), + "mao": ("m", "ao"), + "me": ("m", "e"), + "mei": ("m", "ei"), + "men": ("m", "en"), + "meng": ("m", "eng"), + "mi": ("m", "i"), + "mian": ("m", "ian"), + "miao": ("m", "iao"), + "mie": ("m", "ie"), + "min": ("m", "in"), + "ming": ("m", "ing"), + "miu": ("m", "iou"), + "mo": ("m", "o"), + "mou": ("m", "ou"), + "mu": ("m", "u"), + "na": ("n", "a"), + "nai": ("n", "ai"), + "nan": ("n", "an"), + "nang": ("n", "ang"), + "nao": ("n", "ao"), + "ne": ("n", "e"), + "nei": ("n", "ei"), + "nen": ("n", "en"), + "neng": ("n", "eng"), + "ni": ("n", "i"), + "nia": ("n", "ia"), + "nian": ("n", "ian"), + "niang": ("n", "iang"), + "niao": ("n", "iao"), + "nie": ("n", "ie"), + "nin": ("n", "in"), + "ning": ("n", "ing"), + "niu": ("n", "iou"), + "nong": ("n", "ong"), + "nou": ("n", "ou"), + "nu": ("n", "u"), + "nv": ("n", "v"), + "nuan": ("n", "uan"), + "nve": ("n", "ve"), + "nue": ("n", "ve"), + "nuo": ("n", "uo"), + "o": ("^", "o"), + "ou": ("^", "ou"), + "pa": ("p", "a"), + "pai": ("p", "ai"), + "pan": ("p", "an"), + "pang": ("p", "ang"), + "pao": ("p", "ao"), + "pe": ("p", "e"), + "pei": ("p", "ei"), + "pen": ("p", "en"), + "peng": ("p", "eng"), + "pi": ("p", "i"), + "pian": ("p", "ian"), + "piao": ("p", "iao"), + "pie": ("p", "ie"), + "pin": ("p", "in"), + "ping": ("p", "ing"), + "po": ("p", "o"), + "pou": ("p", "ou"), + "pu": ("p", "u"), + "qi": ("q", "i"), + "qia": ("q", "ia"), + "qian": ("q", "ian"), + "qiang": ("q", "iang"), + "qiao": ("q", "iao"), + "qie": ("q", "ie"), + "qin": ("q", "in"), + "qing": ("q", "ing"), + "qiong": ("q", "iong"), + "qiu": ("q", "iou"), + "qu": ("q", "v"), + "quan": ("q", "van"), + "que": ("q", "ve"), + "qun": ("q", "vn"), + "ran": ("r", "an"), + "rang": ("r", "ang"), + "rao": ("r", "ao"), + "re": ("r", "e"), + "ren": ("r", "en"), + "reng": ("r", "eng"), + "ri": ("r", "iii"), + "rong": ("r", "ong"), + "rou": ("r", "ou"), + "ru": ("r", "u"), + "rua": ("r", "ua"), + "ruan": ("r", "uan"), + "rui": ("r", "uei"), + "run": ("r", "uen"), + "ruo": ("r", "uo"), + "sa": ("s", "a"), + "sai": ("s", "ai"), + "san": ("s", "an"), + "sang": ("s", "ang"), + "sao": ("s", "ao"), + "se": ("s", "e"), + "sen": ("s", "en"), + "seng": ("s", "eng"), + "sha": ("sh", "a"), + "shai": ("sh", "ai"), + "shan": ("sh", "an"), + "shang": ("sh", "ang"), + "shao": ("sh", "ao"), + "she": ("sh", "e"), + "shei": ("sh", "ei"), + "shen": ("sh", "en"), + "sheng": ("sh", "eng"), + "shi": ("sh", "iii"), + "shou": ("sh", "ou"), + "shu": ("sh", "u"), + "shua": ("sh", "ua"), + "shuai": ("sh", "uai"), + "shuan": ("sh", "uan"), + "shuang": ("sh", "uang"), + "shui": ("sh", "uei"), + "shun": ("sh", "uen"), + "shuo": ("sh", "uo"), + "si": ("s", "ii"), + "song": ("s", "ong"), + "sou": ("s", "ou"), + "su": ("s", "u"), + "suan": ("s", "uan"), + "sui": ("s", "uei"), + "sun": ("s", "uen"), + "suo": ("s", "uo"), + "ta": ("t", "a"), + "tai": ("t", "ai"), + "tan": ("t", "an"), + "tang": ("t", "ang"), + "tao": ("t", "ao"), + "te": ("t", "e"), + "tei": ("t", "ei"), + "teng": ("t", "eng"), + "ti": ("t", "i"), + "tian": ("t", "ian"), + "tiao": ("t", "iao"), + "tie": ("t", "ie"), + "ting": ("t", "ing"), + "tong": ("t", "ong"), + "tou": ("t", "ou"), + "tu": ("t", "u"), + "tuan": ("t", "uan"), + "tui": ("t", "uei"), + "tun": ("t", "uen"), + "tuo": ("t", "uo"), + "wa": ("^", "ua"), + "wai": ("^", "uai"), + "wan": ("^", "uan"), + "wang": ("^", "uang"), + "wei": ("^", "uei"), + "wen": ("^", "uen"), + "weng": ("^", "ueng"), + "wo": ("^", "uo"), + "wu": ("^", "u"), + "xi": ("x", "i"), + "xia": ("x", "ia"), + "xian": ("x", "ian"), + "xiang": ("x", "iang"), + "xiao": ("x", "iao"), + "xie": ("x", "ie"), + "xin": ("x", "in"), + "xing": ("x", "ing"), + "xiong": ("x", "iong"), + "xiu": ("x", "iou"), + "xu": ("x", "v"), + "xuan": ("x", "van"), + "xue": ("x", "ve"), + "xun": ("x", "vn"), + "ya": ("^", "ia"), + "yan": ("^", "ian"), + "yang": ("^", "iang"), + "yao": ("^", "iao"), + "ye": ("^", "ie"), + "yi": ("^", "i"), + "yin": ("^", "in"), + "ying": ("^", "ing"), + "yo": ("^", "iou"), + "yong": ("^", "iong"), + "you": ("^", "iou"), + "yu": ("^", "v"), + "yuan": ("^", "van"), + "yue": ("^", "ve"), + "yun": ("^", "vn"), + "za": ("z", "a"), + "zai": ("z", "ai"), + "zan": ("z", "an"), + "zang": ("z", "ang"), + "zao": ("z", "ao"), + "ze": ("z", "e"), + "zei": ("z", "ei"), + "zen": ("z", "en"), + "zeng": ("z", "eng"), + "zha": ("zh", "a"), + "zhai": ("zh", "ai"), + "zhan": ("zh", "an"), + "zhang": ("zh", "ang"), + "zhao": ("zh", "ao"), + "zhe": ("zh", "e"), + "zhei": ("zh", "ei"), + "zhen": ("zh", "en"), + "zheng": ("zh", "eng"), + "zhi": ("zh", "iii"), + "zhong": ("zh", "ong"), + "zhou": ("zh", "ou"), + "zhu": ("zh", "u"), + "zhua": ("zh", "ua"), + "zhuai": ("zh", "uai"), + "zhuan": ("zh", "uan"), + "zhuang": ("zh", "uang"), + "zhui": ("zh", "uei"), + "zhun": ("zh", "uen"), + "zhuo": ("zh", "uo"), + "zi": ("z", "ii"), + "zong": ("z", "ong"), + "zou": ("z", "ou"), + "zu": ("z", "u"), + "zuan": ("z", "uan"), + "zui": ("z", "uei"), + "zun": ("z", "uen"), + "zuo": ("z", "uo"), +} diff --git a/egs/aishell3/TTS/local/prepare_token_file.py b/egs/aishell3/TTS/local/prepare_token_file.py new file mode 100755 index 0000000000..d90910ab02 --- /dev/null +++ b/egs/aishell3/TTS/local/prepare_token_file.py @@ -0,0 +1,53 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This file generates the file that maps tokens to IDs. +""" + +import argparse +import logging +from pathlib import Path +from typing import Dict +from symbols import symbols + + +def get_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--tokens", + type=Path, + default=Path("data/tokens.txt"), + help="Path to the dict that maps the text tokens to IDs", + ) + + return parser.parse_args() + + +def main(): + args = get_args() + tokens = Path(args.tokens) + + with open(tokens, "w", encoding="utf-8") as f: + for token_id, token in enumerate(symbols): + f.write(f"{token} {token_id}\n") + + +if __name__ == "__main__": + main() diff --git a/egs/aishell3/TTS/local/prepare_tokens_aishell3.py b/egs/aishell3/TTS/local/prepare_tokens_aishell3.py new file mode 100755 index 0000000000..4b2b5094fd --- /dev/null +++ b/egs/aishell3/TTS/local/prepare_tokens_aishell3.py @@ -0,0 +1,62 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This file reads the texts in given manifest and save the new cuts with tokens. +""" + +import logging +from pathlib import Path + +from lhotse import CutSet, load_manifest + +from tokenizer import Tokenizer + + +def prepare_tokens_aishell3(): + output_dir = Path("data/spectrogram") + prefix = "aishell3" + suffix = "jsonl.gz" + partitions = ("train", "test") + + tokenizer = Tokenizer() + + for partition in partitions: + cut_set = load_manifest(output_dir / f"{prefix}_cuts_{partition}.{suffix}") + + new_cuts = [] + i = 0 + for cut in cut_set: + # Each cut only contains one supervision + assert len(cut.supervisions) == 1, (len(cut.supervisions), cut) + text = cut.supervisions[0].text + cut.tokens = tokenizer.text_to_tokens(text) + + new_cuts.append(cut) + + new_cut_set = CutSet.from_cuts(new_cuts) + new_cut_set.to_file( + output_dir / f"{prefix}_cuts_with_tokens_{partition}.{suffix}" + ) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + logging.basicConfig(format=formatter, level=logging.INFO) + + prepare_tokens_aishell3() diff --git a/egs/aishell3/TTS/local/pypinyin-local.dict b/egs/aishell3/TTS/local/pypinyin-local.dict new file mode 100644 index 0000000000..5e386014c8 --- /dev/null +++ b/egs/aishell3/TTS/local/pypinyin-local.dict @@ -0,0 +1,328 @@ +姐姐 jie3 jie +宝宝 bao3 bao +哥哥 ge1 ge +妹妹 mei4 mei +弟弟 di4 di +妈妈 ma1 ma +开心哦 kai1 xin1 o +爸爸 ba4 ba +秘密哟 mi4 mi4 yo +哦 o +一年 yi4 nian2 +一夜 yi2 ye4 +一切 yi2 qie4 +一座 yi2 zuo4 +一下 yi2 xia4 +上一山 shang4 yi2 shan1 +下一山 xia4 yi2 shan1 +休息 xiu1 xi2 +东西 dong1 xi +上一届 shang4 yi2 jie4 +便宜 pian2 yi4 +加长 jia1 chang2 +单田芳 shan4 tian2 fang1 +帧 zhen1 +长时间 chang2 shi2 jian1 +长时 chang2 shi2 +识别 shi2 bie2 +生命中 sheng1 ming4 zhong1 +踏实 ta1 shi +嗯 en4 +溜达 liu1 da +少儿 shao4 er2 +爷爷 ye2 ye +不是 bu2 shi4 +一圈 yi1 quan1 +厜读一声 zui1 du2 yi4 sheng1 +一种 yi4 zhong3 +一簇簇 yi2 cu4 cu4 +一个 yi2 ge4 +一样 yi2 yang4 +一跩一跩 yi4 zhuai3 yi4 zhuai3 +一会儿 yi2 hui4 er +一幢 yi2 zhuang4 +挨了 ai2 le +熬菜 ao1 cai4 +扒鸡 pa2 ji1 +背枪 bei1 qiang1 +绷瓷儿 beng4 ci2 er2 +绷劲儿 beng3 jin4 er +绷着脸 beng3 zhe lian3 +藏医 zang4 yi1 +噌吰 cheng1 hong2 +差点儿 cha4 dian3 er +差失 cha1 shi1 +差误 cha1 wu4 +孱头 can4 tou +乘间 cheng2 jian4 +锄镰棘矜 chu2 lian2 ji2 qin2 +川藏 chuan1 zang4 +穿著 chuan1 zhuo2 +答讪 da1 shan4 +答言 da1 yan2 +大伯子 da4 bai3 zi +大夫 dai4 fu +弹冠 tan2 guan1 +当间 dang1 jian4 +当然咯 dang1 ran2 lo +点种 dian3 zhong3 +垛好 duo4 hao3 +发疟子 fa1 yao4 zi +饭熟了 fan4 shou2 le +附著 fu4 zhuo2 +复沓 fu4 ta4 +供稿 gong1 gao3 +供养 gong1 yang3 +骨朵 gu1 duo +骨碌 gu1 lu +果脯 guo3 fu3 +哈什玛 ha4 shi2 ma3 +海蜇 hai3 zhe2 +呵欠 he1 qian +河水汤汤 he2 shui3 shang1 shang1 +鹄立 hu2 li4 +鹄望 hu2 wang4 +混人 hun2 ren2 +混水 hun2 shui3 +鸡血 ji1 xie3 +缉鞋口 qi1 xie2 kou3 +亟来闻讯 qi4 lai2 wen2 xun4 +计量 ji4 liang2 +济水 ji3 shui3 +间杂 jian4 za2 +脚跐两只船 jiao3 ci3 liang3 zhi1 chuan2 +脚儿 jue2 er2 +口角 kou3 jiao3 +勒石 le4 shi2 +累进 lei3 jin4 +累累如丧家之犬 lei2 lei2 ru2 sang4 jia1 zhi1 quan3 +累年 lei3 nian2 +脸涨通红 lian3 zhang4 tong1 hong2 +踉锵 liang4 qiang1 +燎眉毛 liao3 mei2 mao2 +燎头发 liao3 tou2 fa4 +溜达 liu1 da +溜缝儿 liu4 feng4 er +馏口饭 liu4 kou3 fan4 +遛马 liu4 ma3 +遛鸟 liu4 niao3 +遛弯儿 liu4 wan1 er +楼枪机 lou1 qiang1 ji1 +搂钱 lou1 qian2 +鹿脯 lu4 fu3 +露头 lou4 tou2 +落魄 luo4 po4 +捋胡子 lv3 hu2 zi +绿地 lv4 di4 +麦垛 mai4 duo4 +没劲儿 mei2 jin4 er +闷棍 men4 gun4 +闷葫芦 men4 hu2 lu +闷头干 men1 tou2 gan4 +蒙古 meng3 gu3 +靡日不思 mi3 ri4 bu4 si1 +缪姓 miao4 xing4 +抹墙 mo4 qiang2 +抹下脸 ma1 xia4 lian3 +泥子 ni4 zi +拗不过 niu4 bu guo4 +排车 pai3 che1 +盘诘 pan2 jie2 +膀肿 pang1 zhong3 +炮干 bao1 gan1 +炮格 pao2 ge2 +碰钉子 peng4 ding1 zi +缥色 piao3 se4 +瀑河 bao4 he2 +蹊径 xi1 jing4 +前后相属 qian2 hou4 xiang1 zhu3 +翘尾巴 qiao4 wei3 ba +趄坡儿 qie4 po1 er +秦桧 qin2 hui4 +圈马 juan1 ma3 +雀盲眼 qiao3 mang2 yan3 +雀子 qiao1 zi +三年五载 san1 nian2 wu3 zai3 +加载 jia1 zai3 +山大王 shan1 dai4 wang +苫屋草 shan4 wu1 cao3 +数数 shu3 shu4 +说客 shui4 ke4 +思量 si1 liang2 +伺侯 ci4 hou +踏实 ta1 shi +提溜 di1 liu +调拨 diao4 bo1 +帖子 tie3 zi +铜钿 tong2 tian2 +头昏脑涨 tou2 hun1 nao3 zhang4 +褪色 tui4 se4 +褪着手 tun4 zhe shou3 +圩子 wei2 zi +尾巴 wei3 ba +系好船只 xi4 hao3 chuan2 zhi1 +系好马匹 xi4 hao3 ma3 pi3 +杏脯 xing4 fu3 +姓单 xing4 shan4 +姓葛 xing4 ge3 +姓哈 xing4 ha3 +姓解 xing4 xie4 +姓秘 xing4 bi4 +姓宁 xing4 ning4 +旋风 xuan4 feng1 +旋根车轴 xuan4 gen1 che1 zhou2 +荨麻 qian2 ma2 +一幢楼房 yi1 zhuang4 lou2 fang2 +遗之千金 wei4 zhi1 qian1 jin1 +殷殷 yin3 yin3 +应招 ying4 zhao1 +用称约 yong4 cheng4 yao1 +约斤肉 yao1 jin1 rou4 +晕机 yun4 ji1 +熨贴 yu4 tie1 +咋办 za3 ban4 +咋呼 zha1 hu +仔兽 zi3 shou4 +扎彩 za1 cai3 +扎实 zha1 shi +扎腰带 za1 yao1 dai4 +轧朋友 ga2 peng2 you3 +爪子 zhua3 zi +折腾 zhe1 teng +着实 zhuo2 shi2 +着我旧时裳 zhuo2 wo3 jiu4 shi2 chang2 +枝蔓 zhi1 man4 +中鹄 zhong1 hu2 +中选 zhong4 xuan3 +猪圈 zhu1 juan4 +拽住不放 zhuai4 zhu4 bu4 fang4 +转悠 zhuan4 you +庄稼熟了 zhuang1 jia shou2 le +酌量 zhuo2 liang2 +罪行累累 zui4 xing2 lei3 lei3 +一手 yi4 shou3 +一去不复返 yi2 qu4 bu2 fu4 fan3 +一颗 yi4 ke1 +一件 yi2 jian4 +一斤 yi4 jin1 +一点 yi4 dian3 +一朵 yi4 duo3 +一声 yi4 sheng1 +一身 yi4 shen1 +不要 bu2 yao4 +一人 yi4 ren2 +一个 yi2 ge4 +一把 yi4 ba3 +一门 yi4 men2 +一門 yi4 men2 +一艘 yi4 sou1 +一片 yi2 pian4 +一篇 yi2 pian1 +一份 yi2 fen4 +好嗲 hao3 dia3 +随地 sui2 di4 +扁担长 bian3 dan4 chang3 +一堆 yi4 dui1 +不义 bu2 yi4 +放一放 fang4 yi2 fang4 +一米 yi4 mi3 +一顿 yi2 dun4 +一层楼 yi4 ceng2 lou2 +一条 yi4 tiao2 +一件 yi2 jian4 +一棵 yi4 ke1 +一小股 yi4 xiao3 gu3 +一拐一拐 yi4 guai3 yi4 guai3 +一根 yi4 gen1 +沆瀣一气 hang4 xie4 yi2 qi4 +一丝 yi4 si1 +一毫 yi4 hao2 +一樣 yi2 yang4 +处处 chu4 chu4 +一餐 yi4 can +永不 yong3 bu2 +一看 yi2 kan4 +一架 yi2 jia4 +送还 song4 huan2 +一见 yi2 jian4 +一座 yi2 zuo4 +一块 yi2 kuai4 +一天 yi4 tian1 +一只 yi4 zhi1 +一支 yi4 zhi1 +一字 yi2 zi4 +一句 yi2 ju4 +一张 yi4 zhang1 +一條 yi4 tiao2 +一场 yi4 chang3 +一粒 yi2 li4 +小俩口 xiao3 liang3 kou3 +一首 yi4 shou3 +一对 yi2 dui4 +一手 yi4 shou3 +又一村 you4 yi4 cun1 +一概而论 yi2 gai4 er2 lun4 +一峰峰 yi4 feng1 feng1 +不但 bu2 dan4 +一笑 yi2 xiao4 +挠痒痒 nao2 yang3 yang +不对 bu2 dui4 +拧开 ning3 kai1 +爱不释手 ai4 bu2 shi4 shou3 +一念 yi2 nian4 +夺得 duo2 de2 +一袭 yi4 xi2 +一定 yi2 ding4 +不慎 bu2 shen4 +剽窃 piao2 qie4 +一时 yi4 shi2 +撇开 pie3 kai1 +一祭 yi2 ji4 +发卡 fa4 qia3 +少不了 shao3 bu4 liao3 +千虑一失 qian1 lv4 yi4 shi1 +呛得 qiang4 de2 +切菜 qie1 cai4 +茄盒 qie2 he2 +不去 bu2 qu4 +一大圈 yi2 da4 quan1 +不再 bu2 zai4 +一群 yi4 qun2 +不必 bu2 bi4 +一些 yi4 xie1 +一路 yi2 lu4 +一股 yi4 gu3 +一到 yi2 dao4 +一拨 yi4 bo1 +一排 yi4 pai2 +一空 yi4 kong1 +吮吸着 shun3 xi1 zhe +不适合 bu2 shi4 he2 +一串串 yi2 chuan4 chuan4 +一提起 yi4 ti2 qi3 +一尘不染 yi4 chen2 bu4 ran3 +一生 yi4 sheng1 +一派 yi2 pai4 +不断 bu2 duan4 +一次 yi2 ci4 +不进步 bu2 jin4 bu4 +娃娃 wa2 wa +万户侯 wan4 hu4 hou2 +一方 yi4 fang1 +一番话 yi4 fan1 hua4 +一遍 yi2 bian4 +不计较 bu2 ji4 jiao4 +诇 xiong4 +一边 yi4 bian1 +一束 yi2 shu4 +一听到 yi4 ting1 dao4 +炸鸡 zha2 ji1 +乍暧还寒 zha4 ai4 huan2 han2 +我说诶 wo3 shuo1 ei1 +棒诶 bang4 ei1 +寒碜 han2 chen4 +应采儿 ying4 cai3 er2 +晕车 yun1 che1 +必应 bi4 ying4 +应援 ying4 yuan2 +应力 ying4 li4 \ No newline at end of file diff --git a/egs/aishell3/TTS/local/tokenizer.py b/egs/aishell3/TTS/local/tokenizer.py new file mode 100644 index 0000000000..cbf6c9c773 --- /dev/null +++ b/egs/aishell3/TTS/local/tokenizer.py @@ -0,0 +1,137 @@ +# This file is modified from +# https://github.com/UEhQZXI/vits_chinese/blob/master/vits_strings.py + +import logging +from pathlib import Path +from typing import List + +# Note pinyin_dict is from ./pinyin_dict.py +from pinyin_dict import pinyin_dict +from pypinyin import Style +from pypinyin.contrib.neutral_tone import NeutralToneWith5Mixin +from pypinyin.converter import DefaultConverter +from pypinyin.core import Pinyin, load_phrases_dict + + +class _MyConverter(NeutralToneWith5Mixin, DefaultConverter): + pass + + +class Tokenizer: + def __init__(self, tokens: str = ""): + self._load_pinyin_dict() + self._pinyin_parser = Pinyin(_MyConverter()) + + if tokens != "": + self._load_tokens(tokens) + + def texts_to_token_ids(self, texts: List[str], **kwargs) -> List[List[int]]: + """ + Args: + texts: + A list of sentences. + kwargs: + Not used. It is for compatibility with other TTS recipes in icefall. + """ + tokens = [] + + for text in texts: + tokens.append(self.text_to_tokens(text)) + + return self.tokens_to_token_ids(tokens) + + def tokens_to_token_ids(self, tokens: List[List[str]]) -> List[List[int]]: + ans = [] + + for token_list in tokens: + token_ids = [] + for t in token_list: + if t not in self.token2id: + logging.warning(f"Skip OOV {t}") + continue + token_ids.append(self.token2id[t]) + ans.append(token_ids) + + return ans + + def text_to_tokens(self, text: str) -> List[str]: + # Convert "," to ["sp", "sil"] + # Convert "。" to ["sil"] + # append ["eos"] at the end of a sentence + phonemes = ["sil"] + pinyins = self._pinyin_parser.pinyin( + text, + style=Style.TONE3, + errors=lambda x: [[w] for w in x], + ) + + new_pinyin = [] + for p in pinyins: + p = p[0] + if p == ",": + new_pinyin.extend(["sp", "sil"]) + elif p == "。": + new_pinyin.append("sil") + else: + new_pinyin.append(p) + sub_phonemes = self._get_phoneme4pinyin(new_pinyin) + sub_phonemes.append("eos") + phonemes.extend(sub_phonemes) + return phonemes + + def _get_phoneme4pinyin(self, pinyins): + result = [] + for pinyin in pinyins: + if pinyin in ("sil", "sp"): + result.append(pinyin) + elif pinyin[:-1] in pinyin_dict: + tone = pinyin[-1] + a = pinyin[:-1] + a1, a2 = pinyin_dict[a] + # every word is appended with a #0 + result += [a1, a2 + tone, "#0"] + + return result + + def _load_pinyin_dict(self): + this_dir = Path(__file__).parent.resolve() + my_dict = {} + with open(f"{this_dir}/pypinyin-local.dict", "r", encoding="utf-8") as f: + content = f.readlines() + for line in content: + cuts = line.strip().split() + hanzi = cuts[0] + pinyin = cuts[1:] + my_dict[hanzi] = [[p] for p in pinyin] + + load_phrases_dict(my_dict) + + def _load_tokens(self, filename): + token2id: Dict[str, int] = {} + + with open(filename, "r", encoding="utf-8") as f: + for line in f.readlines(): + info = line.rstrip().split() + if len(info) == 1: + # case of space + token = " " + idx = int(info[0]) + else: + token, idx = info[0], int(info[1]) + + assert token not in token2id, token + + token2id[token] = idx + + self.token2id = token2id + self.vocab_size = len(self.token2id) + self.pad_id = self.token2id["#0"] + + +def main(): + tokenizer = Tokenizer() + tokenizer._sentence_to_ids("你好,好的。") + + +if __name__ == "__main__": + main() diff --git a/egs/aishell3/TTS/local/validate_manifest.py b/egs/aishell3/TTS/local/validate_manifest.py new file mode 120000 index 0000000000..b4d52ebca0 --- /dev/null +++ b/egs/aishell3/TTS/local/validate_manifest.py @@ -0,0 +1 @@ +../../../ljspeech/TTS/local/validate_manifest.py \ No newline at end of file diff --git a/egs/aishell3/TTS/prepare.sh b/egs/aishell3/TTS/prepare.sh new file mode 100755 index 0000000000..af532c2296 --- /dev/null +++ b/egs/aishell3/TTS/prepare.sh @@ -0,0 +1,137 @@ +#!/usr/bin/env bash + +# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 +export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python + +set -eou pipefail + +stage=-1 +stop_stage=100 + +dl_dir=$PWD/download + +. shared/parse_options.sh || exit 1 + +# All files generated by this script are saved in "data". +# You can safely remove "data" and rerun this script to regenerate it. +mkdir -p data + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +log "dl_dir: $dl_dir" + +if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then + log "Stage 0: build monotonic_align lib" + if [ ! -d vits/monotonic_align/build ]; then + cd vits/monotonic_align + python3 setup.py build_ext --inplace + cd ../../ + else + log "monotonic_align lib already built" + fi +fi + +if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then + log "Stage 1: Download data" + + # The directory $dl_dir/aishell3 will contain the following files + # and sub directories + # ChangeLog ReadMe.txt phone_set.txt spk-info.txt test train + # If you have pre-downloaded it to /path/to/aishell3, you can create a symlink + # + # ln -sfv /path/to/aishell3 $dl_dir/ + # touch $dl_dir/aishell3/.completed + # + if [ ! -d $dl_dir/aishell3 ]; then + lhotse download aishell3 $dl_dir + fi +fi + + +if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then + log "Stage 2: Prepare aishell3 manifest (may take 13 minutes)" + # We assume that you have downloaded the baker corpus + # to $dl_dir/aishell3. + # You can find files like spk-info.txt inside $dl_dir/aishell3 + mkdir -p data/manifests + if [ ! -e data/manifests/.aishell3.done ]; then + lhotse prepare aishell3 $dl_dir/aishell3 data/manifests + touch data/manifests/.aishell3.done + fi +fi + + +if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then + log "Stage 3: Compute spectrogram for aishell3 (may take 5 minutes)" + mkdir -p data/spectrogram + if [ ! -e data/spectrogram/.aishell3.done ]; then + ./local/compute_spectrogram_aishell3.py + touch data/spectrogram/.aishell3.done + fi + + if [ ! -e data/spectrogram/.aishell3-validated.done ]; then + log "Validating data/spectrogram for aishell3" + python3 ./local/validate_manifest.py \ + data/spectrogram/aishell3_cuts_train.jsonl.gz + + python3 ./local/validate_manifest.py \ + data/spectrogram/aishell3_cuts_test.jsonl.gz + + touch data/spectrogram/.aishell3-validated.done + fi +fi + +if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then + log "Stage 4: Prepare tokens for aishell3 (may take 20 seconds)" + if [ ! -e data/spectrogram/.aishell3_with_token.done ]; then + + ./local/prepare_tokens_aishell3.py + + mv -v data/spectrogram/aishell3_cuts_with_tokens_train.jsonl.gz \ + data/spectrogram/aishell3_cuts_train.jsonl.gz + + mv -v data/spectrogram/aishell3_cuts_with_tokens_test.jsonl.gz \ + data/spectrogram/aishell3_cuts_test.jsonl.gz + + touch data/spectrogram/.aishell3_with_token.done + fi +fi + +if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then + log "Stage 5: Split the aishell3 cuts into train, valid and test sets (may take 25 seconds)" + if [ ! -e data/spectrogram/.aishell3_split.done ]; then + lhotse subset --last 1000 \ + data/spectrogram/aishell3_cuts_test.jsonl.gz \ + data/spectrogram/aishell3_cuts_valid.jsonl.gz + + n=$(( $(gunzip -c data/spectrogram/aishell3_cuts_test.jsonl.gz | wc -l) - 1000 )) + + lhotse subset --first $n \ + data/spectrogram/aishell3_cuts_test.jsonl.gz \ + data/spectrogram/aishell3_cuts_test2.jsonl.gz + + mv data/spectrogram/aishell3_cuts_test2.jsonl.gz data/spectrogram/aishell3_cuts_test.jsonl.gz + + touch data/spectrogram/.aishell3_split.done + fi +fi + +if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then + log "Stage 6: Generate token file" + if [ ! -e data/tokens.txt ]; then + ./local/prepare_token_file.py --tokens data/tokens.txt + fi +fi + +if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then + log "Stage 7: Generate speakers file" + if [ ! -e data/speakers.txt ]; then + gunzip -c data/manifests/aishell3_supervisions_train.jsonl.gz \ + | jq '.speaker' | sed 's/"//g' \ + | sort | uniq > data/speakers.txt + fi +fi diff --git a/egs/aishell3/TTS/shared b/egs/aishell3/TTS/shared new file mode 120000 index 0000000000..4cbd91a7e9 --- /dev/null +++ b/egs/aishell3/TTS/shared @@ -0,0 +1 @@ +../../../icefall/shared \ No newline at end of file diff --git a/egs/aishell3/TTS/vits/export-onnx.py b/egs/aishell3/TTS/vits/export-onnx.py new file mode 100755 index 0000000000..ed5a1c6a33 --- /dev/null +++ b/egs/aishell3/TTS/vits/export-onnx.py @@ -0,0 +1,433 @@ +#!/usr/bin/env python3 +# +# Copyright 2023 Xiaomi Corporation (Author: Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This script exports a VITS model from PyTorch to ONNX. + +Export the model to ONNX: +./vits/export-onnx.py \ + --epoch 1000 \ + --speakers ./data/speakers.txt \ + --exp-dir vits/exp \ + --tokens data/tokens.txt + +It will generate one file inside vits/exp: + - vits-epoch-1000.onnx + +See ./test_onnx.py for how to use the exported ONNX models. +""" + +import argparse +import logging +from pathlib import Path +from typing import Dict, Tuple + +import onnx +import torch +import torch.nn as nn +from tokenizer import Tokenizer +from train import get_model, get_params + +from icefall.checkpoint import load_checkpoint + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=1000, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="vits/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--tokens", + type=str, + default="data/tokens.txt", + help="""Path to vocabulary.""", + ) + + parser.add_argument( + "--speakers", + type=Path, + default=Path("data/speakers.txt"), + help="Path to speakers.txt file.", + ) + + parser.add_argument( + "--model-type", + type=str, + default="medium", + choices=["low", "medium", "high"], + help="""If not empty, valid values are: low, medium, high. + It controls the model size. low -> runs faster. + """, + ) + + return parser + + +def add_meta_data(filename: str, meta_data: Dict[str, str]): + """Add meta data to an ONNX model. It is changed in-place. + + Args: + filename: + Filename of the ONNX model to be changed. + meta_data: + Key-value pairs. + """ + model = onnx.load(filename) + for key, value in meta_data.items(): + meta = model.metadata_props.add() + meta.key = key + meta.value = str(value) + + onnx.save(model, filename) + + +class OnnxModel(nn.Module): + """A wrapper for VITS generator.""" + + def __init__(self, model: nn.Module): + """ + Args: + model: + A VITS generator. + frame_shift: + The frame shift in samples. + """ + super().__init__() + self.model = model + + def forward( + self, + tokens: torch.Tensor, + tokens_lens: torch.Tensor, + noise_scale: float = 0.667, + alpha: float = 1.0, + noise_scale_dur: float = 0.8, + speaker: int = 0, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Please see the help information of VITS.inference_batch + + Args: + tokens: + Input text token indexes (1, T_text) + tokens_lens: + Number of tokens of shape (1,) + noise_scale (float): + Noise scale parameter for flow. + noise_scale_dur (float): + Noise scale parameter for duration predictor. + speaker (int): + Speaker ID. + alpha (float): + Alpha parameter to control the speed of generated speech. + + Returns: + Return a tuple containing: + - audio, generated wavform tensor, (B, T_wav) + """ + audio, _, _ = self.model.generator.inference( + text=tokens, + text_lengths=tokens_lens, + noise_scale=noise_scale, + noise_scale_dur=noise_scale_dur, + sids=speaker, + alpha=alpha, + ) + return audio + + +def export_model_onnx( + model: nn.Module, + model_filename: str, + vocab_size: int, + opset_version: int = 11, +) -> None: + """Export the given generator model to ONNX format. + The exported model has one input: + + - tokens, a tensor of shape (1, T_text); dtype is torch.int64 + + and it has one output: + + - audio, a tensor of shape (1, T'); dtype is torch.float32 + + Args: + model: + The VITS generator. + model_filename: + The filename to save the exported ONNX model. + vocab_size: + Number of tokens used in training. + opset_version: + The opset version to use. + """ + tokens = torch.randint(low=0, high=vocab_size, size=(1, 13), dtype=torch.int64) + tokens_lens = torch.tensor([tokens.shape[1]], dtype=torch.int64) + noise_scale = torch.tensor([1], dtype=torch.float32) + noise_scale_dur = torch.tensor([1], dtype=torch.float32) + alpha = torch.tensor([1], dtype=torch.float32) + speaker = torch.tensor([1], dtype=torch.int64) + + torch.onnx.export( + model, + (tokens, tokens_lens, noise_scale, alpha, noise_scale_dur, speaker), + model_filename, + verbose=False, + opset_version=opset_version, + input_names=[ + "tokens", + "tokens_lens", + "noise_scale", + "alpha", + "noise_scale_dur", + "speaker", + ], + output_names=["audio"], + dynamic_axes={ + "tokens": {0: "N", 1: "T"}, + "tokens_lens": {0: "N"}, + "audio": {0: "N", 1: "T"}, + "speaker": {0: "N"}, + }, + ) + + if model.model.spks is None: + num_speakers = 1 + else: + num_speakers = model.model.spks + + meta_data = { + "model_type": "vits", + "version": "1", + "model_author": "k2-fsa", + "comment": "icefall", # must be icefall for models from icefall + "language": "Chinese", + "n_speakers": num_speakers, + "sample_rate": model.model.sampling_rate, # Must match the real sample rate + } + logging.info(f"meta_data: {meta_data}") + + add_meta_data(filename=model_filename, meta_data=meta_data) + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + tokenizer = Tokenizer(params.tokens) + params.blank_id = tokenizer.pad_id + params.vocab_size = tokenizer.vocab_size + + with open(args.speakers) as f: + speaker_map = {line.strip(): i for i, line in enumerate(f)} + params.num_spks = len(speaker_map) + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + + model.to("cpu") + model.eval() + + model = OnnxModel(model=model) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"generator parameters: {num_param}, or {num_param/1000/1000} M") + + suffix = f"epoch-{params.epoch}" + + opset_version = 13 + + logging.info("Exporting encoder") + model_filename = params.exp_dir / f"vits-{suffix}.onnx" + export_model_onnx( + model, + model_filename, + params.vocab_size, + opset_version=opset_version, + ) + logging.info(f"Exported generator to {model_filename}") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + logging.basicConfig(format=formatter, level=logging.INFO) + main() + +""" +Supported languages. + +LJSpeech is using "en-us" from the second column. + +Pty Language Age/Gender VoiceName File Other Languages + 5 af --/M Afrikaans gmw/af + 5 am --/M Amharic sem/am + 5 an --/M Aragonese roa/an + 5 ar --/M Arabic sem/ar + 5 as --/M Assamese inc/as + 5 az --/M Azerbaijani trk/az + 5 ba --/M Bashkir trk/ba + 5 be --/M Belarusian zle/be + 5 bg --/M Bulgarian zls/bg + 5 bn --/M Bengali inc/bn + 5 bpy --/M Bishnupriya_Manipuri inc/bpy + 5 bs --/M Bosnian zls/bs + 5 ca --/M Catalan roa/ca + 5 chr-US-Qaaa-x-west --/M Cherokee_ iro/chr + 5 cmn --/M Chinese_(Mandarin,_latin_as_English) sit/cmn (zh-cmn 5)(zh 5) + 5 cmn-latn-pinyin --/M Chinese_(Mandarin,_latin_as_Pinyin) sit/cmn-Latn-pinyin (zh-cmn 5)(zh 5) + 5 cs --/M Czech zlw/cs + 5 cv --/M Chuvash trk/cv + 5 cy --/M Welsh cel/cy + 5 da --/M Danish gmq/da + 5 de --/M German gmw/de + 5 el --/M Greek grk/el + 5 en-029 --/M English_(Caribbean) gmw/en-029 (en 10) + 2 en-gb --/M English_(Great_Britain) gmw/en (en 2) + 5 en-gb-scotland --/M English_(Scotland) gmw/en-GB-scotland (en 4) + 5 en-gb-x-gbclan --/M English_(Lancaster) gmw/en-GB-x-gbclan (en-gb 3)(en 5) + 5 en-gb-x-gbcwmd --/M English_(West_Midlands) gmw/en-GB-x-gbcwmd (en-gb 9)(en 9) + 5 en-gb-x-rp --/M English_(Received_Pronunciation) gmw/en-GB-x-rp (en-gb 4)(en 5) + 2 en-us --/M English_(America) gmw/en-US (en 3) + 5 en-us-nyc --/M English_(America,_New_York_City) gmw/en-US-nyc + 5 eo --/M Esperanto art/eo + 5 es --/M Spanish_(Spain) roa/es + 5 es-419 --/M Spanish_(Latin_America) roa/es-419 (es-mx 6) + 5 et --/M Estonian urj/et + 5 eu --/M Basque eu + 5 fa --/M Persian ira/fa + 5 fa-latn --/M Persian_(Pinglish) ira/fa-Latn + 5 fi --/M Finnish urj/fi + 5 fr-be --/M French_(Belgium) roa/fr-BE (fr 8) + 5 fr-ch --/M French_(Switzerland) roa/fr-CH (fr 8) + 5 fr-fr --/M French_(France) roa/fr (fr 5) + 5 ga --/M Gaelic_(Irish) cel/ga + 5 gd --/M Gaelic_(Scottish) cel/gd + 5 gn --/M Guarani sai/gn + 5 grc --/M Greek_(Ancient) grk/grc + 5 gu --/M Gujarati inc/gu + 5 hak --/M Hakka_Chinese sit/hak + 5 haw --/M Hawaiian map/haw + 5 he --/M Hebrew sem/he + 5 hi --/M Hindi inc/hi + 5 hr --/M Croatian zls/hr (hbs 5) + 5 ht --/M Haitian_Creole roa/ht + 5 hu --/M Hungarian urj/hu + 5 hy --/M Armenian_(East_Armenia) ine/hy (hy-arevela 5) + 5 hyw --/M Armenian_(West_Armenia) ine/hyw (hy-arevmda 5)(hy 8) + 5 ia --/M Interlingua art/ia + 5 id --/M Indonesian poz/id + 5 io --/M Ido art/io + 5 is --/M Icelandic gmq/is + 5 it --/M Italian roa/it + 5 ja --/M Japanese jpx/ja + 5 jbo --/M Lojban art/jbo + 5 ka --/M Georgian ccs/ka + 5 kk --/M Kazakh trk/kk + 5 kl --/M Greenlandic esx/kl + 5 kn --/M Kannada dra/kn + 5 ko --/M Korean ko + 5 kok --/M Konkani inc/kok + 5 ku --/M Kurdish ira/ku + 5 ky --/M Kyrgyz trk/ky + 5 la --/M Latin itc/la + 5 lb --/M Luxembourgish gmw/lb + 5 lfn --/M Lingua_Franca_Nova art/lfn + 5 lt --/M Lithuanian bat/lt + 5 ltg --/M Latgalian bat/ltg + 5 lv --/M Latvian bat/lv + 5 mi --/M Māori poz/mi + 5 mk --/M Macedonian zls/mk + 5 ml --/M Malayalam dra/ml + 5 mr --/M Marathi inc/mr + 5 ms --/M Malay poz/ms + 5 mt --/M Maltese sem/mt + 5 mto --/M Totontepec_Mixe miz/mto + 5 my --/M Myanmar_(Burmese) sit/my + 5 nb --/M Norwegian_Bokmål gmq/nb (no 5) + 5 nci --/M Nahuatl_(Classical) azc/nci + 5 ne --/M Nepali inc/ne + 5 nl --/M Dutch gmw/nl + 5 nog --/M Nogai trk/nog + 5 om --/M Oromo cus/om + 5 or --/M Oriya inc/or + 5 pa --/M Punjabi inc/pa + 5 pap --/M Papiamento roa/pap + 5 piqd --/M Klingon art/piqd + 5 pl --/M Polish zlw/pl + 5 pt --/M Portuguese_(Portugal) roa/pt (pt-pt 5) + 5 pt-br --/M Portuguese_(Brazil) roa/pt-BR (pt 6) + 5 py --/M Pyash art/py + 5 qdb --/M Lang_Belta art/qdb + 5 qu --/M Quechua qu + 5 quc --/M K'iche' myn/quc + 5 qya --/M Quenya art/qya + 5 ro --/M Romanian roa/ro + 5 ru --/M Russian zle/ru + 5 ru-cl --/M Russian_(Classic) zle/ru-cl + 2 ru-lv --/M Russian_(Latvia) zle/ru-LV + 5 sd --/M Sindhi inc/sd + 5 shn --/M Shan_(Tai_Yai) tai/shn + 5 si --/M Sinhala inc/si + 5 sjn --/M Sindarin art/sjn + 5 sk --/M Slovak zlw/sk + 5 sl --/M Slovenian zls/sl + 5 smj --/M Lule_Saami urj/smj + 5 sq --/M Albanian ine/sq + 5 sr --/M Serbian zls/sr + 5 sv --/M Swedish gmq/sv + 5 sw --/M Swahili bnt/sw + 5 ta --/M Tamil dra/ta + 5 te --/M Telugu dra/te + 5 th --/M Thai tai/th + 5 tk --/M Turkmen trk/tk + 5 tn --/M Setswana bnt/tn + 5 tr --/M Turkish trk/tr + 5 tt --/M Tatar trk/tt + 5 ug --/M Uyghur trk/ug + 5 uk --/M Ukrainian zle/uk + 5 ur --/M Urdu inc/ur + 5 uz --/M Uzbek trk/uz + 5 vi --/M Vietnamese_(Northern) aav/vi + 5 vi-vn-x-central --/M Vietnamese_(Central) aav/vi-VN-x-central + 5 vi-vn-x-south --/M Vietnamese_(Southern) aav/vi-VN-x-south + 5 yue --/M Chinese_(Cantonese) sit/yue (zh-yue 5)(zh 8) + 5 yue --/M Chinese_(Cantonese,_latin_as_Jyutping) sit/yue-Latn-jyutping (zh-yue 5)(zh 8) +""" diff --git a/egs/aishell3/TTS/vits/pinyin_dict.py b/egs/aishell3/TTS/vits/pinyin_dict.py new file mode 120000 index 0000000000..b8683bd2dc --- /dev/null +++ b/egs/aishell3/TTS/vits/pinyin_dict.py @@ -0,0 +1 @@ +../local/pinyin_dict.py \ No newline at end of file diff --git a/egs/aishell3/TTS/vits/pypinyin-local.dict b/egs/aishell3/TTS/vits/pypinyin-local.dict new file mode 120000 index 0000000000..5bc9b77282 --- /dev/null +++ b/egs/aishell3/TTS/vits/pypinyin-local.dict @@ -0,0 +1 @@ +../local/pypinyin-local.dict \ No newline at end of file diff --git a/egs/aishell3/TTS/vits/tokenizer.py b/egs/aishell3/TTS/vits/tokenizer.py index 057b0dc4b1..0368e07d34 120000 --- a/egs/aishell3/TTS/vits/tokenizer.py +++ b/egs/aishell3/TTS/vits/tokenizer.py @@ -1 +1 @@ -../../../ljspeech/TTS/vits/tokenizer.py \ No newline at end of file +../local/tokenizer.py \ No newline at end of file diff --git a/egs/aishell3/TTS/vits/train.py b/egs/aishell3/TTS/vits/train.py new file mode 100755 index 0000000000..f3f99ebbc6 --- /dev/null +++ b/egs/aishell3/TTS/vits/train.py @@ -0,0 +1,1003 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import logging +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, Optional, Tuple, Union + +import k2 +import numpy as np +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from lhotse.cut import Cut +from lhotse.utils import fix_random_seed +from tokenizer import Tokenizer +from torch.cuda.amp import GradScaler, autocast +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.optim import Optimizer +from torch.utils.tensorboard import SummaryWriter +from tts_datamodule import Aishell3SpeechTtsDataModule +from utils import MetricsTracker, plot_feature, save_checkpoint +from vits import VITS + +from icefall import diagnostics +from icefall.checkpoint import load_checkpoint +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.hooks import register_inf_check_hooks +from icefall.utils import AttributeDict, setup_logger, str2bool + +LRSchedulerType = torch.optim.lr_scheduler._LRScheduler + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=1000, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="vits/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--tokens", + type=str, + default="data/tokens.txt", + help="""Path to vocabulary.""", + ) + + parser.add_argument( + "--lr", type=float, default=2.0e-4, help="The base learning rate." + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=20, + help="""Save checkpoint after processing this number of epochs" + periodically. We save checkpoint to exp-dir/ whenever + params.cur_epoch % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/epoch-{params.cur_epoch}.pt'. + Since it will take around 1000 epochs, we suggest using a large + save_every_n to save disk space. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + parser.add_argument( + "--model-type", + type=str, + default="medium", + choices=["low", "medium", "high"], + help="""If not empty, valid values are: low, medium, high. + It controls the model size. low -> runs faster. + """, + ) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + """ + params = AttributeDict( + { + # training params + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": -1, # 0 + "log_interval": 50, + "valid_interval": 200, + "env_info": get_env_info(), + "sampling_rate": 8000, + "frame_shift": 256, + "frame_length": 1024, + "feature_dim": 513, # 1024 // 2 + 1, 1024 is fft_length + "n_mels": 80, + "lambda_adv": 1.0, # loss scaling coefficient for adversarial loss + "lambda_mel": 45.0, # loss scaling coefficient for Mel loss + "lambda_feat_match": 2.0, # loss scaling coefficient for feat match loss + "lambda_dur": 1.0, # loss scaling coefficient for duration loss + "lambda_kl": 1.0, # loss scaling coefficient for KL divergence loss + } + ) + + return params + + +def load_checkpoint_if_available( + params: AttributeDict, model: nn.Module +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint(filename, model=model) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + return saved_params + + +def get_model(params: AttributeDict) -> nn.Module: + mel_loss_params = { + "n_mels": params.n_mels, + "frame_length": params.frame_length, + "frame_shift": params.frame_shift, + } + generator_params = { + "hidden_channels": 192, + "spks": params.num_spks, + "langs": None, + "spk_embed_dim": None, + "global_channels": 256, + "segment_size": 32, + "text_encoder_attention_heads": 2, + "text_encoder_ffn_expand": 4, + "text_encoder_cnn_module_kernel": 5, + "text_encoder_blocks": 6, + "text_encoder_dropout_rate": 0.1, + "decoder_kernel_size": 7, + "decoder_channels": 512, + "decoder_upsample_scales": [8, 8, 2, 2], + "decoder_upsample_kernel_sizes": [16, 16, 4, 4], + "decoder_resblock_kernel_sizes": [3, 7, 11], + "decoder_resblock_dilations": [[1, 3, 5], [1, 3, 5], [1, 3, 5]], + "use_weight_norm_in_decoder": True, + "posterior_encoder_kernel_size": 5, + "posterior_encoder_layers": 16, + "posterior_encoder_stacks": 1, + "posterior_encoder_base_dilation": 1, + "posterior_encoder_dropout_rate": 0.0, + "use_weight_norm_in_posterior_encoder": True, + "flow_flows": 4, + "flow_kernel_size": 5, + "flow_base_dilation": 1, + "flow_layers": 4, + "flow_dropout_rate": 0.0, + "use_weight_norm_in_flow": True, + "use_only_mean_in_flow": True, + "stochastic_duration_predictor_kernel_size": 3, + "stochastic_duration_predictor_dropout_rate": 0.5, + "stochastic_duration_predictor_flows": 4, + "stochastic_duration_predictor_dds_conv_layers": 3, + } + model = VITS( + vocab_size=params.vocab_size, + feature_dim=params.feature_dim, + sampling_rate=params.sampling_rate, + generator_params=generator_params, + model_type=params.model_type, + mel_loss_params=mel_loss_params, + lambda_adv=params.lambda_adv, + lambda_mel=params.lambda_mel, + lambda_feat_match=params.lambda_feat_match, + lambda_dur=params.lambda_dur, + lambda_kl=params.lambda_kl, + ) + return model + + +def prepare_input( + batch: dict, + tokenizer: Tokenizer, + device: torch.device, + speaker_map: Dict[str, int], +): + """Parse batch data""" + audio = batch["audio"].to(device) + features = batch["features"].to(device) + audio_lens = batch["audio_lens"].to(device) + features_lens = batch["features_lens"].to(device) + tokens = batch["tokens"] + speakers = ( + torch.Tensor([speaker_map.get(sid, 0) for sid in batch["speakers"]]) + .int() + .to(device) + ) + + tokens = tokenizer.tokens_to_token_ids(tokens) + tokens = k2.RaggedTensor(tokens) + row_splits = tokens.shape.row_splits(1) + tokens_lens = row_splits[1:] - row_splits[:-1] + tokens = tokens.to(device) + tokens_lens = tokens_lens.to(device) + # a tensor of shape (B, T) + tokens = tokens.pad(mode="constant", padding_value=tokenizer.pad_id) + + return audio, audio_lens, features, features_lens, tokens, tokens_lens, speakers + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + tokenizer: Tokenizer, + optimizer_g: Optimizer, + optimizer_d: Optimizer, + scheduler_g: LRSchedulerType, + scheduler_d: LRSchedulerType, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + speaker_map: Dict[str, int], + scaler: GradScaler, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + tokenizer: + Used to convert text to phonemes. + optimizer_g: + The optimizer for generator. + optimizer_d: + The optimizer for discriminator. + scheduler_g: + The learning rate scheduler for generator, we call step() every epoch. + scheduler_d: + The learning rate scheduler for discriminator, we call step() every epoch. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + + # used to track the stats over iterations in one epoch + tot_loss = MetricsTracker() + + saved_bad_model = False + + def save_bad_model(suffix: str = ""): + save_checkpoint( + filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", + model=model, + params=params, + optimizer_g=optimizer_g, + optimizer_d=optimizer_d, + scheduler_g=scheduler_g, + scheduler_d=scheduler_d, + sampler=train_dl.sampler, + scaler=scaler, + rank=0, + ) + + for batch_idx, batch in enumerate(train_dl): + params.batch_idx_train += 1 + + batch_size = len(batch["tokens"]) + ( + audio, + audio_lens, + features, + features_lens, + tokens, + tokens_lens, + speakers, + ) = prepare_input(batch, tokenizer, device, speaker_map) + + loss_info = MetricsTracker() + loss_info["samples"] = batch_size + + try: + with autocast(enabled=params.use_fp16): + # forward discriminator + loss_d, stats_d = model( + text=tokens, + text_lengths=tokens_lens, + feats=features, + feats_lengths=features_lens, + speech=audio, + speech_lengths=audio_lens, + sids=speakers, + forward_generator=False, + ) + for k, v in stats_d.items(): + loss_info[k] = v * batch_size + # update discriminator + optimizer_d.zero_grad() + scaler.scale(loss_d).backward() + scaler.step(optimizer_d) + + with autocast(enabled=params.use_fp16): + # forward generator + loss_g, stats_g = model( + text=tokens, + text_lengths=tokens_lens, + feats=features, + feats_lengths=features_lens, + speech=audio, + speech_lengths=audio_lens, + sids=speakers, + forward_generator=True, + return_sample=params.batch_idx_train % params.log_interval == 0, + ) + for k, v in stats_g.items(): + if "returned_sample" not in k: + loss_info[k] = v * batch_size + # update generator + optimizer_g.zero_grad() + scaler.scale(loss_g).backward() + scaler.step(optimizer_g) + scaler.update() + + # summary stats + tot_loss = tot_loss + loss_info + except: # noqa + save_bad_model() + raise + + if params.print_diagnostics and batch_idx == 5: + return + + if params.batch_idx_train % 100 == 0 and params.use_fp16: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + + if cur_grad_scale < 8.0 or ( + cur_grad_scale < 32.0 and params.batch_idx_train % 400 == 0 + ): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + if not saved_bad_model: + save_bad_model(suffix="-first-warning") + saved_bad_model = True + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + save_bad_model() + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + + if params.batch_idx_train % params.log_interval == 0: + cur_lr_g = max(scheduler_g.get_last_lr()) + cur_lr_d = max(scheduler_d.get_last_lr()) + cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + + logging.info( + f"Epoch {params.cur_epoch}, batch {batch_idx}, " + f"global_batch_idx: {params.batch_idx_train}, batch size: {batch_size}, " + f"loss[{loss_info}], tot_loss[{tot_loss}], " + f"cur_lr_g: {cur_lr_g:.2e}, cur_lr_d: {cur_lr_d:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate_g", cur_lr_g, params.batch_idx_train + ) + tb_writer.add_scalar( + "train/learning_rate_d", cur_lr_d, params.batch_idx_train + ) + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", cur_grad_scale, params.batch_idx_train + ) + if "returned_sample" in stats_g: + speech_hat_, speech_, mel_hat_, mel_ = stats_g["returned_sample"] + tb_writer.add_audio( + "train/speech_hat_", + speech_hat_, + params.batch_idx_train, + params.sampling_rate, + ) + tb_writer.add_audio( + "train/speech_", + speech_, + params.batch_idx_train, + params.sampling_rate, + ) + tb_writer.add_image( + "train/mel_hat_", + plot_feature(mel_hat_), + params.batch_idx_train, + dataformats="HWC", + ) + tb_writer.add_image( + "train/mel_", + plot_feature(mel_), + params.batch_idx_train, + dataformats="HWC", + ) + + if ( + params.batch_idx_train % params.valid_interval == 0 + and not params.print_diagnostics + ): + logging.info("Computing validation loss") + valid_info, (speech_hat, speech) = compute_validation_loss( + params=params, + model=model, + tokenizer=tokenizer, + valid_dl=valid_dl, + speaker_map=speaker_map, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + tb_writer.add_audio( + "train/valdi_speech_hat", + speech_hat, + params.batch_idx_train, + params.sampling_rate, + ) + tb_writer.add_audio( + "train/valdi_speech", + speech, + params.batch_idx_train, + params.sampling_rate, + ) + + loss_value = tot_loss["generator_loss"] / tot_loss["samples"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + tokenizer: Tokenizer, + valid_dl: torch.utils.data.DataLoader, + speaker_map: Dict[str, int], + world_size: int = 1, + rank: int = 0, +) -> Tuple[MetricsTracker, Tuple[np.ndarray, np.ndarray]]: + """Run the validation process.""" + model.eval() + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + + # used to summary the stats over iterations + tot_loss = MetricsTracker() + returned_sample = None + + with torch.no_grad(): + for batch_idx, batch in enumerate(valid_dl): + batch_size = len(batch["tokens"]) + ( + audio, + audio_lens, + features, + features_lens, + tokens, + tokens_lens, + speakers, + ) = prepare_input(batch, tokenizer, device, speaker_map) + + loss_info = MetricsTracker() + loss_info["samples"] = batch_size + + # forward discriminator + loss_d, stats_d = model( + text=tokens, + text_lengths=tokens_lens, + feats=features, + feats_lengths=features_lens, + speech=audio, + speech_lengths=audio_lens, + sids=speakers, + forward_generator=False, + ) + assert loss_d.requires_grad is False + for k, v in stats_d.items(): + loss_info[k] = v * batch_size + + # forward generator + loss_g, stats_g = model( + text=tokens, + text_lengths=tokens_lens, + feats=features, + feats_lengths=features_lens, + speech=audio, + speech_lengths=audio_lens, + sids=speakers, + forward_generator=True, + ) + assert loss_g.requires_grad is False + for k, v in stats_g.items(): + loss_info[k] = v * batch_size + + # summary stats + tot_loss = tot_loss + loss_info + + # infer for first batch: + if batch_idx == 0 and rank == 0: + inner_model = model.module if isinstance(model, DDP) else model + audio_pred, _, duration = inner_model.inference( + text=tokens[0, : tokens_lens[0].item()], + sids=speakers[0], + ) + audio_pred = audio_pred.data.cpu().numpy() + audio_len_pred = ( + (duration.sum(0) * params.frame_shift).to(dtype=torch.int64).item() + ) + assert audio_len_pred == len(audio_pred), ( + audio_len_pred, + len(audio_pred), + ) + audio_gt = audio[0, : audio_lens[0].item()].data.cpu().numpy() + returned_sample = (audio_pred, audio_gt) + + if world_size > 1: + tot_loss.reduce(device) + + loss_value = tot_loss["generator_loss"] / tot_loss["samples"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss, returned_sample + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + tokenizer: Tokenizer, + optimizer_g: torch.optim.Optimizer, + optimizer_d: torch.optim.Optimizer, + speaker_map: Dict[str, int], + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + ( + audio, + audio_lens, + features, + features_lens, + tokens, + tokens_lens, + speakers, + ) = prepare_input(batch, tokenizer, device, speaker_map) + try: + # for discriminator + with autocast(enabled=params.use_fp16): + loss_d, stats_d = model( + text=tokens, + text_lengths=tokens_lens, + feats=features, + feats_lengths=features_lens, + speech=audio, + speech_lengths=audio_lens, + sids=speakers, + forward_generator=False, + ) + optimizer_d.zero_grad() + loss_d.backward() + # for generator + with autocast(enabled=params.use_fp16): + loss_g, stats_g = model( + text=tokens, + text_lengths=tokens_lens, + feats=features, + feats_lengths=features_lens, + speech=audio, + speech_lengths=audio_lens, + sids=speakers, + forward_generator=True, + ) + optimizer_g.zero_grad() + loss_g.backward() + except Exception as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + raise + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + tokenizer = Tokenizer(params.tokens) + params.blank_id = tokenizer.pad_id + params.vocab_size = tokenizer.vocab_size + + aishell3 = Aishell3SpeechTtsDataModule(args) + speaker_map = aishell3.speakers() + params.num_spks = len(speaker_map) + + logging.info("About to create model") + model = get_model(params) + generator = model.generator + discriminator = model.discriminator + + num_param_g = sum([p.numel() for p in generator.parameters()]) + logging.info(f"Number of parameters in generator: {num_param_g}") + num_param_d = sum([p.numel() for p in discriminator.parameters()]) + logging.info(f"Number of parameters in discriminator: {num_param_d}") + logging.info(f"Total number of parameters: {num_param_g + num_param_d}") + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available(params=params, model=model) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + optimizer_g = torch.optim.AdamW( + generator.parameters(), lr=params.lr, betas=(0.8, 0.99), eps=1e-9 + ) + optimizer_d = torch.optim.AdamW( + discriminator.parameters(), lr=params.lr, betas=(0.8, 0.99), eps=1e-9 + ) + + scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optimizer_g, gamma=0.999875) + scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optimizer_d, gamma=0.999875) + + if checkpoints is not None: + # load state_dict for optimizers + if "optimizer_g" in checkpoints: + logging.info("Loading optimizer_g state dict") + optimizer_g.load_state_dict(checkpoints["optimizer_g"]) + if "optimizer_d" in checkpoints: + logging.info("Loading optimizer_d state dict") + optimizer_d.load_state_dict(checkpoints["optimizer_d"]) + + # load state_dict for schedulers + if "scheduler_g" in checkpoints: + logging.info("Loading scheduler_g state dict") + scheduler_g.load_state_dict(checkpoints["scheduler_g"]) + if "scheduler_d" in checkpoints: + logging.info("Loading scheduler_d state dict") + scheduler_d.load_state_dict(checkpoints["scheduler_d"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 512 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + train_cuts = aishell3.train_cuts() + + logging.info(params) + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 20 seconds + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + if c.duration < 1.0 or c.duration > 20.0: + # logging.warning( + # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + # ) + return False + return True + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + train_dl = aishell3.train_dataloaders(train_cuts) + + valid_cuts = aishell3.valid_cuts() + valid_dl = aishell3.valid_dataloaders(valid_cuts) + + if not params.print_diagnostics: + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + tokenizer=tokenizer, + optimizer_g=optimizer_g, + optimizer_d=optimizer_d, + speaker_map=speaker_map, + params=params, + ) + + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + logging.info(f"Start epoch {epoch}") + + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + params.cur_epoch = epoch + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + train_one_epoch( + params=params, + model=model, + tokenizer=tokenizer, + optimizer_g=optimizer_g, + optimizer_d=optimizer_d, + scheduler_g=scheduler_g, + scheduler_d=scheduler_d, + train_dl=train_dl, + valid_dl=valid_dl, + speaker_map=speaker_map, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + if epoch % params.save_every_n == 0 or epoch == params.num_epochs: + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint( + filename=filename, + params=params, + model=model, + optimizer_g=optimizer_g, + optimizer_d=optimizer_d, + scheduler_g=scheduler_g, + scheduler_d=scheduler_d, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + if rank == 0: + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + # step per epoch + scheduler_g.step() + scheduler_d.step() + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def main(): + parser = get_parser() + Aishell3SpeechTtsDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/aishell3/TTS/vits/tts_datamodule.py b/egs/aishell3/TTS/vits/tts_datamodule.py new file mode 100644 index 0000000000..a08c645382 --- /dev/null +++ b/egs/aishell3/TTS/vits/tts_datamodule.py @@ -0,0 +1,349 @@ +# Copyright 2021 Piotr Żelasko +# Copyright 2022-2023 Xiaomi Corporation (Authors: Mingshuang Luo, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import logging +from functools import lru_cache +from pathlib import Path +from typing import Any, Dict, Optional + +import torch +from lhotse import CutSet, Spectrogram, SpectrogramConfig, load_manifest_lazy +from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures + CutConcatenate, + CutMix, + DynamicBucketingSampler, + PrecomputedFeatures, + SimpleCutSampler, + SpecAugment, + SpeechSynthesisDataset, +) +from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples + AudioSamples, + OnTheFlyFeatures, +) +from lhotse.utils import fix_random_seed +from torch.utils.data import DataLoader + +from icefall.utils import str2bool + + +class _SeedWorkers: + def __init__(self, seed: int): + self.seed = seed + + def __call__(self, worker_id: int): + fix_random_seed(self.seed + worker_id) + + +class Aishell3SpeechTtsDataModule: + """ + DataModule for tts experiments. + It assumes there is always one train and valid dataloader, + but there can be multiple test dataloaders (e.g. LibriSpeech test-clean + and test-other). + + It contains all the common data pipeline modules used in ASR + experiments, e.g.: + - dynamic batch size, + - bucketing samplers, + - cut concatenation, + - on-the-fly feature extraction + + This class should be derived for specific corpora used in TTS tasks. + """ + + def __init__(self, args: argparse.Namespace): + self.args = args + self.sampling_rate = 8000 + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="TTS data related options", + description="These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc.", + ) + + group.add_argument( + "--manifest-dir", + type=Path, + default=Path("data/spectrogram"), + help="Path to directory with train/valid/test cuts.", + ) + group.add_argument( + "--speakers", + type=Path, + default=Path("data/speakers.txt"), + help="Path to speakers.txt file.", + ) + group.add_argument( + "--max-duration", + type=int, + default=200.0, + help="Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM.", + ) + group.add_argument( + "--bucketing-sampler", + type=str2bool, + default=True, + help="When enabled, the batches will come from buckets of " + "similar duration (saves padding frames).", + ) + group.add_argument( + "--num-buckets", + type=int, + default=30, + help="The number of buckets for the DynamicBucketingSampler" + "(you might want to increase it for larger datasets).", + ) + + group.add_argument( + "--on-the-fly-feats", + type=str2bool, + default=False, + help="When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available.", + ) + group.add_argument( + "--shuffle", + type=str2bool, + default=True, + help="When enabled (=default), the examples will be " + "shuffled for each epoch.", + ) + group.add_argument( + "--drop-last", + type=str2bool, + default=True, + help="Whether to drop last batch. Used by sampler.", + ) + group.add_argument( + "--return-cuts", + type=str2bool, + default=False, + help="When enabled, each batch will have the " + "field: batch['cut'] with the cuts that " + "were used to construct it.", + ) + group.add_argument( + "--num-workers", + type=int, + default=2, + help="The number of training dataloader workers that " + "collect the batches.", + ) + + group.add_argument( + "--input-strategy", + type=str, + default="PrecomputedFeatures", + help="AudioSamples or PrecomputedFeatures", + ) + + def train_dataloaders( + self, + cuts_train: CutSet, + sampler_state_dict: Optional[Dict[str, Any]] = None, + ) -> DataLoader: + """ + Args: + cuts_train: + CutSet for training. + sampler_state_dict: + The state dict for the training sampler. + """ + logging.info("About to create train dataset") + train = SpeechSynthesisDataset( + return_text=False, + return_tokens=True, + return_spk_ids=True, + feature_input_strategy=eval(self.args.input_strategy)(), + return_cuts=self.args.return_cuts, + ) + + if self.args.on_the_fly_feats: + sampling_rate = self.sampling_rate + config = SpectrogramConfig( + sampling_rate=sampling_rate, + frame_length=1024 / sampling_rate, # (in second), + frame_shift=256 / sampling_rate, # (in second) + use_fft_mag=True, + ) + train = SpeechSynthesisDataset( + return_text=False, + return_tokens=True, + return_spk_ids=True, + feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)), + return_cuts=self.args.return_cuts, + ) + + if self.args.bucketing_sampler: + logging.info("Using DynamicBucketingSampler.") + train_sampler = DynamicBucketingSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + num_buckets=self.args.num_buckets, + buffer_size=self.args.num_buckets * 2000, + shuffle_buffer_size=self.args.num_buckets * 5000, + drop_last=self.args.drop_last, + ) + else: + logging.info("Using SimpleCutSampler.") + train_sampler = SimpleCutSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + ) + logging.info("About to create train dataloader") + + if sampler_state_dict is not None: + logging.info("Loading sampler state dict") + train_sampler.load_state_dict(sampler_state_dict) + + # 'seed' is derived from the current random state, which will have + # previously been set in the main process. + seed = torch.randint(0, 100000, ()).item() + worker_init_fn = _SeedWorkers(seed) + + train_dl = DataLoader( + train, + sampler=train_sampler, + batch_size=None, + num_workers=self.args.num_workers, + persistent_workers=False, + worker_init_fn=worker_init_fn, + ) + + return train_dl + + def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: + logging.info("About to create dev dataset") + if self.args.on_the_fly_feats: + sampling_rate = self.sampling_rate + config = SpectrogramConfig( + sampling_rate=sampling_rate, + frame_length=1024 / sampling_rate, # (in second), + frame_shift=256 / sampling_rate, # (in second) + use_fft_mag=True, + ) + validate = SpeechSynthesisDataset( + return_text=False, + return_tokens=True, + return_spk_ids=True, + feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)), + return_cuts=self.args.return_cuts, + ) + else: + validate = SpeechSynthesisDataset( + return_text=False, + return_tokens=True, + return_spk_ids=True, + feature_input_strategy=eval(self.args.input_strategy)(), + return_cuts=self.args.return_cuts, + ) + valid_sampler = DynamicBucketingSampler( + cuts_valid, + max_duration=self.args.max_duration, + num_buckets=self.args.num_buckets, + shuffle=False, + ) + logging.info("About to create valid dataloader") + valid_dl = DataLoader( + validate, + sampler=valid_sampler, + batch_size=None, + num_workers=2, + persistent_workers=False, + ) + + return valid_dl + + def test_dataloaders(self, cuts: CutSet) -> DataLoader: + logging.info("About to create test dataset") + if self.args.on_the_fly_feats: + sampling_rate = self.sampling_rate + config = SpectrogramConfig( + sampling_rate=sampling_rate, + frame_length=1024 / sampling_rate, # (in second), + frame_shift=256 / sampling_rate, # (in second) + use_fft_mag=True, + ) + test = SpeechSynthesisDataset( + return_text=False, + return_tokens=True, + return_spk_ids=True, + feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)), + return_cuts=self.args.return_cuts, + ) + else: + test = SpeechSynthesisDataset( + return_text=False, + return_tokens=True, + return_spk_ids=True, + feature_input_strategy=eval(self.args.input_strategy)(), + return_cuts=self.args.return_cuts, + ) + test_sampler = DynamicBucketingSampler( + cuts, + max_duration=self.args.max_duration, + num_buckets=self.args.num_buckets, + shuffle=False, + ) + logging.info("About to create test dataloader") + test_dl = DataLoader( + test, + batch_size=None, + sampler=test_sampler, + num_workers=self.args.num_workers, + ) + return test_dl + + @lru_cache() + def train_cuts(self) -> CutSet: + logging.info("About to get train cuts") + return load_manifest_lazy( + self.args.manifest_dir / "aishell3_cuts_train.jsonl.gz" + ) + + @lru_cache() + def valid_cuts(self) -> CutSet: + logging.info("About to get validation cuts") + return load_manifest_lazy( + self.args.manifest_dir / "aishell3_cuts_valid.jsonl.gz" + ) + + @lru_cache() + def test_cuts(self) -> CutSet: + logging.info("About to get test cuts") + return load_manifest_lazy( + self.args.manifest_dir / "aishell3_cuts_test.jsonl.gz" + ) + + @lru_cache() + def speakers(self) -> Dict[str, int]: + logging.info("About to get speakers") + with open(self.args.speakers) as f: + speakers = {line.strip(): i for i, line in enumerate(f)} + return speakers From f9bd5ced9d952f537da208873a5a2bb793df89d2 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Sat, 6 Apr 2024 21:50:15 +0800 Subject: [PATCH 5/8] remove baker_zh --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index 620427501b..9e45df61c9 100644 --- a/.gitignore +++ b/.gitignore @@ -39,3 +39,4 @@ node_modules core.c *.so build +*.wav From 35578f0593fae7888619d970e754e1bfbe68d7b9 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Sat, 6 Apr 2024 21:51:09 +0800 Subject: [PATCH 6/8] remove baker-zh --- egs/baker_zh/TTS/README.md | 0 egs/baker_zh/TTS/local/README.md | 7 - egs/baker_zh/TTS/local/__init__.py | 0 .../TTS/local/compute_spectrogram_baker.py | 106 -- egs/baker_zh/TTS/local/pinyin_dict.py | 421 -------- egs/baker_zh/TTS/local/prepare_token_file.py | 53 - .../TTS/local/prepare_tokens_baker_zh.py | 59 -- egs/baker_zh/TTS/local/pypinyin-local.dict | 328 ------- egs/baker_zh/TTS/local/symbols.py | 73 -- egs/baker_zh/TTS/local/tokenizer.py | 137 --- egs/baker_zh/TTS/local/validate_manifest.py | 1 - egs/baker_zh/TTS/prepare.sh | 124 --- egs/baker_zh/TTS/shared | 1 - egs/baker_zh/TTS/vits/duration_predictor.py | 1 - egs/baker_zh/TTS/vits/export-onnx.py | 414 -------- egs/baker_zh/TTS/vits/flow.py | 1 - egs/baker_zh/TTS/vits/generate_lexicon.py | 39 - egs/baker_zh/TTS/vits/generator.py | 1 - egs/baker_zh/TTS/vits/hifigan.py | 1 - egs/baker_zh/TTS/vits/loss.py | 1 - egs/baker_zh/TTS/vits/monotonic_align | 1 - egs/baker_zh/TTS/vits/pinyin_dict.py | 1 - egs/baker_zh/TTS/vits/posterior_encoder.py | 1 - egs/baker_zh/TTS/vits/pypinyin-local.dict | 1 - egs/baker_zh/TTS/vits/residual_coupling.py | 1 - egs/baker_zh/TTS/vits/test_onnx.py | 142 --- egs/baker_zh/TTS/vits/text_encoder.py | 1 - egs/baker_zh/TTS/vits/tokenizer.py | 1 - egs/baker_zh/TTS/vits/train.py | 927 ------------------ egs/baker_zh/TTS/vits/transform.py | 1 - egs/baker_zh/TTS/vits/tts_datamodule.py | 330 ------- egs/baker_zh/TTS/vits/utils.py | 1 - egs/baker_zh/TTS/vits/vits.py | 1 - egs/baker_zh/TTS/vits/wavenet.py | 1 - 34 files changed, 3178 deletions(-) delete mode 100644 egs/baker_zh/TTS/README.md delete mode 100644 egs/baker_zh/TTS/local/README.md delete mode 100644 egs/baker_zh/TTS/local/__init__.py delete mode 100755 egs/baker_zh/TTS/local/compute_spectrogram_baker.py delete mode 100644 egs/baker_zh/TTS/local/pinyin_dict.py delete mode 100755 egs/baker_zh/TTS/local/prepare_token_file.py delete mode 100755 egs/baker_zh/TTS/local/prepare_tokens_baker_zh.py delete mode 100644 egs/baker_zh/TTS/local/pypinyin-local.dict delete mode 100644 egs/baker_zh/TTS/local/symbols.py delete mode 100644 egs/baker_zh/TTS/local/tokenizer.py delete mode 120000 egs/baker_zh/TTS/local/validate_manifest.py delete mode 100755 egs/baker_zh/TTS/prepare.sh delete mode 120000 egs/baker_zh/TTS/shared delete mode 120000 egs/baker_zh/TTS/vits/duration_predictor.py delete mode 100755 egs/baker_zh/TTS/vits/export-onnx.py delete mode 120000 egs/baker_zh/TTS/vits/flow.py delete mode 100755 egs/baker_zh/TTS/vits/generate_lexicon.py delete mode 120000 egs/baker_zh/TTS/vits/generator.py delete mode 120000 egs/baker_zh/TTS/vits/hifigan.py delete mode 120000 egs/baker_zh/TTS/vits/loss.py delete mode 120000 egs/baker_zh/TTS/vits/monotonic_align delete mode 120000 egs/baker_zh/TTS/vits/pinyin_dict.py delete mode 120000 egs/baker_zh/TTS/vits/posterior_encoder.py delete mode 120000 egs/baker_zh/TTS/vits/pypinyin-local.dict delete mode 120000 egs/baker_zh/TTS/vits/residual_coupling.py delete mode 100755 egs/baker_zh/TTS/vits/test_onnx.py delete mode 120000 egs/baker_zh/TTS/vits/text_encoder.py delete mode 120000 egs/baker_zh/TTS/vits/tokenizer.py delete mode 100755 egs/baker_zh/TTS/vits/train.py delete mode 120000 egs/baker_zh/TTS/vits/transform.py delete mode 100644 egs/baker_zh/TTS/vits/tts_datamodule.py delete mode 120000 egs/baker_zh/TTS/vits/utils.py delete mode 120000 egs/baker_zh/TTS/vits/vits.py delete mode 120000 egs/baker_zh/TTS/vits/wavenet.py diff --git a/egs/baker_zh/TTS/README.md b/egs/baker_zh/TTS/README.md deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/egs/baker_zh/TTS/local/README.md b/egs/baker_zh/TTS/local/README.md deleted file mode 100644 index dac1388537..0000000000 --- a/egs/baker_zh/TTS/local/README.md +++ /dev/null @@ -1,7 +0,0 @@ -# Introduction - -[./symbols.py](./symbols.py) is copied from -https://github.com/UEhQZXI/vits_chinese/blob/master/text/symbols.py - -[./pypinyin-local.dict](./pypinyin-local.dict) is copied from -https://github.com/UEhQZXI/vits_chinese/blob/master/misc/pypinyin-local.dict diff --git a/egs/baker_zh/TTS/local/__init__.py b/egs/baker_zh/TTS/local/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/egs/baker_zh/TTS/local/compute_spectrogram_baker.py b/egs/baker_zh/TTS/local/compute_spectrogram_baker.py deleted file mode 100755 index 1a15c7c0d4..0000000000 --- a/egs/baker_zh/TTS/local/compute_spectrogram_baker.py +++ /dev/null @@ -1,106 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, -# Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -""" -This file computes fbank features of the baker_zh dataset. -It looks for manifests in the directory data/manifests. - -The generated spectrogram features are saved in data/spectrogram. -""" - -import logging -import os -from pathlib import Path - -import torch -from lhotse import ( - CutSet, - LilcomChunkyWriter, - Spectrogram, - SpectrogramConfig, - load_manifest, -) -from lhotse.audio import RecordingSet -from lhotse.supervision import SupervisionSet - -from icefall.utils import get_executor - -# Torch's multithreaded behavior needs to be disabled or -# it wastes a lot of CPU and slow things down. -# Do this outside of main() in case it needs to take effect -# even when we are not invoking the main (e.g. when spawning subprocesses). -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - - -def compute_spectrogram_baker_zh(): - src_dir = Path("data/manifests") - output_dir = Path("data/spectrogram") - num_jobs = min(4, os.cpu_count()) - - sampling_rate = 48000 - frame_length = 1024 / sampling_rate # (in second) - frame_shift = 256 / sampling_rate # (in second) - use_fft_mag = True - - prefix = "baker_zh" - suffix = "jsonl.gz" - partition = "all" - - recordings = load_manifest( - src_dir / f"{prefix}_recordings_{partition}.{suffix}", RecordingSet - ) - supervisions = load_manifest( - src_dir / f"{prefix}_supervisions_{partition}.{suffix}", SupervisionSet - ) - - config = SpectrogramConfig( - sampling_rate=sampling_rate, - frame_length=frame_length, - frame_shift=frame_shift, - use_fft_mag=use_fft_mag, - ) - extractor = Spectrogram(config) - - with get_executor() as ex: # Initialize the executor only once. - cuts_filename = f"{prefix}_cuts_{partition}.{suffix}" - if (output_dir / cuts_filename).is_file(): - logging.info(f"{cuts_filename} already exists - skipping.") - return - logging.info(f"Processing {partition}") - cut_set = CutSet.from_manifests( - recordings=recordings, supervisions=supervisions - ) - - cut_set = cut_set.compute_and_store_features( - extractor=extractor, - storage_path=f"{output_dir}/{prefix}_feats_{partition}", - # when an executor is specified, make more partitions - num_jobs=num_jobs if ex is None else 80, - executor=ex, - storage_type=LilcomChunkyWriter, - ) - cut_set.to_file(output_dir / cuts_filename) - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - compute_spectrogram_baker_zh() diff --git a/egs/baker_zh/TTS/local/pinyin_dict.py b/egs/baker_zh/TTS/local/pinyin_dict.py deleted file mode 100644 index 950fb39fc0..0000000000 --- a/egs/baker_zh/TTS/local/pinyin_dict.py +++ /dev/null @@ -1,421 +0,0 @@ -# This dict is copied from -# https://github.com/UEhQZXI/vits_chinese/blob/master/vits_strings.py -pinyin_dict = { - "a": ("^", "a"), - "ai": ("^", "ai"), - "an": ("^", "an"), - "ang": ("^", "ang"), - "ao": ("^", "ao"), - "ba": ("b", "a"), - "bai": ("b", "ai"), - "ban": ("b", "an"), - "bang": ("b", "ang"), - "bao": ("b", "ao"), - "be": ("b", "e"), - "bei": ("b", "ei"), - "ben": ("b", "en"), - "beng": ("b", "eng"), - "bi": ("b", "i"), - "bian": ("b", "ian"), - "biao": ("b", "iao"), - "bie": ("b", "ie"), - "bin": ("b", "in"), - "bing": ("b", "ing"), - "bo": ("b", "o"), - "bu": ("b", "u"), - "ca": ("c", "a"), - "cai": ("c", "ai"), - "can": ("c", "an"), - "cang": ("c", "ang"), - "cao": ("c", "ao"), - "ce": ("c", "e"), - "cen": ("c", "en"), - "ceng": ("c", "eng"), - "cha": ("ch", "a"), - "chai": ("ch", "ai"), - "chan": ("ch", "an"), - "chang": ("ch", "ang"), - "chao": ("ch", "ao"), - "che": ("ch", "e"), - "chen": ("ch", "en"), - "cheng": ("ch", "eng"), - "chi": ("ch", "iii"), - "chong": ("ch", "ong"), - "chou": ("ch", "ou"), - "chu": ("ch", "u"), - "chua": ("ch", "ua"), - "chuai": ("ch", "uai"), - "chuan": ("ch", "uan"), - "chuang": ("ch", "uang"), - "chui": ("ch", "uei"), - "chun": ("ch", "uen"), - "chuo": ("ch", "uo"), - "ci": ("c", "ii"), - "cong": ("c", "ong"), - "cou": ("c", "ou"), - "cu": ("c", "u"), - "cuan": ("c", "uan"), - "cui": ("c", "uei"), - "cun": ("c", "uen"), - "cuo": ("c", "uo"), - "da": ("d", "a"), - "dai": ("d", "ai"), - "dan": ("d", "an"), - "dang": ("d", "ang"), - "dao": ("d", "ao"), - "de": ("d", "e"), - "dei": ("d", "ei"), - "den": ("d", "en"), - "deng": ("d", "eng"), - "di": ("d", "i"), - "dia": ("d", "ia"), - "dian": ("d", "ian"), - "diao": ("d", "iao"), - "die": ("d", "ie"), - "ding": ("d", "ing"), - "diu": ("d", "iou"), - "dong": ("d", "ong"), - "dou": ("d", "ou"), - "du": ("d", "u"), - "duan": ("d", "uan"), - "dui": ("d", "uei"), - "dun": ("d", "uen"), - "duo": ("d", "uo"), - "e": ("^", "e"), - "ei": ("^", "ei"), - "en": ("^", "en"), - "ng": ("^", "en"), - "eng": ("^", "eng"), - "er": ("^", "er"), - "fa": ("f", "a"), - "fan": ("f", "an"), - "fang": ("f", "ang"), - "fei": ("f", "ei"), - "fen": ("f", "en"), - "feng": ("f", "eng"), - "fo": ("f", "o"), - "fou": ("f", "ou"), - "fu": ("f", "u"), - "ga": ("g", "a"), - "gai": ("g", "ai"), - "gan": ("g", "an"), - "gang": ("g", "ang"), - "gao": ("g", "ao"), - "ge": ("g", "e"), - "gei": ("g", "ei"), - "gen": ("g", "en"), - "geng": ("g", "eng"), - "gong": ("g", "ong"), - "gou": ("g", "ou"), - "gu": ("g", "u"), - "gua": ("g", "ua"), - "guai": ("g", "uai"), - "guan": ("g", "uan"), - "guang": ("g", "uang"), - "gui": ("g", "uei"), - "gun": ("g", "uen"), - "guo": ("g", "uo"), - "ha": ("h", "a"), - "hai": ("h", "ai"), - "han": ("h", "an"), - "hang": ("h", "ang"), - "hao": ("h", "ao"), - "he": ("h", "e"), - "hei": ("h", "ei"), - "hen": ("h", "en"), - "heng": ("h", "eng"), - "hong": ("h", "ong"), - "hou": ("h", "ou"), - "hu": ("h", "u"), - "hua": ("h", "ua"), - "huai": ("h", "uai"), - "huan": ("h", "uan"), - "huang": ("h", "uang"), - "hui": ("h", "uei"), - "hun": ("h", "uen"), - "huo": ("h", "uo"), - "ji": ("j", "i"), - "jia": ("j", "ia"), - "jian": ("j", "ian"), - "jiang": ("j", "iang"), - "jiao": ("j", "iao"), - "jie": ("j", "ie"), - "jin": ("j", "in"), - "jing": ("j", "ing"), - "jiong": ("j", "iong"), - "jiu": ("j", "iou"), - "ju": ("j", "v"), - "juan": ("j", "van"), - "jue": ("j", "ve"), - "jun": ("j", "vn"), - "ka": ("k", "a"), - "kai": ("k", "ai"), - "kan": ("k", "an"), - "kang": ("k", "ang"), - "kao": ("k", "ao"), - "ke": ("k", "e"), - "kei": ("k", "ei"), - "ken": ("k", "en"), - "keng": ("k", "eng"), - "kong": ("k", "ong"), - "kou": ("k", "ou"), - "ku": ("k", "u"), - "kua": ("k", "ua"), - "kuai": ("k", "uai"), - "kuan": ("k", "uan"), - "kuang": ("k", "uang"), - "kui": ("k", "uei"), - "kun": ("k", "uen"), - "kuo": ("k", "uo"), - "la": ("l", "a"), - "lai": ("l", "ai"), - "lan": ("l", "an"), - "lang": ("l", "ang"), - "lao": ("l", "ao"), - "le": ("l", "e"), - "lei": ("l", "ei"), - "leng": ("l", "eng"), - "li": ("l", "i"), - "lia": ("l", "ia"), - "lian": ("l", "ian"), - "liang": ("l", "iang"), - "liao": ("l", "iao"), - "lie": ("l", "ie"), - "lin": ("l", "in"), - "ling": ("l", "ing"), - "liu": ("l", "iou"), - "lo": ("l", "o"), - "long": ("l", "ong"), - "lou": ("l", "ou"), - "lu": ("l", "u"), - "lv": ("l", "v"), - "luan": ("l", "uan"), - "lve": ("l", "ve"), - "lue": ("l", "ve"), - "lun": ("l", "uen"), - "luo": ("l", "uo"), - "ma": ("m", "a"), - "mai": ("m", "ai"), - "man": ("m", "an"), - "mang": ("m", "ang"), - "mao": ("m", "ao"), - "me": ("m", "e"), - "mei": ("m", "ei"), - "men": ("m", "en"), - "meng": ("m", "eng"), - "mi": ("m", "i"), - "mian": ("m", "ian"), - "miao": ("m", "iao"), - "mie": ("m", "ie"), - "min": ("m", "in"), - "ming": ("m", "ing"), - "miu": ("m", "iou"), - "mo": ("m", "o"), - "mou": ("m", "ou"), - "mu": ("m", "u"), - "na": ("n", "a"), - "nai": ("n", "ai"), - "nan": ("n", "an"), - "nang": ("n", "ang"), - "nao": ("n", "ao"), - "ne": ("n", "e"), - "nei": ("n", "ei"), - "nen": ("n", "en"), - "neng": ("n", "eng"), - "ni": ("n", "i"), - "nia": ("n", "ia"), - "nian": ("n", "ian"), - "niang": ("n", "iang"), - "niao": ("n", "iao"), - "nie": ("n", "ie"), - "nin": ("n", "in"), - "ning": ("n", "ing"), - "niu": ("n", "iou"), - "nong": ("n", "ong"), - "nou": ("n", "ou"), - "nu": ("n", "u"), - "nv": ("n", "v"), - "nuan": ("n", "uan"), - "nve": ("n", "ve"), - "nue": ("n", "ve"), - "nuo": ("n", "uo"), - "o": ("^", "o"), - "ou": ("^", "ou"), - "pa": ("p", "a"), - "pai": ("p", "ai"), - "pan": ("p", "an"), - "pang": ("p", "ang"), - "pao": ("p", "ao"), - "pe": ("p", "e"), - "pei": ("p", "ei"), - "pen": ("p", "en"), - "peng": ("p", "eng"), - "pi": ("p", "i"), - "pian": ("p", "ian"), - "piao": ("p", "iao"), - "pie": ("p", "ie"), - "pin": ("p", "in"), - "ping": ("p", "ing"), - "po": ("p", "o"), - "pou": ("p", "ou"), - "pu": ("p", "u"), - "qi": ("q", "i"), - "qia": ("q", "ia"), - "qian": ("q", "ian"), - "qiang": ("q", "iang"), - "qiao": ("q", "iao"), - "qie": ("q", "ie"), - "qin": ("q", "in"), - "qing": ("q", "ing"), - "qiong": ("q", "iong"), - "qiu": ("q", "iou"), - "qu": ("q", "v"), - "quan": ("q", "van"), - "que": ("q", "ve"), - "qun": ("q", "vn"), - "ran": ("r", "an"), - "rang": ("r", "ang"), - "rao": ("r", "ao"), - "re": ("r", "e"), - "ren": ("r", "en"), - "reng": ("r", "eng"), - "ri": ("r", "iii"), - "rong": ("r", "ong"), - "rou": ("r", "ou"), - "ru": ("r", "u"), - "rua": ("r", "ua"), - "ruan": ("r", "uan"), - "rui": ("r", "uei"), - "run": ("r", "uen"), - "ruo": ("r", "uo"), - "sa": ("s", "a"), - "sai": ("s", "ai"), - "san": ("s", "an"), - "sang": ("s", "ang"), - "sao": ("s", "ao"), - "se": ("s", "e"), - "sen": ("s", "en"), - "seng": ("s", "eng"), - "sha": ("sh", "a"), - "shai": ("sh", "ai"), - "shan": ("sh", "an"), - "shang": ("sh", "ang"), - "shao": ("sh", "ao"), - "she": ("sh", "e"), - "shei": ("sh", "ei"), - "shen": ("sh", "en"), - "sheng": ("sh", "eng"), - "shi": ("sh", "iii"), - "shou": ("sh", "ou"), - "shu": ("sh", "u"), - "shua": ("sh", "ua"), - "shuai": ("sh", "uai"), - "shuan": ("sh", "uan"), - "shuang": ("sh", "uang"), - "shui": ("sh", "uei"), - "shun": ("sh", "uen"), - "shuo": ("sh", "uo"), - "si": ("s", "ii"), - "song": ("s", "ong"), - "sou": ("s", "ou"), - "su": ("s", "u"), - "suan": ("s", "uan"), - "sui": ("s", "uei"), - "sun": ("s", "uen"), - "suo": ("s", "uo"), - "ta": ("t", "a"), - "tai": ("t", "ai"), - "tan": ("t", "an"), - "tang": ("t", "ang"), - "tao": ("t", "ao"), - "te": ("t", "e"), - "tei": ("t", "ei"), - "teng": ("t", "eng"), - "ti": ("t", "i"), - "tian": ("t", "ian"), - "tiao": ("t", "iao"), - "tie": ("t", "ie"), - "ting": ("t", "ing"), - "tong": ("t", "ong"), - "tou": ("t", "ou"), - "tu": ("t", "u"), - "tuan": ("t", "uan"), - "tui": ("t", "uei"), - "tun": ("t", "uen"), - "tuo": ("t", "uo"), - "wa": ("^", "ua"), - "wai": ("^", "uai"), - "wan": ("^", "uan"), - "wang": ("^", "uang"), - "wei": ("^", "uei"), - "wen": ("^", "uen"), - "weng": ("^", "ueng"), - "wo": ("^", "uo"), - "wu": ("^", "u"), - "xi": ("x", "i"), - "xia": ("x", "ia"), - "xian": ("x", "ian"), - "xiang": ("x", "iang"), - "xiao": ("x", "iao"), - "xie": ("x", "ie"), - "xin": ("x", "in"), - "xing": ("x", "ing"), - "xiong": ("x", "iong"), - "xiu": ("x", "iou"), - "xu": ("x", "v"), - "xuan": ("x", "van"), - "xue": ("x", "ve"), - "xun": ("x", "vn"), - "ya": ("^", "ia"), - "yan": ("^", "ian"), - "yang": ("^", "iang"), - "yao": ("^", "iao"), - "ye": ("^", "ie"), - "yi": ("^", "i"), - "yin": ("^", "in"), - "ying": ("^", "ing"), - "yo": ("^", "iou"), - "yong": ("^", "iong"), - "you": ("^", "iou"), - "yu": ("^", "v"), - "yuan": ("^", "van"), - "yue": ("^", "ve"), - "yun": ("^", "vn"), - "za": ("z", "a"), - "zai": ("z", "ai"), - "zan": ("z", "an"), - "zang": ("z", "ang"), - "zao": ("z", "ao"), - "ze": ("z", "e"), - "zei": ("z", "ei"), - "zen": ("z", "en"), - "zeng": ("z", "eng"), - "zha": ("zh", "a"), - "zhai": ("zh", "ai"), - "zhan": ("zh", "an"), - "zhang": ("zh", "ang"), - "zhao": ("zh", "ao"), - "zhe": ("zh", "e"), - "zhei": ("zh", "ei"), - "zhen": ("zh", "en"), - "zheng": ("zh", "eng"), - "zhi": ("zh", "iii"), - "zhong": ("zh", "ong"), - "zhou": ("zh", "ou"), - "zhu": ("zh", "u"), - "zhua": ("zh", "ua"), - "zhuai": ("zh", "uai"), - "zhuan": ("zh", "uan"), - "zhuang": ("zh", "uang"), - "zhui": ("zh", "uei"), - "zhun": ("zh", "uen"), - "zhuo": ("zh", "uo"), - "zi": ("z", "ii"), - "zong": ("z", "ong"), - "zou": ("z", "ou"), - "zu": ("z", "u"), - "zuan": ("z", "uan"), - "zui": ("z", "uei"), - "zun": ("z", "uen"), - "zuo": ("z", "uo"), -} diff --git a/egs/baker_zh/TTS/local/prepare_token_file.py b/egs/baker_zh/TTS/local/prepare_token_file.py deleted file mode 100755 index d90910ab02..0000000000 --- a/egs/baker_zh/TTS/local/prepare_token_file.py +++ /dev/null @@ -1,53 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -""" -This file generates the file that maps tokens to IDs. -""" - -import argparse -import logging -from pathlib import Path -from typing import Dict -from symbols import symbols - - -def get_args(): - parser = argparse.ArgumentParser() - - parser.add_argument( - "--tokens", - type=Path, - default=Path("data/tokens.txt"), - help="Path to the dict that maps the text tokens to IDs", - ) - - return parser.parse_args() - - -def main(): - args = get_args() - tokens = Path(args.tokens) - - with open(tokens, "w", encoding="utf-8") as f: - for token_id, token in enumerate(symbols): - f.write(f"{token} {token_id}\n") - - -if __name__ == "__main__": - main() diff --git a/egs/baker_zh/TTS/local/prepare_tokens_baker_zh.py b/egs/baker_zh/TTS/local/prepare_tokens_baker_zh.py deleted file mode 100755 index 0b27fd1e9e..0000000000 --- a/egs/baker_zh/TTS/local/prepare_tokens_baker_zh.py +++ /dev/null @@ -1,59 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -""" -This file reads the texts in given manifest and save the new cuts with tokens. -""" - -import logging -from pathlib import Path - -from lhotse import CutSet, load_manifest - -from tokenizer import Tokenizer - - -def prepare_tokens_baker_zh(): - output_dir = Path("data/spectrogram") - prefix = "baker_zh" - suffix = "jsonl.gz" - partition = "all" - - cut_set = load_manifest(output_dir / f"{prefix}_cuts_{partition}.{suffix}") - - tokenizer = Tokenizer() - - new_cuts = [] - i = 0 - for cut in cut_set: - # Each cut only contains one supervision - assert len(cut.supervisions) == 1, (len(cut.supervisions), cut) - text = cut.supervisions[0].normalized_text - cut.tokens = tokenizer.text_to_tokens(text) - - new_cuts.append(cut) - - new_cut_set = CutSet.from_cuts(new_cuts) - new_cut_set.to_file(output_dir / f"{prefix}_cuts_with_tokens_{partition}.{suffix}") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - logging.basicConfig(format=formatter, level=logging.INFO) - - prepare_tokens_baker_zh() diff --git a/egs/baker_zh/TTS/local/pypinyin-local.dict b/egs/baker_zh/TTS/local/pypinyin-local.dict deleted file mode 100644 index 5e386014c8..0000000000 --- a/egs/baker_zh/TTS/local/pypinyin-local.dict +++ /dev/null @@ -1,328 +0,0 @@ -姐姐 jie3 jie -宝宝 bao3 bao -哥哥 ge1 ge -妹妹 mei4 mei -弟弟 di4 di -妈妈 ma1 ma -开心哦 kai1 xin1 o -爸爸 ba4 ba -秘密哟 mi4 mi4 yo -哦 o -一年 yi4 nian2 -一夜 yi2 ye4 -一切 yi2 qie4 -一座 yi2 zuo4 -一下 yi2 xia4 -上一山 shang4 yi2 shan1 -下一山 xia4 yi2 shan1 -休息 xiu1 xi2 -东西 dong1 xi -上一届 shang4 yi2 jie4 -便宜 pian2 yi4 -加长 jia1 chang2 -单田芳 shan4 tian2 fang1 -帧 zhen1 -长时间 chang2 shi2 jian1 -长时 chang2 shi2 -识别 shi2 bie2 -生命中 sheng1 ming4 zhong1 -踏实 ta1 shi -嗯 en4 -溜达 liu1 da -少儿 shao4 er2 -爷爷 ye2 ye -不是 bu2 shi4 -一圈 yi1 quan1 -厜读一声 zui1 du2 yi4 sheng1 -一种 yi4 zhong3 -一簇簇 yi2 cu4 cu4 -一个 yi2 ge4 -一样 yi2 yang4 -一跩一跩 yi4 zhuai3 yi4 zhuai3 -一会儿 yi2 hui4 er -一幢 yi2 zhuang4 -挨了 ai2 le -熬菜 ao1 cai4 -扒鸡 pa2 ji1 -背枪 bei1 qiang1 -绷瓷儿 beng4 ci2 er2 -绷劲儿 beng3 jin4 er -绷着脸 beng3 zhe lian3 -藏医 zang4 yi1 -噌吰 cheng1 hong2 -差点儿 cha4 dian3 er -差失 cha1 shi1 -差误 cha1 wu4 -孱头 can4 tou -乘间 cheng2 jian4 -锄镰棘矜 chu2 lian2 ji2 qin2 -川藏 chuan1 zang4 -穿著 chuan1 zhuo2 -答讪 da1 shan4 -答言 da1 yan2 -大伯子 da4 bai3 zi -大夫 dai4 fu -弹冠 tan2 guan1 -当间 dang1 jian4 -当然咯 dang1 ran2 lo -点种 dian3 zhong3 -垛好 duo4 hao3 -发疟子 fa1 yao4 zi -饭熟了 fan4 shou2 le -附著 fu4 zhuo2 -复沓 fu4 ta4 -供稿 gong1 gao3 -供养 gong1 yang3 -骨朵 gu1 duo -骨碌 gu1 lu -果脯 guo3 fu3 -哈什玛 ha4 shi2 ma3 -海蜇 hai3 zhe2 -呵欠 he1 qian -河水汤汤 he2 shui3 shang1 shang1 -鹄立 hu2 li4 -鹄望 hu2 wang4 -混人 hun2 ren2 -混水 hun2 shui3 -鸡血 ji1 xie3 -缉鞋口 qi1 xie2 kou3 -亟来闻讯 qi4 lai2 wen2 xun4 -计量 ji4 liang2 -济水 ji3 shui3 -间杂 jian4 za2 -脚跐两只船 jiao3 ci3 liang3 zhi1 chuan2 -脚儿 jue2 er2 -口角 kou3 jiao3 -勒石 le4 shi2 -累进 lei3 jin4 -累累如丧家之犬 lei2 lei2 ru2 sang4 jia1 zhi1 quan3 -累年 lei3 nian2 -脸涨通红 lian3 zhang4 tong1 hong2 -踉锵 liang4 qiang1 -燎眉毛 liao3 mei2 mao2 -燎头发 liao3 tou2 fa4 -溜达 liu1 da -溜缝儿 liu4 feng4 er -馏口饭 liu4 kou3 fan4 -遛马 liu4 ma3 -遛鸟 liu4 niao3 -遛弯儿 liu4 wan1 er -楼枪机 lou1 qiang1 ji1 -搂钱 lou1 qian2 -鹿脯 lu4 fu3 -露头 lou4 tou2 -落魄 luo4 po4 -捋胡子 lv3 hu2 zi -绿地 lv4 di4 -麦垛 mai4 duo4 -没劲儿 mei2 jin4 er -闷棍 men4 gun4 -闷葫芦 men4 hu2 lu -闷头干 men1 tou2 gan4 -蒙古 meng3 gu3 -靡日不思 mi3 ri4 bu4 si1 -缪姓 miao4 xing4 -抹墙 mo4 qiang2 -抹下脸 ma1 xia4 lian3 -泥子 ni4 zi -拗不过 niu4 bu guo4 -排车 pai3 che1 -盘诘 pan2 jie2 -膀肿 pang1 zhong3 -炮干 bao1 gan1 -炮格 pao2 ge2 -碰钉子 peng4 ding1 zi -缥色 piao3 se4 -瀑河 bao4 he2 -蹊径 xi1 jing4 -前后相属 qian2 hou4 xiang1 zhu3 -翘尾巴 qiao4 wei3 ba -趄坡儿 qie4 po1 er -秦桧 qin2 hui4 -圈马 juan1 ma3 -雀盲眼 qiao3 mang2 yan3 -雀子 qiao1 zi -三年五载 san1 nian2 wu3 zai3 -加载 jia1 zai3 -山大王 shan1 dai4 wang -苫屋草 shan4 wu1 cao3 -数数 shu3 shu4 -说客 shui4 ke4 -思量 si1 liang2 -伺侯 ci4 hou -踏实 ta1 shi -提溜 di1 liu -调拨 diao4 bo1 -帖子 tie3 zi -铜钿 tong2 tian2 -头昏脑涨 tou2 hun1 nao3 zhang4 -褪色 tui4 se4 -褪着手 tun4 zhe shou3 -圩子 wei2 zi -尾巴 wei3 ba -系好船只 xi4 hao3 chuan2 zhi1 -系好马匹 xi4 hao3 ma3 pi3 -杏脯 xing4 fu3 -姓单 xing4 shan4 -姓葛 xing4 ge3 -姓哈 xing4 ha3 -姓解 xing4 xie4 -姓秘 xing4 bi4 -姓宁 xing4 ning4 -旋风 xuan4 feng1 -旋根车轴 xuan4 gen1 che1 zhou2 -荨麻 qian2 ma2 -一幢楼房 yi1 zhuang4 lou2 fang2 -遗之千金 wei4 zhi1 qian1 jin1 -殷殷 yin3 yin3 -应招 ying4 zhao1 -用称约 yong4 cheng4 yao1 -约斤肉 yao1 jin1 rou4 -晕机 yun4 ji1 -熨贴 yu4 tie1 -咋办 za3 ban4 -咋呼 zha1 hu -仔兽 zi3 shou4 -扎彩 za1 cai3 -扎实 zha1 shi -扎腰带 za1 yao1 dai4 -轧朋友 ga2 peng2 you3 -爪子 zhua3 zi -折腾 zhe1 teng -着实 zhuo2 shi2 -着我旧时裳 zhuo2 wo3 jiu4 shi2 chang2 -枝蔓 zhi1 man4 -中鹄 zhong1 hu2 -中选 zhong4 xuan3 -猪圈 zhu1 juan4 -拽住不放 zhuai4 zhu4 bu4 fang4 -转悠 zhuan4 you -庄稼熟了 zhuang1 jia shou2 le -酌量 zhuo2 liang2 -罪行累累 zui4 xing2 lei3 lei3 -一手 yi4 shou3 -一去不复返 yi2 qu4 bu2 fu4 fan3 -一颗 yi4 ke1 -一件 yi2 jian4 -一斤 yi4 jin1 -一点 yi4 dian3 -一朵 yi4 duo3 -一声 yi4 sheng1 -一身 yi4 shen1 -不要 bu2 yao4 -一人 yi4 ren2 -一个 yi2 ge4 -一把 yi4 ba3 -一门 yi4 men2 -一門 yi4 men2 -一艘 yi4 sou1 -一片 yi2 pian4 -一篇 yi2 pian1 -一份 yi2 fen4 -好嗲 hao3 dia3 -随地 sui2 di4 -扁担长 bian3 dan4 chang3 -一堆 yi4 dui1 -不义 bu2 yi4 -放一放 fang4 yi2 fang4 -一米 yi4 mi3 -一顿 yi2 dun4 -一层楼 yi4 ceng2 lou2 -一条 yi4 tiao2 -一件 yi2 jian4 -一棵 yi4 ke1 -一小股 yi4 xiao3 gu3 -一拐一拐 yi4 guai3 yi4 guai3 -一根 yi4 gen1 -沆瀣一气 hang4 xie4 yi2 qi4 -一丝 yi4 si1 -一毫 yi4 hao2 -一樣 yi2 yang4 -处处 chu4 chu4 -一餐 yi4 can -永不 yong3 bu2 -一看 yi2 kan4 -一架 yi2 jia4 -送还 song4 huan2 -一见 yi2 jian4 -一座 yi2 zuo4 -一块 yi2 kuai4 -一天 yi4 tian1 -一只 yi4 zhi1 -一支 yi4 zhi1 -一字 yi2 zi4 -一句 yi2 ju4 -一张 yi4 zhang1 -一條 yi4 tiao2 -一场 yi4 chang3 -一粒 yi2 li4 -小俩口 xiao3 liang3 kou3 -一首 yi4 shou3 -一对 yi2 dui4 -一手 yi4 shou3 -又一村 you4 yi4 cun1 -一概而论 yi2 gai4 er2 lun4 -一峰峰 yi4 feng1 feng1 -不但 bu2 dan4 -一笑 yi2 xiao4 -挠痒痒 nao2 yang3 yang -不对 bu2 dui4 -拧开 ning3 kai1 -爱不释手 ai4 bu2 shi4 shou3 -一念 yi2 nian4 -夺得 duo2 de2 -一袭 yi4 xi2 -一定 yi2 ding4 -不慎 bu2 shen4 -剽窃 piao2 qie4 -一时 yi4 shi2 -撇开 pie3 kai1 -一祭 yi2 ji4 -发卡 fa4 qia3 -少不了 shao3 bu4 liao3 -千虑一失 qian1 lv4 yi4 shi1 -呛得 qiang4 de2 -切菜 qie1 cai4 -茄盒 qie2 he2 -不去 bu2 qu4 -一大圈 yi2 da4 quan1 -不再 bu2 zai4 -一群 yi4 qun2 -不必 bu2 bi4 -一些 yi4 xie1 -一路 yi2 lu4 -一股 yi4 gu3 -一到 yi2 dao4 -一拨 yi4 bo1 -一排 yi4 pai2 -一空 yi4 kong1 -吮吸着 shun3 xi1 zhe -不适合 bu2 shi4 he2 -一串串 yi2 chuan4 chuan4 -一提起 yi4 ti2 qi3 -一尘不染 yi4 chen2 bu4 ran3 -一生 yi4 sheng1 -一派 yi2 pai4 -不断 bu2 duan4 -一次 yi2 ci4 -不进步 bu2 jin4 bu4 -娃娃 wa2 wa -万户侯 wan4 hu4 hou2 -一方 yi4 fang1 -一番话 yi4 fan1 hua4 -一遍 yi2 bian4 -不计较 bu2 ji4 jiao4 -诇 xiong4 -一边 yi4 bian1 -一束 yi2 shu4 -一听到 yi4 ting1 dao4 -炸鸡 zha2 ji1 -乍暧还寒 zha4 ai4 huan2 han2 -我说诶 wo3 shuo1 ei1 -棒诶 bang4 ei1 -寒碜 han2 chen4 -应采儿 ying4 cai3 er2 -晕车 yun1 che1 -必应 bi4 ying4 -应援 ying4 yuan2 -应力 ying4 li4 \ No newline at end of file diff --git a/egs/baker_zh/TTS/local/symbols.py b/egs/baker_zh/TTS/local/symbols.py deleted file mode 100644 index 1e68788704..0000000000 --- a/egs/baker_zh/TTS/local/symbols.py +++ /dev/null @@ -1,73 +0,0 @@ -# This file is copied from -# https://github.com/UEhQZXI/vits_chinese/blob/master/text/symbols.py -_pause = ["sil", "eos", "sp", "#0", "#1", "#2", "#3"] - -_initials = [ - "^", - "b", - "c", - "ch", - "d", - "f", - "g", - "h", - "j", - "k", - "l", - "m", - "n", - "p", - "q", - "r", - "s", - "sh", - "t", - "x", - "z", - "zh", -] - -_tones = ["1", "2", "3", "4", "5"] - -_finals = [ - "a", - "ai", - "an", - "ang", - "ao", - "e", - "ei", - "en", - "eng", - "er", - "i", - "ia", - "ian", - "iang", - "iao", - "ie", - "ii", - "iii", - "in", - "ing", - "iong", - "iou", - "o", - "ong", - "ou", - "u", - "ua", - "uai", - "uan", - "uang", - "uei", - "uen", - "ueng", - "uo", - "v", - "van", - "ve", - "vn", -] - -symbols = _pause + _initials + [i + j for i in _finals for j in _tones] diff --git a/egs/baker_zh/TTS/local/tokenizer.py b/egs/baker_zh/TTS/local/tokenizer.py deleted file mode 100644 index cbf6c9c773..0000000000 --- a/egs/baker_zh/TTS/local/tokenizer.py +++ /dev/null @@ -1,137 +0,0 @@ -# This file is modified from -# https://github.com/UEhQZXI/vits_chinese/blob/master/vits_strings.py - -import logging -from pathlib import Path -from typing import List - -# Note pinyin_dict is from ./pinyin_dict.py -from pinyin_dict import pinyin_dict -from pypinyin import Style -from pypinyin.contrib.neutral_tone import NeutralToneWith5Mixin -from pypinyin.converter import DefaultConverter -from pypinyin.core import Pinyin, load_phrases_dict - - -class _MyConverter(NeutralToneWith5Mixin, DefaultConverter): - pass - - -class Tokenizer: - def __init__(self, tokens: str = ""): - self._load_pinyin_dict() - self._pinyin_parser = Pinyin(_MyConverter()) - - if tokens != "": - self._load_tokens(tokens) - - def texts_to_token_ids(self, texts: List[str], **kwargs) -> List[List[int]]: - """ - Args: - texts: - A list of sentences. - kwargs: - Not used. It is for compatibility with other TTS recipes in icefall. - """ - tokens = [] - - for text in texts: - tokens.append(self.text_to_tokens(text)) - - return self.tokens_to_token_ids(tokens) - - def tokens_to_token_ids(self, tokens: List[List[str]]) -> List[List[int]]: - ans = [] - - for token_list in tokens: - token_ids = [] - for t in token_list: - if t not in self.token2id: - logging.warning(f"Skip OOV {t}") - continue - token_ids.append(self.token2id[t]) - ans.append(token_ids) - - return ans - - def text_to_tokens(self, text: str) -> List[str]: - # Convert "," to ["sp", "sil"] - # Convert "。" to ["sil"] - # append ["eos"] at the end of a sentence - phonemes = ["sil"] - pinyins = self._pinyin_parser.pinyin( - text, - style=Style.TONE3, - errors=lambda x: [[w] for w in x], - ) - - new_pinyin = [] - for p in pinyins: - p = p[0] - if p == ",": - new_pinyin.extend(["sp", "sil"]) - elif p == "。": - new_pinyin.append("sil") - else: - new_pinyin.append(p) - sub_phonemes = self._get_phoneme4pinyin(new_pinyin) - sub_phonemes.append("eos") - phonemes.extend(sub_phonemes) - return phonemes - - def _get_phoneme4pinyin(self, pinyins): - result = [] - for pinyin in pinyins: - if pinyin in ("sil", "sp"): - result.append(pinyin) - elif pinyin[:-1] in pinyin_dict: - tone = pinyin[-1] - a = pinyin[:-1] - a1, a2 = pinyin_dict[a] - # every word is appended with a #0 - result += [a1, a2 + tone, "#0"] - - return result - - def _load_pinyin_dict(self): - this_dir = Path(__file__).parent.resolve() - my_dict = {} - with open(f"{this_dir}/pypinyin-local.dict", "r", encoding="utf-8") as f: - content = f.readlines() - for line in content: - cuts = line.strip().split() - hanzi = cuts[0] - pinyin = cuts[1:] - my_dict[hanzi] = [[p] for p in pinyin] - - load_phrases_dict(my_dict) - - def _load_tokens(self, filename): - token2id: Dict[str, int] = {} - - with open(filename, "r", encoding="utf-8") as f: - for line in f.readlines(): - info = line.rstrip().split() - if len(info) == 1: - # case of space - token = " " - idx = int(info[0]) - else: - token, idx = info[0], int(info[1]) - - assert token not in token2id, token - - token2id[token] = idx - - self.token2id = token2id - self.vocab_size = len(self.token2id) - self.pad_id = self.token2id["#0"] - - -def main(): - tokenizer = Tokenizer() - tokenizer._sentence_to_ids("你好,好的。") - - -if __name__ == "__main__": - main() diff --git a/egs/baker_zh/TTS/local/validate_manifest.py b/egs/baker_zh/TTS/local/validate_manifest.py deleted file mode 120000 index b4d52ebca0..0000000000 --- a/egs/baker_zh/TTS/local/validate_manifest.py +++ /dev/null @@ -1 +0,0 @@ -../../../ljspeech/TTS/local/validate_manifest.py \ No newline at end of file diff --git a/egs/baker_zh/TTS/prepare.sh b/egs/baker_zh/TTS/prepare.sh deleted file mode 100755 index 6fa87fe438..0000000000 --- a/egs/baker_zh/TTS/prepare.sh +++ /dev/null @@ -1,124 +0,0 @@ -#!/usr/bin/env bash - -# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 -export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python - -set -eou pipefail - -stage=-1 -stop_stage=100 - -dl_dir=$PWD/download - -. shared/parse_options.sh || exit 1 - -# All files generated by this script are saved in "data". -# You can safely remove "data" and rerun this script to regenerate it. -mkdir -p data - -log() { - # This function is from espnet - local fname=${BASH_SOURCE[1]##*/} - echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" -} - -log "dl_dir: $dl_dir" - -if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then - log "Stage 0: build monotonic_align lib" - if [ ! -d vits/monotonic_align/build ]; then - cd vits/monotonic_align - python3 setup.py build_ext --inplace - cd ../../ - else - log "monotonic_align lib already built" - fi -fi - -if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then - log "Stage 1: Download data" - - # The directory $dl_dir/BZNSYP will contain 3 sub directories: - # - PhoneLabeling - # - ProsodyLabeling - # - Wave - - # If you have pre-downloaded it to /path/to/BZNSYP, you can create a symlink - # - # ln -sfv /path/to/BZNSYP $dl_dir/ - # touch $dl_dir/BZNSYP/.completed - # - if [ ! -d $dl_dir/BZNSYP ]; then - lhotse download baker-zh $dl_dir - fi -fi - -if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then - log "Stage 2: Prepare baker-zh manifest" - # We assume that you have downloaded the baker corpus - # to $dl_dir/BZNSYP - mkdir -p data/manifests - if [ ! -e data/manifests/.baker.done ]; then - lhotse prepare baker-zh $dl_dir/BZNSYP data/manifests - touch data/manifests/.baker.done - fi -fi - -if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then - log "Stage 3: Compute spectrogram for baker (may take 3 minutes)" - mkdir -p data/spectrogram - if [ ! -e data/spectrogram/.baker.done ]; then - ./local/compute_spectrogram_baker.py - touch data/spectrogram/.baker.done - fi - - if [ ! -e data/spectrogram/.baker-validated.done ]; then - log "Validating data/spectrogram for baker" - python3 ./local/validate_manifest.py \ - data/spectrogram/baker_zh_cuts_all.jsonl.gz - touch data/spectrogram/.baker-validated.done - fi -fi - -if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then - log "Stage 4: Prepare tokens for baker-zh (may take 20 seconds)" - if [ ! -e data/spectrogram/.baker_zh_with_token.done ]; then - - ./local/prepare_tokens_baker_zh.py - - mv -v data/spectrogram/baker_zh_cuts_with_tokens_all.jsonl.gz \ - data/spectrogram/baker_zh_cuts_all.jsonl.gz - - touch data/spectrogram/.baker_zh_with_token.done - fi -fi - -if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then - log "Stage 5: Split the baker-zh cuts into train, valid and test sets (may take 25 seconds)" - if [ ! -e data/spectrogram/.baker_zh_split.done ]; then - lhotse subset --last 600 \ - data/spectrogram/baker_zh_cuts_all.jsonl.gz \ - data/spectrogram/baker_zh_cuts_validtest.jsonl.gz - lhotse subset --first 100 \ - data/spectrogram/baker_zh_cuts_validtest.jsonl.gz \ - data/spectrogram/baker_zh_cuts_valid.jsonl.gz - lhotse subset --last 500 \ - data/spectrogram/baker_zh_cuts_validtest.jsonl.gz \ - data/spectrogram/baker_zh_cuts_test.jsonl.gz - - rm data/spectrogram/baker_zh_cuts_validtest.jsonl.gz - - n=$(( $(gunzip -c data/spectrogram/baker_zh_cuts_all.jsonl.gz | wc -l) - 600 )) - lhotse subset --first $n \ - data/spectrogram/baker_zh_cuts_all.jsonl.gz \ - data/spectrogram/baker_zh_cuts_train.jsonl.gz - touch data/spectrogram/.baker_zh_split.done - fi -fi - -if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then - log "Stage 6: Generate token file" - if [ ! -e data/tokens.txt ]; then - ./local/prepare_token_file.py --tokens data/tokens.txt - fi -fi diff --git a/egs/baker_zh/TTS/shared b/egs/baker_zh/TTS/shared deleted file mode 120000 index 4cbd91a7e9..0000000000 --- a/egs/baker_zh/TTS/shared +++ /dev/null @@ -1 +0,0 @@ -../../../icefall/shared \ No newline at end of file diff --git a/egs/baker_zh/TTS/vits/duration_predictor.py b/egs/baker_zh/TTS/vits/duration_predictor.py deleted file mode 120000 index 9972b476f9..0000000000 --- a/egs/baker_zh/TTS/vits/duration_predictor.py +++ /dev/null @@ -1 +0,0 @@ -../../../ljspeech/TTS/vits/duration_predictor.py \ No newline at end of file diff --git a/egs/baker_zh/TTS/vits/export-onnx.py b/egs/baker_zh/TTS/vits/export-onnx.py deleted file mode 100755 index 11c8a9791f..0000000000 --- a/egs/baker_zh/TTS/vits/export-onnx.py +++ /dev/null @@ -1,414 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2023 Xiaomi Corporation (Author: Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -This script exports a VITS model from PyTorch to ONNX. - -Export the model to ONNX: -./vits/export-onnx.py \ - --epoch 1000 \ - --exp-dir vits/exp \ - --tokens data/tokens.txt - -It will generate one file inside vits/exp: - - vits-epoch-1000.onnx - -See ./test_onnx.py for how to use the exported ONNX models. -""" - -import argparse -import logging -from pathlib import Path -from typing import Dict, Tuple - -import onnx -import torch -import torch.nn as nn -from tokenizer import Tokenizer -from train import get_model, get_params - -from icefall.checkpoint import load_checkpoint - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=1000, - help="""It specifies the checkpoint to use for decoding. - Note: Epoch counts from 1. - """, - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="vits/exp", - help="The experiment dir", - ) - - parser.add_argument( - "--tokens", - type=str, - default="data/tokens.txt", - help="""Path to vocabulary.""", - ) - - parser.add_argument( - "--model-type", - type=str, - default="high", - choices=["low", "medium", "high"], - help="""If not empty, valid values are: low, medium, high. - It controls the model size. low -> runs faster. - """, - ) - - return parser - - -def add_meta_data(filename: str, meta_data: Dict[str, str]): - """Add meta data to an ONNX model. It is changed in-place. - - Args: - filename: - Filename of the ONNX model to be changed. - meta_data: - Key-value pairs. - """ - model = onnx.load(filename) - for key, value in meta_data.items(): - meta = model.metadata_props.add() - meta.key = key - meta.value = str(value) - - onnx.save(model, filename) - - -class OnnxModel(nn.Module): - """A wrapper for VITS generator.""" - - def __init__(self, model: nn.Module): - """ - Args: - model: - A VITS generator. - frame_shift: - The frame shift in samples. - """ - super().__init__() - self.model = model - - def forward( - self, - tokens: torch.Tensor, - tokens_lens: torch.Tensor, - noise_scale: float = 0.667, - alpha: float = 1.0, - noise_scale_dur: float = 0.8, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """Please see the help information of VITS.inference_batch - - Args: - tokens: - Input text token indexes (1, T_text) - tokens_lens: - Number of tokens of shape (1,) - noise_scale (float): - Noise scale parameter for flow. - noise_scale_dur (float): - Noise scale parameter for duration predictor. - alpha (float): - Alpha parameter to control the speed of generated speech. - - Returns: - Return a tuple containing: - - audio, generated wavform tensor, (B, T_wav) - """ - audio, _, _ = self.model.generator.inference( - text=tokens, - text_lengths=tokens_lens, - noise_scale=noise_scale, - noise_scale_dur=noise_scale_dur, - alpha=alpha, - ) - return audio - - -def export_model_onnx( - model: nn.Module, - model_filename: str, - vocab_size: int, - opset_version: int = 11, -) -> None: - """Export the given generator model to ONNX format. - The exported model has one input: - - - tokens, a tensor of shape (1, T_text); dtype is torch.int64 - - and it has one output: - - - audio, a tensor of shape (1, T'); dtype is torch.float32 - - Args: - model: - The VITS generator. - model_filename: - The filename to save the exported ONNX model. - vocab_size: - Number of tokens used in training. - opset_version: - The opset version to use. - """ - tokens = torch.randint(low=0, high=vocab_size, size=(1, 13), dtype=torch.int64) - tokens_lens = torch.tensor([tokens.shape[1]], dtype=torch.int64) - noise_scale = torch.tensor([1], dtype=torch.float32) - noise_scale_dur = torch.tensor([1], dtype=torch.float32) - alpha = torch.tensor([1], dtype=torch.float32) - - torch.onnx.export( - model, - (tokens, tokens_lens, noise_scale, alpha, noise_scale_dur), - model_filename, - verbose=False, - opset_version=opset_version, - input_names=[ - "tokens", - "tokens_lens", - "noise_scale", - "alpha", - "noise_scale_dur", - ], - output_names=["audio"], - dynamic_axes={ - "tokens": {0: "N", 1: "T"}, - "tokens_lens": {0: "N"}, - "audio": {0: "N", 1: "T"}, - }, - ) - - if model.model.spks is None: - num_speakers = 1 - else: - num_speakers = model.model.spks - - meta_data = { - "model_type": "vits", - "version": "1", - "model_author": "k2-fsa", - "comment": "icefall", # must be icefall for models from icefall - "language": "Chinese", - "n_speakers": num_speakers, - "sample_rate": model.model.sampling_rate, # Must match the real sample rate - } - logging.info(f"meta_data: {meta_data}") - - add_meta_data(filename=model_filename, meta_data=meta_data) - - -@torch.no_grad() -def main(): - args = get_parser().parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - tokenizer = Tokenizer(params.tokens) - params.blank_id = tokenizer.pad_id - params.vocab_size = tokenizer.vocab_size - - logging.info(params) - - logging.info("About to create model") - model = get_model(params) - - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - - model.to("cpu") - model.eval() - - model = OnnxModel(model=model) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"generator parameters: {num_param}, or {num_param/1000/1000} M") - - suffix = f"epoch-{params.epoch}" - - opset_version = 13 - - logging.info("Exporting encoder") - model_filename = params.exp_dir / f"vits-{suffix}.onnx" - export_model_onnx( - model, - model_filename, - params.vocab_size, - opset_version=opset_version, - ) - logging.info(f"Exported generator to {model_filename}") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - logging.basicConfig(format=formatter, level=logging.INFO) - main() - -""" -Supported languages. - -LJSpeech is using "en-us" from the second column. - -Pty Language Age/Gender VoiceName File Other Languages - 5 af --/M Afrikaans gmw/af - 5 am --/M Amharic sem/am - 5 an --/M Aragonese roa/an - 5 ar --/M Arabic sem/ar - 5 as --/M Assamese inc/as - 5 az --/M Azerbaijani trk/az - 5 ba --/M Bashkir trk/ba - 5 be --/M Belarusian zle/be - 5 bg --/M Bulgarian zls/bg - 5 bn --/M Bengali inc/bn - 5 bpy --/M Bishnupriya_Manipuri inc/bpy - 5 bs --/M Bosnian zls/bs - 5 ca --/M Catalan roa/ca - 5 chr-US-Qaaa-x-west --/M Cherokee_ iro/chr - 5 cmn --/M Chinese_(Mandarin,_latin_as_English) sit/cmn (zh-cmn 5)(zh 5) - 5 cmn-latn-pinyin --/M Chinese_(Mandarin,_latin_as_Pinyin) sit/cmn-Latn-pinyin (zh-cmn 5)(zh 5) - 5 cs --/M Czech zlw/cs - 5 cv --/M Chuvash trk/cv - 5 cy --/M Welsh cel/cy - 5 da --/M Danish gmq/da - 5 de --/M German gmw/de - 5 el --/M Greek grk/el - 5 en-029 --/M English_(Caribbean) gmw/en-029 (en 10) - 2 en-gb --/M English_(Great_Britain) gmw/en (en 2) - 5 en-gb-scotland --/M English_(Scotland) gmw/en-GB-scotland (en 4) - 5 en-gb-x-gbclan --/M English_(Lancaster) gmw/en-GB-x-gbclan (en-gb 3)(en 5) - 5 en-gb-x-gbcwmd --/M English_(West_Midlands) gmw/en-GB-x-gbcwmd (en-gb 9)(en 9) - 5 en-gb-x-rp --/M English_(Received_Pronunciation) gmw/en-GB-x-rp (en-gb 4)(en 5) - 2 en-us --/M English_(America) gmw/en-US (en 3) - 5 en-us-nyc --/M English_(America,_New_York_City) gmw/en-US-nyc - 5 eo --/M Esperanto art/eo - 5 es --/M Spanish_(Spain) roa/es - 5 es-419 --/M Spanish_(Latin_America) roa/es-419 (es-mx 6) - 5 et --/M Estonian urj/et - 5 eu --/M Basque eu - 5 fa --/M Persian ira/fa - 5 fa-latn --/M Persian_(Pinglish) ira/fa-Latn - 5 fi --/M Finnish urj/fi - 5 fr-be --/M French_(Belgium) roa/fr-BE (fr 8) - 5 fr-ch --/M French_(Switzerland) roa/fr-CH (fr 8) - 5 fr-fr --/M French_(France) roa/fr (fr 5) - 5 ga --/M Gaelic_(Irish) cel/ga - 5 gd --/M Gaelic_(Scottish) cel/gd - 5 gn --/M Guarani sai/gn - 5 grc --/M Greek_(Ancient) grk/grc - 5 gu --/M Gujarati inc/gu - 5 hak --/M Hakka_Chinese sit/hak - 5 haw --/M Hawaiian map/haw - 5 he --/M Hebrew sem/he - 5 hi --/M Hindi inc/hi - 5 hr --/M Croatian zls/hr (hbs 5) - 5 ht --/M Haitian_Creole roa/ht - 5 hu --/M Hungarian urj/hu - 5 hy --/M Armenian_(East_Armenia) ine/hy (hy-arevela 5) - 5 hyw --/M Armenian_(West_Armenia) ine/hyw (hy-arevmda 5)(hy 8) - 5 ia --/M Interlingua art/ia - 5 id --/M Indonesian poz/id - 5 io --/M Ido art/io - 5 is --/M Icelandic gmq/is - 5 it --/M Italian roa/it - 5 ja --/M Japanese jpx/ja - 5 jbo --/M Lojban art/jbo - 5 ka --/M Georgian ccs/ka - 5 kk --/M Kazakh trk/kk - 5 kl --/M Greenlandic esx/kl - 5 kn --/M Kannada dra/kn - 5 ko --/M Korean ko - 5 kok --/M Konkani inc/kok - 5 ku --/M Kurdish ira/ku - 5 ky --/M Kyrgyz trk/ky - 5 la --/M Latin itc/la - 5 lb --/M Luxembourgish gmw/lb - 5 lfn --/M Lingua_Franca_Nova art/lfn - 5 lt --/M Lithuanian bat/lt - 5 ltg --/M Latgalian bat/ltg - 5 lv --/M Latvian bat/lv - 5 mi --/M Māori poz/mi - 5 mk --/M Macedonian zls/mk - 5 ml --/M Malayalam dra/ml - 5 mr --/M Marathi inc/mr - 5 ms --/M Malay poz/ms - 5 mt --/M Maltese sem/mt - 5 mto --/M Totontepec_Mixe miz/mto - 5 my --/M Myanmar_(Burmese) sit/my - 5 nb --/M Norwegian_Bokmål gmq/nb (no 5) - 5 nci --/M Nahuatl_(Classical) azc/nci - 5 ne --/M Nepali inc/ne - 5 nl --/M Dutch gmw/nl - 5 nog --/M Nogai trk/nog - 5 om --/M Oromo cus/om - 5 or --/M Oriya inc/or - 5 pa --/M Punjabi inc/pa - 5 pap --/M Papiamento roa/pap - 5 piqd --/M Klingon art/piqd - 5 pl --/M Polish zlw/pl - 5 pt --/M Portuguese_(Portugal) roa/pt (pt-pt 5) - 5 pt-br --/M Portuguese_(Brazil) roa/pt-BR (pt 6) - 5 py --/M Pyash art/py - 5 qdb --/M Lang_Belta art/qdb - 5 qu --/M Quechua qu - 5 quc --/M K'iche' myn/quc - 5 qya --/M Quenya art/qya - 5 ro --/M Romanian roa/ro - 5 ru --/M Russian zle/ru - 5 ru-cl --/M Russian_(Classic) zle/ru-cl - 2 ru-lv --/M Russian_(Latvia) zle/ru-LV - 5 sd --/M Sindhi inc/sd - 5 shn --/M Shan_(Tai_Yai) tai/shn - 5 si --/M Sinhala inc/si - 5 sjn --/M Sindarin art/sjn - 5 sk --/M Slovak zlw/sk - 5 sl --/M Slovenian zls/sl - 5 smj --/M Lule_Saami urj/smj - 5 sq --/M Albanian ine/sq - 5 sr --/M Serbian zls/sr - 5 sv --/M Swedish gmq/sv - 5 sw --/M Swahili bnt/sw - 5 ta --/M Tamil dra/ta - 5 te --/M Telugu dra/te - 5 th --/M Thai tai/th - 5 tk --/M Turkmen trk/tk - 5 tn --/M Setswana bnt/tn - 5 tr --/M Turkish trk/tr - 5 tt --/M Tatar trk/tt - 5 ug --/M Uyghur trk/ug - 5 uk --/M Ukrainian zle/uk - 5 ur --/M Urdu inc/ur - 5 uz --/M Uzbek trk/uz - 5 vi --/M Vietnamese_(Northern) aav/vi - 5 vi-vn-x-central --/M Vietnamese_(Central) aav/vi-VN-x-central - 5 vi-vn-x-south --/M Vietnamese_(Southern) aav/vi-VN-x-south - 5 yue --/M Chinese_(Cantonese) sit/yue (zh-yue 5)(zh 8) - 5 yue --/M Chinese_(Cantonese,_latin_as_Jyutping) sit/yue-Latn-jyutping (zh-yue 5)(zh 8) -""" diff --git a/egs/baker_zh/TTS/vits/flow.py b/egs/baker_zh/TTS/vits/flow.py deleted file mode 120000 index e65d91ea75..0000000000 --- a/egs/baker_zh/TTS/vits/flow.py +++ /dev/null @@ -1 +0,0 @@ -../../../ljspeech/TTS/vits/flow.py \ No newline at end of file diff --git a/egs/baker_zh/TTS/vits/generate_lexicon.py b/egs/baker_zh/TTS/vits/generate_lexicon.py deleted file mode 100755 index 6d040ef539..0000000000 --- a/egs/baker_zh/TTS/vits/generate_lexicon.py +++ /dev/null @@ -1,39 +0,0 @@ -#!/usr/bin/env python3 - -from pypinyin import phrases_dict, pinyin_dict -from tokenizer import Tokenizer - - -def main(): - filename = "lexicon.txt" - tokens = "./data/tokens.txt" - tokenizer = Tokenizer(tokens) - - word_dict = pinyin_dict.pinyin_dict - phrases = phrases_dict.phrases_dict - - i = 0 - with open(filename, "w", encoding="utf-8") as f: - for key in word_dict: - if not (0x4E00 <= key <= 0x9FFF): - continue - - w = chr(key) - - # 1 to remove the initial sil - # :-1 to remove the final eos - tokens = tokenizer.text_to_tokens(w)[1:-1] - - tokens = " ".join(tokens) - f.write(f"{w} {tokens}\n") - - for key in phrases: - # 1 to remove the initial sil - # :-1 to remove the final eos - tokens = tokenizer.text_to_tokens(key)[1:-1] - tokens = " ".join(tokens) - f.write(f"{key} {tokens}\n") - - -if __name__ == "__main__": - main() diff --git a/egs/baker_zh/TTS/vits/generator.py b/egs/baker_zh/TTS/vits/generator.py deleted file mode 120000 index 611679bfa8..0000000000 --- a/egs/baker_zh/TTS/vits/generator.py +++ /dev/null @@ -1 +0,0 @@ -../../../ljspeech/TTS/vits/generator.py \ No newline at end of file diff --git a/egs/baker_zh/TTS/vits/hifigan.py b/egs/baker_zh/TTS/vits/hifigan.py deleted file mode 120000 index 5ac025de72..0000000000 --- a/egs/baker_zh/TTS/vits/hifigan.py +++ /dev/null @@ -1 +0,0 @@ -../../../ljspeech/TTS/vits/hifigan.py \ No newline at end of file diff --git a/egs/baker_zh/TTS/vits/loss.py b/egs/baker_zh/TTS/vits/loss.py deleted file mode 120000 index 672e5ff68d..0000000000 --- a/egs/baker_zh/TTS/vits/loss.py +++ /dev/null @@ -1 +0,0 @@ -../../../ljspeech/TTS/vits/loss.py \ No newline at end of file diff --git a/egs/baker_zh/TTS/vits/monotonic_align b/egs/baker_zh/TTS/vits/monotonic_align deleted file mode 120000 index 71934e7cca..0000000000 --- a/egs/baker_zh/TTS/vits/monotonic_align +++ /dev/null @@ -1 +0,0 @@ -../../../ljspeech/TTS/vits/monotonic_align \ No newline at end of file diff --git a/egs/baker_zh/TTS/vits/pinyin_dict.py b/egs/baker_zh/TTS/vits/pinyin_dict.py deleted file mode 120000 index b8683bd2dc..0000000000 --- a/egs/baker_zh/TTS/vits/pinyin_dict.py +++ /dev/null @@ -1 +0,0 @@ -../local/pinyin_dict.py \ No newline at end of file diff --git a/egs/baker_zh/TTS/vits/posterior_encoder.py b/egs/baker_zh/TTS/vits/posterior_encoder.py deleted file mode 120000 index 41d64a3a66..0000000000 --- a/egs/baker_zh/TTS/vits/posterior_encoder.py +++ /dev/null @@ -1 +0,0 @@ -../../../ljspeech/TTS/vits/posterior_encoder.py \ No newline at end of file diff --git a/egs/baker_zh/TTS/vits/pypinyin-local.dict b/egs/baker_zh/TTS/vits/pypinyin-local.dict deleted file mode 120000 index 5bc9b77282..0000000000 --- a/egs/baker_zh/TTS/vits/pypinyin-local.dict +++ /dev/null @@ -1 +0,0 @@ -../local/pypinyin-local.dict \ No newline at end of file diff --git a/egs/baker_zh/TTS/vits/residual_coupling.py b/egs/baker_zh/TTS/vits/residual_coupling.py deleted file mode 120000 index f979adbf00..0000000000 --- a/egs/baker_zh/TTS/vits/residual_coupling.py +++ /dev/null @@ -1 +0,0 @@ -../../../ljspeech/TTS/vits/residual_coupling.py \ No newline at end of file diff --git a/egs/baker_zh/TTS/vits/test_onnx.py b/egs/baker_zh/TTS/vits/test_onnx.py deleted file mode 100755 index 66c94270ce..0000000000 --- a/egs/baker_zh/TTS/vits/test_onnx.py +++ /dev/null @@ -1,142 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2023 Xiaomi Corporation (Author: Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -This script is used to test the exported onnx model by vits/export-onnx.py - -Use the onnx model to generate a wav: -./vits/test_onnx.py \ - --model-filename vits/exp/vits-epoch-1000.onnx \ - --tokens data/tokens.txt -""" - - -import argparse -import logging - -import onnxruntime as ort -import torch -import torchaudio -from tokenizer import Tokenizer - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--model-filename", - type=str, - required=True, - help="Path to the onnx model.", - ) - - parser.add_argument( - "--tokens", - type=str, - default="data/tokens.txt", - help="""Path to vocabulary.""", - ) - - parser.add_argument( - "--text", - type=str, - default="Ask not what your country can do for you; ask what you can do for your country.", - help="Text to generate speech for", - ) - - parser.add_argument( - "--output-filename", - type=str, - default="test_onnx.wav", - help="Filename to save the generated wave file.", - ) - - return parser - - -class OnnxModel: - def __init__(self, model_filename: str): - session_opts = ort.SessionOptions() - session_opts.inter_op_num_threads = 1 - session_opts.intra_op_num_threads = 1 - - self.session_opts = session_opts - - self.model = ort.InferenceSession( - model_filename, - sess_options=self.session_opts, - providers=["CPUExecutionProvider"], - ) - logging.info(f"{self.model.get_modelmeta().custom_metadata_map}") - - metadata = self.model.get_modelmeta().custom_metadata_map - self.sample_rate = int(metadata["sample_rate"]) - - def __call__(self, tokens: torch.Tensor, tokens_lens: torch.Tensor) -> torch.Tensor: - """ - Args: - tokens: - A 1-D tensor of shape (1, T) - Returns: - A tensor of shape (1, T') - """ - noise_scale = torch.tensor([0.667], dtype=torch.float32) - noise_scale_dur = torch.tensor([0.8], dtype=torch.float32) - alpha = torch.tensor([1.0], dtype=torch.float32) - - out = self.model.run( - [ - self.model.get_outputs()[0].name, - ], - { - self.model.get_inputs()[0].name: tokens.numpy(), - self.model.get_inputs()[1].name: tokens_lens.numpy(), - self.model.get_inputs()[2].name: noise_scale.numpy(), - self.model.get_inputs()[3].name: alpha.numpy(), - self.model.get_inputs()[4].name: noise_scale_dur.numpy(), - }, - )[0] - return torch.from_numpy(out) - - -def main(): - args = get_parser().parse_args() - logging.info(vars(args)) - - tokenizer = Tokenizer(args.tokens) - - logging.info("About to create onnx model") - model = OnnxModel(args.model_filename) - - text = args.text - tokens = tokenizer.texts_to_token_ids([text]) - tokens = torch.tensor(tokens) # (1, T) - tokens_lens = torch.tensor([tokens.shape[1]], dtype=torch.int64) # (1, T) - audio = model(tokens, tokens_lens) # (1, T') - - output_filename = args.output_filename - torchaudio.save(output_filename, audio, sample_rate=model.sample_rate) - logging.info(f"Saved to {output_filename}") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/baker_zh/TTS/vits/text_encoder.py b/egs/baker_zh/TTS/vits/text_encoder.py deleted file mode 120000 index 0efba277e1..0000000000 --- a/egs/baker_zh/TTS/vits/text_encoder.py +++ /dev/null @@ -1 +0,0 @@ -../../../ljspeech/TTS/vits/text_encoder.py \ No newline at end of file diff --git a/egs/baker_zh/TTS/vits/tokenizer.py b/egs/baker_zh/TTS/vits/tokenizer.py deleted file mode 120000 index 0368e07d34..0000000000 --- a/egs/baker_zh/TTS/vits/tokenizer.py +++ /dev/null @@ -1 +0,0 @@ -../local/tokenizer.py \ No newline at end of file diff --git a/egs/baker_zh/TTS/vits/train.py b/egs/baker_zh/TTS/vits/train.py deleted file mode 100755 index 694129a89d..0000000000 --- a/egs/baker_zh/TTS/vits/train.py +++ /dev/null @@ -1,927 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import argparse -import logging -from pathlib import Path -from shutil import copyfile -from typing import Any, Dict, Optional, Tuple, Union - -import k2 -import numpy as np -import torch -import torch.multiprocessing as mp -import torch.nn as nn -from lhotse.cut import Cut -from lhotse.utils import fix_random_seed -from tokenizer import Tokenizer -from torch.cuda.amp import GradScaler, autocast -from torch.nn.parallel import DistributedDataParallel as DDP -from torch.optim import Optimizer -from torch.utils.tensorboard import SummaryWriter -from tts_datamodule import BakerZhSpeechTtsDataModule -from utils import MetricsTracker, plot_feature, save_checkpoint -from vits import VITS - -from icefall import diagnostics -from icefall.checkpoint import load_checkpoint -from icefall.dist import cleanup_dist, setup_dist -from icefall.env import get_env_info -from icefall.hooks import register_inf_check_hooks -from icefall.utils import AttributeDict, setup_logger, str2bool - -LRSchedulerType = torch.optim.lr_scheduler._LRScheduler - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--world-size", - type=int, - default=1, - help="Number of GPUs for DDP training.", - ) - - parser.add_argument( - "--master-port", - type=int, - default=12354, - help="Master port to use for DDP training.", - ) - - parser.add_argument( - "--tensorboard", - type=str2bool, - default=True, - help="Should various information be logged in tensorboard.", - ) - - parser.add_argument( - "--num-epochs", - type=int, - default=1000, - help="Number of epochs to train.", - ) - - parser.add_argument( - "--start-epoch", - type=int, - default=1, - help="""Resume training from this epoch. It should be positive. - If larger than 1, it will load checkpoint from - exp-dir/epoch-{start_epoch-1}.pt - """, - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="vits/exp", - help="""The experiment dir. - It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--tokens", - type=str, - default="data/tokens.txt", - help="""Path to vocabulary.""", - ) - - parser.add_argument( - "--lr", type=float, default=2.0e-4, help="The base learning rate." - ) - - parser.add_argument( - "--seed", - type=int, - default=42, - help="The seed for random generators intended for reproducibility", - ) - - parser.add_argument( - "--print-diagnostics", - type=str2bool, - default=False, - help="Accumulate stats on activations, print them and exit.", - ) - - parser.add_argument( - "--inf-check", - type=str2bool, - default=False, - help="Add hooks to check for infinite module outputs and gradients.", - ) - - parser.add_argument( - "--save-every-n", - type=int, - default=20, - help="""Save checkpoint after processing this number of epochs" - periodically. We save checkpoint to exp-dir/ whenever - params.cur_epoch % save_every_n == 0. The checkpoint filename - has the form: f'exp-dir/epoch-{params.cur_epoch}.pt'. - Since it will take around 1000 epochs, we suggest using a large - save_every_n to save disk space. - """, - ) - - parser.add_argument( - "--use-fp16", - type=str2bool, - default=False, - help="Whether to use half precision training.", - ) - - parser.add_argument( - "--model-type", - type=str, - default="high", - choices=["low", "medium", "high"], - help="""If not empty, valid values are: low, medium, high. - It controls the model size. low -> runs faster. - """, - ) - - return parser - - -def get_params() -> AttributeDict: - """Return a dict containing training parameters. - - All training related parameters that are not passed from the commandline - are saved in the variable `params`. - - Commandline options are merged into `params` after they are parsed, so - you can also access them via `params`. - - Explanation of options saved in `params`: - - - best_train_loss: Best training loss so far. It is used to select - the model that has the lowest training loss. It is - updated during the training. - - - best_valid_loss: Best validation loss so far. It is used to select - the model that has the lowest validation loss. It is - updated during the training. - - - best_train_epoch: It is the epoch that has the best training loss. - - - best_valid_epoch: It is the epoch that has the best validation loss. - - - batch_idx_train: Used to writing statistics to tensorboard. It - contains number of batches trained so far across - epochs. - - - log_interval: Print training loss if batch_idx % log_interval` is 0 - - - valid_interval: Run validation if batch_idx % valid_interval is 0 - - - feature_dim: The model input dim. It has to match the one used - in computing features. - """ - params = AttributeDict( - { - # training params - "best_train_loss": float("inf"), - "best_valid_loss": float("inf"), - "best_train_epoch": -1, - "best_valid_epoch": -1, - "batch_idx_train": -1, # 0 - "log_interval": 50, - "valid_interval": 200, - "env_info": get_env_info(), - "sampling_rate": 48000, - "frame_shift": 256, - "frame_length": 1024, - "feature_dim": 513, # 1024 // 2 + 1, 1024 is fft_length - "n_mels": 80, - "lambda_adv": 1.0, # loss scaling coefficient for adversarial loss - "lambda_mel": 45.0, # loss scaling coefficient for Mel loss - "lambda_feat_match": 2.0, # loss scaling coefficient for feat match loss - "lambda_dur": 1.0, # loss scaling coefficient for duration loss - "lambda_kl": 1.0, # loss scaling coefficient for KL divergence loss - } - ) - - return params - - -def load_checkpoint_if_available( - params: AttributeDict, model: nn.Module -) -> Optional[Dict[str, Any]]: - """Load checkpoint from file. - - If params.start_epoch is larger than 1, it will load the checkpoint from - `params.start_epoch - 1`. - - Apart from loading state dict for `model` and `optimizer` it also updates - `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, - and `best_valid_loss` in `params`. - - Args: - params: - The return value of :func:`get_params`. - model: - The training model. - Returns: - Return a dict containing previously saved training info. - """ - if params.start_epoch > 1: - filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" - else: - return None - - assert filename.is_file(), f"{filename} does not exist!" - - saved_params = load_checkpoint(filename, model=model) - - keys = [ - "best_train_epoch", - "best_valid_epoch", - "batch_idx_train", - "best_train_loss", - "best_valid_loss", - ] - for k in keys: - params[k] = saved_params[k] - - return saved_params - - -def get_model(params: AttributeDict) -> nn.Module: - mel_loss_params = { - "n_mels": params.n_mels, - "frame_length": params.frame_length, - "frame_shift": params.frame_shift, - } - model = VITS( - vocab_size=params.vocab_size, - feature_dim=params.feature_dim, - sampling_rate=params.sampling_rate, - model_type=params.model_type, - mel_loss_params=mel_loss_params, - lambda_adv=params.lambda_adv, - lambda_mel=params.lambda_mel, - lambda_feat_match=params.lambda_feat_match, - lambda_dur=params.lambda_dur, - lambda_kl=params.lambda_kl, - ) - return model - - -def prepare_input(batch: dict, tokenizer: Tokenizer, device: torch.device): - """Parse batch data""" - audio = batch["audio"].to(device) - features = batch["features"].to(device) - audio_lens = batch["audio_lens"].to(device) - features_lens = batch["features_lens"].to(device) - tokens = batch["tokens"] - - tokens = tokenizer.tokens_to_token_ids(tokens) - tokens = k2.RaggedTensor(tokens) - row_splits = tokens.shape.row_splits(1) - tokens_lens = row_splits[1:] - row_splits[:-1] - tokens = tokens.to(device) - tokens_lens = tokens_lens.to(device) - # a tensor of shape (B, T) - tokens = tokens.pad(mode="constant", padding_value=tokenizer.pad_id) - - return audio, audio_lens, features, features_lens, tokens, tokens_lens - - -def train_one_epoch( - params: AttributeDict, - model: Union[nn.Module, DDP], - tokenizer: Tokenizer, - optimizer_g: Optimizer, - optimizer_d: Optimizer, - scheduler_g: LRSchedulerType, - scheduler_d: LRSchedulerType, - train_dl: torch.utils.data.DataLoader, - valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, - tb_writer: Optional[SummaryWriter] = None, - world_size: int = 1, - rank: int = 0, -) -> None: - """Train the model for one epoch. - - The training loss from the mean of all frames is saved in - `params.train_loss`. It runs the validation process every - `params.valid_interval` batches. - - Args: - params: - It is returned by :func:`get_params`. - model: - The model for training. - tokenizer: - Used to convert text to phonemes. - optimizer_g: - The optimizer for generator. - optimizer_d: - The optimizer for discriminator. - scheduler_g: - The learning rate scheduler for generator, we call step() every epoch. - scheduler_d: - The learning rate scheduler for discriminator, we call step() every epoch. - train_dl: - Dataloader for the training dataset. - valid_dl: - Dataloader for the validation dataset. - scaler: - The scaler used for mix precision training. - tb_writer: - Writer to write log messages to tensorboard. - world_size: - Number of nodes in DDP training. If it is 1, DDP is disabled. - rank: - The rank of the node in DDP training. If no DDP is used, it should - be set to 0. - """ - model.train() - device = model.device if isinstance(model, DDP) else next(model.parameters()).device - - # used to track the stats over iterations in one epoch - tot_loss = MetricsTracker() - - saved_bad_model = False - - def save_bad_model(suffix: str = ""): - save_checkpoint( - filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", - model=model, - params=params, - optimizer_g=optimizer_g, - optimizer_d=optimizer_d, - scheduler_g=scheduler_g, - scheduler_d=scheduler_d, - sampler=train_dl.sampler, - scaler=scaler, - rank=0, - ) - - for batch_idx, batch in enumerate(train_dl): - params.batch_idx_train += 1 - - batch_size = len(batch["tokens"]) - audio, audio_lens, features, features_lens, tokens, tokens_lens = prepare_input( - batch, tokenizer, device - ) - - loss_info = MetricsTracker() - loss_info["samples"] = batch_size - - try: - with autocast(enabled=params.use_fp16): - # forward discriminator - loss_d, stats_d = model( - text=tokens, - text_lengths=tokens_lens, - feats=features, - feats_lengths=features_lens, - speech=audio, - speech_lengths=audio_lens, - forward_generator=False, - ) - for k, v in stats_d.items(): - loss_info[k] = v * batch_size - # update discriminator - optimizer_d.zero_grad() - scaler.scale(loss_d).backward() - scaler.step(optimizer_d) - - with autocast(enabled=params.use_fp16): - # forward generator - loss_g, stats_g = model( - text=tokens, - text_lengths=tokens_lens, - feats=features, - feats_lengths=features_lens, - speech=audio, - speech_lengths=audio_lens, - forward_generator=True, - return_sample=params.batch_idx_train % params.log_interval == 0, - ) - for k, v in stats_g.items(): - if "returned_sample" not in k: - loss_info[k] = v * batch_size - # update generator - optimizer_g.zero_grad() - scaler.scale(loss_g).backward() - scaler.step(optimizer_g) - scaler.update() - - # summary stats - tot_loss = tot_loss + loss_info - except: # noqa - save_bad_model() - raise - - if params.print_diagnostics and batch_idx == 5: - return - - if params.batch_idx_train % 100 == 0 and params.use_fp16: - # If the grad scale was less than 1, try increasing it. The _growth_interval - # of the grad scaler is configurable, but we can't configure it to have different - # behavior depending on the current grad scale. - cur_grad_scale = scaler._scale.item() - - if cur_grad_scale < 8.0 or ( - cur_grad_scale < 32.0 and params.batch_idx_train % 400 == 0 - ): - scaler.update(cur_grad_scale * 2.0) - if cur_grad_scale < 0.01: - if not saved_bad_model: - save_bad_model(suffix="-first-warning") - saved_bad_model = True - logging.warning(f"Grad scale is small: {cur_grad_scale}") - if cur_grad_scale < 1.0e-05: - save_bad_model() - raise RuntimeError( - f"grad_scale is too small, exiting: {cur_grad_scale}" - ) - - if params.batch_idx_train % params.log_interval == 0: - cur_lr_g = max(scheduler_g.get_last_lr()) - cur_lr_d = max(scheduler_d.get_last_lr()) - cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 - - logging.info( - f"Epoch {params.cur_epoch}, batch {batch_idx}, " - f"global_batch_idx: {params.batch_idx_train}, batch size: {batch_size}, " - f"loss[{loss_info}], tot_loss[{tot_loss}], " - f"cur_lr_g: {cur_lr_g:.2e}, cur_lr_d: {cur_lr_d:.2e}, " - + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") - ) - - if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate_g", cur_lr_g, params.batch_idx_train - ) - tb_writer.add_scalar( - "train/learning_rate_d", cur_lr_d, params.batch_idx_train - ) - loss_info.write_summary( - tb_writer, "train/current_", params.batch_idx_train - ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) - if params.use_fp16: - tb_writer.add_scalar( - "train/grad_scale", cur_grad_scale, params.batch_idx_train - ) - if "returned_sample" in stats_g: - speech_hat_, speech_, mel_hat_, mel_ = stats_g["returned_sample"] - tb_writer.add_audio( - "train/speech_hat_", - speech_hat_, - params.batch_idx_train, - params.sampling_rate, - ) - tb_writer.add_audio( - "train/speech_", - speech_, - params.batch_idx_train, - params.sampling_rate, - ) - tb_writer.add_image( - "train/mel_hat_", - plot_feature(mel_hat_), - params.batch_idx_train, - dataformats="HWC", - ) - tb_writer.add_image( - "train/mel_", - plot_feature(mel_), - params.batch_idx_train, - dataformats="HWC", - ) - - if ( - params.batch_idx_train % params.valid_interval == 0 - and not params.print_diagnostics - ): - logging.info("Computing validation loss") - valid_info, (speech_hat, speech) = compute_validation_loss( - params=params, - model=model, - tokenizer=tokenizer, - valid_dl=valid_dl, - world_size=world_size, - ) - model.train() - logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") - logging.info( - f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" - ) - if tb_writer is not None: - valid_info.write_summary( - tb_writer, "train/valid_", params.batch_idx_train - ) - tb_writer.add_audio( - "train/valdi_speech_hat", - speech_hat, - params.batch_idx_train, - params.sampling_rate, - ) - tb_writer.add_audio( - "train/valdi_speech", - speech, - params.batch_idx_train, - params.sampling_rate, - ) - - loss_value = tot_loss["generator_loss"] / tot_loss["samples"] - params.train_loss = loss_value - if params.train_loss < params.best_train_loss: - params.best_train_epoch = params.cur_epoch - params.best_train_loss = params.train_loss - - -def compute_validation_loss( - params: AttributeDict, - model: Union[nn.Module, DDP], - tokenizer: Tokenizer, - valid_dl: torch.utils.data.DataLoader, - world_size: int = 1, - rank: int = 0, -) -> Tuple[MetricsTracker, Tuple[np.ndarray, np.ndarray]]: - """Run the validation process.""" - model.eval() - device = model.device if isinstance(model, DDP) else next(model.parameters()).device - - # used to summary the stats over iterations - tot_loss = MetricsTracker() - returned_sample = None - - with torch.no_grad(): - for batch_idx, batch in enumerate(valid_dl): - batch_size = len(batch["tokens"]) - ( - audio, - audio_lens, - features, - features_lens, - tokens, - tokens_lens, - ) = prepare_input(batch, tokenizer, device) - - loss_info = MetricsTracker() - loss_info["samples"] = batch_size - - # forward discriminator - loss_d, stats_d = model( - text=tokens, - text_lengths=tokens_lens, - feats=features, - feats_lengths=features_lens, - speech=audio, - speech_lengths=audio_lens, - forward_generator=False, - ) - assert loss_d.requires_grad is False - for k, v in stats_d.items(): - loss_info[k] = v * batch_size - - # forward generator - loss_g, stats_g = model( - text=tokens, - text_lengths=tokens_lens, - feats=features, - feats_lengths=features_lens, - speech=audio, - speech_lengths=audio_lens, - forward_generator=True, - ) - assert loss_g.requires_grad is False - for k, v in stats_g.items(): - loss_info[k] = v * batch_size - - # summary stats - tot_loss = tot_loss + loss_info - - # infer for first batch: - if batch_idx == 0 and rank == 0: - inner_model = model.module if isinstance(model, DDP) else model - audio_pred, _, duration = inner_model.inference( - text=tokens[0, : tokens_lens[0].item()] - ) - audio_pred = audio_pred.data.cpu().numpy() - audio_len_pred = ( - (duration.sum(0) * params.frame_shift).to(dtype=torch.int64).item() - ) - assert audio_len_pred == len(audio_pred), ( - audio_len_pred, - len(audio_pred), - ) - audio_gt = audio[0, : audio_lens[0].item()].data.cpu().numpy() - returned_sample = (audio_pred, audio_gt) - - if world_size > 1: - tot_loss.reduce(device) - - loss_value = tot_loss["generator_loss"] / tot_loss["samples"] - if loss_value < params.best_valid_loss: - params.best_valid_epoch = params.cur_epoch - params.best_valid_loss = loss_value - - return tot_loss, returned_sample - - -def scan_pessimistic_batches_for_oom( - model: Union[nn.Module, DDP], - train_dl: torch.utils.data.DataLoader, - tokenizer: Tokenizer, - optimizer_g: torch.optim.Optimizer, - optimizer_d: torch.optim.Optimizer, - params: AttributeDict, -): - from lhotse.dataset import find_pessimistic_batches - - logging.info( - "Sanity check -- see if any of the batches in epoch 1 would cause OOM." - ) - device = model.device if isinstance(model, DDP) else next(model.parameters()).device - batches, crit_values = find_pessimistic_batches(train_dl.sampler) - for criterion, cuts in batches.items(): - batch = train_dl.dataset[cuts] - audio, audio_lens, features, features_lens, tokens, tokens_lens = prepare_input( - batch, tokenizer, device - ) - try: - # for discriminator - with autocast(enabled=params.use_fp16): - loss_d, stats_d = model( - text=tokens, - text_lengths=tokens_lens, - feats=features, - feats_lengths=features_lens, - speech=audio, - speech_lengths=audio_lens, - forward_generator=False, - ) - optimizer_d.zero_grad() - loss_d.backward() - # for generator - with autocast(enabled=params.use_fp16): - loss_g, stats_g = model( - text=tokens, - text_lengths=tokens_lens, - feats=features, - feats_lengths=features_lens, - speech=audio, - speech_lengths=audio_lens, - forward_generator=True, - ) - optimizer_g.zero_grad() - loss_g.backward() - except Exception as e: - if "CUDA out of memory" in str(e): - logging.error( - "Your GPU ran out of memory with the current " - "max_duration setting. We recommend decreasing " - "max_duration and trying again.\n" - f"Failing criterion: {criterion} " - f"(={crit_values[criterion]}) ..." - ) - raise - logging.info( - f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" - ) - - -def run(rank, world_size, args): - """ - Args: - rank: - It is a value between 0 and `world_size-1`, which is - passed automatically by `mp.spawn()` in :func:`main`. - The node with rank 0 is responsible for saving checkpoint. - world_size: - Number of GPUs for DDP training. - args: - The return value of get_parser().parse_args() - """ - params = get_params() - params.update(vars(args)) - - fix_random_seed(params.seed) - if world_size > 1: - setup_dist(rank, world_size, params.master_port) - - setup_logger(f"{params.exp_dir}/log/log-train") - logging.info("Training started") - - if args.tensorboard and rank == 0: - tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") - else: - tb_writer = None - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", rank) - logging.info(f"Device: {device}") - - tokenizer = Tokenizer(params.tokens) - params.blank_id = tokenizer.pad_id - params.vocab_size = tokenizer.vocab_size - - logging.info(params) - - logging.info("About to create model") - model = get_model(params) - generator = model.generator - discriminator = model.discriminator - - num_param_g = sum([p.numel() for p in generator.parameters()]) - logging.info(f"Number of parameters in generator: {num_param_g}") - num_param_d = sum([p.numel() for p in discriminator.parameters()]) - logging.info(f"Number of parameters in discriminator: {num_param_d}") - logging.info(f"Total number of parameters: {num_param_g + num_param_d}") - - assert params.start_epoch > 0, params.start_epoch - checkpoints = load_checkpoint_if_available(params=params, model=model) - - model.to(device) - if world_size > 1: - logging.info("Using DDP") - model = DDP(model, device_ids=[rank], find_unused_parameters=True) - - optimizer_g = torch.optim.AdamW( - generator.parameters(), lr=params.lr, betas=(0.8, 0.99), eps=1e-9 - ) - optimizer_d = torch.optim.AdamW( - discriminator.parameters(), lr=params.lr, betas=(0.8, 0.99), eps=1e-9 - ) - - scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optimizer_g, gamma=0.999875) - scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optimizer_d, gamma=0.999875) - - if checkpoints is not None: - # load state_dict for optimizers - if "optimizer_g" in checkpoints: - logging.info("Loading optimizer_g state dict") - optimizer_g.load_state_dict(checkpoints["optimizer_g"]) - if "optimizer_d" in checkpoints: - logging.info("Loading optimizer_d state dict") - optimizer_d.load_state_dict(checkpoints["optimizer_d"]) - - # load state_dict for schedulers - if "scheduler_g" in checkpoints: - logging.info("Loading scheduler_g state dict") - scheduler_g.load_state_dict(checkpoints["scheduler_g"]) - if "scheduler_d" in checkpoints: - logging.info("Loading scheduler_d state dict") - scheduler_d.load_state_dict(checkpoints["scheduler_d"]) - - if params.print_diagnostics: - opts = diagnostics.TensorDiagnosticOptions( - 512 - ) # allow 4 megabytes per sub-module - diagnostic = diagnostics.attach_diagnostics(model, opts) - - if params.inf_check: - register_inf_check_hooks(model) - - baker_zh = BakerZhSpeechTtsDataModule(args) - - train_cuts = baker_zh.train_cuts() - - def remove_short_and_long_utt(c: Cut): - # Keep only utterances with duration between 1 second and 20 seconds - # You should use ../local/display_manifest_statistics.py to get - # an utterance duration distribution for your dataset to select - # the threshold - if c.duration < 1.0 or c.duration > 20.0: - # logging.warning( - # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" - # ) - return False - return True - - train_cuts = train_cuts.filter(remove_short_and_long_utt) - train_dl = baker_zh.train_dataloaders(train_cuts) - - valid_cuts = baker_zh.valid_cuts() - valid_dl = baker_zh.valid_dataloaders(valid_cuts) - - if not params.print_diagnostics: - scan_pessimistic_batches_for_oom( - model=model, - train_dl=train_dl, - tokenizer=tokenizer, - optimizer_g=optimizer_g, - optimizer_d=optimizer_d, - params=params, - ) - - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) - if checkpoints and "grad_scaler" in checkpoints: - logging.info("Loading grad scaler state dict") - scaler.load_state_dict(checkpoints["grad_scaler"]) - - for epoch in range(params.start_epoch, params.num_epochs + 1): - logging.info(f"Start epoch {epoch}") - - fix_random_seed(params.seed + epoch - 1) - train_dl.sampler.set_epoch(epoch - 1) - - params.cur_epoch = epoch - - if tb_writer is not None: - tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) - - train_one_epoch( - params=params, - model=model, - tokenizer=tokenizer, - optimizer_g=optimizer_g, - optimizer_d=optimizer_d, - scheduler_g=scheduler_g, - scheduler_d=scheduler_d, - train_dl=train_dl, - valid_dl=valid_dl, - scaler=scaler, - tb_writer=tb_writer, - world_size=world_size, - rank=rank, - ) - - if params.print_diagnostics: - diagnostic.print_diagnostics() - break - - if epoch % params.save_every_n == 0 or epoch == params.num_epochs: - filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" - save_checkpoint( - filename=filename, - params=params, - model=model, - optimizer_g=optimizer_g, - optimizer_d=optimizer_d, - scheduler_g=scheduler_g, - scheduler_d=scheduler_d, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - if rank == 0: - if params.best_train_epoch == params.cur_epoch: - best_train_filename = params.exp_dir / "best-train-loss.pt" - copyfile(src=filename, dst=best_train_filename) - - if params.best_valid_epoch == params.cur_epoch: - best_valid_filename = params.exp_dir / "best-valid-loss.pt" - copyfile(src=filename, dst=best_valid_filename) - - # step per epoch - scheduler_g.step() - scheduler_d.step() - - logging.info("Done!") - - if world_size > 1: - torch.distributed.barrier() - cleanup_dist() - - -def main(): - parser = get_parser() - BakerZhSpeechTtsDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - world_size = args.world_size - assert world_size >= 1 - if world_size > 1: - mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) - else: - run(rank=0, world_size=1, args=args) - - -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - -if __name__ == "__main__": - main() diff --git a/egs/baker_zh/TTS/vits/transform.py b/egs/baker_zh/TTS/vits/transform.py deleted file mode 120000 index 962647408b..0000000000 --- a/egs/baker_zh/TTS/vits/transform.py +++ /dev/null @@ -1 +0,0 @@ -../../../ljspeech/TTS/vits/transform.py \ No newline at end of file diff --git a/egs/baker_zh/TTS/vits/tts_datamodule.py b/egs/baker_zh/TTS/vits/tts_datamodule.py deleted file mode 100644 index 96c5422771..0000000000 --- a/egs/baker_zh/TTS/vits/tts_datamodule.py +++ /dev/null @@ -1,330 +0,0 @@ -# Copyright 2021 Piotr Żelasko -# Copyright 2022-2023 Xiaomi Corporation (Authors: Mingshuang Luo, -# Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import argparse -import logging -from functools import lru_cache -from pathlib import Path -from typing import Any, Dict, Optional - -import torch -from lhotse import CutSet, Spectrogram, SpectrogramConfig, load_manifest_lazy -from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures - CutConcatenate, - CutMix, - DynamicBucketingSampler, - PrecomputedFeatures, - SimpleCutSampler, - SpecAugment, - SpeechSynthesisDataset, -) -from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples - AudioSamples, - OnTheFlyFeatures, -) -from lhotse.utils import fix_random_seed -from torch.utils.data import DataLoader - -from icefall.utils import str2bool - - -class _SeedWorkers: - def __init__(self, seed: int): - self.seed = seed - - def __call__(self, worker_id: int): - fix_random_seed(self.seed + worker_id) - - -class BakerZhSpeechTtsDataModule: - """ - DataModule for tts experiments. - It assumes there is always one train and valid dataloader, - but there can be multiple test dataloaders (e.g. LibriSpeech test-clean - and test-other). - - It contains all the common data pipeline modules used in ASR - experiments, e.g.: - - dynamic batch size, - - bucketing samplers, - - cut concatenation, - - on-the-fly feature extraction - - This class should be derived for specific corpora used in TTS tasks. - """ - - def __init__(self, args: argparse.Namespace): - self.args = args - self.sampling_rate = 48000 - - @classmethod - def add_arguments(cls, parser: argparse.ArgumentParser): - group = parser.add_argument_group( - title="TTS data related options", - description="These options are used for the preparation of " - "PyTorch DataLoaders from Lhotse CutSet's -- they control the " - "effective batch sizes, sampling strategies, applied data " - "augmentations, etc.", - ) - - group.add_argument( - "--manifest-dir", - type=Path, - default=Path("data/spectrogram"), - help="Path to directory with train/valid/test cuts.", - ) - group.add_argument( - "--max-duration", - type=int, - default=200.0, - help="Maximum pooled recordings duration (seconds) in a " - "single batch. You can reduce it if it causes CUDA OOM.", - ) - group.add_argument( - "--bucketing-sampler", - type=str2bool, - default=True, - help="When enabled, the batches will come from buckets of " - "similar duration (saves padding frames).", - ) - group.add_argument( - "--num-buckets", - type=int, - default=30, - help="The number of buckets for the DynamicBucketingSampler" - "(you might want to increase it for larger datasets).", - ) - - group.add_argument( - "--on-the-fly-feats", - type=str2bool, - default=False, - help="When enabled, use on-the-fly cut mixing and feature " - "extraction. Will drop existing precomputed feature manifests " - "if available.", - ) - group.add_argument( - "--shuffle", - type=str2bool, - default=True, - help="When enabled (=default), the examples will be " - "shuffled for each epoch.", - ) - group.add_argument( - "--drop-last", - type=str2bool, - default=True, - help="Whether to drop last batch. Used by sampler.", - ) - group.add_argument( - "--return-cuts", - type=str2bool, - default=False, - help="When enabled, each batch will have the " - "field: batch['cut'] with the cuts that " - "were used to construct it.", - ) - group.add_argument( - "--num-workers", - type=int, - default=2, - help="The number of training dataloader workers that " - "collect the batches.", - ) - - group.add_argument( - "--input-strategy", - type=str, - default="PrecomputedFeatures", - help="AudioSamples or PrecomputedFeatures", - ) - - def train_dataloaders( - self, - cuts_train: CutSet, - sampler_state_dict: Optional[Dict[str, Any]] = None, - ) -> DataLoader: - """ - Args: - cuts_train: - CutSet for training. - sampler_state_dict: - The state dict for the training sampler. - """ - logging.info("About to create train dataset") - train = SpeechSynthesisDataset( - return_text=False, - return_tokens=True, - feature_input_strategy=eval(self.args.input_strategy)(), - return_cuts=self.args.return_cuts, - ) - - if self.args.on_the_fly_feats: - sampling_rate = self.sampling_rate - config = SpectrogramConfig( - sampling_rate=sampling_rate, - frame_length=1024 / sampling_rate, # (in second), - frame_shift=256 / sampling_rate, # (in second) - use_fft_mag=True, - ) - train = SpeechSynthesisDataset( - return_text=False, - return_tokens=True, - feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)), - return_cuts=self.args.return_cuts, - ) - - if self.args.bucketing_sampler: - logging.info("Using DynamicBucketingSampler.") - train_sampler = DynamicBucketingSampler( - cuts_train, - max_duration=self.args.max_duration, - shuffle=self.args.shuffle, - num_buckets=self.args.num_buckets, - buffer_size=self.args.num_buckets * 2000, - shuffle_buffer_size=self.args.num_buckets * 5000, - drop_last=self.args.drop_last, - ) - else: - logging.info("Using SimpleCutSampler.") - train_sampler = SimpleCutSampler( - cuts_train, - max_duration=self.args.max_duration, - shuffle=self.args.shuffle, - ) - logging.info("About to create train dataloader") - - if sampler_state_dict is not None: - logging.info("Loading sampler state dict") - train_sampler.load_state_dict(sampler_state_dict) - - # 'seed' is derived from the current random state, which will have - # previously been set in the main process. - seed = torch.randint(0, 100000, ()).item() - worker_init_fn = _SeedWorkers(seed) - - train_dl = DataLoader( - train, - sampler=train_sampler, - batch_size=None, - num_workers=self.args.num_workers, - persistent_workers=False, - worker_init_fn=worker_init_fn, - ) - - return train_dl - - def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: - logging.info("About to create dev dataset") - if self.args.on_the_fly_feats: - sampling_rate = self.sampling_rate - config = SpectrogramConfig( - sampling_rate=sampling_rate, - frame_length=1024 / sampling_rate, # (in second), - frame_shift=256 / sampling_rate, # (in second) - use_fft_mag=True, - ) - validate = SpeechSynthesisDataset( - return_text=False, - return_tokens=True, - feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)), - return_cuts=self.args.return_cuts, - ) - else: - validate = SpeechSynthesisDataset( - return_text=False, - return_tokens=True, - feature_input_strategy=eval(self.args.input_strategy)(), - return_cuts=self.args.return_cuts, - ) - valid_sampler = DynamicBucketingSampler( - cuts_valid, - max_duration=self.args.max_duration, - num_buckets=self.args.num_buckets, - shuffle=False, - ) - logging.info("About to create valid dataloader") - valid_dl = DataLoader( - validate, - sampler=valid_sampler, - batch_size=None, - num_workers=2, - persistent_workers=False, - ) - - return valid_dl - - def test_dataloaders(self, cuts: CutSet) -> DataLoader: - logging.info("About to create test dataset") - if self.args.on_the_fly_feats: - sampling_rate = self.sampling_rate - config = SpectrogramConfig( - sampling_rate=sampling_rate, - frame_length=1024 / sampling_rate, # (in second), - frame_shift=256 / sampling_rate, # (in second) - use_fft_mag=True, - ) - test = SpeechSynthesisDataset( - return_text=False, - return_tokens=True, - feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)), - return_cuts=self.args.return_cuts, - ) - else: - test = SpeechSynthesisDataset( - return_text=False, - return_tokens=True, - feature_input_strategy=eval(self.args.input_strategy)(), - return_cuts=self.args.return_cuts, - ) - test_sampler = DynamicBucketingSampler( - cuts, - max_duration=self.args.max_duration, - num_buckets=self.args.num_buckets, - shuffle=False, - ) - logging.info("About to create test dataloader") - test_dl = DataLoader( - test, - batch_size=None, - sampler=test_sampler, - num_workers=self.args.num_workers, - ) - return test_dl - - @lru_cache() - def train_cuts(self) -> CutSet: - logging.info("About to get train cuts") - return load_manifest_lazy( - self.args.manifest_dir / "baker_zh_cuts_train.jsonl.gz" - ) - - @lru_cache() - def valid_cuts(self) -> CutSet: - logging.info("About to get validation cuts") - return load_manifest_lazy( - self.args.manifest_dir / "baker_zh_cuts_valid.jsonl.gz" - ) - - @lru_cache() - def test_cuts(self) -> CutSet: - logging.info("About to get test cuts") - return load_manifest_lazy( - self.args.manifest_dir / "baker_zh_cuts_test.jsonl.gz" - ) diff --git a/egs/baker_zh/TTS/vits/utils.py b/egs/baker_zh/TTS/vits/utils.py deleted file mode 120000 index 085e764b43..0000000000 --- a/egs/baker_zh/TTS/vits/utils.py +++ /dev/null @@ -1 +0,0 @@ -../../../ljspeech/TTS/vits/utils.py \ No newline at end of file diff --git a/egs/baker_zh/TTS/vits/vits.py b/egs/baker_zh/TTS/vits/vits.py deleted file mode 120000 index 1f58cf6fea..0000000000 --- a/egs/baker_zh/TTS/vits/vits.py +++ /dev/null @@ -1 +0,0 @@ -../../../ljspeech/TTS/vits/vits.py \ No newline at end of file diff --git a/egs/baker_zh/TTS/vits/wavenet.py b/egs/baker_zh/TTS/vits/wavenet.py deleted file mode 120000 index 28f0a78eeb..0000000000 --- a/egs/baker_zh/TTS/vits/wavenet.py +++ /dev/null @@ -1 +0,0 @@ -../../../ljspeech/TTS/vits/wavenet.py \ No newline at end of file From bfae73cb7483ed4d78ce57007fdf8723af32b65b Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Sat, 6 Apr 2024 22:13:27 +0800 Subject: [PATCH 7/8] Add CI test for aishell3 --- .github/scripts/aishell3/TTS/run.sh | 86 +++++++++++++++++++++++++++++ .github/workflows/aishell3.yml | 73 ++++++++++++++++++++++++ egs/aishell3/TTS/local/symbols.py | 73 ++++++++++++++++++++++++ egs/aishell3/TTS/prepare.sh | 2 +- egs/aishell3/TTS/vits/train.py | 4 ++ 5 files changed, 237 insertions(+), 1 deletion(-) create mode 100755 .github/scripts/aishell3/TTS/run.sh create mode 100644 .github/workflows/aishell3.yml create mode 100644 egs/aishell3/TTS/local/symbols.py diff --git a/.github/scripts/aishell3/TTS/run.sh b/.github/scripts/aishell3/TTS/run.sh new file mode 100755 index 0000000000..81fba1de44 --- /dev/null +++ b/.github/scripts/aishell3/TTS/run.sh @@ -0,0 +1,86 @@ +#!/usr/bin/env bash + +set -ex + +python3 -m pip install piper_phonemize -f https://k2-fsa.github.io/icefall/piper_phonemize.html +python3 -m pip install numba +python3 -m pip install pypinyin +python3 -m pip install cython + +apt-get update +apt-get install -y jq + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +cd egs/aishell3/TTS + +sed -i.bak s/1000/10/g ./prepare.sh + + +function download_data() { + mkdir download + pushd download + curl -SL -O https://huggingface.co/csukuangfj/aishell3-ci-data/resolve/main/aishell3.tar.bz2 + tar xf aishell3.tar.bz2 + rm aishell3.tar.bz2 + ls -lh + popd +} + +function prepare_data() { + ./prepare.sh + + echo "----------tokens.txt----------" + cat data/tokens.txt + echo "------------------------------" + wc -l data/tokens.txt + echo "------------------------------" +} + +function train() { + pushd ./vits + sed -i.bak s/200/50/g ./train.py + git diff . + popd + + for t in low medium high; do + ./vits/train.py \ + --exp-dir vits/exp-$t \ + --model-type $t \ + --num-epochs 1 \ + --save-every-n 1 \ + --num-buckets 2 \ + --tokens data/tokens.txt \ + --max-duration 20 + + ls -lh vits/exp-$t + done +} + +function export_onnx() { + for t in low medium high; do + ./vits/export-onnx.py \ + --model-type $t \ + --epoch 1 \ + --exp-dir ./vits/exp-$t \ + --tokens data/tokens.txt + --speakers ./data/speakers.txt + + ls -lh vits/exp-$t/ + done +} + +function test_low() { + echo "TODO" +} + + +download_data +prepare_data +train +export_onnx +test_low diff --git a/.github/workflows/aishell3.yml b/.github/workflows/aishell3.yml new file mode 100644 index 0000000000..e60c85f4d4 --- /dev/null +++ b/.github/workflows/aishell3.yml @@ -0,0 +1,73 @@ +name: aishell + +on: + push: + branches: + - master + - tts-aishell3 + + pull_request: + branches: + - master + + workflow_dispatch: + +concurrency: + group: aishell3-${{ github.ref }} + cancel-in-progress: true + +jobs: + generate_build_matrix: + if: (github.repository_owner == 'csukuangfj' || github.repository_owner == 'k2-fsa') && (github.event.label.name == 'ready' || github.event_name == 'push' || github.event_name == 'aishell3') + + # see https://github.com/pytorch/pytorch/pull/50633 + runs-on: ubuntu-latest + outputs: + matrix: ${{ steps.set-matrix.outputs.matrix }} + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + - name: Generating build matrix + id: set-matrix + run: | + # outputting for debugging purposes + python ./.github/scripts/docker/generate_build_matrix.py + MATRIX=$(python ./.github/scripts/docker/generate_build_matrix.py) + echo "::set-output name=matrix::${MATRIX}" + aishell3: + needs: generate_build_matrix + name: py${{ matrix.python-version }} torch${{ matrix.torch-version }} v${{ matrix.version }} + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + ${{ fromJson(needs.generate_build_matrix.outputs.matrix) }} + + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: Free space + shell: bash + run: | + df -h + rm -rf /opt/hostedtoolcache + df -h + echo "pwd: $PWD" + echo "github.workspace ${{ github.workspace }}" + + - name: Run aishell3 tests + uses: addnab/docker-run-action@v3 + with: + image: ghcr.io/${{ github.repository_owner }}/icefall:cpu-py${{ matrix.python-version }}-torch${{ matrix.torch-version }}-v${{ matrix.version }} + options: | + --volume ${{ github.workspace }}/:/icefall + shell: bash + run: | + export PYTHONPATH=/icefall:$PYTHONPATH + cd /icefall + git config --global --add safe.directory /icefall + + .github/scripts/aishell3/TTS/run.sh diff --git a/egs/aishell3/TTS/local/symbols.py b/egs/aishell3/TTS/local/symbols.py new file mode 100644 index 0000000000..1e68788704 --- /dev/null +++ b/egs/aishell3/TTS/local/symbols.py @@ -0,0 +1,73 @@ +# This file is copied from +# https://github.com/UEhQZXI/vits_chinese/blob/master/text/symbols.py +_pause = ["sil", "eos", "sp", "#0", "#1", "#2", "#3"] + +_initials = [ + "^", + "b", + "c", + "ch", + "d", + "f", + "g", + "h", + "j", + "k", + "l", + "m", + "n", + "p", + "q", + "r", + "s", + "sh", + "t", + "x", + "z", + "zh", +] + +_tones = ["1", "2", "3", "4", "5"] + +_finals = [ + "a", + "ai", + "an", + "ang", + "ao", + "e", + "ei", + "en", + "eng", + "er", + "i", + "ia", + "ian", + "iang", + "iao", + "ie", + "ii", + "iii", + "in", + "ing", + "iong", + "iou", + "o", + "ong", + "ou", + "u", + "ua", + "uai", + "uan", + "uang", + "uei", + "uen", + "ueng", + "uo", + "v", + "van", + "ve", + "vn", +] + +symbols = _pause + _initials + [i + j for i in _finals for j in _tones] diff --git a/egs/aishell3/TTS/prepare.sh b/egs/aishell3/TTS/prepare.sh index af532c2296..fe3f762054 100755 --- a/egs/aishell3/TTS/prepare.sh +++ b/egs/aishell3/TTS/prepare.sh @@ -59,7 +59,7 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then # You can find files like spk-info.txt inside $dl_dir/aishell3 mkdir -p data/manifests if [ ! -e data/manifests/.aishell3.done ]; then - lhotse prepare aishell3 $dl_dir/aishell3 data/manifests + lhotse prepare aishell3 $dl_dir/aishell3 data/manifests >/dev/null 2>&1 touch data/manifests/.aishell3.done fi fi diff --git a/egs/aishell3/TTS/vits/train.py b/egs/aishell3/TTS/vits/train.py index f3f99ebbc6..b92386e37d 100755 --- a/egs/aishell3/TTS/vits/train.py +++ b/egs/aishell3/TTS/vits/train.py @@ -820,6 +820,10 @@ def run(rank, world_size, args): params.vocab_size = tokenizer.vocab_size aishell3 = Aishell3SpeechTtsDataModule(args) + assert aishell3.sampling_rate == params.sampling_rate, ( + aishell3.sampling_rate, + params.sampling_rate, + ) speaker_map = aishell3.speakers() params.num_spks = len(speaker_map) From c25dc02d5d192a03fc61302d05d2ee602c008b4d Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Sat, 6 Apr 2024 23:27:23 +0800 Subject: [PATCH 8/8] add lexicon --- .github/scripts/aishell3/TTS/run.sh | 40 ++++++++++-- .github/workflows/aishell3.yml | 13 +++- egs/aishell3/TTS/local/generate_lexicon.py | 68 ++++++++++++++++++++ egs/aishell3/TTS/local/prepare_token_file.py | 2 +- egs/aishell3/TTS/prepare.sh | 6 +- egs/aishell3/TTS/vits/export-onnx.py | 2 +- egs/aishell3/TTS/vits/train.py | 2 +- 7 files changed, 124 insertions(+), 9 deletions(-) create mode 100755 egs/aishell3/TTS/local/generate_lexicon.py diff --git a/.github/scripts/aishell3/TTS/run.sh b/.github/scripts/aishell3/TTS/run.sh index 81fba1de44..93ff695728 100755 --- a/.github/scripts/aishell3/TTS/run.sh +++ b/.github/scripts/aishell3/TTS/run.sh @@ -39,6 +39,13 @@ function prepare_data() { echo "------------------------------" wc -l data/tokens.txt echo "------------------------------" + + echo "----------lexicon.txt----------" + head data/lexicon.txt + echo "----" + tail data/lexicon.txt + echo "----" + wc -l data/lexicon.txt } function train() { @@ -47,7 +54,8 @@ function train() { git diff . popd - for t in low medium high; do + # for t in low medium high; do + for t in low; do ./vits/train.py \ --exp-dir vits/exp-$t \ --model-type $t \ @@ -62,12 +70,13 @@ function train() { } function export_onnx() { - for t in low medium high; do + # for t in low medium high; do + for t in low; do ./vits/export-onnx.py \ --model-type $t \ --epoch 1 \ --exp-dir ./vits/exp-$t \ - --tokens data/tokens.txt + --tokens data/tokens.txt \ --speakers ./data/speakers.txt ls -lh vits/exp-$t/ @@ -75,7 +84,30 @@ function export_onnx() { } function test_low() { - echo "TODO" + git clone https://huggingface.co/csukuangfj/icefall-tts-aishell3-vits-low-2024-04-06 + repo=icefall-tts-aishell3-vits-low-2024-04-06 + + ./vits/export-onnx.py \ + --model-type low \ + --epoch 1000 \ + --exp-dir $repo/exp \ + --tokens $repo/data/tokens.txt \ + --speakers $repo/data/speakers.txt + + ls -lh $repo/exp/vits-epoch-1000.onnx + + python3 -m pip install sherpa-onnx + + sherpa-onnx-offline-tts \ + --vits-model=$repo/exp/vits-epoch-960.onnx \ + --vits-tokens=$repo/data/tokens.txt \ + --vits-lexicon=$repo/data/lexicon.txt \ + --num-threads=1 \ + --vits-length-scale=1.0 \ + --sid=33 \ + --output-filename=/icefall/low.wav \ + --debug=1 \ + "这是一个语音合成测试" } diff --git a/.github/workflows/aishell3.yml b/.github/workflows/aishell3.yml index e60c85f4d4..542c77663d 100644 --- a/.github/workflows/aishell3.yml +++ b/.github/workflows/aishell3.yml @@ -1,4 +1,4 @@ -name: aishell +name: aishell3 on: push: @@ -71,3 +71,14 @@ jobs: git config --global --add safe.directory /icefall .github/scripts/aishell3/TTS/run.sh + + - name: display files + shell: bash + run: | + ls -lh + + - uses: actions/upload-artifact@v4 + if: matrix.python-version == '3.9' && matrix.torch-version == '2.2.0' + with: + name: generated-test-files-${{ matrix.python-version }}-${{ matrix.torch-version }} + path: ./*.wav diff --git a/egs/aishell3/TTS/local/generate_lexicon.py b/egs/aishell3/TTS/local/generate_lexicon.py new file mode 100755 index 0000000000..77dd77d625 --- /dev/null +++ b/egs/aishell3/TTS/local/generate_lexicon.py @@ -0,0 +1,68 @@ +#!/usr/bin/env python3 + +""" +This file generates the file lexicon.txt that contains pronunciations of all +words and phrases +""" + +from pypinyin import phrases_dict, pinyin_dict +from tokenizer import Tokenizer + +import argparse + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + parser.add_argument( + "--tokens", + type=str, + default="data/tokens.txt", + help="""Path to vocabulary.""", + ) + + parser.add_argument( + "--lexicon", + type=str, + default="data/lexicon.txt", + help="""Path to save the generated lexicon.""", + ) + return parser + + +def main(): + args = get_parser().parse_args() + filename = args.lexicon + tokens = args.tokens + tokenizer = Tokenizer(tokens) + + word_dict = pinyin_dict.pinyin_dict + phrases = phrases_dict.phrases_dict + + i = 0 + with open(filename, "w", encoding="utf-8") as f: + for key in word_dict: + if not (0x4E00 <= key <= 0x9FFF): + continue + + w = chr(key) + + # 1 to remove the initial sil + # :-1 to remove the final eos + tokens = tokenizer.text_to_tokens(w)[1:-1] + + tokens = " ".join(tokens) + f.write(f"{w} {tokens}\n") + + # TODO(fangjun): Add phrases + # for key in phrases: + # # 1 to remove the initial sil + # # :-1 to remove the final eos + # tokens = tokenizer.text_to_tokens(key)[1:-1] + # tokens = " ".join(tokens) + # f.write(f"{key} {tokens}\n") + + +if __name__ == "__main__": + main() diff --git a/egs/aishell3/TTS/local/prepare_token_file.py b/egs/aishell3/TTS/local/prepare_token_file.py index d90910ab02..57ef837b82 100755 --- a/egs/aishell3/TTS/local/prepare_token_file.py +++ b/egs/aishell3/TTS/local/prepare_token_file.py @@ -17,7 +17,7 @@ """ -This file generates the file that maps tokens to IDs. +This file generates the file tokens.txt that maps tokens to IDs. """ import argparse diff --git a/egs/aishell3/TTS/prepare.sh b/egs/aishell3/TTS/prepare.sh index fe3f762054..db721e67fa 100755 --- a/egs/aishell3/TTS/prepare.sh +++ b/egs/aishell3/TTS/prepare.sh @@ -121,10 +121,14 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then fi if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then - log "Stage 6: Generate token file" + log "Stage 6: Generate tokens.txt and lexicon.txt " if [ ! -e data/tokens.txt ]; then ./local/prepare_token_file.py --tokens data/tokens.txt fi + + if [ ! -e data/lexicon.txt ]; then + ./local/generate_lexicon.py --tokens data/tokens.txt --lexicon data/lexicon.txt + fi fi if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then diff --git a/egs/aishell3/TTS/vits/export-onnx.py b/egs/aishell3/TTS/vits/export-onnx.py index ed5a1c6a33..a2afcaeca6 100755 --- a/egs/aishell3/TTS/vits/export-onnx.py +++ b/egs/aishell3/TTS/vits/export-onnx.py @@ -84,7 +84,7 @@ def get_parser(): parser.add_argument( "--model-type", type=str, - default="medium", + default="low", choices=["low", "medium", "high"], help="""If not empty, valid values are: low, medium, high. It controls the model size. low -> runs faster. diff --git a/egs/aishell3/TTS/vits/train.py b/egs/aishell3/TTS/vits/train.py index b92386e37d..ad30384855 100755 --- a/egs/aishell3/TTS/vits/train.py +++ b/egs/aishell3/TTS/vits/train.py @@ -156,7 +156,7 @@ def get_parser(): parser.add_argument( "--model-type", type=str, - default="medium", + default="low", choices=["low", "medium", "high"], help="""If not empty, valid values are: low, medium, high. It controls the model size. low -> runs faster.