diff --git a/docs/source/changelog/changelog_2.2.rst b/docs/source/changelog/changelog_2.2.rst new file mode 100644 index 00000000..7a9e92f8 --- /dev/null +++ b/docs/source/changelog/changelog_2.2.rst @@ -0,0 +1,18 @@ + +.. _changelog_2.2: + +************* +2.2 Changelog +************* + +2.2.1 +===== + +- Fixed a couple of bugs in training Phonetisaurus models +- Added training of Phonetisaurus models for tokenizer + +2.2.0 +===== + +- Add support for training tokenizers and tokenization +- Migrate most os.path functionality to pathlib diff --git a/docs/source/changelog/index.md b/docs/source/changelog/index.md index 2d1f0fdf..705b527f 100644 --- a/docs/source/changelog/index.md +++ b/docs/source/changelog/index.md @@ -60,6 +60,7 @@ Not tied to 2.1, but in the near-ish term I would like to: :hidden: :maxdepth: 1 +changelog_2.2.rst news_2.1.rst changelog_2.1.rst news_2.0.rst diff --git a/docs/source/conf.py b/docs/source/conf.py index f3eeb191..7285f2d7 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -345,7 +345,7 @@ # "image_dark": "logo-dark.svg", }, "analytics": { - "google_analytics_id": "UA-73068199-4", + "google_analytics_id": "353930198", }, # "show_nav_level": 1, # "navigation_depth": 4, diff --git a/docs/source/first_steps/index.rst b/docs/source/first_steps/index.rst index 707b226c..3bac7a0c 100644 --- a/docs/source/first_steps/index.rst +++ b/docs/source/first_steps/index.rst @@ -34,6 +34,11 @@ There are several broad use cases that you might want to use MFA for. Take a lo #. Use the trained G2P model in :ref:`first_steps_g2p_pretrained` to generate a pronunciation dictionary #. Use the generated pronunciation dictionary in :ref:`first_steps_align_train_acoustic_model` to generate aligned TextGrids +#. **Use case 5:** You have a :ref:`speech corpus ` and the language involved is in the list of :xref:`pretrained_acoustic_models`, but the language does not mark word boundaries in its orthography. + + #. Follow :ref:`first_steps_tokenize` to tokenize the corpus + #. Use the tokenized transcripts and follow :ref:`first_steps_align_pretrained` + .. _first_steps_align_pretrained: Aligning a speech corpus with existing pronunciation dictionary and acoustic model @@ -90,8 +95,8 @@ Depending on your use case, you might have a list of words to run G2P over, or j .. code-block:: - mfa g2p english_us_arpa ~/mfa_data/my_corpus ~/mfa_data/new_dictionary.txt # If using a corpus - mfa g2p english_us_arpa ~/mfa_data/my_word_list.txt ~/mfa_data/new_dictionary.txt # If using a word list + mfa g2p ~/mfa_data/my_corpus english_us_arpa ~/mfa_data/new_dictionary.txt # If using a corpus + mfa g2p ~/mfa_data/my_word_list.txt english_us_arpa ~/mfa_data/new_dictionary.txt # If using a word list Running one of the above will output a text file pronunciation dictionary in the format that MFA uses (:ref:`dictionary_format`). I recommend looking over the pronunciations generated and make sure that they look sensible. For languages where the orthography is not transparent, it may be helpful to include :code:`--num_pronunciations 3` so that more pronunciations are generated than just the most likely one. For more details on running G2P, see :ref:`g2p_dictionary_generating`. @@ -170,11 +175,11 @@ Once the G2P model is trained, you should see the exported archive in the folder mfa model save g2p ~/mfa_data/my_g2p_model.zip - mfa g2p my_g2p_model ~/mfa_data/my_new_word_list.txt ~/mfa_data/my_new_dictionary.txt + mfa g2p ~/mfa_data/my_new_word_list.txt my_g2p_model ~/mfa_data/my_new_dictionary.txt # Or - mfa g2p ~/mfa_data/my_g2p_model.zip ~/mfa_data/my_new_word_list.txt ~/mfa_data/my_new_dictionary.txt + mfa g2p ~/mfa_data/my_new_word_list.txt ~/mfa_data/my_g2p_model.zip ~/mfa_data/my_new_dictionary.txt Take a look at :ref:`first_steps_g2p_pretrained` with this new model for a more detailed walk-through of generating a dictionary. @@ -182,6 +187,38 @@ Take a look at :ref:`first_steps_g2p_pretrained` with this new model for a more Please see :ref:`g2p_model_training_example` for an example using toy data. +.. _first_steps_tokenize: + +Tokenize a corpus to add word boundaries +---------------------------------------- + +For the purposes of this example, we'll also assume that you have done nothing else with MFA other than follow the :ref:`installation` instructions and you have the :code:`mfa` command working. Finally, we'll assume that your corpus is in Japanese and is stored in the folder :code:`~/mfa_data/my_corpus`, so when working with your data, this will be the main thing to update. + +To tokenize the Japanese text to add spaces, first download the Japanese tokenizer model via: + + +.. code-block:: + + mfa model download tokenizer japanese_mfa + +Once you have the model downloaded, you can tokenize your corpus via: + +.. code-block:: + + mfa tokenize ~/mfa_data/my_corpus japanese_mfa ~/mfa_data/tokenized_version + +You can check the tokenized text in :code:`~/mfa_data/tokenized_version`, verify that it looks good, and copy the files to replace the untokenized files in :code:`~/mfa_data/my_corpus` for use in alignment. + +.. warning:: + + MFA's tokenizer models are nowhere near state of the art, and I recommend using other tokenizers as they make sense: + + * Japanese: `nagisa `_ + * Chinese: `spacy-pkuseg `_ + * Thai: `sertiscorp/thai-word-segmentation `_ + + The above were used in the initial construction of the training corpora for MFA, though the training segmentations for Japanese have begun to diverge from :code:`nagisa`, as they break up phonological words into morphological parses where for the purposes of acoustic model training and alignment it makes more sense to not split (nagisa: :ipa_inline:`使っ て [ts ɨ k a Q t e]` vs mfa: :ipa_inline:`使って [ts ɨ k a tː e]`). The MFA tokenizer models are provided as an easy start up path as the ones listed above may have extra dependencies and platform restrictions. + .. toctree:: :maxdepth: 1 :hidden: diff --git a/montreal_forced_aligner/alignment/mixins.py b/montreal_forced_aligner/alignment/mixins.py index 56ee9c44..44a265a5 100644 --- a/montreal_forced_aligner/alignment/mixins.py +++ b/montreal_forced_aligner/alignment/mixins.py @@ -551,7 +551,7 @@ def compile_information(self) -> None: average_logdet_frames += data["logdet_frames"] average_logdet_sum += data["logdet"] * data["logdet_frames"] - if hasattr(self, "db_engine"): + if hasattr(self, "session"): csv_path = self.working_directory.joinpath("alignment_log_likelihood.csv") with mfa_open(csv_path, "w") as f, self.session() as session: writer = csv.writer(f) diff --git a/montreal_forced_aligner/command_line/train_tokenizer.py b/montreal_forced_aligner/command_line/train_tokenizer.py index 71325e36..2cf241ed 100644 --- a/montreal_forced_aligner/command_line/train_tokenizer.py +++ b/montreal_forced_aligner/command_line/train_tokenizer.py @@ -12,7 +12,10 @@ common_options, ) from montreal_forced_aligner.config import GLOBAL_CONFIG, MFA_PROFILE_VARIABLE -from montreal_forced_aligner.tokenization.trainer import TokenizerTrainer +from montreal_forced_aligner.tokenization.trainer import ( + PhonetisaurusTokenizerTrainer, + TokenizerTrainer, +) __all__ = ["train_tokenizer_cli"] @@ -48,6 +51,12 @@ "most of the data and validating on an unseen subset.", default=False, ) +@click.option( + "--phonetisaurus", + is_flag=True, + help="Flag for using Phonetisaurus-style models.", + default=False, +) @common_options @click.help_option("-h", "--help") @click.pass_context @@ -63,10 +72,19 @@ def train_tokenizer_cli(context, **kwargs) -> None: config_path = kwargs.get("config_path", None) corpus_directory = kwargs["corpus_directory"] output_model_path = kwargs["output_model_path"] - trainer = TokenizerTrainer( - corpus_directory=corpus_directory, - **TokenizerTrainer.parse_parameters(config_path, context.params, context.args), - ) + phonetisaurus = kwargs["phonetisaurus"] + if phonetisaurus: + trainer = PhonetisaurusTokenizerTrainer( + corpus_directory=corpus_directory, + **PhonetisaurusTokenizerTrainer.parse_parameters( + config_path, context.params, context.args + ), + ) + else: + trainer = TokenizerTrainer( + corpus_directory=corpus_directory, + **TokenizerTrainer.parse_parameters(config_path, context.params, context.args), + ) try: trainer.setup() diff --git a/montreal_forced_aligner/corpus/multiprocessing.py b/montreal_forced_aligner/corpus/multiprocessing.py index 22bf3a24..b30ce347 100644 --- a/montreal_forced_aligner/corpus/multiprocessing.py +++ b/montreal_forced_aligner/corpus/multiprocessing.py @@ -494,11 +494,8 @@ def _no_dictionary_sanitize(self, session): text.append(w) if character_text: character_text.append("") - if self.bracket_regex.match(w): - character_text.append(self.bracketed_word) - else: - for g in w: - character_text.append(g) + for g in w: + character_text.append(g) text = " ".join(text) character_text = " ".join(character_text) yield { diff --git a/montreal_forced_aligner/g2p/mixins.py b/montreal_forced_aligner/g2p/mixins.py index a5ec43ee..3d19ccb9 100644 --- a/montreal_forced_aligner/g2p/mixins.py +++ b/montreal_forced_aligner/g2p/mixins.py @@ -1,6 +1,6 @@ """Mixin module for G2P functionality""" import typing -from abc import ABCMeta, abstractmethod +from abc import ABCMeta from pathlib import Path from typing import Dict, List @@ -36,7 +36,6 @@ def __init__( self.g2p_threshold = g2p_threshold self.include_bracketed = include_bracketed - @abstractmethod def generate_pronunciations(self) -> Dict[str, List[str]]: """ Generate pronunciations @@ -46,13 +45,12 @@ def generate_pronunciations(self) -> Dict[str, List[str]]: dict[str, list[str]] Mappings of keys to their generated pronunciations """ - ... + raise NotImplementedError @property - @abstractmethod def words_to_g2p(self) -> List[str]: """Words to produce pronunciations""" - ... + raise NotImplementedError class G2PTopLevelMixin(MfaWorker, DictionaryMixin, G2PMixin): diff --git a/montreal_forced_aligner/g2p/phonetisaurus_trainer.py b/montreal_forced_aligner/g2p/phonetisaurus_trainer.py index c9d34954..94c79293 100644 --- a/montreal_forced_aligner/g2p/phonetisaurus_trainer.py +++ b/montreal_forced_aligner/g2p/phonetisaurus_trainer.py @@ -1,6 +1,5 @@ from __future__ import annotations -import abc import collections import logging import multiprocessing as mp @@ -93,8 +92,8 @@ class AlignmentInitArguments: deletions: bool insertions: bool restrict: bool - phone_order: int - grapheme_order: int + output_order: int + input_order: int eps: str s1s2_sep: str seq_sep: str @@ -137,8 +136,8 @@ def __init__( self.deletions = args.deletions self.insertions = args.insertions self.restrict = args.restrict - self.phone_order = args.phone_order - self.grapheme_order = args.grapheme_order + self.output_order = args.output_order + self.input_order = args.input_order self.eps = args.eps self.s1s2_sep = args.s1s2_sep self.seq_sep = args.seq_sep @@ -149,6 +148,19 @@ def __init__( self.db_string = args.db_string self.batch_size = args.batch_size + def data_generator(self, session): + query = ( + session.query(Word.word, Pronunciation.pronunciation) + .join(Pronunciation.word) + .join(Word.job) + .filter(Word2Job.training == True) # noqa + .filter(Word2Job.job_id == self.job_name) + ) + for w, p in query: + w = list(w) + p = p.split() + yield w, p + def run(self) -> None: """Run the function""" engine = sqlalchemy.create_engine( @@ -161,59 +173,51 @@ def run(self) -> None: try: symbol_table = pynini.SymbolTable() symbol_table.add_symbol(self.eps) - Session = scoped_session(sessionmaker(bind=engine, autoflush=False, autocommit=False)) - valid_phone_ngrams = set() + valid_output_ngrams = set() base_dir = os.path.dirname(self.far_path) - with mfa_open(os.path.join(base_dir, "phone_ngram.ngrams"), "r") as f: + with mfa_open(os.path.join(base_dir, "output_ngram.ngrams"), "r") as f: for line in f: line = line.strip() - valid_phone_ngrams.add(line) - valid_grapheme_ngrams = set() - with mfa_open(os.path.join(base_dir, "grapheme_ngram.ngrams"), "r") as f: + valid_output_ngrams.add(line) + valid_input_ngrams = set() + with mfa_open(os.path.join(base_dir, "input_ngram.ngrams"), "r") as f: for line in f: line = line.strip() - valid_grapheme_ngrams.add(line) + valid_input_ngrams.add(line) count = 0 data = {} - with mfa_open(self.log_path, "w") as log_file, Session() as session: + with mfa_open(self.log_path, "w") as log_file, sqlalchemy.orm.Session( + engine + ) as session: far_writer = pywrapfst.FarWriter.create(self.far_path, arc_type="log") - query = ( - session.query(Pronunciation.pronunciation, Word.word) - .join(Pronunciation.word) - .join(Word.job) - .filter(Word2Job.training == True) # noqa - .filter(Word2Job.job_id == self.job_name) - ) - for current_index, (phones, graphemes) in enumerate(query): - graphemes = list(graphemes) - phones = phones.split() + for current_index, (input, output) in enumerate(self.data_generator(session)): if self.stopped.stop_check(): continue try: key = f"{current_index:08x}" fst = pynini.Fst(arc_type="log") - final_state = ((len(graphemes) + 1) * (len(phones) + 1)) - 1 + final_state = ((len(input) + 1) * (len(output) + 1)) - 1 for _ in range(final_state + 1): fst.add_state() - for i in range(len(graphemes) + 1): - for j in range(len(phones) + 1): - istate = i * (len(phones) + 1) + j + for i in range(len(input) + 1): + for j in range(len(output) + 1): + istate = i * (len(output) + 1) + j if self.deletions: - for phone_range in range(1, self.phone_order + 1): - if j + phone_range <= len(phones): - subseq_phones = phones[j : j + phone_range] - phone_string = self.seq_sep.join(subseq_phones) + for output_range in range(1, self.output_order + 1): + if j + output_range <= len(output): + subseq_output = output[j : j + output_range] + output_string = self.seq_sep.join(subseq_output) if ( - phone_range > 1 - and phone_string not in valid_phone_ngrams + output_range > 1 + and output_string not in valid_output_ngrams ): continue - symbol = self.s1s2_sep.join([self.skip, phone_string]) + symbol = self.s1s2_sep.join([self.skip, output_string]) ilabel = symbol_table.find(symbol) if ilabel == pynini.NO_LABEL: ilabel = symbol_table.add_symbol(symbol) - ostate = i * (len(phones) + 1) + (j + phone_range) + ostate = i * (len(output) + 1) + (j + output_range) fst.add_arc( istate, pywrapfst.Arc( @@ -224,22 +228,20 @@ def run(self) -> None: ), ) if self.insertions: - for grapheme_range in range(1, self.grapheme_order + 1): - if i + grapheme_range <= len(graphemes): - subseq_graphemes = graphemes[i : i + grapheme_range] - grapheme_string = self.seq_sep.join(subseq_graphemes) + for input_range in range(1, self.input_order + 1): + if i + input_range <= len(input): + subseq_input = input[i : i + input_range] + input_string = self.seq_sep.join(subseq_input) if ( - grapheme_range > 1 - and grapheme_string not in valid_grapheme_ngrams + input_range > 1 + and input_string not in valid_input_ngrams ): continue - symbol = self.s1s2_sep.join( - [grapheme_string, self.skip] - ) + symbol = self.s1s2_sep.join([input_string, self.skip]) ilabel = symbol_table.find(symbol) if ilabel == pynini.NO_LABEL: ilabel = symbol_table.add_symbol(symbol) - ostate = (i + grapheme_range) * (len(phones) + 1) + j + ostate = (i + input_range) * (len(output) + 1) + j fst.add_arc( istate, pywrapfst.Arc( @@ -250,39 +252,39 @@ def run(self) -> None: ), ) - for grapheme_range in range(1, self.grapheme_order + 1): - for phone_range in range(1, self.phone_order + 1): - if i + grapheme_range <= len( - graphemes - ) and j + phone_range <= len(phones): + for input_range in range(1, self.input_order + 1): + for output_range in range(1, self.output_order + 1): + if i + input_range <= len( + input + ) and j + output_range <= len(output): if ( self.restrict - and grapheme_range > 1 - and phone_range > 1 + and input_range > 1 + and output_range > 1 ): continue - subseq_phones = phones[j : j + phone_range] - phone_string = self.seq_sep.join(subseq_phones) + subseq_output = output[j : j + output_range] + output_string = self.seq_sep.join(subseq_output) if ( - phone_range > 1 - and phone_string not in valid_phone_ngrams + output_range > 1 + and output_string not in valid_output_ngrams ): continue - subseq_graphemes = graphemes[i : i + grapheme_range] - grapheme_string = self.seq_sep.join(subseq_graphemes) + subseq_input = input[i : i + input_range] + input_string = self.seq_sep.join(subseq_input) if ( - grapheme_range > 1 - and grapheme_string not in valid_grapheme_ngrams + input_range > 1 + and input_string not in valid_input_ngrams ): continue symbol = self.s1s2_sep.join( - [grapheme_string, phone_string] + [input_string, output_string] ) ilabel = symbol_table.find(symbol) if ilabel == pynini.NO_LABEL: ilabel = symbol_table.add_symbol(symbol) - ostate = (i + grapheme_range) * (len(phones) + 1) + ( - j + phone_range + ostate = (i + input_range) * (len(output) + 1) + ( + j + output_range ) fst.add_arc( istate, @@ -290,7 +292,7 @@ def run(self) -> None: ilabel, ilabel, pynini.Weight( - "log", float(grapheme_range * phone_range) + "log", float(input_range * output_range) ), ostate, ), @@ -706,6 +708,8 @@ class PhonetisaurusTrainerMixin: Threshold of minimum change for early stopping of EM training """ + alignment_init_function = AlignmentInitWorker + def __init__( self, order: int = 8, @@ -742,6 +746,8 @@ def __init__( self.deletions = deletions self.grapheme_order = grapheme_order self.phone_order = phone_order + self.input_order = self.grapheme_order + self.output_order = self.phone_order self.sequence_separator = sequence_separator self.alignment_separator = alignment_separator self.skip = skip @@ -775,7 +781,7 @@ def initialize_alignments(self) -> None: stopped = Stopped() finished_adding = Stopped() procs = [] - for i in range(GLOBAL_CONFIG.num_jobs): + for i in range(1, GLOBAL_CONFIG.num_jobs + 1): args = AlignmentInitArguments( self.db_string, self.working_log_directory.joinpath(f"alignment_init.{i}.log"), @@ -792,7 +798,7 @@ def initialize_alignments(self) -> None: self.batch_size, ) procs.append( - AlignmentInitWorker( + self.alignment_init_function( i, return_queue, stopped, @@ -800,7 +806,7 @@ def initialize_alignments(self) -> None: args, ) ) - procs[i].start() + procs[-1].start() finished_adding.stop() error_list = [] @@ -904,7 +910,7 @@ def maximization(self, last_iteration=False) -> float: return_queue = mp.Queue() stopped = Stopped() procs = [] - for i in range(GLOBAL_CONFIG.num_jobs): + for i in range(1, GLOBAL_CONFIG.num_jobs + 1): args = MaximizationArguments( self.db_string, self.working_directory.joinpath(f"{i}.far"), @@ -912,7 +918,7 @@ def maximization(self, last_iteration=False) -> float: self.batch_size, ) procs.append(MaximizationWorker(i, return_queue, stopped, args)) - procs[i].start() + procs[-1].start() error_list = [] with tqdm.tqdm( @@ -958,14 +964,14 @@ def expectation(self) -> None: stopped = Stopped() error_list = [] procs = [] - for i in range(GLOBAL_CONFIG.num_jobs): + for i in range(1, GLOBAL_CONFIG.num_jobs + 1): args = ExpectationArguments( self.db_string, self.working_directory.joinpath(f"{i}.far"), self.batch_size, ) procs.append(ExpectationWorker(i, return_queue, stopped, args)) - procs[i].start() + procs[-1].start() mappings = {} zero = pynini.Weight.zero("log") with tqdm.tqdm( @@ -1020,7 +1026,7 @@ def train_ngram_model(self) -> None: error_list = [] procs = [] count_paths = [] - for i in range(GLOBAL_CONFIG.num_jobs): + for i in range(1, GLOBAL_CONFIG.num_jobs + 1): args = NgramCountArguments( self.working_log_directory.joinpath(f"ngram_count.{i}.log"), self.working_directory.joinpath(f"{i}.far"), @@ -1029,7 +1035,7 @@ def train_ngram_model(self) -> None: ) procs.append(NgramCountWorker(return_queue, stopped, args)) count_paths.append(args.far_path.with_suffix(".cnts")) - procs[i].start() + procs[-1].start() with tqdm.tqdm( total=self.g2p_num_training_pronunciations, disable=GLOBAL_CONFIG.quiet @@ -1060,17 +1066,20 @@ def train_ngram_model(self) -> None: logger.info("Training ngram model...") with mfa_open(self.working_log_directory.joinpath("model.log"), "w") as logf: - ngrammerge_proc = subprocess.Popen( - [ - thirdparty_binary("ngrammerge"), - f'--ofile={self.ngram_path.with_suffix(".cnts")}', - *count_paths, - ], - stderr=logf, - # stdout=subprocess.PIPE, - env=os.environ, - ) - ngrammerge_proc.communicate() + if len(count_paths) > 1: + ngrammerge_proc = subprocess.Popen( + [ + thirdparty_binary("ngrammerge"), + f'--ofile={self.ngram_path.with_suffix(".cnts")}', + *count_paths, + ], + stderr=logf, + # stdout=subprocess.PIPE, + env=os.environ, + ) + ngrammerge_proc.communicate() + else: + os.rename(count_paths[0], self.ngram_path.with_suffix(".cnts")) ngrammake_proc = subprocess.Popen( [ thirdparty_binary("ngrammake"), @@ -1215,21 +1224,6 @@ def train_alignments(self) -> None: if change < self.em_threshold: break - @property - @abc.abstractmethod - def working_directory(self) -> Path: - ... - - @property - @abc.abstractmethod - def working_log_directory(self) -> Path: - ... - - @property - @abc.abstractmethod - def db_string(self) -> str: - ... - @property def data_directory(self) -> Path: """Data directory for trainer""" @@ -1312,7 +1306,7 @@ def export_alignments(self) -> None: error_list = [] procs = [] count_paths = [] - for i in range(GLOBAL_CONFIG.num_jobs): + for i in range(1, GLOBAL_CONFIG.num_jobs + 1): args = AlignmentExportArguments( self.db_string, self.working_log_directory.joinpath(f"ngram_count.{i}.log"), @@ -1321,7 +1315,7 @@ def export_alignments(self) -> None: ) procs.append(AlignmentExporter(return_queue, stopped, args)) count_paths.append(args.far_path.with_suffix(".cnts")) - procs[i].start() + procs[-1].start() with tqdm.tqdm( total=self.g2p_num_training_pronunciations, disable=GLOBAL_CONFIG.quiet @@ -1361,8 +1355,8 @@ def export_alignments(self) -> None: stdin=subprocess.PIPE, stdout=subprocess.PIPE, ) - for j in range(GLOBAL_CONFIG.num_jobs): - text_path = self.working_directory.joinpath(f"{j}.far.strings") + for j in range(1, GLOBAL_CONFIG.num_jobs + 1): + text_path = self.working_directory.joinpath(f"{j}.strings") with mfa_open(text_path, "r") as f: for line in f: symbols_proc.stdin.write(line) @@ -1372,146 +1366,18 @@ def export_alignments(self) -> None: self.symbol_table = pynini.SymbolTable.read_text(self.alignment_symbols_path) logger.info("Done exporting alignments!") - -class PhonetisaurusTrainer( - MultispeakerDictionaryMixin, PhonetisaurusTrainerMixin, G2PTrainer, TopLevelMfaWorker -): - """ - Top level trainer class for Phonetisaurus-style models - """ - - def __init__( - self, - **kwargs, - ): - self._data_source = kwargs["dictionary_path"].stem - super().__init__(**kwargs) - self.ler = None - self.wer = None - - @property - def data_directory(self) -> Path: - """Data directory for trainer""" - return self.working_directory - - @property - def configuration(self) -> MetaDict: - """Configuration for G2P trainer""" - config = super().configuration - config.update({"dictionary_path": str(self.dictionary_model.path)}) - return config - - def setup(self) -> None: - """Setup for G2P training""" - super().setup() - self.create_new_current_workflow(WorkflowType.train_g2p) - wf = self.current_workflow - if wf.done: - logger.info("G2P training already done, skipping.") - return - self.dictionary_setup() - os.makedirs(self.phones_dir, exist_ok=True) - self.initialize_training() - self.initialized = True - - def train(self) -> None: - """ - Train a G2P model - """ - os.makedirs(self.working_log_directory, exist_ok=True) - begin = time.time() - self.train_alignments() - logger.debug( - f"Aligning {len(self.g2p_training_dictionary)} words took {time.time() - begin:.3f} seconds" - ) - self.export_alignments() - begin = time.time() - self.train_ngram_model() - logger.debug( - f"Generating model for {len(self.g2p_training_dictionary)} words took {time.time() - begin:.3f} seconds" - ) - self.finalize_training() - - def finalize_training(self) -> None: - """Finalize training and run evaluation if specified""" - if self.evaluation_mode: - self.evaluate_g2p_model() - - @property - def meta(self) -> MetaDict: - """Metadata for exported G2P model""" - from datetime import datetime - - from ..utils import get_mfa_version - - m = { - "version": get_mfa_version(), - "architecture": self.architecture, - "train_date": str(datetime.now()), - "phones": sorted(self.non_silence_phones), - "graphemes": self.g2p_training_graphemes, - "grapheme_order": self.grapheme_order, - "phone_order": self.phone_order, - "sequence_separator": self.sequence_separator, - "evaluation": {}, - "training": { - "num_words": self.g2p_num_training_words, - "num_graphemes": len(self.g2p_training_graphemes), - "num_phones": len(self.non_silence_phones), - }, - } - - if self.evaluation_mode: - m["evaluation"]["num_words"] = self.g2p_num_validation_words - m["evaluation"]["word_error_rate"] = self.wer - m["evaluation"]["phone_error_rate"] = self.ler - return m - - def evaluate_g2p_model(self) -> None: - """ - Validate the G2P model against held out data - """ - temp_model_path = self.working_log_directory.joinpath("g2p_model.zip") - self.export_model(temp_model_path) - temp_dir = self.working_directory.joinpath("validation") - os.makedirs(temp_dir, exist_ok=True) - with self.session() as session: - validation_set = collections.defaultdict(set) - query = ( - session.query(Word.word, Pronunciation.pronunciation) - .join(Pronunciation.word) - .join(Word.job) - .filter(Word2Job.training == False) # noqa - ) - for w, pron in query: - validation_set[w].add(pron) - gen = PyniniValidator( - g2p_model_path=temp_model_path, - word_list=list(validation_set.keys()), - num_pronunciations=self.num_pronunciations, - ) - output = gen.generate_pronunciations() - with mfa_open(temp_dir.joinpath("validation_output.txt"), "w") as f: - for (orthography, pronunciations) in output.items(): - if not pronunciations: - continue - for p in pronunciations: - if not p: - continue - f.write(f"{orthography}\t{p}\n") - gen.compute_validation_errors(validation_set, output) - def compute_initial_ngrams(self) -> None: - word_path = self.working_directory.joinpath("words.txt") - word_ngram_path = self.working_directory.joinpath("grapheme_ngram.fst") - word_symbols_path = self.working_directory.joinpath("grapheme_ngram.syms") + logger.info("Computing initial ngrams...") + input_path = self.working_directory.joinpath("input.txt") + input_ngram_path = self.working_directory.joinpath("input_ngram.fst") + input_symbols_path = self.working_directory.joinpath("input_ngram.syms") symbols_proc = subprocess.Popen( [ thirdparty_binary("ngramsymbols"), "--OOV_symbol=", "--epsilon_symbol=", - word_path, - word_symbols_path, + input_path, + input_symbols_path, ], encoding="utf8", ) @@ -1520,8 +1386,8 @@ def compute_initial_ngrams(self) -> None: [ thirdparty_binary("farcompilestrings"), "--token_type=symbol", - f"--symbols={word_symbols_path}", - word_path, + f"--symbols={input_symbols_path}", + input_path, ], stdout=subprocess.PIPE, env=os.environ, @@ -1531,7 +1397,7 @@ def compute_initial_ngrams(self) -> None: thirdparty_binary("ngramcount"), "--require_symbols=false", "--round_to_int", - f"--order={self.grapheme_order}", + f"--order={self.input_order}", ], stdin=farcompile_proc.stdout, stdout=subprocess.PIPE, @@ -1560,7 +1426,7 @@ def compute_initial_ngrams(self) -> None: print_proc = subprocess.Popen( [ thirdparty_binary("ngramprint"), - f"--symbols={word_symbols_path}", + f"--symbols={input_symbols_path}", ], env=os.environ, stdin=ngramshrink_proc.stdout, @@ -1576,20 +1442,20 @@ def compute_initial_ngrams(self) -> None: ngrams.add(ngram) print_proc.wait() - with mfa_open(word_ngram_path.with_suffix(".ngrams"), "w") as f: + with mfa_open(input_ngram_path.with_suffix(".ngrams"), "w") as f: for ngram in sorted(ngrams): f.write(f"{ngram}\n") - phone_path = self.working_directory.joinpath("pronunciations.txt") - phone_ngram_path = self.working_directory.joinpath("phone_ngram.fst") - phone_symbols_path = self.working_directory.joinpath("phone_ngram.syms") + output_path = self.working_directory.joinpath("output.txt") + output_ngram_path = self.working_directory.joinpath("output_ngram.fst") + output_symbols_path = self.working_directory.joinpath("output_ngram.syms") symbols_proc = subprocess.Popen( [ thirdparty_binary("ngramsymbols"), "--OOV_symbol=", "--epsilon_symbol=", - phone_path, - phone_symbols_path, + output_path, + output_symbols_path, ], encoding="utf8", ) @@ -1598,8 +1464,8 @@ def compute_initial_ngrams(self) -> None: [ thirdparty_binary("farcompilestrings"), "--token_type=symbol", - f"--symbols={phone_symbols_path}", - phone_path, + f"--symbols={output_symbols_path}", + output_path, ], stdout=subprocess.PIPE, env=os.environ, @@ -1609,7 +1475,7 @@ def compute_initial_ngrams(self) -> None: thirdparty_binary("ngramcount"), "--require_symbols=false", "--round_to_int", - f"--order={self.phone_order}", + f"--order={self.output_order}", ], stdin=farcompile_proc.stdout, stdout=subprocess.PIPE, @@ -1636,7 +1502,7 @@ def compute_initial_ngrams(self) -> None: env=os.environ, ) print_proc = subprocess.Popen( - [thirdparty_binary("ngramprint"), f"--symbols={phone_symbols_path}"], + [thirdparty_binary("ngramprint"), f"--symbols={output_symbols_path}"], env=os.environ, stdin=ngramshrink_proc.stdout, stdout=subprocess.PIPE, @@ -1651,10 +1517,139 @@ def compute_initial_ngrams(self) -> None: ngrams.add(ngram) print_proc.wait() - with mfa_open(phone_ngram_path.with_suffix(".ngrams"), "w") as f: + with mfa_open(output_ngram_path.with_suffix(".ngrams"), "w") as f: for ngram in sorted(ngrams): f.write(f"{ngram}\n") + def train(self) -> None: + """ + Train a G2P model + """ + os.makedirs(self.working_log_directory, exist_ok=True) + begin = time.time() + self.train_alignments() + logger.debug( + f"Aligning {len(self.g2p_training_dictionary)} words took {time.time() - begin:.3f} seconds" + ) + self.export_alignments() + begin = time.time() + self.train_ngram_model() + logger.debug( + f"Generating model for {len(self.g2p_training_dictionary)} words took {time.time() - begin:.3f} seconds" + ) + self.finalize_training() + + +class PhonetisaurusTrainer( + MultispeakerDictionaryMixin, PhonetisaurusTrainerMixin, G2PTrainer, TopLevelMfaWorker +): + """ + Top level trainer class for Phonetisaurus-style models + """ + + def __init__( + self, + **kwargs, + ): + self._data_source = kwargs["dictionary_path"].stem + super().__init__(**kwargs) + self.ler = None + self.wer = None + + @property + def data_directory(self) -> Path: + """Data directory for trainer""" + return self.working_directory + + @property + def configuration(self) -> MetaDict: + """Configuration for G2P trainer""" + config = super().configuration + config.update({"dictionary_path": str(self.dictionary_model.path)}) + return config + + def setup(self) -> None: + """Setup for G2P training""" + super().setup() + self.create_new_current_workflow(WorkflowType.train_g2p) + wf = self.current_workflow + if wf.done: + logger.info("G2P training already done, skipping.") + return + self.dictionary_setup() + os.makedirs(self.phones_dir, exist_ok=True) + self.initialize_training() + self.initialized = True + + def finalize_training(self) -> None: + """Finalize training and run evaluation if specified""" + if self.evaluation_mode: + self.evaluate_g2p_model() + + @property + def meta(self) -> MetaDict: + """Metadata for exported G2P model""" + from datetime import datetime + + from ..utils import get_mfa_version + + m = { + "version": get_mfa_version(), + "architecture": self.architecture, + "train_date": str(datetime.now()), + "phones": sorted(self.non_silence_phones), + "graphemes": self.g2p_training_graphemes, + "grapheme_order": self.grapheme_order, + "phone_order": self.phone_order, + "sequence_separator": self.sequence_separator, + "evaluation": {}, + "training": { + "num_words": self.g2p_num_training_words, + "num_graphemes": len(self.g2p_training_graphemes), + "num_phones": len(self.non_silence_phones), + }, + } + + if self.evaluation_mode: + m["evaluation"]["num_words"] = self.g2p_num_validation_words + m["evaluation"]["word_error_rate"] = self.wer + m["evaluation"]["phone_error_rate"] = self.ler + return m + + def evaluate_g2p_model(self) -> None: + """ + Validate the G2P model against held out data + """ + temp_model_path = self.working_log_directory.joinpath("g2p_model.zip") + self.export_model(temp_model_path) + temp_dir = self.working_directory.joinpath("validation") + os.makedirs(temp_dir, exist_ok=True) + with self.session() as session: + validation_set = collections.defaultdict(set) + query = ( + session.query(Word.word, Pronunciation.pronunciation) + .join(Pronunciation.word) + .join(Word.job) + .filter(Word2Job.training == False) # noqa + ) + for w, pron in query: + validation_set[w].add(pron) + gen = PyniniValidator( + g2p_model_path=temp_model_path, + word_list=list(validation_set.keys()), + num_pronunciations=self.num_pronunciations, + ) + output = gen.generate_pronunciations() + with mfa_open(temp_dir.joinpath("validation_output.txt"), "w") as f: + for (orthography, pronunciations) in output.items(): + if not pronunciations: + continue + for p in pronunciations: + if not p: + continue + f.write(f"{orthography}\t{p}\n") + gen.compute_validation_errors(validation_set, output) + def initialize_training(self) -> None: """Initialize training G2P model""" with self.session() as session: @@ -1664,7 +1659,7 @@ def initialize_training(self) -> None: session.query(Job).delete() session.commit() - job_objs = [{"id": j} for j in range(GLOBAL_CONFIG.num_jobs)] + job_objs = [{"id": j} for j in range(1, GLOBAL_CONFIG.num_jobs + 1)] self.g2p_num_training_pronunciations = 0 self.g2p_num_validation_pronunciations = 0 self.g2p_num_training_words = 0 @@ -1673,15 +1668,15 @@ def initialize_training(self) -> None: # so they're not completely overlapping and using more memory num_words = session.query(Word.id).count() words_per_job = int(num_words / GLOBAL_CONFIG.num_jobs) + 1 - current_job = 0 + current_job = 1 words = session.query(Word.id).filter( Word.word_type.in_([WordType.speech, WordType.clitic]) ) mappings = [] for i, (w,) in enumerate(words): if ( - i >= (current_job + 1) * words_per_job - and current_job != GLOBAL_CONFIG.num_jobs + i >= (current_job) * words_per_job + and current_job != GLOBAL_CONFIG.num_jobs + 1 ): current_job += 1 mappings.append({"word_id": w, "job_id": current_job, "training": 1}) @@ -1733,8 +1728,8 @@ def initialize_training(self) -> None: .join(Word.job) .filter(Word2Job.training == True) # noqa ) - with mfa_open(self.working_directory.joinpath("words.txt"), "w") as word_f, mfa_open( - self.working_directory.joinpath("pronunciations.txt"), "w" + with mfa_open(self.working_directory.joinpath("input.txt"), "w") as word_f, mfa_open( + self.working_directory.joinpath("output.txt"), "w" ) as phone_f: for pronunciation, word in query: word = list(word) diff --git a/montreal_forced_aligner/models.py b/montreal_forced_aligner/models.py index 59f3a09f..2d692ae1 100644 --- a/montreal_forced_aligner/models.py +++ b/montreal_forced_aligner/models.py @@ -931,7 +931,26 @@ def sym_path(self) -> Path: path = self.dirname.joinpath("graphemes.txt") if path.exists(): return path - return self.dirname.joinpath("graphemes.sym") + path = self.dirname.joinpath("graphemes.sym") + if path.exists(): + return path + return self.dirname.joinpath("graphemes.syms") + + @property + def input_sym_path(self) -> Path: + """Tokenizer model's input symbols path""" + path = self.dirname.joinpath("input.txt") + if path.exists(): + return path + return self.dirname.joinpath("input.syms") + + @property + def output_sym_path(self) -> Path: + """Tokenizer model's output symbols path""" + path = self.dirname.joinpath("output.txt") + if path.exists(): + return path + return self.dirname.joinpath("output.syms") def add_graphemes_path(self, source_directory: Path) -> None: """ @@ -942,8 +961,10 @@ def add_graphemes_path(self, source_directory: Path) -> None: source_directory: :class:`~pathlib.Path` Source directory path """ - if not self.sym_path.exists(): - copyfile(source_directory.joinpath("graphemes.txt"), self.sym_path) + for p in [self.sym_path, self.output_sym_path, self.input_sym_path]: + source_p = source_directory.joinpath(p.name) + if not p.exists() and source_p.exists(): + copyfile(source_p, p) def add_tokenizer_model(self, source_directory: Path) -> None: """ diff --git a/montreal_forced_aligner/tokenization/tokenizer.py b/montreal_forced_aligner/tokenization/tokenizer.py index 1ff162c3..91a83cf9 100644 --- a/montreal_forced_aligner/tokenization/tokenizer.py +++ b/montreal_forced_aligner/tokenization/tokenizer.py @@ -28,7 +28,7 @@ from montreal_forced_aligner.db import File, Utterance, bulk_update from montreal_forced_aligner.dictionary.mixins import DictionaryMixin from montreal_forced_aligner.exceptions import PyniniGenerationError -from montreal_forced_aligner.g2p.generator import Rewriter, RewriterWorker +from montreal_forced_aligner.g2p.generator import PhonetisaurusRewriter, Rewriter, RewriterWorker from montreal_forced_aligner.helper import edit_distance, mfa_open from montreal_forced_aligner.models import TokenizerModel from montreal_forced_aligner.utils import Stopped, run_kaldi_function @@ -97,6 +97,94 @@ def __call__(self, i: str) -> str: # pragma: no cover return "".join(hypothesis) +class TokenizerPhonetisaurusRewriter(PhonetisaurusRewriter): + """ + Helper function for rewriting + + Parameters + ---------- + fst: pynini.Fst + G2P FST model + input_token_type: pynini.SymbolTable + Grapheme symbol table + output_token_type: pynini.SymbolTable + num_pronunciations: int + Number of pronunciations, default to 0. If this is 0, thresholding is used + threshold: float + Threshold to use for pruning rewrite lattice, defaults to 1.5, only used if num_pronunciations is 0 + grapheme_order: int + Maximum number of graphemes to consider single segment + seq_sep: str + Separator to use between grapheme symbols + """ + + def __init__( + self, + fst: Fst, + input_token_type: SymbolTable, + output_token_type: SymbolTable, + input_order: int = 2, + seq_sep: str = "|", + ): + self.fst = fst + self.seq_sep = seq_sep + self.input_token_type = input_token_type + self.output_token_type = output_token_type + self.input_order = input_order + self.rewrite = functools.partial( + rewrite.top_rewrite, + rule=fst, + input_token_type=None, + output_token_type=output_token_type, + ) + + def __call__(self, graphemes: str) -> str: # pragma: no cover + """Call the rewrite function""" + graphemes = graphemes.replace(" ", "") + original = list(graphemes) + unks = [] + normalized = [] + for c in original: + if self.output_token_type.member(c): + normalized.append(c) + else: + unks.append(c) + normalized.append("") + fst = pynini.Fst() + one = pynini.Weight.one(fst.weight_type()) + max_state = 0 + for i in range(len(normalized)): + start_state = fst.add_state() + for j in range(1, self.input_order + 1): + if i + j <= len(normalized): + substring = self.seq_sep.join(normalized[i : i + j]) + ilabel = self.input_token_type.find(substring) + if ilabel != pynini.NO_LABEL: + fst.add_arc(start_state, pynini.Arc(ilabel, ilabel, one, i + j)) + if i + j >= max_state: + max_state = i + j + for _ in range(fst.num_states(), max_state + 1): + fst.add_state() + fst.set_start(0) + fst.set_final(len(normalized), one) + fst.set_input_symbols(self.input_token_type) + fst.set_output_symbols(self.input_token_type) + hypothesis = self.rewrite(fst).split() + unk_index = 0 + output = [] + for i, w in enumerate(hypothesis): + if w == "": + output.append(unks[unk_index]) + unk_index += 1 + elif w == "": + if i > 0 and hypothesis[i - 1] == " ": + continue + output.append(" ") + else: + output.append(w) + return "".join(output).strip() + + @dataclass class TokenizerArguments(MfaArguments): rewriter: Rewriter @@ -142,12 +230,27 @@ def setup(self) -> None: self._create_dummy_dictionary() self.normalize_text() self.fst = pynini.Fst.read(self.tokenizer_model.fst_path) - self.grapheme_symbols = pywrapfst.SymbolTable.read_text(self.tokenizer_model.sym_path) - self.rewriter = TokenizerRewriter( - self.fst, - self.grapheme_symbols, - ) + if self.tokenizer_model.meta["architecture"] == "phonetisaurus": + self.output_token_type = pywrapfst.SymbolTable.read_text( + self.tokenizer_model.output_sym_path + ) + self.input_token_type = pywrapfst.SymbolTable.read_text( + self.tokenizer_model.input_sym_path + ) + self.rewriter = TokenizerPhonetisaurusRewriter( + self.fst, + self.input_token_type, + self.output_token_type, + input_order=self.tokenizer_model.meta["input_order"], + ) + else: + self.grapheme_symbols = pywrapfst.SymbolTable.read_text(self.tokenizer_model.sym_path) + + self.rewriter = TokenizerRewriter( + self.fst, + self.grapheme_symbols, + ) self.initialized = True def export_files(self, output_directory: Path) -> None: @@ -236,6 +339,8 @@ def __init__(self, utterances_to_tokenize: typing.List[str] = None, **kwargs): if utterances_to_tokenize is None: utterances_to_tokenize = [] self.utterances_to_tokenize = utterances_to_tokenize + self.uer = None + self.cer = None def setup(self): TopLevelMfaWorker.setup(self) @@ -244,15 +349,28 @@ def setup(self): self._current_workflow = "validation" os.makedirs(self.working_log_directory, exist_ok=True) self.fst = pynini.Fst.read(self.tokenizer_model.fst_path) - self.grapheme_symbols = pywrapfst.SymbolTable.read_text(self.tokenizer_model.sym_path) - self.rewriter = TokenizerRewriter( - self.fst, - self.grapheme_symbols, - ) + if self.tokenizer_model.meta["architecture"] == "phonetisaurus": + self.output_token_type = pywrapfst.SymbolTable.read_text( + self.tokenizer_model.output_sym_path + ) + self.input_token_type = pywrapfst.SymbolTable.read_text( + self.tokenizer_model.input_sym_path + ) + self.rewriter = TokenizerPhonetisaurusRewriter( + self.fst, + self.input_token_type, + self.output_token_type, + input_order=self.tokenizer_model.meta["input_order"], + ) + else: + self.grapheme_symbols = pywrapfst.SymbolTable.read_text(self.tokenizer_model.sym_path) + + self.rewriter = TokenizerRewriter( + self.fst, + self.grapheme_symbols, + ) self.initialized = True - self.uer = None - self.cer = None def tokenize_utterances(self) -> typing.Dict[str, str]: """ @@ -269,7 +387,7 @@ def tokenize_utterances(self) -> typing.Dict[str, str]: self.setup() logger.info("Tokenizing utterances...") to_return = {} - if True or num_utterances < 30 or GLOBAL_CONFIG.num_jobs == 1: + if num_utterances < 30 or GLOBAL_CONFIG.num_jobs == 1: with tqdm.tqdm(total=num_utterances, disable=GLOBAL_CONFIG.quiet) as pbar: for utterance in self.utterances_to_tokenize: pbar.update(1) diff --git a/montreal_forced_aligner/tokenization/trainer.py b/montreal_forced_aligner/tokenization/trainer.py index 4fc64d3e..fbc3e116 100644 --- a/montreal_forced_aligner/tokenization/trainer.py +++ b/montreal_forced_aligner/tokenization/trainer.py @@ -7,6 +7,7 @@ import time from pathlib import Path +import pynini import pywrapfst import sqlalchemy @@ -14,9 +15,13 @@ from montreal_forced_aligner.config import GLOBAL_CONFIG from montreal_forced_aligner.corpus.text_corpus import TextCorpusMixin from montreal_forced_aligner.data import WorkflowType -from montreal_forced_aligner.db import Utterance +from montreal_forced_aligner.db import M2M2Job, M2MSymbol, Utterance from montreal_forced_aligner.dictionary.mixins import DictionaryMixin from montreal_forced_aligner.exceptions import KaldiProcessingError +from montreal_forced_aligner.g2p.phonetisaurus_trainer import ( + AlignmentInitWorker, + PhonetisaurusTrainerMixin, +) from montreal_forced_aligner.g2p.trainer import G2PTrainer, PyniniTrainerMixin from montreal_forced_aligner.helper import mfa_open from montreal_forced_aligner.models import TokenizerModel @@ -28,16 +33,177 @@ logger = logging.getLogger("mfa") -class TokenizerTrainer( - PyniniTrainerMixin, TextCorpusMixin, G2PTrainer, TopLevelMfaWorker, DictionaryMixin -): - def __init__(self, oov_count_threshold=5, **kwargs): - super().__init__(oov_count_threshold=oov_count_threshold, **kwargs) +class TokenizerAlignmentInitWorker(AlignmentInitWorker): + """ + Multiprocessing worker that initializes alignment FSTs for a subset of the data + + Parameters + ---------- + job_name: int + Integer ID for the job + return_queue: :class:`multiprocessing.Queue` + Queue to return data + stopped: :class:`~montreal_forced_aligner.utils.Stopped` + Stop check + finished_adding: :class:`~montreal_forced_aligner.utils.Stopped` + Check for whether the job queue is done + args: :class:`~montreal_forced_aligner.g2p.phonetisaurus_trainer.AlignmentInitArguments` + Arguments for initialization + """ + + def data_generator(self, session): + grapheme_table = pywrapfst.SymbolTable.read_text(self.far_path.with_name("graphemes.syms")) + query = session.query(Utterance.normalized_character_text).filter( + Utterance.ignored == False, Utterance.job_id == self.job_name # noqa + ) + for (text,) in query: + tokenized = [x if grapheme_table.member(x) else "" for x in text.split()] + untokenized = [x for x in tokenized if x != ""] + yield untokenized, tokenized + + def run(self) -> None: + """Run the function""" + engine = sqlalchemy.create_engine( + self.db_string, + poolclass=sqlalchemy.NullPool, + pool_reset_on_return=None, + isolation_level="AUTOCOMMIT", + logging_name=f"{type(self).__name__}_engine", + ).execution_options(logging_token=f"{type(self).__name__}_engine") + try: + symbol_table = pynini.SymbolTable() + symbol_table.add_symbol(self.eps) + valid_output_ngrams = set() + base_dir = os.path.dirname(self.far_path) + with mfa_open(os.path.join(base_dir, "output_ngram.ngrams"), "r") as f: + for line in f: + line = line.strip() + valid_output_ngrams.add(line) + valid_input_ngrams = set() + with mfa_open(os.path.join(base_dir, "input_ngram.ngrams"), "r") as f: + for line in f: + line = line.strip() + valid_input_ngrams.add(line) + count = 0 + data = {} + with mfa_open(self.log_path, "w") as log_file, sqlalchemy.orm.Session( + engine + ) as session: + far_writer = pywrapfst.FarWriter.create(self.far_path, arc_type="log") + for current_index, (input, output) in enumerate(self.data_generator(session)): + if self.stopped.stop_check(): + continue + try: + key = f"{current_index:08x}" + fst = pynini.Fst(arc_type="log") + final_state = ((len(input) + 1) * (len(output) + 1)) - 1 + + for _ in range(final_state + 1): + fst.add_state() + for i in range(len(input) + 1): + for j in range(len(output) + 1): + istate = i * (len(output) + 1) + j + + for input_range in range(1, self.input_order + 1): + for output_range in range(input_range, self.output_order + 1): + if i + input_range <= len( + input + ) and j + output_range <= len(output): + if ( + self.restrict + and input_range > 1 + and output_range > 1 + ): + continue + subseq_output = output[j : j + output_range] + output_string = self.seq_sep.join(subseq_output) + if ( + output_range > 1 + and output_string not in valid_output_ngrams + ): + continue + subseq_input = input[i : i + input_range] + input_string = self.seq_sep.join(subseq_input) + if output_range > 1: + if "" not in subseq_output: + continue + if input_string not in output_string: + continue + if ( + output_range == input_range + and input_string != output_string + ): + continue + if ( + input_range > 1 + and input_string not in valid_input_ngrams + ): + continue + symbol = self.s1s2_sep.join( + [input_string, output_string] + ) + ilabel = symbol_table.find(symbol) + if ilabel == pynini.NO_LABEL: + ilabel = symbol_table.add_symbol(symbol) + ostate = (i + input_range) * (len(output) + 1) + ( + j + output_range + ) + fst.add_arc( + istate, + pywrapfst.Arc( + ilabel, + ilabel, + pynini.Weight( + "log", float(input_range * output_range) + ), + ostate, + ), + ) + fst.set_start(0) + fst.set_final(final_state, pywrapfst.Weight.one(fst.weight_type())) + fst = pynini.connect(fst) + for state in fst.states(): + for arc in fst.arcs(state): + sym = symbol_table.find(arc.ilabel) + if sym not in data: + data[sym] = arc.weight + else: + data[sym] = pynini.plus(data[sym], arc.weight) + if count >= self.batch_size: + data = {k: float(v) for k, v in data.items()} + self.return_queue.put((self.job_name, data, count)) + data = {} + count = 0 + log_file.flush() + far_writer[key] = fst + del fst + count += 1 + except Exception as e: # noqa + self.stopped.stop() + self.return_queue.put(e) + if data: + data = {k: float(v) for k, v in data.items()} + self.return_queue.put((self.job_name, data, count)) + symbol_table.write_text(self.far_path.with_suffix(".syms")) + return + except Exception as e: + self.stopped.stop() + self.return_queue.put(e) + finally: + self.finished.stop() + del far_writer + + +class TokenizerMixin(TextCorpusMixin, G2PTrainer, DictionaryMixin, TopLevelMfaWorker): + def __init__(self, **kwargs): + super().__init__(**kwargs) self.training_graphemes = set() self.uer = None self.cer = None self.deletions = False self.insertions = True + self.num_training_utterances = 0 + self.num_validation_utterances = 0 def setup(self) -> None: super().setup() @@ -57,6 +223,202 @@ def setup(self) -> None: raise self.initialized = True + def evaluate_tokenizer(self) -> None: + """ + Validate the tokenizer model against held out data + """ + temp_model_path = self.working_log_directory.joinpath("tokenizer_model.zip") + self.export_model(temp_model_path) + temp_dir = self.working_directory.joinpath("validation") + temp_dir.mkdir(parents=True, exist_ok=True) + with self.session() as session: + validation_set = {} + query = session.query(Utterance.normalized_character_text).filter( + Utterance.ignored == True # noqa + ) + for (text,) in query: + tokenized = text.split() + untokenized = [x for x in tokenized if x != ""] + tokenized = [x if x != "" else " " for x in tokenized] + validation_set[" ".join(untokenized)] = "".join(tokenized) + gen = TokenizerValidator( + tokenizer_model_path=temp_model_path, + corpus_directory=self.corpus_directory, + utterances_to_tokenize=list(validation_set.keys()), + ) + output = gen.tokenize_utterances() + with mfa_open(temp_dir.joinpath("validation_output.txt"), "w") as f: + for (orthography, pronunciations) in output.items(): + if not pronunciations: + continue + for p in pronunciations: + if not p: + continue + f.write(f"{orthography}\t{p}\n") + gen.compute_validation_errors(validation_set, output) + self.uer = gen.uer + self.cer = gen.cer + + +class PhonetisaurusTokenizerTrainer(PhonetisaurusTrainerMixin, TokenizerMixin): + + alignment_init_function = TokenizerAlignmentInitWorker + + def __init__( + self, input_order: int = 2, output_order: int = 3, oov_count_threshold: int = 5, **kwargs + ): + super().__init__( + oov_count_threshold=oov_count_threshold, + grapheme_order=input_order, + phone_order=output_order, + **kwargs, + ) + + @property + def data_source_identifier(self) -> str: + """Corpus name""" + return self.corpus_directory.name + + @property + def meta(self) -> MetaDict: + """Metadata for exported tokenizer model""" + from datetime import datetime + + from ..utils import get_mfa_version + + m = { + "version": get_mfa_version(), + "architecture": self.architecture, + "train_date": str(datetime.now()), + "evaluation": {}, + "input_order": self.input_order, + "output_order": self.output_order, + "oov_count_threshold": self.oov_count_threshold, + "training": { + "num_utterances": self.num_training_utterances, + "num_graphemes": len(self.training_graphemes), + }, + } + + if self.evaluation_mode: + m["evaluation"]["num_utterances"] = self.num_validation_utterances + m["evaluation"]["utterance_error_rate"] = self.uer + m["evaluation"]["character_error_rate"] = self.cer + return m + + def train(self) -> None: + if os.path.exists(self.fst_path): + return + super().train() + + def initialize_training(self) -> None: + """Initialize training tokenizer model""" + logger.info("Initializing training...") + + self.create_new_current_workflow(WorkflowType.tokenizer_training) + with self.session() as session: + session.query(M2M2Job).delete() + session.query(M2MSymbol).delete() + session.commit() + self.num_validation_utterances = 0 + self.num_training_utterances = 0 + if self.evaluation_mode: + validation_items = int(self.num_utterances * self.validation_proportion) + validation_utterances = ( + sqlalchemy.select(Utterance.id) + .order_by(sqlalchemy.func.random()) + .limit(validation_items) + .scalar_subquery() + ) + query = ( + sqlalchemy.update(Utterance) + .execution_options(synchronize_session="fetch") + .values(ignored=True) + .where(Utterance.id.in_(validation_utterances)) + ) + with session.begin_nested(): + session.execute(query) + session.flush() + session.commit() + self.num_validation_utterances = ( + session.query(Utterance.id).filter(Utterance.ignored == True).count() # noqa + ) + + query = session.query(Utterance.normalized_character_text).filter( + Utterance.ignored == False # noqa + ) + unk_character = "" + self.training_graphemes.add(unk_character) + counts = collections.Counter() + for (text,) in query: + counts.update(text.split()) + with mfa_open( + self.working_directory.joinpath("input.txt"), "w" + ) as untokenized_f, mfa_open( + self.working_directory.joinpath("output.txt"), "w" + ) as tokenized_f: + for (text,) in query: + assert text + tokenized = [ + x if counts[x] >= self.oov_count_threshold else unk_character + for x in text.split() + ] + untokenized = [x for x in tokenized if x != ""] + self.num_training_utterances += 1 + self.training_graphemes.update(tokenized) + untokenized_f.write(" ".join(untokenized) + "\n") + tokenized_f.write(" ".join(tokenized) + "\n") + index = 1 + with mfa_open(self.working_directory.joinpath("graphemes.syms"), "w") as f: + f.write("\t0\n") + for g in sorted(self.training_graphemes): + f.write(f"{g}\t{index}\n") + index += 1 + self.compute_initial_ngrams() + self.g2p_num_training_pronunciations = self.num_training_utterances + + def finalize_training(self) -> None: + """Finalize training""" + shutil.copyfile(self.fst_path, self.working_directory.joinpath("tokenizer.fst")) + shutil.copyfile(self.grapheme_symbols_path, self.working_directory.joinpath("input.syms")) + shutil.copyfile(self.phone_symbols_path, self.working_directory.joinpath("output.syms")) + if self.evaluation_mode: + self.evaluate_tokenizer() + + def export_model(self, output_model_path: Path) -> None: + """ + Export tokenizer model to specified path + + Parameters + ---------- + output_model_path: :class:`~pathlib.Path` + Path to export model + """ + directory = output_model_path.parent + + models_temp_dir = self.working_directory.joinpath("model_archive_temp") + model = TokenizerModel.empty(output_model_path.stem, root_directory=models_temp_dir) + model.add_meta_file(self) + model.add_tokenizer_model(self.working_directory) + model.add_graphemes_path(self.working_directory) + if directory: + os.makedirs(directory, exist_ok=True) + model.dump(output_model_path) + if not GLOBAL_CONFIG.current_profile.debug: + model.clean_up() + # self.clean_up() + logger.info(f"Saved model to {output_model_path}") + + +class TokenizerTrainer(PyniniTrainerMixin, TokenizerMixin): + def __init__(self, oov_count_threshold=5, **kwargs): + super().__init__(oov_count_threshold=oov_count_threshold, **kwargs) + self.training_graphemes = set() + self.uer = None + self.cer = None + self.deletions = False + self.insertions = True + @property def meta(self) -> MetaDict: """Metadata for exported tokenizer model""" @@ -88,11 +450,7 @@ def data_source_identifier(self) -> str: @property def sym_path(self) -> Path: - return self.working_directory.joinpath("graphemes.txt") - - @property - def phone_symbol_table_path(self) -> Path: - return self.working_directory.joinpath("graphemes.txt") + return self.working_directory.joinpath("graphemes.syms") def initialize_training(self) -> None: """Initialize training tokenizer model""" @@ -100,8 +458,10 @@ def initialize_training(self) -> None: with self.session() as session: self.num_validation_utterances = 0 self.num_training_utterances = 0 - self.num_iterations = 2 - self.input_token_type = self.working_directory.joinpath("graphemes.txt") + self.num_iterations = 1 + self.random_starts = 1 + self.input_token_type = self.sym_path + self.output_token_type = self.sym_path if self.evaluation_mode: validation_items = int(self.num_utterances * self.validation_proportion) validation_utterances = ( @@ -167,14 +527,11 @@ def _lexicon_covering(self, input_path=None, output_path=None) -> None: thirdparty_binary("farcompilestrings"), "--fst_type=compact", ] - if self.input_token_type != "utf8": - com.append("--token_type=symbol") - com.append( - f"--symbols={self.input_token_type}", - ) - com.append("--unknown_symbol=") - else: - com.append("--token_type=utf8") + com.append("--token_type=symbol") + com.append( + f"--symbols={self.sym_path}", + ) + com.append("--unknown_symbol=") com.extend([input_path, self.input_far_path]) print(" ".join(map(str, com)), file=log_file) subprocess.check_call(com, env=os.environ, stderr=log_file, stdout=log_file) @@ -182,7 +539,7 @@ def _lexicon_covering(self, input_path=None, output_path=None) -> None: thirdparty_binary("farcompilestrings"), "--fst_type=compact", "--token_type=symbol", - f"--symbols={self.phone_symbol_table_path}", + f"--symbols={self.sym_path}", output_path, self.output_far_path, ] @@ -203,48 +560,6 @@ def _lexicon_covering(self, input_path=None, output_path=None) -> None: assert cg.verify(), "Label acceptor is ill-formed" cg.write(self.cg_path) - def evaluate_tokenizer(self) -> None: - """ - Validate the tokenizer model against held out data - """ - temp_model_path = self.working_log_directory.joinpath("tokenizer_model.zip") - self.export_model(temp_model_path) - temp_dir = self.working_directory.joinpath("validation") - temp_dir.mkdir(parents=True, exist_ok=True) - with self.session() as session: - validation_set = {} - query = session.query(Utterance.normalized_character_text).filter( - Utterance.ignored == True # noqa - ) - for (text,) in query: - tokenized = text.split() - untokenized = [x for x in tokenized if x != ""] - tokenized = [x if x != "" else " " for x in tokenized] - validation_set[" ".join(untokenized)] = "".join(tokenized) - gen = TokenizerValidator( - tokenizer_model_path=temp_model_path, - corpus_directory=self.corpus_directory, - utterances_to_tokenize=list(validation_set.keys()), - ) - output = gen.tokenize_utterances() - with mfa_open(temp_dir.joinpath("validation_output.txt"), "w") as f: - for (orthography, pronunciations) in output.items(): - if not pronunciations: - continue - for p in pronunciations: - if not p: - continue - f.write(f"{orthography}\t{p}\n") - gen.compute_validation_errors(validation_set, output) - self.uer = gen.uer - self.cer = gen.cer - - def finalize_training(self) -> None: - """Finalize training""" - shutil.copyfile(self.fst_path, self.working_directory.joinpath("tokenizer.fst")) - if self.evaluation_mode: - self.evaluate_tokenizer() - def train(self) -> None: """ Train a tokenizer model @@ -261,6 +576,12 @@ def train(self) -> None: logger.debug(f"Generating model took {time.time() - begin:.3f} seconds") self.finalize_training() + def finalize_training(self) -> None: + """Finalize training""" + shutil.copyfile(self.fst_path, self.working_directory.joinpath("tokenizer.fst")) + if self.evaluation_mode: + self.evaluate_tokenizer() + def export_model(self, output_model_path: Path) -> None: """ Export tokenizer model to specified path diff --git a/tests/conftest.py b/tests/conftest.py index 2a52e0b8..adbf347a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -265,6 +265,11 @@ def test_tokenizer_model(tokenizer_model_dir): return tokenizer_model_dir.joinpath("test_tokenizer_model.zip") +@pytest.fixture(scope="session") +def test_tokenizer_model_phonetisaurus(tokenizer_model_dir): + return tokenizer_model_dir.joinpath("test_tokenizer_model_phonetisaurus.zip") + + @pytest.fixture(scope="session") def transcription_language_model(language_model_dir, generated_dir): return language_model_dir.joinpath("test_lm.zip") diff --git a/tests/data/tokenizer/test_tokenizer_model_phonetisaurus.zip b/tests/data/tokenizer/test_tokenizer_model_phonetisaurus.zip new file mode 100644 index 00000000..e2031a48 Binary files /dev/null and b/tests/data/tokenizer/test_tokenizer_model_phonetisaurus.zip differ diff --git a/tests/test_commandline_tokenize.py b/tests/test_commandline_tokenize.py index 22e2ec6d..5cf0b32d 100644 --- a/tests/test_commandline_tokenize.py +++ b/tests/test_commandline_tokenize.py @@ -1,12 +1,10 @@ import os import click.testing -import pytest from montreal_forced_aligner.command_line.mfa import mfa_cli -@pytest.mark.skip("No pretrained model yet") def test_tokenize_pretrained(japanese_tokenizer_model, japanese_dir, temp_dir, generated_dir): out_directory = generated_dir.joinpath("japanese_tokenized") command = [ @@ -60,6 +58,33 @@ def test_train_tokenizer(combined_corpus_dir, temp_dir, generated_dir): assert os.path.exists(output_path) +def test_train_tokenizer_phonetisaurus(combined_corpus_dir, temp_dir, generated_dir): + output_path = generated_dir.joinpath("test_tokenizer_model_phonetisaurus.zip") + command = [ + "train_tokenizer", + combined_corpus_dir, + output_path, + "-t", + os.path.join(temp_dir, "test_train_tokenizer_phonetisaurus"), + "-q", + "--clean", + "--debug", + "--phonetisaurus", + "--validate", + ] + command = [str(x) for x in command] + result = click.testing.CliRunner(mix_stderr=False, echo_stdin=True).invoke( + mfa_cli, command, catch_exceptions=True + ) + print(result.stdout) + print(result.stderr) + if result.exception: + print(result.exc_info) + raise result.exception + assert not result.return_value + assert os.path.exists(output_path) + + def test_tokenize_textgrid( multilingual_ipa_tg_corpus_dir, test_tokenizer_model, @@ -90,3 +115,35 @@ def test_tokenize_textgrid( raise result.exception assert not result.return_value assert os.path.exists(output_directory) + + +def test_tokenize_textgrid_phonetisaurus( + multilingual_ipa_tg_corpus_dir, + test_tokenizer_model_phonetisaurus, + generated_dir, + temp_dir, + g2p_config_path, +): + output_directory = generated_dir.joinpath("tokenized_tg") + command = [ + "tokenize", + multilingual_ipa_tg_corpus_dir, + test_tokenizer_model_phonetisaurus, + output_directory, + "-t", + os.path.join(temp_dir, "tokenizer_cli"), + "-q", + "--clean", + "--debug", + ] + command = [str(x) for x in command] + result = click.testing.CliRunner(mix_stderr=False, echo_stdin=True).invoke( + mfa_cli, command, catch_exceptions=True + ) + print(result.stdout) + print(result.stderr) + if result.exception: + print(result.exc_info) + raise result.exception + assert not result.return_value + assert os.path.exists(output_directory)