Skip to content

Commit

Permalink
Merge pull request #47 from CSOgroup/enrichment_v2
Browse files Browse the repository at this point in the history
Rewrite of cc.pl.enrichment
  • Loading branch information
marcovarrone authored Jul 2, 2024
2 parents 5eb35ea + 08265d7 commit 734f014
Show file tree
Hide file tree
Showing 5 changed files with 414 additions and 183 deletions.
51 changes: 47 additions & 4 deletions src/cellcharter/gr/_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@
import pandas as pd
from anndata import AnnData
from squidpy._docs import d
from tqdm import tqdm


def _proportion(adata, id_key, val_key, normalize=True):
df = pd.pivot(adata.obs[[id_key, val_key]].value_counts().reset_index(), index=id_key, columns=val_key)
def _proportion(obs, id_key, val_key, normalize=True):
df = pd.pivot(obs[[id_key, val_key]].value_counts().reset_index(), index=id_key, columns=val_key)
df[df.isna()] = 0
df.columns = df.columns.droplevel(0)
if normalize:
Expand All @@ -16,6 +17,11 @@ def _proportion(adata, id_key, val_key, normalize=True):
return df


def _observed_permuted(annotations, group_key, label_key):
annotations[group_key] = annotations[group_key].sample(frac=1).reset_index(drop=True).values
return _proportion(annotations, id_key=label_key, val_key=group_key).reindex().T


def _enrichment(observed, expected, log=True):
enrichment = observed.div(expected, axis="index", level=0)

Expand All @@ -25,11 +31,24 @@ def _enrichment(observed, expected, log=True):
return enrichment


def _empirical_pvalues(observed, expected):
pvalues = np.zeros(observed.shape)
pvalues[observed.values > 0] = (
1 - np.sum(expected[:, observed.values > 0] < observed.values[observed.values > 0], axis=0) / expected.shape[0]
)
pvalues[observed.values < 0] = (
1 - np.sum(expected[:, observed.values < 0] > observed.values[observed.values < 0], axis=0) / expected.shape[0]
)
return pd.DataFrame(pvalues, columns=observed.columns, index=observed.index)


@d.dedent
def enrichment(
adata: AnnData,
group_key: str,
label_key: str,
pvalues: bool = False,
n_perms: int = 1000,
log: bool = True,
observed_expected: bool = False,
copy: bool = False,
Expand All @@ -44,6 +63,10 @@ def enrichment(
Key in :attr:`anndata.AnnData.obs` where groups are stored.
label_key
Key in :attr:`anndata.AnnData.obs` where labels are stored.
pvalues
If `True`, compute empirical p-values by permutation. It will result in a slower computation.
n_perms
Number of permutations to compute empirical p-values.
log
If `True` use log2 fold change, otherwise use fold change.
observed_expected
Expand All @@ -54,17 +77,34 @@ def enrichment(
-------
If ``copy = True``, returns a :class:`dict` with the following keys:
- ``'enrichment'`` - the enrichment values.
- ``'pvalue'`` - the enrichment pvalues (if `pvalues is True`).
- ``'observed'`` - the observed proportions (if `observed_expected is True`).
- ``'expected'`` - the expected proportions (if `observed_expected is True`).
Otherwise, modifies the ``adata`` with the following keys:
- :attr:`anndata.AnnData.uns` ``['{group_key}_{label_key}_nhood_enrichment']`` - the above mentioned dict.
- :attr:`anndata.AnnData.uns` ``['{group_key}_{label_key}_nhood_enrichment']['params']`` - the parameters used.
"""
observed = _proportion(adata, id_key=label_key, val_key=group_key).reindex().T
observed = _proportion(adata.obs, id_key=label_key, val_key=group_key).reindex().T
observed[observed.isna()] = 0
expected = adata.obs[group_key].value_counts() / adata.shape[0]
if not pvalues:
expected = adata.obs[group_key].value_counts() / adata.shape[0]
# Repeat over the number of labels
expected = pd.concat([expected] * len(observed.columns), axis=1, keys=observed.columns)
else:
annotations = adata.obs.copy()

expected = [_observed_permuted(annotations, group_key, label_key) for _ in tqdm(range(n_perms))]
expected = np.stack(expected, axis=0)

print(expected.shape)

empirical_pvalues = _empirical_pvalues(observed, expected)

expected = np.mean(expected, axis=0)
expected = pd.DataFrame(expected, columns=observed.columns, index=observed.index)

print(expected)
enrichment = _enrichment(observed, expected, log=log)

result = {"enrichment": enrichment}
Expand All @@ -73,6 +113,9 @@ def enrichment(
result["observed"] = observed
result["expected"] = expected

if pvalues:
result["pvalue"] = empirical_pvalues

if copy:
return result
else:
Expand Down
Loading

0 comments on commit 734f014

Please sign in to comment.