From a14b31c4c0e04169b09fa1412d3a3fd538124d3e Mon Sep 17 00:00:00 2001 From: golmschenk Date: Mon, 2 Sep 2024 11:05:41 -0400 Subject: [PATCH] Add basic PyTorch Lightning module version --- src/qusi/internal/lightning_train_session.py | 108 +++++++++++++++++ src/qusi/internal/module.py | 111 ++++++++++++++++++ .../test_toy_train_lightning_session.py | 28 +++++ 3 files changed, 247 insertions(+) create mode 100644 src/qusi/internal/lightning_train_session.py create mode 100644 src/qusi/internal/module.py create mode 100644 tests/end_to_end_tests/test_toy_train_lightning_session.py diff --git a/src/qusi/internal/lightning_train_session.py b/src/qusi/internal/lightning_train_session.py new file mode 100644 index 00000000..68b87a65 --- /dev/null +++ b/src/qusi/internal/lightning_train_session.py @@ -0,0 +1,108 @@ +from __future__ import annotations + +import logging +from warnings import warn + +import lightning +from torch.nn import BCELoss, Module +from torch.optim import Optimizer +from torch.utils.data import DataLoader +from torchmetrics.classification import BinaryAccuracy, BinaryAUROC + +from qusi.internal.light_curve_dataset import InterleavedDataset, LightCurveDataset +from qusi.internal.logging import set_up_default_logger +from qusi.internal.module import QusiLightningModule +from qusi.internal.train_hyperparameter_configuration import TrainHyperparameterConfiguration +from qusi.internal.train_logging_configuration import TrainLoggingConfiguration +from qusi.internal.train_system_configuration import TrainSystemConfiguration + +logger = logging.getLogger(__name__) + + +def train_session( + train_datasets: list[LightCurveDataset], + validation_datasets: list[LightCurveDataset], + model: Module, + optimizer: Optimizer | None = None, + loss_metric: Module | None = None, + logging_metrics: list[Module] | None = None, + *, + hyperparameter_configuration: TrainHyperparameterConfiguration | None = None, + system_configuration: TrainSystemConfiguration | None = None, + logging_configuration: TrainLoggingConfiguration | None = None, + # Deprecated keyword parameters. + loss_function: Module | None = None, + metric_functions: list[Module] | None = None, +) -> None: + """ + Runs a training session. + + :param train_datasets: The datasets to train on. + :param validation_datasets: The datasets to validate on. + :param model: The model to train. + :param optimizer: The optimizer to be used during training. + :param loss_metric: The loss function to train the model on. + :param logging_metrics: A list of metric functions to record during the training process. + :param hyperparameter_configuration: The configuration of the hyperparameters. + :param system_configuration: The configuration of the system. + :param logging_configuration: The configuration of the logging. + """ + if loss_metric is not None and loss_function is not None: + raise ValueError('Both `loss_metric` and `loss_function` cannot be set at the same time.') + if logging_metrics is not None and metric_functions is not None: + raise ValueError('Both `logging_metrics` and `metric_functions` cannot be set at the same time.') + if loss_function is not None: + warn('`loss_function` is deprecated and will be removed in the future. ' + 'Please use `loss_metric` instead.', UserWarning) + loss_metric = loss_function + if metric_functions is not None: + warn('`metric_functions` is deprecated and will be removed in the future. ' + 'Please use `logging_metrics` instead.', UserWarning) + logging_metrics = metric_functions + + if hyperparameter_configuration is None: + hyperparameter_configuration = TrainHyperparameterConfiguration.new() + if system_configuration is None: + system_configuration = TrainSystemConfiguration.new() + if loss_metric is None: + loss_metric = BCELoss() + if logging_metrics is None: + logging_metrics = [BinaryAccuracy(), BinaryAUROC()] + + set_up_default_logger() + train_dataset = InterleavedDataset.new(*train_datasets) + workers_per_dataloader = system_configuration.preprocessing_processes_per_train_process + if workers_per_dataloader == 0: + prefetch_factor = None + persistent_workers = False + else: + prefetch_factor = 10 + persistent_workers = True + train_dataloader = DataLoader( + train_dataset, + batch_size=hyperparameter_configuration.batch_size, + pin_memory=True, + persistent_workers=persistent_workers, + prefetch_factor=prefetch_factor, + num_workers=workers_per_dataloader, + ) + validation_dataloaders: list[DataLoader] = [] + for validation_dataset in validation_datasets: + validation_dataloader = DataLoader( + validation_dataset, + batch_size=hyperparameter_configuration.batch_size, + pin_memory=True, + persistent_workers=persistent_workers, + prefetch_factor=prefetch_factor, + num_workers=workers_per_dataloader, + ) + validation_dataloaders.append(validation_dataloader) + + lightning_model = QusiLightningModule.new(model=model, optimizer=optimizer, loss_metric=loss_metric, + logging_metrics=logging_metrics) + trainer = lightning.Trainer( + max_epochs=hyperparameter_configuration.cycles, + limit_train_batches=hyperparameter_configuration.train_steps_per_cycle, + limit_val_batches=hyperparameter_configuration.validation_steps_per_cycle, + ) + trainer.fit(model=lightning_model, train_dataloaders=train_dataloader, val_dataloaders=validation_dataloaders) diff --git a/src/qusi/internal/module.py b/src/qusi/internal/module.py new file mode 100644 index 00000000..506b58f1 --- /dev/null +++ b/src/qusi/internal/module.py @@ -0,0 +1,111 @@ +from typing import Any + +import numpy as np +import numpy.typing as npt +from lightning import LightningModule +from lightning.pytorch.utilities.types import STEP_OUTPUT +from torch.nn import Module, BCELoss, ModuleList +from torch.optim import Optimizer, AdamW +from torchmetrics import Metric +from torchmetrics.classification import BinaryAccuracy, BinaryAUROC +from typing_extensions import Self + +from qusi.internal.logging import get_metric_name + + +class QusiLightningModule(LightningModule): + @classmethod + def new( + cls, + model: Module, + optimizer: Optimizer | None, + loss_metric: Module | None = None, + logging_metrics: list[Module] | None = None, + ) -> Self: + if optimizer is None: + optimizer = AdamW(model.parameters()) + if loss_metric is None: + loss_metric = BCELoss() + if logging_metrics is None: + logging_metrics = [BinaryAccuracy(), BinaryAUROC()] + state_based_logging_metrics: ModuleList = ModuleList() + functional_logging_metrics: list[Module] = [] + for logging_metric in logging_metrics: + if isinstance(logging_metric, Metric): + state_based_logging_metrics.append(logging_metric) + else: + functional_logging_metrics.append(logging_metric) + instance = cls(model=model, optimizer=optimizer, loss_metric=loss_metric, + state_based_logging_metrics=state_based_logging_metrics, + functional_logging_metrics=functional_logging_metrics) + return instance + + def __init__( + self, + model: Module, + optimizer: Optimizer, + loss_metric: Module, + state_based_logging_metrics: ModuleList, + functional_logging_metrics: list[Module], + ): + super().__init__() + self.model: Module = model + self._optimizer: Optimizer = optimizer + self.loss_metric: Module = loss_metric + self.state_based_logging_metrics: ModuleList = state_based_logging_metrics + self.functional_logging_metrics: list[Module] = functional_logging_metrics + self._functional_logging_metric_cycle_totals: npt.NDArray = np.zeros(len(self.functional_logging_metrics), + dtype=np.float32) + self._loss_cycle_total: int = 0 + self._steps_run_in_cycle: int = 0 + + def forward(self, inputs: Any) -> Any: + return self.model(inputs) + + def training_step(self, batch: tuple[Any, Any], batch_index: int) -> STEP_OUTPUT: + return self.compute_loss_and_metrics(batch) + + def compute_loss_and_metrics(self, batch): + inputs, target = batch + predicted = self(inputs) + loss = self.loss_metric(predicted, target) + self._loss_cycle_total += loss + for state_based_logging_metric in self.state_based_logging_metrics: + state_based_logging_metric(predicted, target) + for functional_logging_metric_index, functional_logging_metric in enumerate(self.functional_logging_metrics): + functional_logging_metric_value = functional_logging_metric(predicted, target) + self._functional_logging_metric_cycle_totals[ + functional_logging_metric_index] += functional_logging_metric_value + self._steps_run_in_cycle += 1 + return loss + + def on_train_epoch_end(self) -> None: + self.log_loss_and_metrics() + + def log_loss_and_metrics(self, logging_name_prefix: str = ''): + for state_based_logging_metric in self.state_based_logging_metrics: + state_based_logging_metric_name = get_metric_name(state_based_logging_metric) + self.log(name=logging_name_prefix + state_based_logging_metric_name, + value=state_based_logging_metric.compute(), sync_dist=True) + state_based_logging_metric.reset() + for functional_logging_metric_index, functional_logging_metric in enumerate(self.functional_logging_metrics): + functional_logging_metric_name = get_metric_name(functional_logging_metric) + functional_logging_metric_cycle_total = float(self._functional_logging_metric_cycle_totals[ + functional_logging_metric_index]) + + functional_logging_metric_cycle_mean = functional_logging_metric_cycle_total / self._steps_run_in_cycle + self.log(name=logging_name_prefix + functional_logging_metric_name, + value=functional_logging_metric_cycle_mean, + sync_dist=True) + mean_cycle_loss = self._loss_cycle_total / self._steps_run_in_cycle + self.log(name=logging_name_prefix + 'loss', + value=mean_cycle_loss, sync_dist=True) + self._loss_cycle_total = 0 + self._functional_logging_metric_cycle_totals = np.zeros(len(self.functional_logging_metrics), dtype=np.float32) + self._steps_run_in_cycle = 0 + + def validation_step(self, batch: tuple[Any, Any], batch_index: int) -> STEP_OUTPUT: + return self.compute_loss_and_metrics(batch) + + def configure_optimizers(self): + return self._optimizer diff --git a/tests/end_to_end_tests/test_toy_train_lightning_session.py b/tests/end_to_end_tests/test_toy_train_lightning_session.py new file mode 100644 index 00000000..750d747e --- /dev/null +++ b/tests/end_to_end_tests/test_toy_train_lightning_session.py @@ -0,0 +1,28 @@ +import os +from functools import partial + +from qusi.internal.light_curve_dataset import ( + default_light_curve_observation_post_injection_transform, +) +from qusi.internal.single_dense_layer_model import SingleDenseLayerBinaryClassificationModel +from qusi.internal.toy_light_curve_collection import get_toy_dataset +from qusi.internal.train_hyperparameter_configuration import TrainHyperparameterConfiguration +from qusi.internal.lightning_train_session import train_session + + +def test_toy_train_session(): + os.environ["WANDB_MODE"] = "disabled" + model = SingleDenseLayerBinaryClassificationModel.new(input_size=100) + dataset = get_toy_dataset() + dataset.post_injection_transform = partial( + default_light_curve_observation_post_injection_transform, length=100 + ) + train_hyperparameter_configuration = TrainHyperparameterConfiguration.new( + batch_size=3, cycles=2, train_steps_per_cycle=5, validation_steps_per_cycle=5 + ) + train_session( + train_datasets=[dataset], + validation_datasets=[dataset], + model=model, + hyperparameter_configuration=train_hyperparameter_configuration, + )