Skip to content

Commit

Permalink
Use processor.tokenizer to get labels
Browse files Browse the repository at this point in the history
  • Loading branch information
fauxneticien committed Apr 19, 2022
1 parent b0bd904 commit 5447bdb
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 14 deletions.
17 changes: 5 additions & 12 deletions scripts/helpers/asr.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,25 +147,18 @@ def get_metrics_computer(processor):
def compute_metrics(pred):

pred_logits = pred.predictions
pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id

if type(processor).__name__ == "Wav2Vec2ProcessorWithLM":

pred_str = processor.batch_decode(pred_logits).text

label_chars = [ processor.tokenizer.convert_ids_to_tokens(l) for l in pred.label_ids ]
label_str = [ "".join([ id for id in l if id not in processor.tokenizer.unique_no_split_tokens ]) for l in label_chars ]
label_str = [ l.replace(processor.tokenizer.word_delimiter_token, " ").strip() for l in label_str ]

else:

pred_logits = pred.predictions
pred_ids = np.argmax(pred_logits, axis=-1)

pred_str = processor.batch_decode(pred_ids)
# we do not want to group tokens when computing the metrics
label_str = processor.batch_decode(pred.label_ids, group_tokens=False)

# Replace data collator padding with tokenizer's padding
pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id
# Retrieve labels as characters, e.g. 'hello', from label_ids, e.g. [5, 3, 10, 10, 2] (where 5 = 'h')
label_str = processor.tokenizer.batch_decode(pred.label_ids, group_tokens=False)

wer = wer_metric.compute(predictions=pred_str, references=label_str)

return {"wer": wer}
Expand Down
4 changes: 2 additions & 2 deletions scripts/train_asr-by-w2v2-ft.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from datasets import load_metric
from helpers.asr import (
configure_lm,
configure_w2v2_components,
configure_w2v2_for_training,
DataCollatorCTCWithPadding,
dataset_from_dict,
get_metrics_computer,
Expand Down Expand Up @@ -66,7 +66,7 @@

dataset, vocab_dict = preprocess_text(dataset)

model, processor = configure_w2v2_components(dataset, args, vocab_dict, w2v2_config)
model, processor = configure_w2v2_for_training(dataset, args, vocab_dict, w2v2_config)

if args.lm_arpa is not None:
processor = configure_lm(processor, args.lm_arpa, args.output_dir)
Expand Down

0 comments on commit 5447bdb

Please sign in to comment.