From fa8377b199926c368cfc5ed00178928be3a103a6 Mon Sep 17 00:00:00 2001 From: euxhenh Date: Sat, 29 Jul 2023 17:46:08 -0400 Subject: [PATCH 1/4] Adding interactive mode for filters --- .gitignore | 3 +- src/grinch/__init__.py | 3 +- src/grinch/conf.py | 37 +++++++++++++++- src/grinch/filters.py | 44 +++++++++++++++++-- src/grinch/normalizers.py | 34 ++++++++++++++- src/grinch/processors/transformers.py | 4 +- src/grinch/utils/plotting.py | 47 +++++++++++++++++++++ src/grinch/utils/stats.py | 61 ++++++++++++++++++++++++++- 8 files changed, 221 insertions(+), 12 deletions(-) create mode 100644 src/grinch/utils/plotting.py diff --git a/.gitignore b/.gitignore index de4c585..6b9ebc1 100644 --- a/.gitignore +++ b/.gitignore @@ -12,4 +12,5 @@ conf/* !conf/example_DE_analysis.yaml docs *.h5ad -figs \ No newline at end of file +figs +*logs \ No newline at end of file diff --git a/src/grinch/__init__.py b/src/grinch/__init__.py index f858fb8..dddd327 100644 --- a/src/grinch/__init__.py +++ b/src/grinch/__init__.py @@ -11,7 +11,7 @@ ) from .filters import FilterCells, FilterGenes from .main import instantiate_config -from .normalizers import Log1P, NormalizeTotal +from .normalizers import Log1P, NormalizeTotal, Scale from .pipeline import GRPipeline from .processors import * # noqa from .reporter import Report, Reporter @@ -43,6 +43,7 @@ 'Filter', 'StackedFilter', 'Log1P', + 'Scale', 'NormalizeTotal', 'instantiate_config', ] diff --git a/src/grinch/conf.py b/src/grinch/conf.py index dadea01..2fa8895 100644 --- a/src/grinch/conf.py +++ b/src/grinch/conf.py @@ -1,11 +1,15 @@ import abc import inspect +from contextlib import contextmanager from itertools import islice +from pathlib import Path from typing import ClassVar, List, Tuple -from pydantic import BaseModel, Field +import matplotlib.pyplot as plt +from pydantic import BaseModel, Field, field_validator from .reporter import Report, Reporter +from .utils.validation import all_not_None reporter = Reporter() @@ -100,7 +104,15 @@ class BaseConfigurable(_BaseConfigurable): class Config(BaseConfig): seed: int | None = None - sanity_check: ClassVar[bool] = Field(False) + logs_path: Path | None = Path('./grinch_logs') # Default + sanity_check: ClassVar[bool] = Field(False, exclude=True) + interactive: bool = Field(False, exclude=True) + + @field_validator('logs_path', mode='before') + def convert_to_Path(cls, val): + if val is None: + return None + return Path(val) cfg: Config @@ -108,6 +120,27 @@ def __init__(self, cfg: Config, /): self.cfg = cfg self._reporter = reporter + @property + def logs_path(self) -> Path: + return self.cfg.logs_path + + @contextmanager + def interactive(self, save_path: str | Path | None = None, **kwargs): + plt.ion() + yield None + plt.ioff() + + if all_not_None(self.logs_path, save_path): + self.logs_path.mkdir(parents=True, exist_ok=True) + # Set good defaults + kwargs.setdefault('dpi', 300) + kwargs.setdefault('bbox_inches', 'tight') + kwargs.setdefault('transparent', True) + plt.savefig(self.logs_path / save_path, **kwargs) + + plt.clf() + plt.show() + def log( self, message: str, diff --git a/src/grinch/filters.py b/src/grinch/filters.py index 9acba03..d545600 100644 --- a/src/grinch/filters.py +++ b/src/grinch/filters.py @@ -10,6 +10,7 @@ from .aliases import OBS, VAR from .conf import BaseConfigurable from .utils import any_not_None, true_inside +from .utils.plotting import plot1d from .utils.stats import _var logger = logging.getLogger(__name__) @@ -52,10 +53,10 @@ class FilterCells(BaseFilter): """Filters cells based on counts and number of expressed genes.""" class Config(BaseFilter.Config): - min_counts: Optional[float] = Field(None, ge=0) - max_counts: Optional[float] = Field(None, ge=0) - min_genes: Optional[int] = Field(None, ge=0) - max_genes: Optional[int] = Field(None, ge=0) + min_counts: float | None = Field(None, ge=0) + max_counts: float | None = Field(None, ge=0) + min_genes: int | None = Field(None, ge=0) + max_genes: int | None = Field(None, ge=0) cfg: Config @@ -66,6 +67,13 @@ def _filter(self, adata: AnnData) -> None: to_keep = np.ones(adata.shape[0], dtype=bool) counts_per_cell = np.ravel(adata.X.sum(axis=1)) + + if self.cfg.interactive: + with self.interactive('counts_per_cell.png'): + plot1d(counts_per_cell, 'nbinom', title='Counts per Cell') + self.cfg.min_counts = eval(input("Enter min_counts=")) + self.cfg.max_counts = eval(input("Enter max_counts=")) + if any_not_None(self.cfg.min_counts, self.cfg.max_counts): to_keep &= true_inside( counts_per_cell, @@ -75,6 +83,13 @@ def _filter(self, adata: AnnData) -> None: # Values are ensured to be non-negative genes_per_cell = np.ravel((adata.X > 0).sum(axis=1)) + + if self.cfg.interactive: + with self.interactive('genes_per_cell.png'): + plot1d(genes_per_cell, 'nbinom', title='Genes per Cell') + self.cfg.min_genes = eval(input("Enter min_genes=")) + self.cfg.max_genes = eval(input("Enter max_genes=")) + if any_not_None(self.cfg.min_genes, self.cfg.max_genes): to_keep &= true_inside( genes_per_cell, @@ -115,6 +130,13 @@ def _filter(self, adata: AnnData) -> None: to_keep = np.ones(adata.shape[1], dtype=bool) counts_per_gene = np.ravel(adata.X.sum(axis=0)) + + if self.cfg.interactive: + with self.interactive('counts_per_gene.png'): + plot1d(counts_per_gene, 'halfnorm', title='Counts per Gene') + self.cfg.min_counts = eval(input("Enter min_counts=")) + self.cfg.max_counts = eval(input("Enter max_counts=")) + if any_not_None(self.cfg.min_counts, self.cfg.max_counts): to_keep &= true_inside( counts_per_gene, @@ -123,6 +145,13 @@ def _filter(self, adata: AnnData) -> None: ) cells_per_gene = np.ravel((adata.X > 0).sum(axis=0)) + + if self.cfg.interactive: + with self.interactive('cells_per_gene.png'): + plot1d(cells_per_gene, 'nbinom', title='Cells per Gene') + self.cfg.min_cells = eval(input("Enter min_cells=")) + self.cfg.max_cells = eval(input("Enter max_cells=")) + if any_not_None(self.cfg.min_cells, self.cfg.max_cells): to_keep &= true_inside( cells_per_gene, @@ -132,6 +161,13 @@ def _filter(self, adata: AnnData) -> None: # TODO separate variance filter into a new module gene_var = _var(adata.X, axis=0, ddof=self.cfg.ddof) + + if self.cfg.interactive: + with self.interactive('gene_var.png'): + plot1d(gene_var, 'halfnorm', title='Gene Variance') + self.cfg.min_var = eval(input("Enter min_var=")) + self.cfg.max_var = eval(input("Enter max_var=")) + if any_not_None(self.cfg.min_var, self.cfg.max_var): to_keep &= true_inside(gene_var, self.cfg.min_var, self.cfg.max_var) diff --git a/src/grinch/normalizers.py b/src/grinch/normalizers.py index 9289753..94869ce 100644 --- a/src/grinch/normalizers.py +++ b/src/grinch/normalizers.py @@ -9,6 +9,7 @@ from sklearn.utils.validation import check_array, check_non_negative from .conf import BaseConfigurable +from .utils.stats import mean_var class BaseNormalizer(BaseConfigurable): @@ -29,7 +30,7 @@ class BaseNormalizer(BaseConfigurable): class Config(BaseConfigurable.Config): inplace: bool = True save_input: bool = True - input_layer_name: Optional[str] = None + input_layer_name: str | None = None @field_validator('input_layer_name') def resolve_input_layer_name(cls, value): @@ -104,3 +105,34 @@ def _normalize(self, adata: AnnData) -> None: check_non_negative(adata.X, f'{self.__class__.__name__}') to_log = adata.X.data if sp.issparse(adata.X) else adata.X np.log1p(to_log, out=to_log) + + +class Scale(BaseNormalizer): + """Scales to standard deviation and optionally zero mean. + """ + + class Config(BaseNormalizer.Config): + max_value: float | None = Field(None, gt=0) + with_mean: bool = True + with_std: bool = True + + cfg: Config + + def _normalize(self, adata: AnnData) -> None: + X = adata.X + # Run before densifying: faster computation if sparse + mean, var = mean_var(X, axis=0) + var[var == 0] = 1 + std = var ** (1/2) + + if self.cfg.with_mean and sp.issparse(adata.X): + X = X.toarray() + if self.cfg.with_mean: + X -= mean + X /= std + + if self.cfg.max_value is not None: + to_clip = X.data if sp.issparse(X) else X + np.clip(to_clip, None, self.cfg.max_value, out=to_clip) + + adata.X = X diff --git a/src/grinch/processors/transformers.py b/src/grinch/processors/transformers.py index 3171148..cbe98dc 100644 --- a/src/grinch/processors/transformers.py +++ b/src/grinch/processors/transformers.py @@ -1,6 +1,6 @@ import abc import logging -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List from anndata import AnnData from pydantic import Field, field_validator, validate_call @@ -61,7 +61,7 @@ class PCA(BaseTransformer): class Config(BaseTransformer.Config): x_emb_key: str = f"obsm.{OBSM.X_PCA}" # PCA args - n_components: Optional[int | float | str] = None + n_components: int | float | str | None = 50 whiten: bool = False svd_solver: str = 'auto' diff --git a/src/grinch/utils/plotting.py b/src/grinch/utils/plotting.py new file mode 100644 index 0000000..a3efc33 --- /dev/null +++ b/src/grinch/utils/plotting.py @@ -0,0 +1,47 @@ +import logging + +import matplotlib.pyplot as plt +import numpy.typing as npt +import seaborn as sns +from scipy.stats import norm + +from .stats import fit_nbinom, stats1d + +logger = logging.getLogger(__name__) + + +def plot1d( + rvs: npt.ArrayLike, + dist: str, + *, + title: str | None = None, + ax: plt.Axes | None = None, + **kwargs, +) -> None: + """Generates random variables from distribution `dist` and plots a + histogram in kde mode. + """ + # For better plot view in case there are a few extreme outliers + params = norm.fit(rvs) + z1 = norm.ppf(0.01, *params) + z2 = norm.ppf(0.99, *params) + to_keep = (rvs >= z1) & (rvs <= z2) + to_remove = (~to_keep).sum() + if 0 < to_remove <= 10: + logger.warning(f"Removing {to_remove} points from view.") + rvs_to_plot = rvs[to_keep] + else: + rvs_to_plot = rvs + + sns.violinplot(rvs_to_plot, color='#b56576') + ax = sns.stripplot(rvs_to_plot, color='black', size=1, jitter=1) + ax.set_title(title) + params = fit_nbinom(rvs) if dist == 'nbinom' else None + stats = stats1d(rvs, dist, params=params, pprint=True) + ax.axhline(stats['dist_q05'], color='#e56b6f', zorder=100) + ax.axhline(stats['dist_q95'], color='#e56b6f', zorder=100) + + y2 = ax.twinx() + y2.set_ylim(ax.get_ylim()) + y2.set_yticks([stats['dist_q05'], stats['dist_q95']]) + y2.set_yticklabels(["dist_q=0.05", "dist_q=0.95"]) diff --git a/src/grinch/utils/stats.py b/src/grinch/utils/stats.py index 69494fe..8b2a846 100644 --- a/src/grinch/utils/stats.py +++ b/src/grinch/utils/stats.py @@ -1,17 +1,21 @@ import logging from dataclasses import dataclass from functools import wraps -from typing import Dict, Hashable, List, Optional, Tuple, overload +from typing import Any, Dict, Hashable, List, Optional, Tuple, overload import numpy as np import numpy.typing as npt import scipy.sparse as sp +import scipy.stats as scs +from rich.pretty import pretty_repr +from scipy.stats import rv_continuous from scipy.stats._stats_py import ( Ttest_indResult, _ttest_ind_from_stats, _unequal_var_ttest_denom, ) from sklearn.utils import check_consistent_length, indexable +from statsmodels.discrete.discrete_model import NegativeBinomial from statsmodels.stats.multitest import multipletests from tqdm.auto import tqdm @@ -218,6 +222,61 @@ def _compute_log2fc(mean1, mean2, base='e', is_logged=False): return log2fc +def fit_nbinom(rvs) -> Tuple[float, float]: + """Fit a negative binomial distribution to the data. Returns (n, p). + """ + s = np.ones_like(rvs.astype(float)) + nb = NegativeBinomial(rvs, s).fit() + + mu = np.exp(nb.params[0]) + p = 1 / (1 + mu * nb.params[1]) + n = mu * p / (1-p) + + return n, p + + +def stats1d( + rvs, + dist: str, + *, + params: Any | None = None, + alpha: float = 0.05, + pprint: bool = False, +) -> dict[str, float]: + """Computes main statistics for a 1D distribution assumed to follow + `dist`. `dist` should be a string pointing to a scipy distribution. + + Parameters + __________ + rvs: array-like + Samples from the distribution. + dist: str + A continuous distribution from scipy.stats. + params: any + If None, will call sc_dist.fit. + alpha: float + The outlier threshold to use for computing inverse quantiles. + pprint: bool + If True, will print the dictionary of stats. + """ + assert alpha > 0 and alpha < 1 + sc_dist: rv_continuous = getattr(scs, dist) + if params is None: + params = sc_dist.fit(rvs) + mean, var = sc_dist.stats(*params) + stats = {'dist_mean': mean, 'dist_std': var ** (1/2), + 'dist_q05': sc_dist.ppf(alpha, *params), + 'dist_q95': sc_dist.ppf(1 - alpha, *params), + 'min': rvs.min(), 'max': rvs.max(), + 'data_mean': rvs.mean(), 'data_std': rvs.std(), + 'data_q05': np.quantile(rvs, 0.05), + 'data_q95': np.quantile(rvs, 0.95)} + if pprint: + logger.info(f"'{dist}' distribution statistics") + logger.info(pretty_repr(stats)) + return stats + + @dataclass class _MeanVarVector: n: int # number of samples accumulated so far From 7f4b55e234eae3c5dd71c100b1f400586686e9f3 Mon Sep 17 00:00:00 2001 From: euxhenh Date: Sat, 29 Jul 2023 17:53:52 -0400 Subject: [PATCH 2/4] Fix mypy --- src/grinch/conf.py | 6 +++--- src/grinch/utils/plotting.py | 4 ++-- src/grinch/utils/stats.py | 24 +++++++++++++++++++----- 3 files changed, 24 insertions(+), 10 deletions(-) diff --git a/src/grinch/conf.py b/src/grinch/conf.py index 2fa8895..963fc87 100644 --- a/src/grinch/conf.py +++ b/src/grinch/conf.py @@ -121,7 +121,7 @@ def __init__(self, cfg: Config, /): self._reporter = reporter @property - def logs_path(self) -> Path: + def logs_path(self) -> Path | None: return self.cfg.logs_path @contextmanager @@ -131,12 +131,12 @@ def interactive(self, save_path: str | Path | None = None, **kwargs): plt.ioff() if all_not_None(self.logs_path, save_path): - self.logs_path.mkdir(parents=True, exist_ok=True) + self.logs_path.mkdir(parents=True, exist_ok=True) # type: ignore # Set good defaults kwargs.setdefault('dpi', 300) kwargs.setdefault('bbox_inches', 'tight') kwargs.setdefault('transparent', True) - plt.savefig(self.logs_path / save_path, **kwargs) + plt.savefig(self.logs_path / save_path, **kwargs) # type: ignore plt.clf() plt.show() diff --git a/src/grinch/utils/plotting.py b/src/grinch/utils/plotting.py index a3efc33..a673db2 100644 --- a/src/grinch/utils/plotting.py +++ b/src/grinch/utils/plotting.py @@ -1,7 +1,7 @@ import logging import matplotlib.pyplot as plt -import numpy.typing as npt +import numpy as np import seaborn as sns from scipy.stats import norm @@ -11,7 +11,7 @@ def plot1d( - rvs: npt.ArrayLike, + rvs: np.ndarray, dist: str, *, title: str | None = None, diff --git a/src/grinch/utils/stats.py b/src/grinch/utils/stats.py index 8b2a846..9ba7b5c 100644 --- a/src/grinch/utils/stats.py +++ b/src/grinch/utils/stats.py @@ -1,7 +1,7 @@ import logging from dataclasses import dataclass from functools import wraps -from typing import Any, Dict, Hashable, List, Optional, Tuple, overload +from typing import Any, Dict, Hashable, List, Tuple, overload import numpy as np import numpy.typing as npt @@ -101,11 +101,25 @@ def _var(x, axis=None, ddof=0, mean=None): return var +@overload def mean_var( x: npt.ArrayLike, - axis: Optional[int] = None, - ddof: int = 0 -) -> Tuple[int | np.ndarray, int | np.ndarray]: + axis: None = None, + ddof: int = 0, +) -> Tuple[float, float]: + ... + + +@overload +def mean_var( + x: npt.ArrayLike, + axis: int, + ddof: int = 0, +) -> Tuple[np.ndarray, np.ndarray]: + ... + + +def mean_var(x, axis=None, ddof=0): """Returns both mean and variance. Parameters @@ -135,7 +149,7 @@ def mean_var( def ttest( a: npt.ArrayLike, b: npt.ArrayLike, - axis: Optional[int] = 0 + axis: int | None = 0 ) -> Tuple[np.ndarray, np.ndarray]: """Performs a Welch's t-test (unequal sample sizes, unequal vars). Extends scipy's ttest_ind to support sparse matrices. From 9796c7623ac296c9b546de6d38e1b43df1ee445e Mon Sep 17 00:00:00 2001 From: euxhenh Date: Sat, 29 Jul 2023 18:08:07 -0400 Subject: [PATCH 3/4] Add seaborn to reqs --- requirements.txt | 3 ++- src/grinch/conf.py | 14 ++++++-------- src/grinch/pipeline.py | 1 + 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/requirements.txt b/requirements.txt index 271b9c4..3ab236c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -11,4 +11,5 @@ gseapy diptest phenotype_cover leidenalg -pyensembl \ No newline at end of file +pyensembl +seaborn \ No newline at end of file diff --git a/src/grinch/conf.py b/src/grinch/conf.py index 963fc87..ff7337f 100644 --- a/src/grinch/conf.py +++ b/src/grinch/conf.py @@ -104,14 +104,12 @@ class BaseConfigurable(_BaseConfigurable): class Config(BaseConfig): seed: int | None = None - logs_path: Path | None = Path('./grinch_logs') # Default + logs_path: Path = Path('./grinch_logs') # Default sanity_check: ClassVar[bool] = Field(False, exclude=True) interactive: bool = Field(False, exclude=True) @field_validator('logs_path', mode='before') def convert_to_Path(cls, val): - if val is None: - return None return Path(val) cfg: Config @@ -121,7 +119,7 @@ def __init__(self, cfg: Config, /): self._reporter = reporter @property - def logs_path(self) -> Path | None: + def logs_path(self) -> Path: return self.cfg.logs_path @contextmanager @@ -130,16 +128,16 @@ def interactive(self, save_path: str | Path | None = None, **kwargs): yield None plt.ioff() - if all_not_None(self.logs_path, save_path): - self.logs_path.mkdir(parents=True, exist_ok=True) # type: ignore + if save_path is not None: + self.logs_path.mkdir(parents=True, exist_ok=True) # Set good defaults kwargs.setdefault('dpi', 300) kwargs.setdefault('bbox_inches', 'tight') kwargs.setdefault('transparent', True) - plt.savefig(self.logs_path / save_path, **kwargs) # type: ignore + plt.savefig(self.logs_path / save_path, **kwargs) plt.clf() - plt.show() + plt.close() def log( self, diff --git a/src/grinch/pipeline.py b/src/grinch/pipeline.py index 2d7c9f5..a5a4c0d 100644 --- a/src/grinch/pipeline.py +++ b/src/grinch/pipeline.py @@ -48,6 +48,7 @@ def __init__(self, cfg: Config, /) -> None: for c in self.cfg.processors: if self.cfg.seed is not None: c.seed = self.cfg.seed + c.logs_path = c.logs_path / self.cfg.data_writepath.split('/')[-1] self.processors.append(c.initialize()) @validate_call(config=dict(arbitrary_types_allowed=True)) From ed35af2ca27c987f5f4b88be794db43014588d85 Mon Sep 17 00:00:00 2001 From: euxhenh Date: Sat, 29 Jul 2023 18:12:49 -0400 Subject: [PATCH 4/4] Fix path None --- src/grinch/conf.py | 1 - src/grinch/pipeline.py | 4 +++- src/grinch/utils/stats.py | 3 ++- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/src/grinch/conf.py b/src/grinch/conf.py index ff7337f..93a8301 100644 --- a/src/grinch/conf.py +++ b/src/grinch/conf.py @@ -9,7 +9,6 @@ from pydantic import BaseModel, Field, field_validator from .reporter import Report, Reporter -from .utils.validation import all_not_None reporter = Reporter() diff --git a/src/grinch/pipeline.py b/src/grinch/pipeline.py index a5a4c0d..d2790a0 100644 --- a/src/grinch/pipeline.py +++ b/src/grinch/pipeline.py @@ -48,7 +48,9 @@ def __init__(self, cfg: Config, /) -> None: for c in self.cfg.processors: if self.cfg.seed is not None: c.seed = self.cfg.seed - c.logs_path = c.logs_path / self.cfg.data_writepath.split('/')[-1] + path = self.cfg.data_writepath or self.cfg.data_readpath + if path is not None: + c.logs_path = c.logs_path / path.split('/')[-1] self.processors.append(c.initialize()) @validate_call(config=dict(arbitrary_types_allowed=True)) diff --git a/src/grinch/utils/stats.py b/src/grinch/utils/stats.py index 9ba7b5c..dac9202 100644 --- a/src/grinch/utils/stats.py +++ b/src/grinch/utils/stats.py @@ -278,7 +278,8 @@ def stats1d( if params is None: params = sc_dist.fit(rvs) mean, var = sc_dist.stats(*params) - stats = {'dist_mean': mean, 'dist_std': var ** (1/2), + stats = {'dist': dist, + 'dist_mean': mean, 'dist_std': var ** (1/2), 'dist_q05': sc_dist.ppf(alpha, *params), 'dist_q95': sc_dist.ppf(1 - alpha, *params), 'min': rvs.min(), 'max': rvs.max(),