From a7bc9543edc60dda85d88ee1ad6026cf5a5a4547 Mon Sep 17 00:00:00 2001 From: golmschenk Date: Sat, 13 Jul 2024 12:11:01 -0400 Subject: [PATCH] Rename metrics --- src/qusi/internal/train_session.py | 48 ++++++++++++++++++++---------- tests/unit_tests/metrics.py | 14 +++++++++ 2 files changed, 47 insertions(+), 15 deletions(-) create mode 100644 tests/unit_tests/metrics.py diff --git a/src/qusi/internal/train_session.py b/src/qusi/internal/train_session.py index 3ebf3e2..07b1eab 100644 --- a/src/qusi/internal/train_session.py +++ b/src/qusi/internal/train_session.py @@ -2,6 +2,7 @@ import logging from pathlib import Path +from warnings import warn import numpy as np import torch @@ -26,12 +27,15 @@ def train_session( validation_datasets: list[LightCurveDataset], model: Module, optimizer: Optimizer | None = None, - loss_function: Module | None = None, - metric_functions: list[Module] | None = None, + loss_metric: Module | None = None, + logging_metrics: list[Module] | None = None, *, hyperparameter_configuration: TrainHyperparameterConfiguration | None = None, system_configuration: TrainSystemConfiguration | None = None, logging_configuration: TrainLoggingConfiguration | None = None, + # Deprecated keyword parameters. + loss_function: Module | None = None, + metric_functions: list[Module] | None = None, ) -> None: """ Runs a training session. @@ -40,22 +44,36 @@ def train_session( :param validation_datasets: The datasets to validate on. :param model: The model to train. :param optimizer: The optimizer to be used during training. - :param loss_function: The loss function to train the model on. - :param metric_functions: A list of metric functions to record during the training process. + :param loss_metric: The loss function to train the model on. + :param logging_metrics: A list of metric functions to record during the training process. :param hyperparameter_configuration: The configuration of the hyperparameters. :param system_configuration: The configuration of the system. :param logging_configuration: The configuration of the logging. """ + if loss_metric is not None and loss_function is not None: + raise ValueError('Both `loss_metric` and `loss_function` cannot be set at the same time.') + if logging_metrics is not None and metric_functions is not None: + raise ValueError('Both `logging_metrics` and `metric_functions` cannot be set at the same time.') + if loss_function is not None: + warn('`loss_function` is deprecated and will be removed in the future. ' + 'Please use `loss_metric` instead.', UserWarning) + loss_metric = loss_function + if metric_functions is not None: + warn('`metric_functions` is deprecated and will be removed in the future. ' + 'Please use `logging_metrics` instead.', UserWarning) + logging_metrics = metric_functions + if hyperparameter_configuration is None: hyperparameter_configuration = TrainHyperparameterConfiguration.new() if system_configuration is None: system_configuration = TrainSystemConfiguration.new() if logging_configuration is None: logging_configuration = TrainLoggingConfiguration.new() - if loss_function is None: - loss_function = BCELoss() - if metric_functions is None: - metric_functions = [BinaryAccuracy(), BinaryAUROC()] + if loss_metric is None: + loss_metric = BCELoss() + if logging_metrics is None: + logging_metrics = [BinaryAccuracy(), BinaryAUROC()] + set_up_default_logger() sessions_directory = Path("sessions") sessions_directory.mkdir(exist_ok=True) @@ -99,21 +117,21 @@ def train_session( else: device = torch.device('cpu') model = model.to(device, non_blocking=True) - loss_function = loss_function.to(device, non_blocking=True) + loss_metric = loss_metric.to(device, non_blocking=True) if optimizer is None: optimizer = AdamW(model.parameters()) - metric_functions: list[Module] = [ + logging_metrics: list[Module] = [ metric_function.to(device, non_blocking=True) - for metric_function in metric_functions + for metric_function in logging_metrics ] for _cycle_index in range(hyperparameter_configuration.cycles): logger.info(f'Cycle {_cycle_index}') - train_phase(dataloader=train_dataloader, model=model, loss_metric=loss_function, - logging_metrics=metric_functions, optimizer=optimizer, + train_phase(dataloader=train_dataloader, model=model, loss_metric=loss_metric, + logging_metrics=logging_metrics, optimizer=optimizer, steps=hyperparameter_configuration.train_steps_per_cycle, device=device) for validation_dataloader in validation_dataloaders: - validation_phase(dataloader=validation_dataloader, model=model, loss_metric=loss_function, - logging_metrics=metric_functions, + validation_phase(dataloader=validation_dataloader, model=model, loss_metric=loss_metric, + logging_metrics=logging_metrics, steps=hyperparameter_configuration.validation_steps_per_cycle, device=device) save_model(model, suffix="latest_model", process_rank=0) wandb_commit(process_rank=0) diff --git a/tests/unit_tests/metrics.py b/tests/unit_tests/metrics.py new file mode 100644 index 0000000..61bd188 --- /dev/null +++ b/tests/unit_tests/metrics.py @@ -0,0 +1,14 @@ +import torch +from torch.nn import MSELoss + +from qusi.internal.train_session import update_logging_metrics + + +def test_update_logging_metrics_for_functional_metrics(): + predicted_targets = torch.tensor([1.]) + targets = torch.tensor([3.]) + metric_totals = torch.tensor([2.]) + expected_metric_totals = torch.tensor([6.]) + metric = MSELoss() + update_logging_metrics(predicted_targets, targets, [metric], metric_totals) + assert metric_totals == expected_metric_totals