Skip to content

Commit

Permalink
Rename metrics
Browse files Browse the repository at this point in the history
  • Loading branch information
golmschenk committed Jul 13, 2024
1 parent f452521 commit a7bc954
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 15 deletions.
48 changes: 33 additions & 15 deletions src/qusi/internal/train_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import logging
from pathlib import Path
from warnings import warn

import numpy as np
import torch
Expand All @@ -26,12 +27,15 @@ def train_session(
validation_datasets: list[LightCurveDataset],
model: Module,
optimizer: Optimizer | None = None,
loss_function: Module | None = None,
metric_functions: list[Module] | None = None,
loss_metric: Module | None = None,
logging_metrics: list[Module] | None = None,
*,
hyperparameter_configuration: TrainHyperparameterConfiguration | None = None,
system_configuration: TrainSystemConfiguration | None = None,
logging_configuration: TrainLoggingConfiguration | None = None,
# Deprecated keyword parameters.
loss_function: Module | None = None,
metric_functions: list[Module] | None = None,
) -> None:
"""
Runs a training session.
Expand All @@ -40,22 +44,36 @@ def train_session(
:param validation_datasets: The datasets to validate on.
:param model: The model to train.
:param optimizer: The optimizer to be used during training.
:param loss_function: The loss function to train the model on.
:param metric_functions: A list of metric functions to record during the training process.
:param loss_metric: The loss function to train the model on.
:param logging_metrics: A list of metric functions to record during the training process.
:param hyperparameter_configuration: The configuration of the hyperparameters.
:param system_configuration: The configuration of the system.
:param logging_configuration: The configuration of the logging.
"""
if loss_metric is not None and loss_function is not None:
raise ValueError('Both `loss_metric` and `loss_function` cannot be set at the same time.')
if logging_metrics is not None and metric_functions is not None:
raise ValueError('Both `logging_metrics` and `metric_functions` cannot be set at the same time.')
if loss_function is not None:
warn('`loss_function` is deprecated and will be removed in the future. '
'Please use `loss_metric` instead.', UserWarning)
loss_metric = loss_function
if metric_functions is not None:
warn('`metric_functions` is deprecated and will be removed in the future. '
'Please use `logging_metrics` instead.', UserWarning)
logging_metrics = metric_functions

if hyperparameter_configuration is None:
hyperparameter_configuration = TrainHyperparameterConfiguration.new()
if system_configuration is None:
system_configuration = TrainSystemConfiguration.new()
if logging_configuration is None:
logging_configuration = TrainLoggingConfiguration.new()
if loss_function is None:
loss_function = BCELoss()
if metric_functions is None:
metric_functions = [BinaryAccuracy(), BinaryAUROC()]
if loss_metric is None:
loss_metric = BCELoss()
if logging_metrics is None:
logging_metrics = [BinaryAccuracy(), BinaryAUROC()]

set_up_default_logger()
sessions_directory = Path("sessions")
sessions_directory.mkdir(exist_ok=True)
Expand Down Expand Up @@ -99,21 +117,21 @@ def train_session(
else:
device = torch.device('cpu')
model = model.to(device, non_blocking=True)
loss_function = loss_function.to(device, non_blocking=True)
loss_metric = loss_metric.to(device, non_blocking=True)
if optimizer is None:
optimizer = AdamW(model.parameters())
metric_functions: list[Module] = [
logging_metrics: list[Module] = [
metric_function.to(device, non_blocking=True)
for metric_function in metric_functions
for metric_function in logging_metrics
]
for _cycle_index in range(hyperparameter_configuration.cycles):
logger.info(f'Cycle {_cycle_index}')
train_phase(dataloader=train_dataloader, model=model, loss_metric=loss_function,
logging_metrics=metric_functions, optimizer=optimizer,
train_phase(dataloader=train_dataloader, model=model, loss_metric=loss_metric,
logging_metrics=logging_metrics, optimizer=optimizer,
steps=hyperparameter_configuration.train_steps_per_cycle, device=device)
for validation_dataloader in validation_dataloaders:
validation_phase(dataloader=validation_dataloader, model=model, loss_metric=loss_function,
logging_metrics=metric_functions,
validation_phase(dataloader=validation_dataloader, model=model, loss_metric=loss_metric,
logging_metrics=logging_metrics,
steps=hyperparameter_configuration.validation_steps_per_cycle, device=device)
save_model(model, suffix="latest_model", process_rank=0)
wandb_commit(process_rank=0)
Expand Down
14 changes: 14 additions & 0 deletions tests/unit_tests/metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import torch
from torch.nn import MSELoss

from qusi.internal.train_session import update_logging_metrics


def test_update_logging_metrics_for_functional_metrics():
predicted_targets = torch.tensor([1.])
targets = torch.tensor([3.])
metric_totals = torch.tensor([2.])
expected_metric_totals = torch.tensor([6.])
metric = MSELoss()
update_logging_metrics(predicted_targets, targets, [metric], metric_totals)
assert metric_totals == expected_metric_totals

0 comments on commit a7bc954

Please sign in to comment.