diff --git a/_modules/cyto_dl/datamodules/dataframe/dataframe_datamodule.html b/_modules/cyto_dl/datamodules/dataframe/dataframe_datamodule.html index b07c25fd..0cce5261 100644 --- a/_modules/cyto_dl/datamodules/dataframe/dataframe_datamodule.html +++ b/_modules/cyto_dl/datamodules/dataframe/dataframe_datamodule.html @@ -587,6 +587,10 @@

Source code for cyto_dl.datamodules.dataframe.dataframe_datamodule

self.subsample = subsample or {} self.refresh_subsample = refresh_subsample + self.batch_size = dataloader_kwargs.get("batch_size", 1) + # init size is used to check if the batch size has changed (used for Automatic batch size finder) + self._init_size = self.batch_size + for key in list(self.subsample.keys()): self.subsample[get_canonical_split_name(key)] = self.subsample[key] @@ -603,16 +607,25 @@

Source code for cyto_dl.datamodules.dataframe.dataframe_datamodule

[docs] def make_dataloader(self, split): kwargs = dict(**self.dataloader_kwargs) kwargs["shuffle"] = kwargs.get("shuffle", True) and split == "train" + kwargs["batch_size"] = self.batch_size + subset = self.get_dataset(split) return DataLoader(dataset=subset, **kwargs)
[docs] def get_dataloader(self, split): sample_size = self.subsample.get(split, -1) - if (split not in self.dataloaders) or (sample_size != -1 and self.refresh_subsample): + if ( + (split not in self.dataloaders) + or (sample_size != -1 and self.refresh_subsample) + # check if batch size has changed (used for Automatic batch size finder) + or (self._init_size != self.batch_size) + ): # if we want to use a subsample per epoch, we need to remake the # dataloader, to refresh the sample self.dataloaders[split] = self.make_dataloader(split) + # reset the init size to the current batch size so dataloader isn't recreated every epoch + self._init_size = self.batch_size return self.dataloaders[split]
diff --git a/_modules/cyto_dl/train.html b/_modules/cyto_dl/train.html index 473bd86f..7ba471c2 100644 --- a/_modules/cyto_dl/train.html +++ b/_modules/cyto_dl/train.html @@ -453,6 +453,7 @@

Source code for cyto_dl.train

 import torch
 from lightning import Callback, LightningDataModule, LightningModule, Trainer
 from lightning.pytorch.loggers.logger import Logger
+from lightning.pytorch.tuner import Tuner
 from omegaconf import DictConfig, OmegaConf
 
 from cyto_dl import utils
@@ -499,6 +500,12 @@ 

Source code for cyto_dl.train

     utils.remove_aux_key(cfg)
 
     log.info(f"Instantiating data <{cfg.data.get('_target_', cfg.data)}>")
+
+    use_batch_tuner = False
+    if cfg.data.get("batch_size") == "AUTO":
+        use_batch_tuner = True
+        cfg.data.batch_size = 1
+
     data = utils.create_dataloader(cfg.data, data)
     if not isinstance(data, LightningDataModule):
         if not isinstance(data, MutableMapping) or "train_dataloaders" not in data:
@@ -535,6 +542,10 @@ 

Source code for cyto_dl.train

         log.info("Logging hyperparameters!")
         utils.log_hyperparameters(object_dict)
 
+    if use_batch_tuner:
+        tuner = Tuner(trainer=trainer)
+        tuner.scale_batch_size(model, datamodule=data, mode="power")
+
     if cfg.get("train"):
         log.info("Starting training!")
         model, load_params = utils.load_checkpoint(model, cfg.get("checkpoint"))