Skip to content

Commit

Permalink
Directly integrate the context_graph into _ctc_prefix_beam_search
Browse files Browse the repository at this point in the history
  • Loading branch information
kaixunhuang0 committed Aug 1, 2023
1 parent 6e18c0d commit 3d18735
Showing 1 changed file with 95 additions and 154 deletions.
249 changes: 95 additions & 154 deletions wenet/transformer/asr_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,7 @@ def _ctc_prefix_beam_search(
decoding_chunk_size: int = -1,
num_decoding_left_chunks: int = -1,
simulate_streaming: bool = False,
context_graph: ContextGraph = None,
) -> Tuple[List[List[int]], torch.Tensor]:
""" CTC prefix beam search inner implementation
Expand Down Expand Up @@ -393,140 +394,93 @@ def _ctc_prefix_beam_search(
encoder_out) # (1, maxlen, vocab_size)
ctc_probs = ctc_probs.squeeze(0)
# cur_hyps: (prefix, (blank_ending_score, none_blank_ending_score))
cur_hyps = [(tuple(), (0.0, -float('inf')))]
# 2. CTC beam search step by step
for t in range(0, maxlen):
logp = ctc_probs[t] # (vocab_size,)
# key: prefix, value (pb, pnb), default value(-inf, -inf)
next_hyps = defaultdict(lambda: (-float('inf'), -float('inf')))
# 2.1 First beam prune: select topk best
top_k_logp, top_k_index = logp.topk(beam_size) # (beam_size,)
for s in top_k_index:
s = s.item()
ps = logp[s].item()
for prefix, (pb, pnb) in cur_hyps:
last = prefix[-1] if len(prefix) > 0 else None
if s == 0: # blank
n_pb, n_pnb = next_hyps[prefix]
n_pb = log_add([n_pb, pb + ps, pnb + ps])
next_hyps[prefix] = (n_pb, n_pnb)
elif s == last:
# Update *ss -> *s;
n_pb, n_pnb = next_hyps[prefix]
n_pnb = log_add([n_pnb, pnb + ps])
next_hyps[prefix] = (n_pb, n_pnb)
# Update *s-s -> *ss, - is for blank
n_prefix = prefix + (s, )
n_pb, n_pnb = next_hyps[n_prefix]
n_pnb = log_add([n_pnb, pb + ps])
next_hyps[n_prefix] = (n_pb, n_pnb)
else:
n_prefix = prefix + (s, )
n_pb, n_pnb = next_hyps[n_prefix]
n_pnb = log_add([n_pnb, pb + ps, pnb + ps])
next_hyps[n_prefix] = (n_pb, n_pnb)

# 2.2 Second beam prune
next_hyps = sorted(next_hyps.items(),
key=lambda x: log_add(list(x[1])),
reverse=True)
cur_hyps = next_hyps[:beam_size]
hyps = [(y[0], log_add([y[1][0], y[1][1]])) for y in cur_hyps]
return hyps, encoder_out

def _ctc_prefix_beam_search_with_bias(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
beam_size: int,
decoding_chunk_size: int = -1,
num_decoding_left_chunks: int = -1,
simulate_streaming: bool = False,
context_graph: ContextGraph = None,
) -> Tuple[List[List[int]], torch.Tensor]:
""" CTC prefix beam search inner implementation, rescoring with
context graph
Args:
speech (torch.Tensor): (batch, max_len, feat_dim)
speech_length (torch.Tensor): (batch, )
beam_size (int): beam size for beam search
decoding_chunk_size (int): decoding chunk for dynamic chunk
trained model.
<0: for decoding, use full chunk.
>0: for decoding, use fixed chunk size as set.
0: used for training, it's prohibited here
simulate_streaming (bool): whether do encoder forward in a
streaming fashion
Returns:
List[List[int]]: nbest results
torch.Tensor: encoder output, (1, max_len, encoder_dim),
it will be used for rescoring in attention rescoring mode
"""
assert speech.shape[0] == speech_lengths.shape[0]
assert decoding_chunk_size != 0
batch_size = speech.shape[0]
# For CTC prefix beam search, we only support batch_size=1
assert batch_size == 1
# Let's assume B = batch_size and N = beam_size
# 1. Encoder forward and get CTC score
encoder_out, encoder_mask = self._forward_encoder(
speech, speech_lengths, decoding_chunk_size,
num_decoding_left_chunks,
simulate_streaming) # (B, maxlen, encoder_dim)
maxlen = encoder_out.size(1)
ctc_probs = self.ctc.log_softmax(
encoder_out) # (1, maxlen, vocab_size)
ctc_probs = ctc_probs.squeeze(0)

# cur_hyps: (prefix, (blank_ending_score, none_blank_ending_score,
# context_graph_state, context_score))
cur_hyps = [(tuple(), (0.0, -float('inf'), 0, 0.0))]
# 2. CTC beam search step by step
for t in range(0, maxlen):
logp = ctc_probs[t] # (vocab_size,)
# key: prefix, value (pb, pnb), default value(-inf, -inf)
next_hyps = defaultdict(lambda: (-float('inf'), -float('inf'), 0, 0.0))
# 2.1 First beam prune: select topk best
top_k_logp, top_k_index = logp.topk(beam_size) # (beam_size,)
for s in top_k_index:
s = s.item()
ps = logp[s].item()
for prefix, (pb, pnb, c_state, c_score) in cur_hyps:
last = prefix[-1] if len(prefix) > 0 else None
if s == 0: # blank
n_pb, n_pnb, _, _ = next_hyps[prefix]
n_pb = log_add([n_pb, pb + ps, pnb + ps])
next_hyps[prefix] = (n_pb, n_pnb, c_state, c_score)
elif s == last:
# Update *ss -> *s;
n_pb, n_pnb, _, _ = next_hyps[prefix]
n_pnb = log_add([n_pnb, pnb + ps])
next_hyps[prefix] = (n_pb, n_pnb, c_state, c_score)
# Update *s-s -> *ss, - is for blank
n_prefix = prefix + (s, )
n_pb, n_pnb, _, _ = next_hyps[n_prefix]
new_c_state, new_c_score = context_graph. \
find_next_state(c_state, s)
n_pnb = log_add([n_pnb, pb + ps])
next_hyps[n_prefix] = (n_pb, n_pnb, new_c_state,
c_score + new_c_score)
else:
n_prefix = prefix + (s, )
n_pb, n_pnb, _, _ = next_hyps[n_prefix]
new_c_state, new_c_score = context_graph. \
find_next_state(c_state, s)
n_pnb = log_add([n_pnb, pb + ps, pnb + ps])
next_hyps[n_prefix] = (n_pb, n_pnb, new_c_state,
c_score + new_c_score)

# 2.2 Second beam prune
next_hyps = sorted(next_hyps.items(),
key=lambda x: log_add([x[1][0], x[1][1]]) + x[1][3],
reverse=True)
cur_hyps = next_hyps[:beam_size]
hyps = [(y[0], log_add([y[1][0], y[1][1]]) + y[1][3]) for y in cur_hyps]
if context_graph is None:
cur_hyps = [(tuple(), (0.0, -float('inf')))]
for t in range(0, maxlen):
logp = ctc_probs[t] # (vocab_size,)
# key: prefix, value (pb, pnb), default value(-inf, -inf)
next_hyps = defaultdict(lambda: (-float('inf'), -float('inf')))
# 2.1 First beam prune: select topk best
top_k_logp, top_k_index = logp.topk(beam_size) # (beam_size,)
for s in top_k_index:
s = s.item()
ps = logp[s].item()
for prefix, (pb, pnb) in cur_hyps:
last = prefix[-1] if len(prefix) > 0 else None
if s == 0: # blank
n_pb, n_pnb = next_hyps[prefix]
n_pb = log_add([n_pb, pb + ps, pnb + ps])
next_hyps[prefix] = (n_pb, n_pnb)
elif s == last:
# Update *ss -> *s;
n_pb, n_pnb = next_hyps[prefix]
n_pnb = log_add([n_pnb, pnb + ps])
next_hyps[prefix] = (n_pb, n_pnb)
# Update *s-s -> *ss, - is for blank
n_prefix = prefix + (s, )
n_pb, n_pnb = next_hyps[n_prefix]
n_pnb = log_add([n_pnb, pb + ps])
next_hyps[n_prefix] = (n_pb, n_pnb)
else:
n_prefix = prefix + (s, )
n_pb, n_pnb = next_hyps[n_prefix]
n_pnb = log_add([n_pnb, pb + ps, pnb + ps])
next_hyps[n_prefix] = (n_pb, n_pnb)

# 2.2 Second beam prune
next_hyps = sorted(next_hyps.items(),
key=lambda x: log_add(list(x[1])),
reverse=True)
cur_hyps = next_hyps[:beam_size]
hyps = [(y[0], log_add([y[1][0], y[1][1]])) for y in cur_hyps]
else:
cur_hyps = [(tuple(), (0.0, -float('inf'), 0, 0.0))]
# 2. CTC beam search step by step
for t in range(0, maxlen):
logp = ctc_probs[t] # (vocab_size,)
# key: prefix, value (pb, pnb, context_state, context_score),
# default value(-inf, -inf, 0, 0.0)
next_hyps = defaultdict(lambda: (-float('inf'), -float('inf'), 0, 0.0))
# 2.1 First beam prune: select topk best
top_k_logp, top_k_index = logp.topk(beam_size) # (beam_size,)
for s in top_k_index:
s = s.item()
ps = logp[s].item()
for prefix, (pb, pnb, c_state, c_score) in cur_hyps:
last = prefix[-1] if len(prefix) > 0 else None
if s == 0: # blank
n_pb, n_pnb, _, _ = next_hyps[prefix]
n_pb = log_add([n_pb, pb + ps, pnb + ps])
next_hyps[prefix] = (n_pb, n_pnb, c_state, c_score)
elif s == last:
# Update *ss -> *s;
n_pb, n_pnb, _, _ = next_hyps[prefix]
n_pnb = log_add([n_pnb, pnb + ps])
next_hyps[prefix] = (n_pb, n_pnb, c_state, c_score)
# Update *s-s -> *ss, - is for blank
n_prefix = prefix + (s, )
n_pb, n_pnb, _, _ = next_hyps[n_prefix]
new_c_state, new_c_score = context_graph. \
find_next_state(c_state, s)
n_pnb = log_add([n_pnb, pb + ps])
next_hyps[n_prefix] = (n_pb, n_pnb, new_c_state,
c_score + new_c_score)
else:
n_prefix = prefix + (s, )
n_pb, n_pnb, _, _ = next_hyps[n_prefix]
new_c_state, new_c_score = context_graph. \
find_next_state(c_state, s)
n_pnb = log_add([n_pnb, pb + ps, pnb + ps])
next_hyps[n_prefix] = (n_pb, n_pnb, new_c_state,
c_score + new_c_score)

# 2.2 Second beam prune
next_hyps = sorted(next_hyps.items(),
key=lambda x: log_add([x[1][0], x[1][1]]) + x[1][3],
reverse=True)
cur_hyps = next_hyps[:beam_size]
hyps = [(y[0], log_add([y[1][0], y[1][1]]) + y[1][3]) for y in cur_hyps]
return hyps, encoder_out

def ctc_prefix_beam_search(
Expand Down Expand Up @@ -556,19 +510,11 @@ def ctc_prefix_beam_search(
Returns:
List[int]: CTC prefix beam search nbest results
"""
if context_graph is None:
hyps, _ = self._ctc_prefix_beam_search(speech, speech_lengths,
beam_size, decoding_chunk_size,
num_decoding_left_chunks,
simulate_streaming)
else:
hyps, _ = self._ctc_prefix_beam_search_with_bias(speech,
speech_lengths,
beam_size,
decoding_chunk_size,
num_decoding_left_chunks,
simulate_streaming,
context_graph)
hyps, _ = self._ctc_prefix_beam_search(speech, speech_lengths,
beam_size, decoding_chunk_size,
num_decoding_left_chunks,
simulate_streaming,
context_graph)
return hyps[0]

def attention_rescoring(
Expand Down Expand Up @@ -614,14 +560,9 @@ def attention_rescoring(
# For attention rescoring we only support batch_size=1
assert batch_size == 1
# encoder_out: (1, maxlen, encoder_dim), len(hyps) = beam_size
if context_graph is None:
hyps, encoder_out = self._ctc_prefix_beam_search(
speech, speech_lengths, beam_size, decoding_chunk_size,
num_decoding_left_chunks, simulate_streaming)
else:
hyps, encoder_out = self._ctc_prefix_beam_search_with_bias(
speech, speech_lengths, beam_size, decoding_chunk_size,
num_decoding_left_chunks, simulate_streaming, context_graph)
hyps, encoder_out = self._ctc_prefix_beam_search(
speech, speech_lengths, beam_size, decoding_chunk_size,
num_decoding_left_chunks, simulate_streaming, context_graph)

assert len(hyps) == beam_size
hyps_pad = pad_sequence([
Expand Down

0 comments on commit 3d18735

Please sign in to comment.