From 2569caaed42b8618ecf7fa19102251367383516b Mon Sep 17 00:00:00 2001 From: Michael McAuliffe Date: Sun, 16 Apr 2023 19:36:38 -0700 Subject: [PATCH] 2.2.8 (#615) * Bug fixes * Fix bug in filter training utterances --- Dockerfile | 2 +- docs/source/changelog/changelog_2.2.rst | 7 + docs/source/installation.rst | 4 +- environment.yml | 3 + .../acoustic_modeling/trainer.py | 32 +++- .../alignment/multiprocessing.py | 82 +++++++--- .../alignment/pretrained.py | 3 +- montreal_forced_aligner/command_line/mfa.py | 8 +- .../command_line/server.py | 5 + .../command_line/train_tokenizer.py | 2 +- montreal_forced_aligner/command_line/utils.py | 21 +-- montreal_forced_aligner/config.py | 10 +- montreal_forced_aligner/corpus/base.py | 4 +- montreal_forced_aligner/corpus/features.py | 54 ++++--- montreal_forced_aligner/corpus/helper.py | 102 +----------- .../corpus/multiprocessing.py | 35 ++--- montreal_forced_aligner/db.py | 16 ++ .../dictionary/multispeaker.py | 148 +++--------------- montreal_forced_aligner/exceptions.py | 1 + montreal_forced_aligner/g2p/generator.py | 13 +- .../g2p/phonetisaurus_trainer.py | 11 +- montreal_forced_aligner/helper.py | 22 +++ montreal_forced_aligner/textgrid.py | 2 +- .../tokenization/trainer.py | 8 +- requirements.txt | 10 +- tests/conftest.py | 3 +- tests/test_commandline_align.py | 15 -- tests/test_corpus.py | 2 - 28 files changed, 284 insertions(+), 341 deletions(-) diff --git a/Dockerfile b/Dockerfile index fd530be2..e436af43 100644 --- a/Dockerfile +++ b/Dockerfile @@ -11,7 +11,7 @@ RUN useradd -ms /bin/bash mfauser RUN chown -R mfauser /mfa RUN chown -R mfauser /env USER mfauser -ENV MFA_ROOT_ENVIRONMENT_VARIABLE=/mfa +ENV MFA_ROOT_DIR=/mfa RUN conda run -p /env mfa server init RUN echo "source activate /env && mfa server start" > ~/.bashrc diff --git a/docs/source/changelog/changelog_2.2.rst b/docs/source/changelog/changelog_2.2.rst index e3fb1093..dbd30f2a 100644 --- a/docs/source/changelog/changelog_2.2.rst +++ b/docs/source/changelog/changelog_2.2.rst @@ -5,6 +5,13 @@ 2.2 Changelog ************* +2.2.8 +===== +- Fixed a bug introduced in 2.2.4 that made segments overlap with silence intervals when using textgrid cleanup +- Changed databases to always use the root MFA rather than rely on temporary directories to make it more consistent where database files and sockets will get placed. This root directory can be changed via the environment variable :code:`MFA_ROOT_DIR` +- Optimized training graph and collecting alignments after changes to how unknown words were represented internally +- Changed feature generation to use piped audio loaded via PySoundFile rather than via calls to sox/ffmpeg directly + 2.2.7 ===== diff --git a/docs/source/installation.rst b/docs/source/installation.rst index 764b50df..2d6d8313 100644 --- a/docs/source/installation.rst +++ b/docs/source/installation.rst @@ -70,7 +70,7 @@ A simple Dockerfile for installing MFA would be: RUN chown -R mfauser /mfa RUN chown -R mfauser /env USER mfauser - ENV MFA_ROOT_ENVIRONMENT_VARIABLE=/mfa + ENV MFA_ROOT_DIR=/mfa RUN conda run -p /env mfa server init RUN echo "source activate /env && mfa server start" > ~/.bashrc @@ -84,7 +84,7 @@ Crucially, note the useradd and subsequent user commands: RUN chown -R mfauser /mfa RUN chown -R mfauser /env USER mfauser - ENV MFA_ROOT_ENVIRONMENT_VARIABLE=/mfa + ENV MFA_ROOT_DIR=/mfa RUN conda run -p /env mfa server init These lines ensure that the database is initialized without using Docker's default root user, avoiding a permissions error thrown by PostGreSQL. diff --git a/environment.yml b/environment.yml index 9aed87c4..82a95213 100644 --- a/environment.yml +++ b/environment.yml @@ -6,6 +6,7 @@ dependencies: - python>=3.8 - numpy - librosa + - pysoundfile - tqdm - requests - pyyaml @@ -13,8 +14,10 @@ dependencies: - kaldi=*=*cpu* - sox - ffmpeg + - scipy - pynini - openfst + - scikit-learn - hdbscan - baumwelch - ngram diff --git a/montreal_forced_aligner/acoustic_modeling/trainer.py b/montreal_forced_aligner/acoustic_modeling/trainer.py index 9d0940f5..471c4271 100644 --- a/montreal_forced_aligner/acoustic_modeling/trainer.py +++ b/montreal_forced_aligner/acoustic_modeling/trainer.py @@ -21,7 +21,14 @@ from montreal_forced_aligner.abc import KaldiFunction, ModelExporterMixin, TopLevelMfaWorker from montreal_forced_aligner.config import GLOBAL_CONFIG from montreal_forced_aligner.data import MfaArguments, WorkflowType -from montreal_forced_aligner.db import CorpusWorkflow, Dictionary, Job +from montreal_forced_aligner.db import ( + CorpusWorkflow, + Dictionary, + Job, + Speaker, + Utterance, + bulk_update, +) from montreal_forced_aligner.exceptions import ConfigError, KaldiProcessingError from montreal_forced_aligner.helper import load_configuration, mfa_open, parse_old_features from montreal_forced_aligner.models import AcousticModel, DictionaryModel @@ -333,6 +340,28 @@ def setup_trainers(self): wf.current = True session.commit() + def filter_training_utterances(self): + logger.info("Filtering utterances with only unknown words...") + with self.session() as session: + dictionaries = session.query(Dictionary) + for d in dictionaries: + update_mapping = [] + word_mapping = d.word_mapping + utterances = ( + session.query(Utterance.id, Utterance.normalized_text) + .join(Utterance.speaker) + .filter(Utterance.ignored == False) # noqa + .filter(Speaker.dictionary_id == d.id) + ) + for u_id, text in utterances: + words = text.split() + if any(x in word_mapping for x in words): + continue + update_mapping.append({"id": u_id, "ignored": True}) + if update_mapping: + bulk_update(session, Utterance, update_mapping) + session.commit() + def setup(self) -> None: """Setup for acoustic model training""" super().setup() @@ -342,6 +371,7 @@ def setup(self) -> None: try: self.load_corpus() self.setup_trainers() + self.filter_training_utterances() except Exception as e: if isinstance(e, KaldiProcessingError): log_kaldi_errors(e.error_logs) diff --git a/montreal_forced_aligner/alignment/multiprocessing.py b/montreal_forced_aligner/alignment/multiprocessing.py index 4d9e9e98..6b8e3b54 100644 --- a/montreal_forced_aligner/alignment/multiprocessing.py +++ b/montreal_forced_aligner/alignment/multiprocessing.py @@ -24,6 +24,7 @@ import pynini import pywrapfst import sqlalchemy +from pynini.lib import rewrite from sqlalchemy.orm import Session, joinedload, selectinload, subqueryload from montreal_forced_aligner.corpus.features import ( @@ -53,7 +54,7 @@ Word, ) from montreal_forced_aligner.exceptions import AlignmentExportError, FeatureGenerationError -from montreal_forced_aligner.helper import mfa_open, split_phone_position +from montreal_forced_aligner.helper import align_pronunciations, mfa_open, split_phone_position from montreal_forced_aligner.textgrid import ( construct_output_path, construct_output_tiers, @@ -95,6 +96,8 @@ "GeneratePronunciationsFunction", ] +logger = logging.getLogger("mfa") + def phones_to_prons( text: str, @@ -104,9 +107,11 @@ def phones_to_prons( phone_symbol_table: pywrapfst.SymbolTableView, optional_silence_phone: str, transcription: bool = False, - clitic_marker=None, + clitic_marker: str = None, + oov_word: str = None, + use_g2p: bool = False, ): - if "" in text: + if use_g2p: words = [x.replace(" ", "") for x in text.split("")] else: words = text.split() @@ -114,7 +119,11 @@ def phones_to_prons( word_end = "#2" word_begin_symbol = phone_symbol_table.find(word_begin) word_end_symbol = phone_symbol_table.find(word_end) - acceptor = pynini.accep(text, token_type=word_symbol_table) + if use_g2p: + kaldi_text = text + else: + kaldi_text = " ".join([x if word_symbol_table.member(x) else oov_word for x in words]) + acceptor = pynini.accep(kaldi_text, token_type=word_symbol_table) phone_to_word = pynini.compose(align_lexicon_fst, acceptor) phone_fst = pynini.Fst() current_state = phone_fst.add_state() @@ -183,25 +192,27 @@ def phones_to_prons( try: path_string = pynini.shortestpath(lattice).project("input").string(phone_symbol_table) except Exception: - logging.debug("For the text and intervals:") - logging.debug(text) - logging.debug([x.label for x in intervals]) - logging.debug("There was an issue composing word and phone FSTs") - logging.debug("PHONE FST:") + logger.debug("For the text and intervals:") + logger.debug(text) + logger.debug(kaldi_text) + logger.debug([x.label for x in intervals]) + logger.debug("There was an issue composing word and phone FSTs") + logger.debug("PHONE FST:") phone_fst.set_input_symbols(phone_symbol_table) phone_fst.set_output_symbols(phone_symbol_table) - logging.debug(phone_fst) - logging.debug("PHONE_TO_WORD FST:") + logger.debug(phone_fst) + logger.debug("PHONE_TO_WORD FST:") phone_to_word.set_input_symbols(phone_symbol_table) phone_to_word.set_output_symbols(word_symbol_table) - logging.debug(phone_to_word) + logger.debug(phone_to_word) raise path_string = path_string.replace(f"{word_end} {word_begin}", word_begin) path_string = path_string.replace(f"{word_end}", word_begin) + path_string = re.sub(f"^{word_begin} ", "", path_string) word_splits = re.split(rf" ?{word_begin} ?", path_string) - word_splits = [x.split() for x in word_splits if x != optional_silence_phone and x] - - return list(zip(words, word_splits)) + word_splits = [x.split() for x in word_splits if x != optional_silence_phone] + pronunciations = align_pronunciations(words, list(zip(words, word_splits)), oov_word) + return pronunciations @dataclass @@ -568,20 +579,27 @@ def _run(self) -> typing.Generator[typing.Tuple[int, int]]: workflow.working_directory, f"{self.job_name}.ha_out_disambig.temp" ) text_int_paths = job.per_dictionary_text_int_scp_paths + batch_size = 1000 if self.use_g2p: - import pynini - from pynini.lib import rewrite from montreal_forced_aligner.g2p.generator import threshold_lattice_to_dfa for d in job.dictionaries: + log_file.write(f"Compiling graphs for {d.name} ({d.id})...\n") fst = pynini.Fst.read(d.lexicon_fst_path) - token_type = pynini.SymbolTable.read_text(d.grapheme_symbol_table_path) + words = d.word_mapping + if self.use_g2p: + token_type = pywrapfst.SymbolTable.read_text(d.grapheme_symbol_table_path) + text_column = Utterance.normalized_character_text + else: + token_type = pywrapfst.SymbolTable.read_text(d.words_symbol_path) + text_column = Utterance.normalized_text + fst.invert() utterances = ( - session.query(Utterance.kaldi_id, Utterance.normalized_character_text) + session.query(Utterance.kaldi_id, text_column) .join(Utterance.speaker) .filter(Utterance.ignored == False) # noqa - .filter(Utterance.normalized_character_text != "") + .filter(text_column != "") .filter(Utterance.job_id == self.job_name) .filter(Speaker.dictionary_id == d.id) .order_by(Utterance.kaldi_id) @@ -593,8 +611,19 @@ def _run(self) -> typing.Generator[typing.Tuple[int, int]]: with mfa_open(fst_ark_path, "wb") as fst_output_file: for utt_id, full_text in utterances: try: - lattice = rewrite.rewrite_lattice(full_text, fst, token_type) - lattice = threshold_lattice_to_dfa(lattice, 2.0) + if self.use_g2p: + lattice = rewrite.rewrite_lattice(full_text, fst, token_type) + lattice = threshold_lattice_to_dfa(lattice, 2.0) + else: + text = " ".join( + [ + x if x in words else d.oov_word + for x in full_text.split() + ] + ) + a = pynini.accep(text, token_type=token_type) + lattice = rewrite.rewrite_lattice(a, fst) + lattice.invert() input = lattice.write_to_string() except pynini.lib.rewrite.Error: log_file.write(f'Error composing "{full_text}"\n') @@ -703,6 +732,7 @@ def _run(self) -> typing.Generator[typing.Tuple[int, int]]: else: for d in job.dictionaries: + log_file.write(f"Compiling graphs for {d}") fst_ark_path = job.construct_path( workflow.working_directory, "fsts", "ark", d.id ) @@ -711,6 +741,7 @@ def _run(self) -> typing.Generator[typing.Tuple[int, int]]: [ thirdparty_binary("compile-train-graphs"), f"--read-disambig-syms={d.disambiguation_symbols_int_path}", + f"--batch-size={batch_size}", self.tree_path, self.model_path, d.lexicon_fst_path, @@ -723,6 +754,7 @@ def _run(self) -> typing.Generator[typing.Tuple[int, int]]: ) for line in proc.stderr: log_file.write(line) + log_file.flush() m = self.progress_pattern.match(line.strip()) if m: yield int(m.group("succeeded")), int(m.group("failed")) @@ -1766,6 +1798,7 @@ def _run(self) -> typing.Generator[typing.Tuple[int, int, str]]: self.word_symbol_table, self.phone_symbol_table, self.optional_silence_phone, + oov_word=self.oov_word, ) if d.position_dependent_phones: word_pronunciations = [ @@ -1837,7 +1870,7 @@ def compile_information_func( if decode_error_match: data["unaligned"].append(decode_error_match.group("utt")) continue - log_like_match = re.match(log_like_pattern, line) + log_like_match = re.search(log_like_pattern, line) if log_like_match: log_like = log_like_match.group("log_like") frames = log_like_match.group("frames") @@ -1923,6 +1956,7 @@ def cleanup_intervals( self.phone_symbol_table, self.optional_silence_phone, self.transcription, + oov_word=self.oov_word, ) actual_phone_intervals = [] actual_word_intervals = [] @@ -2018,6 +2052,8 @@ def cleanup_g2p_intervals( self.phone_symbol_table, self.optional_silence_phone, clitic_marker=self.clitic_marker, + oov_word=self.oov_word, + use_g2p=True, ) actual_phone_intervals = [] actual_word_intervals = [] diff --git a/montreal_forced_aligner/alignment/pretrained.py b/montreal_forced_aligner/alignment/pretrained.py index 4ffd1033..1d3fd35e 100644 --- a/montreal_forced_aligner/alignment/pretrained.py +++ b/montreal_forced_aligner/alignment/pretrained.py @@ -296,10 +296,11 @@ def align_one_utterance(self, utterance: Utterance, session: Session) -> None: if not sox_string: sox_string = utterance.file.sound_file.sound_file_path text_int_path = self.working_directory.joinpath("text.int") + word_mapping = self.word_mapping(utterance.speaker.dictionary_id) with mfa_open(text_int_path, "w") as f: normalized_text_int = " ".join( [ - str(self.word_mapping(utterance.speaker.dictionary_id)[x]) + str(word_mapping[x]) if x in word_mapping else str(word_mapping[self.oov_word]) for x in utterance.normalized_text.split() ] ) diff --git a/montreal_forced_aligner/command_line/mfa.py b/montreal_forced_aligner/command_line/mfa.py index 526c6079..8d9572d2 100644 --- a/montreal_forced_aligner/command_line/mfa.py +++ b/montreal_forced_aligner/command_line/mfa.py @@ -3,7 +3,6 @@ import atexit import multiprocessing as mp -import os import sys import time import warnings @@ -34,11 +33,7 @@ validate_corpus_cli, validate_dictionary_cli, ) -from montreal_forced_aligner.config import ( - GLOBAL_CONFIG, - MFA_PROFILE_VARIABLE, - update_command_history, -) +from montreal_forced_aligner.config import GLOBAL_CONFIG, update_command_history from montreal_forced_aligner.utils import check_third_party BEGIN = time.time() @@ -118,7 +113,6 @@ def mfa_cli(ctx: click.Context) -> None: auto_server = False run_check = True if ctx.invoked_subcommand == "anchor": - os.environ[MFA_PROFILE_VARIABLE] = "anchor" GLOBAL_CONFIG.current_profile.clean = False GLOBAL_CONFIG.save() diff --git a/montreal_forced_aligner/command_line/server.py b/montreal_forced_aligner/command_line/server.py index 519ef24e..a0ad95bc 100644 --- a/montreal_forced_aligner/command_line/server.py +++ b/montreal_forced_aligner/command_line/server.py @@ -4,6 +4,7 @@ import rich_click as click from montreal_forced_aligner.command_line.utils import ( + common_options, delete_server, initialize_server, start_server, @@ -28,6 +29,7 @@ def server_cli(): default=None, ) @click.help_option("-h", "--help") +@common_options @click.pass_context def init_cli(context, **kwargs): if kwargs.get("profile", None) is not None: @@ -46,6 +48,7 @@ def init_cli(context, **kwargs): default=None, ) @click.help_option("-h", "--help") +@common_options @click.pass_context def start_cli(context, **kwargs): if kwargs.get("profile", None) is not None: @@ -71,6 +74,7 @@ def start_cli(context, **kwargs): default="fast", ) @click.help_option("-h", "--help") +@common_options @click.pass_context def stop_cli(context, **kwargs): if kwargs.get("profile", None) is not None: @@ -89,6 +93,7 @@ def stop_cli(context, **kwargs): default=None, ) @click.help_option("-h", "--help") +@common_options @click.pass_context def delete_cli(context, **kwargs): if kwargs.get("profile", None) is not None: diff --git a/montreal_forced_aligner/command_line/train_tokenizer.py b/montreal_forced_aligner/command_line/train_tokenizer.py index cfbe8bc4..c2419c76 100644 --- a/montreal_forced_aligner/command_line/train_tokenizer.py +++ b/montreal_forced_aligner/command_line/train_tokenizer.py @@ -1,4 +1,4 @@ -"""Command line functions for training G2P models""" +"""Command line functions for training tokenizer models""" from __future__ import annotations import os diff --git a/montreal_forced_aligner/command_line/utils.py b/montreal_forced_aligner/command_line/utils.py index c55910fa..f760d1c2 100644 --- a/montreal_forced_aligner/command_line/utils.py +++ b/montreal_forced_aligner/command_line/utils.py @@ -292,13 +292,13 @@ def initialize_server() -> None: logger = logging.getLogger("mfa") logger.info(f"Initializing the {GLOBAL_CONFIG.current_profile_name} MFA database server...") - db_directory = GLOBAL_CONFIG.current_profile.temporary_directory.joinpath( + db_directory = GLOBAL_CONFIG.root_temporary_directory.joinpath( f"pg_mfa_{GLOBAL_CONFIG.current_profile_name}" ) - init_log_path = GLOBAL_CONFIG.current_profile.temporary_directory.joinpath( + init_log_path = GLOBAL_CONFIG.root_temporary_directory.joinpath( f"pg_init_log_{GLOBAL_CONFIG.current_profile_name}.txt" ) - GLOBAL_CONFIG.current_profile.temporary_directory.mkdir(parents=True, exist_ok=True) + GLOBAL_CONFIG.root_temporary_directory.mkdir(parents=True, exist_ok=True) if db_directory.exists(): logger.error( "The server directory already exists, if you would like to make a new server, please run `mfa server delete` first, or run `mfa server start` to start the existing one." @@ -350,10 +350,10 @@ def check_server() -> None: GLOBAL_CONFIG.load() logger = logging.getLogger("mfa") - db_directory = GLOBAL_CONFIG.current_profile.temporary_directory.joinpath( + db_directory = GLOBAL_CONFIG.root_temporary_directory.joinpath( f"pg_mfa_{GLOBAL_CONFIG.current_profile_name}" ) - log_path = GLOBAL_CONFIG.current_profile.temporary_directory.joinpath( + log_path = GLOBAL_CONFIG.root_temporary_directory.joinpath( f"pg_log_{GLOBAL_CONFIG.current_profile_name}.txt" ) if not db_directory.exists(): @@ -398,7 +398,7 @@ def start_server() -> None: except Exception: pass - db_directory = GLOBAL_CONFIG.current_profile.temporary_directory.joinpath( + db_directory = GLOBAL_CONFIG.root_temporary_directory.joinpath( f"pg_mfa_{GLOBAL_CONFIG.current_profile_name}" ) if not db_directory.exists(): @@ -407,8 +407,9 @@ def start_server() -> None: ) initialize_server() return + assert os.path.exists(GLOBAL_CONFIG.database_socket) logger.info(f"Starting the {GLOBAL_CONFIG.current_profile_name} MFA database server...") - log_path = GLOBAL_CONFIG.current_profile.temporary_directory.joinpath( + log_path = GLOBAL_CONFIG.root_temporary_directory.joinpath( f"pg_log_{GLOBAL_CONFIG.current_profile_name}.txt" ) try: @@ -444,10 +445,10 @@ def stop_server(mode: str = "fast") -> None: logger = logging.getLogger("mfa") GLOBAL_CONFIG.load() - db_directory = GLOBAL_CONFIG.current_profile.temporary_directory.joinpath( + db_directory = GLOBAL_CONFIG.root_temporary_directory.joinpath( f"pg_mfa_{GLOBAL_CONFIG.current_profile_name}" ) - log_path = GLOBAL_CONFIG.current_profile.temporary_directory.joinpath( + log_path = GLOBAL_CONFIG.root_temporary_directory.joinpath( f"pg_log_{GLOBAL_CONFIG.current_profile_name}.txt" ) if not db_directory.exists(): @@ -481,7 +482,7 @@ def delete_server() -> None: logger = logging.getLogger("mfa") GLOBAL_CONFIG.load() - db_directory = GLOBAL_CONFIG.current_profile.temporary_directory.joinpath( + db_directory = GLOBAL_CONFIG.root_temporary_directory.joinpath( f"pg_mfa_{GLOBAL_CONFIG.current_profile_name}" ) if db_directory.exists(): diff --git a/montreal_forced_aligner/config.py b/montreal_forced_aligner/config.py index 8913c613..3a080a04 100644 --- a/montreal_forced_aligner/config.py +++ b/montreal_forced_aligner/config.py @@ -209,8 +209,14 @@ def __getitem__(self, item): return getattr(self.current_profile, item) @property - def database_socket(self): - p = get_temporary_directory().joinpath(f"pg_mfa_{self.current_profile_name}_socket") + def root_temporary_directory(self): + return pathlib.Path( + os.environ.get(MFA_ROOT_ENVIRONMENT_VARIABLE, "~/Documents/MFA") + ).expanduser() + + @property + def database_socket(self) -> str: + p = self.root_temporary_directory.joinpath(f"pg_mfa_{self.current_profile_name}_socket") p.mkdir(parents=True, exist_ok=True) return p.as_posix() diff --git a/montreal_forced_aligner/corpus/base.py b/montreal_forced_aligner/corpus/base.py index ede1753b..1d191b3c 100644 --- a/montreal_forced_aligner/corpus/base.py +++ b/montreal_forced_aligner/corpus/base.py @@ -1074,7 +1074,7 @@ def create_subset(self, subset: int) -> None: session.query(Utterance.id) .join(Utterance.speaker) .filter(Speaker.dictionary_id == dict_id) - .filter(Utterance.text.like("% %")) + .filter(Utterance.text.op("~")(r" [^ ]+ ")) .filter(Utterance.ignored == False) # noqa .order_by(Utterance.duration) .limit(larger_subset_num) @@ -1142,7 +1142,7 @@ def create_subset(self, subset: int) -> None: # Get all shorter utterances that are not one word long larger_subset_query = ( session.query(Utterance.id) - .filter(Utterance.text.like("% %")) + .filter(Utterance.text.op("~")(r"\s\S+\s")) .filter(Utterance.ignored == False) # noqa .order_by(Utterance.duration) .limit(larger_subset_num) diff --git a/montreal_forced_aligner/corpus/features.py b/montreal_forced_aligner/corpus/features.py index 458dd8ec..6297674d 100644 --- a/montreal_forced_aligner/corpus/features.py +++ b/montreal_forced_aligner/corpus/features.py @@ -9,12 +9,15 @@ import subprocess import typing from abc import abstractmethod +from io import BytesIO from pathlib import Path from typing import TYPE_CHECKING, Any, Dict, List, Union import dataclassy +import librosa import numba import numpy as np +import soundfile import sqlalchemy from numba import njit from scipy.sparse import csr_matrix @@ -23,7 +26,7 @@ from montreal_forced_aligner.abc import KaldiFunction from montreal_forced_aligner.config import IVECTOR_DIMENSION, PLDA_DIMENSION from montreal_forced_aligner.data import M_LOG_2PI, MfaArguments -from montreal_forced_aligner.db import Job, Utterance +from montreal_forced_aligner.db import File, Job, SoundFile, Utterance from montreal_forced_aligner.exceptions import KaldiProcessingError from montreal_forced_aligner.helper import mfa_open from montreal_forced_aligner.utils import read_feats, thirdparty_binary @@ -482,25 +485,11 @@ def _run(self) -> typing.Generator[int]: job: typing.Optional[Job] = session.get(Job, self.job_name) feats_scp_path = job.construct_path(self.data_directory, "feats", "scp") pitch_scp_path = job.construct_path(self.data_directory, "pitch", "scp") - segments_scp_path = job.construct_path(self.data_directory, "segments", "scp") wav_path = job.construct_path(self.data_directory, "wav", "scp") raw_ark_path = job.construct_path(self.data_directory, "feats", "ark") raw_pitch_ark_path = job.construct_path(self.data_directory, "pitch", "ark") if os.path.exists(raw_ark_path): return - min_length = 0.1 - seg_proc = subprocess.Popen( - [ - thirdparty_binary("extract-segments"), - f"--min-segment-length={min_length}", - f"scp:{wav_path}", - segments_scp_path, - "ark:-", - ], - stdout=subprocess.PIPE, - stderr=log_file, - env=os.environ, - ) mfcc_proc = compute_mfcc_process( log_file, wav_path, subprocess.PIPE, self.mfcc_options ) @@ -531,14 +520,39 @@ def _run(self) -> typing.Generator[int]: stderr=log_file, env=os.environ, ) - for line in seg_proc.stdout: - mfcc_proc.stdin.write(line) + min_length = 0.1 + utterances = ( + session.query(Utterance, SoundFile) + .join(Utterance.file) + .join(File.sound_file) + .filter( + Utterance.job_id == self.job_name, + Utterance.ignored == False, # noqa + Utterance.duration >= min_length, + ) + .order_by(Utterance.kaldi_id) + ) + for u, sf in utterances: + + wave, _ = librosa.load( + sf.sound_file_path, + sr=16000, + offset=u.begin, + duration=u.duration, + mono=False, + ) + if len(wave.shape) == 2: + wave = wave[u.channel, :] + bio = BytesIO() + soundfile.write(bio, wave, samplerate=16000, format="WAV") + mfcc_proc.stdin.write(f"{u.kaldi_id}\t".encode("utf8")) + mfcc_proc.stdin.write(bio.getvalue()) mfcc_proc.stdin.flush() if use_pitch: - pitch_proc.stdin.write(line) + pitch_proc.stdin.write(f"{u.kaldi_id}\t".encode("utf8")) + pitch_proc.stdin.write(bio.getvalue()) pitch_proc.stdin.flush() - if re.search(rb"\d+-\d+ ", line): - yield 1 + yield 1 mfcc_proc.stdin.close() if use_pitch: pitch_proc.stdin.close() diff --git a/montreal_forced_aligner/corpus/helper.py b/montreal_forced_aligner/corpus/helper.py index 0b00a184..50ab6cbd 100644 --- a/montreal_forced_aligner/corpus/helper.py +++ b/montreal_forced_aligner/corpus/helper.py @@ -1,16 +1,12 @@ """Helper functions for corpus parsing and loading""" from __future__ import annotations -import re -import subprocess -import sys import typing from pathlib import Path import soundfile from montreal_forced_aligner.data import FileExtensions, SoundFileInformation -from montreal_forced_aligner.exceptions import SoundFileError from montreal_forced_aligner.helper import mfa_open SoundFileInfoDict = typing.Dict[str, typing.Union[int, float, str]] @@ -112,99 +108,11 @@ def get_wav_info( if isinstance(file_path, str): file_path = Path(file_path) format = file_path.suffix.lower() - num_channels = 0 - sample_rate = 0 - duration = 0 sox_string = "" - if format in {".mp3", ".opus"}: - if sys.platform != "win32" and format == ".mp3": - sox_proc = subprocess.Popen( - ["soxi", file_path], stderr=subprocess.PIPE, stdout=subprocess.PIPE, text=True - ) - stdout, stderr = sox_proc.communicate() - if stderr: - raise SoundFileError(file_path, stderr) - for line in stdout.splitlines(): - if line.startswith("Channels"): - num_channels = int(line.split(":")[-1].strip()) - elif line.startswith("Sample Rate"): - sample_rate = int(line.split(":")[-1].strip()) - elif line.startswith("Duration"): - m = re.search(r"= (?P\d+) samples", line) - if m: - num_samples = int(m.group("num_samples")) - duration = round(num_samples / sample_rate, 6) - else: - raise SoundFileError(file_path, "Could not parse number of samples") - break - sample_rate_string = "" - if enforce_sample_rate is not None: - sample_rate_string = f" -r {enforce_sample_rate}" - sox_string = f'sox "{file_path}" -t wav -b 16{sample_rate_string} - |' - else: # Fall back use ffmpeg if sox doesn't support the format - ffmpeg_proc = subprocess.Popen( - [ - "ffprobe", - "-v", - "error", - "-hide_banner", - "-show_entries", - "stream=duration,channels,sample_rate", - "-of", - "default=noprint_wrappers=1", - "-i", - file_path, - ], - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - text=True, - ) - stdout, stderr = ffmpeg_proc.communicate() - if stderr: - raise SoundFileError(file_path, stderr) - for line in stdout.splitlines(): - try: - key, value = line.strip().split("=") - if key == "duration": - duration = float(value) - elif key == "sample_rate": - sample_rate = int(value) - else: - num_channels = int(value) - except ValueError: - pass - mono_string = "" - sample_rate_string = "" - if num_channels > 1 and enforce_mono: - mono_string = ' -af "pan=mono|FC=FL"' - if enforce_sample_rate is not None: - sample_rate_string = f" -ar {enforce_sample_rate}" - sox_string = f'ffmpeg -nostdin -hide_banner -loglevel error -nostats -i "{file_path}" -acodec pcm_s16le -f wav{mono_string}{sample_rate_string} - |' - else: - use_sox = False - with soundfile.SoundFile(file_path) as inf: - frames = inf.frames - sample_rate = inf.samplerate - duration = frames / sample_rate - num_channels = inf.channels - try: - bit_depth = int(inf.subtype.split("_")[-1]) - if bit_depth != 16: - use_sox = True - except Exception: - use_sox = True - sample_rate_string = "" - if enforce_sample_rate is not None: - sample_rate_string = f" -r {enforce_sample_rate}" - if format != "wav": - use_sox = True - if num_channels > 1 and enforce_mono: - use_sox = True - elif enforce_sample_rate is not None and sample_rate != enforce_sample_rate: - use_sox = True - if num_channels > 1 and enforce_mono: - sox_string = f'sox "{file_path}" -t wav -b 16{sample_rate_string} - remix 1 |' - elif use_sox: - sox_string = f'sox "{file_path}" -t wav -b 16{sample_rate_string} - |' + with soundfile.SoundFile(file_path) as inf: + frames = inf.frames + sample_rate = inf.samplerate + duration = frames / sample_rate + num_channels = inf.channels return SoundFileInformation(format, sample_rate, duration, num_channels, sox_string) diff --git a/montreal_forced_aligner/corpus/multiprocessing.py b/montreal_forced_aligner/corpus/multiprocessing.py index 8391153d..b013bc56 100644 --- a/montreal_forced_aligner/corpus/multiprocessing.py +++ b/montreal_forced_aligner/corpus/multiprocessing.py @@ -524,9 +524,8 @@ def output_for_features(self, session: Session) -> None: .first() ) wav_scp_path = job.wav_scp_path - segments_scp_path = job.segments_scp_path utt2spk_scp_path = job.utt2spk_scp_path - if os.path.exists(segments_scp_path): + if os.path.exists(utt2spk_scp_path): return with mfa_open(wav_scp_path, "w") as wav_file: files = ( @@ -543,23 +542,13 @@ def output_for_features(self, session: Session) -> None: wav_file.write(f"{f_id} {sox_string}\n") yield 1 - with mfa_open(segments_scp_path, "w") as segments_file, mfa_open( - utt2spk_scp_path, "w" - ) as utt2spk_file: + with mfa_open(utt2spk_scp_path, "w") as utt2spk_file: utterances = ( - session.query( - Utterance.kaldi_id, - Utterance.file_id, - Utterance.speaker_id, - Utterance.begin, - Utterance.end, - Utterance.channel, - ) + session.query(Utterance.kaldi_id, Utterance.speaker_id) .filter(Utterance.job_id == job.id) .order_by(Utterance.kaldi_id) ) - for u_id, f_id, s_id, begin, end, channel in utterances: - segments_file.write(f"{u_id} {f_id} {begin} {end} {channel}\n") + for u_id, s_id in utterances: utt2spk_file.write(f"{u_id} {s_id}\n") yield 1 @@ -701,12 +690,7 @@ def output_to_directory(self, session) -> None: text_ints_paths = job.per_dictionary_text_int_scp_paths for d in job.dictionaries: - words_mapping = {} - words_query = session.query(Word.word, Word.mapping_id).filter( - Word.dictionary_id == d.id - ) - for w, m_id in words_query: - words_mapping[w] = m_id + words_mapping = d.word_mapping spk2utt = {} feats = {} cmvns = {} @@ -730,7 +714,14 @@ def output_to_directory(self, session) -> None: feats[utterance] = features cmvns[speaker] = cmvn words = normalized_text.split() - text_ints[utterance] = " ".join([str(words_mapping[x]) for x in words]) + text_ints[utterance] = " ".join( + [ + str(words_mapping[x]) + if x in words_mapping + else str(words_mapping[d.oov_word]) + for x in words + ] + ) yield 1 with mfa_open(spk2utt_paths[d.id], "w") as f: diff --git a/montreal_forced_aligner/db.py b/montreal_forced_aligner/db.py index 68e4e5b9..84c084ff 100644 --- a/montreal_forced_aligner/db.py +++ b/montreal_forced_aligner/db.py @@ -372,6 +372,22 @@ class Dictionary(MfaSqlBase): back_populates="dictionaries", ) + @property + def word_mapping(self): + if not hasattr(self, "_word_mapping"): + session = sqlalchemy.orm.Session.object_session(self) + query = ( + session.query(Word.word, Word.mapping_id) + .filter(Word.dictionary_id == self.id) + .filter(sqlalchemy.or_(Word.word_type != WordType.oov, Word.word == self.oov_word)) + .filter(Word.word_type != WordType.bracketed) + .order_by(Word.mapping_id) + ) + self._word_mapping = {} + for w, mapping_id in query: + self._word_mapping[w] = mapping_id + return self._word_mapping + @property def special_set(self) -> typing.Set[str]: return { diff --git a/montreal_forced_aligner/dictionary/multispeaker.py b/montreal_forced_aligner/dictionary/multispeaker.py index 2b292c0b..091b2882 100644 --- a/montreal_forced_aligner/dictionary/multispeaker.py +++ b/montreal_forced_aligner/dictionary/multispeaker.py @@ -43,7 +43,7 @@ DictionaryFileError, KaldiProcessingError, ) -from montreal_forced_aligner.helper import mfa_open, split_phone_position +from montreal_forced_aligner.helper import comma_join, mfa_open, split_phone_position from montreal_forced_aligner.models import DictionaryModel, PhoneSetType from montreal_forced_aligner.utils import parse_dictionary_file, thirdparty_binary @@ -123,6 +123,17 @@ def load_phone_groups(self) -> None: self._phone_groups = {k: v for k, v in enumerate(self._phone_groups)} for k, v in self._phone_groups.items(): self._phone_groups[k] = [x for x in v if x in self.non_silence_phones] + found_phones = set() + for phones in self._phone_groups.values(): + found_phones.update(phones) + missing_phones = self.non_silence_phones - found_phones + if missing_phones: + logger.warning( + f"The following phones were missing from the phone group: " + f"{comma_join(sorted(missing_phones))}" + ) + else: + logger.debug("All phones were included in phone groups") @property def speaker_mapping(self) -> typing.Dict[str, int]: @@ -183,13 +194,9 @@ def word_mapping(self, dictionary_id: int = None) -> Dict[str, int]: if dictionary_id is None: dictionary_id = self._default_dictionary_id if dictionary_id not in self._words_mappings: - self._words_mappings[dictionary_id] = {} with self.session() as session: - words = session.query(Word.word, Word.mapping_id).filter( - Word.dictionary_id == dictionary_id - ) - for w, index in words: - self._words_mappings[dictionary_id][w] = index + d = session.get(Dictionary, dictionary_id) + self._words_mappings[dictionary_id] = d.word_mapping return self._words_mappings[dictionary_id] def reversed_word_mapping(self, dictionary_id: int = 1) -> Dict[int, str]: @@ -886,8 +893,9 @@ def _write_probabilistic_fst_text( session.query(bn) .join(Pronunciation.word) .filter(Word.dictionary_id == dictionary.id) - .filter(Word.word_type != WordType.oov) - .filter(Word.count > 0) + .filter(sqlalchemy.or_(Word.word_type != WordType.oov, Word.word == self.oov_word)) + .filter(Word.word_type != WordType.bracketed) + .filter(sqlalchemy.or_(Word.count > 0, Word.word == self.oov_word)) .filter(Word.word_type != WordType.silence) ) for row in pronunciation_query: @@ -902,7 +910,7 @@ def _write_probabilistic_fst_text( silence_after_probability = data.get("silence_after_probability", None) if silence_after_probability is not None: silence_following_cost = -math.log(silence_after_probability) - non_silence_following_cost = -math.log(1 - (silence_after_probability)) + non_silence_following_cost = -math.log(1 - silence_after_probability) silence_before_correction = data.get("silence_before_correction", None) if silence_before_correction is not None: @@ -936,66 +944,6 @@ def _write_probabilistic_fst_text( f"{silence_state}\t{new_state}\t{phones[0]}\t{data['word']}\t{pron_cost+silence_before_cost}\n" ) - next_state += 1 - current_state = new_state - for i in range(1, len(phones)): - new_state = next_state - next_state += 1 - outf.write(f"{current_state}\t{new_state}\t{phones[i]}\t\n") - current_state = new_state - outf.write( - f"{current_state}\t{non_silence_state}\t{silence_disambiguation_symbol}\t\t{non_silence_following_cost}\n" - ) - outf.write( - f"{current_state}\t{silence_state}\t{self.optional_silence_phone}\t\t{silence_following_cost}\n" - ) - oov_pron = ( - session.query(Pronunciation) - .join(Pronunciation.word) - .filter(Word.word == self.oov_word) - .first() - ) - if not disambiguation: - oovs = ( - session.query(Word.word) - .filter(Word.word_type == WordType.oov, Word.dictionary_id == dictionary.id) - .filter(sqlalchemy.or_(Word.count > 0, Word.word.in_(self.specials_set))) - ) - else: - oovs = session.query(Word.word).filter(Word.word == self.oov_word) - - phones = [self.oov_phone] - if self.position_dependent_phones: - phones[0] += "_S" - if alignment: - phones = ["#1"] + phones + ["#2"] - for (w,) in oovs: - silence_before_cost = 0.0 - non_silence_before_cost = 0.0 - silence_following_cost = base_silence_following_cost - non_silence_following_cost = base_non_silence_following_cost - - silence_after_probability = oov_pron.silence_after_probability - if silence_after_probability is not None: - silence_following_cost = -math.log(silence_after_probability) - non_silence_following_cost = -math.log(1 - (silence_after_probability)) - - silence_before_correction = oov_pron.silence_before_correction - if silence_before_correction is not None: - silence_before_cost = -math.log(silence_before_correction) - - non_silence_before_correction = oov_pron.non_silence_before_correction - if non_silence_before_correction is not None: - non_silence_before_cost = -math.log(non_silence_before_correction) - pron_cost = 0.0 - new_state = next_state - outf.write( - f"{non_silence_state}\t{new_state}\t{phones[0]}\t{w}\t{pron_cost+non_silence_before_cost}\n" - ) - outf.write( - f"{silence_state}\t{new_state}\t{phones[0]}\t{w}\t{pron_cost+silence_before_cost}\n" - ) - next_state += 1 current_state = new_state for i in range(1, len(phones)): @@ -1079,54 +1027,14 @@ def _write_align_lexicon( ), ) - oovs = session.query(Word.word).filter( - Word.word_type == WordType.oov, - sqlalchemy.or_(Word.count > 0, Word.word.in_(self.specials_set)), - ) - for (w,) in oovs: - pron = [self.oov_phone] - if self.position_dependent_phones: - pron[0] += "_S" - pron = ["#1"] + pron + ["#2"] - current_state = loop_state - for i in range(len(pron) - 1): - p_s = phone_symbol_table.find(pron[i]) - if i == 0: - w_s = word_symbol_table.find(w) - else: - w_s = word_eps_symbol - fst.add_arc( - current_state, - pywrapfst.Arc(p_s, w_s, pywrapfst.Weight.one(fst.weight_type()), next_state), - ) - current_state = next_state - next_state = fst.add_state() - i = len(pron) - 1 - if i >= 0: - p_s = phone_symbol_table.find(pron[i]) - else: - p_s = phone_eps_symbol - if i <= 0: - w_s = word_symbol_table.find(w) - else: - w_s = word_eps_symbol - fst.add_arc( - current_state, - pywrapfst.Arc( - p_s, w_s, pywrapfst.Weight(fst.weight_type(), non_sil_cost), loop_state - ), - ) - fst.add_arc( - current_state, - pywrapfst.Arc(p_s, w_s, pywrapfst.Weight(fst.weight_type(), sil_cost), sil_state), - ) pronunciation_query = ( session.query(Word.word, Pronunciation.pronunciation) .join(Pronunciation.word) .filter(Word.dictionary_id == dictionary.id) - .filter( - Word.word_type != WordType.silence, Word.word_type != WordType.oov, Word.count > 0 - ) + .filter(sqlalchemy.or_(Word.word_type != WordType.oov, Word.word == self.oov_word)) + .filter(Word.word_type != WordType.bracketed) + .filter(sqlalchemy.or_(Word.count > 0, Word.word == self.oov_word)) + .filter(Word.word_type != WordType.silence) ) for w, pron in pronunciation_query: pron = pron.split() @@ -1633,11 +1541,12 @@ def save_oovs_found(self, directory: str) -> None: def find_all_cutoffs(self) -> None: """Find all instances of cutoff words followed by actual words""" - logger.info("Finding all cutoffs...") with self.session() as session: c = session.query(Corpus).first() if c.cutoffs_found: + logger.debug("Cutoffs already found") return + logger.info("Finding all cutoffs...") initial_brackets = re.escape("".join(x[0] for x in self.brackets)) final_brackets = re.escape("".join(x[1] for x in self.brackets)) pronunciation_mapping = {} @@ -1805,13 +1714,8 @@ def _write_word_file(self, dictionary: Dictionary) -> None: Write the word mapping to the temporary directory """ self._words_mappings = {} - with mfa_open(dictionary.words_symbol_path, "w") as f, self.session() as session: - words = ( - session.query(Word.word, Word.mapping_id) - .filter(Word.dictionary_id == dictionary.id) - .order_by(Word.mapping_id) - ) - for w, i in words: + with mfa_open(dictionary.words_symbol_path, "w") as f: + for w, i in dictionary.word_mapping.items(): f.write(f"{w} {i}\n") diff --git a/montreal_forced_aligner/exceptions.py b/montreal_forced_aligner/exceptions.py index 21e09b42..739b7bcf 100644 --- a/montreal_forced_aligner/exceptions.py +++ b/montreal_forced_aligner/exceptions.py @@ -31,6 +31,7 @@ "LanguageModelNotFoundError", "ModelExtensionError", "ThirdpartyError", + "MultiprocessingError", "TrainerError", "ModelError", "CorpusError", diff --git a/montreal_forced_aligner/g2p/generator.py b/montreal_forced_aligner/g2p/generator.py index bb054cb1..07e0e9f6 100644 --- a/montreal_forced_aligner/g2p/generator.py +++ b/montreal_forced_aligner/g2p/generator.py @@ -171,10 +171,14 @@ def __call__(self, graphemes: str) -> List[str]: # pragma: no cover hypotheses = [] for w in words: w_fst = self.create_word_fst(w) + if not w_fst: + continue hypotheses.append(self.rewrite(w_fst)) hypotheses = sorted(set(" ".join(x) for x in itertools.product(*hypotheses))) else: fst = self.create_word_fst(graphemes) + if not fst: + return [] hypotheses = self.rewrite(fst) return [x for x in hypotheses if x] @@ -234,9 +238,11 @@ def __init__( output_token_type=output_token_type, ) - def create_word_fst(self, word: str) -> pynini.Fst: + def create_word_fst(self, word: str) -> typing.Optional[pynini.Fst]: if self.graphemes is not None: word = [x for x in word if x in self.graphemes] + if not word: + return None fst = pynini.Fst() one = pywrapfst.Weight.one(fst.weight_type()) max_state = 0 @@ -265,10 +271,14 @@ def __call__(self, graphemes: str) -> List[str]: # pragma: no cover hypotheses = [] for w in words: w_fst = self.create_word_fst(w) + if not w_fst: + continue hypotheses.append(self.rewrite(w_fst)) hypotheses = sorted(set(" ".join(x) for x in itertools.product(*hypotheses))) else: fst = self.create_word_fst(graphemes) + if not fst: + return [] hypotheses = self.rewrite(fst) hypotheses = [x.replace(self.seq_sep, " ") for x in hypotheses if x] return hypotheses @@ -483,6 +493,7 @@ def setup(self): self.output_token_type, num_pronunciations=self.num_pronunciations, threshold=self.g2p_threshold, + graphemes=self.g2p_model.meta["graphemes"], ) def generate_pronunciations(self) -> Dict[str, List[str]]: diff --git a/montreal_forced_aligner/g2p/phonetisaurus_trainer.py b/montreal_forced_aligner/g2p/phonetisaurus_trainer.py index a041bd87..47d608f5 100644 --- a/montreal_forced_aligner/g2p/phonetisaurus_trainer.py +++ b/montreal_forced_aligner/g2p/phonetisaurus_trainer.py @@ -291,7 +291,7 @@ def run(self) -> None: pywrapfst.Arc( ilabel, ilabel, - pynini.Weight( + pywrapfst.Weight( "log", float(input_range * output_range) ), ostate, @@ -306,7 +306,7 @@ def run(self) -> None: if sym not in data: data[sym] = arc.weight else: - data[sym] = pynini.plus(data[sym], arc.weight) + data[sym] = pywrapfst.plus(data[sym], arc.weight) if count >= self.batch_size: data = {k: float(v) for k, v in data.items()} self.return_queue.put((self.job_name, data, count)) @@ -1359,9 +1359,12 @@ def export_alignments(self) -> None: logger.info("Done exporting alignments!") def compute_initial_ngrams(self) -> None: - logger.info("Computing initial ngrams...") input_path = self.working_directory.joinpath("input.txt") input_ngram_path = self.working_directory.joinpath("input_ngram.fst") + if input_ngram_path.with_suffix(".ngrams").exists(): + logger.info("Initial ngrams already computed") + return + logger.info("Computing initial ngrams...") input_symbols_path = self.working_directory.joinpath("input_ngram.syms") symbols_proc = subprocess.Popen( [ @@ -1517,6 +1520,8 @@ def train(self) -> None: """ Train a G2P model """ + if os.path.exists(self.fst_path): + return os.makedirs(self.working_log_directory, exist_ok=True) begin = time.time() self.train_alignments() diff --git a/montreal_forced_aligner/helper.py b/montreal_forced_aligner/helper.py index c873301e..46c32b89 100644 --- a/montreal_forced_aligner/helper.py +++ b/montreal_forced_aligner/helper.py @@ -42,6 +42,7 @@ "overlap_scoring", "align_phones", "split_phone_position", + "align_pronunciations", "configure_logger", "mfa_open", "load_configuration", @@ -573,6 +574,27 @@ def default(self, o: Any) -> Any: return dataclassy.asdict(o) +def align_pronunciations( + ref_text: List[str], pronunciations: List[Tuple[str, str]], unknown_word: str +): + def score_function(ref: str, pron: Tuple[str, str]): + if ref == pron[0]: + return 0 + if pron[0] == unknown_word: + return 0 + return -2 + + alignments = pairwise2.align.globalcs( + ref_text, pronunciations, score_function, -5, -5, gap_char=["-"], one_alignment_only=True + ) + transformed_pronunciations = [] + for a in alignments: + for i, sa in enumerate(a.seqA): + sb = a.seqB[i] + transformed_pronunciations.append((sa, sb[1])) + return transformed_pronunciations + + def align_phones( ref: List[CtmInterval], test: List[CtmInterval], diff --git a/montreal_forced_aligner/textgrid.py b/montreal_forced_aligner/textgrid.py index 0a485c4a..8fd54464 100644 --- a/montreal_forced_aligner/textgrid.py +++ b/montreal_forced_aligner/textgrid.py @@ -372,7 +372,7 @@ def export_textgrid( for i, a in enumerate(sorted(intervals, key=lambda x: x.begin)): if duration - a.end < (frame_shift * 2): # Fix rounding issues a.end = duration - if i > 0 and a.to_tg_interval().start > tier.entries[-1].end: + if i > 0 and tier.entries[-1].end > a.to_tg_interval().start: a.begin = tier.entries[-1].end tier.insertEntry(a.to_tg_interval(duration)) if has_data: diff --git a/montreal_forced_aligner/tokenization/trainer.py b/montreal_forced_aligner/tokenization/trainer.py index 0e735b79..0cc31dfc 100644 --- a/montreal_forced_aligner/tokenization/trainer.py +++ b/montreal_forced_aligner/tokenization/trainer.py @@ -71,7 +71,7 @@ def run(self) -> None: logging_name=f"{type(self).__name__}_engine", ).execution_options(logging_token=f"{type(self).__name__}_engine") try: - symbol_table = pynini.SymbolTable() + symbol_table = pywrapfst.SymbolTable() symbol_table.add_symbol(self.eps) valid_output_ngrams = set() base_dir = os.path.dirname(self.far_path) @@ -143,7 +143,7 @@ def run(self) -> None: [input_string, output_string] ) ilabel = symbol_table.find(symbol) - if ilabel == pynini.NO_LABEL: + if ilabel == pywrapfst.NO_LABEL: ilabel = symbol_table.add_symbol(symbol) ostate = (i + input_range) * (len(output) + 1) + ( j + output_range @@ -153,7 +153,7 @@ def run(self) -> None: pywrapfst.Arc( ilabel, ilabel, - pynini.Weight( + pywrapfst.Weight( "log", float(input_range * output_range) ), ostate, @@ -168,7 +168,7 @@ def run(self) -> None: if sym not in data: data[sym] = arc.weight else: - data[sym] = pynini.plus(data[sym], arc.weight) + data[sym] = pywrapfst.plus(data[sym], arc.weight) if count >= self.batch_size: data = {k: float(v) for k, v in data.items()} self.return_queue.put((self.job_name, data, count)) diff --git a/requirements.txt b/requirements.txt index 249f209a..67f73aa2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,11 +1,15 @@ -praatio>=5.0 +praatio>=6.0 tqdm pyyaml librosa +numpy +scipy +scikit-learn requests -biopython +biopython=1.79 dataclassy -sqlalchemy +sqlalchemy>=2.0 +click rich rich-click numba diff --git a/tests/conftest.py b/tests/conftest.py index 648ebe19..7678b1a6 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -9,7 +9,7 @@ import sqlalchemy.orm import yaml -from montreal_forced_aligner.config import GLOBAL_CONFIG +from montreal_forced_aligner.config import GLOBAL_CONFIG, get_temporary_directory from montreal_forced_aligner.helper import mfa_open @@ -85,6 +85,7 @@ def global_config(): GLOBAL_CONFIG.current_profile.use_mp = False GLOBAL_CONFIG.current_profile.database_limited_mode = True GLOBAL_CONFIG.current_profile.auto_server = False + GLOBAL_CONFIG.current_profile.temporary_directory = get_temporary_directory() GLOBAL_CONFIG.save() yield GLOBAL_CONFIG diff --git a/tests/test_commandline_align.py b/tests/test_commandline_align.py index 3ff0f344..17852c58 100644 --- a/tests/test_commandline_align.py +++ b/tests/test_commandline_align.py @@ -1,13 +1,11 @@ import os import click.testing -import pytest from praatio import textgrid as tgio from montreal_forced_aligner.command_line.mfa import mfa_cli -@pytest.mark.timeout(200) def test_align_no_speaker_adaptation( basic_corpus_dir, generated_dir, english_dictionary, temp_dir, english_acoustic_model, db_setup ): @@ -31,7 +29,6 @@ def test_align_no_speaker_adaptation( assert os.path.exists(output_directory) -@pytest.mark.timeout(200) def test_align_single_speaker( basic_corpus_dir, generated_dir, @@ -84,7 +81,6 @@ def test_align_single_speaker( assert os.path.exists(path) -@pytest.mark.timeout(200) def test_align_duplicated( duplicated_name_corpus_dir, generated_dir, @@ -132,7 +128,6 @@ def test_align_duplicated( assert os.path.exists(path) -@pytest.mark.timeout(200) def test_align_multilingual( multilingual_ipa_corpus_dir, english_uk_mfa_dictionary, @@ -171,7 +166,6 @@ def test_align_multilingual( assert not result.return_value -@pytest.mark.timeout(200) def test_align_multilingual_speaker_dict( multilingual_ipa_corpus_dir, mfa_speaker_dict_path, @@ -210,7 +204,6 @@ def test_align_multilingual_speaker_dict( assert not result.return_value -@pytest.mark.timeout(200) def test_align_multilingual_tg_speaker_dict( multilingual_ipa_tg_corpus_dir, mfa_speaker_dict_path, @@ -248,7 +241,6 @@ def test_align_multilingual_tg_speaker_dict( assert not result.return_value -@pytest.mark.timeout(200) def test_align_evaluation( basic_corpus_dir, english_us_mfa_dictionary, @@ -294,7 +286,6 @@ def test_align_evaluation( assert not result.return_value -@pytest.mark.timeout(200) def test_align_split( basic_split_dir, english_us_mfa_dictionary, @@ -336,7 +327,6 @@ def test_align_split( assert not result.return_value -@pytest.mark.timeout(200) def test_align_stereo( stereo_corpus_dir, generated_dir, @@ -378,7 +368,6 @@ def test_align_stereo( assert len(tg.tierNames) == 4 -@pytest.mark.timeout(200) def test_align_mp3s( mp3_corpus_dir, generated_dir, @@ -420,7 +409,6 @@ def test_align_mp3s( assert len(tg.tierNames) == 2 -@pytest.mark.timeout(200) def test_align_opus( opus_corpus_dir, generated_dir, @@ -462,7 +450,6 @@ def test_align_opus( assert len(tg.tierNames) == 2 -@pytest.mark.timeout(200) def test_swedish_cv( swedish_dir, generated_dir, @@ -512,7 +499,6 @@ def test_swedish_cv( assert len(tg.tierNames) == 2 -@pytest.mark.timeout(200) def test_swedish_mfa( swedish_dir, generated_dir, @@ -562,7 +548,6 @@ def test_swedish_mfa( assert len(tg.tierNames) == 2 -@pytest.mark.timeout(200) def test_acoustic_g2p_model( basic_corpus_dir, acoustic_model_dir, diff --git a/tests/test_corpus.py b/tests/test_corpus.py index 05ebb413..e54e84bd 100644 --- a/tests/test_corpus.py +++ b/tests/test_corpus.py @@ -14,13 +14,11 @@ def test_mp3(mp3_test_path): info = get_wav_info(str(mp3_test_path)) - assert info.sox_string assert info.duration > 0 def test_opus(opus_test_path): info = get_wav_info(str(opus_test_path)) - assert info.sox_string assert info.duration > 0