Skip to content

Commit

Permalink
change batch size for dataframe datamodule
Browse files Browse the repository at this point in the history
  • Loading branch information
Benjamin Morris committed Sep 20, 2024
1 parent 48e7cb9 commit 0c4bdca
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 1 deletion.
15 changes: 14 additions & 1 deletion cyto_dl/datamodules/dataframe/dataframe_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,10 @@ def __init__(
self.subsample = subsample or {}
self.refresh_subsample = refresh_subsample

self.batch_size = dataloader_kwargs.get("batch_size", 1)
# init size is used to check if the batch size has changed (used for Automatic batch size finder)
self._init_size = self.batch_size

for key in list(self.subsample.keys()):
self.subsample[get_canonical_split_name(key)] = self.subsample[key]

Expand All @@ -161,16 +165,25 @@ def get_dataset(self, split):
def make_dataloader(self, split):
kwargs = dict(**self.dataloader_kwargs)
kwargs["shuffle"] = kwargs.get("shuffle", True) and split == "train"
kwargs["batch_size"] = self.batch_size

subset = self.get_dataset(split)
return DataLoader(dataset=subset, **kwargs)

def get_dataloader(self, split):
sample_size = self.subsample.get(split, -1)

if (split not in self.dataloaders) or (sample_size != -1 and self.refresh_subsample):
if (
(split not in self.dataloaders)
or (sample_size != -1 and self.refresh_subsample)
# check if batch size has changed (used for Automatic batch size finder)
or (self._init_size != self.batch_size)
):
# if we want to use a subsample per epoch, we need to remake the
# dataloader, to refresh the sample
self.dataloaders[split] = self.make_dataloader(split)
# reset the init size to the current batch size so dataloader isn't recreated every epoch
self._init_size = self.batch_size

return self.dataloaders[split]

Expand Down
11 changes: 11 additions & 0 deletions cyto_dl/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import torch
from lightning import Callback, LightningDataModule, LightningModule, Trainer
from lightning.pytorch.loggers.logger import Logger
from lightning.pytorch.tuner import Tuner
from omegaconf import DictConfig, OmegaConf

from cyto_dl import utils
Expand Down Expand Up @@ -57,6 +58,12 @@ def train(cfg: DictConfig, data=None) -> Tuple[dict, dict]:
utils.remove_aux_key(cfg)

log.info(f"Instantiating data <{cfg.data.get('_target_', cfg.data)}>")

use_batch_tuner = False
if cfg.data.batch_size == "AUTO":
use_batch_tuner = True
cfg.data.batch_size = 1

data = utils.create_dataloader(cfg.data, data)
if not isinstance(data, LightningDataModule):
if not isinstance(data, MutableMapping) or "train_dataloaders" not in data:
Expand Down Expand Up @@ -93,6 +100,10 @@ def train(cfg: DictConfig, data=None) -> Tuple[dict, dict]:
log.info("Logging hyperparameters!")
utils.log_hyperparameters(object_dict)

if use_batch_tuner:
tuner = Tuner(trainer=trainer)
tuner.scale_batch_size(model, datamodule=data, mode="power")

if cfg.get("train"):
log.info("Starting training!")
model, load_params = utils.load_checkpoint(model, cfg.get("checkpoint"))
Expand Down

0 comments on commit 0c4bdca

Please sign in to comment.