Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(train): in case of last batch <=2, move to validation if possible #3036

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@ to [Semantic Versioning]. Full commit history is available in the

#### Added

- Added adaptive handling for last training minibatch of 1-2 cells in case of
`datasplitter_kwargs={"drop_last": False}` by moving them into validation set, if available.
{pr}`3036`.
-

#### Fixed

- Breaking Change: Fix `get_outlier_cell_sample_pairs` function in {class}`scvi.external.MRVI`
Expand Down
84 changes: 70 additions & 14 deletions src/scvi/dataloaders/_data_splitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,13 @@
from scvi.utils._docstrings import devices_dsp


def validate_data_split(n_samples: int, train_size: float, validation_size: float | None = None):
def validate_data_split(
n_samples: int,
train_size: float,
validation_size: float | None = None,
batch_size: int | None = None,
drop_last: bool | int = False,
):
"""Check data splitting parameters and return n_train and n_val.

Parameters
Expand All @@ -32,21 +38,16 @@ def validate_data_split(n_samples: int, train_size: float, validation_size: floa
Size of train set. Need to be: 0 < train_size <= 1.
validation_size
Size of validation set. Need to be 0 <= validation_size < 1
batch_size
batch size of each iteration. If `None`, do not minibatch
drop_last
drops last non-full batch
"""
if train_size > 1.0 or train_size <= 0.0:
raise ValueError("Invalid train_size. Must be: 0 < train_size <= 1")

n_train = ceil(train_size * n_samples)

if n_train % settings.batch_size < 3 and n_train % settings.batch_size > 0:
warnings.warn(
f"Last batch will have a small size of {n_train % settings.batch_size}"
f"samples. Consider changing settings.batch_size or batch_size in model.train"
f"currently {settings.batch_size} to avoid errors during model training.",
UserWarning,
stacklevel=settings.warnings_stacklevel,
)

if validation_size is None:
n_val = n_samples - n_train
elif validation_size >= 1.0 or validation_size < 0.0:
Expand All @@ -59,16 +60,38 @@ def validate_data_split(n_samples: int, train_size: float, validation_size: floa
if n_train == 0:
raise ValueError(
f"With n_samples={n_samples}, train_size={train_size} and "
f"validation_size={validation_size}, the resulting train set will be empty. Adjust"
f"validation_size={validation_size}, the resulting train set will be empty. Adjust "
"any of the aforementioned parameters."
)

if batch_size is not None and not drop_last:
if n_train % batch_size < 3 and n_train % batch_size > 0:
num_of_cells = n_train % batch_size
warnings.warn(
f"Last batch will have a small size of {num_of_cells} "
f"samples. Consider changing settings.batch_size or batch_size in model.train "
f"from currently {batch_size} to avoid errors during model training. Those cells "
f"will be removed from the training set automatically",
UserWarning,
stacklevel=settings.warnings_stacklevel,
)
n_train -= num_of_cells
if n_val > 0:
n_val += num_of_cells
warnings.warn(
f" {num_of_cells} cells moved from training set to " f"validation set",
UserWarning,
stacklevel=settings.warnings_stacklevel,
)

return n_train, n_val


def validate_data_split_with_external_indexing(
n_samples: int,
external_indexing: list[np.array, np.array, np.array] | None = None,
batch_size: int | None = None,
drop_last: bool | int = False,
):
"""Check data splitting parameters and return n_train and n_val.

Expand All @@ -79,6 +102,10 @@ def validate_data_split_with_external_indexing(
external_indexing
A list of data split indices in the order of training, validation, and test sets.
Validation and test set are not required and can be left empty.
batch_size
batch size of each iteration. If `None`, do not minibatch
drop_last
drops last non-full batch
"""
if not isinstance(external_indexing, list):
raise ValueError("External indexing is not of list type")
Expand Down Expand Up @@ -132,6 +159,17 @@ def validate_data_split_with_external_indexing(
n_train = len(external_indexing[0])
n_val = len(external_indexing[1])

if batch_size is not None and not drop_last:
ori-kron-wis marked this conversation as resolved.
Show resolved Hide resolved
if n_train % batch_size < 3 and n_train % batch_size > 0:
warnings.warn(
f"Last batch will have a small size of {n_train % batch_size} "
f"samples. Consider changing settings.batch_size or batch_size in model.train "
f"from currently {settings.batch_size} to avoid errors during model training "
f"or change the given external indices accordingly",
UserWarning,
stacklevel=settings.warnings_stacklevel,
)

return n_train, n_val


Expand Down Expand Up @@ -205,10 +243,16 @@ def __init__(
self.n_train, self.n_val = validate_data_split_with_external_indexing(
self.adata_manager.adata.n_obs,
self.external_indexing,
self.data_loader_kwargs.pop("batch_size", settings.batch_size),
ori-kron-wis marked this conversation as resolved.
Show resolved Hide resolved
self.drop_last,
)
else:
self.n_train, self.n_val = validate_data_split(
self.adata_manager.adata.n_obs, self.train_size, self.validation_size
self.adata_manager.adata.n_obs,
self.train_size,
self.validation_size,
self.data_loader_kwargs.pop("batch_size", settings.batch_size),
self.drop_last,
)

def setup(self, stage: str | None = None):
Expand Down Expand Up @@ -379,10 +423,16 @@ def setup(self, stage: str | None = None):
n_labeled_train, n_labeled_val = validate_data_split_with_external_indexing(
n_labeled_idx,
[labeled_idx_train, labeled_idx_val, labeled_idx_test],
self.data_loader_kwargs.pop("batch_size", settings.batch_size),
self.drop_last,
)
else:
n_labeled_train, n_labeled_val = validate_data_split(
n_labeled_idx, self.train_size, self.validation_size
n_labeled_idx,
self.train_size,
self.validation_size,
self.data_loader_kwargs.pop("batch_size", settings.batch_size),
self.drop_last,
)

labeled_permutation = self._labeled_indices
Expand Down Expand Up @@ -413,10 +463,16 @@ def setup(self, stage: str | None = None):
n_unlabeled_train, n_unlabeled_val = validate_data_split_with_external_indexing(
n_unlabeled_idx,
[unlabeled_idx_train, unlabeled_idx_val, unlabeled_idx_test],
self.data_loader_kwargs.pop("batch_size", settings.batch_size),
self.drop_last,
)
else:
n_unlabeled_train, n_unlabeled_val = validate_data_split(
n_unlabeled_idx, self.train_size, self.validation_size
n_unlabeled_idx,
self.train_size,
self.validation_size,
self.data_loader_kwargs.pop("batch_size", settings.batch_size),
self.drop_last,
)

unlabeled_permutation = self._unlabeled_indices
Expand Down
16 changes: 14 additions & 2 deletions src/scvi/external/contrastivevi/_contrastive_data_splitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,18 @@ def __init__(
self.n_target = len(target_indices)
if external_indexing is None:
self.n_background_train, self.n_background_val = validate_data_split(
self.n_background, self.train_size, self.validation_size
self.n_background,
self.train_size,
self.validation_size,
self.data_loader_kwargs.pop("batch_size", settings.batch_size),
self.drop_last,
)
self.n_target_train, self.n_target_val = validate_data_split(
self.n_target, self.train_size, self.validation_size
self.n_target,
self.train_size,
self.validation_size,
self.data_loader_kwargs.pop("batch_size", settings.batch_size),
self.drop_last,
)
else:
# we need to intersect the external indexing given with the bg/target indices
Expand All @@ -93,6 +101,8 @@ def __init__(
validate_data_split_with_external_indexing(
self.n_background,
[self.background_train_idx, self.background_val_idx, self.background_test_idx],
self.data_loader_kwargs.pop("batch_size", settings.batch_size),
self.drop_last,
)
)
self.background_train_idx, self.background_val_idx, self.background_test_idx = (
Expand All @@ -107,6 +117,8 @@ def __init__(
self.n_target_train, self.n_target_val = validate_data_split_with_external_indexing(
self.n_target,
[self.target_train_idx, self.target_val_idx, self.target_test_idx],
self.data_loader_kwargs.pop("batch_size", settings.batch_size),
self.drop_last,
)
self.target_train_idx, self.target_val_idx, self.target_test_idx = (
self.target_train_idx.tolist(),
Expand Down
10 changes: 9 additions & 1 deletion tests/model/test_scvi.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,12 +475,20 @@ def test_scvi_n_obs_error(n_latent: int = 5):
adata = adata[0:129].copy()
SCVI.setup_anndata(adata)
model = SCVI(adata, n_latent=n_latent)
with pytest.raises(ValueError):
with pytest.warns(UserWarning):
model.train(1, train_size=1.0)
ori-kron-wis marked this conversation as resolved.
Show resolved Hide resolved
with pytest.warns(UserWarning):
# Warning is emitted if last batch less than 3 cells.
model.train(1, train_size=1.0, batch_size=127)
model.train(1, train_size=1.0, datasplitter_kwargs={"drop_last": True})

adata = synthetic_iid()
adata = adata[0:143].copy()
SCVI.setup_anndata(adata)
model = SCVI(adata, n_latent=n_latent)
with pytest.warns(UserWarning):
model.train(1, train_size=0.9) # np.ceil(n_cells * 0.9) % 128 == 1

assert model.is_trained is True


Expand Down