Skip to content

Commit

Permalink
Adding rank sum test
Browse files Browse the repository at this point in the history
  • Loading branch information
euxhenh committed Dec 12, 2023
1 parent dae6b41 commit 1f63146
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 20 deletions.
1 change: 1 addition & 0 deletions src/grinch/aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ class VARM:

class UNS:
TTEST = auto()
RANK_SUM = auto()
KSTEST = auto()
GSEA_ENRICH = auto()
GSEA_PRERANK = auto()
Expand Down
3 changes: 2 additions & 1 deletion src/grinch/processors/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .base_processor import BaseProcessor, ReadKey, WriteKey
from .de import KSTest, TTest, UnimodalityTest
from .de import KSTest, RankSum, TTest, UnimodalityTest
from .feature_selection import PhenotypeCover
from .graphs import BaseGraphConstructor, FuzzySimplicialSetGraph, KNNGraph
from .group import GroupProcess
Expand Down Expand Up @@ -30,6 +30,7 @@
'ReadKey',
'WriteKey',
'TTest',
'RankSum',
'KSTest',
'UnimodalityTest',
'GSEAEnrich',
Expand Down
99 changes: 81 additions & 18 deletions src/grinch/processors/de.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from anndata import AnnData
from diptest import diptest
from pydantic import Field, PositiveFloat, field_validator
from scipy.stats import ks_2samp
from scipy.stats import ks_2samp, ranksums
from sklearn.utils import (
check_array,
check_consistent_length,
Expand Down Expand Up @@ -80,7 +80,7 @@ class Config(BaseProcessor.Config):
These will be replaced with appropriate values (1 for
p-values).
control_label : str, default=None
control_group : str, default=None
The label to use in a 'one_vs_one' test type. Must be present
in the array specified by `group_key`.
Expand Down Expand Up @@ -110,7 +110,7 @@ class Config(BaseProcessor.Config):
# test. E.g., if control samples are given in `control_key`, then for
# each group G determined by `group_key`, will run the test G vs
# control.
control_label: str | None = None
control_group: str | None = None
control_key: ReadKey | None = None

show_progress_bar: bool = Field(True, exclude=True)
Expand All @@ -129,7 +129,7 @@ def is_one_vs_one(self) -> bool:
is_ovo : bool
True if any of the control keys is not None.
"""
return any_not_None(self.control_key, self.control_label)
return any_not_None(self.control_key, self.control_group)

def get_label_key(self, label: str) -> str:
"""Get the WriteKey to store a given test in."""
Expand Down Expand Up @@ -187,7 +187,7 @@ def _to_iter(self, labels) -> Iterable:
def _get_x_control(self, adata, x, group_labels) -> NP_SP | None:
"""Return the x representation of a condition if set.
"""
if all_None(self.cfg.control_key, self.cfg.control_label):
if all_None(self.cfg.control_key, self.cfg.control_group):
return None

if self.cfg.control_key is not None: # Found in a separate key
Expand All @@ -198,9 +198,9 @@ def _get_x_control(self, adata, x, group_labels) -> NP_SP | None:
# Ensure same number of features
check_consistent_length(x.T, x_cond.T)
else: # Index x by label
if self.cfg.control_label not in group_labels:
raise ValueError(f"Could not find label='{self.cfg.control_label}'.")
x_cond = x[group_labels == self.cfg.control_label]
if self.cfg.control_group not in group_labels:
raise ValueError(f"Could not find label='{self.cfg.control_group}'.")
x_cond = x[group_labels == self.cfg.control_group]
return x_cond

def _process(self, adata: AnnData) -> None:
Expand All @@ -214,11 +214,12 @@ def _process(self, adata: AnnData) -> None:
logger.warning("Cannot run test with only one group.")
return

pmv = PartMeanVar(x, group_labels, self.cfg.show_progress_bar)
# Read x_control and run test
self._test(x, group_labels, self._get_x_control(adata, x, group_labels))
self._test(pmv, x, group_labels, self._get_x_control(adata, x, group_labels))

@abc.abstractmethod
def _test(self, x, group_labels, x_control=None) -> None:
def _test(self, pmv: PartMeanVar, x, group_labels, x_control=None) -> None:
raise NotImplementedError


Expand All @@ -241,10 +242,14 @@ class Config(PairwiseDETest.Config):

cfg: Config

def _test(self, x: NP_SP, group_labels: NP1D_Any, x_control: NP_SP | None = None):
"""Runs t-Test.
"""
pmv = PartMeanVar(x, group_labels, self.cfg.show_progress_bar)
def _test(
self,
pmv: PartMeanVar,
x: NP_SP,
group_labels: NP1D_Any,
x_control: NP_SP | None = None,
):
"""Runs t-Test."""
unq_labels = np.unique(group_labels)

def get_x_stats(x_cond) -> _Statistics | None:
Expand All @@ -254,7 +259,7 @@ def get_x_stats(x_cond) -> _Statistics | None:
return _Statistics(n=len(x_cond), mean=mean, var=var)

# Skip if label is the same as control label
for label in filter(lambda x: x != self.cfg.control_label, self._to_iter(unq_labels)):
for label in filter(lambda x: x != self.cfg.control_group, self._to_iter(unq_labels)):
test = self._single_test(pmv, label, get_x_stats(x_control))
self.store_item(self.cfg.get_label_key(label), test)

Expand All @@ -279,6 +284,59 @@ def _single_test(
)


class RankSum(PairwiseDETest):
"""A class for performing differential expression analysis by using
Wilcoxon rank-sum statistic to determine if a gene is differentially
expressed in one group vs the other.
"""

class Config(PairwiseDETest.Config):

if TYPE_CHECKING:
create: Callable[..., 'RankSum']

write_key: WriteKey = f"uns.{UNS.RANK_SUM}"
alternative: str = 'two-sided'

cfg: Config

def _test(
self,
pmv: PartMeanVar,
x: NP_SP,
group_labels: NP1D_Any,
x_control: NP_SP | None = None,
):
"""Runs rank sum."""
x = densify(x, ensure_2d=True)
if x_control is not None:
x_control = densify(x, ensure_2d=True)
m2 = x_control.mean(axis=0).ravel()
else:
m2 = None

unq_labels, groups = group_indices(group_labels, as_mask=True)
for label, group in zip(self._to_iter(unq_labels), groups):
if label == self.cfg.control_group:
continue
y = x_control if x_control is not None else x[~group]
test = self._single_test(pmv, label, x=x[group], y=y, m2=m2)
self.store_item(self.cfg.get_label_key(label), test)

def _single_test(self, pmv: PartMeanVar, label, *, x, y, m2=None) -> pd.DataFrame:
"""Perform a single rank sum test."""
statistic, pvals = ranksums(x, y, alternative=self.cfg.alternative)
pvals, qvals = self.get_pqvals(pvals)
m1 = pmv.compute([label], ddof=1)[1] # take label
m2 = m2 if m2 is not None else pmv.compute([label], ddof=1, exclude=True)[1]
log2fc = self.get_log2fc(m1, m2)

return pd.DataFrame(
data=dict(pvals=pvals, qvals=qvals, statistic=statistic,
mean1=m1, mean2=m2, log2fc=log2fc)
)


class KSTest(PairwiseDETest):
"""A class for comparing two distributions based on the
Kolmogorov-Smirnov Test.
Expand All @@ -297,10 +355,15 @@ class Config(PairwiseDETest.Config):

cfg: Config

def _test(self, x: NP_SP, group_labels: NP1D_Any, x_control: NP_SP | None = None):
def _test(
self,
pmv: PartMeanVar,
x: NP_SP,
group_labels: NP1D_Any,
x_control: NP_SP | None = None,
):
"""Runs KS test.
"""
pmv = PartMeanVar(x, group_labels, self.cfg.show_progress_bar)
x, = indexable(densify(x, ensure_2d=True, warn=True))

if x_control is not None:
Expand All @@ -312,7 +375,7 @@ def _test(self, x: NP_SP, group_labels: NP1D_Any, x_control: NP_SP | None = None

unq_labels, groups = group_indices(group_labels, as_mask=True)
for label, group in zip(self._to_iter(unq_labels), groups):
if label == self.cfg.control_label:
if label == self.cfg.control_group:
continue
y = x_control if x_control is not None else x[~group]
test = self._single_test(pmv, label, x=x[group], y=y, m2=m2)
Expand Down
3 changes: 2 additions & 1 deletion tests/test_de.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@

X_mods = [X, sp.csr_matrix(X), to_view(X)]
tests = [("TTest", UNS.TTEST),
("KSTest", UNS.KSTEST)]
("KSTest", UNS.KSTEST),
("RankSum", UNS.RANK_SUM)]


# Test all combinations
Expand Down

0 comments on commit 1f63146

Please sign in to comment.