Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Eh/interactive mode #83

Merged
merged 4 commits into from
Jul 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,5 @@ conf/*
!conf/example_DE_analysis.yaml
docs
*.h5ad
figs
figs
*logs
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,5 @@ gseapy
diptest
phenotype_cover
leidenalg
pyensembl
pyensembl
seaborn
3 changes: 2 additions & 1 deletion src/grinch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
)
from .filters import FilterCells, FilterGenes
from .main import instantiate_config
from .normalizers import Log1P, NormalizeTotal
from .normalizers import Log1P, NormalizeTotal, Scale
from .pipeline import GRPipeline
from .processors import * # noqa
from .reporter import Report, Reporter
Expand Down Expand Up @@ -43,6 +43,7 @@
'Filter',
'StackedFilter',
'Log1P',
'Scale',
'NormalizeTotal',
'instantiate_config',
]
34 changes: 32 additions & 2 deletions src/grinch/conf.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import abc
import inspect
from contextlib import contextmanager
from itertools import islice
from pathlib import Path
from typing import ClassVar, List, Tuple

from pydantic import BaseModel, Field
import matplotlib.pyplot as plt
from pydantic import BaseModel, Field, field_validator

from .reporter import Report, Reporter

Expand Down Expand Up @@ -100,14 +103,41 @@ class BaseConfigurable(_BaseConfigurable):

class Config(BaseConfig):
seed: int | None = None
sanity_check: ClassVar[bool] = Field(False)
logs_path: Path = Path('./grinch_logs') # Default
sanity_check: ClassVar[bool] = Field(False, exclude=True)
interactive: bool = Field(False, exclude=True)

@field_validator('logs_path', mode='before')
def convert_to_Path(cls, val):
return Path(val)

cfg: Config

def __init__(self, cfg: Config, /):
self.cfg = cfg
self._reporter = reporter

@property
def logs_path(self) -> Path:
return self.cfg.logs_path

@contextmanager
def interactive(self, save_path: str | Path | None = None, **kwargs):
plt.ion()
yield None
plt.ioff()

if save_path is not None:
self.logs_path.mkdir(parents=True, exist_ok=True)
# Set good defaults
kwargs.setdefault('dpi', 300)
kwargs.setdefault('bbox_inches', 'tight')
kwargs.setdefault('transparent', True)
plt.savefig(self.logs_path / save_path, **kwargs)

plt.clf()
plt.close()

def log(
self,
message: str,
Expand Down
44 changes: 40 additions & 4 deletions src/grinch/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from .aliases import OBS, VAR
from .conf import BaseConfigurable
from .utils import any_not_None, true_inside
from .utils.plotting import plot1d
from .utils.stats import _var

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -52,10 +53,10 @@ class FilterCells(BaseFilter):
"""Filters cells based on counts and number of expressed genes."""

class Config(BaseFilter.Config):
min_counts: Optional[float] = Field(None, ge=0)
max_counts: Optional[float] = Field(None, ge=0)
min_genes: Optional[int] = Field(None, ge=0)
max_genes: Optional[int] = Field(None, ge=0)
min_counts: float | None = Field(None, ge=0)
max_counts: float | None = Field(None, ge=0)
min_genes: int | None = Field(None, ge=0)
max_genes: int | None = Field(None, ge=0)

cfg: Config

Expand All @@ -66,6 +67,13 @@ def _filter(self, adata: AnnData) -> None:
to_keep = np.ones(adata.shape[0], dtype=bool)

counts_per_cell = np.ravel(adata.X.sum(axis=1))

if self.cfg.interactive:
with self.interactive('counts_per_cell.png'):
plot1d(counts_per_cell, 'nbinom', title='Counts per Cell')
self.cfg.min_counts = eval(input("Enter min_counts="))
self.cfg.max_counts = eval(input("Enter max_counts="))

if any_not_None(self.cfg.min_counts, self.cfg.max_counts):
to_keep &= true_inside(
counts_per_cell,
Expand All @@ -75,6 +83,13 @@ def _filter(self, adata: AnnData) -> None:

# Values are ensured to be non-negative
genes_per_cell = np.ravel((adata.X > 0).sum(axis=1))

if self.cfg.interactive:
with self.interactive('genes_per_cell.png'):
plot1d(genes_per_cell, 'nbinom', title='Genes per Cell')
self.cfg.min_genes = eval(input("Enter min_genes="))
self.cfg.max_genes = eval(input("Enter max_genes="))

if any_not_None(self.cfg.min_genes, self.cfg.max_genes):
to_keep &= true_inside(
genes_per_cell,
Expand Down Expand Up @@ -115,6 +130,13 @@ def _filter(self, adata: AnnData) -> None:
to_keep = np.ones(adata.shape[1], dtype=bool)

counts_per_gene = np.ravel(adata.X.sum(axis=0))

if self.cfg.interactive:
with self.interactive('counts_per_gene.png'):
plot1d(counts_per_gene, 'halfnorm', title='Counts per Gene')
self.cfg.min_counts = eval(input("Enter min_counts="))
self.cfg.max_counts = eval(input("Enter max_counts="))

if any_not_None(self.cfg.min_counts, self.cfg.max_counts):
to_keep &= true_inside(
counts_per_gene,
Expand All @@ -123,6 +145,13 @@ def _filter(self, adata: AnnData) -> None:
)

cells_per_gene = np.ravel((adata.X > 0).sum(axis=0))

if self.cfg.interactive:
with self.interactive('cells_per_gene.png'):
plot1d(cells_per_gene, 'nbinom', title='Cells per Gene')
self.cfg.min_cells = eval(input("Enter min_cells="))
self.cfg.max_cells = eval(input("Enter max_cells="))

if any_not_None(self.cfg.min_cells, self.cfg.max_cells):
to_keep &= true_inside(
cells_per_gene,
Expand All @@ -132,6 +161,13 @@ def _filter(self, adata: AnnData) -> None:

# TODO separate variance filter into a new module
gene_var = _var(adata.X, axis=0, ddof=self.cfg.ddof)

if self.cfg.interactive:
with self.interactive('gene_var.png'):
plot1d(gene_var, 'halfnorm', title='Gene Variance')
self.cfg.min_var = eval(input("Enter min_var="))
self.cfg.max_var = eval(input("Enter max_var="))

if any_not_None(self.cfg.min_var, self.cfg.max_var):
to_keep &= true_inside(gene_var, self.cfg.min_var, self.cfg.max_var)

Expand Down
34 changes: 33 additions & 1 deletion src/grinch/normalizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from sklearn.utils.validation import check_array, check_non_negative

from .conf import BaseConfigurable
from .utils.stats import mean_var


class BaseNormalizer(BaseConfigurable):
Expand All @@ -29,7 +30,7 @@ class BaseNormalizer(BaseConfigurable):
class Config(BaseConfigurable.Config):
inplace: bool = True
save_input: bool = True
input_layer_name: Optional[str] = None
input_layer_name: str | None = None

@field_validator('input_layer_name')
def resolve_input_layer_name(cls, value):
Expand Down Expand Up @@ -104,3 +105,34 @@ def _normalize(self, adata: AnnData) -> None:
check_non_negative(adata.X, f'{self.__class__.__name__}')
to_log = adata.X.data if sp.issparse(adata.X) else adata.X
np.log1p(to_log, out=to_log)


class Scale(BaseNormalizer):
"""Scales to standard deviation and optionally zero mean.
"""

class Config(BaseNormalizer.Config):
max_value: float | None = Field(None, gt=0)
with_mean: bool = True
with_std: bool = True

cfg: Config

def _normalize(self, adata: AnnData) -> None:
X = adata.X
# Run before densifying: faster computation if sparse
mean, var = mean_var(X, axis=0)
var[var == 0] = 1
std = var ** (1/2)

if self.cfg.with_mean and sp.issparse(adata.X):
X = X.toarray()
if self.cfg.with_mean:
X -= mean
X /= std

if self.cfg.max_value is not None:
to_clip = X.data if sp.issparse(X) else X
np.clip(to_clip, None, self.cfg.max_value, out=to_clip)

adata.X = X
3 changes: 3 additions & 0 deletions src/grinch/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ def __init__(self, cfg: Config, /) -> None:
for c in self.cfg.processors:
if self.cfg.seed is not None:
c.seed = self.cfg.seed
path = self.cfg.data_writepath or self.cfg.data_readpath
if path is not None:
c.logs_path = c.logs_path / path.split('/')[-1]
self.processors.append(c.initialize())

@validate_call(config=dict(arbitrary_types_allowed=True))
Expand Down
4 changes: 2 additions & 2 deletions src/grinch/processors/transformers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import abc
import logging
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List

from anndata import AnnData
from pydantic import Field, field_validator, validate_call
Expand Down Expand Up @@ -61,7 +61,7 @@ class PCA(BaseTransformer):
class Config(BaseTransformer.Config):
x_emb_key: str = f"obsm.{OBSM.X_PCA}"
# PCA args
n_components: Optional[int | float | str] = None
n_components: int | float | str | None = 50
whiten: bool = False
svd_solver: str = 'auto'

Expand Down
47 changes: 47 additions & 0 deletions src/grinch/utils/plotting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import logging

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from scipy.stats import norm

from .stats import fit_nbinom, stats1d

logger = logging.getLogger(__name__)


def plot1d(
rvs: np.ndarray,
dist: str,
*,
title: str | None = None,
ax: plt.Axes | None = None,
**kwargs,
) -> None:
"""Generates random variables from distribution `dist` and plots a
histogram in kde mode.
"""
# For better plot view in case there are a few extreme outliers
params = norm.fit(rvs)
z1 = norm.ppf(0.01, *params)
z2 = norm.ppf(0.99, *params)
to_keep = (rvs >= z1) & (rvs <= z2)
to_remove = (~to_keep).sum()
if 0 < to_remove <= 10:
logger.warning(f"Removing {to_remove} points from view.")
rvs_to_plot = rvs[to_keep]
else:
rvs_to_plot = rvs

sns.violinplot(rvs_to_plot, color='#b56576')
ax = sns.stripplot(rvs_to_plot, color='black', size=1, jitter=1)
ax.set_title(title)
params = fit_nbinom(rvs) if dist == 'nbinom' else None
stats = stats1d(rvs, dist, params=params, pprint=True)
ax.axhline(stats['dist_q05'], color='#e56b6f', zorder=100)
ax.axhline(stats['dist_q95'], color='#e56b6f', zorder=100)

y2 = ax.twinx()
y2.set_ylim(ax.get_ylim())
y2.set_yticks([stats['dist_q05'], stats['dist_q95']])
y2.set_yticklabels(["dist_q=0.05", "dist_q=0.95"])
Loading