From 5447bdbfdd432c90d0c0476a67164d99b580ccf1 Mon Sep 17 00:00:00 2001 From: Nay San Date: Tue, 19 Apr 2022 10:32:46 -0700 Subject: [PATCH] Use processor.tokenizer to get labels --- scripts/helpers/asr.py | 17 +++++------------ scripts/train_asr-by-w2v2-ft.py | 4 ++-- 2 files changed, 7 insertions(+), 14 deletions(-) diff --git a/scripts/helpers/asr.py b/scripts/helpers/asr.py index df7ee8c..2376017 100644 --- a/scripts/helpers/asr.py +++ b/scripts/helpers/asr.py @@ -147,25 +147,18 @@ def get_metrics_computer(processor): def compute_metrics(pred): pred_logits = pred.predictions - pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id if type(processor).__name__ == "Wav2Vec2ProcessorWithLM": - pred_str = processor.batch_decode(pred_logits).text - - label_chars = [ processor.tokenizer.convert_ids_to_tokens(l) for l in pred.label_ids ] - label_str = [ "".join([ id for id in l if id not in processor.tokenizer.unique_no_split_tokens ]) for l in label_chars ] - label_str = [ l.replace(processor.tokenizer.word_delimiter_token, " ").strip() for l in label_str ] - else: - - pred_logits = pred.predictions pred_ids = np.argmax(pred_logits, axis=-1) - pred_str = processor.batch_decode(pred_ids) - # we do not want to group tokens when computing the metrics - label_str = processor.batch_decode(pred.label_ids, group_tokens=False) + # Replace data collator padding with tokenizer's padding + pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id + # Retrieve labels as characters, e.g. 'hello', from label_ids, e.g. [5, 3, 10, 10, 2] (where 5 = 'h') + label_str = processor.tokenizer.batch_decode(pred.label_ids, group_tokens=False) + wer = wer_metric.compute(predictions=pred_str, references=label_str) return {"wer": wer} diff --git a/scripts/train_asr-by-w2v2-ft.py b/scripts/train_asr-by-w2v2-ft.py index 2cfd609..71f95f5 100644 --- a/scripts/train_asr-by-w2v2-ft.py +++ b/scripts/train_asr-by-w2v2-ft.py @@ -6,7 +6,7 @@ from datasets import load_metric from helpers.asr import ( configure_lm, - configure_w2v2_components, + configure_w2v2_for_training, DataCollatorCTCWithPadding, dataset_from_dict, get_metrics_computer, @@ -66,7 +66,7 @@ dataset, vocab_dict = preprocess_text(dataset) -model, processor = configure_w2v2_components(dataset, args, vocab_dict, w2v2_config) +model, processor = configure_w2v2_for_training(dataset, args, vocab_dict, w2v2_config) if args.lm_arpa is not None: processor = configure_lm(processor, args.lm_arpa, args.output_dir)