diff --git a/src/qusi/internal/module.py b/src/qusi/internal/module.py index c793aa3..8e63635 100644 --- a/src/qusi/internal/module.py +++ b/src/qusi/internal/module.py @@ -1,3 +1,4 @@ +import copy from typing import Any import numpy as np @@ -13,6 +14,33 @@ from qusi.internal.logging import get_metric_name +class MetricGroup(Module): + def __init__(self, loss_metric: Module, state_based_logging_metrics: ModuleList, + functional_logging_metrics: ModuleList): + super().__init__() + self.loss_metric: Module = loss_metric + self.state_based_logging_metrics: ModuleList = state_based_logging_metrics + self.functional_logging_metrics: ModuleList = functional_logging_metrics + self.loss_cycle_total: float = 0 + self.steps_run_in_phase: int = 0 + self.functional_logging_metric_cycle_totals: npt.NDArray = np.zeros( + len(self.functional_logging_metrics), dtype=np.float32) + + @classmethod + def new( + cls, + loss_metric: Module, + state_based_logging_metrics: ModuleList, + functional_logging_metrics: ModuleList + ) -> Self: + loss_metric_: Module = copy.deepcopy(loss_metric) + state_based_logging_metrics_: ModuleList = copy.deepcopy(state_based_logging_metrics) + functional_logging_metrics_: ModuleList = copy.deepcopy(functional_logging_metrics) + instance = cls(loss_metric=loss_metric_, state_based_logging_metrics=state_based_logging_metrics_, + functional_logging_metrics=functional_logging_metrics_) + return instance + + class QusiLightningModule(LightningModule): @classmethod def new( @@ -29,54 +57,50 @@ def new( if logging_metrics is None: logging_metrics = [BinaryAccuracy(), BinaryAUROC()] state_based_logging_metrics: ModuleList = ModuleList() - functional_logging_metrics: list[Module] = [] + functional_logging_metrics: ModuleList = ModuleList() for logging_metric in logging_metrics: if isinstance(logging_metric, Metric): state_based_logging_metrics.append(logging_metric) else: functional_logging_metrics.append(logging_metric) - instance = cls(model=model, optimizer=optimizer, loss_metric=loss_metric, - state_based_logging_metrics=state_based_logging_metrics, - functional_logging_metrics=functional_logging_metrics) + train_metric_group = MetricGroup.new(loss_metric, state_based_logging_metrics, functional_logging_metrics) + validation_metric_group = MetricGroup.new(loss_metric, state_based_logging_metrics, functional_logging_metrics) + instance = cls(model=model, optimizer=optimizer, train_metric_group=train_metric_group, + validation_metric_groups=[validation_metric_group]) return instance def __init__( self, model: Module, optimizer: Optimizer, - loss_metric: Module, - state_based_logging_metrics: ModuleList, - functional_logging_metrics: list[Module], + train_metric_group: MetricGroup, + validation_metric_groups: list[MetricGroup], ): super().__init__() self.model: Module = model self._optimizer: Optimizer = optimizer - self.loss_metric: Module = loss_metric - self.train_state_based_logging_metrics: ModuleList = state_based_logging_metrics - self.train_functional_logging_metrics: list[Module] = functional_logging_metrics - self._train_functional_logging_metric_cycle_totals: npt.NDArray = np.zeros( - len(self.train_functional_logging_metrics), dtype=np.float32) - self._loss_cycle_total: int = 0 - self._steps_run_in_phase: int = 0 + self.train_metric_group = train_metric_group + self.validation_metric_groups: list[MetricGroup] = validation_metric_groups def forward(self, inputs: Any) -> Any: return self.model(inputs) def training_step(self, batch: tuple[Any, Any], batch_index: int) -> STEP_OUTPUT: - return self.compute_loss_and_metrics(batch) + return self.compute_loss_and_metrics(batch, self.train_metric_group) - def compute_loss_and_metrics(self, batch): + def compute_loss_and_metrics(self, batch: tuple[Any, Any], metric_group: MetricGroup): inputs, target = batch predicted = self(inputs) - loss = self.loss_metric(predicted, target) - self._loss_cycle_total += loss - for state_based_logging_metric in self.train_state_based_logging_metrics: + loss = metric_group.loss_metric(predicted, target) + metric_group.loss_cycle_total += loss + for state_based_logging_metric in metric_group.state_based_logging_metrics: state_based_logging_metric(predicted, target) - for functional_logging_metric_index, functional_logging_metric in enumerate(self.train_functional_logging_metrics): + for functional_logging_metric_index, functional_logging_metric in enumerate( + metric_group.functional_logging_metrics): functional_logging_metric_value = functional_logging_metric(predicted, target) - self._train_functional_logging_metric_cycle_totals[ + metric_group.functional_logging_metric_cycle_totals[ functional_logging_metric_index] += functional_logging_metric_value - self._steps_run_in_phase += 1 + metric_group.steps_run_in_phase += 1 return loss def on_train_epoch_end(self) -> None: @@ -88,24 +112,26 @@ def log_loss_and_metrics(self, logging_name_prefix: str = ''): self.log(name=logging_name_prefix + state_based_logging_metric_name, value=state_based_logging_metric.compute(), sync_dist=True) state_based_logging_metric.reset() - for functional_logging_metric_index, functional_logging_metric in enumerate(self.train_functional_logging_metrics): + for functional_logging_metric_index, functional_logging_metric in enumerate( + self.train_functional_logging_metrics): functional_logging_metric_name = get_metric_name(functional_logging_metric) functional_logging_metric_cycle_total = float(self._train_functional_logging_metric_cycle_totals[ functional_logging_metric_index]) - functional_logging_metric_cycle_mean = functional_logging_metric_cycle_total / self._steps_run_in_phase + functional_logging_metric_cycle_mean = functional_logging_metric_cycle_total / self._train_steps_run_in_phase self.log(name=logging_name_prefix + functional_logging_metric_name, value=functional_logging_metric_cycle_mean, sync_dist=True) - mean_cycle_loss = self._loss_cycle_total / self._steps_run_in_phase + mean_cycle_loss = self._train_loss_cycle_total / self._train_steps_run_in_phase self.log(name=logging_name_prefix + 'loss', value=mean_cycle_loss, sync_dist=True) - self._loss_cycle_total = 0 - self._train_functional_logging_metric_cycle_totals = np.zeros(len(self.train_functional_logging_metrics), dtype=np.float32) - self._steps_run_in_phase = 0 + self._train_loss_cycle_total = 0 + self._train_functional_logging_metric_cycle_totals = np.zeros(len(self.train_functional_logging_metrics), + dtype=np.float32) + self._train_steps_run_in_phase = 0 def validation_step(self, batch: tuple[Any, Any], batch_index: int) -> STEP_OUTPUT: - return self.compute_loss_and_metrics(batch) + return self.compute_loss_and_metrics(batch, self.validation_metric_groups[0]) def on_validation_epoch_end(self) -> None: self.log_loss_and_metrics(logging_name_prefix='val_') diff --git a/tests/unit_tests/test_lightning_module.py b/tests/unit_tests/test_lightning_module.py new file mode 100644 index 0000000..e2269d7 --- /dev/null +++ b/tests/unit_tests/test_lightning_module.py @@ -0,0 +1,60 @@ +from unittest.mock import Mock + +import torch +from torch.nn import ModuleList, Module +from torchmetrics import MeanSquaredError + +from qusi.internal.module import QusiLightningModule, MetricGroup + + +class MockStateBasedMetric(Mock, MeanSquaredError): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + +class MockFunctionalMetric(Mock, Module): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.side_effect: Module = MeanSquaredError() + + +def create_fake_qusi_lightning_module() -> QusiLightningModule: + qusi_lightning_module_mock = QusiLightningModule( + model=Mock(return_value=torch.tensor([1])), optimizer=Mock(), train_metric_group=Mock(), + validation_metric_groups=[Mock()] + ) + return qusi_lightning_module_mock + + +def create_fake_metric_group() -> MetricGroup: + fake_metric_group = MetricGroup(loss_metric=Mock(return_value=torch.tensor([1])), + state_based_logging_metrics=ModuleList([MockStateBasedMetric()]), + functional_logging_metrics=ModuleList([MockFunctionalMetric()])) + return fake_metric_group + + +def test_compute_loss_and_metrics_calls_passed_loss_metric(): + fake_qusi_lightning_module0 = create_fake_qusi_lightning_module() + fake_metric_group = create_fake_metric_group() + batch = (torch.tensor([3]), torch.tensor([4])) + assert not fake_metric_group.loss_metric.called + fake_qusi_lightning_module0.compute_loss_and_metrics(batch=batch, metric_group=fake_metric_group) + assert fake_metric_group.loss_metric.called + + +def test_compute_loss_and_metrics_uses_correct_phase_state_metric(): + fake_qusi_lightning_module0 = create_fake_qusi_lightning_module() + fake_metric_group = create_fake_metric_group() + batch = (torch.tensor([3]), torch.tensor([4])) + assert not fake_metric_group.state_based_logging_metrics[0].called + fake_qusi_lightning_module0.compute_loss_and_metrics(batch=batch, metric_group=fake_metric_group) + assert fake_metric_group.state_based_logging_metrics[0].called + + +def test_compute_loss_and_metrics_uses_correct_phase_functional_metric(): + fake_qusi_lightning_module0 = create_fake_qusi_lightning_module() + fake_metric_group = create_fake_metric_group() + batch = (torch.tensor([3]), torch.tensor([4])) + assert not fake_metric_group.functional_logging_metrics[0].called + fake_qusi_lightning_module0.compute_loss_and_metrics(batch=batch, metric_group=fake_metric_group) + assert fake_metric_group.functional_logging_metrics[0].called