Skip to content

Commit

Permalink
Backport PR #3042: hvg flavors seurat and cellranger with batch: bug …
Browse files Browse the repository at this point in the history
…in subset (#3128)

Co-authored-by: Eljas Roellin <[email protected]>
  • Loading branch information
meeseeksmachine and eroell authored Jun 28, 2024
1 parent af88467 commit 4e5d903
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 14 deletions.
3 changes: 3 additions & 0 deletions docs/release-notes/1.10.3.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,8 @@
```{rubric} Bug fixes
```

* Fix `subset=True` of {func}`~scanpy.pp.highly_variable_genes` when `flavor` is `seurat` or `cell_ranger`, and `batch_key!=None` {pr}`3042` {smaller}`E Roellin`


```{rubric} Performance
```
3 changes: 3 additions & 0 deletions src/scanpy/preprocessing/_highly_variable_genes.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,13 +481,16 @@ def _highly_variable_genes_batched(
if isinstance(cutoff, int):
# sort genes by how often they selected as hvg within each batch and
# break ties with normalized dispersion across batches

df_orig_ind = adata.var.index.copy()
df.sort_values(
["highly_variable_nbatches", "dispersions_norm"],
ascending=False,
na_position="last",
inplace=True,
)
df["highly_variable"] = np.arange(df.shape[0]) < cutoff
df = df.loc[df_orig_ind]
else:
df["dispersions_norm"] = df["dispersions_norm"].fillna(0) # similar to Seurat
df["highly_variable"] = cutoff.in_bounds(df["means"], df["dispersions_norm"])
Expand Down
60 changes: 46 additions & 14 deletions tests/test_highly_variable_genes.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import itertools
from pathlib import Path
from string import ascii_letters
from typing import TYPE_CHECKING
Expand Down Expand Up @@ -576,15 +577,17 @@ def test_cutoff_info():

@pytest.mark.parametrize("flavor", ["seurat", "cell_ranger"])
@pytest.mark.parametrize("array_type", ARRAY_TYPES)
@pytest.mark.parametrize("subset", [True, False], ids=["subset", "full"])
@pytest.mark.parametrize("inplace", [True, False], ids=["inplace", "copy"])
def test_subset_inplace_consistency(flavor, array_type, subset, inplace):
@pytest.mark.parametrize("batch_key", [None, "batch"])
def test_subset_inplace_consistency(flavor, array_type, batch_key):
"""Tests that, with `n_top_genes=n`
- `inplace` and `subset` interact correctly
- for both the `seurat` and `cell_ranger` flavors
- for dask arrays and non-dask arrays
- for both with and without batch_key
"""
adata = sc.datasets.blobs(n_observations=20, n_variables=80, random_state=0)
rng = np.random.default_rng(0)
adata.obs["batch"] = rng.choice(["a", "b"], adata.shape[0])
adata.X = array_type(np.abs(adata.X).astype(int))

if flavor == "seurat" or flavor == "cell_ranger":
Expand All @@ -599,18 +602,47 @@ def test_subset_inplace_consistency(flavor, array_type, subset, inplace):

n_genes = adata.shape[1]

output_df = sc.pp.highly_variable_genes(
adata,
flavor=flavor,
n_top_genes=15,
subset=subset,
inplace=inplace,
)
adatas: dict[bool, AnnData] = {}
dfs: dict[bool, pd.DataFrame] = {}
# for loops instead of parametrization to compare between settings
for subset, inplace in itertools.product([True, False], repeat=2):
adata_copy = adata.copy()

output_df = sc.pp.highly_variable_genes(
adata_copy,
flavor=flavor,
n_top_genes=15,
batch_key=batch_key,
subset=subset,
inplace=inplace,
)

assert (output_df is None) == inplace
assert len(adata_copy.var if inplace else output_df) == (
15 if subset else n_genes
)
assert sum((adata_copy.var if inplace else output_df)["highly_variable"]) == 15

if not inplace:
assert isinstance(output_df, pd.DataFrame)

if inplace:
assert subset not in adatas
adatas[subset] = adata_copy
else:
assert subset not in dfs
dfs[subset] = output_df

# check that the results are consistent for subset True/False: inplace True
adata_subset = adatas[False][:, adatas[False].var["highly_variable"]]
assert adata_subset.var_names.equals(adatas[True].var_names)

# check that the results are consistent for subset True/False: inplace False
df_subset = dfs[False][dfs[False]["highly_variable"]]
assert df_subset.index.equals(dfs[True].index)

assert (output_df is None) == inplace
assert len(adata.var if inplace else output_df) == (15 if subset else n_genes)
if output_df is not None:
assert isinstance(output_df, pd.DataFrame)
# check that the results are consistent for inplace True/False: subset True
assert adatas[True].var_names.equals(dfs[True].index)


@pytest.mark.parametrize("flavor", ["seurat", "cell_ranger"])
Expand Down

0 comments on commit 4e5d903

Please sign in to comment.