From 399dc8b531ae167f663402333ea7bd553338cc0c Mon Sep 17 00:00:00 2001 From: Paul Tunison Date: Wed, 6 Nov 2024 13:23:10 -0500 Subject: [PATCH] Use best F1 instead of accuracy --- tcn_hpl/callbacks/plot_metrics.py | 6 ++--- tcn_hpl/models/ptg_module.py | 41 ++++++++++++++++++------------- 2 files changed, 27 insertions(+), 20 deletions(-) diff --git a/tcn_hpl/callbacks/plot_metrics.py b/tcn_hpl/callbacks/plot_metrics.py index 062c17f23..894b129c3 100644 --- a/tcn_hpl/callbacks/plot_metrics.py +++ b/tcn_hpl/callbacks/plot_metrics.py @@ -253,8 +253,8 @@ def on_validation_epoch_end( current_epoch = pl_module.current_epoch curr_acc = pl_module.val_acc.compute() - best_acc = pl_module.val_acc_best.compute() curr_f1 = pl_module.val_f1.compute() + best_f1 = pl_module.val_f1_best.compute() class_ids = np.arange(all_probs.shape[-1]) num_classes = len(class_ids) @@ -296,7 +296,7 @@ def on_validation_epoch_end( if Image is not None: pl_module.logger.experiment.track(Image(fig), name=f"CM Validation Epoch") - if curr_acc >= best_acc: + if curr_f1 >= best_f1: fig.savefig( self.output_dir / f"confusion_mat_val_epoch{current_epoch:04d}_acc_{curr_acc:.4f}_f1_{curr_f1:.4f}.jpg", @@ -380,7 +380,7 @@ def on_test_epoch_end( fig, ax = plt.subplots(figsize=(num_classes, num_classes)) - sns.heatmap(cm, annot=True, ax=ax, fmt=".2f", vmin=0, vmax=1) + sns.heatmap(cm, annot=True, ax=ax, fmt=".2f", linewidth=0.5, vmin=0, vmax=1) # labels, title and ticks ax.set_xlabel("Predicted labels") diff --git a/tcn_hpl/models/ptg_module.py b/tcn_hpl/models/ptg_module.py index 969b43fee..5de396dd0 100644 --- a/tcn_hpl/models/ptg_module.py +++ b/tcn_hpl/models/ptg_module.py @@ -123,8 +123,7 @@ def __init__( self.test_loss = MeanMetric() # for tracking best so far validation accuracy - self.val_acc_best = MaxMetric() - self.train_acc_best = MaxMetric() + self.val_f1_best = MaxMetric() def forward(self, x: torch.Tensor, m: torch.Tensor) -> torch.Tensor: """Perform a forward pass through the model `self.net`. @@ -141,13 +140,16 @@ def on_train_start(self) -> None: # so it's worth to make sure validation metrics don't store results from these checks self.val_loss.reset() self.val_acc.reset() - self.val_acc_best.reset() + self.val_f1.reset() + self.val_recall.reset() + self.val_precision.reset() + self.val_f1_best.reset() def compute_loss(self, p, y, mask): """Compute the total loss for a batch :param p: The prediction - :param batch_target: The target labels + :param y: The target labels :param mask: Marks valid input data :return: The loss @@ -325,13 +327,6 @@ def validation_epoch_end(self, outputs: Union[EPOCH_OUTPUT, List[EPOCH_OUTPUT]]) all_preds = torch.cat([o['preds'] for o in outputs]) all_targets = torch.cat([o['targets'] for o in outputs]) - acc = self.val_acc.compute() - # log `val_acc_best` as a value through `.compute()` return, instead of - # as a metric object otherwise metric would be reset by lightning after - # each epoch. - best_val_acc = self.val_acc_best(acc) # update best so far val acc - self.log("val/acc_best", best_val_acc, sync_dist=True, prog_bar=True) - self.val_f1(all_preds, all_targets) self.val_recall(all_preds, all_targets) self.val_precision(all_preds, all_targets) @@ -339,6 +334,12 @@ def validation_epoch_end(self, outputs: Union[EPOCH_OUTPUT, List[EPOCH_OUTPUT]]) self.log("val/recall", self.val_recall, prog_bar=True, on_epoch=True) self.log("val/precision", self.val_precision, prog_bar=True, on_epoch=True) + # log `val_f1_best` as a value through `.compute()` return, instead of + # as a metric object otherwise metric would be reset by lightning after + # each epoch. + self.val_f1_best(self.val_f1.compute()) + self.log("val/f1_best", self.val_f1_best.compute(), prog_bar=True, on_epoch=True) + def test_step( self, batch: Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor], @@ -357,16 +358,10 @@ def test_step( # update and log metrics self.test_loss(loss) self.test_acc(preds, targets[:, -1]) - self.test_f1(preds, targets[:, -1]) - self.test_recall(preds, targets[:, -1]) - self.test_precision(preds, targets[:, -1]) self.log( "test/loss", self.test_loss, on_step=False, on_epoch=True, prog_bar=True ) self.log("test/acc", self.test_acc, on_step=False, on_epoch=True, prog_bar=True) - self.log("test/f1", self.test_f1, on_step=False, on_epoch=True, prog_bar=True) - self.log("test/recall", self.test_recall, on_step=False, on_epoch=True, prog_bar=True) - self.log("test/precision", self.test_precision, on_step=False, on_epoch=True, prog_bar=True) # Only retain the truth and source vid/frame IDs for the final window # frame as this is the ultimately relevant result. @@ -379,6 +374,18 @@ def test_step( "source_frame": source_frame[:, -1], } + def test_epoch_end(self, outputs: Union[EPOCH_OUTPUT, List[EPOCH_OUTPUT]]) -> None: + all_preds = torch.cat([o['preds'] for o in outputs]) + all_targets = torch.cat([o['targets'] for o in outputs]) + + # update and log metrics + self.test_f1(all_preds, all_targets) + self.test_recall(all_preds, all_targets) + self.test_precision(all_preds, all_targets) + self.log("test/f1", self.test_f1, on_step=False, on_epoch=True, prog_bar=True) + self.log("test/recall", self.test_recall, on_step=False, on_epoch=True, prog_bar=True) + self.log("test/precision", self.test_precision, on_step=False, on_epoch=True, prog_bar=True) + def setup(self, stage: Optional[str] = None) -> None: """Lightning hook that is called at the beginning of fit (train + validate), validate, test, or predict.