diff --git a/scripts/exp_asr-eval.py b/scripts/exp_asr-eval.py index 02d41ce..a941460 100644 --- a/scripts/exp_asr-eval.py +++ b/scripts/exp_asr-eval.py @@ -13,6 +13,9 @@ ("data/exps/asr/checkpoints/train-60", "data/exps/asr/datasets/test.tsv"), ("data/exps/asr/checkpoints/train-40", "data/exps/asr/datasets/test.tsv"), ("data/exps/asr/checkpoints/train-20", "data/exps/asr/datasets/test.tsv"), + ("data/exps/asr/checkpoints/train-10", "data/exps/asr/datasets/test.tsv"), + ("data/exps/asr/checkpoints/train-05", "data/exps/asr/datasets/test.tsv"), + ("data/exps/asr/checkpoints/train-01", "data/exps/asr/datasets/test.tsv"), # Baseline model with no additional fine-tuning ("facebook/wav2vec2-large-robust-ft-swbd-300h", "data/exps/asr/datasets/test.tsv"), diff --git a/scripts/exp_asr-train.sh b/scripts/exp_asr-train.sh index 8ae3f3c..2c4660b 100755 --- a/scripts/exp_asr-train.sh +++ b/scripts/exp_asr-train.sh @@ -1,7 +1,7 @@ #!/usr/bin/bash # Data subset experiments -declare -a subsets=("train-100" "train-80" "train-60" "train-40" "train-20") +declare -a subsets=("train-100" "train-80" "train-60" "train-40" "train-20" "train-10" "train-05" "train-01") for i in "${subsets[@]}" do @@ -9,7 +9,8 @@ do facebook/wav2vec2-large-robust-ft-swbd-300h \ "data/exps/asr/checkpoints/$i" \ "data/exps/asr/datasets/$i.tsv" \ - data/exps/asr/datasets/test.tsv + data/exps/asr/datasets/test.tsv \ + --use_target_vocab False done # Cross-validation experiments without language model diff --git a/scripts/helpers/asr.py b/scripts/helpers/asr.py index 2376017..f23582f 100644 --- a/scripts/helpers/asr.py +++ b/scripts/helpers/asr.py @@ -38,7 +38,7 @@ def dataset_from_dict(dataset_dict): def remove_special_characters(batch): chars_to_ignore_regex = '[\,\?\.\!\-\;\:\"\“\%\‘\”\�]' - batch["sentence"] = re.sub(chars_to_ignore_regex, '', batch["sentence"]).lower() + " " + batch["sentence"] = re.sub(chars_to_ignore_regex, '', batch["sentence"]) return batch @@ -75,7 +75,6 @@ def preprocess_text(dataset_dict): print("Pre-processing transcriptions ...") dataset_dict = dataset_dict.map(remove_special_characters) - print("Creating vocabulary ...") vocab_path = create_vocab(dataset_dict) enable_progress_bar() @@ -143,6 +142,7 @@ def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> def get_metrics_computer(processor): wer_metric = load_metric("wer") + cer_metric = load_metric("cer") def compute_metrics(pred): @@ -159,9 +159,15 @@ def compute_metrics(pred): # Retrieve labels as characters, e.g. 'hello', from label_ids, e.g. [5, 3, 10, 10, 2] (where 5 = 'h') label_str = processor.tokenizer.batch_decode(pred.label_ids, group_tokens=False) + print(pd.DataFrame({ + "pred_str" : pred_str, + "label_str" : label_str + })) + wer = wer_metric.compute(predictions=pred_str, references=label_str) + cer = cer_metric.compute(predictions=pred_str, references=label_str) - return {"wer": wer} + return {"wer": wer, "cer": cer} return compute_metrics @@ -170,7 +176,7 @@ def configure_w2v2_for_training(dataset, args, vocab_dict, w2v2_config={}): feature_extractor_kwargs = w2v2_config["feature_extractor"] if "feature_extractor" in w2v2_config.keys() else {} model_kwargs = w2v2_config["model_kwargs"] if "model_kwargs" in w2v2_config.keys() else {} - if args.use_target_vocab: + if args.use_target_vocab is True: vocab_path = os.path.join(args.output_dir, 'vocab.json') print(f"Writing created vocabulary to {vocab_path}") @@ -183,6 +189,7 @@ def configure_w2v2_for_training(dataset, args, vocab_dict, w2v2_config={}): else: + print("Using vocabulary from tokenizer ...") tokenizer = Wav2Vec2CTCTokenizer.from_pretrained(args.repo_path_or_name) feature_extractor = Wav2Vec2FeatureExtractor(**feature_extractor_kwargs) diff --git a/scripts/train_asr-by-w2v2-ft.py b/scripts/train_asr-by-w2v2-ft.py index 71f95f5..a21bc7c 100644 --- a/scripts/train_asr-by-w2v2-ft.py +++ b/scripts/train_asr-by-w2v2-ft.py @@ -1,4 +1,5 @@ import json +import math import os import torch @@ -14,6 +15,7 @@ process_data ) from transformers import ( + EarlyStoppingCallback, logging, Trainer, TrainingArguments @@ -38,6 +40,9 @@ args = parser.parse_args() +# Turns out bool('False') evaluates to True in Python (only bool('') is False) +args.use_target_vocab = False if args.use_target_vocab == 'False' else True + logging.set_verbosity(args.hft_logging) # For debugging @@ -45,6 +50,7 @@ # args.train_tsv = 'data/train-asr/train.tsv' # args.eval_tsv = 'data/train-asr/test.tsv' # args.output_dir = 'data/asr-temp' +# args.use_target_vocab = False os.makedirs(args.output_dir, exist_ok=True) @@ -76,25 +82,38 @@ # Set logging to 'INFO' or else progress bar gets hidden logging.set_verbosity(20) +n_epochs = 50 +batch_size = 32 + +# How many epochs between evals? +eps_b_eval = 5 +# Save/Eval/Logging steps +sel_steps = int(math.ceil(len(dataset['train']) / batch_size) * eps_b_eval) + training_args = TrainingArguments( output_dir=args.output_dir, group_by_length=True, - per_device_train_batch_size=32, + per_device_train_batch_size=batch_size, gradient_accumulation_steps=1, evaluation_strategy="steps", - num_train_epochs=50, + num_train_epochs=n_epochs, fp16=True if torch.cuda.is_available() else False, seed=7135, - save_steps=400, - eval_steps=400, - logging_steps=400, + save_steps=sel_steps, + eval_steps=sel_steps, + logging_steps=sel_steps, learning_rate=1e-4, - warmup_steps=500, - save_total_limit=1, + # Warm up: 100 steps or 10% of total optimisation steps + warmup_steps=min(100, int(0.1 * sel_steps * n_epochs)), report_to="none", # 2022-03-09: manually set optmizier to PyTorch implementation torch.optim.AdamW # 'adamw_torch' to get rid of deprecation warning for default optimizer 'adamw_hf' - optim="adamw_torch" + optim="adamw_torch", + metric_for_best_model="wer", + save_total_limit=5, + load_best_model_at_end = True, + # Lower WER is better + greater_is_better=False ) trainer = Trainer( @@ -105,6 +124,7 @@ train_dataset=dataset['train'], eval_dataset=dataset['eval'], tokenizer=processor.feature_extractor, + callbacks = [EarlyStoppingCallback(early_stopping_patience=3)] ) print("Training model ...")