Skip to content

Commit

Permalink
generalize pairwise eval to k-mer eval
Browse files Browse the repository at this point in the history
  • Loading branch information
markus583 committed Mar 28, 2024
1 parent fa5cb3a commit e2eae1c
Show file tree
Hide file tree
Showing 4 changed files with 251 additions and 71 deletions.
45 changes: 45 additions & 0 deletions configs/xlmr_stratify_0.1_6layers_p_v3_look6.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
{
"model_name_or_path": "xlm-roberta-base",
"output_dir": "xlmr-6l-v3_look6",
"train_text_path": "data/sentence/train.parquet",
"valid_text_path": "data/sentence/valid.parquet",
"block_size": 256,
"eval_stride": 128,
"use_bert": true,
"do_train": true,
"do_eval": true,
"evaluation_strategy": "steps",
"per_device_train_batch_size": 128,
"per_device_eval_batch_size": 64,
"gradient_accumulation_steps": 1,
"eval_accumulation_steps": 8,
"dataloader_num_workers": 4,
"preprocessing_num_workers": 4,
"learning_rate": 1e-4,
"save_strategy": "steps",
"fp16": false,
"max_steps": 200000,
"save_steps": 50000,
"eval_steps": 5000,
"logging_steps": 50,
"report_to": "none",
"is_decoder": false,
"remove_unused_columns": false,
"lookahead": 6,
"one_sample_per_line": false,
"do_sentence_training": true,
"do_auxiliary_training": true,
"auxiliary_remove_prob": 0.2,
"warmup_steps": 5000,
"adapter_warmup_steps": 0,
"adapter_lr_multiplier": 1,
"ngram_order": 1,
"non_punctuation_sample_ratio": 0.1,
"prediction_loss_only": true,
"use_auxiliary": true,
"ddp_timeout": 3600,
"use_subwords": true,
"num_hidden_layers": 6,
"custom_punctuation_file": "punctuation_xlmr_unk.txt",
"log_level": "warning"
}
70 changes: 60 additions & 10 deletions wtpsplit/evaluation/intrinsic_pairwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,27 +61,27 @@ class Args:
min_pair_length: int = 0


def process_logits_pairwise(pairs, model, lang_code, block_size, batch_size, verbose=True) -> List[np.ndarray]:
def process_logits_k_mers(pairs, model, lang_code, block_size, batch_size, verbose=True) -> List[np.ndarray]:
logits_list = []
# create batches of sentence pairs
batched_pairs = [pairs[i : i + batch_size] for i in range(0, len(pairs), batch_size)]
for batch in tqdm(batched_pairs, disable=not verbose):
pair_texts = [pair[0] + Constants.SEPARATORS[lang_code] + pair[1] for pair in batch]
batched_k_mers = [pairs[i : i + batch_size] for i in range(0, len(pairs), batch_size)]
for batch in tqdm(batched_k_mers, disable=not verbose):
k_mer_texts = [Constants.SEPARATORS[lang_code].join(pair) for pair in batch]
all_logits, offsets_mapping, tokenizer = extract_batched(
pair_texts,
k_mer_texts,
model,
lang_code=lang_code,
block_size=block_size,
batch_size=batch_size,
pad_last_batch=True,
)

for pair, logit, offset_mapping in zip(pair_texts, all_logits, offsets_mapping):
for k_mer, logit, offset_mapping in zip(k_mer_texts, all_logits, offsets_mapping):
if "xlm" in model.config.model_type:
tokens = tokenizer.tokenize(pair, verbose=False)
tokens = tokenizer.tokenize(k_mer, verbose=False)

# padding is also removed here (via offset_mapping)
logits = token_to_char_probs(pair, tokens, logit, tokenizer, offset_mapping)
logits = token_to_char_probs(k_mer, tokens, logit, tokenizer, offset_mapping)
logits_list.append(logits)
else:
if len(logit) < offset_mapping:
Expand Down Expand Up @@ -143,6 +143,56 @@ def generate_pairs(
]
return all_pairs

def generate_k_mers(
sentences: List[str],
k: int,
do_lowercase: bool,
do_remove_punct: bool,
sample_pct: float = 1,
max_n_samples: int = sys.maxsize,
min_k_mer_length: int = 0,
) -> List[Tuple[str, ...]]:
"""Generate k-mers from a list of sentences.
Args:
sentences (List[str]): Input list of sentences.
k (int): The number of sentences to include in each k-mer.
sample_pct (float): Percentage of k-mers to sample.
max_n_samples (int): Maximum number of k-mers to sample.
min_k_mer_length (int): Minimum length of a k-mer.
do_lowercase (bool): Whether to lowercase the sentences.
do_remove_punct (bool): Whether to remove punctuation from the sentences.
Returns:
List[Tuple[str, ...]]: List of k-mers.
"""
random.seed(42)
n_k_mers = len(sentences) // k
sample_size = min(round(n_k_mers * sample_pct), max_n_samples)

# Efficient sampling of a subset of all possible k-mers if needed
if sample_size < n_k_mers:
sampled_indices = set(random.sample(range(n_k_mers), sample_size))
all_k_mers = [
tuple(sentences[i*k+j] for j in range(k))
for i in sampled_indices
if sum(len(sentences[i*k+j]) for j in range(k)) > min_k_mer_length
]
else:
# Generate all k-mers that meet the min_k_mer_length criterion
all_k_mers = [
tuple(sentences[i+j] for j in range(k))
for i in range(0, len(sentences) - k + 1, k)
if sum(len(sentences[i+j]) for j in range(k)) > min_k_mer_length
]

# Apply corruption to k-mers
all_k_mers = [
tuple(corrupt(sentence, do_lowercase, do_remove_punct) for sentence in k_mer)
for k_mer in all_k_mers
]

return all_k_mers

def load_or_compute_logits(args, model, eval_data, valid_data=None, save_str: str = None):
logits_path = Constants.CACHE_DIR / "intrinsic_pairwise" / f"{save_str}.h5"
Expand Down Expand Up @@ -204,7 +254,7 @@ def load_or_compute_logits(args, model, eval_data, valid_data=None, save_str: st
)

start_time = time.time() # Start timing for test logits processing
test_logits = process_logits_pairwise(
test_logits = process_logits_k_mers(
all_pairs_test,
model,
lang_code,
Expand Down Expand Up @@ -250,7 +300,7 @@ def load_or_compute_logits(args, model, eval_data, valid_data=None, save_str: st
min_pair_length=args.min_pair_length,
)

train_logits = process_logits_pairwise(
train_logits = process_logits_k_mers(
all_pairs_train, model, lang_code, args.block_size, args.batch_size
)
train_logits = np.concatenate(train_logits)
Expand Down
75 changes: 72 additions & 3 deletions wtpsplit/train/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from wtpsplit.evaluation import token_to_char_probs
from wtpsplit.evaluation.intrinsic import corrupt
from wtpsplit.evaluation.intrinsic_pairwise import generate_pairs, process_logits_pairwise
from wtpsplit.evaluation.intrinsic_pairwise import generate_pairs, generate_k_mers, process_logits_k_mers
from wtpsplit.extract import PyTorchWrapper, extract
from wtpsplit.utils import Constants, sigmoid

Expand Down Expand Up @@ -164,7 +164,7 @@ def evaluate_sentence_pairwise(
)

# get logits for each pair
logits = process_logits_pairwise(
logits = process_logits_k_mers(
pairs=sampled_pairs,
model=PyTorchWrapper(model.backbone),
lang_code=lang_code,
Expand All @@ -174,7 +174,6 @@ def evaluate_sentence_pairwise(
)

# simulate performance for WtP-U

for i, (sentence1, sentence2) in enumerate(sampled_pairs):
newline_probs = logits[i][:, positive_index]

Expand All @@ -200,3 +199,73 @@ def evaluate_sentence_pairwise(
average_metric = np.mean(metrics_list)
avg_accuracy = np.mean(accuracy_list)
return average_metric, avg_accuracy

def evaluate_sentence_kmers(
lang_code,
sentences,
model,
stride,
block_size,
batch_size,
k: int = 3,
sample_pct: float = 0.1,
max_n_samples: int = sys.maxsize,
use_pysbd=False,
positive_index=None,
do_lowercase=False,
do_remove_punct=False,
threshold: float = 0.1
):
if positive_index is None:
positive_index = Constants.NEWLINE_INDEX

# Preprocess sentences
sentences = [sentence.lstrip("-").strip() for sentence in sentences]

separator = Constants.SEPARATORS[lang_code]
metrics_list = []
accuracy_list = []

# get pairs of sentences (non-overlapping)
sampled_k_mers = generate_k_mers(
sentences=sentences,
k=k,
do_lowercase=do_lowercase,
do_remove_punct=do_remove_punct,
sample_pct=sample_pct,
max_n_samples=max_n_samples,
min_k_mer_length=0,
)

# get logits for each pair
logits = process_logits_k_mers( # TODO
pairs=sampled_k_mers,
model=PyTorchWrapper(model.backbone),
lang_code=lang_code,
block_size=block_size,
batch_size=batch_size,
verbose=False,
)

for i, k_mer in enumerate(sampled_k_mers):
newline_probs = logits[i][:, positive_index]

k_mer_text = separator.join(k_mer)
true_end_indices = np.cumsum(np.array([len(s) for s in k_mer])) + np.arange(len(k_mer)) * len(separator)
newline_labels = np.zeros(len(k_mer_text))
newline_labels[true_end_indices - 1] = 1

# Get metrics for the k-mer
k_mer_metrics, _ = get_metrics(newline_labels, newline_probs)
metrics_list.append(k_mer_metrics["pr_auc"])
predicted_labels = newline_probs > np.log(threshold / (1 - threshold)) # inverse sigmoid
# For accuracy, check if all the labels in between are correctly predicted (ignore the one at the end)
intermediate_newline_labels = newline_labels[:-len(separator)] # Exclude the end
intermediate_predicted_labels = predicted_labels[:-len(separator)] # Exclude the end
correct = np.array_equal(intermediate_newline_labels, intermediate_predicted_labels)
accuracy_list.append(correct)

# Compute and return the average metric and accuracy
average_metric = np.mean(metrics_list)
avg_accuracy = np.mean(accuracy_list)
return average_metric, avg_accuracy
Loading

0 comments on commit e2eae1c

Please sign in to comment.