diff --git a/ci/docker_environment.yaml b/ci/docker_environment.yaml index 2f7bbb53..d57ef5c0 100644 --- a/ci/docker_environment.yaml +++ b/ci/docker_environment.yaml @@ -38,5 +38,6 @@ dependencies: - pip - rich - rich-click + - kalpy - pip: - speechbrain diff --git a/montreal_forced_aligner/corpus/acoustic_corpus.py b/montreal_forced_aligner/corpus/acoustic_corpus.py index c1cc8659..b38c8511 100644 --- a/montreal_forced_aligner/corpus/acoustic_corpus.py +++ b/montreal_forced_aligner/corpus/acoustic_corpus.py @@ -358,6 +358,56 @@ def load_corpus(self) -> None: self.normalize_text() self.generate_features() + def reset_features(self): + with self.session() as session: + session.execute( + sqlalchemy.update(Corpus).values( + ivectors_calculated=False, + plda_calculated=False, + xvectors_loaded=False, + features_generated=False, + ) + ) + session.execute( + sqlalchemy.update(Utterance).values( + ivector=None, features=None, xvector=None, plda_vector=None + ) + ) + session.execute( + sqlalchemy.update(Speaker).values( + cmvn=None, fmllr=None, ivector=None, xvector=None, plda_vector=None + ) + ) + session.commit() + paths = [ + self.output_directory.joinpath("cmvn.ark"), + self.output_directory.joinpath("cmvn.scp"), + self.output_directory.joinpath("feats.scp"), + self.output_directory.joinpath("ivectors.scp"), + ] + for path in paths: + path.unlink(missing_ok=True) + for j in self.jobs: + paths = [ + j.construct_path(self.split_directory, "cmvn", "scp"), + j.construct_path(self.split_directory, "ivectors", "scp"), + j.construct_path(self.split_directory, "ivectors", "ark"), + ] + for path in paths: + path.unlink(missing_ok=True) + for d_id in j.dictionary_ids: + paths = [ + j.construct_path(self.split_directory, "trans", "scp", d_id), + j.construct_path(self.split_directory, "trans", "ark", d_id), + j.construct_path(self.split_directory, "cmvn", "scp", d_id), + j.construct_path(self.split_directory, "feats", "scp", d_id), + j.construct_path(self.split_directory, "feats", "ark", d_id), + j.construct_path(self.split_directory, "final_features", "scp", d_id), + j.construct_path(self.split_directory, "final_features", "ark", d_id), + ] + for path in paths: + path.unlink(missing_ok=True) + def generate_final_features(self) -> None: """ Generate features for the corpus @@ -392,7 +442,6 @@ def generate_final_features(self) -> None: bulk_update(session, Utterance, list(update_mapping.values())) session.commit() - with self.session() as session: non_ignored_check = ( session.query(Utterance).filter(Utterance.ignored == False).first() # noqa ) diff --git a/montreal_forced_aligner/corpus/multiprocessing.py b/montreal_forced_aligner/corpus/multiprocessing.py index 97ff26a7..f0bd6f85 100644 --- a/montreal_forced_aligner/corpus/multiprocessing.py +++ b/montreal_forced_aligner/corpus/multiprocessing.py @@ -5,7 +5,6 @@ from __future__ import annotations import os -import re import threading import typing from pathlib import Path @@ -20,7 +19,7 @@ from montreal_forced_aligner.data import MfaArguments, WordType from montreal_forced_aligner.db import Dictionary, Grapheme, Job, Speaker, Utterance, Word from montreal_forced_aligner.exceptions import SoundFileError, TextGridParseError, TextParseError -from montreal_forced_aligner.helper import make_re_character_set_safe, mfa_open +from montreal_forced_aligner.helper import mfa_open from montreal_forced_aligner.utils import Counter if typing.TYPE_CHECKING: @@ -270,240 +269,100 @@ def __init__(self, args: NormalizeTextArguments): self.oov_word = args.oov_word self.bracketed_word = args.bracketed_word self.cutoff_word = args.cutoff_word - self.clitic_marker = None - self.clitic_cleanup_regex = None - self.compound_regex = None - self.bracket_regex = None - self.cutoff_regex = None - self.bracket_sanitize_regex = None - self.laughter_regex = None - self.word_break_regex = None - self.clitic_quote_regex = None - self.punctuation_regex = None - self.non_speech_regexes = {} - - def compile_regexes(self) -> None: - """Compile regular expressions necessary for corpus parsing""" - if len(self.clitic_markers) >= 1: - other_clitic_markers = self.clitic_markers[1:] - if other_clitic_markers: - extra = "" - if "-" in other_clitic_markers: - extra = "-" - other_clitic_markers = [x for x in other_clitic_markers if x != "-"] - self.clitic_cleanup_regex = re.compile( - rf'[{extra}{"".join(other_clitic_markers)}]' - ) - self.clitic_marker = self.clitic_markers[0] - if self.compound_markers: - extra = "" - compound_markers = self.compound_markers - if "-" in self.compound_markers: - extra = "-" - compound_markers = [x for x in compound_markers if x != "-"] - self.compound_regex = re.compile(rf"(?<=\w)[{extra}{''.join(compound_markers)}](?=\w)") - if self.brackets: - left_brackets = [x[0] for x in self.brackets] - right_brackets = [x[1] for x in self.brackets] - self.cutoff_regex = re.compile( - rf"[{re.escape(''.join(left_brackets))}](cutoff|hes).*?[{re.escape(''.join(right_brackets))}]+", - flags=re.IGNORECASE, - ) - self.bracket_regex = re.compile( - rf"[{re.escape(''.join(left_brackets))}].*?[{re.escape(''.join(right_brackets))}]+" - ) - self.laughter_regex = re.compile( - rf"[{re.escape(''.join(left_brackets))}](laugh(ing|ter)?|lachen|lg)[{re.escape(''.join(right_brackets))}]+", - flags=re.IGNORECASE, - ) - all_punctuation = set() - non_word_character_set = set(self.punctuation) - non_word_character_set -= {b for x in self.brackets for b in x} - - if self.clitic_markers: - all_punctuation.update(self.clitic_markers) - if self.compound_markers: - all_punctuation.update(self.compound_markers) - self.bracket_sanitize_regex = None - if self.brackets: - word_break_set = ( - non_word_character_set | set(self.clitic_markers) | set(self.compound_markers) - ) - if self.word_break_markers: - word_break_set |= set(self.word_break_markers) - word_break_set = make_re_character_set_safe(word_break_set, [r"\s"]) - self.bracket_sanitize_regex = re.compile(f"(?= 1: - non_clitic_punctuation = all_punctuation - set(self.clitic_markers) - non_clitic_punctuation_set = make_re_character_set_safe(non_clitic_punctuation) - non_punctuation_set = "[^" + punctuation_set[1:] - self.clitic_quote_regex = re.compile( - rf"((?<=\W)|(?<=^)){non_clitic_punctuation_set}*{self.clitic_marker}{non_clitic_punctuation_set}*(?P{non_punctuation_set}+){non_clitic_punctuation_set}*{self.clitic_marker}{non_clitic_punctuation_set}*((?=\W)|(?=$))" - ) - - if self.laughter_regex is not None: - self.non_speech_regexes[self.laughter_word] = self.laughter_regex - if self.cutoff_regex is not None: - self.non_speech_regexes[self.cutoff_word] = self.cutoff_regex - if self.bracket_regex is not None: - self.non_speech_regexes[self.bracketed_word] = self.bracket_regex - - def _dictionary_sanitize(self, session): - from montreal_forced_aligner.dictionary.mixins import SanitizeFunction, SplitWordsFunction - - dictionaries: typing.List[Dictionary] = session.query(Dictionary) - grapheme_mapping = {} - grapheme_query = session.query(Grapheme.grapheme, Grapheme.mapping_id) - for w, m_id in grapheme_query: - grapheme_mapping[w] = m_id - for d in dictionaries: - words_mapping = {} - words_query = session.query(Word.word, Word.mapping_id).filter( - Word.dictionary_id == d.id - ) - for w, m_id in words_query: - words_mapping[w] = m_id - sanitize_function = SanitizeFunction( - self.clitic_marker, - self.clitic_cleanup_regex, - self.clitic_quote_regex, - self.punctuation_regex, - self.word_break_regex, - self.bracket_regex, - self.bracket_sanitize_regex, - self.ignore_case, - ) - clitic_set = set( - x[0] - for x in session.query(Word.word) - .filter(Word.word_type == WordType.clitic) - .filter(Word.dictionary_id == d.id) - ) - initial_clitic_regex = None - final_clitic_regex = None - if self.clitic_marker is not None: - initial_clitics = sorted(x for x in clitic_set if x.endswith(self.clitic_marker)) - final_clitics = sorted(x for x in clitic_set if x.startswith(self.clitic_marker)) - if initial_clitics: - initial_clitic_regex = re.compile(rf"^({'|'.join(initial_clitics)})(?=\w)") - if final_clitics: - final_clitic_regex = re.compile(rf"(?<=\w)({'|'.join(final_clitics)})$") - - non_speech_regexes = {} - if self.laughter_regex is not None: - non_speech_regexes[d.laughter_word] = self.laughter_regex - if self.cutoff_regex is not None: - non_speech_regexes[d.cutoff_word] = self.cutoff_regex - if self.bracket_regex is not None: - non_speech_regexes[d.bracketed_word] = self.bracket_regex - split_function = SplitWordsFunction( - self.clitic_marker, - initial_clitic_regex, - final_clitic_regex, - self.compound_regex, - non_speech_regexes, - d.oov_word, - words_mapping, - grapheme_mapping, - ) - utterances = ( - session.query(Utterance.id, Utterance.text) - .join(Utterance.speaker) - .filter(Utterance.text != "") - .filter(Utterance.job_id == self.job_name) - .filter(Speaker.dictionary_id == d.id) - ) - for u_id, u_text in utterances: - words = sanitize_function(u_text) - normalized_text = [] - normalized_character_text = [] - oovs = set() - text = "" - for w in words: - for new_w in split_function(w): - if new_w not in words_mapping: - oovs.add(new_w) - normalized_text.append(split_function.to_str(new_w)) - if normalized_character_text: - if not self.clitic_marker or ( - not normalized_text[-1].endswith(self.clitic_marker) - and not new_w.startswith(self.clitic_marker) - ): - normalized_character_text.append("") - for c in split_function.parse_graphemes(new_w): - normalized_character_text.append(c) - if text: - text += " " - text += w - self.callback( - ( - { - "id": u_id, - "oovs": " ".join(sorted(oovs)), - "normalized_text": " ".join(normalized_text), - "normalized_character_text": " ".join(normalized_character_text), - }, - d.id, - ) - ) - - def _no_dictionary_sanitize(self, session): - from montreal_forced_aligner.dictionary.mixins import SanitizeFunction - - sanitize_function = SanitizeFunction( - self.clitic_marker, - self.clitic_cleanup_regex, - self.clitic_quote_regex, - self.punctuation_regex, - self.word_break_regex, - self.bracket_regex, - self.bracket_sanitize_regex, - self.ignore_case, - ) - utterances = ( - session.query(Utterance.id, Utterance.text) - .join(Utterance.speaker) - .filter(Utterance.text != "") - .filter(Utterance.job_id == self.job_name) - ) - for u_id, u_text in utterances: - text = [] - character_text = [] - for w in sanitize_function(u_text): - text.append(w) - if character_text: - character_text.append("") - for g in w: - character_text.append(g) - text = " ".join(text) - character_text = " ".join(character_text) - self.callback( - ( - { - "id": u_id, - "oovs": "", - "normalized_text": text, - "normalized_character_text": character_text, - }, - None, - ) - ) def _run(self): """Run the function""" - self.compile_regexes() + from montreal_forced_aligner.tokenization.simple import SimpleTokenizer + with self.session() as session: + + grapheme_set = set() + grapheme_query = session.query(Grapheme.grapheme) + for (g,) in grapheme_query: + grapheme_set.add(g) dict_count = session.query(Dictionary).join(Dictionary.words).limit(1).count() if self.use_g2p or dict_count > 0: - self._dictionary_sanitize(session) + dictionaries = session.query(Dictionary) + for d in dictionaries: + + word_set = set( + x[0] for x in session.query(Word.word).filter(Word.dictionary_id == d.id) + ) + clitic_set = set( + x[0] + for x in session.query(Word.word) + .filter(Word.word_type == WordType.clitic) + .filter(Word.dictionary_id == d.id) + ) + tokenizer = SimpleTokenizer( + word_break_markers=self.word_break_markers, + punctuation=self.punctuation, + clitic_markers=self.clitic_markers, + compound_markers=self.compound_markers, + brackets=self.brackets, + laughter_word=self.laughter_word, + oov_word=self.oov_word, + bracketed_word=self.bracketed_word, + cutoff_word=self.cutoff_word, + ignore_case=self.ignore_case, + use_g2p=self.use_g2p, + clitic_set=clitic_set, + word_set=word_set, + grapheme_set=grapheme_set, + ) + utterances = ( + session.query(Utterance.id, Utterance.text) + .join(Utterance.speaker) + .filter(Utterance.text != "") + .filter(Utterance.job_id == self.job_name) + .filter(Speaker.dictionary_id == d.id) + ) + for u_id, u_text in utterances: + normalized_text, normalized_character_text, oovs = tokenizer(u_text) + self.callback( + ( + { + "id": u_id, + "oovs": " ".join(sorted(oovs)), + "normalized_text": normalized_text, + "normalized_character_text": normalized_character_text, + }, + d.id, + ) + ) else: - self._no_dictionary_sanitize(session) + tokenizer = SimpleTokenizer( + word_break_markers=self.word_break_markers, + punctuation=self.punctuation, + clitic_markers=self.clitic_markers, + compound_markers=self.compound_markers, + brackets=self.brackets, + laughter_word=self.laughter_word, + oov_word=self.oov_word, + bracketed_word=self.bracketed_word, + cutoff_word=self.cutoff_word, + ignore_case=self.ignore_case, + use_g2p=self.use_g2p, + grapheme_set=grapheme_set, + ) + utterances = ( + session.query(Utterance.id, Utterance.text) + .filter(Utterance.text != "") + .filter(Utterance.job_id == self.job_name) + ) + for u_id, u_text in utterances: + normalized_text, normalized_character_text, oovs = tokenizer(u_text) + self.callback( + ( + { + "id": u_id, + "oovs": " ".join(sorted(oovs)), + "normalized_text": normalized_text, + "normalized_character_text": normalized_character_text, + }, + None, + ) + ) class ExportKaldiFilesFunction(KaldiFunction): diff --git a/montreal_forced_aligner/data.py b/montreal_forced_aligner/data.py index 0193b980..fd7bd782 100644 --- a/montreal_forced_aligner/data.py +++ b/montreal_forced_aligner/data.py @@ -44,6 +44,7 @@ "WORD_END_SYMBOL", "OOV_WORD", "BRACKETED_WORD", + "LAUGHTER_WORD", "CUTOFF_WORD", "SIL_WORD", "SIL_PHONE", @@ -54,6 +55,7 @@ WORD_END_SYMBOL = "#2" OOV_WORD = "" BRACKETED_WORD = "" +LAUGHTER_WORD = "[laughter]" CUTOFF_WORD = "" SIL_WORD = "" SIL_PHONE = "sil" diff --git a/montreal_forced_aligner/db.py b/montreal_forced_aligner/db.py index e1ae6598..1b7588d5 100644 --- a/montreal_forced_aligner/db.py +++ b/montreal_forced_aligner/db.py @@ -399,6 +399,7 @@ def word_table(self): if self.words_symbol_path.exists(): self._word_table = pywrapfst.SymbolTable.read_text(self.words_symbol_path) return self._word_table + self.temp_directory.mkdir(parents=True, exist_ok=True) session = sqlalchemy.orm.Session.object_session(self) query = ( session.query(Word.word, Word.mapping_id) @@ -419,6 +420,13 @@ def phone_table(self): self._phone_table = pywrapfst.SymbolTable.read_text(self.phone_symbol_table_path) else: return None + self.temp_directory.mkdir(parents=True, exist_ok=True) + session = sqlalchemy.orm.Session.object_session(self) + query = session.query(Phone.kaldi_label, Phone.mapping_id).order_by(Phone.mapping_id) + self._phone_table = pywrapfst.SymbolTable() + for p, mapping_id in query: + self._phone_table.add_symbol(p, mapping_id) + self._phone_table.write_text(self.phone_symbol_table_path) return self._phone_table @property diff --git a/montreal_forced_aligner/diarization/multiprocessing.py b/montreal_forced_aligner/diarization/multiprocessing.py index 0e587b14..0f1fd1c2 100644 --- a/montreal_forced_aligner/diarization/multiprocessing.py +++ b/montreal_forced_aligner/diarization/multiprocessing.py @@ -703,7 +703,7 @@ def __init__(self, args: SpeechbrainArguments): self.cuda = args.cuda self.cluster = args.cluster - def _run(self) -> typing.Generator[typing.Tuple[int, int, int]]: + def _run(self) -> None: """Run the function""" run_opts = None if self.cuda: @@ -742,6 +742,7 @@ def _run(self) -> typing.Generator[typing.Tuple[int, int, int]]: if stopped.is_set(): continue if isinstance(result, Exception): + exception = result stopped.set() continue @@ -756,8 +757,6 @@ def _run(self) -> typing.Generator[typing.Tuple[int, int, int]]: ) self.callback((u_id, emb[0])) del emb - if self.cuda: - torch.cuda.empty_cache() loader.join() if exception: @@ -772,8 +771,8 @@ class UtteranceFileLoader(threading.Thread): ---------- job_name: int Job identifier - db_string: str - Connection string for database + session: sqlalchemy.orm.scoped_session + Session return_q: multiprocessing.Queue Queue to put waveforms stopped: :class:`~threading.Event` @@ -813,6 +812,7 @@ def run(self) -> None: .join(Utterance.file) .join(File.sound_file) .filter(Utterance.job_id == self.job_name) + .filter(Utterance.xvector == None) # noqa ) for u_id, begin, duration, sound_file_path in utterances: if self.stopped.is_set(): diff --git a/montreal_forced_aligner/diarization/speaker_diarizer.py b/montreal_forced_aligner/diarization/speaker_diarizer.py index 19e5055e..7d0c9aa6 100644 --- a/montreal_forced_aligner/diarization/speaker_diarizer.py +++ b/montreal_forced_aligner/diarization/speaker_diarizer.py @@ -1293,9 +1293,6 @@ def calculate_eer(self) -> typing.Tuple[float, float]: def load_embeddings(self) -> None: """Load embeddings from a speechbrain model""" - if self.has_xvectors(): - logger.info("Embeddings already loaded.") - return logger.info("Loading SpeechBrain embeddings...") with tqdm( total=self.num_utterances, disable=config.QUIET @@ -1308,6 +1305,9 @@ def load_embeddings(self) -> None: ] embeddings = [] utterance_ids = [] + original_use_mp = config.USE_MP + if self.cuda: + config.USE_MP = False for u_id, emb in run_kaldi_function( SpeechbrainEmbeddingFunction, arguments, pbar.update ): @@ -1344,6 +1344,8 @@ def load_embeddings(self) -> None: ) session.query(Corpus).update({Corpus.xvectors_loaded: True}) session.commit() + if self.cuda: + config.USE_MP = original_use_mp logger.debug(f"Loading embeddings took {time.time() - begin:.3f} seconds") def refresh_plda_vectors(self): diff --git a/montreal_forced_aligner/dictionary/__init__.py b/montreal_forced_aligner/dictionary/__init__.py index acc285b3..2a7a7016 100644 --- a/montreal_forced_aligner/dictionary/__init__.py +++ b/montreal_forced_aligner/dictionary/__init__.py @@ -4,7 +4,7 @@ """ -from montreal_forced_aligner.dictionary.mixins import DictionaryMixin, SanitizeFunction +from montreal_forced_aligner.dictionary.mixins import DictionaryMixin from montreal_forced_aligner.dictionary.multispeaker import ( MultispeakerDictionary, MultispeakerDictionaryMixin, @@ -14,7 +14,6 @@ "multispeaker", "mixins", "DictionaryMixin", - "SanitizeFunction", "MultispeakerDictionary", "MultispeakerDictionaryMixin", ] diff --git a/montreal_forced_aligner/dictionary/mixins.py b/montreal_forced_aligner/dictionary/mixins.py index 1dac0fff..cbb5109a 100644 --- a/montreal_forced_aligner/dictionary/mixins.py +++ b/montreal_forced_aligner/dictionary/mixins.py @@ -28,290 +28,7 @@ DEFAULT_COMPOUND_MARKERS = list("-/") DEFAULT_BRACKETS = [("[", "]"), ("{", "}"), ("<", ">"), ("(", ")"), ("<", ">")] -__all__ = ["SanitizeFunction", "SplitWordsFunction", "DictionaryMixin", "TemporaryDictionaryMixin"] - - -class SanitizeFunction: - """ - Class for functions that sanitize text and strip punctuation - - Parameters - ---------- - punctuation: list[str] - List of characters to treat as punctuation - clitic_markers: list[str] - Characters that mark clitics - compound_markers: list[str] - Characters that mark compound words - brackets: list[tuple[str, str]] - List of bracket sets to not strip from the ends of words - ignore_case: bool - Flag for whether all items should be converted to lower case, defaults to True - quote_markers: list[str], optional - Quotation markers to use when parsing text - quote_markers: list[str], optional - Quotation markers to use when parsing text - word_break_markers: list[str], optional - Word break markers to use when parsing text - """ - - def __init__( - self, - clitic_marker: str, - clitic_cleanup_regex: Optional[re.Pattern], - clitic_quote_regex: Optional[re.Pattern], - punctuation_regex: Optional[re.Pattern], - word_break_regex: Optional[re.Pattern], - bracket_regex: Optional[re.Pattern], - bracket_sanitize_regex: Optional[re.Pattern], - ignore_case: bool = True, - ): - self.clitic_marker = clitic_marker - self.clitic_cleanup_regex = clitic_cleanup_regex - self.clitic_quote_regex = clitic_quote_regex - self.punctuation_regex = punctuation_regex - self.word_break_regex = word_break_regex - self.bracket_regex = bracket_regex - self.bracket_sanitize_regex = bracket_sanitize_regex - - self.ignore_case = ignore_case - - def __call__(self, text) -> typing.Generator[str]: - """ - Sanitize text according to punctuation, quotes, and word break characters - - Parameters - ---------- - text: str - Text to sanitize - - Returns - ------- - Generator[str] - Sanitized form - """ - if self.ignore_case: - text = text.lower() - if self.bracket_regex: - for word_object in self.bracket_regex.finditer(text): - word = word_object.group(0) - new_word = self.bracket_sanitize_regex.sub("_", word) - - text = text.replace(word, new_word) - - if self.clitic_cleanup_regex: - text = self.clitic_cleanup_regex.sub(self.clitic_marker, text) - - if self.clitic_quote_regex is not None and self.clitic_marker in text: - text = self.clitic_quote_regex.sub(r"\g", text) - - words = self.word_break_regex.split(text) - - for w in words: - if not w: - continue - if self.punctuation_regex is not None and self.punctuation_regex.match(w): - continue - if w: - yield w - - -class SplitWordsFunction: - """ - Class for functions that splits words that have compound and clitic markers - - Parameters - ---------- - clitic_markers: list[str] - Characters that mark clitics - compound_markers: list[str] - Characters that mark compound words - clitic_set: set[str] - Set of clitic words - brackets: list[tuple[str, str], optional - Character tuples to treat as full brackets around words - words_mapping: dict[str, int] - Mapping of words to integer IDs - specials_set: set[str] - Set of special words - oov_word : str - What to label words not in the dictionary, defaults to None - """ - - def __init__( - self, - clitic_marker: str, - initial_clitic_regex: Optional[re.Pattern], - final_clitic_regex: Optional[re.Pattern], - compound_regex: Optional[re.Pattern], - non_speech_regexes: Dict[str, re.Pattern], - oov_word: Optional[str] = None, - word_mapping: Optional[Dict[str, int]] = None, - grapheme_mapping: Optional[Dict[str, int]] = None, - ): - self.clitic_marker = clitic_marker - self.compound_regex = compound_regex - self.oov_word = oov_word - self.specials_set = {self.oov_word, "", ""} - if not word_mapping: - word_mapping = None - self.word_mapping = word_mapping - if not grapheme_mapping: - grapheme_mapping = None - self.grapheme_mapping = grapheme_mapping - self.compound_pattern = None - self.clitic_pattern = None - self.non_speech_regexes = non_speech_regexes - self.initial_clitic_regex = initial_clitic_regex - self.final_clitic_regex = final_clitic_regex - self.has_initial = False - self.has_final = False - if self.initial_clitic_regex is not None: - self.has_initial = True - if self.final_clitic_regex is not None: - self.has_final = True - - def to_str(self, normalized_text: str) -> str: - """ - Convert normalized text to an integer ID - - Parameters - ---------- - normalized_text: - Word to convert - - Returns - ------- - str - Normalized string - """ - if normalized_text in self.specials_set: - return self.oov_word - for word, regex in self.non_speech_regexes.items(): - if regex.match(normalized_text): - return word - return normalized_text - - def split_clitics( - self, - item: str, - ) -> List[str]: - """ - Split a word into subwords based on dictionary information - - Parameters - ---------- - item: str - Word to split - - Returns - ------- - list[str] - List of subwords - """ - split = [] - if self.compound_regex is not None: - s = self.compound_regex.split(item) - else: - s = [item] - if self.word_mapping is None: - return [item] - clean_initial_quote_regex = re.compile("^'") - clean_final_quote_regex = re.compile("'$") - benefit = False - for seg in s: - if not seg: - continue - if not self.clitic_marker or self.clitic_marker not in seg: - split.append(seg) - if not benefit and seg in self.word_mapping: - benefit = True - continue - elif seg.startswith(self.clitic_marker): - if seg[1:] in self.word_mapping: - split.append(seg[1:]) - benefit = True - continue - elif seg.endswith(self.clitic_marker): - if seg[:-1] in self.word_mapping: - split.append(seg[:-1]) - benefit = True - continue - - initial_clitics = [] - final_clitics = [] - if self.has_initial: - while True: - clitic = self.initial_clitic_regex.match(seg) - if clitic is None: - break - benefit = True - initial_clitics.append(clitic.group(0)) - seg = seg[clitic.end(0) :] - if seg in self.word_mapping: - break - if self.has_final: - while True: - clitic = self.final_clitic_regex.search(seg) - if clitic is None: - break - benefit = True - final_clitics.append(clitic.group(0)) - seg = seg[: clitic.start(0)] - if seg in self.word_mapping: - break - final_clitics.reverse() - split.extend([clean_initial_quote_regex.sub("", x) for x in initial_clitics]) - seg = clean_final_quote_regex.sub("", clean_initial_quote_regex.sub("", seg)) - if seg: - split.append(seg) - split.extend([clean_final_quote_regex.sub("", x) for x in final_clitics]) - if not benefit and seg in self.word_mapping: - benefit = True - if not benefit: - return [item] - return split - - def parse_graphemes( - self, - item: str, - ) -> typing.Generator[str]: - for word, regex in self.non_speech_regexes.items(): - if regex.match(item): - yield word - break - else: - characters = list(item) - for c in characters: - if self.grapheme_mapping is not None and c in self.grapheme_mapping: - yield c - else: - yield self.oov_word - - def __call__( - self, - item: str, - ) -> List[str]: - """ - Return the list of sub words if necessary - taking into account clitic and compound markers - - Parameters - ---------- - item: str - Word to look up - - Returns - ------- - list[str] - List of subwords that are in the dictionary - """ - if self.word_mapping is not None and item in self.word_mapping: - return [item] - for regex in self.non_speech_regexes.values(): - if regex.match(item): - return [item] - return self.split_clitics(item) +__all__ = ["DictionaryMixin", "TemporaryDictionaryMixin"] class DictionaryMixin: diff --git a/montreal_forced_aligner/models.py b/montreal_forced_aligner/models.py index 4bff89bf..3cc96987 100644 --- a/montreal_forced_aligner/models.py +++ b/montreal_forced_aligner/models.py @@ -458,7 +458,7 @@ def mfcc_computer(self) -> MfccComputer: @property def pitch_computer(self) -> typing.Optional[PitchComputer]: - if self.meta["features"]["use_pitch"]: + if self.meta["features"].get("use_pitch", False): return PitchComputer(**self.pitch_options) return diff --git a/montreal_forced_aligner/tokenization/simple.py b/montreal_forced_aligner/tokenization/simple.py new file mode 100644 index 00000000..ee47e2b1 --- /dev/null +++ b/montreal_forced_aligner/tokenization/simple.py @@ -0,0 +1,492 @@ +from __future__ import annotations + +import re +import typing + +from montreal_forced_aligner.data import BRACKETED_WORD, CUTOFF_WORD, LAUGHTER_WORD, OOV_WORD +from montreal_forced_aligner.helper import make_re_character_set_safe + +__all__ = ["SanitizeFunction", "SplitWordsFunction", "SimpleTokenizer"] + + +class SanitizeFunction: + """ + Class for functions that sanitize text and strip punctuation + + Parameters + ---------- + punctuation: list[str] + List of characters to treat as punctuation + clitic_markers: list[str] + Characters that mark clitics + compound_markers: list[str] + Characters that mark compound words + brackets: list[tuple[str, str]] + List of bracket sets to not strip from the ends of words + ignore_case: bool + Flag for whether all items should be converted to lower case, defaults to True + quote_markers: list[str], optional + Quotation markers to use when parsing text + quote_markers: list[str], optional + Quotation markers to use when parsing text + word_break_markers: list[str], optional + Word break markers to use when parsing text + """ + + def __init__( + self, + clitic_marker: str, + clitic_cleanup_regex: typing.Optional[re.Pattern], + clitic_quote_regex: typing.Optional[re.Pattern], + punctuation_regex: typing.Optional[re.Pattern], + word_break_regex: typing.Optional[re.Pattern], + bracket_regex: typing.Optional[re.Pattern], + bracket_sanitize_regex: typing.Optional[re.Pattern], + ignore_case: bool = True, + ): + self.clitic_marker = clitic_marker + self.clitic_cleanup_regex = clitic_cleanup_regex + self.clitic_quote_regex = clitic_quote_regex + self.punctuation_regex = punctuation_regex + self.word_break_regex = word_break_regex + self.bracket_regex = bracket_regex + self.bracket_sanitize_regex = bracket_sanitize_regex + + self.ignore_case = ignore_case + + def __call__(self, text) -> typing.Generator[str]: + """ + Sanitize text according to punctuation, quotes, and word break characters + + Parameters + ---------- + text: str + Text to sanitize + + Returns + ------- + Generator[str] + Sanitized form + """ + if self.ignore_case: + text = text.lower() + if self.bracket_regex: + for word_object in self.bracket_regex.finditer(text): + word = word_object.group(0) + new_word = self.bracket_sanitize_regex.sub("_", word) + + text = text.replace(word, new_word) + + if self.clitic_cleanup_regex: + text = self.clitic_cleanup_regex.sub(self.clitic_marker, text) + + if self.clitic_quote_regex is not None and self.clitic_marker in text: + text = self.clitic_quote_regex.sub(r"\g", text) + + words = self.word_break_regex.split(text) + + for w in words: + if not w: + continue + if self.punctuation_regex is not None and self.punctuation_regex.match(w): + continue + if w: + yield w + + +class SplitWordsFunction: + """ + Class for functions that splits words that have compound and clitic markers + + Parameters + ---------- + clitic_markers: list[str] + Characters that mark clitics + compound_markers: list[str] + Characters that mark compound words + clitic_set: set[str] + Set of clitic words + brackets: list[tuple[str, str], optional + Character tuples to treat as full brackets around words + words_mapping: dict[str, int] + Mapping of words to integer IDs + specials_set: set[str] + Set of special words + oov_word : str + What to label words not in the dictionary, defaults to None + """ + + def __init__( + self, + clitic_marker: str, + initial_clitic_regex: typing.Optional[re.Pattern], + final_clitic_regex: typing.Optional[re.Pattern], + compound_regex: typing.Optional[re.Pattern], + non_speech_regexes: typing.Dict[str, re.Pattern], + oov_word: typing.Optional[str] = None, + word_set: typing.Optional[typing.Collection[str]] = None, + grapheme_set: typing.Optional[typing.Collection[str]] = None, + ): + self.clitic_marker = clitic_marker + self.compound_regex = compound_regex + self.oov_word = oov_word + self.specials_set = {self.oov_word, "", ""} + if not word_set: + word_set = None + self.word_set = word_set + if not grapheme_set: + grapheme_set = None + self.grapheme_set = grapheme_set + self.compound_pattern = None + self.clitic_pattern = None + self.non_speech_regexes = non_speech_regexes + self.initial_clitic_regex = initial_clitic_regex + self.final_clitic_regex = final_clitic_regex + self.has_initial = False + self.has_final = False + if self.initial_clitic_regex is not None: + self.has_initial = True + if self.final_clitic_regex is not None: + self.has_final = True + + def to_str(self, normalized_text: str) -> str: + """ + Convert normalized text to an integer ID + + Parameters + ---------- + normalized_text: + Word to convert + + Returns + ------- + str + Normalized string + """ + if normalized_text in self.specials_set: + return self.oov_word + for word, regex in self.non_speech_regexes.items(): + if regex.match(normalized_text): + return word + return normalized_text + + def split_clitics( + self, + item: str, + ) -> typing.List[str]: + """ + Split a word into subwords based on dictionary information + + Parameters + ---------- + item: str + Word to split + + Returns + ------- + list[str] + List of subwords + """ + split = [] + if self.compound_regex is not None: + s = self.compound_regex.split(item) + else: + s = [item] + if self.word_set is None: + return [item] + clean_initial_quote_regex = re.compile("^'") + clean_final_quote_regex = re.compile("'$") + benefit = False + for seg in s: + if not seg: + continue + if not self.clitic_marker or self.clitic_marker not in seg: + split.append(seg) + if not benefit and seg in self.word_set: + benefit = True + continue + elif seg.startswith(self.clitic_marker): + if seg[1:] in self.word_set: + split.append(seg[1:]) + benefit = True + continue + elif seg.endswith(self.clitic_marker): + if seg[:-1] in self.word_set: + split.append(seg[:-1]) + benefit = True + continue + + initial_clitics = [] + final_clitics = [] + if self.has_initial: + while True: + clitic = self.initial_clitic_regex.match(seg) + if clitic is None: + break + benefit = True + initial_clitics.append(clitic.group(0)) + seg = seg[clitic.end(0) :] + if seg in self.word_set: + break + if self.has_final: + while True: + clitic = self.final_clitic_regex.search(seg) + if clitic is None: + break + benefit = True + final_clitics.append(clitic.group(0)) + seg = seg[: clitic.start(0)] + if seg in self.word_set: + break + final_clitics.reverse() + split.extend([clean_initial_quote_regex.sub("", x) for x in initial_clitics]) + seg = clean_final_quote_regex.sub("", clean_initial_quote_regex.sub("", seg)) + if seg: + split.append(seg) + split.extend([clean_final_quote_regex.sub("", x) for x in final_clitics]) + if not benefit and seg in self.word_set: + benefit = True + if not benefit: + return [item] + return split + + def parse_graphemes( + self, + item: str, + ) -> typing.Generator[str]: + for word, regex in self.non_speech_regexes.items(): + if regex.match(item): + yield word + break + else: + characters = list(item) + for c in characters: + if self.grapheme_set is not None and c in self.grapheme_set: + yield c + else: + yield self.oov_word + + def __call__( + self, + item: str, + ) -> typing.List[str]: + """ + Return the list of sub words if necessary + taking into account clitic and compound markers + + Parameters + ---------- + item: str + Word to look up + + Returns + ------- + list[str] + List of subwords that are in the dictionary + """ + if self.word_set is not None and item in self.word_set: + return [item] + for regex in self.non_speech_regexes.values(): + if regex.match(item): + return [item] + return self.split_clitics(item) + + +class SimpleTokenizer: + def __init__( + self, + word_break_markers: typing.List[str], + punctuation: typing.List[str], + clitic_markers: typing.List[str], + compound_markers: typing.List[str], + brackets: typing.List[typing.Tuple[str, str]], + laughter_word: str = LAUGHTER_WORD, + oov_word: str = OOV_WORD, + bracketed_word: str = BRACKETED_WORD, + cutoff_word: str = CUTOFF_WORD, + ignore_case: bool = True, + use_g2p: bool = False, + clitic_set: typing.Iterable = None, + word_set: typing.Iterable = None, + grapheme_set: typing.Iterable = None, + ): + self.word_break_markers = word_break_markers + self.punctuation = punctuation + self.clitic_markers = clitic_markers + self.compound_markers = compound_markers + self.brackets = brackets + self.laughter_word = laughter_word + self.oov_word = oov_word + self.bracketed_word = bracketed_word + self.cutoff_word = cutoff_word + self.ignore_case = ignore_case + self.use_g2p = use_g2p + self.clitic_set = set() + if clitic_set is not None: + self.clitic_set.update(clitic_set) + + self.word_set = set() + if word_set is not None: + self.word_set.update(word_set) + + self.grapheme_set = set() + if grapheme_set is not None: + self.grapheme_set.update(grapheme_set) + + self.clitic_marker = None + self.clitic_cleanup_regex = None + self.compound_regex = None + self.bracket_regex = None + self.cutoff_regex = None + self.bracket_sanitize_regex = None + self.laughter_regex = None + self.word_break_regex = None + self.clitic_quote_regex = None + self.punctuation_regex = None + self.initial_clitic_regex = None + self.final_clitic_regex = None + self.non_speech_regexes = {} + self._compile_regexes() + self.sanitize_function = SanitizeFunction( + self.clitic_marker, + self.clitic_cleanup_regex, + self.clitic_quote_regex, + self.punctuation_regex, + self.word_break_regex, + self.bracket_regex, + self.bracket_sanitize_regex, + self.ignore_case, + ) + self.split_function = SplitWordsFunction( + self.clitic_marker, + self.initial_clitic_regex, + self.final_clitic_regex, + self.compound_regex, + self.non_speech_regexes, + self.oov_word, + self.word_set, + self.grapheme_set, + ) + + def _compile_regexes(self) -> None: + """Compile regular expressions necessary for corpus parsing""" + if len(self.clitic_markers) >= 1: + other_clitic_markers = self.clitic_markers[1:] + if other_clitic_markers: + extra = "" + if "-" in other_clitic_markers: + extra = "-" + other_clitic_markers = [x for x in other_clitic_markers if x != "-"] + self.clitic_cleanup_regex = re.compile( + rf'[{extra}{"".join(other_clitic_markers)}]' + ) + self.clitic_marker = self.clitic_markers[0] + if self.compound_markers: + extra = "" + compound_markers = self.compound_markers + if "-" in self.compound_markers: + extra = "-" + compound_markers = [x for x in compound_markers if x != "-"] + self.compound_regex = re.compile(rf"(?<=\w)[{extra}{''.join(compound_markers)}](?=\w)") + if self.brackets: + left_brackets = [x[0] for x in self.brackets] + right_brackets = [x[1] for x in self.brackets] + self.cutoff_regex = re.compile( + rf"[{re.escape(''.join(left_brackets))}](cutoff|hes).*?[{re.escape(''.join(right_brackets))}]+", + flags=re.IGNORECASE, + ) + self.bracket_regex = re.compile( + rf"[{re.escape(''.join(left_brackets))}].*?[{re.escape(''.join(right_brackets))}]+" + ) + self.laughter_regex = re.compile( + rf"[{re.escape(''.join(left_brackets))}](laugh(ing|ter)?|lachen|lg)[{re.escape(''.join(right_brackets))}]+", + flags=re.IGNORECASE, + ) + all_punctuation = set() + non_word_character_set = set(self.punctuation) + non_word_character_set -= {b for x in self.brackets for b in x} + + if self.clitic_markers: + all_punctuation.update(self.clitic_markers) + if self.compound_markers: + all_punctuation.update(self.compound_markers) + self.bracket_sanitize_regex = None + if self.brackets: + word_break_set = ( + non_word_character_set | set(self.clitic_markers) | set(self.compound_markers) + ) + if self.word_break_markers: + word_break_set |= set(self.word_break_markers) + word_break_set = make_re_character_set_safe(word_break_set, [r"\s"]) + self.bracket_sanitize_regex = re.compile(f"(?= 1: + non_clitic_punctuation = all_punctuation - set(self.clitic_markers) + non_clitic_punctuation_set = make_re_character_set_safe(non_clitic_punctuation) + non_punctuation_set = "[^" + punctuation_set[1:] + self.clitic_quote_regex = re.compile( + rf"((?<=\W)|(?<=^)){non_clitic_punctuation_set}*{self.clitic_marker}{non_clitic_punctuation_set}*(?P{non_punctuation_set}+){non_clitic_punctuation_set}*{self.clitic_marker}{non_clitic_punctuation_set}*((?=\W)|(?=$))" + ) + + if self.laughter_regex is not None: + self.non_speech_regexes[self.laughter_word] = self.laughter_regex + if self.cutoff_regex is not None: + self.non_speech_regexes[self.cutoff_word] = self.cutoff_regex + if self.bracket_regex is not None: + self.non_speech_regexes[self.bracketed_word] = self.bracket_regex + + if self.clitic_marker is not None: + initial_clitics = sorted(x for x in self.clitic_set if x.endswith(self.clitic_marker)) + final_clitics = sorted(x for x in self.clitic_set if x.startswith(self.clitic_marker)) + if initial_clitics: + self.initial_clitic_regex = re.compile(rf"^({'|'.join(initial_clitics)})(?=\w)") + if final_clitics: + self.final_clitic_regex = re.compile(rf"(?<=\w)({'|'.join(final_clitics)})$") + + def _dictionary_sanitize(self, text): + + words = self.sanitize_function(text) + normalized_text = [] + normalized_character_text = [] + oovs = set() + for w in words: + for new_w in self.split_function(w): + if new_w not in self.word_set: + oovs.add(new_w) + normalized_text.append(self.split_function.to_str(new_w)) + if normalized_character_text: + if not self.clitic_marker or ( + not normalized_text[-1].endswith(self.clitic_marker) + and not new_w.startswith(self.clitic_marker) + ): + normalized_character_text.append("") + for c in self.split_function.parse_graphemes(new_w): + normalized_character_text.append(c) + normalized_text = " ".join(normalized_text) + normalized_character_text = " ".join(normalized_character_text) + return normalized_text, normalized_character_text, sorted(oovs) + + def _no_dictionary_sanitize(self, text): + normalized_text = [] + normalized_character_text = [] + for w in self.sanitize_function(text): + normalized_text.append(w) + if normalized_character_text: + normalized_character_text.append("") + for g in w: + normalized_character_text.append(g) + normalized_text = " ".join(normalized_text) + normalized_character_text = " ".join(normalized_character_text) + return normalized_text, normalized_character_text, [] + + def __call__(self, text): + """Run the function""" + if self.word_set or self.grapheme_set: + return self._dictionary_sanitize(text) + else: + return self._no_dictionary_sanitize(text) diff --git a/montreal_forced_aligner/utils.py b/montreal_forced_aligner/utils.py index 6ffefba2..fce219cd 100644 --- a/montreal_forced_aligner/utils.py +++ b/montreal_forced_aligner/utils.py @@ -687,3 +687,7 @@ def run_kaldi_function( finally: p.join() + + if error_dict: + for v in error_dict.values(): + raise v diff --git a/tests/test_acoustic_modeling.py b/tests/test_acoustic_modeling.py index 6455de55..764eeb17 100644 --- a/tests/test_acoustic_modeling.py +++ b/tests/test_acoustic_modeling.py @@ -18,6 +18,7 @@ def test_trainer(basic_dict_path, temp_dir, basic_corpus_dir): assert a.training_configs[a.final_identifier].subset == 0 assert a.training_configs[a.final_identifier].num_leaves == 7000 assert a.training_configs[a.final_identifier].max_gaussians == 150000 + a.cleanup() def test_basic_mono(