Skip to content

Commit

Permalink
Merge pull request #3591 from flairNLP/relation_printout
Browse files Browse the repository at this point in the history
Modify printouts in RelationClassifier evaluation to remove clutter
  • Loading branch information
alanakbik authored Jan 2, 2025
2 parents 68508cc + 58e903c commit 8bc9c28
Showing 1 changed file with 26 additions and 0 deletions.
26 changes: 26 additions & 0 deletions flair/models/relation_classifier_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -667,6 +667,32 @@ def predict(

return loss if return_loss else None

def _print_predictions(self, batch, gold_label_type: str) -> list[str]:
lines = []
for datapoint in batch:
# check if there is a label mismatch
g = [label.labeled_identifier for label in datapoint.get_labels(gold_label_type)]
p = [label.labeled_identifier for label in datapoint.get_labels("predicted")]
g.sort()
p.sort()

# if the gold label is O and is correctly predicted as no label, do not print out as this clutters
# the output file with trivial predictions
if not (
len(datapoint.get_labels(gold_label_type)) == 1
and datapoint.get_label(gold_label_type).value == "O"
and len(datapoint.get_labels("predicted")) == 0
):
correct_string = " -> MISMATCH!\n" if g != p else ""
eval_line = (
f"{datapoint.text}\n"
f" - Gold: {', '.join(label.value if label.data_point == datapoint else label.labeled_identifier for label in datapoint.get_labels(gold_label_type))}\n"
f" - Pred: {', '.join(label.value if label.data_point == datapoint else label.labeled_identifier for label in datapoint.get_labels('predicted'))}\n"
f"{correct_string}\n"
)
lines.append(eval_line)
return lines

def _get_state_dict(self) -> dict[str, Any]:
model_state: dict[str, Any] = {
**super()._get_state_dict(),
Expand Down

0 comments on commit 8bc9c28

Please sign in to comment.