Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Implements RNNT+MMI #1030

Draft
wants to merge 9 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
355 changes: 350 additions & 5 deletions k2/csrc/fsa_algo.cu

Large diffs are not rendered by default.

44 changes: 42 additions & 2 deletions k2/csrc/fsa_algo.h
Original file line number Diff line number Diff line change
Expand Up @@ -842,7 +842,7 @@ FsaOrVec ReplaceFsa(FsaVec &src, FsaOrVec &index, int32_t symbol_range_begin,
weight 'x.weight', then the reverse of 'src' accepts the reverse of string 'x'
with weight 'x.weight.reverse'.

Implementation notss:
Implementation notes:
The Fsa in k2 only has one start state 0, and the only final state with
the largest state number whose in-coming arcs have "-1" as the label.
So, 1) the start state of 'dest' will correspond to the final state of 'src'.
Expand All @@ -864,13 +864,53 @@ FsaOrVec ReplaceFsa(FsaVec &src, FsaOrVec &index, int32_t symbol_range_begin,
@param [out] dest Output Fsa or FsaVec. At exit, it will be equivalent
to the reverse Fsa of 'src'. Caution: the reverse will
ignore the "-1" label.

@param [out,optional] arc_map For each arc in `dest`, gives the index of
the corresponding arc in `src` that it corresponds to.

*/
void Reverse(FsaVec &src, FsaVec *dest, Array1<int32_t> *arc_map = nullptr);

/*
* Generate denominator lattice from sampled linear paths for RNN-T+MMI
* training.
*
* Implementation notes:
* 1) Generate "states" for each sampled symbol from their left_symbols and
* the frame_ids they are sampled from.
* 2) Sort those "states" for each sequence and then merge the same "states".
* 3) Map all of the sampled symbols to the merged "states".
* 4) Remove duplicate arcs.
*
* @param [in] sampled_paths The sampled symbols, it has a regular shape of
* [seq][num_path][path_length]. All its elements MUST satisfy
* `0 <= value < vocab_size.
* @param [in] frame_ids It contains the frame indexes of at which frame we
* sampled the symbols, which has same shape of sampled_paths.
* @param [in] left_symbols The left_symbols of the sampled symbols, it has a
* regular shape of [seq][num_path][path_length][context], the
* first three indexes are the same as sampled_paths. Each
* sublist along axis 3 has `context_size` elements. All its
* elements MUST satisfy `0 <= value < vocab_size`.
* @param [in] sampling_probs It contains the probabilities of sampling each
* symbol, which has the same shape as sampled_paths.
* @param [in] boundary It contains the number of frames for each sequence.
* @param [in] vocab_size The vocabulary size.
* @param [in] context_size The number of left symbols.
* @param [out] arc_map For each arc in the return Fsa, gives the orignal
* index (idx012) in sampled_paths that it corresponds to.
*
* @return Return the generated lattice.
*/
FsaVec GenerateDenominatorLattice(Ragged<int32_t> &sampled_paths,
Ragged<int32_t> &frame_ids,
Ragged<int32_t> &left_symbols,
Ragged<float> &sampling_probs,
Array1<int32_t> &boundary,
int32_t vocab_size,
int32_t context_size,
Array1<int32_t> *arc_map);

} // namespace k2

#endif // K2_CSRC_FSA_ALGO_H_
44 changes: 44 additions & 0 deletions k2/csrc/fsa_algo_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1383,4 +1383,48 @@ TEST(FsaAlgo, TestLevenshteinGraph) {
}
}

TEST(FsaAlgo, TestGenerateDenominatorLattice) {
for (const ContextPtr &c : {GetCpuContext(), GetCudaContext()}) {
Ragged<int32_t> sampled_paths(c, "[ [ [ 3 5 0 4 6 0 2 1 ] "
" [ 2 0 5 4 0 6 1 2 ] "
" [ 3 5 2 0 0 1 6 4 ] ] "
" [ [ 7 0 4 0 6 0 3 0 ] "
" [ 0 7 3 0 2 0 4 5 ] "
" [ 7 0 3 4 0 1 2 0 ] ] ]");
Ragged<int32_t> frame_ids(c, "[ [ [ 0 0 0 1 1 1 2 2 ] "
" [ 0 0 1 1 1 2 2 2 ] "
" [ 0 0 0 0 1 2 2 2 ] ] "
" [ [ 0 0 1 1 2 2 3 3 ] "
" [ 0 1 1 1 2 2 3 1 ] "
" [ 0 0 1 1 1 2 2 2 ] ] ]");
Ragged<int32_t> left_symbols(c,
"[ [ [ [ 0 0 ] [ 0 3 ] [ 3 5 ] [ 3 5 ] [ 5 4 ] [ 4 6 ] [ 4 6 ] [ 6 2 ] ] "
" [ [ 0 0 ] [ 0 2 ] [ 0 2 ] [ 2 5 ] [ 5 4 ] [ 5 4 ] [ 4 6 ] [ 6 1 ] ] "
" [ [ 0 0 ] [ 0 3 ] [ 3 5 ] [ 5 2 ] [ 5 2 ] [ 5 2 ] [ 2 1 ] [ 1 6 ] ] "
" ] "
" [ [ [ 0 0 ] [ 0 7 ] [ 0 7 ] [ 7 4 ] [ 7 4 ] [ 4 6 ] [ 4 6 ] [ 6 3 ] ] "
" [ [ 0 0 ] [ 0 0 ] [ 0 7 ] [ 7 3 ] [ 7 3 ] [ 3 2 ] [ 3 2 ] [ 0 0 ] ] "
" [ [ 0 0 ] [ 0 7 ] [ 0 7 ] [ 7 3 ] [ 3 4 ] [ 3 4 ] [ 4 1 ] [ 1 2 ] ] "
" ] ]");

Ragged<float> sampling_probs(c, "[ [ [ 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 ] "
" [ 0.2 0.2 0.2 0.1 0.2 0.2 0.1 0.2 ] "
" [ 0.1 0.1 0.1 0.3 0.2 0.3 0.3 0.3 ] ] "
" [ [ 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 ] "
" [ 0.2 0.2 0.2 0.1 0.2 0.1 0.2 0.2 ] "
" [ 0.1 0.1 0.2 0.2 0.3 0.3 0.3 0.3 ] ] "
"]");
Array1<int32_t> boundary(c, "[ 3 4 ]");

Array1<int32_t> arc_map;
FsaVec lattice = GenerateDenominatorLattice(
sampled_paths, frame_ids, left_symbols, sampling_probs, boundary,
10 /*vocab_size*/, 2 /*context_size*/, &arc_map);
K2_LOG(INFO) << arc_map;
K2_LOG(INFO) << lattice;
K2_LOG(INFO) << FsaToString(lattice.Index(0, 0));
K2_LOG(INFO) << FsaToString(lattice.Index(0, 1));
}
}

} // namespace k2
26 changes: 26 additions & 0 deletions k2/python/csrc/torch/fsa_algo.cu
Original file line number Diff line number Diff line change
Expand Up @@ -807,6 +807,31 @@ static void PybindReverse(py::module &m) {
py::arg("src"), py::arg("need_arc_map") = true);
}

static void PybindGenerateDenominatorLattice(py::module &m) {
m.def(
"generate_denominator_lattice",
[](RaggedAny &sampled_paths, RaggedAny &frame_ids,
RaggedAny &left_symbols, RaggedAny &sampling_probs,
torch::Tensor &boundary, int32_t vocab_size, int32_t context_size)
-> std::pair<FsaVec, torch::Tensor> {
DeviceGuard guard(sampled_paths.any.Context());
Array1<int32_t> arc_map;
Array1<int32_t> boundary_array = FromTorch<int32_t>(boundary);
FsaVec lattice = GenerateDenominatorLattice(
sampled_paths.any.Specialize<int32_t>(),
frame_ids.any.Specialize<int32_t>(),
left_symbols.any.Specialize<int32_t>(),
sampling_probs.any.Specialize<float>(),
boundary_array,
vocab_size, context_size, &arc_map);
auto arc_map_tensor = ToTorch(arc_map);
return std::make_pair(lattice, arc_map_tensor);
},
py::arg("sampled_paths"), py::arg("frame_ids"), py::arg("left_symbols"),
py::arg("sampling_probs"), py::arg("boundary"), py::arg("vocab_size"),
py::arg("context_size"));
}

} // namespace k2

void PybindFsaAlgo(py::module &m) {
Expand All @@ -820,6 +845,7 @@ void PybindFsaAlgo(py::module &m) {
k2::PybindDeterminize(m);
k2::PybindExpandArcs(m);
k2::PybindFixFinalLabels(m);
k2::PybindGenerateDenominatorLattice(m);
k2::PybindIntersect(m);
k2::PybindIntersectDense(m);
k2::PybindIntersectDensePruned(m);
Expand Down
1 change: 1 addition & 0 deletions k2/python/k2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
from .fsa_algo import ctc_topo
from .fsa_algo import determinize
from .fsa_algo import expand_ragged_attributes
from .fsa_algo import generate_denominator_lattice
from .fsa_algo import intersect
from .fsa_algo import intersect_device
from .fsa_algo import invert
Expand Down
66 changes: 66 additions & 0 deletions k2/python/k2/fsa_algo.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import torch
import _k2
import k2
import logging

from . import fsa_properties
from .fsa import Fsa
Expand Down Expand Up @@ -1473,3 +1474,68 @@ def union(fsas: Fsa) -> Fsa:
out_fsa = k2.utils.fsa_from_unary_function_tensor(fsas, ragged_arc,
arc_map)
return out_fsa


def generate_denominator_lattice(
sampled_paths: torch.Tensor,
frame_ids: torch.Tensor,
left_symbols: torch.Tensor,
sampling_probs: torch.Tensor,
path_scores: torch.Tensor,
boundary: torch.Tensor,
vocab_size: int,
context_size: int,
return_arc_map: bool = False,
) -> Union[Fsa, Tuple[Fsa, torch.Tensor]]:
"""Generate denominator lattice from sampled linear paths for RNN-T+MMI
training.

Args:
sampled_paths:
The sampled symbols, it has a shape of (seq, num_path, path_length).
All its elements MUST satisfy `0 <= value < vocab_size.
frame_ids:
It contains the frame indexes of at which frame we sampled the symbols,
which has same shape of sampled_paths.
left_symbols:
The left_symbols of the sampled symbols, it has a shape of
(seq, num_path, path_length, context_size), the first three indexes are
the same as sampled_paths. All its elements MUST satisfy
`0 <= value < vocab_size`.
sampling_probs:
It contains the probabilities of sampling each symbol, which has a
same shape as sampled_paths. Normally comes from the output of
"predictor" head.
path_scores:
It contains the scores of each sampled symbol, which has a same shape as
sampled_paths. It might contain the output of hybrid head and the extra
language model output. Note: Autograd is supported for this tensor.
boundary:
It contains the number of frames for each sequence.
vocab_size:
The vocabulary size.
context_size:
The number of left symbols.
return_arc_map:
Whether to return arc_map.
"""
ragged_arc, arc_map = _k2.generate_denominator_lattice(
sampled_paths=k2.RaggedTensor(sampled_paths),
frame_ids=k2.RaggedTensor(frame_ids),
left_symbols=k2.RaggedTensor(left_symbols),
sampling_probs=k2.RaggedTensor(sampling_probs),
boundary=boundary,
vocab_size=vocab_size,
context_size=context_size,
)
lattice = Fsa(ragged_arc)
a_value = getattr(lattice, "scores")
# Enable autograd for path_scores
b_value = index_select(path_scores.flatten(), arc_map)
assert torch.all(a_value >= 0), a_value
value = b_value + a_value
setattr(lattice, "scores", value)
if return_arc_map:
return lattice, arc_map
else:
return lattice
Loading