Skip to content

Commit

Permalink
Add support for boostraping WER 95% confidence intervals following Ka…
Browse files Browse the repository at this point in the history
…ldi impl
  • Loading branch information
pzelasko committed Mar 4, 2024
1 parent 6fe7314 commit 8dc3537
Show file tree
Hide file tree
Showing 5 changed files with 295 additions and 43 deletions.
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ cmake_minimum_required(VERSION 3.8 FATAL_ERROR)
project(kaldialign CXX)

# Please remember to also change line 3 of ./scripts/conda/kaldialign/meta.yaml
set(KALDIALIGN_VERSION "0.7.2")
set(KALDIALIGN_VERSION "0.8.0")

if(NOT CMAKE_BUILD_TYPE)
set(CMAKE_BUILD_TYPE Release)
Expand Down
66 changes: 63 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@ python3 -m pip install --verbose .

## Examples

- `align(seq1, seq2, epsilon)` - used to obtain the alignment between two string sequences. `epsilon` should be a null symbol (indicating deletion/insertion) that doesn't exist in either sequence.
### Alignment

`align(seq1, seq2, epsilon)` - used to obtain the alignment between two string sequences. `epsilon` should be a null symbol (indicating deletion/insertion) that doesn't exist in either sequence.

```python
from kaldialign import align
Expand All @@ -42,7 +44,9 @@ ali = align(a, b, EPS)
assert ali == [('a', 'a'), ('b', 's'), (EPS, 'x'), ('c', 'c')]
```

- `edit_distance(seq1, seq2)` - used to obtain the total edit distance, as well as the number of insertions, deletions and substitutions.
### Edit distance

`edit_distance(seq1, seq2)` - used to obtain the total edit distance, as well as the number of insertions, deletions and substitutions.

```python
from kaldialign import edit_distance
Expand All @@ -58,9 +62,65 @@ assert results == {
}
```

- For both of the above examples, you can pass `sclite_mode=True` to compute WER or alignments
For alignment and edit distance, you can pass `sclite_mode=True` to compute WER or alignments
based on SCLITE style weights, i.e., insertion/deletion cost 3 and substitution cost 4.

### Bootstrapping method to extract WER 95% confidence intervals

`boostrap_wer_ci(ref, hyp)` - obtain the 95% confidence intervals for WER using Bisani and Ney boostrapping method.

```python
from kaldialign import bootstrap_wer_ci

ref = [
("a", "b", "c"),
("d", "e", "f"),
]
hyp = [
("a", "b", "d"),
("e", "f", "f"),
]
ans = bootstrap_wer_ci(ref, hyp)
assert ans["wer"] == 0.4989
assert ans["ci95"] == 0.2312
assert ans["ci95min"] == 0.2678
assert ans["ci95max"] == 0.7301
```

It also supports providing hypotheses from system 1 and system 2 to compute the probability of S2 improving over S1:

```python
from kaldialign import bootstrap_wer_ci

ref = [
("a", "b", "c"),
("d", "e", "f"),
]
hyp = [
("a", "b", "d"),
("e", "f", "f"),
]
hyp2 = [
("a", "b", "c"),
("e", "e", "f"),
]
ans = bootstrap_wer_ci(ref, hyp, hyp2)

s = ans["system1"]
assert s["wer"] == 0.4989
assert s["ci95"] == 0.2312
assert s["ci95min"] == 0.2678
assert s["ci95max"] == 0.7301

s = ans["system2"]
assert s["wer"] == 0.1656
assert s["ci95"] == 0.2312
assert s["ci95min"] == -0.0656
assert s["ci95max"] == 0.3968

assert ans["p_s2_improv_over_s1"] == 1.0
```

## Motivation

The need for this arised from the fact that practically all implementations of the Levenshtein distance have slight differences, making it impossible to use a different scoring tool than Kaldi and get the same error rate results. This package copies code from Kaldi directly and wraps it using Cython, avoiding the issue altogether.
143 changes: 142 additions & 1 deletion kaldialign/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,21 @@
import math
from typing import List, Tuple
import random
import _kaldialign


def edit_distance(a, b, sclite_mode=False):
"""
Compute the edit distance between sequences ``a`` and ``b``.
Both sequences can be strings or lists of strings or ints.
Optional ``sclite_mode`` sets INS/DEL/SUB costs to 3/3/4 for
compatibility with sclite tool.
Returns a dict with keys ``ins``, ``del``, ``sub``, ``total``,
which stand for the count of insertions, deletions, substitutions,
and the total number of errors.
"""
int2sym = dict(enumerate(sorted(set(a) | set(b))))
sym2int = {v: k for k, v in int2sym.items()}

Expand All @@ -18,6 +31,19 @@ def edit_distance(a, b, sclite_mode=False):


def align(a, b, eps_symbol, sclite_mode=False):
"""
Compute the alignment between sequences ``a`` and ``b``.
Both sequences can be strings or lists of strings or ints.
``eps_symbol`` is used as a blank symbol to indicate insertion or deletion.
Optional ``sclite_mode`` sets INS/DEL/SUB costs to 3/3/4 for
compatibility with sclite tool.
Returns a list of pairs of alignment symbols. The presence of ``eps_symbol``
in the first pair index indicates insertion, and in the second pair index, deletion.
Mismatched symbols indicate substitution.
"""
int2sym = dict(enumerate(sorted(set(a) | set(b) | {eps_symbol})))
sym2int = {v: k for k, v in int2sym.items()}

Expand All @@ -34,8 +60,123 @@ def align(a, b, eps_symbol, sclite_mode=False):
alignment: List[Tuple[int, int]] = _kaldialign.align(ai, bi, eps_int, sclite_mode)

ali = []
idx = 0
for idx in range(len(alignment)):
ali.append((int2sym[alignment[idx][0]], int2sym[alignment[idx][1]]))

return ali


def bootstrap_wer_ci(
ref_seqs, hyp_seqs, hyp2_seqs=None, replications: int = 10000, seed: int = 0
):
"""
Compute a boostrapping of WER to extract the 95% confidence interval (CI)
using the bootstrap method of Bisani and Ney [1].
The implementation is based on Kaldi's ``compute-wer-bootci`` script [2].
Args:
ref_seqs: A list of reference sequences (str, list[str], list[int])
hyp_seqs: A list of hypothesis sequences from system1 (str, list[str], list[int])
hyp2_seqs: A list of hypothesis sequences from system2 (str, list[str], list[int]).
When provided, we'll compute CI for both systems as well as the probability
of system2 improving over system1.
replications: The number of replications to use for bootstrapping.
seed: The random seed to reproduce the results.
Returns:
A dict with results. When scoring a single system (``hyp2_seqs=None``), the keys are:
- "wer" (mean WER estimate),
- "ci95" (95% confidence interval size),
- "ci95min" (95% confidence interval lower bound)
- "ci95max" (95% confidence interval upper bound)
When scoring two systems, the keys are "system1", "system2", and "p_s2_improv_over_s1".
The first two keys contain dicts as described for the single-system case, and the last key's
value is a float in the range [0, 1].
[1] Bisani, M., & Ney, H. (2004, May). Bootstrap estimates for confidence intervals in ASR performance evaluation.
In 2004 IEEE International Conference on Acoustics, Speech, and Signal Processing (Vol. 1, pp. I-409). IEEE.
[2] https://github.com/kaldi-asr/kaldi/blob/master/src/bin/compute-wer-bootci.cc
"""
assert len(hyp_seqs) == len(
ref_seqs
), f"Inconsistent number of reference ({len(ref_seqs)}) and hypothesis ({len(hyp_seqs)}) sequences."
edit_sym_per_hyp = _get_edits(ref_seqs, hyp_seqs)
mean, interval = _get_boostrap_wer_interval(
edit_sym_per_hyp, replications=replications, seed=seed
)
ans1 = _build_results(mean, interval)
if hyp2_seqs is None:
return ans1

assert len(hyp2_seqs) == len(
ref_seqs
), f"Inconsistent number of reference ({len(ref_seqs)}) and hypothesis ({len(hyp2_seqs)}) sequences for the second system (hyp2_seqs)."
edit_sym_per_hyp2 = _get_edits(ref_seqs, hyp2_seqs)
mean2, interval2 = _get_boostrap_wer_interval(
edit_sym_per_hyp2, replications=replications, seed=seed
)
p_improv = _get_p_improv(
edit_sym_per_hyp, edit_sym_per_hyp2, replications=replications, seed=seed
)
return {
"system1": ans1,
"system2": _build_results(mean2, interval2),
"p_s2_improv_over_s1": p_improv,
}


def _build_results(mean, interval):
return {
"wer": round(mean, ndigits=4),
"ci95": round(interval, ndigits=4),
"ci95min": round(mean - interval, ndigits=4),
"ci95max": round(mean + interval, ndigits=4),
}


def _get_edits(ref_seqs, hyp_seqs):
edit_sym_per_hyp = []
for ref, hyp in zip(ref_seqs, hyp_seqs):
dist = edit_distance(ref, hyp)
edit_sym_per_hyp.append((dist["total"], len(ref)))
return edit_sym_per_hyp


def _get_boostrap_wer_interval(edit_sym_per_hyp, replications, seed):
rng = random.Random(seed)

wer_accum, wer_mult_accum = 0.0, 0.0
for i in range(replications):
num_sym, num_errs = 0, 0
for j in range(len(edit_sym_per_hyp)):
nerr, nsym = rng.choice(edit_sym_per_hyp)
num_sym += nsym
num_errs += nerr
wer_rep = num_errs / num_sym
wer_accum += wer_rep
wer_mult_accum += wer_rep**2

mean = wer_accum / replications
_tmp = wer_mult_accum / replications - mean**2
if _tmp < 0:
interval = 0
else:
interval = 1.96 * math.sqrt(_tmp)

return mean, interval


def _get_p_improv(edit_sym_per_hyp, edit_sym_per_hyp2, replications, seed):
rng = random.Random(seed)

improv_accum = 0
for i in range(replications):
num_errs = 0
for j in range(len(edit_sym_per_hyp)):
pos = rng.randint(0, len(edit_sym_per_hyp) - 1)
num_errs += edit_sym_per_hyp[pos][0] - edit_sym_per_hyp2[pos][0]
if num_errs > 0:
improv_accum += 1

return improv_accum / replications
2 changes: 1 addition & 1 deletion scripts/conda/kaldialign/meta.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package:
name: kaldialign
version: "0.7.2"
version: "0.8.0"

source:
path: "{{ environ.get('KALDIALIGN_ROOT_DIR') }}"
Expand Down
Loading

0 comments on commit 8dc3537

Please sign in to comment.