Skip to content

Commit

Permalink
Make prepare_lemma_treebank automatically prepare the lemma_classifie…
Browse files Browse the repository at this point in the history
…r data, make run_lemma automatically attach it
  • Loading branch information
AngledLuffa committed Nov 22, 2024
1 parent 50609f8 commit 1a39fa9
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 7 deletions.
14 changes: 13 additions & 1 deletion stanza/utils/datasets/prepare_lemma_treebank.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,17 @@
and it will prepare each of train, dev, test
"""

from stanza.models.common.constant import treebank_to_short_name

import stanza.utils.datasets.common as common
import stanza.utils.datasets.prepare_tokenizer_treebank as prepare_tokenizer_treebank

import stanza.utils.datasets.prepare_lemma_classifier as prepare_lemma_classifier

def add_specific_args(parser) -> None:
parser.add_argument('--no_lemma_classifier', dest='lemma_classifier', action='store_false', default=True,
help="Don't use the lemma classifier datasets. Default is to build lemma classifier as part of the original lemmatizer")

def check_lemmas(train_file):
"""
Check if a treebank has any lemmas in it
Expand Down Expand Up @@ -50,8 +58,12 @@ def process_treebank(treebank, model_type, paths, args):
augment = True
prepare_tokenizer_treebank.copy_conllu_treebank(treebank, model_type, paths, paths["LEMMA_DATA_DIR"], augment=augment)

short_name = treebank_to_short_name(treebank)
if args.lemma_classifier and short_name in prepare_lemma_classifier.DATASET_MAPPING:
prepare_lemma_classifier.main(short_name)

def main():
common.main(process_treebank, common.ModelType.LEMMA)
common.main(process_treebank, common.ModelType.LEMMA, add_specific_args)

if __name__ == '__main__':
main()
Expand Down
12 changes: 8 additions & 4 deletions stanza/utils/training/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def add_charlm_args(parser):
parser.add_argument('--charlm', default="default", type=str, help='Which charlm to run on. Will use the default charlm for this language/model if not set. Set to None to turn off charlm for languages with a default charlm')
parser.add_argument('--no_charlm', dest='charlm', action="store_const", const=None, help="Don't use a charlm, even if one is used by default for this package")

def main(run_treebank, model_dir, model_name, add_specific_args=None, sub_argparse=None, build_model_filename=None, choose_charlm_method=None):
def main(run_treebank, model_dir, model_name, add_specific_args=None, sub_argparse=None, build_model_filename=None, choose_charlm_method=None, args=None):
"""
A main program for each of the run_xyz scripts
Expand All @@ -83,7 +83,11 @@ def main(run_treebank, model_dir, model_name, add_specific_args=None, sub_argpar
- the charlm, for example, needs this feature, since it makes
both forward and backward models
"""
logger.info("Training program called with:\n" + " ".join(sys.argv))
if args is None:
logger.info("Training program called with:\n" + " ".join(sys.argv))
args = sys.argv[1:]
else:
logger.info("Training program called with:\n" + " ".join(args))

paths = default_paths.get_default_paths()

Expand All @@ -93,9 +97,9 @@ def main(run_treebank, model_dir, model_name, add_specific_args=None, sub_argpar
if '--extra_args' in sys.argv:
idx = sys.argv.index('--extra_args')
extra_args = sys.argv[idx+1:]
command_args = parser.parse_args(sys.argv[1:idx])
command_args = parser.parse_args(sys.argv[:idx])
else:
command_args, extra_args = parser.parse_known_args()
command_args, extra_args = parser.parse_known_args(args=args)

# Pass this through to the underlying model as well as use it here
# we don't put --save_name here for the awkward situation of
Expand Down
22 changes: 22 additions & 0 deletions stanza/utils/training/run_lemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,23 @@

from stanza.models import identity_lemmatizer
from stanza.models import lemmatizer
from stanza.models.lemma import attach_lemma_classifier

from stanza.utils.training import common
from stanza.utils.training.common import Mode, add_charlm_args, build_lemma_charlm_args, choose_lemma_charlm
from stanza.utils.training import run_lemma_classifier

from stanza.utils.datasets.prepare_lemma_treebank import check_lemmas
import stanza.utils.datasets.prepare_lemma_classifier as prepare_lemma_classifier

logger = logging.getLogger('stanza')

def add_lemma_args(parser):
add_charlm_args(parser)

parser.add_argument('--no_lemma_classifier', dest='lemma_classifier', action='store_false', default=True,
help="Don't use the lemma classifier datasets. Default is to build lemma classifier as part of the original lemmatizer")

def build_model_filename(paths, short_name, command_args, extra_args):
"""
Figure out what the model savename will be, taking into account the model settings.
Expand Down Expand Up @@ -142,6 +148,22 @@ def run_treebank(mode, paths, treebank, short_name,
logger.info("Running test lemmatizer for {} with args {}".format(treebank, test_args))
lemmatizer.main(test_args)

use_lemma_classifier = command_args.lemma_classifier and short_name in prepare_lemma_classifier.DATASET_MAPPING
if use_lemma_classifier and mode == Mode.TRAIN:
# TODO: pass along charlm args
lemma_classifier_args = [treebank]
if command_args.force:
lemma_classifier_args.append('--force')
run_lemma_classifier.main(lemma_classifier_args)

save_name = build_model_filename(paths, short_name, command_args, extra_args)
# TODO: use a temp path for the lemma_classifier or keep it somewhere
attach_args = ['--input', save_name,
'--output', save_name,
'--classifier', 'saved_models/lemma_classifier/%s_lemma_classifier.pt' % short_name]
attach_lemma_classifier.main(attach_args)
# TODO: rerun dev set / test set with the attached classifier?

def main():
common.main(run_treebank, "lemma", "lemmatizer", add_lemma_args, sub_argparse=lemmatizer.build_argparse(), build_model_filename=build_model_filename, choose_charlm_method=choose_lemma_charlm)

Expand Down
4 changes: 2 additions & 2 deletions stanza/utils/training/run_lemma_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,8 @@ def run_treebank(mode, paths, treebank, short_name,
eval_args = bert_args + model_type_args + base_args + eval_args + embedding_args + extra_args
evaluate_models.main(eval_args)

def main():
common.main(run_treebank, "lemma_classifier", "lemma_classifier", add_lemma_args, sub_argparse=train_lstm_model.build_argparse(), build_model_filename=build_model_filename, choose_charlm_method=choose_lemma_charlm)
def main(args=None):
common.main(run_treebank, "lemma_classifier", "lemma_classifier", add_lemma_args, sub_argparse=train_lstm_model.build_argparse(), build_model_filename=build_model_filename, choose_charlm_method=choose_lemma_charlm, args=args)


if __name__ == '__main__':
Expand Down

0 comments on commit 1a39fa9

Please sign in to comment.