diff --git a/stanza/models/constituency/parser_training.py b/stanza/models/constituency/parser_training.py index 1317fc073d..cee8367b5a 100644 --- a/stanza/models/constituency/parser_training.py +++ b/stanza/models/constituency/parser_training.py @@ -450,6 +450,8 @@ def iterate_training(args, trainer, train_trees, train_sequences, transitions, d wandb.log({n: torch.linalg.norm(p)}) if args['early_dropout'] > 0 and trainer.epochs_trained >= args['early_dropout']: + if any(x > 0.0 for x in (trainer.model.word_dropout.p, trainer.model.predict_dropout.p, trainer.model.lstm_input_dropout.p)): + logger.info("Setting dropout to 0.0 at epoch %d", trainer.epochs_trained) trainer.model.word_dropout.p = 0 trainer.model.predict_dropout.p = 0 trainer.model.lstm_input_dropout.p = 0