From afd17ffd5e01a728b75c5c2e69c707ce3b2ca92b Mon Sep 17 00:00:00 2001 From: Eric Joanis Date: Fri, 14 Feb 2025 15:04:54 -0500 Subject: [PATCH] perf: add caching to the g2p cascade Since g2p is the most expensive operation, and since we assume Zipfian distribution of words apply, we expect a significant speed up benefit from caching the words converted in one document. Only cache within a document, however, to keep the cache's memory use reasonable and to avoid leaking info between documents. --- readalongs/text/convert_xml.py | 70 +++++++++++++++++++++------------- 1 file changed, 44 insertions(+), 26 deletions(-) diff --git a/readalongs/text/convert_xml.py b/readalongs/text/convert_xml.py index 6dc05954..7ba22b4f 100644 --- a/readalongs/text/convert_xml.py +++ b/readalongs/text/convert_xml.py @@ -36,7 +36,7 @@ import copy import re from time import perf_counter -from typing import Optional +from typing import Dict, Optional, Tuple from readalongs.log import LOGGER from readalongs.text.util import get_attrib_recursive, get_word_text, iterate_over_text @@ -135,6 +135,13 @@ def convert_word(word: str, lang: str) -> Tuple[str, bool]: converter.check(tg, shallow=False, display_warnings=verbose_warnings) return text, valid + # We cache words passed through the g2p cascade, but only internally to this function + # so that the cache is specific to a document, and does not leak between documents + # or excessively grow the memory requirement over time. + convert_word_with_cascade_cache: Dict[ + Tuple[str, str, str], Tuple[str, bool, Optional[str]] + ] = {} + def convert_word_with_cascade( text_to_g2p: str, g2p_lang: str, g2p_fallbacks: str ) -> Tuple[str, bool, Optional[str]]: @@ -154,34 +161,45 @@ def convert_word_with_cascade( if a fallback lang was successfully used: that lang's code if no valid conversion was found: None (and valid==False) """ + cached_result = convert_word_with_cascade_cache.get( + (text_to_g2p, g2p_lang, g2p_fallbacks) + ) + if cached_result: + return cached_result + + result: Tuple[str, bool, Optional[str]] g2p_text, valid = convert_word(text_to_g2p, g2p_lang) if valid: - return g2p_text, True, None - - # This is where we apply the g2p cascade - for lang in re.split(r"[,:]", g2p_fallbacks) if g2p_fallbacks else []: - _, langs = get_langs() - nonlocal g2p_fallback_warning_count - if g2p_fallback_warning_count < 2 or verbose_warnings: - g2p_fallback_warning_count += 1 - LOGGER.warning( - f'Could not g2p "{text_to_g2p}" as {langs.get(g2p_lang, "")} ({g2p_lang}). ' - f"Trying fallback: {langs.get(lang, '')} ({lang})." - ) - g2p_lang = lang.strip() - g2p_text, valid = convert_word(text_to_g2p, g2p_lang) - if valid: - return g2p_text, True, g2p_lang + result = g2p_text, True, None else: - nonlocal g2p_fail_warning_count - if g2p_fail_warning_count < 2 or verbose_warnings: - g2p_fail_warning_count += 1 - LOGGER.warning( - f'No valid g2p conversion found for "{text_to_g2p}". ' - f"Check its orthography and language code, " - f"or pick suitable g2p fallback languages." - ) - return g2p_text, False, None + # This is where we apply the g2p cascade + for lang in re.split(r"[,:]", g2p_fallbacks) if g2p_fallbacks else []: + _, langs = get_langs() + nonlocal g2p_fallback_warning_count + if g2p_fallback_warning_count < 2 or verbose_warnings: + g2p_fallback_warning_count += 1 + LOGGER.warning( + f'Could not g2p "{text_to_g2p}" as {langs.get(g2p_lang, "")} ({g2p_lang}). ' + f"Trying fallback: {langs.get(lang, '')} ({lang})." + ) + g2p_lang = lang.strip() + g2p_text, valid = convert_word(text_to_g2p, g2p_lang) + if valid: + result = g2p_text, True, g2p_lang + break + else: + nonlocal g2p_fail_warning_count + if g2p_fail_warning_count < 2 or verbose_warnings: + g2p_fail_warning_count += 1 + LOGGER.warning( + f'No valid g2p conversion found for "{text_to_g2p}". ' + f"Check its orthography and language code, " + f"or pick suitable g2p fallback languages." + ) + result = g2p_text, False, None + + convert_word_with_cascade_cache[(text_to_g2p, g2p_lang, g2p_fallbacks)] = result + return result all_g2p_valid = True start_time = perf_counter()