From 3d1a8cc46e118fb94f5fc8c77f711b6081e2166e Mon Sep 17 00:00:00 2001 From: markus583 Date: Tue, 23 Apr 2024 08:08:38 +0000 Subject: [PATCH] add full corruption param --- wtpsplit/evaluation/intrinsic.py | 13 ++------ wtpsplit/train/evaluate.py | 3 +- wtpsplit/train/train.py | 4 +-- wtpsplit/train/train_adapter.py | 13 ++++++-- wtpsplit/train/train_adapter_parallel.py | 3 +- wtpsplit/utils.py | 38 +++++++++++++++++++----- 6 files changed, 46 insertions(+), 28 deletions(-) diff --git a/wtpsplit/evaluation/intrinsic.py b/wtpsplit/evaluation/intrinsic.py index 744d9c8c..d49cd1d2 100644 --- a/wtpsplit/evaluation/intrinsic.py +++ b/wtpsplit/evaluation/intrinsic.py @@ -18,7 +18,7 @@ import wtpsplit.models # noqa: F401 from wtpsplit.evaluation import evaluate_mixture, get_labels, train_mixture, token_to_char_probs from wtpsplit.extract import PyTorchWrapper, extract -from wtpsplit.utils import Constants +from wtpsplit.utils import Constants, corrupt logger = logging.getLogger() logger.setLevel(logging.INFO) @@ -42,7 +42,7 @@ class Args: # } # } # TODO: for songs/etc., maybe feed in each sample separately? - eval_data_path: str = "data/all_data.pth" + eval_data_path: str = "data/all_data_21-04.pth" valid_text_path: str = None # "data/sentence/valid.parquet" device: str = "cpu" block_size: int = 512 @@ -92,15 +92,6 @@ def process_logits(text, model, lang_code, args): return logits -def corrupt(text: str, do_lowercase: bool, do_remove_punct: bool): - if do_lowercase: - text = text.lower() - if do_remove_punct: - for punct in Constants.PUNCTUATION_CHARS: - text = text.replace(punct, "") - return text - - def load_or_compute_logits(args, model, eval_data, valid_data=None, save_str: str = None): logits_path = Constants.CACHE_DIR / "intrinsic" / f"{save_str}.h5" diff --git a/wtpsplit/train/evaluate.py b/wtpsplit/train/evaluate.py index 60b4b55b..a5eec3ab 100644 --- a/wtpsplit/train/evaluate.py +++ b/wtpsplit/train/evaluate.py @@ -6,10 +6,9 @@ import sklearn.metrics from wtpsplit.evaluation import token_to_char_probs -from wtpsplit.evaluation.intrinsic import corrupt from wtpsplit.evaluation.intrinsic_pairwise import generate_pairs, generate_k_mers, process_logits_k_mers from wtpsplit.extract import PyTorchWrapper, extract -from wtpsplit.utils import Constants, sigmoid +from wtpsplit.utils import Constants, sigmoid, corrupt logger = logging.getLogger(__name__) diff --git a/wtpsplit/train/train.py b/wtpsplit/train/train.py index 9dbe919f..ed02dba3 100644 --- a/wtpsplit/train/train.py +++ b/wtpsplit/train/train.py @@ -34,7 +34,7 @@ ) from wtpsplit.train.evaluate import evaluate_sentence, evaluate_sentence_pairwise, evaluate_sentence_kmers from wtpsplit.train.trainer import Trainer -from wtpsplit.utils import Constants, LabelArgs, corrupt, get_label_dict, get_subword_label_dict +from wtpsplit.utils import Constants, LabelArgs, corrupt_training, get_label_dict, get_subword_label_dict from wtpsplit.train.utils import Model, cleanup_cache_files logger = logging.getLogger(__name__) @@ -131,7 +131,7 @@ def collate_fn(batch, args, label_args, label_dict, tokenizer, add_lang_ids: boo block_ids = [0] * len(input_ids) - input_ids, _, labels = corrupt( + input_ids, _, labels = corrupt_training( input_ids, block_ids, lang, diff --git a/wtpsplit/train/train_adapter.py b/wtpsplit/train/train_adapter.py index 666e1d80..103fc31f 100644 --- a/wtpsplit/train/train_adapter.py +++ b/wtpsplit/train/train_adapter.py @@ -19,14 +19,13 @@ import adapters import wandb from adapters import AdapterArguments -from wtpsplit.evaluation.intrinsic import corrupt from wtpsplit.models import SubwordXLMConfig, SubwordXLMForTokenClassification from wtpsplit.train.adaptertrainer import AdapterTrainer from wtpsplit.train.trainer import Trainer from wtpsplit.train.evaluate import evaluate_sentence, evaluate_sentence_pairwise from wtpsplit.train.train import collate_fn, setup_logging from wtpsplit.train.utils import Model -from wtpsplit.utils import Constants, LabelArgs, get_label_dict, get_subword_label_dict +from wtpsplit.utils import Constants, LabelArgs, get_label_dict, get_subword_label_dict, corrupt from tqdm import tqdm from typing import Union, Optional @@ -117,6 +116,8 @@ def prepare_dataset( dataset = data[lang]["sentence"][dataset_name]["meta"]["train_data"] elif split == "valid": dataset = data[lang]["sentence"][dataset_name]["data"] + if dataset_name == "opus100" and lang == "fr": + dataset = data[lang]["sentence"][dataset_name]["data"] if dataset is None: return None @@ -403,6 +404,12 @@ def maybe_pad(text): for lang in tqdm(data.keys(), desc="Language"): if lang in args.include_languages: for dataset_name in data[lang]["sentence"].keys(): + if dataset_name != "ted2020": + continue + # skip langs starting with a, b, ..., k + if lang[0] < "d": + print(f"Skipping {lang} {dataset_name}") + continue # do model stuff here; otherwise, head params would be overwritten every time backbone = SubwordXLMForTokenClassification.from_pretrained( args.model_name_or_path, config=copy.deepcopy(config), ignore_mismatched_sizes=True @@ -551,7 +558,7 @@ def compute_metrics(trainer): model.backbone.classifier = torch.nn.Sequential( clf, # original classifier - if frozen above, also frozen here torch.nn.Linear(clf.out_features, 1), - ) + ) model.backbone.config.num_labels = 1 # if args.one_sample_per_line: diff --git a/wtpsplit/train/train_adapter_parallel.py b/wtpsplit/train/train_adapter_parallel.py index ed715fd2..b67760ba 100644 --- a/wtpsplit/train/train_adapter_parallel.py +++ b/wtpsplit/train/train_adapter_parallel.py @@ -35,8 +35,7 @@ from wtpsplit.train.evaluate import evaluate_sentence, evaluate_sentence_pairwise from wtpsplit.train.train import collate_fn from wtpsplit.train.utils import Model -from wtpsplit.utils import Constants, LabelArgs, get_label_dict, get_subword_label_dict -from wtpsplit.evaluation.intrinsic import corrupt +from wtpsplit.utils import Constants, LabelArgs, get_label_dict, get_subword_label_dict, corrupt os.environ["TOKENIZERS_PARALLELISM"] = "false" diff --git a/wtpsplit/utils.py b/wtpsplit/utils.py index e7dcf453..6505904c 100644 --- a/wtpsplit/utils.py +++ b/wtpsplit/utils.py @@ -85,6 +85,7 @@ class LabelArgs: case_corruption_prob_after_newline: float = 0.0 case_corruption_prob_after_punct: float = 0.0 corrupt_entire_chunk_prob: float = 0.0 + corrupt_entire_chunk_strategy: str = "full" def __post_init__(self): if self.custom_punctuation_file: @@ -133,7 +134,7 @@ def get_subword_label_dict(label_args, tokenizer): return label_dict -# numerically more stable sigmoid taken from +# numerically more stable sigmoid taken from # https://stackoverflow.com/questions/51976461/optimal-way-of-defining-a-numerically-stable-sigmoid-function-for-a-list-in-pyth def _positive_sigmoid(x): return 1 / (1 + np.exp(-x)) @@ -196,8 +197,17 @@ def lang_code_to_lang(lang_code): return languages.get(part3=lang_code).name +def corrupt(text: str, do_lowercase: bool, do_remove_punct: bool): + if do_lowercase: + text = text.lower() + if do_remove_punct: + for punct in Constants.PUNCTUATION_CHARS: + text = text.replace(punct, "") + return text + + # does the steps in Figure 2 of the paper -def corrupt( +def corrupt_training( input_ids, block_ids, lang, @@ -211,14 +221,24 @@ def corrupt( block_ids = block_ids.copy() if random.random() < label_args.corrupt_entire_chunk_prob: # lowercase all text - lowercased = tokenizer.decode(input_ids).lower() - input_ids = tokenizer.encode(lowercased, add_special_tokens=False) + input_text = tokenizer.decode(input_ids) + if label_args.corrupt_entire_chunk_strategy == "tokenizer": + if not tokenizer: + raise NotImplementedError() + corrupted = corrupt(input_text, do_lowercase=True, do_remove_punct=False) + input_ids = tokenizer.encode(corrupted, add_special_tokens=False) + # remove ALL punct *tokens* + auxiliary_remove_prob = 1.0 + elif label_args.corrupt_entire_chunk_strategy == "full": + # remove all punct *characters* + corrupted = corrupt(input_text, do_lowercase=True, do_remove_punct=True) + input_ids = tokenizer.encode(corrupted, add_special_tokens=False) + auxiliary_remove_prob = 1.0 # just for safety/consistency block_ids = [0] * len(input_ids) - # remove ALL punct - auxiliary_remove_prob = 1.0 + else: auxiliary_remove_prob = label_args.auxiliary_remove_prob - + labels = label(input_ids, label_dict) separator = Constants.SEPARATORS[lang] @@ -459,7 +479,9 @@ def reconstruct_sentences(text, partial_sentences): label_dict = get_subword_label_dict(label_args, tokenizer) # corrupt - input_ids, block_ids, labels = corrupt(input_ids, block_ids, "en", label_args, label_dict, tokenizer=tokenizer) + input_ids, block_ids, labels = corrupt_training( + input_ids, block_ids, "en", label_args, label_dict, tokenizer=tokenizer + ) print(input_ids) print(labels) print(tokenizer.tokenize(text))