Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(train): in case of last batch <=2, move to validation if possible #3036

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,11 @@ to [Semantic Versioning]. Full commit history is available in the

#### Added

- Added adaptive handling for last training minibatch of 1-2 cells in case of
`datasplitter_kwargs={"drop_last": False}` and `train_size = None` by moving them into
validation set, if available.
{pr}`3036`.

#### Fixed

- Breaking Change: Fix `get_outlier_cell_sample_pairs` function in {class}`scvi.external.MRVI`
Expand Down
118 changes: 96 additions & 22 deletions src/scvi/dataloaders/_data_splitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,14 @@
from scvi.utils._docstrings import devices_dsp


def validate_data_split(n_samples: int, train_size: float, validation_size: float | None = None):
def validate_data_split(
n_samples: int,
train_size: float,
validation_size: float | None = None,
batch_size: int | None = None,
drop_last: bool | int = False,
train_size_is_none: bool | int = True,
):
"""Check data splitting parameters and return n_train and n_val.

Parameters
Expand All @@ -32,21 +39,18 @@ def validate_data_split(n_samples: int, train_size: float, validation_size: floa
Size of train set. Need to be: 0 < train_size <= 1.
validation_size
Size of validation set. Need to be 0 <= validation_size < 1
batch_size
batch size of each iteration. If `None`, do not minibatch
drop_last
drops last non-full batch
train_size_is_none
Whether the user did not explicitly input train_size
"""
if train_size > 1.0 or train_size <= 0.0:
raise ValueError("Invalid train_size. Must be: 0 < train_size <= 1")

n_train = ceil(train_size * n_samples)

if n_train % settings.batch_size < 3 and n_train % settings.batch_size > 0:
warnings.warn(
f"Last batch will have a small size of {n_train % settings.batch_size}"
f"samples. Consider changing settings.batch_size or batch_size in model.train"
f"currently {settings.batch_size} to avoid errors during model training.",
UserWarning,
stacklevel=settings.warnings_stacklevel,
)

if validation_size is None:
n_val = n_samples - n_train
elif validation_size >= 1.0 or validation_size < 0.0:
Expand All @@ -59,16 +63,41 @@ def validate_data_split(n_samples: int, train_size: float, validation_size: floa
if n_train == 0:
raise ValueError(
f"With n_samples={n_samples}, train_size={train_size} and "
f"validation_size={validation_size}, the resulting train set will be empty. Adjust"
f"validation_size={validation_size}, the resulting train set will be empty. Adjust "
"any of the aforementioned parameters."
)

if batch_size is not None:
num_of_cells = n_train % batch_size
if (num_of_cells < 3 and num_of_cells > 0) and not (
num_of_cells == 1 and drop_last is True
):
warnings.warn(
f"Last batch will have a small size of {num_of_cells} "
f"samples. Consider changing settings.batch_size or batch_size in model.train "
f"from currently {batch_size} to avoid errors during model training, "
f"or use drop_last parameter if there is 1 cell left",
UserWarning,
stacklevel=settings.warnings_stacklevel,
)
if train_size_is_none:
n_train -= num_of_cells
if n_val > 0:
n_val += num_of_cells
warnings.warn(
f"{num_of_cells} cells moved from training set to validation set",
UserWarning,
stacklevel=settings.warnings_stacklevel,
)

return n_train, n_val


def validate_data_split_with_external_indexing(
n_samples: int,
external_indexing: list[np.array, np.array, np.array] | None = None,
batch_size: int | None = None,
drop_last: bool | int = False,
):
"""Check data splitting parameters and return n_train and n_val.

Expand All @@ -79,6 +108,10 @@ def validate_data_split_with_external_indexing(
external_indexing
A list of data split indices in the order of training, validation, and test sets.
Validation and test set are not required and can be left empty.
batch_size
batch size of each iteration. If `None`, do not minibatch
drop_last
drops last non-full batch
"""
if not isinstance(external_indexing, list):
raise ValueError("External indexing is not of list type")
Expand Down Expand Up @@ -132,6 +165,21 @@ def validate_data_split_with_external_indexing(
n_train = len(external_indexing[0])
n_val = len(external_indexing[1])

if batch_size is not None:
num_of_cells = n_train % batch_size
if (num_of_cells < 3 and num_of_cells > 0) and not (
num_of_cells == 1 and drop_last is True
):
warnings.warn(
f"Last batch will have a small size of {num_of_cells} "
f"samples. Consider changing settings.batch_size or batch_size in model.train "
f"from currently {settings.batch_size} to avoid errors during model training "
f"or change the given external indices accordingly or use drop_last parameter if "
f"there is 1 cell left",
UserWarning,
stacklevel=settings.warnings_stacklevel,
)

return n_train, n_val


Expand All @@ -145,7 +193,8 @@ class DataSplitter(pl.LightningDataModule):
adata_manager
:class:`~scvi.data.AnnDataManager` object that has been created via ``setup_anndata``.
train_size
float, or None (default is 0.9)
float, or None (default is None, which is practicaly 0.9 and potentially adding small last
batch to validation cells)
validation_size
float, or None (default is None)
shuffle_set_split
Expand Down Expand Up @@ -182,7 +231,7 @@ class is :class:`~scvi.dataloaders.SemiSupervisedDataLoader`,
def __init__(
self,
adata_manager: AnnDataManager,
train_size: float = 0.9,
train_size: float | None = None,
validation_size: float | None = None,
shuffle_set_split: bool = True,
load_sparse_tensor: bool = False,
Expand All @@ -192,7 +241,8 @@ def __init__(
):
super().__init__()
self.adata_manager = adata_manager
self.train_size = float(train_size)
self.train_size_is_none = not bool(train_size)
self.train_size = 0.9 if self.train_size_is_none else float(train_size)
self.validation_size = validation_size
self.shuffle_set_split = shuffle_set_split
self.load_sparse_tensor = load_sparse_tensor
Expand All @@ -205,10 +255,17 @@ def __init__(
self.n_train, self.n_val = validate_data_split_with_external_indexing(
self.adata_manager.adata.n_obs,
self.external_indexing,
self.data_loader_kwargs.pop("batch_size", settings.batch_size),
ori-kron-wis marked this conversation as resolved.
Show resolved Hide resolved
self.drop_last,
)
else:
self.n_train, self.n_val = validate_data_split(
self.adata_manager.adata.n_obs, self.train_size, self.validation_size
self.adata_manager.adata.n_obs,
self.train_size,
self.validation_size,
self.data_loader_kwargs.pop("batch_size", settings.batch_size),
self.drop_last,
self.train_size_is_none,
)

def setup(self, stage: str | None = None):
Expand Down Expand Up @@ -298,7 +355,8 @@ class SemiSupervisedDataSplitter(pl.LightningDataModule):
adata_manager
:class:`~scvi.data.AnnDataManager` object that has been created via ``setup_anndata``.
train_size
float, or None (default is 0.9)
float, or None (default is None, which is practicaly 0.9 and potentially adding small last
batch to validation cells)
validation_size
float, or None (default is None)
shuffle_set_split
Expand Down Expand Up @@ -333,7 +391,7 @@ class is :class:`~scvi.dataloaders.SemiSupervisedDataLoader`,
def __init__(
self,
adata_manager: AnnDataManager,
train_size: float = 0.9,
train_size: float | None = None,
validation_size: float | None = None,
shuffle_set_split: bool = True,
n_samples_per_label: int | None = None,
Expand All @@ -343,7 +401,8 @@ def __init__(
):
super().__init__()
self.adata_manager = adata_manager
self.train_size = float(train_size)
self.train_size_is_none = not bool(train_size)
self.train_size = 0.9 if train_size is None else float(train_size)
self.validation_size = validation_size
self.shuffle_set_split = shuffle_set_split
self.drop_last = kwargs.pop("drop_last", False)
Expand Down Expand Up @@ -379,10 +438,17 @@ def setup(self, stage: str | None = None):
n_labeled_train, n_labeled_val = validate_data_split_with_external_indexing(
n_labeled_idx,
[labeled_idx_train, labeled_idx_val, labeled_idx_test],
self.data_loader_kwargs.pop("batch_size", settings.batch_size),
self.drop_last,
)
else:
n_labeled_train, n_labeled_val = validate_data_split(
n_labeled_idx, self.train_size, self.validation_size
n_labeled_idx,
self.train_size,
self.validation_size,
self.data_loader_kwargs.pop("batch_size", settings.batch_size),
self.drop_last,
self.train_size_is_none,
)

labeled_permutation = self._labeled_indices
Expand Down Expand Up @@ -413,10 +479,17 @@ def setup(self, stage: str | None = None):
n_unlabeled_train, n_unlabeled_val = validate_data_split_with_external_indexing(
n_unlabeled_idx,
[unlabeled_idx_train, unlabeled_idx_val, unlabeled_idx_test],
self.data_loader_kwargs.pop("batch_size", settings.batch_size),
self.drop_last,
)
else:
n_unlabeled_train, n_unlabeled_val = validate_data_split(
n_unlabeled_idx, self.train_size, self.validation_size
n_unlabeled_idx,
self.train_size,
self.validation_size,
self.data_loader_kwargs.pop("batch_size", settings.batch_size),
self.drop_last,
self.train_size_is_none,
)

unlabeled_permutation = self._unlabeled_indices
Expand Down Expand Up @@ -508,7 +581,8 @@ class DeviceBackedDataSplitter(DataSplitter):
adata_manager
:class:`~scvi.data.AnnDataManager` object that has been created via ``setup_anndata``.
train_size
float, or None (default is 0.9)
float, or None (default is None, which is practicaly 0.9 and potentially adding small last
batch to validation cells)
validation_size
float, or None (default is None)
%(param_accelerator)s
Expand Down Expand Up @@ -536,7 +610,7 @@ class DeviceBackedDataSplitter(DataSplitter):
def __init__(
self,
adata_manager: AnnDataManager,
train_size: float = 1.0,
train_size: float | None = None,
validation_size: float | None = None,
accelerator: str = "auto",
device: int | str = "auto",
Expand Down
2 changes: 1 addition & 1 deletion src/scvi/external/cellassign/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def train(
lr: float = 3e-3,
accelerator: str = "auto",
devices: int | list[int] | str = "auto",
train_size: float = 0.9,
train_size: float | None = None,
validation_size: float | None = None,
shuffle_set_split: bool = True,
batch_size: int = 1024,
Expand Down
20 changes: 17 additions & 3 deletions src/scvi/external/contrastivevi/_contrastive_data_splitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def __init__(
adata_manager: AnnDataManager,
background_indices: list[int],
target_indices: list[int],
train_size: float = 0.9,
train_size: float | None = None,
validation_size: float | None = None,
shuffle_set_split: bool = True,
load_sparse_tensor: bool = False,
Expand All @@ -78,10 +78,20 @@ def __init__(
self.n_target = len(target_indices)
if external_indexing is None:
self.n_background_train, self.n_background_val = validate_data_split(
self.n_background, self.train_size, self.validation_size
self.n_background,
self.train_size,
self.validation_size,
self.data_loader_kwargs.pop("batch_size", settings.batch_size),
self.drop_last,
self.train_size_is_none,
)
self.n_target_train, self.n_target_val = validate_data_split(
self.n_target, self.train_size, self.validation_size
self.n_target,
self.train_size,
self.validation_size,
self.data_loader_kwargs.pop("batch_size", settings.batch_size),
self.drop_last,
self.train_size_is_none,
)
else:
# we need to intersect the external indexing given with the bg/target indices
Expand All @@ -93,6 +103,8 @@ def __init__(
validate_data_split_with_external_indexing(
self.n_background,
[self.background_train_idx, self.background_val_idx, self.background_test_idx],
self.data_loader_kwargs.pop("batch_size", settings.batch_size),
self.drop_last,
)
)
self.background_train_idx, self.background_val_idx, self.background_test_idx = (
Expand All @@ -107,6 +119,8 @@ def __init__(
self.n_target_train, self.n_target_val = validate_data_split_with_external_indexing(
self.n_target,
[self.target_train_idx, self.target_val_idx, self.target_test_idx],
self.data_loader_kwargs.pop("batch_size", settings.batch_size),
self.drop_last,
)
self.target_train_idx, self.target_val_idx, self.target_test_idx = (
self.target_train_idx.tolist(),
Expand Down
2 changes: 1 addition & 1 deletion src/scvi/external/contrastivevi/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def train(
max_epochs: int | None = None,
accelerator: str = "auto",
devices: int | list[int] | str = "auto",
train_size: float = 0.9,
train_size: float | None = None,
validation_size: float | None = None,
shuffle_set_split: bool = True,
load_sparse_tensor: bool = False,
Expand Down
2 changes: 1 addition & 1 deletion src/scvi/external/gimvi/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def train(
accelerator: str = "auto",
devices: int | list[int] | str = "auto",
kappa: int = 5,
train_size: float = 0.9,
train_size: float | None = None,
validation_size: float | None = None,
shuffle_set_split: bool = True,
batch_size: int = 128,
Expand Down
2 changes: 1 addition & 1 deletion src/scvi/external/mrvi/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def train(
max_epochs: int | None = None,
accelerator: str | None = "auto",
devices: int | list[int] | str = "auto",
train_size: float = 0.9,
train_size: float | None = None,
validation_size: float | None = None,
batch_size: int = 128,
early_stopping: bool = False,
Expand Down
2 changes: 1 addition & 1 deletion src/scvi/external/scbasset/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def train(
lr: float = 0.01,
accelerator: str = "auto",
devices: int | list[int] | str = "auto",
train_size: float = 0.9,
train_size: float | None = None,
validation_size: float | None = None,
shuffle_set_split: bool = True,
batch_size: int = 128,
Expand Down
2 changes: 1 addition & 1 deletion src/scvi/external/solo/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ def train(
lr: float = 1e-3,
accelerator: str = "auto",
devices: int | list[int] | str = "auto",
train_size: float = 0.9,
train_size: float | None = None,
validation_size: float | None = None,
shuffle_set_split: bool = True,
batch_size: int = 128,
Expand Down
2 changes: 1 addition & 1 deletion src/scvi/external/velovi/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def train(
weight_decay: float = 1e-2,
accelerator: str = "auto",
devices: int | list[int] | str = "auto",
train_size: float = 0.9,
train_size: float | None = None,
validation_size: float | None = None,
batch_size: int = 256,
early_stopping: bool = True,
Expand Down
2 changes: 1 addition & 1 deletion src/scvi/model/_multivi.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ def train(
lr: float = 1e-4,
accelerator: str = "auto",
devices: int | list[int] | str = "auto",
train_size: float = 0.9,
train_size: float | None = None,
validation_size: float | None = None,
shuffle_set_split: bool = True,
batch_size: int = 128,
Expand Down
2 changes: 1 addition & 1 deletion src/scvi/model/_peakvi.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def train(
lr: float = 1e-4,
accelerator: str = "auto",
devices: int | list[int] | str = "auto",
train_size: float = 0.9,
train_size: float | None = None,
validation_size: float | None = None,
shuffle_set_split: bool = True,
batch_size: int = 128,
Expand Down
Loading