Skip to content

Commit

Permalink
Re-run all experiments using built-in vocabulary/tokenizer from wav2v…
Browse files Browse the repository at this point in the history
…ec 2.0 Robust. Add early stopping.
  • Loading branch information
fauxneticien committed Apr 22, 2022
1 parent 5447bdb commit 6b6a43d
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 14 deletions.
3 changes: 3 additions & 0 deletions scripts/exp_asr-eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
("data/exps/asr/checkpoints/train-60", "data/exps/asr/datasets/test.tsv"),
("data/exps/asr/checkpoints/train-40", "data/exps/asr/datasets/test.tsv"),
("data/exps/asr/checkpoints/train-20", "data/exps/asr/datasets/test.tsv"),
("data/exps/asr/checkpoints/train-10", "data/exps/asr/datasets/test.tsv"),
("data/exps/asr/checkpoints/train-05", "data/exps/asr/datasets/test.tsv"),
("data/exps/asr/checkpoints/train-01", "data/exps/asr/datasets/test.tsv"),
# Baseline model with no additional fine-tuning
("facebook/wav2vec2-large-robust-ft-swbd-300h", "data/exps/asr/datasets/test.tsv"),

Expand Down
5 changes: 3 additions & 2 deletions scripts/exp_asr-train.sh
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
#!/usr/bin/bash

# Data subset experiments
declare -a subsets=("train-100" "train-80" "train-60" "train-40" "train-20")
declare -a subsets=("train-100" "train-80" "train-60" "train-40" "train-20" "train-10" "train-05" "train-01")

for i in "${subsets[@]}"
do
python scripts/train_asr-by-w2v2-ft.py \
facebook/wav2vec2-large-robust-ft-swbd-300h \
"data/exps/asr/checkpoints/$i" \
"data/exps/asr/datasets/$i.tsv" \
data/exps/asr/datasets/test.tsv
data/exps/asr/datasets/test.tsv \
--use_target_vocab False
done

# Cross-validation experiments without language model
Expand Down
15 changes: 11 additions & 4 deletions scripts/helpers/asr.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def dataset_from_dict(dataset_dict):

def remove_special_characters(batch):
chars_to_ignore_regex = '[\,\?\.\!\-\;\:\"\“\%\‘\”\�]'
batch["sentence"] = re.sub(chars_to_ignore_regex, '', batch["sentence"]).lower() + " "
batch["sentence"] = re.sub(chars_to_ignore_regex, '', batch["sentence"])

return batch

Expand Down Expand Up @@ -75,7 +75,6 @@ def preprocess_text(dataset_dict):
print("Pre-processing transcriptions ...")
dataset_dict = dataset_dict.map(remove_special_characters)

print("Creating vocabulary ...")
vocab_path = create_vocab(dataset_dict)

enable_progress_bar()
Expand Down Expand Up @@ -143,6 +142,7 @@ def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) ->
def get_metrics_computer(processor):

wer_metric = load_metric("wer")
cer_metric = load_metric("cer")

def compute_metrics(pred):

Expand All @@ -159,9 +159,15 @@ def compute_metrics(pred):
# Retrieve labels as characters, e.g. 'hello', from label_ids, e.g. [5, 3, 10, 10, 2] (where 5 = 'h')
label_str = processor.tokenizer.batch_decode(pred.label_ids, group_tokens=False)

print(pd.DataFrame({
"pred_str" : pred_str,
"label_str" : label_str
}))

wer = wer_metric.compute(predictions=pred_str, references=label_str)
cer = cer_metric.compute(predictions=pred_str, references=label_str)

return {"wer": wer}
return {"wer": wer, "cer": cer}

return compute_metrics

Expand All @@ -170,7 +176,7 @@ def configure_w2v2_for_training(dataset, args, vocab_dict, w2v2_config={}):
feature_extractor_kwargs = w2v2_config["feature_extractor"] if "feature_extractor" in w2v2_config.keys() else {}
model_kwargs = w2v2_config["model_kwargs"] if "model_kwargs" in w2v2_config.keys() else {}

if args.use_target_vocab:
if args.use_target_vocab is True:
vocab_path = os.path.join(args.output_dir, 'vocab.json')

print(f"Writing created vocabulary to {vocab_path}")
Expand All @@ -183,6 +189,7 @@ def configure_w2v2_for_training(dataset, args, vocab_dict, w2v2_config={}):

else:

print("Using vocabulary from tokenizer ...")
tokenizer = Wav2Vec2CTCTokenizer.from_pretrained(args.repo_path_or_name)

feature_extractor = Wav2Vec2FeatureExtractor(**feature_extractor_kwargs)
Expand Down
36 changes: 28 additions & 8 deletions scripts/train_asr-by-w2v2-ft.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
import math
import os
import torch

Expand All @@ -14,6 +15,7 @@
process_data
)
from transformers import (
EarlyStoppingCallback,
logging,
Trainer,
TrainingArguments
Expand All @@ -38,13 +40,17 @@

args = parser.parse_args()

# Turns out bool('False') evaluates to True in Python (only bool('') is False)
args.use_target_vocab = False if args.use_target_vocab == 'False' else True

logging.set_verbosity(args.hft_logging)

# For debugging
# args.repo_path_or_name = "facebook/wav2vec2-large-robust-ft-swbd-300h"
# args.train_tsv = 'data/train-asr/train.tsv'
# args.eval_tsv = 'data/train-asr/test.tsv'
# args.output_dir = 'data/asr-temp'
# args.use_target_vocab = False

os.makedirs(args.output_dir, exist_ok=True)

Expand Down Expand Up @@ -76,25 +82,38 @@
# Set logging to 'INFO' or else progress bar gets hidden
logging.set_verbosity(20)

n_epochs = 50
batch_size = 32

# How many epochs between evals?
eps_b_eval = 5
# Save/Eval/Logging steps
sel_steps = int(math.ceil(len(dataset['train']) / batch_size) * eps_b_eval)

training_args = TrainingArguments(
output_dir=args.output_dir,
group_by_length=True,
per_device_train_batch_size=32,
per_device_train_batch_size=batch_size,
gradient_accumulation_steps=1,
evaluation_strategy="steps",
num_train_epochs=50,
num_train_epochs=n_epochs,
fp16=True if torch.cuda.is_available() else False,
seed=7135,
save_steps=400,
eval_steps=400,
logging_steps=400,
save_steps=sel_steps,
eval_steps=sel_steps,
logging_steps=sel_steps,
learning_rate=1e-4,
warmup_steps=500,
save_total_limit=1,
# Warm up: 100 steps or 10% of total optimisation steps
warmup_steps=min(100, int(0.1 * sel_steps * n_epochs)),
report_to="none",
# 2022-03-09: manually set optmizier to PyTorch implementation torch.optim.AdamW
# 'adamw_torch' to get rid of deprecation warning for default optimizer 'adamw_hf'
optim="adamw_torch"
optim="adamw_torch",
metric_for_best_model="wer",
save_total_limit=5,
load_best_model_at_end = True,
# Lower WER is better
greater_is_better=False
)

trainer = Trainer(
Expand All @@ -105,6 +124,7 @@
train_dataset=dataset['train'],
eval_dataset=dataset['eval'],
tokenizer=processor.feature_extractor,
callbacks = [EarlyStoppingCallback(early_stopping_patience=3)]
)

print("Training model ...")
Expand Down

0 comments on commit 6b6a43d

Please sign in to comment.