Skip to content

Commit

Permalink
Switching to ReadKey and WriteKey
Browse files Browse the repository at this point in the history
  • Loading branch information
euxhenh committed Aug 1, 2023
1 parent 7c1c094 commit 950fd5d
Show file tree
Hide file tree
Showing 12 changed files with 110 additions and 151 deletions.
3 changes: 0 additions & 3 deletions src/grinch/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,6 @@ def __init__(self, cfg: Config, /):
def __repr__(self):
return f"{self.__class__.__name__}({repr(self.cfg)})"

def __copy__(self):
"""Return a shallow copy of the model."""


class BaseConfigurable(_BaseConfigurable):
r"""BaseConfigurable class with some default parameters and methods.
Expand Down
11 changes: 10 additions & 1 deletion src/grinch/processors/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
from .base_processor import BaseProcessor, adata_modifier
from .base_processor import (
BaseProcessor,
ProcessorParam,
ReadKey,
WriteKey,
adata_modifier,
)
from .de import BimodalTest, KSTest, TTest
from .feature_selection import PhenotypeCover
from .graphs import BaseGraphConstructor, FuzzySimplicialSetGraph, KNNGraph
Expand Down Expand Up @@ -27,6 +33,9 @@
__all__ = [
'adata_modifier',
'BaseProcessor',
'ReadKey',
'WriteKey',
'ProcessorParam',
'TTest',
'KSTest',
'BimodalTest',
Expand Down
76 changes: 14 additions & 62 deletions src/grinch/processors/base_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,20 @@
from functools import partial
from itertools import islice, starmap
from operator import itemgetter
from typing import TYPE_CHECKING, Annotated, Any, Callable, Dict, List, TypeVar
from typing import (
TYPE_CHECKING,
Annotated,
Any,
Callable,
Dict,
List,
TypeAlias,
TypeVar,
)

from anndata import AnnData
from pydantic import field_validator, validate_call

from ..aliases import ALLOWED_KEYS
from ..conf import BaseConfigurable
from ..custom_types import REP, REP_KEY, NP1D_int
from ..utils.ops import compose, safe_format
Expand All @@ -23,6 +31,10 @@
T = TypeVar('T')
ProcessorParam = Annotated[T, 'ProcessorParam']

# Storage and retrieval keys
ReadKey: TypeAlias = str
WriteKey: TypeAlias = str


def adata_modifier(f: Callable):
"""A decorator for lazy adata setattr. This exists so that all
Expand Down Expand Up @@ -119,66 +131,6 @@ class Config(BaseConfigurable.Config):
# Processor kwargs
kwargs: Dict[str, Any] = {}

@staticmethod
def _validate_single_rep_key(val: str):
"""Validates the format of a single key (str)."""
if val is None or val in ['X', 'obs_names', 'var_names']:
return val
if '.' not in val:
raise ValueError(
"Representation keys must equal 'X' or must contain a "
"dot '.' that points to the AnnData column to use."
)
if len(parts := val.split('.')) > 2 and parts[0] != 'uns':
raise ValueError(
"There can only be one dot "
"'.' in non-uns representation keys."
)
if parts[0] not in ALLOWED_KEYS:
raise ValueError(
f"AnnData annotation key should be one of {ALLOWED_KEYS}."
)
if len(parts[1]) >= 120:
raise ValueError(
"Columns keys should be less than 120 characters."
)
return val

@field_validator('*')
def rep_format_is_correct(cls, val, info):
"""Select the representation to use. If val is str: if 'X',
will use adata.X, otherwise it must contain a dot that splits
the annotation key that will be used and the column key. E.g.,
'obsm.x_emb' will use 'adata.obsm['x_emb']'. A list of str will
be parsed as *args, and a dict of (str, str) should map a
dictionary key to the desired representation. The latter is
useful when calling, for example, predictors which require a
data representation X and labels y. In this case, X and y would
be dictionary keys and the corresponding representations for X
and y would be the values.
This validator will only check fields that end with '_key'.
"""
if not info.field_name.endswith('_key'):
return val

match val:
case str() as v:
return cls._validate_single_rep_key(v)
case [*vals]:
return [cls._validate_single_rep_key(v) for v in vals]
case {**vals}:
return {k: cls._validate_single_rep_key(v) # type: ignore
for k, v in vals.items()}
case None:
return None
case _:
raise ValueError(
f"Could not interpret format for {info.field_name}. "
"Please make sure it is a str, list[str], "
"or dict[str, str]."
)

def get_save_key_prefix(
self,
current_prefix: str,
Expand Down
24 changes: 12 additions & 12 deletions src/grinch/processors/de.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
ttest_from_mean_var,
)
from ..utils.validation import all_None, any_not_None
from .base_processor import BaseProcessor
from .base_processor import BaseProcessor, ReadKey, WriteKey

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -80,9 +80,9 @@ class Config(BaseProcessor.Config):
if TYPE_CHECKING:
create: Callable[..., 'PairwiseDETest']

x_key: str = "X"
save_key: str
group_key: str
x_key: ReadKey = "X"
group_key: ReadKey
save_key: WriteKey

is_logged: bool = True
# If the data is logged, this should point to the base of the
Expand All @@ -97,9 +97,9 @@ class Config(BaseProcessor.Config):
# control. If both `experimental` and `control` samples are given,
# will perform a single one_vs_one test between those two.
experimental_label: str | None = None
experimental_key: str | None = None
experimental_key: ReadKey | None = None
control_label: str | None = None
control_key: str | None = None
control_key: ReadKey | None = None

show_progress_bar: bool = Field(True, exclude=True)

Expand Down Expand Up @@ -202,7 +202,7 @@ class Config(PairwiseDETest.Config):
if TYPE_CHECKING:
create: Callable[..., 'TTest']

save_key: str = f"uns.{UNS.TTEST}"
save_key: WriteKey = f"uns.{UNS.TTEST}"

cfg: Config

Expand Down Expand Up @@ -262,7 +262,7 @@ class Config(PairwiseDETest.Config):
if TYPE_CHECKING:
create: Callable[..., 'KSTest']

save_key: str = f"uns.{UNS.KSTEST}"
save_key: WriteKey = f"uns.{UNS.KSTEST}"
method: str = 'auto'
alternative: str = 'two-sided'
max_workers: Optional[int] = Field(None, ge=1, le=2 * mp.cpu_count(),
Expand Down Expand Up @@ -350,13 +350,13 @@ class Config(BaseProcessor.Config):
if TYPE_CHECKING:
create: Callable[..., 'BimodalTest']

x_key: str = "X"
save_key: str = f"uns.{UNS.BIMODALTEST}"
x_key: ReadKey = "X"
save_key: WriteKey = f"uns.{UNS.BIMODALTEST}"
correction: str = 'fdr_bh'
skip_zeros: bool = False

max_workers: Optional[int] = Field(None, ge=1, le=2 * mp.cpu_count(),
exclude=True)
max_workers: int | None = Field(None, ge=1, le=2 * mp.cpu_count(),
exclude=True)

@field_validator('max_workers')
def init_max_workers(cls, val):
Expand Down
12 changes: 6 additions & 6 deletions src/grinch/processors/feature_selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from sklearn.utils import check_X_y

from ..aliases import UNS, VAR
from .base_processor import BaseProcessor
from .base_processor import BaseProcessor, ReadKey, WriteKey


class PhenotypeCover(BaseProcessor):
Expand All @@ -17,13 +17,13 @@ class Config(BaseProcessor.Config):
if TYPE_CHECKING:
create: Callable[..., 'PhenotypeCover']

x_key: str = "X"
y_key: str
feature_mask_key: str = f"var.{VAR.PCOVER_M}"
feature_importance_key: str = f"var.{VAR.PCOVER_I}"
x_key: ReadKey = "X"
y_key: ReadKey
feature_mask_key: WriteKey = f"var.{VAR.PCOVER_M}"
feature_importance_key: WriteKey = f"var.{VAR.PCOVER_I}"

save_stats: bool = True
stats_key: str = f"uns.{UNS.PCOVER_}"
stats_key: WriteKey = f"uns.{UNS.PCOVER_}"

# GreedyPC args
coverage: int
Expand Down
20 changes: 10 additions & 10 deletions src/grinch/processors/graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from ..aliases import OBSM, OBSP
from ..custom_types import NP1D_float, NP1D_int, NP2D_float
from .base_processor import BaseProcessor
from .base_processor import BaseProcessor, ReadKey, WriteKey
from .wrappers import FuzzySimplicialSet as _FuzzySimplicialSet


Expand All @@ -33,11 +33,11 @@ class Config(BaseProcessor.Config):
if TYPE_CHECKING:
create: Callable[..., 'BaseGraphConstructor']

x_key: str = f"obsm.{OBSM.X_PCA}"
conn_key: str
dist_key: str
x_key: ReadKey = f"obsm.{OBSM.X_PCA}"
conn_key: WriteKey
dist_key: WriteKey
save_stats: bool = True
stats_key: str | None = None
stats_key: WriteKey | None = None
kwargs: Dict[str, Any] = {}

cfg: Config
Expand Down Expand Up @@ -76,8 +76,8 @@ class Config(BaseGraphConstructor.Config):
if TYPE_CHECKING:
create: Callable[..., 'KNNGraph']

conn_key: str = f"obsp.{OBSP.KNN_CONNECTIVITY}"
dist_key: str = f"obsp.{OBSP.KNN_DISTANCE}"
conn_key: WriteKey = f"obsp.{OBSP.KNN_CONNECTIVITY}"
dist_key: WriteKey = f"obsp.{OBSP.KNN_DISTANCE}"
n_neighbors: int = Field(15, gt=0)
n_jobs: int = Field(4, gt=0)

Expand Down Expand Up @@ -117,9 +117,9 @@ class Config(BaseGraphConstructor.Config):
if TYPE_CHECKING:
create: Callable[..., 'FuzzySimplicialSetGraph']

conn_key: str = f"obsp.{OBSP.UMAP_CONNECTIVITY}"
dist_key: str = f"obsp.{OBSP.UMAP_DISTANCE}"
affinity_key: str = f"obsp.{OBSP.UMAP_AFFINITY}"
conn_key: WriteKey = f"obsp.{OBSP.UMAP_CONNECTIVITY}"
dist_key: WriteKey = f"obsp.{OBSP.UMAP_DISTANCE}"
affinity_key: WriteKey = f"obsp.{OBSP.UMAP_AFFINITY}"
precomputed: bool = False
n_neighbors: int = 15
metric: str = "euclidean"
Expand Down
6 changes: 3 additions & 3 deletions src/grinch/processors/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from ..custom_types import NP1D_str
from ..utils.ops import group_indices
from ..utils.validation import validate_axis
from .base_processor import BaseProcessor
from .base_processor import BaseProcessor, ReadKey, WriteKey

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -47,9 +47,9 @@ class Config(BaseProcessor.Config):

processor: BaseProcessor.Config
# Key to group by, must be recognized by np.unique.
group_key: str
group_key: ReadKey
axis: int | Literal['obs', 'var'] = Field(0, ge=0, le=1)
group_prefix: str = f'g-{{group_key}}{GROUP_SEP}{{label}}.'
group_prefix: WriteKey = f'g-{{group_key}}{GROUP_SEP}{{label}}.'
min_points_per_group: int = Field(default_factory=int, ge=0)
# Whether to drop the groups which have less than
# `min_points_per_group` points or not.
Expand Down
24 changes: 12 additions & 12 deletions src/grinch/processors/gsea.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from ..shortcuts import FWERpVal_Filter_05, log2fc_Filter_1, qVal_Filter_05
from ..utils.decorators import retry
from ..utils.validation import all_not_None
from .base_processor import BaseProcessor
from .base_processor import BaseProcessor, ReadKey, WriteKey

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -66,8 +66,8 @@ class Config(BaseProcessor.Config):
if TYPE_CHECKING:
create: Callable[..., 'GSEA']

read_key: str = f"uns.{UNS.TTEST}"
save_key: str
read_key: ReadKey = f"uns.{UNS.TTEST}"
save_key: WriteKey

gene_sets: str | List[str]
# Dict of keys to use for filtering DE genes; keys are ignored
Expand Down Expand Up @@ -209,7 +209,7 @@ class Config(GSEA.Config):
if TYPE_CHECKING:
create: Callable[..., 'GSEAEnrich']

save_key: str = f"uns.{UNS.GSEA_ENRICH}"
save_key: WriteKey = f"uns.{UNS.GSEA_ENRICH}"
gene_sets: str | List[str] = DEFAULT_GENE_SET_ENRICH
filter_by: List[Filter] = DEFAULT_ENRICH_FILTERS

Expand Down Expand Up @@ -268,7 +268,7 @@ class Config(GSEA.Config):
if TYPE_CHECKING:
create: Callable[..., 'GSEAPrerank']

save_key: str = f"uns.{UNS.GSEA_PRERANK}"
save_key: WriteKey = f"uns.{UNS.GSEA_PRERANK}"
gene_sets: str | List[str] = DEFAULT_GENE_SET_PRERANK
# By default all genes are inputted into prerank. DE tests are
# still needed in order to scale log2fc by the q-values before
Expand Down Expand Up @@ -340,11 +340,11 @@ class Config(BaseProcessor.Config):
if TYPE_CHECKING:
create: Callable[..., 'FindLeadGenes']

read_key: str = f"uns.{UNS.GSEA_PRERANK}"
all_leads_save_key: str = f"var.{VAR.IS_LEAD}"
lead_group_save_key: str = f"var.{VAR.LEAD_GROUP}"
read_key: ReadKey = f"uns.{UNS.GSEA_PRERANK}"
all_leads_save_key: WriteKey = f"var.{VAR.IS_LEAD}"
lead_group_save_key: WriteKey = f"var.{VAR.LEAD_GROUP}"
filter_by: List[Filter] = DEFAULT_FILTERS_LEAD_GENES
gene_names_key: str = "var_names"
gene_names_key: ReadKey = "var_names"

@field_validator('filter_by', mode='before')
def ensure_filter_list(cls, val):
Expand Down Expand Up @@ -403,10 +403,10 @@ class Config(BaseProcessor.Config):
gene_set: str = 'GO_Biological_Process_2023'
organism: str = 'Human'
terms: str | List[str] = '.*' # by default take all
save_key: str = f'var.{VAR.CUSTOM_LEAD_GENES}'
save_key: WriteKey = f'var.{VAR.CUSTOM_LEAD_GENES}'
regex: bool = False
all_leads_save_key: str = f'uns.{UNS.ALL_CUSTOM_LEAD_GENES}'
gene_names_key: str = "var_names"
all_leads_save_key: WriteKey = f'uns.{UNS.ALL_CUSTOM_LEAD_GENES}'
gene_names_key: ReadKey = "var_names"

@field_validator('terms')
def to_list(cls, val):
Expand Down
Loading

0 comments on commit 950fd5d

Please sign in to comment.