Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

hvg flavors seurat and cellranger with batch: bug in subset #3042

Merged
merged 12 commits into from
Jun 28, 2024
3 changes: 3 additions & 0 deletions docs/release-notes/1.10.3.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,8 @@
```{rubric} Bug fixes
```

* Fix `subset=True` of {func}`~scanpy.pp.highly_variable_genes` when `flavor` is `seurat` or `cell_ranger`, and `batch_key!=None` {pr}`3042` {smaller}`E Roellin`


```{rubric} Performance
```
3 changes: 3 additions & 0 deletions src/scanpy/preprocessing/_highly_variable_genes.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,13 +481,16 @@ def _highly_variable_genes_batched(
if isinstance(cutoff, int):
# sort genes by how often they selected as hvg within each batch and
# break ties with normalized dispersion across batches

df_orig_ind = adata.var.index.copy()
df.sort_values(
["highly_variable_nbatches", "dispersions_norm"],
ascending=False,
na_position="last",
inplace=True,
)
df["highly_variable"] = np.arange(df.shape[0]) < cutoff
df = df.loc[df_orig_ind]
else:
df["dispersions_norm"] = df["dispersions_norm"].fillna(0) # similar to Seurat
df["highly_variable"] = cutoff.in_bounds(df["means"], df["dispersions_norm"])
Expand Down
60 changes: 46 additions & 14 deletions tests/test_highly_variable_genes.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import itertools
from pathlib import Path
from string import ascii_letters
from typing import TYPE_CHECKING
Expand Down Expand Up @@ -576,15 +577,17 @@ def test_cutoff_info():

@pytest.mark.parametrize("flavor", ["seurat", "cell_ranger"])
@pytest.mark.parametrize("array_type", ARRAY_TYPES)
@pytest.mark.parametrize("subset", [True, False], ids=["subset", "full"])
@pytest.mark.parametrize("inplace", [True, False], ids=["inplace", "copy"])
def test_subset_inplace_consistency(flavor, array_type, subset, inplace):
@pytest.mark.parametrize("batch_key", [None, "batch"])
def test_subset_inplace_consistency(flavor, array_type, batch_key):
"""Tests that, with `n_top_genes=n`
- `inplace` and `subset` interact correctly
- for both the `seurat` and `cell_ranger` flavors
- for dask arrays and non-dask arrays
- for both with and without batch_key
"""
adata = sc.datasets.blobs(n_observations=20, n_variables=80, random_state=0)
rng = np.random.default_rng(0)
adata.obs["batch"] = rng.choice(["a", "b"], adata.shape[0])
adata.X = array_type(np.abs(adata.X).astype(int))

if flavor == "seurat" or flavor == "cell_ranger":
Expand All @@ -599,18 +602,47 @@ def test_subset_inplace_consistency(flavor, array_type, subset, inplace):

n_genes = adata.shape[1]

output_df = sc.pp.highly_variable_genes(
adata,
flavor=flavor,
n_top_genes=15,
subset=subset,
inplace=inplace,
)
adatas: dict[bool, AnnData] = {}
dfs: dict[bool, pd.DataFrame] = {}
# for loops instead of parametrization to compare between settings
for subset, inplace in itertools.product([True, False], repeat=2):
adata_copy = adata.copy()

output_df = sc.pp.highly_variable_genes(
adata_copy,
flavor=flavor,
n_top_genes=15,
batch_key=batch_key,
subset=subset,
inplace=inplace,
)

assert (output_df is None) == inplace
assert len(adata_copy.var if inplace else output_df) == (
15 if subset else n_genes
)
assert sum((adata_copy.var if inplace else output_df)["highly_variable"]) == 15

if not inplace:
assert isinstance(output_df, pd.DataFrame)

if inplace:
assert subset not in adatas
adatas[subset] = adata_copy
else:
assert subset not in dfs
dfs[subset] = output_df

# check that the results are consistent for subset True/False: inplace True
adata_subset = adatas[False][:, adatas[False].var["highly_variable"]]
assert adata_subset.var_names.equals(adatas[True].var_names)

# check that the results are consistent for subset True/False: inplace False
df_subset = dfs[False][dfs[False]["highly_variable"]]
assert df_subset.index.equals(dfs[True].index)

assert (output_df is None) == inplace
assert len(adata.var if inplace else output_df) == (15 if subset else n_genes)
if output_df is not None:
assert isinstance(output_df, pd.DataFrame)
# check that the results are consistent for inplace True/False: subset True
assert adatas[True].var_names.equals(dfs[True].index)


@pytest.mark.parametrize("flavor", ["seurat", "cell_ranger"])
Expand Down
Loading