diff --git a/stanza/utils/datasets/prepare_lemma_treebank.py b/stanza/utils/datasets/prepare_lemma_treebank.py index a666c4ed18..16a88c5f8e 100644 --- a/stanza/utils/datasets/prepare_lemma_treebank.py +++ b/stanza/utils/datasets/prepare_lemma_treebank.py @@ -9,9 +9,17 @@ and it will prepare each of train, dev, test """ +from stanza.models.common.constant import treebank_to_short_name + import stanza.utils.datasets.common as common import stanza.utils.datasets.prepare_tokenizer_treebank as prepare_tokenizer_treebank +import stanza.utils.datasets.prepare_lemma_classifier as prepare_lemma_classifier + +def add_specific_args(parser) -> None: + parser.add_argument('--no_lemma_classifier', dest='lemma_classifier', action='store_false', default=True, + help="Don't use the lemma classifier datasets. Default is to build lemma classifier as part of the original lemmatizer") + def check_lemmas(train_file): """ Check if a treebank has any lemmas in it @@ -50,8 +58,12 @@ def process_treebank(treebank, model_type, paths, args): augment = True prepare_tokenizer_treebank.copy_conllu_treebank(treebank, model_type, paths, paths["LEMMA_DATA_DIR"], augment=augment) + short_name = treebank_to_short_name(treebank) + if args.lemma_classifier and short_name in prepare_lemma_classifier.DATASET_MAPPING: + prepare_lemma_classifier.main(short_name) + def main(): - common.main(process_treebank, common.ModelType.LEMMA) + common.main(process_treebank, common.ModelType.LEMMA, add_specific_args) if __name__ == '__main__': main() diff --git a/stanza/utils/training/common.py b/stanza/utils/training/common.py index 6e56a91598..267c4933d4 100644 --- a/stanza/utils/training/common.py +++ b/stanza/utils/training/common.py @@ -72,7 +72,7 @@ def add_charlm_args(parser): parser.add_argument('--charlm', default="default", type=str, help='Which charlm to run on. Will use the default charlm for this language/model if not set. Set to None to turn off charlm for languages with a default charlm') parser.add_argument('--no_charlm', dest='charlm', action="store_const", const=None, help="Don't use a charlm, even if one is used by default for this package") -def main(run_treebank, model_dir, model_name, add_specific_args=None, sub_argparse=None, build_model_filename=None, choose_charlm_method=None): +def main(run_treebank, model_dir, model_name, add_specific_args=None, sub_argparse=None, build_model_filename=None, choose_charlm_method=None, args=None): """ A main program for each of the run_xyz scripts @@ -83,7 +83,11 @@ def main(run_treebank, model_dir, model_name, add_specific_args=None, sub_argpar - the charlm, for example, needs this feature, since it makes both forward and backward models """ - logger.info("Training program called with:\n" + " ".join(sys.argv)) + if args is None: + logger.info("Training program called with:\n" + " ".join(sys.argv)) + args = sys.argv[1:] + else: + logger.info("Training program called with:\n" + " ".join(args)) paths = default_paths.get_default_paths() @@ -93,9 +97,9 @@ def main(run_treebank, model_dir, model_name, add_specific_args=None, sub_argpar if '--extra_args' in sys.argv: idx = sys.argv.index('--extra_args') extra_args = sys.argv[idx+1:] - command_args = parser.parse_args(sys.argv[1:idx]) + command_args = parser.parse_args(sys.argv[:idx]) else: - command_args, extra_args = parser.parse_known_args() + command_args, extra_args = parser.parse_known_args(args=args) # Pass this through to the underlying model as well as use it here # we don't put --save_name here for the awkward situation of diff --git a/stanza/utils/training/run_lemma.py b/stanza/utils/training/run_lemma.py index 835f2d8904..026a6c5fa3 100644 --- a/stanza/utils/training/run_lemma.py +++ b/stanza/utils/training/run_lemma.py @@ -20,17 +20,23 @@ from stanza.models import identity_lemmatizer from stanza.models import lemmatizer +from stanza.models.lemma import attach_lemma_classifier from stanza.utils.training import common from stanza.utils.training.common import Mode, add_charlm_args, build_lemma_charlm_args, choose_lemma_charlm +from stanza.utils.training import run_lemma_classifier from stanza.utils.datasets.prepare_lemma_treebank import check_lemmas +import stanza.utils.datasets.prepare_lemma_classifier as prepare_lemma_classifier logger = logging.getLogger('stanza') def add_lemma_args(parser): add_charlm_args(parser) + parser.add_argument('--no_lemma_classifier', dest='lemma_classifier', action='store_false', default=True, + help="Don't use the lemma classifier datasets. Default is to build lemma classifier as part of the original lemmatizer") + def build_model_filename(paths, short_name, command_args, extra_args): """ Figure out what the model savename will be, taking into account the model settings. @@ -142,6 +148,22 @@ def run_treebank(mode, paths, treebank, short_name, logger.info("Running test lemmatizer for {} with args {}".format(treebank, test_args)) lemmatizer.main(test_args) + use_lemma_classifier = command_args.lemma_classifier and short_name in prepare_lemma_classifier.DATASET_MAPPING + if use_lemma_classifier and mode == Mode.TRAIN: + # TODO: pass along charlm args + lemma_classifier_args = [treebank] + if command_args.force: + lemma_classifier_args.append('--force') + run_lemma_classifier.main(lemma_classifier_args) + + save_name = build_model_filename(paths, short_name, command_args, extra_args) + # TODO: use a temp path for the lemma_classifier or keep it somewhere + attach_args = ['--input', save_name, + '--output', save_name, + '--classifier', 'saved_models/lemma_classifier/%s_lemma_classifier.pt' % short_name] + attach_lemma_classifier.main(attach_args) + # TODO: rerun dev set / test set with the attached classifier? + def main(): common.main(run_treebank, "lemma", "lemmatizer", add_lemma_args, sub_argparse=lemmatizer.build_argparse(), build_model_filename=build_model_filename, choose_charlm_method=choose_lemma_charlm) diff --git a/stanza/utils/training/run_lemma_classifier.py b/stanza/utils/training/run_lemma_classifier.py index 1a8420814a..73669e79b0 100644 --- a/stanza/utils/training/run_lemma_classifier.py +++ b/stanza/utils/training/run_lemma_classifier.py @@ -79,8 +79,8 @@ def run_treebank(mode, paths, treebank, short_name, eval_args = bert_args + model_type_args + base_args + eval_args + embedding_args + extra_args evaluate_models.main(eval_args) -def main(): - common.main(run_treebank, "lemma_classifier", "lemma_classifier", add_lemma_args, sub_argparse=train_lstm_model.build_argparse(), build_model_filename=build_model_filename, choose_charlm_method=choose_lemma_charlm) +def main(args=None): + common.main(run_treebank, "lemma_classifier", "lemma_classifier", add_lemma_args, sub_argparse=train_lstm_model.build_argparse(), build_model_filename=build_model_filename, choose_charlm_method=choose_lemma_charlm, args=args) if __name__ == '__main__':