diff --git a/stanza/models/mwt/data.py b/stanza/models/mwt/data.py index b8b0b2f1e..791c23efa 100644 --- a/stanza/models/mwt/data.py +++ b/stanza/models/mwt/data.py @@ -3,6 +3,7 @@ import os from collections import Counter import logging + import torch import stanza.models.common.seq2seq_constant as constant @@ -17,7 +18,9 @@ # including covering some potential " typos # setting the augmentation to a very low value should be enough to teach it # about the unknown characters without messing up the predictions for other text -APOS = ('"', '’', 'ʼ') +# +# 0x22, 0x27, 0x02BC, 0x02CA, 0x2019 +APOS = ('"', "'", 'ʼ', 'ˊ', '’') # TODO: can wrap this in a Pytorch DataLoader, such as what was done for POS class DataLoader: @@ -34,7 +37,7 @@ def __init__(self, doc, batch_size, args, vocab=None, evaluation=False, expand_u if vocab is None: assert self.evaluation == False # for eval vocab must exist self.vocab = self.init_vocab(data) - if self.augment_apos > 0 and "'" in self.vocab: + if self.augment_apos > 0 and any(x in self.vocab for x in APOS): for apos in APOS: self.vocab.add_unit(apos) elif expand_unk_vocab: @@ -66,9 +69,12 @@ def init_vocab(self, data): return vocab def maybe_augment_apos(self, datum): - if "'" in datum[0] and random.uniform(0,1) < self.augment_apos: - replacement = random.choice(APOS) - datum = (datum[0].replace("'", replacement), datum[1].replace("'", replacement)) + for original in APOS: + if original in datum[0]: + if random.uniform(0,1) < self.augment_apos: + replacement = random.choice(APOS) + datum = (datum[0].replace(original, replacement), datum[1].replace(original, replacement)) + break return datum