diff --git a/.github/scripts/aishell3/TTS/run.sh b/.github/scripts/aishell3/TTS/run.sh new file mode 100755 index 0000000000..93ff695728 --- /dev/null +++ b/.github/scripts/aishell3/TTS/run.sh @@ -0,0 +1,118 @@ +#!/usr/bin/env bash + +set -ex + +python3 -m pip install piper_phonemize -f https://k2-fsa.github.io/icefall/piper_phonemize.html +python3 -m pip install numba +python3 -m pip install pypinyin +python3 -m pip install cython + +apt-get update +apt-get install -y jq + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +cd egs/aishell3/TTS + +sed -i.bak s/1000/10/g ./prepare.sh + + +function download_data() { + mkdir download + pushd download + curl -SL -O https://huggingface.co/csukuangfj/aishell3-ci-data/resolve/main/aishell3.tar.bz2 + tar xf aishell3.tar.bz2 + rm aishell3.tar.bz2 + ls -lh + popd +} + +function prepare_data() { + ./prepare.sh + + echo "----------tokens.txt----------" + cat data/tokens.txt + echo "------------------------------" + wc -l data/tokens.txt + echo "------------------------------" + + echo "----------lexicon.txt----------" + head data/lexicon.txt + echo "----" + tail data/lexicon.txt + echo "----" + wc -l data/lexicon.txt +} + +function train() { + pushd ./vits + sed -i.bak s/200/50/g ./train.py + git diff . + popd + + # for t in low medium high; do + for t in low; do + ./vits/train.py \ + --exp-dir vits/exp-$t \ + --model-type $t \ + --num-epochs 1 \ + --save-every-n 1 \ + --num-buckets 2 \ + --tokens data/tokens.txt \ + --max-duration 20 + + ls -lh vits/exp-$t + done +} + +function export_onnx() { + # for t in low medium high; do + for t in low; do + ./vits/export-onnx.py \ + --model-type $t \ + --epoch 1 \ + --exp-dir ./vits/exp-$t \ + --tokens data/tokens.txt \ + --speakers ./data/speakers.txt + + ls -lh vits/exp-$t/ + done +} + +function test_low() { + git clone https://huggingface.co/csukuangfj/icefall-tts-aishell3-vits-low-2024-04-06 + repo=icefall-tts-aishell3-vits-low-2024-04-06 + + ./vits/export-onnx.py \ + --model-type low \ + --epoch 1000 \ + --exp-dir $repo/exp \ + --tokens $repo/data/tokens.txt \ + --speakers $repo/data/speakers.txt + + ls -lh $repo/exp/vits-epoch-1000.onnx + + python3 -m pip install sherpa-onnx + + sherpa-onnx-offline-tts \ + --vits-model=$repo/exp/vits-epoch-960.onnx \ + --vits-tokens=$repo/data/tokens.txt \ + --vits-lexicon=$repo/data/lexicon.txt \ + --num-threads=1 \ + --vits-length-scale=1.0 \ + --sid=33 \ + --output-filename=/icefall/low.wav \ + --debug=1 \ + "这是一个语音合成测试" +} + + +download_data +prepare_data +train +export_onnx +test_low diff --git a/.github/workflows/aishell3.yml b/.github/workflows/aishell3.yml new file mode 100644 index 0000000000..542c77663d --- /dev/null +++ b/.github/workflows/aishell3.yml @@ -0,0 +1,84 @@ +name: aishell3 + +on: + push: + branches: + - master + - tts-aishell3 + + pull_request: + branches: + - master + + workflow_dispatch: + +concurrency: + group: aishell3-${{ github.ref }} + cancel-in-progress: true + +jobs: + generate_build_matrix: + if: (github.repository_owner == 'csukuangfj' || github.repository_owner == 'k2-fsa') && (github.event.label.name == 'ready' || github.event_name == 'push' || github.event_name == 'aishell3') + + # see https://github.com/pytorch/pytorch/pull/50633 + runs-on: ubuntu-latest + outputs: + matrix: ${{ steps.set-matrix.outputs.matrix }} + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + - name: Generating build matrix + id: set-matrix + run: | + # outputting for debugging purposes + python ./.github/scripts/docker/generate_build_matrix.py + MATRIX=$(python ./.github/scripts/docker/generate_build_matrix.py) + echo "::set-output name=matrix::${MATRIX}" + aishell3: + needs: generate_build_matrix + name: py${{ matrix.python-version }} torch${{ matrix.torch-version }} v${{ matrix.version }} + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + ${{ fromJson(needs.generate_build_matrix.outputs.matrix) }} + + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: Free space + shell: bash + run: | + df -h + rm -rf /opt/hostedtoolcache + df -h + echo "pwd: $PWD" + echo "github.workspace ${{ github.workspace }}" + + - name: Run aishell3 tests + uses: addnab/docker-run-action@v3 + with: + image: ghcr.io/${{ github.repository_owner }}/icefall:cpu-py${{ matrix.python-version }}-torch${{ matrix.torch-version }}-v${{ matrix.version }} + options: | + --volume ${{ github.workspace }}/:/icefall + shell: bash + run: | + export PYTHONPATH=/icefall:$PYTHONPATH + cd /icefall + git config --global --add safe.directory /icefall + + .github/scripts/aishell3/TTS/run.sh + + - name: display files + shell: bash + run: | + ls -lh + + - uses: actions/upload-artifact@v4 + if: matrix.python-version == '3.9' && matrix.torch-version == '2.2.0' + with: + name: generated-test-files-${{ matrix.python-version }}-${{ matrix.torch-version }} + path: ./*.wav diff --git a/.gitignore b/.gitignore index fa18ca83c3..9e45df61c9 100644 --- a/.gitignore +++ b/.gitignore @@ -36,3 +36,7 @@ node_modules .DS_Store *.fst *.arpa +core.c +*.so +build +*.wav diff --git a/docs/source/recipes/TTS/ljspeech/vits.rst b/docs/source/recipes/TTS/ljspeech/vits.rst index 9499a3aea2..37c8bff1e6 100644 --- a/docs/source/recipes/TTS/ljspeech/vits.rst +++ b/docs/source/recipes/TTS/ljspeech/vits.rst @@ -19,7 +19,7 @@ Install extra dependencies .. code-block:: bash pip install piper_phonemize -f https://k2-fsa.github.io/icefall/piper_phonemize.html - pip install numba espnet_tts_frontend + pip install numba espnet_tts_frontend cython Data preparation ---------------- diff --git a/egs/aishell3/TTS/local/compute_spectrogram_aishell3.py b/egs/aishell3/TTS/local/compute_spectrogram_aishell3.py new file mode 100755 index 0000000000..1c7fccad63 --- /dev/null +++ b/egs/aishell3/TTS/local/compute_spectrogram_aishell3.py @@ -0,0 +1,110 @@ +#!/usr/bin/env python3 +# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This file computes fbank features of the aishell3 dataset. +It looks for manifests in the directory data/manifests. + +The generated spectrogram features are saved in data/spectrogram. +""" + +import logging +import os +from pathlib import Path + +import torch +from lhotse import ( + CutSet, + LilcomChunkyWriter, + Spectrogram, + SpectrogramConfig, + load_manifest, +) +from lhotse.audio import RecordingSet +from lhotse.supervision import SupervisionSet + +from icefall.utils import get_executor + +# Torch's multithreaded behavior needs to be disabled or +# it wastes a lot of CPU and slow things down. +# Do this outside of main() in case it needs to take effect +# even when we are not invoking the main (e.g. when spawning subprocesses). +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + + +def compute_spectrogram_aishell3(): + src_dir = Path("data/manifests") + output_dir = Path("data/spectrogram") + num_jobs = min(4, os.cpu_count()) + + sampling_rate = 8000 + frame_length = 1024 / sampling_rate # (in second) + frame_shift = 256 / sampling_rate # (in second) + use_fft_mag = True + + prefix = "aishell3" + suffix = "jsonl.gz" + partitions = ("test", "train") + + config = SpectrogramConfig( + sampling_rate=sampling_rate, + frame_length=frame_length, + frame_shift=frame_shift, + use_fft_mag=use_fft_mag, + ) + extractor = Spectrogram(config) + + for partition in partitions: + recordings = load_manifest( + src_dir / f"{prefix}_recordings_{partition}.{suffix}", RecordingSet + ) + supervisions = load_manifest( + src_dir / f"{prefix}_supervisions_{partition}.{suffix}", SupervisionSet + ) + + # resample from 44100 to 8000 + recordings = recordings.resample(sampling_rate) + + with get_executor() as ex: # Initialize the executor only once. + cuts_filename = f"{prefix}_cuts_{partition}.{suffix}" + if (output_dir / cuts_filename).is_file(): + logging.info(f"{cuts_filename} already exists - skipping.") + return + logging.info(f"Processing {partition}") + cut_set = CutSet.from_manifests( + recordings=recordings, supervisions=supervisions + ) + + cut_set = cut_set.compute_and_store_features( + extractor=extractor, + storage_path=f"{output_dir}/{prefix}_feats_{partition}", + # when an executor is specified, make more partitions + num_jobs=num_jobs if ex is None else 80, + executor=ex, + storage_type=LilcomChunkyWriter, + ) + cut_set.to_file(output_dir / cuts_filename) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + compute_spectrogram_aishell3() diff --git a/egs/aishell3/TTS/local/generate_lexicon.py b/egs/aishell3/TTS/local/generate_lexicon.py new file mode 100755 index 0000000000..77dd77d625 --- /dev/null +++ b/egs/aishell3/TTS/local/generate_lexicon.py @@ -0,0 +1,68 @@ +#!/usr/bin/env python3 + +""" +This file generates the file lexicon.txt that contains pronunciations of all +words and phrases +""" + +from pypinyin import phrases_dict, pinyin_dict +from tokenizer import Tokenizer + +import argparse + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + parser.add_argument( + "--tokens", + type=str, + default="data/tokens.txt", + help="""Path to vocabulary.""", + ) + + parser.add_argument( + "--lexicon", + type=str, + default="data/lexicon.txt", + help="""Path to save the generated lexicon.""", + ) + return parser + + +def main(): + args = get_parser().parse_args() + filename = args.lexicon + tokens = args.tokens + tokenizer = Tokenizer(tokens) + + word_dict = pinyin_dict.pinyin_dict + phrases = phrases_dict.phrases_dict + + i = 0 + with open(filename, "w", encoding="utf-8") as f: + for key in word_dict: + if not (0x4E00 <= key <= 0x9FFF): + continue + + w = chr(key) + + # 1 to remove the initial sil + # :-1 to remove the final eos + tokens = tokenizer.text_to_tokens(w)[1:-1] + + tokens = " ".join(tokens) + f.write(f"{w} {tokens}\n") + + # TODO(fangjun): Add phrases + # for key in phrases: + # # 1 to remove the initial sil + # # :-1 to remove the final eos + # tokens = tokenizer.text_to_tokens(key)[1:-1] + # tokens = " ".join(tokens) + # f.write(f"{key} {tokens}\n") + + +if __name__ == "__main__": + main() diff --git a/egs/aishell3/TTS/local/pinyin_dict.py b/egs/aishell3/TTS/local/pinyin_dict.py new file mode 100644 index 0000000000..950fb39fc0 --- /dev/null +++ b/egs/aishell3/TTS/local/pinyin_dict.py @@ -0,0 +1,421 @@ +# This dict is copied from +# https://github.com/UEhQZXI/vits_chinese/blob/master/vits_strings.py +pinyin_dict = { + "a": ("^", "a"), + "ai": ("^", "ai"), + "an": ("^", "an"), + "ang": ("^", "ang"), + "ao": ("^", "ao"), + "ba": ("b", "a"), + "bai": ("b", "ai"), + "ban": ("b", "an"), + "bang": ("b", "ang"), + "bao": ("b", "ao"), + "be": ("b", "e"), + "bei": ("b", "ei"), + "ben": ("b", "en"), + "beng": ("b", "eng"), + "bi": ("b", "i"), + "bian": ("b", "ian"), + "biao": ("b", "iao"), + "bie": ("b", "ie"), + "bin": ("b", "in"), + "bing": ("b", "ing"), + "bo": ("b", "o"), + "bu": ("b", "u"), + "ca": ("c", "a"), + "cai": ("c", "ai"), + "can": ("c", "an"), + "cang": ("c", "ang"), + "cao": ("c", "ao"), + "ce": ("c", "e"), + "cen": ("c", "en"), + "ceng": ("c", "eng"), + "cha": ("ch", "a"), + "chai": ("ch", "ai"), + "chan": ("ch", "an"), + "chang": ("ch", "ang"), + "chao": ("ch", "ao"), + "che": ("ch", "e"), + "chen": ("ch", "en"), + "cheng": ("ch", "eng"), + "chi": ("ch", "iii"), + "chong": ("ch", "ong"), + "chou": ("ch", "ou"), + "chu": ("ch", "u"), + "chua": ("ch", "ua"), + "chuai": ("ch", "uai"), + "chuan": ("ch", "uan"), + "chuang": ("ch", "uang"), + "chui": ("ch", "uei"), + "chun": ("ch", "uen"), + "chuo": ("ch", "uo"), + "ci": ("c", "ii"), + "cong": ("c", "ong"), + "cou": ("c", "ou"), + "cu": ("c", "u"), + "cuan": ("c", "uan"), + "cui": ("c", "uei"), + "cun": ("c", "uen"), + "cuo": ("c", "uo"), + "da": ("d", "a"), + "dai": ("d", "ai"), + "dan": ("d", "an"), + "dang": ("d", "ang"), + "dao": ("d", "ao"), + "de": ("d", "e"), + "dei": ("d", "ei"), + "den": ("d", "en"), + "deng": ("d", "eng"), + "di": ("d", "i"), + "dia": ("d", "ia"), + "dian": ("d", "ian"), + "diao": ("d", "iao"), + "die": ("d", "ie"), + "ding": ("d", "ing"), + "diu": ("d", "iou"), + "dong": ("d", "ong"), + "dou": ("d", "ou"), + "du": ("d", "u"), + "duan": ("d", "uan"), + "dui": ("d", "uei"), + "dun": ("d", "uen"), + "duo": ("d", "uo"), + "e": ("^", "e"), + "ei": ("^", "ei"), + "en": ("^", "en"), + "ng": ("^", "en"), + "eng": ("^", "eng"), + "er": ("^", "er"), + "fa": ("f", "a"), + "fan": ("f", "an"), + "fang": ("f", "ang"), + "fei": ("f", "ei"), + "fen": ("f", "en"), + "feng": ("f", "eng"), + "fo": ("f", "o"), + "fou": ("f", "ou"), + "fu": ("f", "u"), + "ga": ("g", "a"), + "gai": ("g", "ai"), + "gan": ("g", "an"), + "gang": ("g", "ang"), + "gao": ("g", "ao"), + "ge": ("g", "e"), + "gei": ("g", "ei"), + "gen": ("g", "en"), + "geng": ("g", "eng"), + "gong": ("g", "ong"), + "gou": ("g", "ou"), + "gu": ("g", "u"), + "gua": ("g", "ua"), + "guai": ("g", "uai"), + "guan": ("g", "uan"), + "guang": ("g", "uang"), + "gui": ("g", "uei"), + "gun": ("g", "uen"), + "guo": ("g", "uo"), + "ha": ("h", "a"), + "hai": ("h", "ai"), + "han": ("h", "an"), + "hang": ("h", "ang"), + "hao": ("h", "ao"), + "he": ("h", "e"), + "hei": ("h", "ei"), + "hen": ("h", "en"), + "heng": ("h", "eng"), + "hong": ("h", "ong"), + "hou": ("h", "ou"), + "hu": ("h", "u"), + "hua": ("h", "ua"), + "huai": ("h", "uai"), + "huan": ("h", "uan"), + "huang": ("h", "uang"), + "hui": ("h", "uei"), + "hun": ("h", "uen"), + "huo": ("h", "uo"), + "ji": ("j", "i"), + "jia": ("j", "ia"), + "jian": ("j", "ian"), + "jiang": ("j", "iang"), + "jiao": ("j", "iao"), + "jie": ("j", "ie"), + "jin": ("j", "in"), + "jing": ("j", "ing"), + "jiong": ("j", "iong"), + "jiu": ("j", "iou"), + "ju": ("j", "v"), + "juan": ("j", "van"), + "jue": ("j", "ve"), + "jun": ("j", "vn"), + "ka": ("k", "a"), + "kai": ("k", "ai"), + "kan": ("k", "an"), + "kang": ("k", "ang"), + "kao": ("k", "ao"), + "ke": ("k", "e"), + "kei": ("k", "ei"), + "ken": ("k", "en"), + "keng": ("k", "eng"), + "kong": ("k", "ong"), + "kou": ("k", "ou"), + "ku": ("k", "u"), + "kua": ("k", "ua"), + "kuai": ("k", "uai"), + "kuan": ("k", "uan"), + "kuang": ("k", "uang"), + "kui": ("k", "uei"), + "kun": ("k", "uen"), + "kuo": ("k", "uo"), + "la": ("l", "a"), + "lai": ("l", "ai"), + "lan": ("l", "an"), + "lang": ("l", "ang"), + "lao": ("l", "ao"), + "le": ("l", "e"), + "lei": ("l", "ei"), + "leng": ("l", "eng"), + "li": ("l", "i"), + "lia": ("l", "ia"), + "lian": ("l", "ian"), + "liang": ("l", "iang"), + "liao": ("l", "iao"), + "lie": ("l", "ie"), + "lin": ("l", "in"), + "ling": ("l", "ing"), + "liu": ("l", "iou"), + "lo": ("l", "o"), + "long": ("l", "ong"), + "lou": ("l", "ou"), + "lu": ("l", "u"), + "lv": ("l", "v"), + "luan": ("l", "uan"), + "lve": ("l", "ve"), + "lue": ("l", "ve"), + "lun": ("l", "uen"), + "luo": ("l", "uo"), + "ma": ("m", "a"), + "mai": ("m", "ai"), + "man": ("m", "an"), + "mang": ("m", "ang"), + "mao": ("m", "ao"), + "me": ("m", "e"), + "mei": ("m", "ei"), + "men": ("m", "en"), + "meng": ("m", "eng"), + "mi": ("m", "i"), + "mian": ("m", "ian"), + "miao": ("m", "iao"), + "mie": ("m", "ie"), + "min": ("m", "in"), + "ming": ("m", "ing"), + "miu": ("m", "iou"), + "mo": ("m", "o"), + "mou": ("m", "ou"), + "mu": ("m", "u"), + "na": ("n", "a"), + "nai": ("n", "ai"), + "nan": ("n", "an"), + "nang": ("n", "ang"), + "nao": ("n", "ao"), + "ne": ("n", "e"), + "nei": ("n", "ei"), + "nen": ("n", "en"), + "neng": ("n", "eng"), + "ni": ("n", "i"), + "nia": ("n", "ia"), + "nian": ("n", "ian"), + "niang": ("n", "iang"), + "niao": ("n", "iao"), + "nie": ("n", "ie"), + "nin": ("n", "in"), + "ning": ("n", "ing"), + "niu": ("n", "iou"), + "nong": ("n", "ong"), + "nou": ("n", "ou"), + "nu": ("n", "u"), + "nv": ("n", "v"), + "nuan": ("n", "uan"), + "nve": ("n", "ve"), + "nue": ("n", "ve"), + "nuo": ("n", "uo"), + "o": ("^", "o"), + "ou": ("^", "ou"), + "pa": ("p", "a"), + "pai": ("p", "ai"), + "pan": ("p", "an"), + "pang": ("p", "ang"), + "pao": ("p", "ao"), + "pe": ("p", "e"), + "pei": ("p", "ei"), + "pen": ("p", "en"), + "peng": ("p", "eng"), + "pi": ("p", "i"), + "pian": ("p", "ian"), + "piao": ("p", "iao"), + "pie": ("p", "ie"), + "pin": ("p", "in"), + "ping": ("p", "ing"), + "po": ("p", "o"), + "pou": ("p", "ou"), + "pu": ("p", "u"), + "qi": ("q", "i"), + "qia": ("q", "ia"), + "qian": ("q", "ian"), + "qiang": ("q", "iang"), + "qiao": ("q", "iao"), + "qie": ("q", "ie"), + "qin": ("q", "in"), + "qing": ("q", "ing"), + "qiong": ("q", "iong"), + "qiu": ("q", "iou"), + "qu": ("q", "v"), + "quan": ("q", "van"), + "que": ("q", "ve"), + "qun": ("q", "vn"), + "ran": ("r", "an"), + "rang": ("r", "ang"), + "rao": ("r", "ao"), + "re": ("r", "e"), + "ren": ("r", "en"), + "reng": ("r", "eng"), + "ri": ("r", "iii"), + "rong": ("r", "ong"), + "rou": ("r", "ou"), + "ru": ("r", "u"), + "rua": ("r", "ua"), + "ruan": ("r", "uan"), + "rui": ("r", "uei"), + "run": ("r", "uen"), + "ruo": ("r", "uo"), + "sa": ("s", "a"), + "sai": ("s", "ai"), + "san": ("s", "an"), + "sang": ("s", "ang"), + "sao": ("s", "ao"), + "se": ("s", "e"), + "sen": ("s", "en"), + "seng": ("s", "eng"), + "sha": ("sh", "a"), + "shai": ("sh", "ai"), + "shan": ("sh", "an"), + "shang": ("sh", "ang"), + "shao": ("sh", "ao"), + "she": ("sh", "e"), + "shei": ("sh", "ei"), + "shen": ("sh", "en"), + "sheng": ("sh", "eng"), + "shi": ("sh", "iii"), + "shou": ("sh", "ou"), + "shu": ("sh", "u"), + "shua": ("sh", "ua"), + "shuai": ("sh", "uai"), + "shuan": ("sh", "uan"), + "shuang": ("sh", "uang"), + "shui": ("sh", "uei"), + "shun": ("sh", "uen"), + "shuo": ("sh", "uo"), + "si": ("s", "ii"), + "song": ("s", "ong"), + "sou": ("s", "ou"), + "su": ("s", "u"), + "suan": ("s", "uan"), + "sui": ("s", "uei"), + "sun": ("s", "uen"), + "suo": ("s", "uo"), + "ta": ("t", "a"), + "tai": ("t", "ai"), + "tan": ("t", "an"), + "tang": ("t", "ang"), + "tao": ("t", "ao"), + "te": ("t", "e"), + "tei": ("t", "ei"), + "teng": ("t", "eng"), + "ti": ("t", "i"), + "tian": ("t", "ian"), + "tiao": ("t", "iao"), + "tie": ("t", "ie"), + "ting": ("t", "ing"), + "tong": ("t", "ong"), + "tou": ("t", "ou"), + "tu": ("t", "u"), + "tuan": ("t", "uan"), + "tui": ("t", "uei"), + "tun": ("t", "uen"), + "tuo": ("t", "uo"), + "wa": ("^", "ua"), + "wai": ("^", "uai"), + "wan": ("^", "uan"), + "wang": ("^", "uang"), + "wei": ("^", "uei"), + "wen": ("^", "uen"), + "weng": ("^", "ueng"), + "wo": ("^", "uo"), + "wu": ("^", "u"), + "xi": ("x", "i"), + "xia": ("x", "ia"), + "xian": ("x", "ian"), + "xiang": ("x", "iang"), + "xiao": ("x", "iao"), + "xie": ("x", "ie"), + "xin": ("x", "in"), + "xing": ("x", "ing"), + "xiong": ("x", "iong"), + "xiu": ("x", "iou"), + "xu": ("x", "v"), + "xuan": ("x", "van"), + "xue": ("x", "ve"), + "xun": ("x", "vn"), + "ya": ("^", "ia"), + "yan": ("^", "ian"), + "yang": ("^", "iang"), + "yao": ("^", "iao"), + "ye": ("^", "ie"), + "yi": ("^", "i"), + "yin": ("^", "in"), + "ying": ("^", "ing"), + "yo": ("^", "iou"), + "yong": ("^", "iong"), + "you": ("^", "iou"), + "yu": ("^", "v"), + "yuan": ("^", "van"), + "yue": ("^", "ve"), + "yun": ("^", "vn"), + "za": ("z", "a"), + "zai": ("z", "ai"), + "zan": ("z", "an"), + "zang": ("z", "ang"), + "zao": ("z", "ao"), + "ze": ("z", "e"), + "zei": ("z", "ei"), + "zen": ("z", "en"), + "zeng": ("z", "eng"), + "zha": ("zh", "a"), + "zhai": ("zh", "ai"), + "zhan": ("zh", "an"), + "zhang": ("zh", "ang"), + "zhao": ("zh", "ao"), + "zhe": ("zh", "e"), + "zhei": ("zh", "ei"), + "zhen": ("zh", "en"), + "zheng": ("zh", "eng"), + "zhi": ("zh", "iii"), + "zhong": ("zh", "ong"), + "zhou": ("zh", "ou"), + "zhu": ("zh", "u"), + "zhua": ("zh", "ua"), + "zhuai": ("zh", "uai"), + "zhuan": ("zh", "uan"), + "zhuang": ("zh", "uang"), + "zhui": ("zh", "uei"), + "zhun": ("zh", "uen"), + "zhuo": ("zh", "uo"), + "zi": ("z", "ii"), + "zong": ("z", "ong"), + "zou": ("z", "ou"), + "zu": ("z", "u"), + "zuan": ("z", "uan"), + "zui": ("z", "uei"), + "zun": ("z", "uen"), + "zuo": ("z", "uo"), +} diff --git a/egs/aishell3/TTS/local/prepare_token_file.py b/egs/aishell3/TTS/local/prepare_token_file.py new file mode 100755 index 0000000000..57ef837b82 --- /dev/null +++ b/egs/aishell3/TTS/local/prepare_token_file.py @@ -0,0 +1,53 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This file generates the file tokens.txt that maps tokens to IDs. +""" + +import argparse +import logging +from pathlib import Path +from typing import Dict +from symbols import symbols + + +def get_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--tokens", + type=Path, + default=Path("data/tokens.txt"), + help="Path to the dict that maps the text tokens to IDs", + ) + + return parser.parse_args() + + +def main(): + args = get_args() + tokens = Path(args.tokens) + + with open(tokens, "w", encoding="utf-8") as f: + for token_id, token in enumerate(symbols): + f.write(f"{token} {token_id}\n") + + +if __name__ == "__main__": + main() diff --git a/egs/aishell3/TTS/local/prepare_tokens_aishell3.py b/egs/aishell3/TTS/local/prepare_tokens_aishell3.py new file mode 100755 index 0000000000..4b2b5094fd --- /dev/null +++ b/egs/aishell3/TTS/local/prepare_tokens_aishell3.py @@ -0,0 +1,62 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This file reads the texts in given manifest and save the new cuts with tokens. +""" + +import logging +from pathlib import Path + +from lhotse import CutSet, load_manifest + +from tokenizer import Tokenizer + + +def prepare_tokens_aishell3(): + output_dir = Path("data/spectrogram") + prefix = "aishell3" + suffix = "jsonl.gz" + partitions = ("train", "test") + + tokenizer = Tokenizer() + + for partition in partitions: + cut_set = load_manifest(output_dir / f"{prefix}_cuts_{partition}.{suffix}") + + new_cuts = [] + i = 0 + for cut in cut_set: + # Each cut only contains one supervision + assert len(cut.supervisions) == 1, (len(cut.supervisions), cut) + text = cut.supervisions[0].text + cut.tokens = tokenizer.text_to_tokens(text) + + new_cuts.append(cut) + + new_cut_set = CutSet.from_cuts(new_cuts) + new_cut_set.to_file( + output_dir / f"{prefix}_cuts_with_tokens_{partition}.{suffix}" + ) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + logging.basicConfig(format=formatter, level=logging.INFO) + + prepare_tokens_aishell3() diff --git a/egs/aishell3/TTS/local/pypinyin-local.dict b/egs/aishell3/TTS/local/pypinyin-local.dict new file mode 100644 index 0000000000..5e386014c8 --- /dev/null +++ b/egs/aishell3/TTS/local/pypinyin-local.dict @@ -0,0 +1,328 @@ +姐姐 jie3 jie +宝宝 bao3 bao +哥哥 ge1 ge +妹妹 mei4 mei +弟弟 di4 di +妈妈 ma1 ma +开心哦 kai1 xin1 o +爸爸 ba4 ba +秘密哟 mi4 mi4 yo +哦 o +一年 yi4 nian2 +一夜 yi2 ye4 +一切 yi2 qie4 +一座 yi2 zuo4 +一下 yi2 xia4 +上一山 shang4 yi2 shan1 +下一山 xia4 yi2 shan1 +休息 xiu1 xi2 +东西 dong1 xi +上一届 shang4 yi2 jie4 +便宜 pian2 yi4 +加长 jia1 chang2 +单田芳 shan4 tian2 fang1 +帧 zhen1 +长时间 chang2 shi2 jian1 +长时 chang2 shi2 +识别 shi2 bie2 +生命中 sheng1 ming4 zhong1 +踏实 ta1 shi +嗯 en4 +溜达 liu1 da +少儿 shao4 er2 +爷爷 ye2 ye +不是 bu2 shi4 +一圈 yi1 quan1 +厜读一声 zui1 du2 yi4 sheng1 +一种 yi4 zhong3 +一簇簇 yi2 cu4 cu4 +一个 yi2 ge4 +一样 yi2 yang4 +一跩一跩 yi4 zhuai3 yi4 zhuai3 +一会儿 yi2 hui4 er +一幢 yi2 zhuang4 +挨了 ai2 le +熬菜 ao1 cai4 +扒鸡 pa2 ji1 +背枪 bei1 qiang1 +绷瓷儿 beng4 ci2 er2 +绷劲儿 beng3 jin4 er +绷着脸 beng3 zhe lian3 +藏医 zang4 yi1 +噌吰 cheng1 hong2 +差点儿 cha4 dian3 er +差失 cha1 shi1 +差误 cha1 wu4 +孱头 can4 tou +乘间 cheng2 jian4 +锄镰棘矜 chu2 lian2 ji2 qin2 +川藏 chuan1 zang4 +穿著 chuan1 zhuo2 +答讪 da1 shan4 +答言 da1 yan2 +大伯子 da4 bai3 zi +大夫 dai4 fu +弹冠 tan2 guan1 +当间 dang1 jian4 +当然咯 dang1 ran2 lo +点种 dian3 zhong3 +垛好 duo4 hao3 +发疟子 fa1 yao4 zi +饭熟了 fan4 shou2 le +附著 fu4 zhuo2 +复沓 fu4 ta4 +供稿 gong1 gao3 +供养 gong1 yang3 +骨朵 gu1 duo +骨碌 gu1 lu +果脯 guo3 fu3 +哈什玛 ha4 shi2 ma3 +海蜇 hai3 zhe2 +呵欠 he1 qian +河水汤汤 he2 shui3 shang1 shang1 +鹄立 hu2 li4 +鹄望 hu2 wang4 +混人 hun2 ren2 +混水 hun2 shui3 +鸡血 ji1 xie3 +缉鞋口 qi1 xie2 kou3 +亟来闻讯 qi4 lai2 wen2 xun4 +计量 ji4 liang2 +济水 ji3 shui3 +间杂 jian4 za2 +脚跐两只船 jiao3 ci3 liang3 zhi1 chuan2 +脚儿 jue2 er2 +口角 kou3 jiao3 +勒石 le4 shi2 +累进 lei3 jin4 +累累如丧家之犬 lei2 lei2 ru2 sang4 jia1 zhi1 quan3 +累年 lei3 nian2 +脸涨通红 lian3 zhang4 tong1 hong2 +踉锵 liang4 qiang1 +燎眉毛 liao3 mei2 mao2 +燎头发 liao3 tou2 fa4 +溜达 liu1 da +溜缝儿 liu4 feng4 er +馏口饭 liu4 kou3 fan4 +遛马 liu4 ma3 +遛鸟 liu4 niao3 +遛弯儿 liu4 wan1 er +楼枪机 lou1 qiang1 ji1 +搂钱 lou1 qian2 +鹿脯 lu4 fu3 +露头 lou4 tou2 +落魄 luo4 po4 +捋胡子 lv3 hu2 zi +绿地 lv4 di4 +麦垛 mai4 duo4 +没劲儿 mei2 jin4 er +闷棍 men4 gun4 +闷葫芦 men4 hu2 lu +闷头干 men1 tou2 gan4 +蒙古 meng3 gu3 +靡日不思 mi3 ri4 bu4 si1 +缪姓 miao4 xing4 +抹墙 mo4 qiang2 +抹下脸 ma1 xia4 lian3 +泥子 ni4 zi +拗不过 niu4 bu guo4 +排车 pai3 che1 +盘诘 pan2 jie2 +膀肿 pang1 zhong3 +炮干 bao1 gan1 +炮格 pao2 ge2 +碰钉子 peng4 ding1 zi +缥色 piao3 se4 +瀑河 bao4 he2 +蹊径 xi1 jing4 +前后相属 qian2 hou4 xiang1 zhu3 +翘尾巴 qiao4 wei3 ba +趄坡儿 qie4 po1 er +秦桧 qin2 hui4 +圈马 juan1 ma3 +雀盲眼 qiao3 mang2 yan3 +雀子 qiao1 zi +三年五载 san1 nian2 wu3 zai3 +加载 jia1 zai3 +山大王 shan1 dai4 wang +苫屋草 shan4 wu1 cao3 +数数 shu3 shu4 +说客 shui4 ke4 +思量 si1 liang2 +伺侯 ci4 hou +踏实 ta1 shi +提溜 di1 liu +调拨 diao4 bo1 +帖子 tie3 zi +铜钿 tong2 tian2 +头昏脑涨 tou2 hun1 nao3 zhang4 +褪色 tui4 se4 +褪着手 tun4 zhe shou3 +圩子 wei2 zi +尾巴 wei3 ba +系好船只 xi4 hao3 chuan2 zhi1 +系好马匹 xi4 hao3 ma3 pi3 +杏脯 xing4 fu3 +姓单 xing4 shan4 +姓葛 xing4 ge3 +姓哈 xing4 ha3 +姓解 xing4 xie4 +姓秘 xing4 bi4 +姓宁 xing4 ning4 +旋风 xuan4 feng1 +旋根车轴 xuan4 gen1 che1 zhou2 +荨麻 qian2 ma2 +一幢楼房 yi1 zhuang4 lou2 fang2 +遗之千金 wei4 zhi1 qian1 jin1 +殷殷 yin3 yin3 +应招 ying4 zhao1 +用称约 yong4 cheng4 yao1 +约斤肉 yao1 jin1 rou4 +晕机 yun4 ji1 +熨贴 yu4 tie1 +咋办 za3 ban4 +咋呼 zha1 hu +仔兽 zi3 shou4 +扎彩 za1 cai3 +扎实 zha1 shi +扎腰带 za1 yao1 dai4 +轧朋友 ga2 peng2 you3 +爪子 zhua3 zi +折腾 zhe1 teng +着实 zhuo2 shi2 +着我旧时裳 zhuo2 wo3 jiu4 shi2 chang2 +枝蔓 zhi1 man4 +中鹄 zhong1 hu2 +中选 zhong4 xuan3 +猪圈 zhu1 juan4 +拽住不放 zhuai4 zhu4 bu4 fang4 +转悠 zhuan4 you +庄稼熟了 zhuang1 jia shou2 le +酌量 zhuo2 liang2 +罪行累累 zui4 xing2 lei3 lei3 +一手 yi4 shou3 +一去不复返 yi2 qu4 bu2 fu4 fan3 +一颗 yi4 ke1 +一件 yi2 jian4 +一斤 yi4 jin1 +一点 yi4 dian3 +一朵 yi4 duo3 +一声 yi4 sheng1 +一身 yi4 shen1 +不要 bu2 yao4 +一人 yi4 ren2 +一个 yi2 ge4 +一把 yi4 ba3 +一门 yi4 men2 +一門 yi4 men2 +一艘 yi4 sou1 +一片 yi2 pian4 +一篇 yi2 pian1 +一份 yi2 fen4 +好嗲 hao3 dia3 +随地 sui2 di4 +扁担长 bian3 dan4 chang3 +一堆 yi4 dui1 +不义 bu2 yi4 +放一放 fang4 yi2 fang4 +一米 yi4 mi3 +一顿 yi2 dun4 +一层楼 yi4 ceng2 lou2 +一条 yi4 tiao2 +一件 yi2 jian4 +一棵 yi4 ke1 +一小股 yi4 xiao3 gu3 +一拐一拐 yi4 guai3 yi4 guai3 +一根 yi4 gen1 +沆瀣一气 hang4 xie4 yi2 qi4 +一丝 yi4 si1 +一毫 yi4 hao2 +一樣 yi2 yang4 +处处 chu4 chu4 +一餐 yi4 can +永不 yong3 bu2 +一看 yi2 kan4 +一架 yi2 jia4 +送还 song4 huan2 +一见 yi2 jian4 +一座 yi2 zuo4 +一块 yi2 kuai4 +一天 yi4 tian1 +一只 yi4 zhi1 +一支 yi4 zhi1 +一字 yi2 zi4 +一句 yi2 ju4 +一张 yi4 zhang1 +一條 yi4 tiao2 +一场 yi4 chang3 +一粒 yi2 li4 +小俩口 xiao3 liang3 kou3 +一首 yi4 shou3 +一对 yi2 dui4 +一手 yi4 shou3 +又一村 you4 yi4 cun1 +一概而论 yi2 gai4 er2 lun4 +一峰峰 yi4 feng1 feng1 +不但 bu2 dan4 +一笑 yi2 xiao4 +挠痒痒 nao2 yang3 yang +不对 bu2 dui4 +拧开 ning3 kai1 +爱不释手 ai4 bu2 shi4 shou3 +一念 yi2 nian4 +夺得 duo2 de2 +一袭 yi4 xi2 +一定 yi2 ding4 +不慎 bu2 shen4 +剽窃 piao2 qie4 +一时 yi4 shi2 +撇开 pie3 kai1 +一祭 yi2 ji4 +发卡 fa4 qia3 +少不了 shao3 bu4 liao3 +千虑一失 qian1 lv4 yi4 shi1 +呛得 qiang4 de2 +切菜 qie1 cai4 +茄盒 qie2 he2 +不去 bu2 qu4 +一大圈 yi2 da4 quan1 +不再 bu2 zai4 +一群 yi4 qun2 +不必 bu2 bi4 +一些 yi4 xie1 +一路 yi2 lu4 +一股 yi4 gu3 +一到 yi2 dao4 +一拨 yi4 bo1 +一排 yi4 pai2 +一空 yi4 kong1 +吮吸着 shun3 xi1 zhe +不适合 bu2 shi4 he2 +一串串 yi2 chuan4 chuan4 +一提起 yi4 ti2 qi3 +一尘不染 yi4 chen2 bu4 ran3 +一生 yi4 sheng1 +一派 yi2 pai4 +不断 bu2 duan4 +一次 yi2 ci4 +不进步 bu2 jin4 bu4 +娃娃 wa2 wa +万户侯 wan4 hu4 hou2 +一方 yi4 fang1 +一番话 yi4 fan1 hua4 +一遍 yi2 bian4 +不计较 bu2 ji4 jiao4 +诇 xiong4 +一边 yi4 bian1 +一束 yi2 shu4 +一听到 yi4 ting1 dao4 +炸鸡 zha2 ji1 +乍暧还寒 zha4 ai4 huan2 han2 +我说诶 wo3 shuo1 ei1 +棒诶 bang4 ei1 +寒碜 han2 chen4 +应采儿 ying4 cai3 er2 +晕车 yun1 che1 +必应 bi4 ying4 +应援 ying4 yuan2 +应力 ying4 li4 \ No newline at end of file diff --git a/egs/aishell3/TTS/local/symbols.py b/egs/aishell3/TTS/local/symbols.py new file mode 100644 index 0000000000..1e68788704 --- /dev/null +++ b/egs/aishell3/TTS/local/symbols.py @@ -0,0 +1,73 @@ +# This file is copied from +# https://github.com/UEhQZXI/vits_chinese/blob/master/text/symbols.py +_pause = ["sil", "eos", "sp", "#0", "#1", "#2", "#3"] + +_initials = [ + "^", + "b", + "c", + "ch", + "d", + "f", + "g", + "h", + "j", + "k", + "l", + "m", + "n", + "p", + "q", + "r", + "s", + "sh", + "t", + "x", + "z", + "zh", +] + +_tones = ["1", "2", "3", "4", "5"] + +_finals = [ + "a", + "ai", + "an", + "ang", + "ao", + "e", + "ei", + "en", + "eng", + "er", + "i", + "ia", + "ian", + "iang", + "iao", + "ie", + "ii", + "iii", + "in", + "ing", + "iong", + "iou", + "o", + "ong", + "ou", + "u", + "ua", + "uai", + "uan", + "uang", + "uei", + "uen", + "ueng", + "uo", + "v", + "van", + "ve", + "vn", +] + +symbols = _pause + _initials + [i + j for i in _finals for j in _tones] diff --git a/egs/aishell3/TTS/local/tokenizer.py b/egs/aishell3/TTS/local/tokenizer.py new file mode 100644 index 0000000000..cbf6c9c773 --- /dev/null +++ b/egs/aishell3/TTS/local/tokenizer.py @@ -0,0 +1,137 @@ +# This file is modified from +# https://github.com/UEhQZXI/vits_chinese/blob/master/vits_strings.py + +import logging +from pathlib import Path +from typing import List + +# Note pinyin_dict is from ./pinyin_dict.py +from pinyin_dict import pinyin_dict +from pypinyin import Style +from pypinyin.contrib.neutral_tone import NeutralToneWith5Mixin +from pypinyin.converter import DefaultConverter +from pypinyin.core import Pinyin, load_phrases_dict + + +class _MyConverter(NeutralToneWith5Mixin, DefaultConverter): + pass + + +class Tokenizer: + def __init__(self, tokens: str = ""): + self._load_pinyin_dict() + self._pinyin_parser = Pinyin(_MyConverter()) + + if tokens != "": + self._load_tokens(tokens) + + def texts_to_token_ids(self, texts: List[str], **kwargs) -> List[List[int]]: + """ + Args: + texts: + A list of sentences. + kwargs: + Not used. It is for compatibility with other TTS recipes in icefall. + """ + tokens = [] + + for text in texts: + tokens.append(self.text_to_tokens(text)) + + return self.tokens_to_token_ids(tokens) + + def tokens_to_token_ids(self, tokens: List[List[str]]) -> List[List[int]]: + ans = [] + + for token_list in tokens: + token_ids = [] + for t in token_list: + if t not in self.token2id: + logging.warning(f"Skip OOV {t}") + continue + token_ids.append(self.token2id[t]) + ans.append(token_ids) + + return ans + + def text_to_tokens(self, text: str) -> List[str]: + # Convert "," to ["sp", "sil"] + # Convert "。" to ["sil"] + # append ["eos"] at the end of a sentence + phonemes = ["sil"] + pinyins = self._pinyin_parser.pinyin( + text, + style=Style.TONE3, + errors=lambda x: [[w] for w in x], + ) + + new_pinyin = [] + for p in pinyins: + p = p[0] + if p == ",": + new_pinyin.extend(["sp", "sil"]) + elif p == "。": + new_pinyin.append("sil") + else: + new_pinyin.append(p) + sub_phonemes = self._get_phoneme4pinyin(new_pinyin) + sub_phonemes.append("eos") + phonemes.extend(sub_phonemes) + return phonemes + + def _get_phoneme4pinyin(self, pinyins): + result = [] + for pinyin in pinyins: + if pinyin in ("sil", "sp"): + result.append(pinyin) + elif pinyin[:-1] in pinyin_dict: + tone = pinyin[-1] + a = pinyin[:-1] + a1, a2 = pinyin_dict[a] + # every word is appended with a #0 + result += [a1, a2 + tone, "#0"] + + return result + + def _load_pinyin_dict(self): + this_dir = Path(__file__).parent.resolve() + my_dict = {} + with open(f"{this_dir}/pypinyin-local.dict", "r", encoding="utf-8") as f: + content = f.readlines() + for line in content: + cuts = line.strip().split() + hanzi = cuts[0] + pinyin = cuts[1:] + my_dict[hanzi] = [[p] for p in pinyin] + + load_phrases_dict(my_dict) + + def _load_tokens(self, filename): + token2id: Dict[str, int] = {} + + with open(filename, "r", encoding="utf-8") as f: + for line in f.readlines(): + info = line.rstrip().split() + if len(info) == 1: + # case of space + token = " " + idx = int(info[0]) + else: + token, idx = info[0], int(info[1]) + + assert token not in token2id, token + + token2id[token] = idx + + self.token2id = token2id + self.vocab_size = len(self.token2id) + self.pad_id = self.token2id["#0"] + + +def main(): + tokenizer = Tokenizer() + tokenizer._sentence_to_ids("你好,好的。") + + +if __name__ == "__main__": + main() diff --git a/egs/aishell3/TTS/local/validate_manifest.py b/egs/aishell3/TTS/local/validate_manifest.py new file mode 120000 index 0000000000..b4d52ebca0 --- /dev/null +++ b/egs/aishell3/TTS/local/validate_manifest.py @@ -0,0 +1 @@ +../../../ljspeech/TTS/local/validate_manifest.py \ No newline at end of file diff --git a/egs/aishell3/TTS/prepare.sh b/egs/aishell3/TTS/prepare.sh new file mode 100755 index 0000000000..db721e67fa --- /dev/null +++ b/egs/aishell3/TTS/prepare.sh @@ -0,0 +1,141 @@ +#!/usr/bin/env bash + +# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 +export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python + +set -eou pipefail + +stage=-1 +stop_stage=100 + +dl_dir=$PWD/download + +. shared/parse_options.sh || exit 1 + +# All files generated by this script are saved in "data". +# You can safely remove "data" and rerun this script to regenerate it. +mkdir -p data + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +log "dl_dir: $dl_dir" + +if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then + log "Stage 0: build monotonic_align lib" + if [ ! -d vits/monotonic_align/build ]; then + cd vits/monotonic_align + python3 setup.py build_ext --inplace + cd ../../ + else + log "monotonic_align lib already built" + fi +fi + +if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then + log "Stage 1: Download data" + + # The directory $dl_dir/aishell3 will contain the following files + # and sub directories + # ChangeLog ReadMe.txt phone_set.txt spk-info.txt test train + # If you have pre-downloaded it to /path/to/aishell3, you can create a symlink + # + # ln -sfv /path/to/aishell3 $dl_dir/ + # touch $dl_dir/aishell3/.completed + # + if [ ! -d $dl_dir/aishell3 ]; then + lhotse download aishell3 $dl_dir + fi +fi + + +if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then + log "Stage 2: Prepare aishell3 manifest (may take 13 minutes)" + # We assume that you have downloaded the baker corpus + # to $dl_dir/aishell3. + # You can find files like spk-info.txt inside $dl_dir/aishell3 + mkdir -p data/manifests + if [ ! -e data/manifests/.aishell3.done ]; then + lhotse prepare aishell3 $dl_dir/aishell3 data/manifests >/dev/null 2>&1 + touch data/manifests/.aishell3.done + fi +fi + + +if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then + log "Stage 3: Compute spectrogram for aishell3 (may take 5 minutes)" + mkdir -p data/spectrogram + if [ ! -e data/spectrogram/.aishell3.done ]; then + ./local/compute_spectrogram_aishell3.py + touch data/spectrogram/.aishell3.done + fi + + if [ ! -e data/spectrogram/.aishell3-validated.done ]; then + log "Validating data/spectrogram for aishell3" + python3 ./local/validate_manifest.py \ + data/spectrogram/aishell3_cuts_train.jsonl.gz + + python3 ./local/validate_manifest.py \ + data/spectrogram/aishell3_cuts_test.jsonl.gz + + touch data/spectrogram/.aishell3-validated.done + fi +fi + +if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then + log "Stage 4: Prepare tokens for aishell3 (may take 20 seconds)" + if [ ! -e data/spectrogram/.aishell3_with_token.done ]; then + + ./local/prepare_tokens_aishell3.py + + mv -v data/spectrogram/aishell3_cuts_with_tokens_train.jsonl.gz \ + data/spectrogram/aishell3_cuts_train.jsonl.gz + + mv -v data/spectrogram/aishell3_cuts_with_tokens_test.jsonl.gz \ + data/spectrogram/aishell3_cuts_test.jsonl.gz + + touch data/spectrogram/.aishell3_with_token.done + fi +fi + +if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then + log "Stage 5: Split the aishell3 cuts into train, valid and test sets (may take 25 seconds)" + if [ ! -e data/spectrogram/.aishell3_split.done ]; then + lhotse subset --last 1000 \ + data/spectrogram/aishell3_cuts_test.jsonl.gz \ + data/spectrogram/aishell3_cuts_valid.jsonl.gz + + n=$(( $(gunzip -c data/spectrogram/aishell3_cuts_test.jsonl.gz | wc -l) - 1000 )) + + lhotse subset --first $n \ + data/spectrogram/aishell3_cuts_test.jsonl.gz \ + data/spectrogram/aishell3_cuts_test2.jsonl.gz + + mv data/spectrogram/aishell3_cuts_test2.jsonl.gz data/spectrogram/aishell3_cuts_test.jsonl.gz + + touch data/spectrogram/.aishell3_split.done + fi +fi + +if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then + log "Stage 6: Generate tokens.txt and lexicon.txt " + if [ ! -e data/tokens.txt ]; then + ./local/prepare_token_file.py --tokens data/tokens.txt + fi + + if [ ! -e data/lexicon.txt ]; then + ./local/generate_lexicon.py --tokens data/tokens.txt --lexicon data/lexicon.txt + fi +fi + +if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then + log "Stage 7: Generate speakers file" + if [ ! -e data/speakers.txt ]; then + gunzip -c data/manifests/aishell3_supervisions_train.jsonl.gz \ + | jq '.speaker' | sed 's/"//g' \ + | sort | uniq > data/speakers.txt + fi +fi diff --git a/egs/aishell3/TTS/shared b/egs/aishell3/TTS/shared new file mode 120000 index 0000000000..4cbd91a7e9 --- /dev/null +++ b/egs/aishell3/TTS/shared @@ -0,0 +1 @@ +../../../icefall/shared \ No newline at end of file diff --git a/egs/aishell3/TTS/vits/duration_predictor.py b/egs/aishell3/TTS/vits/duration_predictor.py new file mode 120000 index 0000000000..9972b476f9 --- /dev/null +++ b/egs/aishell3/TTS/vits/duration_predictor.py @@ -0,0 +1 @@ +../../../ljspeech/TTS/vits/duration_predictor.py \ No newline at end of file diff --git a/egs/aishell3/TTS/vits/export-onnx.py b/egs/aishell3/TTS/vits/export-onnx.py new file mode 100755 index 0000000000..a2afcaeca6 --- /dev/null +++ b/egs/aishell3/TTS/vits/export-onnx.py @@ -0,0 +1,433 @@ +#!/usr/bin/env python3 +# +# Copyright 2023 Xiaomi Corporation (Author: Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This script exports a VITS model from PyTorch to ONNX. + +Export the model to ONNX: +./vits/export-onnx.py \ + --epoch 1000 \ + --speakers ./data/speakers.txt \ + --exp-dir vits/exp \ + --tokens data/tokens.txt + +It will generate one file inside vits/exp: + - vits-epoch-1000.onnx + +See ./test_onnx.py for how to use the exported ONNX models. +""" + +import argparse +import logging +from pathlib import Path +from typing import Dict, Tuple + +import onnx +import torch +import torch.nn as nn +from tokenizer import Tokenizer +from train import get_model, get_params + +from icefall.checkpoint import load_checkpoint + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=1000, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="vits/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--tokens", + type=str, + default="data/tokens.txt", + help="""Path to vocabulary.""", + ) + + parser.add_argument( + "--speakers", + type=Path, + default=Path("data/speakers.txt"), + help="Path to speakers.txt file.", + ) + + parser.add_argument( + "--model-type", + type=str, + default="low", + choices=["low", "medium", "high"], + help="""If not empty, valid values are: low, medium, high. + It controls the model size. low -> runs faster. + """, + ) + + return parser + + +def add_meta_data(filename: str, meta_data: Dict[str, str]): + """Add meta data to an ONNX model. It is changed in-place. + + Args: + filename: + Filename of the ONNX model to be changed. + meta_data: + Key-value pairs. + """ + model = onnx.load(filename) + for key, value in meta_data.items(): + meta = model.metadata_props.add() + meta.key = key + meta.value = str(value) + + onnx.save(model, filename) + + +class OnnxModel(nn.Module): + """A wrapper for VITS generator.""" + + def __init__(self, model: nn.Module): + """ + Args: + model: + A VITS generator. + frame_shift: + The frame shift in samples. + """ + super().__init__() + self.model = model + + def forward( + self, + tokens: torch.Tensor, + tokens_lens: torch.Tensor, + noise_scale: float = 0.667, + alpha: float = 1.0, + noise_scale_dur: float = 0.8, + speaker: int = 0, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Please see the help information of VITS.inference_batch + + Args: + tokens: + Input text token indexes (1, T_text) + tokens_lens: + Number of tokens of shape (1,) + noise_scale (float): + Noise scale parameter for flow. + noise_scale_dur (float): + Noise scale parameter for duration predictor. + speaker (int): + Speaker ID. + alpha (float): + Alpha parameter to control the speed of generated speech. + + Returns: + Return a tuple containing: + - audio, generated wavform tensor, (B, T_wav) + """ + audio, _, _ = self.model.generator.inference( + text=tokens, + text_lengths=tokens_lens, + noise_scale=noise_scale, + noise_scale_dur=noise_scale_dur, + sids=speaker, + alpha=alpha, + ) + return audio + + +def export_model_onnx( + model: nn.Module, + model_filename: str, + vocab_size: int, + opset_version: int = 11, +) -> None: + """Export the given generator model to ONNX format. + The exported model has one input: + + - tokens, a tensor of shape (1, T_text); dtype is torch.int64 + + and it has one output: + + - audio, a tensor of shape (1, T'); dtype is torch.float32 + + Args: + model: + The VITS generator. + model_filename: + The filename to save the exported ONNX model. + vocab_size: + Number of tokens used in training. + opset_version: + The opset version to use. + """ + tokens = torch.randint(low=0, high=vocab_size, size=(1, 13), dtype=torch.int64) + tokens_lens = torch.tensor([tokens.shape[1]], dtype=torch.int64) + noise_scale = torch.tensor([1], dtype=torch.float32) + noise_scale_dur = torch.tensor([1], dtype=torch.float32) + alpha = torch.tensor([1], dtype=torch.float32) + speaker = torch.tensor([1], dtype=torch.int64) + + torch.onnx.export( + model, + (tokens, tokens_lens, noise_scale, alpha, noise_scale_dur, speaker), + model_filename, + verbose=False, + opset_version=opset_version, + input_names=[ + "tokens", + "tokens_lens", + "noise_scale", + "alpha", + "noise_scale_dur", + "speaker", + ], + output_names=["audio"], + dynamic_axes={ + "tokens": {0: "N", 1: "T"}, + "tokens_lens": {0: "N"}, + "audio": {0: "N", 1: "T"}, + "speaker": {0: "N"}, + }, + ) + + if model.model.spks is None: + num_speakers = 1 + else: + num_speakers = model.model.spks + + meta_data = { + "model_type": "vits", + "version": "1", + "model_author": "k2-fsa", + "comment": "icefall", # must be icefall for models from icefall + "language": "Chinese", + "n_speakers": num_speakers, + "sample_rate": model.model.sampling_rate, # Must match the real sample rate + } + logging.info(f"meta_data: {meta_data}") + + add_meta_data(filename=model_filename, meta_data=meta_data) + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + tokenizer = Tokenizer(params.tokens) + params.blank_id = tokenizer.pad_id + params.vocab_size = tokenizer.vocab_size + + with open(args.speakers) as f: + speaker_map = {line.strip(): i for i, line in enumerate(f)} + params.num_spks = len(speaker_map) + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + + model.to("cpu") + model.eval() + + model = OnnxModel(model=model) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"generator parameters: {num_param}, or {num_param/1000/1000} M") + + suffix = f"epoch-{params.epoch}" + + opset_version = 13 + + logging.info("Exporting encoder") + model_filename = params.exp_dir / f"vits-{suffix}.onnx" + export_model_onnx( + model, + model_filename, + params.vocab_size, + opset_version=opset_version, + ) + logging.info(f"Exported generator to {model_filename}") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + logging.basicConfig(format=formatter, level=logging.INFO) + main() + +""" +Supported languages. + +LJSpeech is using "en-us" from the second column. + +Pty Language Age/Gender VoiceName File Other Languages + 5 af --/M Afrikaans gmw/af + 5 am --/M Amharic sem/am + 5 an --/M Aragonese roa/an + 5 ar --/M Arabic sem/ar + 5 as --/M Assamese inc/as + 5 az --/M Azerbaijani trk/az + 5 ba --/M Bashkir trk/ba + 5 be --/M Belarusian zle/be + 5 bg --/M Bulgarian zls/bg + 5 bn --/M Bengali inc/bn + 5 bpy --/M Bishnupriya_Manipuri inc/bpy + 5 bs --/M Bosnian zls/bs + 5 ca --/M Catalan roa/ca + 5 chr-US-Qaaa-x-west --/M Cherokee_ iro/chr + 5 cmn --/M Chinese_(Mandarin,_latin_as_English) sit/cmn (zh-cmn 5)(zh 5) + 5 cmn-latn-pinyin --/M Chinese_(Mandarin,_latin_as_Pinyin) sit/cmn-Latn-pinyin (zh-cmn 5)(zh 5) + 5 cs --/M Czech zlw/cs + 5 cv --/M Chuvash trk/cv + 5 cy --/M Welsh cel/cy + 5 da --/M Danish gmq/da + 5 de --/M German gmw/de + 5 el --/M Greek grk/el + 5 en-029 --/M English_(Caribbean) gmw/en-029 (en 10) + 2 en-gb --/M English_(Great_Britain) gmw/en (en 2) + 5 en-gb-scotland --/M English_(Scotland) gmw/en-GB-scotland (en 4) + 5 en-gb-x-gbclan --/M English_(Lancaster) gmw/en-GB-x-gbclan (en-gb 3)(en 5) + 5 en-gb-x-gbcwmd --/M English_(West_Midlands) gmw/en-GB-x-gbcwmd (en-gb 9)(en 9) + 5 en-gb-x-rp --/M English_(Received_Pronunciation) gmw/en-GB-x-rp (en-gb 4)(en 5) + 2 en-us --/M English_(America) gmw/en-US (en 3) + 5 en-us-nyc --/M English_(America,_New_York_City) gmw/en-US-nyc + 5 eo --/M Esperanto art/eo + 5 es --/M Spanish_(Spain) roa/es + 5 es-419 --/M Spanish_(Latin_America) roa/es-419 (es-mx 6) + 5 et --/M Estonian urj/et + 5 eu --/M Basque eu + 5 fa --/M Persian ira/fa + 5 fa-latn --/M Persian_(Pinglish) ira/fa-Latn + 5 fi --/M Finnish urj/fi + 5 fr-be --/M French_(Belgium) roa/fr-BE (fr 8) + 5 fr-ch --/M French_(Switzerland) roa/fr-CH (fr 8) + 5 fr-fr --/M French_(France) roa/fr (fr 5) + 5 ga --/M Gaelic_(Irish) cel/ga + 5 gd --/M Gaelic_(Scottish) cel/gd + 5 gn --/M Guarani sai/gn + 5 grc --/M Greek_(Ancient) grk/grc + 5 gu --/M Gujarati inc/gu + 5 hak --/M Hakka_Chinese sit/hak + 5 haw --/M Hawaiian map/haw + 5 he --/M Hebrew sem/he + 5 hi --/M Hindi inc/hi + 5 hr --/M Croatian zls/hr (hbs 5) + 5 ht --/M Haitian_Creole roa/ht + 5 hu --/M Hungarian urj/hu + 5 hy --/M Armenian_(East_Armenia) ine/hy (hy-arevela 5) + 5 hyw --/M Armenian_(West_Armenia) ine/hyw (hy-arevmda 5)(hy 8) + 5 ia --/M Interlingua art/ia + 5 id --/M Indonesian poz/id + 5 io --/M Ido art/io + 5 is --/M Icelandic gmq/is + 5 it --/M Italian roa/it + 5 ja --/M Japanese jpx/ja + 5 jbo --/M Lojban art/jbo + 5 ka --/M Georgian ccs/ka + 5 kk --/M Kazakh trk/kk + 5 kl --/M Greenlandic esx/kl + 5 kn --/M Kannada dra/kn + 5 ko --/M Korean ko + 5 kok --/M Konkani inc/kok + 5 ku --/M Kurdish ira/ku + 5 ky --/M Kyrgyz trk/ky + 5 la --/M Latin itc/la + 5 lb --/M Luxembourgish gmw/lb + 5 lfn --/M Lingua_Franca_Nova art/lfn + 5 lt --/M Lithuanian bat/lt + 5 ltg --/M Latgalian bat/ltg + 5 lv --/M Latvian bat/lv + 5 mi --/M Māori poz/mi + 5 mk --/M Macedonian zls/mk + 5 ml --/M Malayalam dra/ml + 5 mr --/M Marathi inc/mr + 5 ms --/M Malay poz/ms + 5 mt --/M Maltese sem/mt + 5 mto --/M Totontepec_Mixe miz/mto + 5 my --/M Myanmar_(Burmese) sit/my + 5 nb --/M Norwegian_Bokmål gmq/nb (no 5) + 5 nci --/M Nahuatl_(Classical) azc/nci + 5 ne --/M Nepali inc/ne + 5 nl --/M Dutch gmw/nl + 5 nog --/M Nogai trk/nog + 5 om --/M Oromo cus/om + 5 or --/M Oriya inc/or + 5 pa --/M Punjabi inc/pa + 5 pap --/M Papiamento roa/pap + 5 piqd --/M Klingon art/piqd + 5 pl --/M Polish zlw/pl + 5 pt --/M Portuguese_(Portugal) roa/pt (pt-pt 5) + 5 pt-br --/M Portuguese_(Brazil) roa/pt-BR (pt 6) + 5 py --/M Pyash art/py + 5 qdb --/M Lang_Belta art/qdb + 5 qu --/M Quechua qu + 5 quc --/M K'iche' myn/quc + 5 qya --/M Quenya art/qya + 5 ro --/M Romanian roa/ro + 5 ru --/M Russian zle/ru + 5 ru-cl --/M Russian_(Classic) zle/ru-cl + 2 ru-lv --/M Russian_(Latvia) zle/ru-LV + 5 sd --/M Sindhi inc/sd + 5 shn --/M Shan_(Tai_Yai) tai/shn + 5 si --/M Sinhala inc/si + 5 sjn --/M Sindarin art/sjn + 5 sk --/M Slovak zlw/sk + 5 sl --/M Slovenian zls/sl + 5 smj --/M Lule_Saami urj/smj + 5 sq --/M Albanian ine/sq + 5 sr --/M Serbian zls/sr + 5 sv --/M Swedish gmq/sv + 5 sw --/M Swahili bnt/sw + 5 ta --/M Tamil dra/ta + 5 te --/M Telugu dra/te + 5 th --/M Thai tai/th + 5 tk --/M Turkmen trk/tk + 5 tn --/M Setswana bnt/tn + 5 tr --/M Turkish trk/tr + 5 tt --/M Tatar trk/tt + 5 ug --/M Uyghur trk/ug + 5 uk --/M Ukrainian zle/uk + 5 ur --/M Urdu inc/ur + 5 uz --/M Uzbek trk/uz + 5 vi --/M Vietnamese_(Northern) aav/vi + 5 vi-vn-x-central --/M Vietnamese_(Central) aav/vi-VN-x-central + 5 vi-vn-x-south --/M Vietnamese_(Southern) aav/vi-VN-x-south + 5 yue --/M Chinese_(Cantonese) sit/yue (zh-yue 5)(zh 8) + 5 yue --/M Chinese_(Cantonese,_latin_as_Jyutping) sit/yue-Latn-jyutping (zh-yue 5)(zh 8) +""" diff --git a/egs/aishell3/TTS/vits/flow.py b/egs/aishell3/TTS/vits/flow.py new file mode 120000 index 0000000000..e65d91ea75 --- /dev/null +++ b/egs/aishell3/TTS/vits/flow.py @@ -0,0 +1 @@ +../../../ljspeech/TTS/vits/flow.py \ No newline at end of file diff --git a/egs/aishell3/TTS/vits/generator.py b/egs/aishell3/TTS/vits/generator.py new file mode 120000 index 0000000000..611679bfa8 --- /dev/null +++ b/egs/aishell3/TTS/vits/generator.py @@ -0,0 +1 @@ +../../../ljspeech/TTS/vits/generator.py \ No newline at end of file diff --git a/egs/aishell3/TTS/vits/hifigan.py b/egs/aishell3/TTS/vits/hifigan.py new file mode 120000 index 0000000000..5ac025de72 --- /dev/null +++ b/egs/aishell3/TTS/vits/hifigan.py @@ -0,0 +1 @@ +../../../ljspeech/TTS/vits/hifigan.py \ No newline at end of file diff --git a/egs/aishell3/TTS/vits/loss.py b/egs/aishell3/TTS/vits/loss.py new file mode 120000 index 0000000000..672e5ff68d --- /dev/null +++ b/egs/aishell3/TTS/vits/loss.py @@ -0,0 +1 @@ +../../../ljspeech/TTS/vits/loss.py \ No newline at end of file diff --git a/egs/aishell3/TTS/vits/monotonic_align b/egs/aishell3/TTS/vits/monotonic_align new file mode 120000 index 0000000000..2c4923075e --- /dev/null +++ b/egs/aishell3/TTS/vits/monotonic_align @@ -0,0 +1 @@ +../../../ljspeech/TTS/vits/monotonic_align/ \ No newline at end of file diff --git a/egs/aishell3/TTS/vits/pinyin_dict.py b/egs/aishell3/TTS/vits/pinyin_dict.py new file mode 120000 index 0000000000..b8683bd2dc --- /dev/null +++ b/egs/aishell3/TTS/vits/pinyin_dict.py @@ -0,0 +1 @@ +../local/pinyin_dict.py \ No newline at end of file diff --git a/egs/aishell3/TTS/vits/posterior_encoder.py b/egs/aishell3/TTS/vits/posterior_encoder.py new file mode 120000 index 0000000000..41d64a3a66 --- /dev/null +++ b/egs/aishell3/TTS/vits/posterior_encoder.py @@ -0,0 +1 @@ +../../../ljspeech/TTS/vits/posterior_encoder.py \ No newline at end of file diff --git a/egs/aishell3/TTS/vits/pypinyin-local.dict b/egs/aishell3/TTS/vits/pypinyin-local.dict new file mode 120000 index 0000000000..5bc9b77282 --- /dev/null +++ b/egs/aishell3/TTS/vits/pypinyin-local.dict @@ -0,0 +1 @@ +../local/pypinyin-local.dict \ No newline at end of file diff --git a/egs/aishell3/TTS/vits/residual_coupling.py b/egs/aishell3/TTS/vits/residual_coupling.py new file mode 120000 index 0000000000..f979adbf00 --- /dev/null +++ b/egs/aishell3/TTS/vits/residual_coupling.py @@ -0,0 +1 @@ +../../../ljspeech/TTS/vits/residual_coupling.py \ No newline at end of file diff --git a/egs/aishell3/TTS/vits/text_encoder.py b/egs/aishell3/TTS/vits/text_encoder.py new file mode 120000 index 0000000000..0efba277e1 --- /dev/null +++ b/egs/aishell3/TTS/vits/text_encoder.py @@ -0,0 +1 @@ +../../../ljspeech/TTS/vits/text_encoder.py \ No newline at end of file diff --git a/egs/aishell3/TTS/vits/tokenizer.py b/egs/aishell3/TTS/vits/tokenizer.py new file mode 120000 index 0000000000..0368e07d34 --- /dev/null +++ b/egs/aishell3/TTS/vits/tokenizer.py @@ -0,0 +1 @@ +../local/tokenizer.py \ No newline at end of file diff --git a/egs/aishell3/TTS/vits/train.py b/egs/aishell3/TTS/vits/train.py new file mode 100755 index 0000000000..ad30384855 --- /dev/null +++ b/egs/aishell3/TTS/vits/train.py @@ -0,0 +1,1007 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import logging +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, Optional, Tuple, Union + +import k2 +import numpy as np +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from lhotse.cut import Cut +from lhotse.utils import fix_random_seed +from tokenizer import Tokenizer +from torch.cuda.amp import GradScaler, autocast +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.optim import Optimizer +from torch.utils.tensorboard import SummaryWriter +from tts_datamodule import Aishell3SpeechTtsDataModule +from utils import MetricsTracker, plot_feature, save_checkpoint +from vits import VITS + +from icefall import diagnostics +from icefall.checkpoint import load_checkpoint +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.hooks import register_inf_check_hooks +from icefall.utils import AttributeDict, setup_logger, str2bool + +LRSchedulerType = torch.optim.lr_scheduler._LRScheduler + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=1000, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="vits/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--tokens", + type=str, + default="data/tokens.txt", + help="""Path to vocabulary.""", + ) + + parser.add_argument( + "--lr", type=float, default=2.0e-4, help="The base learning rate." + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=20, + help="""Save checkpoint after processing this number of epochs" + periodically. We save checkpoint to exp-dir/ whenever + params.cur_epoch % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/epoch-{params.cur_epoch}.pt'. + Since it will take around 1000 epochs, we suggest using a large + save_every_n to save disk space. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + parser.add_argument( + "--model-type", + type=str, + default="low", + choices=["low", "medium", "high"], + help="""If not empty, valid values are: low, medium, high. + It controls the model size. low -> runs faster. + """, + ) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + """ + params = AttributeDict( + { + # training params + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": -1, # 0 + "log_interval": 50, + "valid_interval": 200, + "env_info": get_env_info(), + "sampling_rate": 8000, + "frame_shift": 256, + "frame_length": 1024, + "feature_dim": 513, # 1024 // 2 + 1, 1024 is fft_length + "n_mels": 80, + "lambda_adv": 1.0, # loss scaling coefficient for adversarial loss + "lambda_mel": 45.0, # loss scaling coefficient for Mel loss + "lambda_feat_match": 2.0, # loss scaling coefficient for feat match loss + "lambda_dur": 1.0, # loss scaling coefficient for duration loss + "lambda_kl": 1.0, # loss scaling coefficient for KL divergence loss + } + ) + + return params + + +def load_checkpoint_if_available( + params: AttributeDict, model: nn.Module +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint(filename, model=model) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + return saved_params + + +def get_model(params: AttributeDict) -> nn.Module: + mel_loss_params = { + "n_mels": params.n_mels, + "frame_length": params.frame_length, + "frame_shift": params.frame_shift, + } + generator_params = { + "hidden_channels": 192, + "spks": params.num_spks, + "langs": None, + "spk_embed_dim": None, + "global_channels": 256, + "segment_size": 32, + "text_encoder_attention_heads": 2, + "text_encoder_ffn_expand": 4, + "text_encoder_cnn_module_kernel": 5, + "text_encoder_blocks": 6, + "text_encoder_dropout_rate": 0.1, + "decoder_kernel_size": 7, + "decoder_channels": 512, + "decoder_upsample_scales": [8, 8, 2, 2], + "decoder_upsample_kernel_sizes": [16, 16, 4, 4], + "decoder_resblock_kernel_sizes": [3, 7, 11], + "decoder_resblock_dilations": [[1, 3, 5], [1, 3, 5], [1, 3, 5]], + "use_weight_norm_in_decoder": True, + "posterior_encoder_kernel_size": 5, + "posterior_encoder_layers": 16, + "posterior_encoder_stacks": 1, + "posterior_encoder_base_dilation": 1, + "posterior_encoder_dropout_rate": 0.0, + "use_weight_norm_in_posterior_encoder": True, + "flow_flows": 4, + "flow_kernel_size": 5, + "flow_base_dilation": 1, + "flow_layers": 4, + "flow_dropout_rate": 0.0, + "use_weight_norm_in_flow": True, + "use_only_mean_in_flow": True, + "stochastic_duration_predictor_kernel_size": 3, + "stochastic_duration_predictor_dropout_rate": 0.5, + "stochastic_duration_predictor_flows": 4, + "stochastic_duration_predictor_dds_conv_layers": 3, + } + model = VITS( + vocab_size=params.vocab_size, + feature_dim=params.feature_dim, + sampling_rate=params.sampling_rate, + generator_params=generator_params, + model_type=params.model_type, + mel_loss_params=mel_loss_params, + lambda_adv=params.lambda_adv, + lambda_mel=params.lambda_mel, + lambda_feat_match=params.lambda_feat_match, + lambda_dur=params.lambda_dur, + lambda_kl=params.lambda_kl, + ) + return model + + +def prepare_input( + batch: dict, + tokenizer: Tokenizer, + device: torch.device, + speaker_map: Dict[str, int], +): + """Parse batch data""" + audio = batch["audio"].to(device) + features = batch["features"].to(device) + audio_lens = batch["audio_lens"].to(device) + features_lens = batch["features_lens"].to(device) + tokens = batch["tokens"] + speakers = ( + torch.Tensor([speaker_map.get(sid, 0) for sid in batch["speakers"]]) + .int() + .to(device) + ) + + tokens = tokenizer.tokens_to_token_ids(tokens) + tokens = k2.RaggedTensor(tokens) + row_splits = tokens.shape.row_splits(1) + tokens_lens = row_splits[1:] - row_splits[:-1] + tokens = tokens.to(device) + tokens_lens = tokens_lens.to(device) + # a tensor of shape (B, T) + tokens = tokens.pad(mode="constant", padding_value=tokenizer.pad_id) + + return audio, audio_lens, features, features_lens, tokens, tokens_lens, speakers + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + tokenizer: Tokenizer, + optimizer_g: Optimizer, + optimizer_d: Optimizer, + scheduler_g: LRSchedulerType, + scheduler_d: LRSchedulerType, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + speaker_map: Dict[str, int], + scaler: GradScaler, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + tokenizer: + Used to convert text to phonemes. + optimizer_g: + The optimizer for generator. + optimizer_d: + The optimizer for discriminator. + scheduler_g: + The learning rate scheduler for generator, we call step() every epoch. + scheduler_d: + The learning rate scheduler for discriminator, we call step() every epoch. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + + # used to track the stats over iterations in one epoch + tot_loss = MetricsTracker() + + saved_bad_model = False + + def save_bad_model(suffix: str = ""): + save_checkpoint( + filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", + model=model, + params=params, + optimizer_g=optimizer_g, + optimizer_d=optimizer_d, + scheduler_g=scheduler_g, + scheduler_d=scheduler_d, + sampler=train_dl.sampler, + scaler=scaler, + rank=0, + ) + + for batch_idx, batch in enumerate(train_dl): + params.batch_idx_train += 1 + + batch_size = len(batch["tokens"]) + ( + audio, + audio_lens, + features, + features_lens, + tokens, + tokens_lens, + speakers, + ) = prepare_input(batch, tokenizer, device, speaker_map) + + loss_info = MetricsTracker() + loss_info["samples"] = batch_size + + try: + with autocast(enabled=params.use_fp16): + # forward discriminator + loss_d, stats_d = model( + text=tokens, + text_lengths=tokens_lens, + feats=features, + feats_lengths=features_lens, + speech=audio, + speech_lengths=audio_lens, + sids=speakers, + forward_generator=False, + ) + for k, v in stats_d.items(): + loss_info[k] = v * batch_size + # update discriminator + optimizer_d.zero_grad() + scaler.scale(loss_d).backward() + scaler.step(optimizer_d) + + with autocast(enabled=params.use_fp16): + # forward generator + loss_g, stats_g = model( + text=tokens, + text_lengths=tokens_lens, + feats=features, + feats_lengths=features_lens, + speech=audio, + speech_lengths=audio_lens, + sids=speakers, + forward_generator=True, + return_sample=params.batch_idx_train % params.log_interval == 0, + ) + for k, v in stats_g.items(): + if "returned_sample" not in k: + loss_info[k] = v * batch_size + # update generator + optimizer_g.zero_grad() + scaler.scale(loss_g).backward() + scaler.step(optimizer_g) + scaler.update() + + # summary stats + tot_loss = tot_loss + loss_info + except: # noqa + save_bad_model() + raise + + if params.print_diagnostics and batch_idx == 5: + return + + if params.batch_idx_train % 100 == 0 and params.use_fp16: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + + if cur_grad_scale < 8.0 or ( + cur_grad_scale < 32.0 and params.batch_idx_train % 400 == 0 + ): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + if not saved_bad_model: + save_bad_model(suffix="-first-warning") + saved_bad_model = True + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + save_bad_model() + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + + if params.batch_idx_train % params.log_interval == 0: + cur_lr_g = max(scheduler_g.get_last_lr()) + cur_lr_d = max(scheduler_d.get_last_lr()) + cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + + logging.info( + f"Epoch {params.cur_epoch}, batch {batch_idx}, " + f"global_batch_idx: {params.batch_idx_train}, batch size: {batch_size}, " + f"loss[{loss_info}], tot_loss[{tot_loss}], " + f"cur_lr_g: {cur_lr_g:.2e}, cur_lr_d: {cur_lr_d:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate_g", cur_lr_g, params.batch_idx_train + ) + tb_writer.add_scalar( + "train/learning_rate_d", cur_lr_d, params.batch_idx_train + ) + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", cur_grad_scale, params.batch_idx_train + ) + if "returned_sample" in stats_g: + speech_hat_, speech_, mel_hat_, mel_ = stats_g["returned_sample"] + tb_writer.add_audio( + "train/speech_hat_", + speech_hat_, + params.batch_idx_train, + params.sampling_rate, + ) + tb_writer.add_audio( + "train/speech_", + speech_, + params.batch_idx_train, + params.sampling_rate, + ) + tb_writer.add_image( + "train/mel_hat_", + plot_feature(mel_hat_), + params.batch_idx_train, + dataformats="HWC", + ) + tb_writer.add_image( + "train/mel_", + plot_feature(mel_), + params.batch_idx_train, + dataformats="HWC", + ) + + if ( + params.batch_idx_train % params.valid_interval == 0 + and not params.print_diagnostics + ): + logging.info("Computing validation loss") + valid_info, (speech_hat, speech) = compute_validation_loss( + params=params, + model=model, + tokenizer=tokenizer, + valid_dl=valid_dl, + speaker_map=speaker_map, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + tb_writer.add_audio( + "train/valdi_speech_hat", + speech_hat, + params.batch_idx_train, + params.sampling_rate, + ) + tb_writer.add_audio( + "train/valdi_speech", + speech, + params.batch_idx_train, + params.sampling_rate, + ) + + loss_value = tot_loss["generator_loss"] / tot_loss["samples"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + tokenizer: Tokenizer, + valid_dl: torch.utils.data.DataLoader, + speaker_map: Dict[str, int], + world_size: int = 1, + rank: int = 0, +) -> Tuple[MetricsTracker, Tuple[np.ndarray, np.ndarray]]: + """Run the validation process.""" + model.eval() + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + + # used to summary the stats over iterations + tot_loss = MetricsTracker() + returned_sample = None + + with torch.no_grad(): + for batch_idx, batch in enumerate(valid_dl): + batch_size = len(batch["tokens"]) + ( + audio, + audio_lens, + features, + features_lens, + tokens, + tokens_lens, + speakers, + ) = prepare_input(batch, tokenizer, device, speaker_map) + + loss_info = MetricsTracker() + loss_info["samples"] = batch_size + + # forward discriminator + loss_d, stats_d = model( + text=tokens, + text_lengths=tokens_lens, + feats=features, + feats_lengths=features_lens, + speech=audio, + speech_lengths=audio_lens, + sids=speakers, + forward_generator=False, + ) + assert loss_d.requires_grad is False + for k, v in stats_d.items(): + loss_info[k] = v * batch_size + + # forward generator + loss_g, stats_g = model( + text=tokens, + text_lengths=tokens_lens, + feats=features, + feats_lengths=features_lens, + speech=audio, + speech_lengths=audio_lens, + sids=speakers, + forward_generator=True, + ) + assert loss_g.requires_grad is False + for k, v in stats_g.items(): + loss_info[k] = v * batch_size + + # summary stats + tot_loss = tot_loss + loss_info + + # infer for first batch: + if batch_idx == 0 and rank == 0: + inner_model = model.module if isinstance(model, DDP) else model + audio_pred, _, duration = inner_model.inference( + text=tokens[0, : tokens_lens[0].item()], + sids=speakers[0], + ) + audio_pred = audio_pred.data.cpu().numpy() + audio_len_pred = ( + (duration.sum(0) * params.frame_shift).to(dtype=torch.int64).item() + ) + assert audio_len_pred == len(audio_pred), ( + audio_len_pred, + len(audio_pred), + ) + audio_gt = audio[0, : audio_lens[0].item()].data.cpu().numpy() + returned_sample = (audio_pred, audio_gt) + + if world_size > 1: + tot_loss.reduce(device) + + loss_value = tot_loss["generator_loss"] / tot_loss["samples"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss, returned_sample + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + tokenizer: Tokenizer, + optimizer_g: torch.optim.Optimizer, + optimizer_d: torch.optim.Optimizer, + speaker_map: Dict[str, int], + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + ( + audio, + audio_lens, + features, + features_lens, + tokens, + tokens_lens, + speakers, + ) = prepare_input(batch, tokenizer, device, speaker_map) + try: + # for discriminator + with autocast(enabled=params.use_fp16): + loss_d, stats_d = model( + text=tokens, + text_lengths=tokens_lens, + feats=features, + feats_lengths=features_lens, + speech=audio, + speech_lengths=audio_lens, + sids=speakers, + forward_generator=False, + ) + optimizer_d.zero_grad() + loss_d.backward() + # for generator + with autocast(enabled=params.use_fp16): + loss_g, stats_g = model( + text=tokens, + text_lengths=tokens_lens, + feats=features, + feats_lengths=features_lens, + speech=audio, + speech_lengths=audio_lens, + sids=speakers, + forward_generator=True, + ) + optimizer_g.zero_grad() + loss_g.backward() + except Exception as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + raise + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + tokenizer = Tokenizer(params.tokens) + params.blank_id = tokenizer.pad_id + params.vocab_size = tokenizer.vocab_size + + aishell3 = Aishell3SpeechTtsDataModule(args) + assert aishell3.sampling_rate == params.sampling_rate, ( + aishell3.sampling_rate, + params.sampling_rate, + ) + speaker_map = aishell3.speakers() + params.num_spks = len(speaker_map) + + logging.info("About to create model") + model = get_model(params) + generator = model.generator + discriminator = model.discriminator + + num_param_g = sum([p.numel() for p in generator.parameters()]) + logging.info(f"Number of parameters in generator: {num_param_g}") + num_param_d = sum([p.numel() for p in discriminator.parameters()]) + logging.info(f"Number of parameters in discriminator: {num_param_d}") + logging.info(f"Total number of parameters: {num_param_g + num_param_d}") + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available(params=params, model=model) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + optimizer_g = torch.optim.AdamW( + generator.parameters(), lr=params.lr, betas=(0.8, 0.99), eps=1e-9 + ) + optimizer_d = torch.optim.AdamW( + discriminator.parameters(), lr=params.lr, betas=(0.8, 0.99), eps=1e-9 + ) + + scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optimizer_g, gamma=0.999875) + scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optimizer_d, gamma=0.999875) + + if checkpoints is not None: + # load state_dict for optimizers + if "optimizer_g" in checkpoints: + logging.info("Loading optimizer_g state dict") + optimizer_g.load_state_dict(checkpoints["optimizer_g"]) + if "optimizer_d" in checkpoints: + logging.info("Loading optimizer_d state dict") + optimizer_d.load_state_dict(checkpoints["optimizer_d"]) + + # load state_dict for schedulers + if "scheduler_g" in checkpoints: + logging.info("Loading scheduler_g state dict") + scheduler_g.load_state_dict(checkpoints["scheduler_g"]) + if "scheduler_d" in checkpoints: + logging.info("Loading scheduler_d state dict") + scheduler_d.load_state_dict(checkpoints["scheduler_d"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 512 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + train_cuts = aishell3.train_cuts() + + logging.info(params) + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 20 seconds + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + if c.duration < 1.0 or c.duration > 20.0: + # logging.warning( + # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + # ) + return False + return True + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + train_dl = aishell3.train_dataloaders(train_cuts) + + valid_cuts = aishell3.valid_cuts() + valid_dl = aishell3.valid_dataloaders(valid_cuts) + + if not params.print_diagnostics: + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + tokenizer=tokenizer, + optimizer_g=optimizer_g, + optimizer_d=optimizer_d, + speaker_map=speaker_map, + params=params, + ) + + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + logging.info(f"Start epoch {epoch}") + + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + params.cur_epoch = epoch + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + train_one_epoch( + params=params, + model=model, + tokenizer=tokenizer, + optimizer_g=optimizer_g, + optimizer_d=optimizer_d, + scheduler_g=scheduler_g, + scheduler_d=scheduler_d, + train_dl=train_dl, + valid_dl=valid_dl, + speaker_map=speaker_map, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + if epoch % params.save_every_n == 0 or epoch == params.num_epochs: + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint( + filename=filename, + params=params, + model=model, + optimizer_g=optimizer_g, + optimizer_d=optimizer_d, + scheduler_g=scheduler_g, + scheduler_d=scheduler_d, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + if rank == 0: + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + # step per epoch + scheduler_g.step() + scheduler_d.step() + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def main(): + parser = get_parser() + Aishell3SpeechTtsDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/aishell3/TTS/vits/transform.py b/egs/aishell3/TTS/vits/transform.py new file mode 120000 index 0000000000..962647408b --- /dev/null +++ b/egs/aishell3/TTS/vits/transform.py @@ -0,0 +1 @@ +../../../ljspeech/TTS/vits/transform.py \ No newline at end of file diff --git a/egs/aishell3/TTS/vits/tts_datamodule.py b/egs/aishell3/TTS/vits/tts_datamodule.py new file mode 100644 index 0000000000..a08c645382 --- /dev/null +++ b/egs/aishell3/TTS/vits/tts_datamodule.py @@ -0,0 +1,349 @@ +# Copyright 2021 Piotr Żelasko +# Copyright 2022-2023 Xiaomi Corporation (Authors: Mingshuang Luo, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import logging +from functools import lru_cache +from pathlib import Path +from typing import Any, Dict, Optional + +import torch +from lhotse import CutSet, Spectrogram, SpectrogramConfig, load_manifest_lazy +from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures + CutConcatenate, + CutMix, + DynamicBucketingSampler, + PrecomputedFeatures, + SimpleCutSampler, + SpecAugment, + SpeechSynthesisDataset, +) +from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples + AudioSamples, + OnTheFlyFeatures, +) +from lhotse.utils import fix_random_seed +from torch.utils.data import DataLoader + +from icefall.utils import str2bool + + +class _SeedWorkers: + def __init__(self, seed: int): + self.seed = seed + + def __call__(self, worker_id: int): + fix_random_seed(self.seed + worker_id) + + +class Aishell3SpeechTtsDataModule: + """ + DataModule for tts experiments. + It assumes there is always one train and valid dataloader, + but there can be multiple test dataloaders (e.g. LibriSpeech test-clean + and test-other). + + It contains all the common data pipeline modules used in ASR + experiments, e.g.: + - dynamic batch size, + - bucketing samplers, + - cut concatenation, + - on-the-fly feature extraction + + This class should be derived for specific corpora used in TTS tasks. + """ + + def __init__(self, args: argparse.Namespace): + self.args = args + self.sampling_rate = 8000 + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="TTS data related options", + description="These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc.", + ) + + group.add_argument( + "--manifest-dir", + type=Path, + default=Path("data/spectrogram"), + help="Path to directory with train/valid/test cuts.", + ) + group.add_argument( + "--speakers", + type=Path, + default=Path("data/speakers.txt"), + help="Path to speakers.txt file.", + ) + group.add_argument( + "--max-duration", + type=int, + default=200.0, + help="Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM.", + ) + group.add_argument( + "--bucketing-sampler", + type=str2bool, + default=True, + help="When enabled, the batches will come from buckets of " + "similar duration (saves padding frames).", + ) + group.add_argument( + "--num-buckets", + type=int, + default=30, + help="The number of buckets for the DynamicBucketingSampler" + "(you might want to increase it for larger datasets).", + ) + + group.add_argument( + "--on-the-fly-feats", + type=str2bool, + default=False, + help="When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available.", + ) + group.add_argument( + "--shuffle", + type=str2bool, + default=True, + help="When enabled (=default), the examples will be " + "shuffled for each epoch.", + ) + group.add_argument( + "--drop-last", + type=str2bool, + default=True, + help="Whether to drop last batch. Used by sampler.", + ) + group.add_argument( + "--return-cuts", + type=str2bool, + default=False, + help="When enabled, each batch will have the " + "field: batch['cut'] with the cuts that " + "were used to construct it.", + ) + group.add_argument( + "--num-workers", + type=int, + default=2, + help="The number of training dataloader workers that " + "collect the batches.", + ) + + group.add_argument( + "--input-strategy", + type=str, + default="PrecomputedFeatures", + help="AudioSamples or PrecomputedFeatures", + ) + + def train_dataloaders( + self, + cuts_train: CutSet, + sampler_state_dict: Optional[Dict[str, Any]] = None, + ) -> DataLoader: + """ + Args: + cuts_train: + CutSet for training. + sampler_state_dict: + The state dict for the training sampler. + """ + logging.info("About to create train dataset") + train = SpeechSynthesisDataset( + return_text=False, + return_tokens=True, + return_spk_ids=True, + feature_input_strategy=eval(self.args.input_strategy)(), + return_cuts=self.args.return_cuts, + ) + + if self.args.on_the_fly_feats: + sampling_rate = self.sampling_rate + config = SpectrogramConfig( + sampling_rate=sampling_rate, + frame_length=1024 / sampling_rate, # (in second), + frame_shift=256 / sampling_rate, # (in second) + use_fft_mag=True, + ) + train = SpeechSynthesisDataset( + return_text=False, + return_tokens=True, + return_spk_ids=True, + feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)), + return_cuts=self.args.return_cuts, + ) + + if self.args.bucketing_sampler: + logging.info("Using DynamicBucketingSampler.") + train_sampler = DynamicBucketingSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + num_buckets=self.args.num_buckets, + buffer_size=self.args.num_buckets * 2000, + shuffle_buffer_size=self.args.num_buckets * 5000, + drop_last=self.args.drop_last, + ) + else: + logging.info("Using SimpleCutSampler.") + train_sampler = SimpleCutSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + ) + logging.info("About to create train dataloader") + + if sampler_state_dict is not None: + logging.info("Loading sampler state dict") + train_sampler.load_state_dict(sampler_state_dict) + + # 'seed' is derived from the current random state, which will have + # previously been set in the main process. + seed = torch.randint(0, 100000, ()).item() + worker_init_fn = _SeedWorkers(seed) + + train_dl = DataLoader( + train, + sampler=train_sampler, + batch_size=None, + num_workers=self.args.num_workers, + persistent_workers=False, + worker_init_fn=worker_init_fn, + ) + + return train_dl + + def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: + logging.info("About to create dev dataset") + if self.args.on_the_fly_feats: + sampling_rate = self.sampling_rate + config = SpectrogramConfig( + sampling_rate=sampling_rate, + frame_length=1024 / sampling_rate, # (in second), + frame_shift=256 / sampling_rate, # (in second) + use_fft_mag=True, + ) + validate = SpeechSynthesisDataset( + return_text=False, + return_tokens=True, + return_spk_ids=True, + feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)), + return_cuts=self.args.return_cuts, + ) + else: + validate = SpeechSynthesisDataset( + return_text=False, + return_tokens=True, + return_spk_ids=True, + feature_input_strategy=eval(self.args.input_strategy)(), + return_cuts=self.args.return_cuts, + ) + valid_sampler = DynamicBucketingSampler( + cuts_valid, + max_duration=self.args.max_duration, + num_buckets=self.args.num_buckets, + shuffle=False, + ) + logging.info("About to create valid dataloader") + valid_dl = DataLoader( + validate, + sampler=valid_sampler, + batch_size=None, + num_workers=2, + persistent_workers=False, + ) + + return valid_dl + + def test_dataloaders(self, cuts: CutSet) -> DataLoader: + logging.info("About to create test dataset") + if self.args.on_the_fly_feats: + sampling_rate = self.sampling_rate + config = SpectrogramConfig( + sampling_rate=sampling_rate, + frame_length=1024 / sampling_rate, # (in second), + frame_shift=256 / sampling_rate, # (in second) + use_fft_mag=True, + ) + test = SpeechSynthesisDataset( + return_text=False, + return_tokens=True, + return_spk_ids=True, + feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)), + return_cuts=self.args.return_cuts, + ) + else: + test = SpeechSynthesisDataset( + return_text=False, + return_tokens=True, + return_spk_ids=True, + feature_input_strategy=eval(self.args.input_strategy)(), + return_cuts=self.args.return_cuts, + ) + test_sampler = DynamicBucketingSampler( + cuts, + max_duration=self.args.max_duration, + num_buckets=self.args.num_buckets, + shuffle=False, + ) + logging.info("About to create test dataloader") + test_dl = DataLoader( + test, + batch_size=None, + sampler=test_sampler, + num_workers=self.args.num_workers, + ) + return test_dl + + @lru_cache() + def train_cuts(self) -> CutSet: + logging.info("About to get train cuts") + return load_manifest_lazy( + self.args.manifest_dir / "aishell3_cuts_train.jsonl.gz" + ) + + @lru_cache() + def valid_cuts(self) -> CutSet: + logging.info("About to get validation cuts") + return load_manifest_lazy( + self.args.manifest_dir / "aishell3_cuts_valid.jsonl.gz" + ) + + @lru_cache() + def test_cuts(self) -> CutSet: + logging.info("About to get test cuts") + return load_manifest_lazy( + self.args.manifest_dir / "aishell3_cuts_test.jsonl.gz" + ) + + @lru_cache() + def speakers(self) -> Dict[str, int]: + logging.info("About to get speakers") + with open(self.args.speakers) as f: + speakers = {line.strip(): i for i, line in enumerate(f)} + return speakers diff --git a/egs/aishell3/TTS/vits/utils.py b/egs/aishell3/TTS/vits/utils.py new file mode 120000 index 0000000000..085e764b43 --- /dev/null +++ b/egs/aishell3/TTS/vits/utils.py @@ -0,0 +1 @@ +../../../ljspeech/TTS/vits/utils.py \ No newline at end of file diff --git a/egs/aishell3/TTS/vits/vits.py b/egs/aishell3/TTS/vits/vits.py new file mode 120000 index 0000000000..1f58cf6fea --- /dev/null +++ b/egs/aishell3/TTS/vits/vits.py @@ -0,0 +1 @@ +../../../ljspeech/TTS/vits/vits.py \ No newline at end of file diff --git a/egs/aishell3/TTS/vits/wavenet.py b/egs/aishell3/TTS/vits/wavenet.py new file mode 120000 index 0000000000..28f0a78eeb --- /dev/null +++ b/egs/aishell3/TTS/vits/wavenet.py @@ -0,0 +1 @@ +../../../ljspeech/TTS/vits/wavenet.py \ No newline at end of file diff --git a/egs/ljspeech/TTS/vits/monotonic_align/setup.py b/egs/ljspeech/TTS/vits/monotonic_align/setup.py index 33d75e1765..dc9ddaf489 100644 --- a/egs/ljspeech/TTS/vits/monotonic_align/setup.py +++ b/egs/ljspeech/TTS/vits/monotonic_align/setup.py @@ -1,7 +1,10 @@ # https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/monotonic_align/setup.py """Setup cython code.""" -from Cython.Build import cythonize +try: + from Cython.Build import cythonize +except ModuleNotFoundError as ex: + raise RuntimeError(f'{ex}\nPlease run:\n pip install cython') from setuptools import Extension, setup from setuptools.command.build_ext import build_ext as _build_ext diff --git a/egs/ljspeech/TTS/vits/tokenizer.py b/egs/ljspeech/TTS/vits/tokenizer.py index 3c9046adde..f314cc3624 100644 --- a/egs/ljspeech/TTS/vits/tokenizer.py +++ b/egs/ljspeech/TTS/vits/tokenizer.py @@ -44,11 +44,11 @@ def __init__(self, tokens: str): if len(info) == 1: # case of space token = " " - id = int(info[0]) + idx = int(info[0]) else: - token, id = info[0], int(info[1]) + token, idx = info[0], int(info[1]) assert token not in self.token2id, token - self.token2id[token] = id + self.token2id[token] = idx # Refer to https://github.com/rhasspy/piper/blob/master/TRAINING.md self.pad_id = self.token2id["_"] # padding diff --git a/egs/ljspeech/TTS/vits/tts_datamodule.py b/egs/ljspeech/TTS/vits/tts_datamodule.py index e1a9c7b3ca..005e1da494 100644 --- a/egs/ljspeech/TTS/vits/tts_datamodule.py +++ b/egs/ljspeech/TTS/vits/tts_datamodule.py @@ -66,7 +66,7 @@ class LJSpeechTtsDataModule: - cut concatenation, - on-the-fly feature extraction - This class should be derived for specific corpora used in ASR tasks. + This class should be derived for specific corpora used in TTS tasks. """ def __init__(self, args: argparse.Namespace):