Skip to content

Commit

Permalink
fix threshold investigation
Browse files Browse the repository at this point in the history
  • Loading branch information
markus583 committed Apr 4, 2024
1 parent 888c1c9 commit 99c6b9c
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 7 deletions.
6 changes: 3 additions & 3 deletions wtpsplit/train/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,9 +268,9 @@ def evaluate_sentence_kmers(
info["threshold_best"] / (1 - info["threshold_best"])
) # inverse sigmoid
# For accuracy, check if all the labels in between are correctly predicted (ignore the one at the end)
intermediate_newline_labels = newline_labels[: -len(separator)] # Exclude the end
intermediate_predicted_labels = predicted_labels[: -len(separator)] # Exclude the end
intermediate_predicted_labels_opt = predicted_labels_optimal[: -len(separator)] # Exclude the end
intermediate_newline_labels = newline_labels[:-1] # Exclude the end
intermediate_predicted_labels = predicted_labels[:-1]
intermediate_predicted_labels_opt = predicted_labels_optimal[:-1]
correct = np.array_equal(intermediate_newline_labels, intermediate_predicted_labels)
correct_optimal = np.array_equal(intermediate_newline_labels, intermediate_predicted_labels_opt)
accuracy_list.append(correct)
Expand Down
8 changes: 4 additions & 4 deletions wtpsplit/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,8 +570,8 @@ def compute_metrics(trainer):
lang_code,
dataset["data"],
model,
stride=args.eval_stride,
block_size=args.block_size,
stride=128,
block_size=512,
batch_size=training_args.per_device_eval_batch_size,
)
metrics[f"{lang_code}_{dataset_name}_pr_auc"] = score
Expand Down Expand Up @@ -608,8 +608,8 @@ def compute_metrics(trainer):
lang_code,
dataset["data"],
model,
stride=args.eval_stride,
block_size=args.block_size,
stride=128,
block_size=512,
batch_size=training_args.per_device_eval_batch_size,
k=k,
# sample_pct=0.1,
Expand Down

0 comments on commit 99c6b9c

Please sign in to comment.