From 0ccb22f2c5acde978d1d9ed61ae98da2c33f5a4f Mon Sep 17 00:00:00 2001 From: alanakbik Date: Wed, 1 Jan 2025 11:58:44 +0100 Subject: [PATCH 1/2] Unclutter printouts of RelationClassifier during evaluation --- flair/models/relation_classifier_model.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/flair/models/relation_classifier_model.py b/flair/models/relation_classifier_model.py index 9c6c69577f..389bf05157 100644 --- a/flair/models/relation_classifier_model.py +++ b/flair/models/relation_classifier_model.py @@ -667,6 +667,28 @@ def predict( return loss if return_loss else None + def _print_predictions(self, batch, gold_label_type: str) -> list[str]: + lines = [] + for datapoint in batch: + # check if there is a label mismatch + g = [label.labeled_identifier for label in datapoint.get_labels(gold_label_type)] + p = [label.labeled_identifier for label in datapoint.get_labels("predicted")] + g.sort() + p.sort() + + # if the gold label is O and is correctly predicted as no label, do not print out as this clutters + # the output file with trivial predictions + if not (len(datapoint.get_labels(gold_label_type)) == 1 and datapoint.get_label(gold_label_type).value == "O" and len(datapoint.get_labels("predicted")) == 0): + correct_string = " -> MISMATCH!\n" if g != p else "" + eval_line = ( + f"{datapoint.text}\n" + f" - Gold: {', '.join(label.value if label.data_point == datapoint else label.labeled_identifier for label in datapoint.get_labels(gold_label_type))}\n" + f" - Pred: {', '.join(label.value if label.data_point == datapoint else label.labeled_identifier for label in datapoint.get_labels('predicted'))}\n" + f"{correct_string}\n" + ) + lines.append(eval_line) + return lines + def _get_state_dict(self) -> dict[str, Any]: model_state: dict[str, Any] = { **super()._get_state_dict(), From 58e903cf60d207f93bce0a4108ae9c34a8762e3c Mon Sep 17 00:00:00 2001 From: alanakbik Date: Wed, 1 Jan 2025 12:08:03 +0100 Subject: [PATCH 2/2] Formatting --- flair/models/relation_classifier_model.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/flair/models/relation_classifier_model.py b/flair/models/relation_classifier_model.py index 389bf05157..8aca236230 100644 --- a/flair/models/relation_classifier_model.py +++ b/flair/models/relation_classifier_model.py @@ -678,7 +678,11 @@ def _print_predictions(self, batch, gold_label_type: str) -> list[str]: # if the gold label is O and is correctly predicted as no label, do not print out as this clutters # the output file with trivial predictions - if not (len(datapoint.get_labels(gold_label_type)) == 1 and datapoint.get_label(gold_label_type).value == "O" and len(datapoint.get_labels("predicted")) == 0): + if not ( + len(datapoint.get_labels(gold_label_type)) == 1 + and datapoint.get_label(gold_label_type).value == "O" + and len(datapoint.get_labels("predicted")) == 0 + ): correct_string = " -> MISMATCH!\n" if g != p else "" eval_line = ( f"{datapoint.text}\n"