diff --git a/_modules/cyto_dl/datamodules/dataframe/dataframe_datamodule.html b/_modules/cyto_dl/datamodules/dataframe/dataframe_datamodule.html index b07c25fd..0cce5261 100644 --- a/_modules/cyto_dl/datamodules/dataframe/dataframe_datamodule.html +++ b/_modules/cyto_dl/datamodules/dataframe/dataframe_datamodule.html @@ -587,6 +587,10 @@
import torch
from lightning import Callback, LightningDataModule, LightningModule, Trainer
from lightning.pytorch.loggers.logger import Logger
+from lightning.pytorch.tuner import Tuner
from omegaconf import DictConfig, OmegaConf
from cyto_dl import utils
@@ -499,6 +500,12 @@ Source code for cyto_dl.train
utils.remove_aux_key(cfg)
log.info(f"Instantiating data <{cfg.data.get('_target_', cfg.data)}>")
+
+ use_batch_tuner = False
+ if cfg.data.get("batch_size") == "AUTO":
+ use_batch_tuner = True
+ cfg.data.batch_size = 1
+
data = utils.create_dataloader(cfg.data, data)
if not isinstance(data, LightningDataModule):
if not isinstance(data, MutableMapping) or "train_dataloaders" not in data:
@@ -535,6 +542,10 @@ Source code for cyto_dl.train
log.info("Logging hyperparameters!")
utils.log_hyperparameters(object_dict)
+ if use_batch_tuner:
+ tuner = Tuner(trainer=trainer)
+ tuner.scale_batch_size(model, datamodule=data, mode="power")
+
if cfg.get("train"):
log.info("Starting training!")
model, load_params = utils.load_checkpoint(model, cfg.get("checkpoint"))