From e2eae1cccc1fd7225c8161a3f606db559970867c Mon Sep 17 00:00:00 2001 From: markus583 Date: Thu, 28 Mar 2024 15:48:11 +0000 Subject: [PATCH] generalize pairwise eval to k-mer eval --- .../xlmr_stratify_0.1_6layers_p_v3_look6.json | 45 ++++++ wtpsplit/evaluation/intrinsic_pairwise.py | 70 ++++++++-- wtpsplit/train/evaluate.py | 75 +++++++++- wtpsplit/train/train.py | 132 ++++++++++-------- 4 files changed, 251 insertions(+), 71 deletions(-) create mode 100644 configs/xlmr_stratify_0.1_6layers_p_v3_look6.json diff --git a/configs/xlmr_stratify_0.1_6layers_p_v3_look6.json b/configs/xlmr_stratify_0.1_6layers_p_v3_look6.json new file mode 100644 index 00000000..445f0bce --- /dev/null +++ b/configs/xlmr_stratify_0.1_6layers_p_v3_look6.json @@ -0,0 +1,45 @@ +{ + "model_name_or_path": "xlm-roberta-base", + "output_dir": "xlmr-6l-v3_look6", + "train_text_path": "data/sentence/train.parquet", + "valid_text_path": "data/sentence/valid.parquet", + "block_size": 256, + "eval_stride": 128, + "use_bert": true, + "do_train": true, + "do_eval": true, + "evaluation_strategy": "steps", + "per_device_train_batch_size": 128, + "per_device_eval_batch_size": 64, + "gradient_accumulation_steps": 1, + "eval_accumulation_steps": 8, + "dataloader_num_workers": 4, + "preprocessing_num_workers": 4, + "learning_rate": 1e-4, + "save_strategy": "steps", + "fp16": false, + "max_steps": 200000, + "save_steps": 50000, + "eval_steps": 5000, + "logging_steps": 50, + "report_to": "none", + "is_decoder": false, + "remove_unused_columns": false, + "lookahead": 6, + "one_sample_per_line": false, + "do_sentence_training": true, + "do_auxiliary_training": true, + "auxiliary_remove_prob": 0.2, + "warmup_steps": 5000, + "adapter_warmup_steps": 0, + "adapter_lr_multiplier": 1, + "ngram_order": 1, + "non_punctuation_sample_ratio": 0.1, + "prediction_loss_only": true, + "use_auxiliary": true, + "ddp_timeout": 3600, + "use_subwords": true, + "num_hidden_layers": 6, + "custom_punctuation_file": "punctuation_xlmr_unk.txt", + "log_level": "warning" +} \ No newline at end of file diff --git a/wtpsplit/evaluation/intrinsic_pairwise.py b/wtpsplit/evaluation/intrinsic_pairwise.py index cbf19e1b..90ce3d57 100644 --- a/wtpsplit/evaluation/intrinsic_pairwise.py +++ b/wtpsplit/evaluation/intrinsic_pairwise.py @@ -61,14 +61,14 @@ class Args: min_pair_length: int = 0 -def process_logits_pairwise(pairs, model, lang_code, block_size, batch_size, verbose=True) -> List[np.ndarray]: +def process_logits_k_mers(pairs, model, lang_code, block_size, batch_size, verbose=True) -> List[np.ndarray]: logits_list = [] # create batches of sentence pairs - batched_pairs = [pairs[i : i + batch_size] for i in range(0, len(pairs), batch_size)] - for batch in tqdm(batched_pairs, disable=not verbose): - pair_texts = [pair[0] + Constants.SEPARATORS[lang_code] + pair[1] for pair in batch] + batched_k_mers = [pairs[i : i + batch_size] for i in range(0, len(pairs), batch_size)] + for batch in tqdm(batched_k_mers, disable=not verbose): + k_mer_texts = [Constants.SEPARATORS[lang_code].join(pair) for pair in batch] all_logits, offsets_mapping, tokenizer = extract_batched( - pair_texts, + k_mer_texts, model, lang_code=lang_code, block_size=block_size, @@ -76,12 +76,12 @@ def process_logits_pairwise(pairs, model, lang_code, block_size, batch_size, ver pad_last_batch=True, ) - for pair, logit, offset_mapping in zip(pair_texts, all_logits, offsets_mapping): + for k_mer, logit, offset_mapping in zip(k_mer_texts, all_logits, offsets_mapping): if "xlm" in model.config.model_type: - tokens = tokenizer.tokenize(pair, verbose=False) + tokens = tokenizer.tokenize(k_mer, verbose=False) # padding is also removed here (via offset_mapping) - logits = token_to_char_probs(pair, tokens, logit, tokenizer, offset_mapping) + logits = token_to_char_probs(k_mer, tokens, logit, tokenizer, offset_mapping) logits_list.append(logits) else: if len(logit) < offset_mapping: @@ -143,6 +143,56 @@ def generate_pairs( ] return all_pairs +def generate_k_mers( + sentences: List[str], + k: int, + do_lowercase: bool, + do_remove_punct: bool, + sample_pct: float = 1, + max_n_samples: int = sys.maxsize, + min_k_mer_length: int = 0, +) -> List[Tuple[str, ...]]: + """Generate k-mers from a list of sentences. + + Args: + sentences (List[str]): Input list of sentences. + k (int): The number of sentences to include in each k-mer. + sample_pct (float): Percentage of k-mers to sample. + max_n_samples (int): Maximum number of k-mers to sample. + min_k_mer_length (int): Minimum length of a k-mer. + do_lowercase (bool): Whether to lowercase the sentences. + do_remove_punct (bool): Whether to remove punctuation from the sentences. + + Returns: + List[Tuple[str, ...]]: List of k-mers. + """ + random.seed(42) + n_k_mers = len(sentences) // k + sample_size = min(round(n_k_mers * sample_pct), max_n_samples) + + # Efficient sampling of a subset of all possible k-mers if needed + if sample_size < n_k_mers: + sampled_indices = set(random.sample(range(n_k_mers), sample_size)) + all_k_mers = [ + tuple(sentences[i*k+j] for j in range(k)) + for i in sampled_indices + if sum(len(sentences[i*k+j]) for j in range(k)) > min_k_mer_length + ] + else: + # Generate all k-mers that meet the min_k_mer_length criterion + all_k_mers = [ + tuple(sentences[i+j] for j in range(k)) + for i in range(0, len(sentences) - k + 1, k) + if sum(len(sentences[i+j]) for j in range(k)) > min_k_mer_length + ] + + # Apply corruption to k-mers + all_k_mers = [ + tuple(corrupt(sentence, do_lowercase, do_remove_punct) for sentence in k_mer) + for k_mer in all_k_mers + ] + + return all_k_mers def load_or_compute_logits(args, model, eval_data, valid_data=None, save_str: str = None): logits_path = Constants.CACHE_DIR / "intrinsic_pairwise" / f"{save_str}.h5" @@ -204,7 +254,7 @@ def load_or_compute_logits(args, model, eval_data, valid_data=None, save_str: st ) start_time = time.time() # Start timing for test logits processing - test_logits = process_logits_pairwise( + test_logits = process_logits_k_mers( all_pairs_test, model, lang_code, @@ -250,7 +300,7 @@ def load_or_compute_logits(args, model, eval_data, valid_data=None, save_str: st min_pair_length=args.min_pair_length, ) - train_logits = process_logits_pairwise( + train_logits = process_logits_k_mers( all_pairs_train, model, lang_code, args.block_size, args.batch_size ) train_logits = np.concatenate(train_logits) diff --git a/wtpsplit/train/evaluate.py b/wtpsplit/train/evaluate.py index c2558ff6..0cca1220 100644 --- a/wtpsplit/train/evaluate.py +++ b/wtpsplit/train/evaluate.py @@ -7,7 +7,7 @@ from wtpsplit.evaluation import token_to_char_probs from wtpsplit.evaluation.intrinsic import corrupt -from wtpsplit.evaluation.intrinsic_pairwise import generate_pairs, process_logits_pairwise +from wtpsplit.evaluation.intrinsic_pairwise import generate_pairs, generate_k_mers, process_logits_k_mers from wtpsplit.extract import PyTorchWrapper, extract from wtpsplit.utils import Constants, sigmoid @@ -164,7 +164,7 @@ def evaluate_sentence_pairwise( ) # get logits for each pair - logits = process_logits_pairwise( + logits = process_logits_k_mers( pairs=sampled_pairs, model=PyTorchWrapper(model.backbone), lang_code=lang_code, @@ -174,7 +174,6 @@ def evaluate_sentence_pairwise( ) # simulate performance for WtP-U - for i, (sentence1, sentence2) in enumerate(sampled_pairs): newline_probs = logits[i][:, positive_index] @@ -200,3 +199,73 @@ def evaluate_sentence_pairwise( average_metric = np.mean(metrics_list) avg_accuracy = np.mean(accuracy_list) return average_metric, avg_accuracy + +def evaluate_sentence_kmers( + lang_code, + sentences, + model, + stride, + block_size, + batch_size, + k: int = 3, + sample_pct: float = 0.1, + max_n_samples: int = sys.maxsize, + use_pysbd=False, + positive_index=None, + do_lowercase=False, + do_remove_punct=False, + threshold: float = 0.1 +): + if positive_index is None: + positive_index = Constants.NEWLINE_INDEX + + # Preprocess sentences + sentences = [sentence.lstrip("-").strip() for sentence in sentences] + + separator = Constants.SEPARATORS[lang_code] + metrics_list = [] + accuracy_list = [] + + # get pairs of sentences (non-overlapping) + sampled_k_mers = generate_k_mers( + sentences=sentences, + k=k, + do_lowercase=do_lowercase, + do_remove_punct=do_remove_punct, + sample_pct=sample_pct, + max_n_samples=max_n_samples, + min_k_mer_length=0, + ) + + # get logits for each pair + logits = process_logits_k_mers( # TODO + pairs=sampled_k_mers, + model=PyTorchWrapper(model.backbone), + lang_code=lang_code, + block_size=block_size, + batch_size=batch_size, + verbose=False, + ) + + for i, k_mer in enumerate(sampled_k_mers): + newline_probs = logits[i][:, positive_index] + + k_mer_text = separator.join(k_mer) + true_end_indices = np.cumsum(np.array([len(s) for s in k_mer])) + np.arange(len(k_mer)) * len(separator) + newline_labels = np.zeros(len(k_mer_text)) + newline_labels[true_end_indices - 1] = 1 + + # Get metrics for the k-mer + k_mer_metrics, _ = get_metrics(newline_labels, newline_probs) + metrics_list.append(k_mer_metrics["pr_auc"]) + predicted_labels = newline_probs > np.log(threshold / (1 - threshold)) # inverse sigmoid + # For accuracy, check if all the labels in between are correctly predicted (ignore the one at the end) + intermediate_newline_labels = newline_labels[:-len(separator)] # Exclude the end + intermediate_predicted_labels = predicted_labels[:-len(separator)] # Exclude the end + correct = np.array_equal(intermediate_newline_labels, intermediate_predicted_labels) + accuracy_list.append(correct) + + # Compute and return the average metric and accuracy + average_metric = np.mean(metrics_list) + avg_accuracy = np.mean(accuracy_list) + return average_metric, avg_accuracy diff --git a/wtpsplit/train/train.py b/wtpsplit/train/train.py index 7b043d30..9645ddd6 100644 --- a/wtpsplit/train/train.py +++ b/wtpsplit/train/train.py @@ -32,7 +32,7 @@ SubwordXLMConfig, SubwordXLMForTokenClassification, ) -from wtpsplit.train.evaluate import evaluate_sentence, evaluate_sentence_pairwise +from wtpsplit.train.evaluate import evaluate_sentence, evaluate_sentence_pairwise, evaluate_sentence_kmers from wtpsplit.train.trainer import Trainer from wtpsplit.utils import Constants, LabelArgs, corrupt, get_label_dict, get_subword_label_dict from wtpsplit.train.utils import Model, cleanup_cache_files @@ -501,13 +501,13 @@ def maybe_pad(text): dataset = dataset.rename_column(args.text_column, "input_ids") logger.warning(f"Tokenized {split} dataset.") - if split == "train" and args.use_subwords: - with training_args.main_process_first(): - for root, dirs, files in os.walk(os.environ.get("HF_DATASETS_CACHE")): - for file in files: - if file.startswith("m_c4-test-train"): - logger.warning(f"Removing {os.path.join(root, file)}") - os.remove(os.path.join(root, file)) + # if split == "train" and args.use_subwords: + # with training_args.main_process_first(): + # for root, dirs, files in os.walk(os.environ.get("HF_DATASETS_CACHE")): + # for file in files: + # if file.startswith("m_c4-test-train"): + # logger.warning(f"Removing {os.path.join(root, file)}") + # os.remove(os.path.join(root, file)) if not args.one_sample_per_line: with training_args.main_process_first(): @@ -534,7 +534,7 @@ def maybe_pad(text): num_workers=args.preprocessing_num_workers, include_languages=args.include_languages, shuffle=args.shuffle, - split="train", + split="valid", ) logger.warning(f"Train dataset has {len(train_dataset)} examples.") @@ -566,54 +566,70 @@ def compute_metrics(trainer): if trainer.args.process_index == 0 and args.do_sentence_training: # with training_args.main_process_first(): for dataset_name, dataset in lang_data["sentence"].items(): - score, _ = evaluate_sentence( - lang_code, - dataset["data"], - model, - stride=args.eval_stride, - block_size=args.block_size, - batch_size=training_args.per_device_eval_batch_size, - ) - metrics[f"{lang_code}_{dataset_name}_pr_auc"] = score - avg_metrics[f"average_{dataset_name}_pr_auc"].append(score) - if lang_code in ["zh", "ja", "my", "km"]: - avg_metrics[f"average_nonwhitespace_{dataset_name}_pr_auc"].append(score) - else: - avg_metrics[f"average_whitespace_{dataset_name}_pr_auc"].append(score) - score, _ = evaluate_sentence( - lang_code, - dataset["data"], - model, - stride=args.eval_stride, - block_size=args.block_size, - batch_size=training_args.per_device_eval_batch_size, - do_lowercase=True, - do_remove_punct=True, - ) - metrics[f"lower_rmp_{lang_code}_{dataset_name}_pr_auc"] = score - avg_metrics[f"lower_rmp_average_{dataset_name}_pr_auc"].append(score) - if lang_code in ["zh", "ja", "my", "km"]: - avg_metrics[f"lower_rmp_average_nonwhitespace_{dataset_name}_pr_auc"].append(score) - else: - avg_metrics[f"lower_rmp_average_whitespace_{dataset_name}_pr_auc"].append(score) - score, avg_acc = evaluate_sentence_pairwise( - lang_code, - dataset["data"], - model, - stride=args.eval_stride, - block_size=args.block_size, - batch_size=training_args.per_device_eval_batch_size, - ) - metrics[f"pairwise_{lang_code}_{dataset_name}_pr_auc"] = score - avg_metrics[f"pairwise_average_{dataset_name}_pr_auc"].append(score) - metrics[f"pairwise_{lang_code}_{dataset_name}_acc"] = avg_acc - avg_metrics[f"pairwise_average_{dataset_name}_acc"].append(avg_acc) - if lang_code in ["zh", "ja", "my", "km"]: - avg_metrics[f"pairwise_average_nonwhitespace_{dataset_name}_pr_auc"].append(score) - avg_metrics[f"pairwise_average_nonwhitespace_{dataset_name}_acc"].append(avg_acc) - else: - avg_metrics[f"pairwise_average_whitespace_{dataset_name}_pr_auc"].append(score) - avg_metrics[f"pairwise_average_whitespace_{dataset_name}_acc"].append(avg_acc) + # score, _ = evaluate_sentence( + # lang_code, + # dataset["data"], + # model, + # stride=args.eval_stride, + # block_size=args.block_size, + # batch_size=training_args.per_device_eval_batch_size, + # ) + # metrics[f"{lang_code}_{dataset_name}_pr_auc"] = score + # avg_metrics[f"average_{dataset_name}_pr_auc"].append(score) + # if lang_code in ["zh", "ja", "my", "km"]: + # avg_metrics[f"average_nonwhitespace_{dataset_name}_pr_auc"].append(score) + # else: + # avg_metrics[f"average_whitespace_{dataset_name}_pr_auc"].append(score) + # score, _ = evaluate_sentence( + # lang_code, + # dataset["data"], + # model, + # stride=args.eval_stride, + # block_size=args.block_size, + # batch_size=training_args.per_device_eval_batch_size, + # do_lowercase=True, + # do_remove_punct=True, + # ) + # metrics[f"lower_rmp_{lang_code}_{dataset_name}_pr_auc"] = score + # avg_metrics[f"lower_rmp_average_{dataset_name}_pr_auc"].append(score) + # if lang_code in ["zh", "ja", "my", "km"]: + # avg_metrics[f"lower_rmp_average_nonwhitespace_{dataset_name}_pr_auc"].append(score) + # else: + # avg_metrics[f"lower_rmp_average_whitespace_{dataset_name}_pr_auc"].append(score) + # k-mer based evaluation + for k in [2, 3, 4]: + score, avg_acc = evaluate_sentence_kmers( + lang_code, + dataset["data"], + model, + stride=args.eval_stride, + block_size=args.block_size, + batch_size=training_args.per_device_eval_batch_size, + k=k, + # sample_pct=0.1, + ) + metrics[f"k_{k}_{lang_code}_{dataset_name}_pr_auc"] = score + avg_metrics[f"k_{k}_average_{dataset_name}_pr_auc"].append(score) + metrics[f"k_{k}_{lang_code}_{dataset_name}_acc"] = avg_acc + avg_metrics[f"k_{k}_average_{dataset_name}_acc"].append(avg_acc) + if lang_code in ["zh", "ja", "my", "km"]: + avg_metrics[f"k_{k}_average_nonwhitespace_{dataset_name}_pr_auc"].append(score) + avg_metrics[f"k_{k}_average_nonwhitespace_{dataset_name}_acc"].append(avg_acc) + else: + avg_metrics[f"k_{k}_average_whitespace_{dataset_name}_pr_auc"].append(score) + avg_metrics[f"k_{k}_average_whitespace_{dataset_name}_acc"].append(avg_acc) + if k == 2: + # keep keys for backwards compat in wandb + metrics[f"pairwise_{lang_code}_{dataset_name}_pr_auc"] = score + avg_metrics[f"pairwise_average_{dataset_name}_pr_auc"].append(score) + metrics[f"pairwise_{lang_code}_{dataset_name}_acc"] = avg_acc + avg_metrics[f"pairwise_average_{dataset_name}_acc"].append(avg_acc) + if lang_code in ["zh", "ja", "my", "km"]: + avg_metrics[f"pairwise_average_nonwhitespace_{dataset_name}_pr_auc"].append(score) + avg_metrics[f"pairwise_average_nonwhitespace_{dataset_name}_acc"].append(avg_acc) + else: + avg_metrics[f"pairwise_average_whitespace_{dataset_name}_pr_auc"].append(score) + avg_metrics[f"pairwise_average_whitespace_{dataset_name}_acc"].append(avg_acc) for name, values in avg_metrics.items(): if len(values) > 1: @@ -622,7 +638,7 @@ def compute_metrics(trainer): return metrics if "wandb" in training_args.report_to and training_args.process_index == 0: - wandb.init(name=wandb_name, project="sentence") + wandb.init(name=wandb_name, project="sentence", entity="markus_583") wandb.config.update(args) wandb.config.update(training_args) wandb.config.update(label_args)