diff --git a/src/qusi/internal/module.py b/src/qusi/internal/module.py index 8e63635..daa185d 100644 --- a/src/qusi/internal/module.py +++ b/src/qusi/internal/module.py @@ -48,7 +48,7 @@ def new( model: Module, optimizer: Optimizer | None, loss_metric: Module | None = None, - logging_metrics: list[Module] | None = None, + logging_metrics: ModuleList | None = None, ) -> Self: if optimizer is None: optimizer = AdamW(model.parameters()) @@ -66,7 +66,7 @@ def new( train_metric_group = MetricGroup.new(loss_metric, state_based_logging_metrics, functional_logging_metrics) validation_metric_group = MetricGroup.new(loss_metric, state_based_logging_metrics, functional_logging_metrics) instance = cls(model=model, optimizer=optimizer, train_metric_group=train_metric_group, - validation_metric_groups=[validation_metric_group]) + validation_metric_groups=ModuleList([validation_metric_group])) return instance def __init__( @@ -74,13 +74,13 @@ def __init__( model: Module, optimizer: Optimizer, train_metric_group: MetricGroup, - validation_metric_groups: list[MetricGroup], + validation_metric_groups: ModuleList, ): super().__init__() self.model: Module = model self._optimizer: Optimizer = optimizer - self.train_metric_group = train_metric_group - self.validation_metric_groups: list[MetricGroup] = validation_metric_groups + self.train_metric_group: MetricGroup = train_metric_group + self.validation_metric_groups: ModuleList | list[MetricGroup] = validation_metric_groups def forward(self, inputs: Any) -> Any: return self.model(inputs) @@ -104,37 +104,37 @@ def compute_loss_and_metrics(self, batch: tuple[Any, Any], metric_group: MetricG return loss def on_train_epoch_end(self) -> None: - self.log_loss_and_metrics() + self.log_loss_and_metrics(self.train_metric_group) - def log_loss_and_metrics(self, logging_name_prefix: str = ''): - for state_based_logging_metric in self.train_state_based_logging_metrics: + def log_loss_and_metrics(self, metric_group: MetricGroup, logging_name_prefix: str = ''): + for state_based_logging_metric in metric_group.state_based_logging_metrics: state_based_logging_metric_name = get_metric_name(state_based_logging_metric) self.log(name=logging_name_prefix + state_based_logging_metric_name, value=state_based_logging_metric.compute(), sync_dist=True) state_based_logging_metric.reset() for functional_logging_metric_index, functional_logging_metric in enumerate( - self.train_functional_logging_metrics): + metric_group.functional_logging_metrics): functional_logging_metric_name = get_metric_name(functional_logging_metric) - functional_logging_metric_cycle_total = float(self._train_functional_logging_metric_cycle_totals[ + functional_logging_metric_cycle_total = float(metric_group.functional_logging_metric_cycle_totals[ functional_logging_metric_index]) - functional_logging_metric_cycle_mean = functional_logging_metric_cycle_total / self._train_steps_run_in_phase + functional_logging_metric_cycle_mean = functional_logging_metric_cycle_total / metric_group.steps_run_in_phase self.log(name=logging_name_prefix + functional_logging_metric_name, value=functional_logging_metric_cycle_mean, sync_dist=True) - mean_cycle_loss = self._train_loss_cycle_total / self._train_steps_run_in_phase + mean_cycle_loss = metric_group.loss_cycle_total / metric_group.steps_run_in_phase self.log(name=logging_name_prefix + 'loss', value=mean_cycle_loss, sync_dist=True) - self._train_loss_cycle_total = 0 - self._train_functional_logging_metric_cycle_totals = np.zeros(len(self.train_functional_logging_metrics), + metric_group.loss_cycle_total = 0 + metric_group.functional_logging_metric_cycle_totals = np.zeros(len(metric_group.functional_logging_metrics), dtype=np.float32) - self._train_steps_run_in_phase = 0 + metric_group.steps_run_in_phase = 0 def validation_step(self, batch: tuple[Any, Any], batch_index: int) -> STEP_OUTPUT: return self.compute_loss_and_metrics(batch, self.validation_metric_groups[0]) def on_validation_epoch_end(self) -> None: - self.log_loss_and_metrics(logging_name_prefix='val_') + self.log_loss_and_metrics(self.validation_metric_groups[0], logging_name_prefix='val_') def configure_optimizers(self): return self._optimizer diff --git a/tests/end_to_end_tests/test_toy_lightning_train_session.py b/tests/end_to_end_tests/test_toy_lightning_train_session.py new file mode 100644 index 0000000..750d747 --- /dev/null +++ b/tests/end_to_end_tests/test_toy_lightning_train_session.py @@ -0,0 +1,28 @@ +import os +from functools import partial + +from qusi.internal.light_curve_dataset import ( + default_light_curve_observation_post_injection_transform, +) +from qusi.internal.single_dense_layer_model import SingleDenseLayerBinaryClassificationModel +from qusi.internal.toy_light_curve_collection import get_toy_dataset +from qusi.internal.train_hyperparameter_configuration import TrainHyperparameterConfiguration +from qusi.internal.lightning_train_session import train_session + + +def test_toy_train_session(): + os.environ["WANDB_MODE"] = "disabled" + model = SingleDenseLayerBinaryClassificationModel.new(input_size=100) + dataset = get_toy_dataset() + dataset.post_injection_transform = partial( + default_light_curve_observation_post_injection_transform, length=100 + ) + train_hyperparameter_configuration = TrainHyperparameterConfiguration.new( + batch_size=3, cycles=2, train_steps_per_cycle=5, validation_steps_per_cycle=5 + ) + train_session( + train_datasets=[dataset], + validation_datasets=[dataset], + model=model, + hyperparameter_configuration=train_hyperparameter_configuration, + )