Skip to content

Commit

Permalink
threshold investigation
Browse files Browse the repository at this point in the history
  • Loading branch information
markus583 committed Apr 4, 2024
1 parent aec7fc2 commit 284e925
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 9 deletions.
32 changes: 25 additions & 7 deletions wtpsplit/train/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def evaluate_sentence_pairwise(
positive_index=None,
do_lowercase=False,
do_remove_punct=False,
threshold: float = 0.1
threshold: float = 0.1,
):
if positive_index is None:
positive_index = Constants.NEWLINE_INDEX
Expand Down Expand Up @@ -200,6 +200,7 @@ def evaluate_sentence_pairwise(
avg_accuracy = np.mean(accuracy_list)
return average_metric, avg_accuracy


def evaluate_sentence_kmers(
lang_code,
sentences,
Expand All @@ -214,7 +215,7 @@ def evaluate_sentence_kmers(
positive_index=None,
do_lowercase=False,
do_remove_punct=False,
threshold: float = 0.1
threshold: float = 0.1,
):
if positive_index is None:
positive_index = Constants.NEWLINE_INDEX
Expand All @@ -225,6 +226,8 @@ def evaluate_sentence_kmers(
separator = Constants.SEPARATORS[lang_code]
metrics_list = []
accuracy_list = []
accuracy_list_optimal = []
info_list = []

# get pairs of sentences (non-overlapping)
sampled_k_mers = generate_k_mers(
Expand All @@ -238,7 +241,7 @@ def evaluate_sentence_kmers(
)

# get logits for each pair
logits = process_logits_k_mers( # TODO
logits = process_logits_k_mers(
pairs=sampled_k_mers,
model=PyTorchWrapper(model.backbone),
lang_code=lang_code,
Expand All @@ -256,16 +259,31 @@ def evaluate_sentence_kmers(
newline_labels[true_end_indices - 1] = 1

# Get metrics for the k-mer
k_mer_metrics, _ = get_metrics(newline_labels, newline_probs)
k_mer_metrics, info = get_metrics(newline_labels, newline_probs)
metrics_list.append(k_mer_metrics["pr_auc"])
info_list.append(info)

predicted_labels = newline_probs > np.log(threshold / (1 - threshold)) # inverse sigmoid
predicted_labels_optimal = newline_probs > np.log(
info["threshold_best"] / (1 - info["threshold_best"])
) # inverse sigmoid
# For accuracy, check if all the labels in between are correctly predicted (ignore the one at the end)
intermediate_newline_labels = newline_labels[:-len(separator)] # Exclude the end
intermediate_predicted_labels = predicted_labels[:-len(separator)] # Exclude the end
intermediate_newline_labels = newline_labels[: -len(separator)] # Exclude the end
intermediate_predicted_labels = predicted_labels[: -len(separator)] # Exclude the end
intermediate_predicted_labels_opt = predicted_labels_optimal[: -len(separator)] # Exclude the end
correct = np.array_equal(intermediate_newline_labels, intermediate_predicted_labels)
correct_optimal = np.array_equal(intermediate_newline_labels, intermediate_predicted_labels_opt)
accuracy_list.append(correct)
accuracy_list_optimal.append(correct_optimal)

# Compute and return the average metric and accuracy
average_metric = np.mean(metrics_list)
avg_accuracy = np.mean(accuracy_list)
return average_metric, avg_accuracy
# get averages for info_list
avg_info = {
key: np.mean([info[key] for info in info_list])
for key in info_list[0].keys()
if isinstance(info_list[0][key], (int, float))
}
avg_info["accuracy_optimal"] = np.mean(accuracy_list_optimal)
return average_metric, avg_accuracy, avg_info
24 changes: 22 additions & 2 deletions wtpsplit/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,7 +566,7 @@ def compute_metrics(trainer):
if trainer.args.process_index == 0 and args.do_sentence_training:
# with training_args.main_process_first():
for dataset_name, dataset in lang_data["sentence"].items():
score, _ = evaluate_sentence(
score, info = evaluate_sentence(
lang_code,
dataset["data"],
model,
Expand All @@ -575,7 +575,13 @@ def compute_metrics(trainer):
batch_size=training_args.per_device_eval_batch_size,
)
metrics[f"{lang_code}_{dataset_name}_pr_auc"] = score
metrics[f"{lang_code}_{dataset_name}_f1"] = info["f1"]
metrics[f"{lang_code}_{dataset_name}_f1_best"] = info["f1_best"]
metrics[f"{lang_code}_{dataset_name}_threshold_best"] = info["threshold_best"]
avg_metrics[f"average_{dataset_name}_pr_auc"].append(score)
avg_metrics[f"average_{dataset_name}_f1"].append(info["f1"])
avg_metrics[f"average_{dataset_name}_f1_best"].append(info["f1_best"])
avg_metrics[f"average_{dataset_name}_threshold_best"].append(info["threshold_best"])
# if lang_code in ["zh", "ja", "my", "km"]:
# avg_metrics[f"average_nonwhitespace_{dataset_name}_pr_auc"].append(score)
# else:
Expand All @@ -598,7 +604,7 @@ def compute_metrics(trainer):
# avg_metrics[f"lower_rmp_average_whitespace_{dataset_name}_pr_auc"].append(score)
# k-mer based evaluation
for k in [2, 3, 4, 5, 6]:
score, avg_acc = evaluate_sentence_kmers(
score, avg_acc, info = evaluate_sentence_kmers(
lang_code,
dataset["data"],
model,
Expand All @@ -612,6 +618,13 @@ def compute_metrics(trainer):
avg_metrics[f"k_{k}_average_{dataset_name}_pr_auc"].append(score)
metrics[f"k_{k}_{lang_code}_{dataset_name}_acc"] = avg_acc
avg_metrics[f"k_{k}_average_{dataset_name}_acc"].append(avg_acc)
metrics[f"k_{k}_{lang_code}_{dataset_name}_f1"] = info["f1"]
metrics[f"k_{k}_{lang_code}_{dataset_name}_f1_best"] = info["f1_best"]
metrics[f"k_{k}_{lang_code}_{dataset_name}_threshold_best"] = info["threshold_best"]
avg_metrics[f"k_{k}_average_{dataset_name}_f1"].append(info["f1"])
avg_metrics[f"k_{k}_average_{dataset_name}_f1_best"].append(info["f1_best"])
avg_metrics[f"k_{k}_average_{dataset_name}_threshold_best"].append(info["threshold_best"])

# if lang_code in ["zh", "ja", "my", "km"]:
# avg_metrics[f"k_{k}_average_nonwhitespace_{dataset_name}_pr_auc"].append(score)
# avg_metrics[f"k_{k}_average_nonwhitespace_{dataset_name}_acc"].append(avg_acc)
Expand All @@ -624,6 +637,12 @@ def compute_metrics(trainer):
avg_metrics[f"pairwise_average_{dataset_name}_pr_auc"].append(score)
metrics[f"pairwise_{lang_code}_{dataset_name}_acc"] = avg_acc
avg_metrics[f"pairwise_average_{dataset_name}_acc"].append(avg_acc)
metrics[f"pairwise_{lang_code}_{dataset_name}_f1"] = info["f1"]
metrics[f"pairwise_{lang_code}_{dataset_name}_f1_best"] = info["f1_best"]
metrics[f"pairwise_{lang_code}_{dataset_name}_threshold_best"] = info["threshold_best"]
avg_metrics[f"pairwise_average_{dataset_name}_f1"].append(info["f1"])
avg_metrics[f"pairwise_average_{dataset_name}_f1_best"].append(info["f1_best"])
avg_metrics[f"pairwise_average_{dataset_name}_threshold_best"].append(info["threshold_best"])
# if lang_code in ["zh", "ja", "my", "km"]:
# avg_metrics[f"pairwise_average_nonwhitespace_{dataset_name}_pr_auc"].append(score)
# avg_metrics[f"pairwise_average_nonwhitespace_{dataset_name}_acc"].append(avg_acc)
Expand Down Expand Up @@ -676,6 +695,7 @@ def compute_metrics(trainer):
label_args=label_args,
label_dict=label_dict,
tokenizer=tokenizer if args.use_subwords else None,
add_lang_ids=not args.use_subwords,
),
)

Expand Down

0 comments on commit 284e925

Please sign in to comment.