From eb35a144696978ad8f6132a2e9a83713367cfd79 Mon Sep 17 00:00:00 2001 From: elseml <60779710+elseml@users.noreply.github.com> Date: Tue, 27 Feb 2024 17:05:20 +0100 Subject: [PATCH 1/2] Include shared context in MultiSimulationDataset for offline training --- bayesflow/helper_classes.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/bayesflow/helper_classes.py b/bayesflow/helper_classes.py index 24c75e312..d55f2a923 100644 --- a/bayesflow/helper_classes.py +++ b/bayesflow/helper_classes.py @@ -141,10 +141,18 @@ def __init__(self, forward_dict, batch_size, buffer_size=1024): self.iters = [iter(d) for d in self.datasets] self.batch_size = batch_size + # Include further keys (= shared context) from forward_dict + self.further_keys = {} + for key, value in forward_dict.items(): + if key not in [DEFAULT_KEYS["model_outputs"], DEFAULT_KEYS["model_indices"]]: + self.further_keys[key] = value + def __next__(self): if self.current_it < self.num_batches: outputs = [next(d) for d in self.iters] output_dict = {DEFAULT_KEYS["model_outputs"]: outputs, DEFAULT_KEYS["model_indices"]: self.model_indices} + if self.further_keys: + output_dict.update(self.further_keys) self.current_it += 1 return output_dict self.current_it = 0 From 56a8dd5ec9708c5084f836274d77096409a15201 Mon Sep 17 00:00:00 2001 From: elseml <60779710+elseml@users.noreply.github.com> Date: Fri, 1 Mar 2024 14:10:15 +0100 Subject: [PATCH 2/2] Fix last epoch validation loss not saving --- bayesflow/trainers.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/bayesflow/trainers.py b/bayesflow/trainers.py index 1f717f5e8..4f7dc2091 100644 --- a/bayesflow/trainers.py +++ b/bayesflow/trainers.py @@ -454,8 +454,8 @@ def train_online( p_bar.update(1) # Store and compute validation loss, if specified - self._save_trainer(save_checkpoint) self._validation(ep, validation_sims, **kwargs) + self._save_trainer(save_checkpoint) # Check early stopping, if specified if self._check_early_stopping(early_stopper): @@ -579,13 +579,13 @@ def train_offline( # Format for display on progress bar disp_str = format_loss_string(ep, bi, loss, avg_dict, lr=lr, it_str="Batch") - # Update progress + # Update progress bar p_bar.set_postfix_str(disp_str, refresh=False) p_bar.update(1) # Store and compute validation loss, if specified - self._save_trainer(save_checkpoint) self._validation(ep, validation_sims, **kwargs) + self._save_trainer(save_checkpoint) # Check early stopping, if specified if self._check_early_stopping(early_stopper): @@ -762,15 +762,14 @@ def train_from_presimulation( p_bar.update(1) # Store after each epoch, if specified - self._save_trainer(save_checkpoint) - self._validation(ep, validation_sims, **kwargs) + self._save_trainer(save_checkpoint) # Check early stopping, if specified if self._check_early_stopping(early_stopper): break - # Remove reference to optimizer, if not set to persistent + # Remove optimizer reference, if not set as persistent if not reuse_optimizer: self.optimizer = None return self.loss_history.get_plottable() @@ -906,8 +905,8 @@ def train_experience_replay( p_bar.update(1) # Store and compute validation loss, if specified - self._save_trainer(save_checkpoint) self._validation(ep, validation_sims, **kwargs) + self._save_trainer(save_checkpoint) # Check early stopping, if specified if self._check_early_stopping(early_stopper):