From 19bb53b83882837c2ae2e1c40798e5c07e43829c Mon Sep 17 00:00:00 2001 From: John Bauer Date: Thu, 28 Nov 2024 21:02:06 -0800 Subject: [PATCH] Convert the MWT training to use a pytorch dataloader with shuffling. In theory this should also provide some cpu/gpu parallelism at test time, although we haven't done anything to ensure it is using multiprocessing --- stanza/models/mwt/data.py | 80 ++++++++++++++++++++------------ stanza/models/mwt_expander.py | 6 +-- stanza/pipeline/mwt_processor.py | 2 +- 3 files changed, 54 insertions(+), 34 deletions(-) diff --git a/stanza/models/mwt/data.py b/stanza/models/mwt/data.py index 791c23efab..d6c4469e3b 100644 --- a/stanza/models/mwt/data.py +++ b/stanza/models/mwt/data.py @@ -1,10 +1,12 @@ import random import numpy as np import os -from collections import Counter +from collections import Counter, namedtuple import logging import torch +from torch.nn.utils.rnn import pad_sequence +from torch.utils.data import DataLoader as DL import stanza.models.common.seq2seq_constant as constant from stanza.models.common.data import map_to_ids, get_long_tensor, get_float_tensor, sort_all @@ -14,6 +16,9 @@ logger = logging.getLogger('stanza') +DataSample = namedtuple("DataSample", "src tgt_in tgt_out orig_text") +DataBatch = namedtuple("DataBatch", "src src_mask tgt_in tgt_out orig_text orig_idx") + # enforce that the MWT splitter knows about a couple different alternate apostrophes # including covering some potential " typos # setting the augmentation to a very low value should be enough to teach it @@ -56,12 +61,9 @@ def __init__(self, doc, batch_size, args, vocab=None, evaluation=False, expand_u indices = list(range(len(data))) random.shuffle(indices) data = [data[i] for i in indices] - self.num_examples = len(data) - # chunk into batches - data = [data[i:i+batch_size] for i in range(0, len(data), batch_size)] self.data = data - logger.debug("{} batches created.".format(len(data))) + self.num_examples = len(data) def init_vocab(self, data): assert self.evaluation == False # for eval vocab must exist @@ -77,17 +79,14 @@ def maybe_augment_apos(self, datum): break return datum - - def process(self, data): - processed = [] - for d in data: - if not self.evaluation and self.augment_apos > 0: - d = self.maybe_augment_apos(d) - src = list(d[0]) - src = [constant.SOS] + src + [constant.EOS] - tgt_in, tgt_out = self.prepare_target(self.vocab, d) - src = self.vocab.map(src) - processed += [[src, tgt_in, tgt_out, d[0]]] + def process(self, sample): + if not self.evaluation and self.augment_apos > 0: + sample = self.maybe_augment_apos(sample) + src = list(sample[0]) + src = [constant.SOS] + src + [constant.EOS] + tgt_in, tgt_out = self.prepare_target(self.vocab, sample) + src = self.vocab.map(src) + processed = [src, tgt_in, tgt_out, sample[0]] return processed def prepare_target(self, vocab, datum): @@ -108,31 +107,52 @@ def __getitem__(self, key): raise TypeError if key < 0 or key >= len(self.data): raise IndexError - batch = self.data[key] - batch = self.process(batch) - batch_size = len(batch) - batch = list(zip(*batch)) - assert len(batch) == 4 + sample = self.data[key] + sample = self.process(sample) + assert len(sample) == 4 + + src = torch.tensor(sample[0]) + tgt_in = torch.tensor(sample[1]) + tgt_out = torch.tensor(sample[2]) + orig_text = sample[3] + result = DataSample(src, tgt_in, tgt_out, orig_text), key + return result - # sort all fields by lens for easy RNN operations - lens = [len(x) for x in batch[0]] - batch, orig_idx = sort_all(batch, lens) + @staticmethod + def __collate_fn(data): + (data, idx) = zip(*data) + (src, tgt_in, tgt_out, orig_text) = zip(*data) + + # collate_fn is given a list of length batch size + batch_size = len(data) + + lens = [len(x) for x in tgt_in] + (src, tgt_in, tgt_out, orig_text), orig_idx = sort_all((src, tgt_in, tgt_out, orig_text), lens) + lens = [len(x) for x in tgt_in] # convert to tensors - src = batch[0] - src = get_long_tensor(src, batch_size) + src = pad_sequence(src, True, constant.PAD_ID) src_mask = torch.eq(src, constant.PAD_ID) - tgt_in = get_long_tensor(batch[1], batch_size) - tgt_out = get_long_tensor(batch[2], batch_size) - orig_text = batch[3] + tgt_in = pad_sequence(tgt_in, True, constant.PAD_ID) + tgt_out = pad_sequence(tgt_out, True, constant.PAD_ID) assert tgt_in.size(1) == tgt_out.size(1), \ "Target input and output sequence sizes do not match." - return (src, src_mask, tgt_in, tgt_out, orig_text, orig_idx) + return DataBatch(src, src_mask, tgt_in, tgt_out, orig_text, orig_idx) def __iter__(self): for i in range(self.__len__()): yield self.__getitem__(i) + def to_loader(self): + """Converts self to a DataLoader """ + + batch_size = self.batch_size + shuffle = not self.evaluation + return DL(self, + collate_fn=self.__collate_fn, + batch_size=batch_size, + shuffle=shuffle) + def load_doc(self, doc, evaluation=False): data = doc.get_mwt_expansions(evaluation) if evaluation: data = [[e] for e in data] diff --git a/stanza/models/mwt_expander.py b/stanza/models/mwt_expander.py index 87eadbb8a0..a57b703e39 100644 --- a/stanza/models/mwt_expander.py +++ b/stanza/models/mwt_expander.py @@ -201,7 +201,7 @@ def train(args): # start training for epoch in range(1, args['num_epoch']+1): train_loss = 0 - for i, batch in enumerate(train_batch): + for i, batch in enumerate(train_batch.to_loader()): start_time = time.time() global_step += 1 loss = trainer.update(batch, eval=False) # update step @@ -218,7 +218,7 @@ def train(args): # eval on dev logger.info("Evaluating on dev set...") dev_preds = [] - for i, batch in enumerate(dev_batch): + for i, batch in enumerate(dev_batch.to_loader()): preds = trainer.predict(batch) dev_preds += preds if args.get('ensemble_dict', False) and args.get('ensemble_early_stop', False): @@ -296,7 +296,7 @@ def evaluate(args): else: logger.info("Running the seq2seq model...") preds = [] - for i, b in enumerate(batch): + for i, b in enumerate(batch.to_loader()): preds += trainer.predict(b) if loaded_args.get('ensemble_dict', False): diff --git a/stanza/pipeline/mwt_processor.py b/stanza/pipeline/mwt_processor.py index 50b83bfd81..6aaf1b3112 100644 --- a/stanza/pipeline/mwt_processor.py +++ b/stanza/pipeline/mwt_processor.py @@ -37,7 +37,7 @@ def process(self, document): else: with torch.no_grad(): preds = [] - for i, b in enumerate(batch): + for i, b in enumerate(batch.to_loader()): preds += self.trainer.predict(b, never_decode_unk=True, vocab=batch.vocab) if self.config.get('ensemble_dict', False):