diff --git a/cyto_dl/datamodules/dataframe/dataframe_datamodule.py b/cyto_dl/datamodules/dataframe/dataframe_datamodule.py index 28850de0e..384aba2db 100644 --- a/cyto_dl/datamodules/dataframe/dataframe_datamodule.py +++ b/cyto_dl/datamodules/dataframe/dataframe_datamodule.py @@ -145,6 +145,10 @@ def __init__( self.subsample = subsample or {} self.refresh_subsample = refresh_subsample + self.batch_size = dataloader_kwargs.get("batch_size", 1) + # init size is used to check if the batch size has changed (used for Automatic batch size finder) + self._init_size = self.batch_size + for key in list(self.subsample.keys()): self.subsample[get_canonical_split_name(key)] = self.subsample[key] @@ -161,16 +165,25 @@ def get_dataset(self, split): def make_dataloader(self, split): kwargs = dict(**self.dataloader_kwargs) kwargs["shuffle"] = kwargs.get("shuffle", True) and split == "train" + kwargs["batch_size"] = self.batch_size + subset = self.get_dataset(split) return DataLoader(dataset=subset, **kwargs) def get_dataloader(self, split): sample_size = self.subsample.get(split, -1) - if (split not in self.dataloaders) or (sample_size != -1 and self.refresh_subsample): + if ( + (split not in self.dataloaders) + or (sample_size != -1 and self.refresh_subsample) + # check if batch size has changed (used for Automatic batch size finder) + or (self._init_size != self.batch_size) + ): # if we want to use a subsample per epoch, we need to remake the # dataloader, to refresh the sample self.dataloaders[split] = self.make_dataloader(split) + # reset the init size to the current batch size so dataloader isn't recreated every epoch + self._init_size = self.batch_size return self.dataloaders[split] diff --git a/cyto_dl/train.py b/cyto_dl/train.py index ee359a913..308bef050 100644 --- a/cyto_dl/train.py +++ b/cyto_dl/train.py @@ -11,6 +11,7 @@ import torch from lightning import Callback, LightningDataModule, LightningModule, Trainer from lightning.pytorch.loggers.logger import Logger +from lightning.pytorch.tuner import Tuner from omegaconf import DictConfig, OmegaConf from cyto_dl import utils @@ -57,6 +58,12 @@ def train(cfg: DictConfig, data=None) -> Tuple[dict, dict]: utils.remove_aux_key(cfg) log.info(f"Instantiating data <{cfg.data.get('_target_', cfg.data)}>") + + use_batch_tuner = False + if cfg.data.batch_size == "AUTO": + use_batch_tuner = True + cfg.data.batch_size = 1 + data = utils.create_dataloader(cfg.data, data) if not isinstance(data, LightningDataModule): if not isinstance(data, MutableMapping) or "train_dataloaders" not in data: @@ -93,6 +100,10 @@ def train(cfg: DictConfig, data=None) -> Tuple[dict, dict]: log.info("Logging hyperparameters!") utils.log_hyperparameters(object_dict) + if use_batch_tuner: + tuner = Tuner(trainer=trainer) + tuner.scale_batch_size(model, datamodule=data, mode="power") + if cfg.get("train"): log.info("Starting training!") model, load_params = utils.load_checkpoint(model, cfg.get("checkpoint"))