From d8945e842974c0b3839940f157bbe6f473dc0090 Mon Sep 17 00:00:00 2001 From: golmschenk Date: Thu, 9 May 2024 14:31:57 -0400 Subject: [PATCH] Add binary AUROC metric --- pyproject.toml | 3 ++- src/qusi/internal/train_session.py | 4 ++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 6b3e2de8..fe9d2221 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,7 +48,8 @@ dependencies = [ "sphinx>=6.1.3", "backports.strenum", "typing_extensions", - "myst-parser" + "myst-parser", + "torcheval>=0.0.7", ] [build-system] diff --git a/src/qusi/internal/train_session.py b/src/qusi/internal/train_session.py index 4366d4a6..4b7f4110 100644 --- a/src/qusi/internal/train_session.py +++ b/src/qusi/internal/train_session.py @@ -9,7 +9,7 @@ from torch.nn import BCELoss, Module from torch.optim import AdamW from torch.utils.data import DataLoader -from torchmetrics.classification import BinaryAccuracy +from torcheval.metrics import BinaryAccuracy, BinaryAUROC import wandb from qusi.internal.light_curve_dataset import InterleavedDataset, LightCurveDataset @@ -49,7 +49,7 @@ def train_session( if loss_function is None: loss_function = BCELoss() if metric_functions is None: - metric_functions = [BinaryAccuracy()] + metric_functions = [BinaryAccuracy(), BinaryAUROC()] set_up_default_logger() wandb_init( process_rank=0,