From 37824a491d73812b47dab1ec35df14c5fb74c962 Mon Sep 17 00:00:00 2001 From: markus583 Date: Tue, 23 Apr 2024 19:28:50 +0000 Subject: [PATCH] add mixed corruption strategy --- wtpsplit/utils.py | 32 +++++++++++++++++++------------- 1 file changed, 19 insertions(+), 13 deletions(-) diff --git a/wtpsplit/utils.py b/wtpsplit/utils.py index 8c6ca4f7..abe5bbcf 100644 --- a/wtpsplit/utils.py +++ b/wtpsplit/utils.py @@ -1,17 +1,17 @@ import json +import logging import os import random +from collections import defaultdict from dataclasses import dataclass, field -from cached_property import cached_property from pathlib import Path from typing import List -import logging -from collections import defaultdict import numpy as np import pandas as pd -from transformers import AutoTokenizer +from cached_property import cached_property from mosestokenizer import MosesTokenizer +from transformers import AutoTokenizer # same as in CANINE PRIMES = [31, 43, 59, 61, 73, 97, 103, 113, 137, 149, 157, 173, 181, 193, 211, 223] @@ -86,7 +86,8 @@ class LabelArgs: case_corruption_prob_after_newline: float = 0.0 case_corruption_prob_after_punct: float = 0.0 corrupt_entire_chunk_prob: float = 0.0 - corrupt_entire_chunk_strategy: str = "full" + corrupt_entire_chunk_strategy: str = "mix" + corrupt_entire_chunk_prob_full: float = 0.5 def __post_init__(self): if self.custom_punctuation_file: @@ -239,27 +240,32 @@ def corrupt_training( input_ids = input_ids.copy() block_ids = block_ids.copy() if random.random() < label_args.corrupt_entire_chunk_prob: - # lowercase all text + # choose corruption strategy + if label_args.corrupt_entire_chunk_strategy == "mix": + corrupt_strategy = "full" if random.random() < label_args.corrupt_entire_chunk_prob_full else "asr" + else: + corrupt_strategy = label_args.corrupt_entire_chunk_strategy + input_text = tokenizer.decode(input_ids) - if label_args.corrupt_entire_chunk_strategy == "tokenizer": + if corrupt_strategy == "tokenizer": if not tokenizer: raise NotImplementedError() corrupted = corrupt(input_text, do_lowercase=True, do_remove_punct=False) - input_ids = tokenizer.encode(corrupted, add_special_tokens=False) + input_ids = tokenizer.encode(corrupted, add_special_tokens=False, verbose=False) # remove ALL punct *tokens* auxiliary_remove_prob = 1.0 - elif label_args.corrupt_entire_chunk_strategy == "full": + elif corrupt_strategy == "full": # remove all punct *characters* corrupted = corrupt(input_text, do_lowercase=True, do_remove_punct=True) - input_ids = tokenizer.encode(corrupted, add_special_tokens=False) + input_ids = tokenizer.encode(corrupted, add_special_tokens=False, verbose=False) auxiliary_remove_prob = 1.0 # just for safety/consistency - elif label_args.corrupt_entire_chunk_strategy == "asr": + elif corrupt_strategy == "asr": if not tokenizer: raise NotImplementedError() corrupted_sentences = corrupt_asr(input_text, lang) corrupted_text = "\n".join(corrupted_sentences) - input_ids = tokenizer.encode(corrupted_text, add_special_tokens=False) - auxiliary_remove_prob = 0.0 + input_ids = tokenizer.encode(corrupted_text, add_special_tokens=False, verbose=False) + auxiliary_remove_prob = 0.0 # do not remove additional characters. block_ids = [0] * len(input_ids) else: