Skip to content

Commit

Permalink
Adding Harmony and multiread (#108)
Browse files Browse the repository at this point in the history
* Adding Harmony and multiread

* sp
  • Loading branch information
euxhenh authored Nov 13, 2023
1 parent c3cba08 commit 2553023
Show file tree
Hide file tree
Showing 15 changed files with 186 additions and 54 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,4 @@ figs
.coverage
docs/_build
test-results
docs
13 changes: 0 additions & 13 deletions codecov.yml

This file was deleted.

4 changes: 3 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,6 @@ diptest
phenotype_cover
leidenalg
pyensembl
seaborn
seaborn
harmonypy
scanpy
8 changes: 6 additions & 2 deletions src/grinch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,13 @@
from . import processors as pr
from . import shortcuts
from .aliases import ADK, OBS, OBSM, OBSP, UNS, VAR, VARM, VARP, AnnDataKeys
from .base import StorageMixin
from .cond_filter import Filter, StackedFilter
from .conf import BaseConfigurable
from .filters import FilterCells, FilterGenes, VarianceFilter
from .main import instantiate_config
from .normalizers import Combat, Log1P, NormalizeTotal, Scale
from .pipeline import GRPipeline
from .normalizers import Combat, Harmony, Log1P, NormalizeTotal, Scale
from .pipeline import GRPipeline, MultiRead
from .processors import * # noqa
from .reporter import Report, Reporter
from .shortcuts import * # noqa
Expand All @@ -33,9 +34,11 @@
'OBSP',
'VARP',
'UNS',
'StorageMixin',
'AnnDataKeys',
'BaseConfigurable',
'GRPipeline',
'MultiRead',
'FilterCells',
'FilterGenes',
'VarianceFilter',
Expand All @@ -44,6 +47,7 @@
'Filter',
'StackedFilter',
'Combat',
'Harmony',
'Log1P',
'Scale',
'NormalizeTotal',
Expand Down
1 change: 1 addition & 0 deletions src/grinch/aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ class OBSM:
X_TRUNCATED_SVD = auto()
X_UMAP = auto()
GAUSSIAN_MIXTURE_PROBA = auto()
X_HARMONY = auto()

class VARM:
LOG_REG_COEF = auto()
Expand Down
20 changes: 19 additions & 1 deletion src/grinch/cond_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,10 @@ class Filter(BaseModel, Generic[T]):
>>> r([3, 4, 5, 6, 7], as_mask=True)
array([False, False, False, False, False])
"""
__conditions__ = ['ge', 'le', 'gt', 'lt', 'top_k', 'bot_k', 'top_ratio', 'bot_ratio']
__conditions__ = ['ge', 'le', 'gt', 'lt',
'equal', 'not_equal',
'top_k', 'bot_k',
'top_ratio', 'bot_ratio']

model_config = {
'validate_assignment': True,
Expand All @@ -80,6 +83,9 @@ class Filter(BaseModel, Generic[T]):
gt: T | None = None # greater than
lt: T | None = None # less than

equal: T | None = None # exactly equal to
not_equal: T | None = None # not equal to

top_k: NonNegativeInt | None = None # top k items after sorting
bot_k: NonNegativeInt | None = None # bottom k items after sorting
# These will be rounded up to the nearest item
Expand Down Expand Up @@ -167,6 +173,15 @@ def _take_ratio(self, arr, as_mask: bool = True):
k = int(np.ceil(ratio * len(arr))) # round up
return self._take_k_functional(arr, k, as_mask, self.is_top)

def _take_equal(self, arr, as_mask: bool = True):
"""Take elements exactly equal to `self.cfg.equal`.
"""
if self.equal is not None:
mask = arr == self.equal
elif self.not_equal is not None:
mask = arr != self.not_equal
return mask if as_mask else arr[mask]

def _take_cutoff(self, arr, as_mask: bool = True):
"""Takes the elements which are greater than or less than cutoff.
"""
Expand Down Expand Up @@ -233,6 +248,8 @@ def __call__(self, obj, as_mask=True):

if any_not_None(self.ge, self.gt, self.le, self.lt):
return self._take_cutoff(arr, as_mask)
if any_not_None(self.equal, self.not_equal):
return self._take_equal(arr, as_mask)
if any_not_None(self.top_k, self.bot_k):
return self._take_k(arr, as_mask)
if any_not_None(self.top_ratio, self.bot_ratio):
Expand All @@ -254,6 +271,7 @@ class StackedFilter(UserList):
*filters: iterable
An iterable of Filter's or StackedFilter's.
"""

def __init__(self, *filters: Filter | StackedFilter):
__filters__: List[Filter] = []

Expand Down
12 changes: 2 additions & 10 deletions src/grinch/main.py
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#!/Users/ehasanaj/mambaforge/envs/m10/bin/python
import argparse
import logging
import os
Expand All @@ -18,18 +19,9 @@
logging.captureWarnings(True)


try:
import grinch
src_dir = os.path.dirname(grinch.__file__)
except ImportError:
src_dir = os.path.dirname(os.path.realpath(__file__))

root_dir = os.path.abspath(os.path.join(src_dir, os.pardir, os.pardir))


def instantiate_config(config_name):
head, tail = os.path.split(config_name)
config_dir = os.path.join(root_dir, head)
config_dir = os.path.join(os.getcwd(), 'conf')
# context initialization
with hydra.initialize_config_dir(version_base=None, config_dir=config_dir):
cfg = hydra.compose(config_name=tail)
Expand Down
40 changes: 35 additions & 5 deletions src/grinch/normalizers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import abc
from typing import TYPE_CHECKING, Callable, Optional
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional

import harmonypy
import numpy as np
import pandas as pd
import scipy.sparse as sp
Expand All @@ -9,9 +10,11 @@
from sklearn.preprocessing import normalize
from sklearn.utils.validation import check_array, check_non_negative

from .aliases import OBSM
from .base import StorageMixin
from .conf import BaseConfigurable
from .external.combat import combat # type: ignore
from .processors import BaseProcessor
from .processors import ReadKey, WriteKey
from .utils.stats import mean_var


Expand Down Expand Up @@ -78,7 +81,7 @@ def _normalize(self, adata: AnnData) -> None:
raise NotImplementedError


class Combat(BaseNormalizer):
class Combat(BaseNormalizer, StorageMixin):
"""Performs batch correction using Combat
Source:
https://academic.oup.com/biostatistics/article/8/1/118/252073?login=false
Expand All @@ -90,12 +93,12 @@ class Config(BaseNormalizer.Config):
if TYPE_CHECKING:
create: Callable[..., 'Combat']

batch_key: str
batch_key: ReadKey

cfg: Config

def _normalize(self, adata: AnnData) -> None:
batch: pd.Series = BaseProcessor.read(adata, self.cfg.batch_key)
batch: pd.Series = self.read(adata, self.cfg.batch_key)
if not isinstance(batch, pd.Series):
raise ValueError("Batch should be a pandas series")

Expand All @@ -110,6 +113,33 @@ def _normalize(self, adata: AnnData) -> None:
adata.X = corrected_data.to_numpy()


class Harmony(BaseNormalizer, StorageMixin):
"""Performs batch correction based on Harmony.
https://www.nature.com/articles/s41592-019-0619-0
Uses scanpy's port.
"""
class Config(BaseNormalizer.Config):

if TYPE_CHECKING:
create: Callable[..., 'Harmony']

batch: str
x_key: ReadKey = f"obsm.{OBSM.X_PCA}"
write_key: WriteKey = f"obsm.{OBSM.X_HARMONY}"
kwargs: Dict[str, Any] = {}

cfg: Config

@StorageMixin.lazy_writer
def _normalize(self, adata: AnnData) -> None:
X = self.read(adata, self.cfg.x_key)
hm_out = harmonypy.run_harmony(
X, adata.obs, self.cfg.batch, **self.cfg.kwargs
)
self.store_item(self.cfg.write_key, hm_out.Z_corr.T)


class NormalizeTotal(BaseNormalizer):
"""Normalizes each cell so that total counts are equal."""

Expand Down
91 changes: 81 additions & 10 deletions src/grinch/pipeline.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,102 @@
import gc
import logging
import traceback
from os.path import expanduser
from pathlib import Path
from typing import TYPE_CHECKING, Callable, List, Optional
from typing import TYPE_CHECKING, Any, Callable, Dict, List

import anndata
import scanpy as sc
from anndata import AnnData
from pydantic import Field, FilePath, field_validator, validate_call
from tqdm.auto import tqdm

from .base import StorageMixin
from .conf import BaseConfigurable
from .processors import (
BasePredictor,
BaseTransformer,
DataSplitter,
GroupProcess,
Splitter,
WriteKey,
)

logger = logging.getLogger(__name__)


class GRPipeline(BaseConfigurable):
class ReadMixin:
"""Mixin class for reading data files."""

@staticmethod
def read(filepath: FilePath) -> AnnData:
"""Reads AnnData from filepath"""
if filepath.suffix == '.h5':
return sc.read_10x_h5(filepath)
return anndata.read(filepath)


class MultiRead(BaseConfigurable, ReadMixin):
"""Reads multiple adatas and concatenates them."""

class Config(BaseConfigurable.Config):
"""MultiRead.Config
Parameters
----------
data_readpath: Dict
Maps the ID of a dataset to the path of the AnnData.
id_key: str
The ID will be stored as a key under `id_key` if not None.
[obs|var]_names_make_unique: bool
If True, will make the corresponding axis labels unique.
kwargs: Dict
Arguments to pass to `concat`.
"""

if TYPE_CHECKING:
create: Callable[..., 'MultiRead']

paths: Dict[str, FilePath] = {}
id_key: WriteKey | None = 'obs.batch_ID'
obs_names_make_unique: bool = True
var_names_make_unique: bool = True
kwargs: Dict[str, Any] = {}

@field_validator('paths', mode='before')
def expand_paths(cls, val):
return {k: expanduser(v) for k, v in val.items()}

cfg: Config

def __call__(self) -> AnnData:
adatas = []
for idx, readpath in self.cfg.paths.items():
logger.info(f"Reading AnnData from '{readpath}'...")
adata = self.read(readpath)
if self.cfg.obs_names_make_unique:
adata.obs_names_make_unique()
if self.cfg.var_names_make_unique:
adata.var_names_make_unique()
if self.cfg.id_key is not None:
StorageMixin.write(adata, self.cfg.id_key, idx)
adatas.append(adata)
adata = anndata.concat(adatas, **self.cfg.kwargs)
del adatas
gc.collect()
adata.obs_names_make_unique()
return adata


class GRPipeline(BaseConfigurable, ReadMixin):

class Config(BaseConfigurable.Config):

if TYPE_CHECKING:
create: Callable[..., 'GRPipeline']

data_readpath: FilePath | None = None # FilePath ensures file exists
# FilePath ensures file exists
data_readpath: FilePath | MultiRead.Config | None = None
data_writepath: Path | None = None
processors: List[BaseConfigurable.Config]
verbose: bool = Field(True, exclude=True)
Expand All @@ -41,7 +109,9 @@ class Config(BaseConfigurable.Config):

@field_validator('data_readpath', 'data_writepath', mode='before')
def expand_paths(cls, val):
return expanduser(val) if val is not None else None
if not isinstance(val, MultiRead.Config):
return expanduser(val) if val is not None else None
return val

cfg: Config

Expand All @@ -52,13 +122,10 @@ def __init__(self, cfg: Config, /) -> None:
for c in self.cfg.processors:
if self.cfg.seed is not None:
c.seed = self.cfg.seed
path = self.cfg.data_writepath or self.cfg.data_readpath
if path is not None:
c.logs_path = c.logs_path / path.stem
self.processors.append(c.create())

@validate_call(config=dict(arbitrary_types_allowed=True))
def __call__(self, adata: Optional[AnnData] = None, **kwargs) -> DataSplitter:
def __call__(self, adata: AnnData | None = None, **kwargs) -> DataSplitter:
"""Applies processor to the different data splits in DataSplitter.
It differentiates between predictors (calls processor.predict),
transformers (calls processor.transform) and it defaults to
Expand All @@ -67,8 +134,12 @@ def __call__(self, adata: Optional[AnnData] = None, **kwargs) -> DataSplitter:
if adata is None:
if self.cfg.data_readpath is None:
raise ValueError("A path to adata or an adata object is required.")
logger.info(f"Reading AnnData from '{self.cfg.data_readpath}'...")
adata = anndata.read_h5ad(self.cfg.data_readpath)
if isinstance(self.cfg.data_readpath, MultiRead.Config):
multi_read = self.cfg.data_readpath.create()
adata = multi_read()
else:
logger.info(f"Reading AnnData from '{self.cfg.data_readpath}'...")
adata = self.read(self.cfg.data_readpath)
logger.info(adata)
ds = DataSplitter(adata) if not isinstance(adata, DataSplitter) else adata

Expand Down
4 changes: 3 additions & 1 deletion src/grinch/processors/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .base_processor import BaseProcessor
from .base_processor import BaseProcessor, ReadKey, WriteKey
from .de import KSTest, TTest, UnimodalityTest
from .feature_selection import PhenotypeCover
from .graphs import BaseGraphConstructor, FuzzySimplicialSetGraph, KNNGraph
Expand Down Expand Up @@ -26,6 +26,8 @@

__all__ = [
'BaseProcessor',
'ReadKey',
'WriteKey',
'TTest',
'KSTest',
'UnimodalityTest',
Expand Down
Loading

0 comments on commit 2553023

Please sign in to comment.