From 6a07e88c3abf3549453c1c2d066c8989f0fbacfb Mon Sep 17 00:00:00 2001 From: golmschenk Date: Fri, 1 Nov 2024 23:29:10 -0400 Subject: [PATCH] Log cycle during both train and validation --- src/qusi/internal/module.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/qusi/internal/module.py b/src/qusi/internal/module.py index 7878f7d..a405926 100644 --- a/src/qusi/internal/module.py +++ b/src/qusi/internal/module.py @@ -101,7 +101,6 @@ def on_train_epoch_start(self) -> None: # Due to Lightning's inconsistent step ordering, performing this during the train epoch start gives the most # consistent results. self.cycle += 1 - self.log(name='cycle', value=self.cycle, reduce_fx=torch.max, rank_zero_only=True, on_step=False, on_epoch=True) def training_step(self, batch: tuple[Any, Any], batch_index: int) -> STEP_OUTPUT: return self.compute_loss_and_metrics(batch, self.train_metric_group) @@ -129,6 +128,7 @@ def log_loss_and_metrics(self, metric_group: MetricGroup, logging_name_prefix: s self.log(name=logging_name_prefix + 'loss', value=mean_cycle_loss, sync_dist=True, on_step=False, on_epoch=True) + self.log(name='cycle', value=self.cycle, reduce_fx=torch.max, rank_zero_only=True, on_step=False, on_epoch=True) for state_based_logging_metric in metric_group.state_based_logging_metrics: state_based_logging_metric_name = get_metric_name(state_based_logging_metric) self.log(name=logging_name_prefix + state_based_logging_metric_name,