Skip to content

Commit

Permalink
Convert the MWT training to use a pytorch dataloader with shuffling. …
Browse files Browse the repository at this point in the history
…In theory this should also provide some cpu/gpu parallelism at test time, although we haven't done anything to ensure it is using multiprocessing
  • Loading branch information
AngledLuffa committed Nov 29, 2024
1 parent 5c34924 commit 0f9e8d1
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 35 deletions.
81 changes: 50 additions & 31 deletions stanza/models/mwt/data.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import random
import numpy as np
import os
from collections import Counter
from collections import Counter, namedtuple
import logging

import torch
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader as DL

import stanza.models.common.seq2seq_constant as constant
from stanza.models.common.data import map_to_ids, get_long_tensor, get_float_tensor, sort_all
Expand All @@ -14,6 +16,9 @@

logger = logging.getLogger('stanza')

DataSample = namedtuple("DataSample", "src tgt_in tgt_out orig_text")
DataBatch = namedtuple("DataBatch", "src src_mask tgt_in tgt_out orig_text orig_idx")

# enforce that the MWT splitter knows about a couple different alternate apostrophes
# including covering some potential " typos
# setting the augmentation to a very low value should be enough to teach it
Expand All @@ -22,7 +27,6 @@
# 0x22, 0x27, 0x02BC, 0x02CA, 0x2019
APOS = ('"', "'", 'ʼ', 'ˊ', '’')

# TODO: can wrap this in a Pytorch DataLoader, such as what was done for POS
class DataLoader:
def __init__(self, doc, batch_size, args, vocab=None, evaluation=False, expand_unk_vocab=False):
self.batch_size = batch_size
Expand Down Expand Up @@ -56,12 +60,9 @@ def __init__(self, doc, batch_size, args, vocab=None, evaluation=False, expand_u
indices = list(range(len(data)))
random.shuffle(indices)
data = [data[i] for i in indices]
self.num_examples = len(data)

# chunk into batches
data = [data[i:i+batch_size] for i in range(0, len(data), batch_size)]
self.data = data
logger.debug("{} batches created.".format(len(data)))
self.num_examples = len(data)

def init_vocab(self, data):
assert self.evaluation == False # for eval vocab must exist
Expand All @@ -77,17 +78,14 @@ def maybe_augment_apos(self, datum):
break
return datum


def process(self, data):
processed = []
for d in data:
if not self.evaluation and self.augment_apos > 0:
d = self.maybe_augment_apos(d)
src = list(d[0])
src = [constant.SOS] + src + [constant.EOS]
tgt_in, tgt_out = self.prepare_target(self.vocab, d)
src = self.vocab.map(src)
processed += [[src, tgt_in, tgt_out, d[0]]]
def process(self, sample):
if not self.evaluation and self.augment_apos > 0:
sample = self.maybe_augment_apos(sample)
src = list(sample[0])
src = [constant.SOS] + src + [constant.EOS]
tgt_in, tgt_out = self.prepare_target(self.vocab, sample)
src = self.vocab.map(src)
processed = [src, tgt_in, tgt_out, sample[0]]
return processed

def prepare_target(self, vocab, datum):
Expand All @@ -108,31 +106,52 @@ def __getitem__(self, key):
raise TypeError
if key < 0 or key >= len(self.data):
raise IndexError
batch = self.data[key]
batch = self.process(batch)
batch_size = len(batch)
batch = list(zip(*batch))
assert len(batch) == 4
sample = self.data[key]
sample = self.process(sample)
assert len(sample) == 4

src = torch.tensor(sample[0])
tgt_in = torch.tensor(sample[1])
tgt_out = torch.tensor(sample[2])
orig_text = sample[3]
result = DataSample(src, tgt_in, tgt_out, orig_text), key
return result

# sort all fields by lens for easy RNN operations
lens = [len(x) for x in batch[0]]
batch, orig_idx = sort_all(batch, lens)
@staticmethod
def __collate_fn(data):
(data, idx) = zip(*data)
(src, tgt_in, tgt_out, orig_text) = zip(*data)

# collate_fn is given a list of length batch size
batch_size = len(data)

lens = [len(x) for x in tgt_in]
(src, tgt_in, tgt_out, orig_text), orig_idx = sort_all((src, tgt_in, tgt_out, orig_text), lens)
lens = [len(x) for x in tgt_in]

# convert to tensors
src = batch[0]
src = get_long_tensor(src, batch_size)
src = pad_sequence(src, True, constant.PAD_ID)
src_mask = torch.eq(src, constant.PAD_ID)
tgt_in = get_long_tensor(batch[1], batch_size)
tgt_out = get_long_tensor(batch[2], batch_size)
orig_text = batch[3]
tgt_in = pad_sequence(tgt_in, True, constant.PAD_ID)
tgt_out = pad_sequence(tgt_out, True, constant.PAD_ID)
assert tgt_in.size(1) == tgt_out.size(1), \
"Target input and output sequence sizes do not match."
return (src, src_mask, tgt_in, tgt_out, orig_text, orig_idx)
return DataBatch(src, src_mask, tgt_in, tgt_out, orig_text, orig_idx)

def __iter__(self):
for i in range(self.__len__()):
yield self.__getitem__(i)

def to_loader(self):
"""Converts self to a DataLoader """

batch_size = self.batch_size
shuffle = not self.evaluation
return DL(self,
collate_fn=self.__collate_fn,
batch_size=batch_size,
shuffle=shuffle)

def load_doc(self, doc, evaluation=False):
data = doc.get_mwt_expansions(evaluation)
if evaluation: data = [[e] for e in data]
Expand Down
6 changes: 3 additions & 3 deletions stanza/models/mwt_expander.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def train(args):
# start training
for epoch in range(1, args['num_epoch']+1):
train_loss = 0
for i, batch in enumerate(train_batch):
for i, batch in enumerate(train_batch.to_loader()):
start_time = time.time()
global_step += 1
loss = trainer.update(batch, eval=False) # update step
Expand All @@ -218,7 +218,7 @@ def train(args):
# eval on dev
logger.info("Evaluating on dev set...")
dev_preds = []
for i, batch in enumerate(dev_batch):
for i, batch in enumerate(dev_batch.to_loader()):
preds = trainer.predict(batch)
dev_preds += preds
if args.get('ensemble_dict', False) and args.get('ensemble_early_stop', False):
Expand Down Expand Up @@ -296,7 +296,7 @@ def evaluate(args):
else:
logger.info("Running the seq2seq model...")
preds = []
for i, b in enumerate(batch):
for i, b in enumerate(batch.to_loader()):
preds += trainer.predict(b)

if loaded_args.get('ensemble_dict', False):
Expand Down
2 changes: 1 addition & 1 deletion stanza/pipeline/mwt_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def process(self, document):
else:
with torch.no_grad():
preds = []
for i, b in enumerate(batch):
for i, b in enumerate(batch.to_loader()):
preds += self.trainer.predict(b, never_decode_unk=True, vocab=batch.vocab)

if self.config.get('ensemble_dict', False):
Expand Down

0 comments on commit 0f9e8d1

Please sign in to comment.