Skip to content

Commit

Permalink
add full corruption param
Browse files Browse the repository at this point in the history
  • Loading branch information
markus583 committed Apr 23, 2024
1 parent 42eb4f3 commit 3d1a8cc
Show file tree
Hide file tree
Showing 6 changed files with 46 additions and 28 deletions.
13 changes: 2 additions & 11 deletions wtpsplit/evaluation/intrinsic.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import wtpsplit.models # noqa: F401
from wtpsplit.evaluation import evaluate_mixture, get_labels, train_mixture, token_to_char_probs
from wtpsplit.extract import PyTorchWrapper, extract
from wtpsplit.utils import Constants
from wtpsplit.utils import Constants, corrupt

logger = logging.getLogger()
logger.setLevel(logging.INFO)
Expand All @@ -42,7 +42,7 @@ class Args:
# }
# }
# TODO: for songs/etc., maybe feed in each sample separately?
eval_data_path: str = "data/all_data.pth"
eval_data_path: str = "data/all_data_21-04.pth"
valid_text_path: str = None # "data/sentence/valid.parquet"
device: str = "cpu"
block_size: int = 512
Expand Down Expand Up @@ -92,15 +92,6 @@ def process_logits(text, model, lang_code, args):
return logits


def corrupt(text: str, do_lowercase: bool, do_remove_punct: bool):
if do_lowercase:
text = text.lower()
if do_remove_punct:
for punct in Constants.PUNCTUATION_CHARS:
text = text.replace(punct, "")
return text


def load_or_compute_logits(args, model, eval_data, valid_data=None, save_str: str = None):
logits_path = Constants.CACHE_DIR / "intrinsic" / f"{save_str}.h5"

Expand Down
3 changes: 1 addition & 2 deletions wtpsplit/train/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,9 @@
import sklearn.metrics

from wtpsplit.evaluation import token_to_char_probs
from wtpsplit.evaluation.intrinsic import corrupt
from wtpsplit.evaluation.intrinsic_pairwise import generate_pairs, generate_k_mers, process_logits_k_mers
from wtpsplit.extract import PyTorchWrapper, extract
from wtpsplit.utils import Constants, sigmoid
from wtpsplit.utils import Constants, sigmoid, corrupt

logger = logging.getLogger(__name__)

Expand Down
4 changes: 2 additions & 2 deletions wtpsplit/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
)
from wtpsplit.train.evaluate import evaluate_sentence, evaluate_sentence_pairwise, evaluate_sentence_kmers
from wtpsplit.train.trainer import Trainer
from wtpsplit.utils import Constants, LabelArgs, corrupt, get_label_dict, get_subword_label_dict
from wtpsplit.utils import Constants, LabelArgs, corrupt_training, get_label_dict, get_subword_label_dict
from wtpsplit.train.utils import Model, cleanup_cache_files

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -131,7 +131,7 @@ def collate_fn(batch, args, label_args, label_dict, tokenizer, add_lang_ids: boo

block_ids = [0] * len(input_ids)

input_ids, _, labels = corrupt(
input_ids, _, labels = corrupt_training(
input_ids,
block_ids,
lang,
Expand Down
13 changes: 10 additions & 3 deletions wtpsplit/train/train_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,13 @@
import adapters
import wandb
from adapters import AdapterArguments
from wtpsplit.evaluation.intrinsic import corrupt
from wtpsplit.models import SubwordXLMConfig, SubwordXLMForTokenClassification
from wtpsplit.train.adaptertrainer import AdapterTrainer
from wtpsplit.train.trainer import Trainer
from wtpsplit.train.evaluate import evaluate_sentence, evaluate_sentence_pairwise
from wtpsplit.train.train import collate_fn, setup_logging
from wtpsplit.train.utils import Model
from wtpsplit.utils import Constants, LabelArgs, get_label_dict, get_subword_label_dict
from wtpsplit.utils import Constants, LabelArgs, get_label_dict, get_subword_label_dict, corrupt
from tqdm import tqdm
from typing import Union, Optional

Expand Down Expand Up @@ -117,6 +116,8 @@ def prepare_dataset(
dataset = data[lang]["sentence"][dataset_name]["meta"]["train_data"]
elif split == "valid":
dataset = data[lang]["sentence"][dataset_name]["data"]
if dataset_name == "opus100" and lang == "fr":
dataset = data[lang]["sentence"][dataset_name]["data"]
if dataset is None:
return None

Expand Down Expand Up @@ -403,6 +404,12 @@ def maybe_pad(text):
for lang in tqdm(data.keys(), desc="Language"):
if lang in args.include_languages:
for dataset_name in data[lang]["sentence"].keys():
if dataset_name != "ted2020":
continue
# skip langs starting with a, b, ..., k
if lang[0] < "d":
print(f"Skipping {lang} {dataset_name}")
continue
# do model stuff here; otherwise, head params would be overwritten every time
backbone = SubwordXLMForTokenClassification.from_pretrained(
args.model_name_or_path, config=copy.deepcopy(config), ignore_mismatched_sizes=True
Expand Down Expand Up @@ -551,7 +558,7 @@ def compute_metrics(trainer):
model.backbone.classifier = torch.nn.Sequential(
clf, # original classifier - if frozen above, also frozen here
torch.nn.Linear(clf.out_features, 1),
)
)
model.backbone.config.num_labels = 1

# if args.one_sample_per_line:
Expand Down
3 changes: 1 addition & 2 deletions wtpsplit/train/train_adapter_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,7 @@
from wtpsplit.train.evaluate import evaluate_sentence, evaluate_sentence_pairwise
from wtpsplit.train.train import collate_fn
from wtpsplit.train.utils import Model
from wtpsplit.utils import Constants, LabelArgs, get_label_dict, get_subword_label_dict
from wtpsplit.evaluation.intrinsic import corrupt
from wtpsplit.utils import Constants, LabelArgs, get_label_dict, get_subword_label_dict, corrupt

os.environ["TOKENIZERS_PARALLELISM"] = "false"

Expand Down
38 changes: 30 additions & 8 deletions wtpsplit/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ class LabelArgs:
case_corruption_prob_after_newline: float = 0.0
case_corruption_prob_after_punct: float = 0.0
corrupt_entire_chunk_prob: float = 0.0
corrupt_entire_chunk_strategy: str = "full"

def __post_init__(self):
if self.custom_punctuation_file:
Expand Down Expand Up @@ -133,7 +134,7 @@ def get_subword_label_dict(label_args, tokenizer):
return label_dict


# numerically more stable sigmoid taken from
# numerically more stable sigmoid taken from
# https://stackoverflow.com/questions/51976461/optimal-way-of-defining-a-numerically-stable-sigmoid-function-for-a-list-in-pyth
def _positive_sigmoid(x):
return 1 / (1 + np.exp(-x))
Expand Down Expand Up @@ -196,8 +197,17 @@ def lang_code_to_lang(lang_code):
return languages.get(part3=lang_code).name


def corrupt(text: str, do_lowercase: bool, do_remove_punct: bool):
if do_lowercase:
text = text.lower()
if do_remove_punct:
for punct in Constants.PUNCTUATION_CHARS:
text = text.replace(punct, "")
return text


# does the steps in Figure 2 of the paper
def corrupt(
def corrupt_training(
input_ids,
block_ids,
lang,
Expand All @@ -211,14 +221,24 @@ def corrupt(
block_ids = block_ids.copy()
if random.random() < label_args.corrupt_entire_chunk_prob:
# lowercase all text
lowercased = tokenizer.decode(input_ids).lower()
input_ids = tokenizer.encode(lowercased, add_special_tokens=False)
input_text = tokenizer.decode(input_ids)
if label_args.corrupt_entire_chunk_strategy == "tokenizer":
if not tokenizer:
raise NotImplementedError()
corrupted = corrupt(input_text, do_lowercase=True, do_remove_punct=False)
input_ids = tokenizer.encode(corrupted, add_special_tokens=False)
# remove ALL punct *tokens*
auxiliary_remove_prob = 1.0
elif label_args.corrupt_entire_chunk_strategy == "full":
# remove all punct *characters*
corrupted = corrupt(input_text, do_lowercase=True, do_remove_punct=True)
input_ids = tokenizer.encode(corrupted, add_special_tokens=False)
auxiliary_remove_prob = 1.0 # just for safety/consistency
block_ids = [0] * len(input_ids)
# remove ALL punct
auxiliary_remove_prob = 1.0

else:
auxiliary_remove_prob = label_args.auxiliary_remove_prob

labels = label(input_ids, label_dict)

separator = Constants.SEPARATORS[lang]
Expand Down Expand Up @@ -459,7 +479,9 @@ def reconstruct_sentences(text, partial_sentences):
label_dict = get_subword_label_dict(label_args, tokenizer)

# corrupt
input_ids, block_ids, labels = corrupt(input_ids, block_ids, "en", label_args, label_dict, tokenizer=tokenizer)
input_ids, block_ids, labels = corrupt_training(
input_ids, block_ids, "en", label_args, label_dict, tokenizer=tokenizer
)
print(input_ids)
print(labels)
print(tokenizer.tokenize(text))
Expand Down

0 comments on commit 3d1a8cc

Please sign in to comment.