Skip to content

Commit

Permalink
Add a flag for turning off dropout, as in early dropout
Browse files Browse the repository at this point in the history
Add a test that the early dropout is turning off all the dropouts in a model
  • Loading branch information
AngledLuffa committed Dec 10, 2024
1 parent 2216cb5 commit d4a81a7
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 0 deletions.
5 changes: 5 additions & 0 deletions stanza/models/constituency/parser_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,6 +449,11 @@ def iterate_training(args, trainer, train_trees, train_sequences, transitions, d
if watch_regex.search(n):
wandb.log({n: torch.linalg.norm(p)})

if args['early_dropout'] > 0 and trainer.epochs_trained >= args['early_dropout']:
trainer.model.word_dropout.p = 0
trainer.model.predict_dropout.p = 0
trainer.model.lstm_input_dropout.p = 0

# recreate the optimizer and alter the model as needed if we hit a new multistage split
if args['multistage'] and trainer.epochs_trained in multistage_splits:
# we may be loading a save model from an earlier epoch if the scores stopped increasing
Expand Down
1 change: 1 addition & 0 deletions stanza/models/constituency_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,6 +562,7 @@ def build_argparse():
parser.add_argument('--loss', default='cross', help='cross, large_margin, or focal. Focal requires `pip install focal_loss_torch`')
parser.add_argument('--loss_focal_gamma', default=2, type=float, help='gamma value for a focal loss')

parser.add_argument('--early_dropout', default=-1, type=int, help='When to turn off dropout')
# When using word_dropout and predict_dropout in conjunction with relu, one particular experiment produced the following dev scores after 300 iterations:
# 0.0: 0.9085
# 0.2: 0.9165
Expand Down
24 changes: 24 additions & 0 deletions stanza/tests/constituency/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import pytest
import torch
from torch import nn
from torch import optim

from stanza import Pipeline
Expand Down Expand Up @@ -253,6 +254,29 @@ def test_train(self, wordvec_pretrain_file):
with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tmpdirname:
self.run_train_test(wordvec_pretrain_file, tmpdirname)

def test_early_dropout(self, wordvec_pretrain_file):
"""
Test the whole thing for a few iterations on the fake data
"""
with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tmpdirname:
args = ['--early_dropout', '3']
_, model = self.run_train_test(wordvec_pretrain_file, tmpdirname, num_epochs=6, extra_args=args)
model = model.model
dropouts = [(name, module) for name, module in model.named_children() if isinstance(module, nn.Dropout)]
assert len(dropouts) > 0, "Didn't find any dropouts in the model!"
for name, module in dropouts:
assert module.p == 0.0, "Dropout module %s was not set to 0 with early_dropout"

with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tmpdirname:
# test that when turned off, early_dropout doesn't happen
args = ['--early_dropout', '-1']
_, model = self.run_train_test(wordvec_pretrain_file, tmpdirname, num_epochs=6, extra_args=args)
model = model.model
dropouts = [(name, module) for name, module in model.named_children() if isinstance(module, nn.Dropout)]
assert len(dropouts) > 0, "Didn't find any dropouts in the model!"
if all(module.p == 0.0 for module in dropouts):
raise AssertionError("All dropouts were 0 after training even though early_dropout was set to -1")

def test_train_silver(self, wordvec_pretrain_file):
"""
Test the whole thing for a few iterations on the fake data
Expand Down

0 comments on commit d4a81a7

Please sign in to comment.