diff --git a/CMakeLists.txt b/CMakeLists.txt index e36ac19..f2b7692 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -14,7 +14,7 @@ cmake_minimum_required(VERSION 3.8 FATAL_ERROR) project(kaldialign CXX) # Please remember to also change line 3 of ./scripts/conda/kaldialign/meta.yaml -set(KALDIALIGN_VERSION "0.7.2") +set(KALDIALIGN_VERSION "0.8.0") if(NOT CMAKE_BUILD_TYPE) set(CMAKE_BUILD_TYPE Release) diff --git a/README.md b/README.md index 957550f..2c72bf2 100644 --- a/README.md +++ b/README.md @@ -30,7 +30,9 @@ python3 -m pip install --verbose . ## Examples -- `align(seq1, seq2, epsilon)` - used to obtain the alignment between two string sequences. `epsilon` should be a null symbol (indicating deletion/insertion) that doesn't exist in either sequence. +### Alignment + +`align(seq1, seq2, epsilon)` - used to obtain the alignment between two string sequences. `epsilon` should be a null symbol (indicating deletion/insertion) that doesn't exist in either sequence. ```python from kaldialign import align @@ -42,7 +44,9 @@ ali = align(a, b, EPS) assert ali == [('a', 'a'), ('b', 's'), (EPS, 'x'), ('c', 'c')] ``` -- `edit_distance(seq1, seq2)` - used to obtain the total edit distance, as well as the number of insertions, deletions and substitutions. +### Edit distance + +`edit_distance(seq1, seq2)` - used to obtain the total edit distance, as well as the number of insertions, deletions and substitutions. ```python from kaldialign import edit_distance @@ -58,9 +62,65 @@ assert results == { } ``` -- For both of the above examples, you can pass `sclite_mode=True` to compute WER or alignments +For alignment and edit distance, you can pass `sclite_mode=True` to compute WER or alignments based on SCLITE style weights, i.e., insertion/deletion cost 3 and substitution cost 4. +### Bootstrapping method to extract WER 95% confidence intervals + +`boostrap_wer_ci(ref, hyp)` - obtain the 95% confidence intervals for WER using Bisani and Ney boostrapping method. + +```python +from kaldialign import bootstrap_wer_ci + +ref = [ + ("a", "b", "c"), + ("d", "e", "f"), +] +hyp = [ + ("a", "b", "d"), + ("e", "f", "f"), +] +ans = bootstrap_wer_ci(ref, hyp) +assert ans["wer"] == 0.4989 +assert ans["ci95"] == 0.2312 +assert ans["ci95min"] == 0.2678 +assert ans["ci95max"] == 0.7301 +``` + +It also supports providing hypotheses from system 1 and system 2 to compute the probability of S2 improving over S1: + +```python +from kaldialign import bootstrap_wer_ci + +ref = [ + ("a", "b", "c"), + ("d", "e", "f"), +] +hyp = [ + ("a", "b", "d"), + ("e", "f", "f"), +] +hyp2 = [ + ("a", "b", "c"), + ("e", "e", "f"), +] +ans = bootstrap_wer_ci(ref, hyp, hyp2) + +s = ans["system1"] +assert s["wer"] == 0.4989 +assert s["ci95"] == 0.2312 +assert s["ci95min"] == 0.2678 +assert s["ci95max"] == 0.7301 + +s = ans["system2"] +assert s["wer"] == 0.1656 +assert s["ci95"] == 0.2312 +assert s["ci95min"] == -0.0656 +assert s["ci95max"] == 0.3968 + +assert ans["p_s2_improv_over_s1"] == 1.0 +``` + ## Motivation The need for this arised from the fact that practically all implementations of the Levenshtein distance have slight differences, making it impossible to use a different scoring tool than Kaldi and get the same error rate results. This package copies code from Kaldi directly and wraps it using Cython, avoiding the issue altogether. diff --git a/kaldialign/__init__.py b/kaldialign/__init__.py index 780ed23..bcd3692 100644 --- a/kaldialign/__init__.py +++ b/kaldialign/__init__.py @@ -1,8 +1,21 @@ +import math from typing import List, Tuple +import random import _kaldialign def edit_distance(a, b, sclite_mode=False): + """ + Compute the edit distance between sequences ``a`` and ``b``. + Both sequences can be strings or lists of strings or ints. + + Optional ``sclite_mode`` sets INS/DEL/SUB costs to 3/3/4 for + compatibility with sclite tool. + + Returns a dict with keys ``ins``, ``del``, ``sub``, ``total``, + which stand for the count of insertions, deletions, substitutions, + and the total number of errors. + """ int2sym = dict(enumerate(sorted(set(a) | set(b)))) sym2int = {v: k for k, v in int2sym.items()} @@ -18,6 +31,19 @@ def edit_distance(a, b, sclite_mode=False): def align(a, b, eps_symbol, sclite_mode=False): + """ + Compute the alignment between sequences ``a`` and ``b``. + Both sequences can be strings or lists of strings or ints. + + ``eps_symbol`` is used as a blank symbol to indicate insertion or deletion. + + Optional ``sclite_mode`` sets INS/DEL/SUB costs to 3/3/4 for + compatibility with sclite tool. + + Returns a list of pairs of alignment symbols. The presence of ``eps_symbol`` + in the first pair index indicates insertion, and in the second pair index, deletion. + Mismatched symbols indicate substitution. + """ int2sym = dict(enumerate(sorted(set(a) | set(b) | {eps_symbol}))) sym2int = {v: k for k, v in int2sym.items()} @@ -34,8 +60,123 @@ def align(a, b, eps_symbol, sclite_mode=False): alignment: List[Tuple[int, int]] = _kaldialign.align(ai, bi, eps_int, sclite_mode) ali = [] - idx = 0 for idx in range(len(alignment)): ali.append((int2sym[alignment[idx][0]], int2sym[alignment[idx][1]])) return ali + + +def bootstrap_wer_ci( + ref_seqs, hyp_seqs, hyp2_seqs=None, replications: int = 10000, seed: int = 0 +): + """ + Compute a boostrapping of WER to extract the 95% confidence interval (CI) + using the bootstrap method of Bisani and Ney [1]. + The implementation is based on Kaldi's ``compute-wer-bootci`` script [2]. + + Args: + ref_seqs: A list of reference sequences (str, list[str], list[int]) + hyp_seqs: A list of hypothesis sequences from system1 (str, list[str], list[int]) + hyp2_seqs: A list of hypothesis sequences from system2 (str, list[str], list[int]). + When provided, we'll compute CI for both systems as well as the probability + of system2 improving over system1. + replications: The number of replications to use for bootstrapping. + seed: The random seed to reproduce the results. + + Returns: + A dict with results. When scoring a single system (``hyp2_seqs=None``), the keys are: + - "wer" (mean WER estimate), + - "ci95" (95% confidence interval size), + - "ci95min" (95% confidence interval lower bound) + - "ci95max" (95% confidence interval upper bound) + When scoring two systems, the keys are "system1", "system2", and "p_s2_improv_over_s1". + The first two keys contain dicts as described for the single-system case, and the last key's + value is a float in the range [0, 1]. + + [1] Bisani, M., & Ney, H. (2004, May). Bootstrap estimates for confidence intervals in ASR performance evaluation. + In 2004 IEEE International Conference on Acoustics, Speech, and Signal Processing (Vol. 1, pp. I-409). IEEE. + + [2] https://github.com/kaldi-asr/kaldi/blob/master/src/bin/compute-wer-bootci.cc + """ + assert len(hyp_seqs) == len( + ref_seqs + ), f"Inconsistent number of reference ({len(ref_seqs)}) and hypothesis ({len(hyp_seqs)}) sequences." + edit_sym_per_hyp = _get_edits(ref_seqs, hyp_seqs) + mean, interval = _get_boostrap_wer_interval( + edit_sym_per_hyp, replications=replications, seed=seed + ) + ans1 = _build_results(mean, interval) + if hyp2_seqs is None: + return ans1 + + assert len(hyp2_seqs) == len( + ref_seqs + ), f"Inconsistent number of reference ({len(ref_seqs)}) and hypothesis ({len(hyp2_seqs)}) sequences for the second system (hyp2_seqs)." + edit_sym_per_hyp2 = _get_edits(ref_seqs, hyp2_seqs) + mean2, interval2 = _get_boostrap_wer_interval( + edit_sym_per_hyp2, replications=replications, seed=seed + ) + p_improv = _get_p_improv( + edit_sym_per_hyp, edit_sym_per_hyp2, replications=replications, seed=seed + ) + return { + "system1": ans1, + "system2": _build_results(mean2, interval2), + "p_s2_improv_over_s1": p_improv, + } + + +def _build_results(mean, interval): + return { + "wer": round(mean, ndigits=4), + "ci95": round(interval, ndigits=4), + "ci95min": round(mean - interval, ndigits=4), + "ci95max": round(mean + interval, ndigits=4), + } + + +def _get_edits(ref_seqs, hyp_seqs): + edit_sym_per_hyp = [] + for ref, hyp in zip(ref_seqs, hyp_seqs): + dist = edit_distance(ref, hyp) + edit_sym_per_hyp.append((dist["total"], len(ref))) + return edit_sym_per_hyp + + +def _get_boostrap_wer_interval(edit_sym_per_hyp, replications, seed): + rng = random.Random(seed) + + wer_accum, wer_mult_accum = 0.0, 0.0 + for i in range(replications): + num_sym, num_errs = 0, 0 + for j in range(len(edit_sym_per_hyp)): + nerr, nsym = rng.choice(edit_sym_per_hyp) + num_sym += nsym + num_errs += nerr + wer_rep = num_errs / num_sym + wer_accum += wer_rep + wer_mult_accum += wer_rep**2 + + mean = wer_accum / replications + _tmp = wer_mult_accum / replications - mean**2 + if _tmp < 0: + interval = 0 + else: + interval = 1.96 * math.sqrt(_tmp) + + return mean, interval + + +def _get_p_improv(edit_sym_per_hyp, edit_sym_per_hyp2, replications, seed): + rng = random.Random(seed) + + improv_accum = 0 + for i in range(replications): + num_errs = 0 + for j in range(len(edit_sym_per_hyp)): + pos = rng.randint(0, len(edit_sym_per_hyp) - 1) + num_errs += edit_sym_per_hyp[pos][0] - edit_sym_per_hyp2[pos][0] + if num_errs > 0: + improv_accum += 1 + + return improv_accum / replications diff --git a/scripts/conda/kaldialign/meta.yaml b/scripts/conda/kaldialign/meta.yaml index 2fc2934..c5e4d01 100644 --- a/scripts/conda/kaldialign/meta.yaml +++ b/scripts/conda/kaldialign/meta.yaml @@ -1,6 +1,6 @@ package: name: kaldialign - version: "0.7.2" + version: "0.8.0" source: path: "{{ environ.get('KALDIALIGN_ROOT_DIR') }}" diff --git a/tests/test_align.py b/tests/test_align.py index 6c85a91..3e504c2 100644 --- a/tests/test_align.py +++ b/tests/test_align.py @@ -1,63 +1,114 @@ -from kaldialign import align, edit_distance +from kaldialign import align, edit_distance, bootstrap_wer_ci -EPS = '*' +EPS = "*" def test_align(): - a = ['a', 'b', 'c'] - b = ['a', 's', 'x', 'c'] + a = ["a", "b", "c"] + b = ["a", "s", "x", "c"] ali = align(a, b, EPS) - assert ali == [('a', 'a'), ('b', 's'), (EPS, 'x'), ('c', 'c')] + assert ali == [("a", "a"), ("b", "s"), (EPS, "x"), ("c", "c")] dist = edit_distance(a, b) - assert dist == { 'ins': 1, 'del': 0, 'sub': 1, 'total': 2} + assert dist == {"ins": 1, "del": 0, "sub": 1, "total": 2} - a = ['a', 'b'] - b = ['b', 'c'] + a = ["a", "b"] + b = ["b", "c"] ali = align(a, b, EPS) - assert ali == [('a', EPS), ('b', 'b'), (EPS, 'c')] + assert ali == [("a", EPS), ("b", "b"), (EPS, "c")] dist = edit_distance(a, b) - assert dist == {'ins': 1, 'del': 1, 'sub': 0, 'total': 2} + assert dist == {"ins": 1, "del": 1, "sub": 0, "total": 2} - a = ['A' ,'B','C'] - b = ['D' ,'C', 'A'] + a = ["A", "B", "C"] + b = ["D", "C", "A"] ali = align(a, b, EPS) - assert ali == [('A', 'D'), ('B', EPS), ('C', 'C'), (EPS, 'A')] + assert ali == [("A", "D"), ("B", EPS), ("C", "C"), (EPS, "A")] dist = edit_distance(a, b) - assert dist == {'ins': 1, 'del': 1, 'sub': 1, 'total': 3} + assert dist == {"ins": 1, "del": 1, "sub": 1, "total": 3} - - a = ['A', 'B', 'C', 'D'] - b = ['C', 'E', 'D', 'F'] + a = ["A", "B", "C", "D"] + b = ["C", "E", "D", "F"] ali = align(a, b, EPS) - assert ali == [('A', EPS), ('B', EPS), ('C', 'C'), (EPS, 'E'), ('D', 'D'), (EPS, 'F')] + assert ali == [ + ("A", EPS), + ("B", EPS), + ("C", "C"), + (EPS, "E"), + ("D", "D"), + (EPS, "F"), + ] dist = edit_distance(a, b) - assert dist == {'ins': 2, 'del': 2, 'sub': 0, 'total': 4} + assert dist == {"ins": 2, "del": 2, "sub": 0, "total": 4} def test_edit_distance(): - a = ['a', 'b', 'c'] - b = ['a', 's', 'x', 'c'] + a = ["a", "b", "c"] + b = ["a", "s", "x", "c"] results = edit_distance(a, b) - assert results == { - 'ins': 1, - 'del': 0, - 'sub': 1, - 'total': 2 - } + assert results == {"ins": 1, "del": 0, "sub": 1, "total": 2} + def test_edit_distance_sclite(): - a = ['a', 'b'] - b = ['b', 'c'] + a = ["a", "b"] + b = ["b", "c"] results = edit_distance(a, b, sclite_mode=True) - assert results == { - 'ins': 1, - 'del': 1, - 'sub': 0, - 'total': 2 - } + assert results == {"ins": 1, "del": 1, "sub": 0, "total": 2} + + +def test_bootstrap_wer_ci_1system(): + ref = [ + ("a", "b", "c"), + ("d", "e", "f"), + ] + + hyp = [ + ("a", "b", "d"), + ("e", "f", "f"), + ] + + ans = bootstrap_wer_ci(ref, hyp) + + assert ans["wer"] == 0.4989 + assert ans["ci95"] == 0.2312 + assert ans["ci95min"] == 0.2678 + assert ans["ci95max"] == 0.7301 + +def test_bootstrap_wer_ci_2system(): + ref = [ + ("a", "b", "c"), + ("d", "e", "f"), + ] -if __name__ == '__main__': + hyp = [ + ("a", "b", "d"), + ("e", "f", "f"), + ] + + hyp2 = [ + ("a", "b", "c"), + ("e", "e", "f"), + ] + + ans = bootstrap_wer_ci(ref, hyp, hyp2) + + s = ans["system1"] + assert s["wer"] == 0.4989 + assert s["ci95"] == 0.2312 + assert s["ci95min"] == 0.2678 + assert s["ci95max"] == 0.7301 + + s = ans["system2"] + assert s["wer"] == 0.1656 + assert s["ci95"] == 0.2312 + assert s["ci95min"] == -0.0656 + assert s["ci95max"] == 0.3968 + + assert ans["p_s2_improv_over_s1"] == 1.0 + + +if __name__ == "__main__": test_align() test_edit_distance() - + test_edit_distance_sclite() + test_bootstrap_wer_ci_1system() + test_bootstrap_wer_ci_2system()