diff --git a/src/grinch/filters.py b/src/grinch/filters.py index f12d369..32ef6eb 100644 --- a/src/grinch/filters.py +++ b/src/grinch/filters.py @@ -5,10 +5,12 @@ import numpy as np from anndata import AnnData from pydantic import NonNegativeFloat, NonNegativeInt, validate_call -from sklearn.utils.validation import check_non_negative +from sklearn.utils.validation import check_non_negative, column_or_1d from .aliases import OBS, VAR +from .base import StorageMixin from .conf import BaseConfigurable +from .processors import ReadKey from .utils import any_not_None, true_inside from .utils.decorators import plt_interactive from .utils.plotting import plot1d @@ -132,6 +134,8 @@ class Config(BaseFilter.Config): max_counts: NonNegativeFloat | None = None min_cells: NonNegativeInt | None = None max_cells: NonNegativeInt | None = None + remove_MT_genes: bool = False + gene_names_key: ReadKey = "var_names" # only used if remove_MT_genes cfg: Config @@ -170,6 +174,11 @@ def _filter(self, adata: AnnData) -> None: self.cfg.max_cells, ) + if self.cfg.remove_MT_genes: + gene_names = StorageMixin.read(adata, self.cfg.gene_names_key) + gene_names = np.char.upper(column_or_1d(gene_names).astype(str)) + to_keep &= ~np.char.startswith(gene_names, 'MT-') + if to_keep.sum() < 1: raise ValueError( "Filtering options are too stringent. " diff --git a/tests/test_filters.py b/tests/test_filters.py index 8da720a..e505c50 100644 --- a/tests/test_filters.py +++ b/tests/test_filters.py @@ -235,11 +235,13 @@ def test_gene_all(X): "min_counts": 6, "min_cells": 1, "max_cells": 2, + "remove_MT_genes": True, } ) cfg = instantiate(cfg) filter_genes = cfg.create() adata = AnnData(X) + adata.var_names = ["G1", "MT2", "MT-3", "G4"] X_original = adata.X.copy() filter_genes(adata) assert_allclose(X_original, X) @@ -249,6 +251,25 @@ def test_gene_all(X): assert_allclose(adata.var[VAR.N_CELLS], [2]) +@pytest.mark.parametrize("X", X_mods) +def test_gene_MT(X): + cfg = OmegaConf.create( + { + "_target_": "src.grinch.FilterGenes.Config", + "remove_MT_genes": True, + } + ) + cfg = instantiate(cfg) + filter_genes = cfg.create() + adata = AnnData(X) + adata.var_names = ["G1", "MT2", "MT-3", "G4"] + X_original = adata.X.copy() + filter_genes(adata) + assert_allclose(X_original, X) + X_filtered = X[:, [0, 1, 3]] + assert_allclose(X_filtered, adata.X) + + @pytest.mark.parametrize("X", X_mods) def test_gene_inplace(X): cfg = OmegaConf.create(