Skip to content

Commit

Permalink
Fix last epoch validation loss not saving
Browse files Browse the repository at this point in the history
  • Loading branch information
elseml committed Mar 1, 2024
1 parent 9a5b125 commit 56a8dd5
Showing 1 changed file with 6 additions and 7 deletions.
13 changes: 6 additions & 7 deletions bayesflow/trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,8 +454,8 @@ def train_online(
p_bar.update(1)

# Store and compute validation loss, if specified
self._save_trainer(save_checkpoint)
self._validation(ep, validation_sims, **kwargs)
self._save_trainer(save_checkpoint)

# Check early stopping, if specified
if self._check_early_stopping(early_stopper):
Expand Down Expand Up @@ -579,13 +579,13 @@ def train_offline(
# Format for display on progress bar
disp_str = format_loss_string(ep, bi, loss, avg_dict, lr=lr, it_str="Batch")

# Update progress
# Update progress bar
p_bar.set_postfix_str(disp_str, refresh=False)
p_bar.update(1)

# Store and compute validation loss, if specified
self._save_trainer(save_checkpoint)
self._validation(ep, validation_sims, **kwargs)
self._save_trainer(save_checkpoint)

# Check early stopping, if specified
if self._check_early_stopping(early_stopper):
Expand Down Expand Up @@ -762,15 +762,14 @@ def train_from_presimulation(
p_bar.update(1)

# Store after each epoch, if specified
self._save_trainer(save_checkpoint)

self._validation(ep, validation_sims, **kwargs)
self._save_trainer(save_checkpoint)

# Check early stopping, if specified
if self._check_early_stopping(early_stopper):
break

# Remove reference to optimizer, if not set to persistent
# Remove optimizer reference, if not set as persistent
if not reuse_optimizer:
self.optimizer = None
return self.loss_history.get_plottable()
Expand Down Expand Up @@ -906,8 +905,8 @@ def train_experience_replay(
p_bar.update(1)

# Store and compute validation loss, if specified
self._save_trainer(save_checkpoint)
self._validation(ep, validation_sims, **kwargs)
self._save_trainer(save_checkpoint)

# Check early stopping, if specified
if self._check_early_stopping(early_stopper):
Expand Down

0 comments on commit 56a8dd5

Please sign in to comment.