Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(train): in case of last batch <=2, move to validation if possible #3036

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ to [Semantic Versioning]. Full commit history is available in the

#### Added

- Added adaptive handling for last training minibatch of 1-2 cells in case of
`datasplitter_kwargs={"drop_last": False}` and `train_size = None` by moving them into validation set, if available.
{pr}`3036`.

#### Fixed

- Breaking Change: Fix `get_outlier_cell_sample_pairs` function in {class}`scvi.external.MRVI`
Expand Down
77 changes: 55 additions & 22 deletions src/scvi/dataloaders/_data_splitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,12 @@
from scvi.utils._docstrings import devices_dsp


def validate_data_split(n_samples: int, train_size: float, validation_size: float | None = None):
def validate_data_split(
n_samples: int,
train_size: float,
validation_size: float | None = None,
batch_size_for_adaptive_last_batch: int | None = None,
):
"""Check data splitting parameters and return n_train and n_val.

Parameters
Expand All @@ -32,21 +37,14 @@ def validate_data_split(n_samples: int, train_size: float, validation_size: floa
Size of train set. Need to be: 0 < train_size <= 1.
validation_size
Size of validation set. Need to be 0 <= validation_size < 1
batch_size_for_adaptive_last_batch
batch size of each iteration. If `None`, do not do adaptive last batch sizing
"""
if train_size > 1.0 or train_size <= 0.0:
raise ValueError("Invalid train_size. Must be: 0 < train_size <= 1")

n_train = ceil(train_size * n_samples)

if n_train % settings.batch_size < 3 and n_train % settings.batch_size > 0:
warnings.warn(
f"Last batch will have a small size of {n_train % settings.batch_size}"
f"samples. Consider changing settings.batch_size or batch_size in model.train"
f"currently {settings.batch_size} to avoid errors during model training.",
UserWarning,
stacklevel=settings.warnings_stacklevel,
)

if validation_size is None:
n_val = n_samples - n_train
elif validation_size >= 1.0 or validation_size < 0.0:
Expand All @@ -59,10 +57,31 @@ def validate_data_split(n_samples: int, train_size: float, validation_size: floa
if n_train == 0:
raise ValueError(
f"With n_samples={n_samples}, train_size={train_size} and "
f"validation_size={validation_size}, the resulting train set will be empty. Adjust"
f"validation_size={validation_size}, the resulting train set will be empty. Adjust "
"any of the aforementioned parameters."
)

if batch_size_for_adaptive_last_batch is not None:
ori-kron-wis marked this conversation as resolved.
Show resolved Hide resolved
if (n_train % batch_size_for_adaptive_last_batch < 3 and
n_train % batch_size_for_adaptive_last_batch > 0):
num_of_cells = n_train % batch_size_for_adaptive_last_batch
warnings.warn(
f"Last batch will have a small size of {num_of_cells} "
f"samples. Consider changing settings.batch_size or batch_size in model.train "
f"from currently {batch_size_for_adaptive_last_batch} to avoid errors during model"
f" training. Those cells will be removed from the training set automatically",
UserWarning,
stacklevel=settings.warnings_stacklevel,
)
n_train -= num_of_cells
if n_val > 0:
n_val += num_of_cells
warnings.warn(
f" {num_of_cells} cells moved from training set to " f"validation set",
UserWarning,
stacklevel=settings.warnings_stacklevel,
)

return n_train, n_val


Expand Down Expand Up @@ -145,7 +164,7 @@ class DataSplitter(pl.LightningDataModule):
adata_manager
:class:`~scvi.data.AnnDataManager` object that has been created via ``setup_anndata``.
train_size
float, or None (default is 0.9)
float, or None (default is None, which is practicaly 0.9 + adaptive last batch)
ori-kron-wis marked this conversation as resolved.
Show resolved Hide resolved
validation_size
float, or None (default is None)
shuffle_set_split
Expand Down Expand Up @@ -182,7 +201,7 @@ class is :class:`~scvi.dataloaders.SemiSupervisedDataLoader`,
def __init__(
self,
adata_manager: AnnDataManager,
train_size: float = 0.9,
train_size: float | None = None,
validation_size: float | None = None,
shuffle_set_split: bool = True,
load_sparse_tensor: bool = False,
Expand All @@ -192,7 +211,8 @@ def __init__(
):
super().__init__()
self.adata_manager = adata_manager
self.train_size = float(train_size)
self.train_size_was_none = not bool(train_size)
ori-kron-wis marked this conversation as resolved.
Show resolved Hide resolved
self.train_size = 0.9 if self.train_size_was_none else float(train_size)
self.validation_size = validation_size
self.shuffle_set_split = shuffle_set_split
self.load_sparse_tensor = load_sparse_tensor
Expand All @@ -208,7 +228,11 @@ def __init__(
)
else:
self.n_train, self.n_val = validate_data_split(
self.adata_manager.adata.n_obs, self.train_size, self.validation_size
self.adata_manager.adata.n_obs,
self.train_size,
self.validation_size,
self.data_loader_kwargs.pop("batch_size", settings.batch_size) if
self.train_size_was_none and not self.drop_last else None,
)

def setup(self, stage: str | None = None):
Expand Down Expand Up @@ -298,7 +322,7 @@ class SemiSupervisedDataSplitter(pl.LightningDataModule):
adata_manager
:class:`~scvi.data.AnnDataManager` object that has been created via ``setup_anndata``.
train_size
float, or None (default is 0.9)
float, or None (default is None, which is practicaly 0.9 + adaptive last batch)
validation_size
float, or None (default is None)
shuffle_set_split
Expand Down Expand Up @@ -333,7 +357,7 @@ class is :class:`~scvi.dataloaders.SemiSupervisedDataLoader`,
def __init__(
self,
adata_manager: AnnDataManager,
train_size: float = 0.9,
train_size: float | None = None,
validation_size: float | None = None,
shuffle_set_split: bool = True,
n_samples_per_label: int | None = None,
Expand All @@ -343,7 +367,8 @@ def __init__(
):
super().__init__()
self.adata_manager = adata_manager
self.train_size = float(train_size)
self.train_size_was_none = not bool(train_size)
self.train_size = 0.9 if train_size is None else float(train_size)
self.validation_size = validation_size
self.shuffle_set_split = shuffle_set_split
self.drop_last = kwargs.pop("drop_last", False)
Expand Down Expand Up @@ -382,7 +407,11 @@ def setup(self, stage: str | None = None):
)
else:
n_labeled_train, n_labeled_val = validate_data_split(
n_labeled_idx, self.train_size, self.validation_size
n_labeled_idx,
self.train_size,
self.validation_size,
self.data_loader_kwargs.pop("batch_size", settings.batch_size) if
self.train_size_was_none and not self.drop_last else None,
)

labeled_permutation = self._labeled_indices
Expand Down Expand Up @@ -416,7 +445,11 @@ def setup(self, stage: str | None = None):
)
else:
n_unlabeled_train, n_unlabeled_val = validate_data_split(
n_unlabeled_idx, self.train_size, self.validation_size
n_unlabeled_idx,
self.train_size,
self.validation_size,
self.data_loader_kwargs.pop("batch_size",settings.batch_size) if
self.train_size_was_none and not self.drop_last else None,
)

unlabeled_permutation = self._unlabeled_indices
Expand Down Expand Up @@ -508,7 +541,7 @@ class DeviceBackedDataSplitter(DataSplitter):
adata_manager
:class:`~scvi.data.AnnDataManager` object that has been created via ``setup_anndata``.
train_size
float, or None (default is 0.9)
float, or None (default is None, which is practicaly 0.9 + adaptive last batch)
ori-kron-wis marked this conversation as resolved.
Show resolved Hide resolved
validation_size
float, or None (default is None)
%(param_accelerator)s
Expand Down Expand Up @@ -536,7 +569,7 @@ class DeviceBackedDataSplitter(DataSplitter):
def __init__(
self,
adata_manager: AnnDataManager,
train_size: float = 1.0,
train_size: float | None = None,
validation_size: float | None = None,
accelerator: str = "auto",
device: int | str = "auto",
Expand Down
2 changes: 1 addition & 1 deletion src/scvi/external/cellassign/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def train(
lr: float = 3e-3,
accelerator: str = "auto",
devices: int | list[int] | str = "auto",
train_size: float = 0.9,
train_size: float | None = None,
validation_size: float | None = None,
shuffle_set_split: bool = True,
batch_size: int = 1024,
Expand Down
14 changes: 11 additions & 3 deletions src/scvi/external/contrastivevi/_contrastive_data_splitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def __init__(
adata_manager: AnnDataManager,
background_indices: list[int],
target_indices: list[int],
train_size: float = 0.9,
train_size: float | None = None,
validation_size: float | None = None,
shuffle_set_split: bool = True,
load_sparse_tensor: bool = False,
Expand All @@ -78,10 +78,18 @@ def __init__(
self.n_target = len(target_indices)
if external_indexing is None:
self.n_background_train, self.n_background_val = validate_data_split(
self.n_background, self.train_size, self.validation_size
self.n_background,
self.train_size,
self.validation_size,
self.data_loader_kwargs.pop("batch_size", settings.batch_size) if
self.train_size_was_none and not self.drop_last else None,
)
self.n_target_train, self.n_target_val = validate_data_split(
self.n_target, self.train_size, self.validation_size
self.n_target,
self.train_size,
self.validation_size,
self.data_loader_kwargs.pop("batch_size", settings.batch_size) if
ori-kron-wis marked this conversation as resolved.
Show resolved Hide resolved
self.train_size_was_none and not self.drop_last else None,
)
else:
# we need to intersect the external indexing given with the bg/target indices
Expand Down
2 changes: 1 addition & 1 deletion src/scvi/external/contrastivevi/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def train(
max_epochs: int | None = None,
accelerator: str = "auto",
devices: int | list[int] | str = "auto",
train_size: float = 0.9,
train_size: float | None = None,
validation_size: float | None = None,
shuffle_set_split: bool = True,
load_sparse_tensor: bool = False,
Expand Down
2 changes: 1 addition & 1 deletion src/scvi/external/gimvi/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def train(
accelerator: str = "auto",
devices: int | list[int] | str = "auto",
kappa: int = 5,
train_size: float = 0.9,
train_size: float | None = None,
validation_size: float | None = None,
shuffle_set_split: bool = True,
batch_size: int = 128,
Expand Down
2 changes: 1 addition & 1 deletion src/scvi/external/mrvi/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def train(
max_epochs: int | None = None,
accelerator: str | None = "auto",
devices: int | list[int] | str = "auto",
train_size: float = 0.9,
train_size: float | None = None,
validation_size: float | None = None,
batch_size: int = 128,
early_stopping: bool = False,
Expand Down
2 changes: 1 addition & 1 deletion src/scvi/external/scbasset/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def train(
lr: float = 0.01,
accelerator: str = "auto",
devices: int | list[int] | str = "auto",
train_size: float = 0.9,
train_size: float | None = None,
validation_size: float | None = None,
shuffle_set_split: bool = True,
batch_size: int = 128,
Expand Down
2 changes: 1 addition & 1 deletion src/scvi/external/solo/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ def train(
lr: float = 1e-3,
accelerator: str = "auto",
devices: int | list[int] | str = "auto",
train_size: float = 0.9,
train_size: float | None = None,
validation_size: float | None = None,
shuffle_set_split: bool = True,
batch_size: int = 128,
Expand Down
2 changes: 1 addition & 1 deletion src/scvi/external/velovi/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def train(
weight_decay: float = 1e-2,
accelerator: str = "auto",
devices: int | list[int] | str = "auto",
train_size: float = 0.9,
train_size: float | None = None,
validation_size: float | None = None,
batch_size: int = 256,
early_stopping: bool = True,
Expand Down
2 changes: 1 addition & 1 deletion src/scvi/model/_multivi.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ def train(
lr: float = 1e-4,
accelerator: str = "auto",
devices: int | list[int] | str = "auto",
train_size: float = 0.9,
train_size: float | None = None,
validation_size: float | None = None,
shuffle_set_split: bool = True,
batch_size: int = 128,
Expand Down
2 changes: 1 addition & 1 deletion src/scvi/model/_peakvi.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def train(
lr: float = 1e-4,
accelerator: str = "auto",
devices: int | list[int] | str = "auto",
train_size: float = 0.9,
train_size: float | None = None,
validation_size: float | None = None,
shuffle_set_split: bool = True,
batch_size: int = 128,
Expand Down
2 changes: 1 addition & 1 deletion src/scvi/model/_scanvi.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,7 @@ def train(
max_epochs: int | None = None,
n_samples_per_label: float | None = None,
check_val_every_n_epoch: int | None = None,
train_size: float = 0.9,
train_size: float | None = None,
validation_size: float | None = None,
shuffle_set_split: bool = True,
batch_size: int = 128,
Expand Down
2 changes: 1 addition & 1 deletion src/scvi/model/_totalvi.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def train(
lr: float = 4e-3,
accelerator: str = "auto",
devices: int | list[int] | str = "auto",
train_size: float = 0.9,
train_size: float | None = None,
validation_size: float | None = None,
shuffle_set_split: bool = True,
batch_size: int = 256,
Expand Down
2 changes: 1 addition & 1 deletion src/scvi/model/base/_jaxmixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def train(
max_epochs: int | None = None,
accelerator: str = "auto",
devices: int | list[int] | str = "auto",
train_size: float = 0.9,
train_size: float | None = None,
validation_size: float | None = None,
shuffle_set_split: bool = True,
batch_size: int = 128,
Expand Down
2 changes: 1 addition & 1 deletion src/scvi/model/base/_pyromixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def train(
max_epochs: int | None = None,
accelerator: str = "auto",
device: int | str = "auto",
train_size: float = 0.9,
train_size: float | None = None,
validation_size: float | None = None,
shuffle_set_split: bool = True,
batch_size: int = 128,
Expand Down
2 changes: 1 addition & 1 deletion src/scvi/model/base/_training_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def train(
max_epochs: int | None = None,
accelerator: str = "auto",
devices: int | list[int] | str = "auto",
train_size: float = 0.9,
train_size: float | None = None,
validation_size: float | None = None,
shuffle_set_split: bool = True,
load_sparse_tensor: bool = False,
Expand Down
8 changes: 8 additions & 0 deletions tests/model/test_scvi.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,6 +481,14 @@ def test_scvi_n_obs_error(n_latent: int = 5):
# Warning is emitted if last batch less than 3 cells.
model.train(1, train_size=1.0, batch_size=127)
model.train(1, train_size=1.0, datasplitter_kwargs={"drop_last": True})

adata = synthetic_iid()
adata = adata[0:143].copy()
SCVI.setup_anndata(adata)
model = SCVI(adata, n_latent=n_latent)
with pytest.raises(ValueError):
model.train(1, train_size=0.9) # np.ceil(n_cells * 0.9) % 128 == 1
model.train(1)
assert model.is_trained is True


Expand Down
Loading