diff --git a/src/grinch/aliases.py b/src/grinch/aliases.py index f7ba2da..99ea239 100644 --- a/src/grinch/aliases.py +++ b/src/grinch/aliases.py @@ -53,6 +53,7 @@ class VARM: class UNS: TTEST = auto() + RANK_SUM = auto() KSTEST = auto() GSEA_ENRICH = auto() GSEA_PRERANK = auto() diff --git a/src/grinch/processors/__init__.py b/src/grinch/processors/__init__.py index 0d0dc41..a5b1e26 100644 --- a/src/grinch/processors/__init__.py +++ b/src/grinch/processors/__init__.py @@ -1,5 +1,5 @@ from .base_processor import BaseProcessor, ReadKey, WriteKey -from .de import KSTest, TTest, UnimodalityTest +from .de import KSTest, RankSum, TTest, UnimodalityTest from .feature_selection import PhenotypeCover from .graphs import BaseGraphConstructor, FuzzySimplicialSetGraph, KNNGraph from .group import GroupProcess @@ -30,6 +30,7 @@ 'ReadKey', 'WriteKey', 'TTest', + 'RankSum', 'KSTest', 'UnimodalityTest', 'GSEAEnrich', diff --git a/src/grinch/processors/de.py b/src/grinch/processors/de.py index 20a9601..33bfcbe 100644 --- a/src/grinch/processors/de.py +++ b/src/grinch/processors/de.py @@ -12,7 +12,7 @@ from anndata import AnnData from diptest import diptest from pydantic import Field, PositiveFloat, field_validator -from scipy.stats import ks_2samp +from scipy.stats import ks_2samp, ranksums from sklearn.utils import ( check_array, check_consistent_length, @@ -80,7 +80,7 @@ class Config(BaseProcessor.Config): These will be replaced with appropriate values (1 for p-values). - control_label : str, default=None + control_group : str, default=None The label to use in a 'one_vs_one' test type. Must be present in the array specified by `group_key`. @@ -110,7 +110,7 @@ class Config(BaseProcessor.Config): # test. E.g., if control samples are given in `control_key`, then for # each group G determined by `group_key`, will run the test G vs # control. - control_label: str | None = None + control_group: str | None = None control_key: ReadKey | None = None show_progress_bar: bool = Field(True, exclude=True) @@ -129,7 +129,7 @@ def is_one_vs_one(self) -> bool: is_ovo : bool True if any of the control keys is not None. """ - return any_not_None(self.control_key, self.control_label) + return any_not_None(self.control_key, self.control_group) def get_label_key(self, label: str) -> str: """Get the WriteKey to store a given test in.""" @@ -187,7 +187,7 @@ def _to_iter(self, labels) -> Iterable: def _get_x_control(self, adata, x, group_labels) -> NP_SP | None: """Return the x representation of a condition if set. """ - if all_None(self.cfg.control_key, self.cfg.control_label): + if all_None(self.cfg.control_key, self.cfg.control_group): return None if self.cfg.control_key is not None: # Found in a separate key @@ -198,9 +198,9 @@ def _get_x_control(self, adata, x, group_labels) -> NP_SP | None: # Ensure same number of features check_consistent_length(x.T, x_cond.T) else: # Index x by label - if self.cfg.control_label not in group_labels: - raise ValueError(f"Could not find label='{self.cfg.control_label}'.") - x_cond = x[group_labels == self.cfg.control_label] + if self.cfg.control_group not in group_labels: + raise ValueError(f"Could not find label='{self.cfg.control_group}'.") + x_cond = x[group_labels == self.cfg.control_group] return x_cond def _process(self, adata: AnnData) -> None: @@ -214,11 +214,12 @@ def _process(self, adata: AnnData) -> None: logger.warning("Cannot run test with only one group.") return + pmv = PartMeanVar(x, group_labels, self.cfg.show_progress_bar) # Read x_control and run test - self._test(x, group_labels, self._get_x_control(adata, x, group_labels)) + self._test(pmv, x, group_labels, self._get_x_control(adata, x, group_labels)) @abc.abstractmethod - def _test(self, x, group_labels, x_control=None) -> None: + def _test(self, pmv: PartMeanVar, x, group_labels, x_control=None) -> None: raise NotImplementedError @@ -241,10 +242,14 @@ class Config(PairwiseDETest.Config): cfg: Config - def _test(self, x: NP_SP, group_labels: NP1D_Any, x_control: NP_SP | None = None): - """Runs t-Test. - """ - pmv = PartMeanVar(x, group_labels, self.cfg.show_progress_bar) + def _test( + self, + pmv: PartMeanVar, + x: NP_SP, + group_labels: NP1D_Any, + x_control: NP_SP | None = None, + ): + """Runs t-Test.""" unq_labels = np.unique(group_labels) def get_x_stats(x_cond) -> _Statistics | None: @@ -254,7 +259,7 @@ def get_x_stats(x_cond) -> _Statistics | None: return _Statistics(n=len(x_cond), mean=mean, var=var) # Skip if label is the same as control label - for label in filter(lambda x: x != self.cfg.control_label, self._to_iter(unq_labels)): + for label in filter(lambda x: x != self.cfg.control_group, self._to_iter(unq_labels)): test = self._single_test(pmv, label, get_x_stats(x_control)) self.store_item(self.cfg.get_label_key(label), test) @@ -279,6 +284,59 @@ def _single_test( ) +class RankSum(PairwiseDETest): + """A class for performing differential expression analysis by using + Wilcoxon rank-sum statistic to determine if a gene is differentially + expressed in one group vs the other. + """ + + class Config(PairwiseDETest.Config): + + if TYPE_CHECKING: + create: Callable[..., 'RankSum'] + + write_key: WriteKey = f"uns.{UNS.RANK_SUM}" + alternative: str = 'two-sided' + + cfg: Config + + def _test( + self, + pmv: PartMeanVar, + x: NP_SP, + group_labels: NP1D_Any, + x_control: NP_SP | None = None, + ): + """Runs rank sum.""" + x = densify(x, ensure_2d=True) + if x_control is not None: + x_control = densify(x, ensure_2d=True) + m2 = x_control.mean(axis=0).ravel() + else: + m2 = None + + unq_labels, groups = group_indices(group_labels, as_mask=True) + for label, group in zip(self._to_iter(unq_labels), groups): + if label == self.cfg.control_group: + continue + y = x_control if x_control is not None else x[~group] + test = self._single_test(pmv, label, x=x[group], y=y, m2=m2) + self.store_item(self.cfg.get_label_key(label), test) + + def _single_test(self, pmv: PartMeanVar, label, *, x, y, m2=None) -> pd.DataFrame: + """Perform a single rank sum test.""" + statistic, pvals = ranksums(x, y, alternative=self.cfg.alternative) + pvals, qvals = self.get_pqvals(pvals) + m1 = pmv.compute([label], ddof=1)[1] # take label + m2 = m2 if m2 is not None else pmv.compute([label], ddof=1, exclude=True)[1] + log2fc = self.get_log2fc(m1, m2) + + return pd.DataFrame( + data=dict(pvals=pvals, qvals=qvals, statistic=statistic, + mean1=m1, mean2=m2, log2fc=log2fc) + ) + + class KSTest(PairwiseDETest): """A class for comparing two distributions based on the Kolmogorov-Smirnov Test. @@ -297,10 +355,15 @@ class Config(PairwiseDETest.Config): cfg: Config - def _test(self, x: NP_SP, group_labels: NP1D_Any, x_control: NP_SP | None = None): + def _test( + self, + pmv: PartMeanVar, + x: NP_SP, + group_labels: NP1D_Any, + x_control: NP_SP | None = None, + ): """Runs KS test. """ - pmv = PartMeanVar(x, group_labels, self.cfg.show_progress_bar) x, = indexable(densify(x, ensure_2d=True, warn=True)) if x_control is not None: @@ -312,7 +375,7 @@ def _test(self, x: NP_SP, group_labels: NP1D_Any, x_control: NP_SP | None = None unq_labels, groups = group_indices(group_labels, as_mask=True) for label, group in zip(self._to_iter(unq_labels), groups): - if label == self.cfg.control_label: + if label == self.cfg.control_group: continue y = x_control if x_control is not None else x[~group] test = self._single_test(pmv, label, x=x[group], y=y, m2=m2) diff --git a/tests/test_de.py b/tests/test_de.py index 675afe5..db57339 100644 --- a/tests/test_de.py +++ b/tests/test_de.py @@ -22,7 +22,8 @@ X_mods = [X, sp.csr_matrix(X), to_view(X)] tests = [("TTest", UNS.TTEST), - ("KSTest", UNS.KSTEST)] + ("KSTest", UNS.KSTEST), + ("RankSum", UNS.RANK_SUM)] # Test all combinations