Skip to content

Commit

Permalink
Eh/cleanups (#103)
Browse files Browse the repository at this point in the history
* Simplify attr access

* Mostly fixed new attrs design

* Fixed failing tests
  • Loading branch information
euxhenh authored Aug 15, 2023
1 parent c95cbcc commit a2f3d58
Show file tree
Hide file tree
Showing 26 changed files with 617 additions and 588 deletions.
9 changes: 1 addition & 8 deletions src/grinch/aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class OBS:
class VAR:
N_COUNTS = auto()
N_CELLS = auto()
PCOVER_M = auto()
PCOVER = auto()
PCOVER_I = auto()
VARIANCE = auto()
ENSEMBL_ID = auto()
Expand All @@ -49,18 +49,11 @@ class VARM:
LOG_REG_COEF = auto()

class UNS:
X_PCA_ = auto()
X_TRUNCATED_SVD_ = auto()
KMEANS_ = auto()
GAUSSIAN_MIXTURE_ = auto()
LEIDEN_ = auto()
LOG_REG_ = auto()
TTEST = auto()
KSTEST = auto()
GSEA_ENRICH = auto()
GSEA_PRERANK = auto()
BIMODALTEST = auto()
PCOVER_ = auto()
KNN_ = auto()
N_GENE_ID_TO_NAME_FAILED = auto()
ALL_CUSTOM_LEAD_GENES = auto()
Expand Down
262 changes: 262 additions & 0 deletions src/grinch/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,262 @@
import inspect
import logging
from itertools import islice
from operator import itemgetter
from typing import Any, Callable, Dict

import numpy as np
from anndata import AnnData
from pydantic import validate_call

from .utils.ops import compose, safe_format

logger = logging.getLogger(__name__)


class StorageMixin:
r"""Mixin class for classes that read and store data into AnnData.
Supports reading and writing values into adata using a
'column.key_name' format. E.g., `self.read(adata, 'obs.leiden')` will read
`adata.obs['leiden']` and `self.read(adata, 'X)` will read `adata.X`.
The benefit of using this class comes from simplifying access for
composite keys such as those stored under the `uns` column. E.g., the
following is valid: `self.read(adata, 'uns.group-main.leiden-0.ttest')`
and will map to `adata.uns['group-main']['leiden-0']['ttest']`. Same
logic applies to writing objects.
In addition, this class handles prefixes for all keys. This can be used
by classes such as GroupBy to set a prefix for all downstream keys.
E.g., if running a t-Test using 'leiden' clusters as groups, one can set
the prefix for all those t-Tests to be 'TTest-leiden.' followed by the
cluster ID itself.
Finally, this class provides a storage dictionary which holds keys
*to-be-written*. This is useful when the processor wants to hold writing
save keys until the end (e.g., when a processor runs on a view of
adata, its caller (having access to the full adata) may be the one that
actually wants to do the writing by reading `self.storage`.)
Attributes
----------
prefix : str
The prefix to prepend to all save keys after the column.
storage : Dict[str, Any]
A dict mapping a key to a representation.
"""
@property
def prefix(self) -> str:
"""The prefix to prepend to all save keys."""
if not hasattr(self, '_prefix'):
self._prefix = '' # By default no prefix
return self._prefix

@prefix.setter
def prefix(self, val) -> None:
"""Updates save key prefix."""
self._prefix = val

@property
def storage(self) -> Dict[str, Any]:
"""Return a pointer to the inner storage dict.
The storage dict will be used by the `lazy_writer` decorator to
dump all the keys with the respective values into the AnnData
columns. The keys in storage should not include the upstream prefix
as this will be added here.
"""
if not hasattr(self, '_storage'):
self._storage: Dict[str, Any] = {}
return self._storage

def store_item(self, key: str, val: Any, /, add_prefix: bool = True) -> None:
"""Will store the value to a key for lazy saving into adata.
Updates the inner storage only. Does not set the value in AnnData.
Parameters
----------
key : str
The key to use for writing. Can contain dots '.'.
val : Any
The object to store.
add_prefix : bool, default=True.
Whether to prepend `self.prefix` to the key, after the column.
"""
if add_prefix:
key = StorageMixin.__insert_prefix_after_col(key, self.prefix)
self.storage[key] = val

def store_items(self, items: Dict[str, Any], add_prefix: bool = True) -> None:
"""Will store the values to their keys for lazy saving into adata.
Updates the inner storage only. Does not set the value in AnnData.
This takes a dictionary as input. To store a single key, use
`self.store_item`.
"""
if add_prefix:
items = {
StorageMixin.__insert_prefix_after_col(key, self.prefix): val
for key, val in items.items()
}
self.storage.update(items)

@staticmethod
def __insert_prefix_after_col(key: str, prefix: str):
"""Inserts prefix to the key after performing a column split.
Example
-------
If key='uns.ttest' and prefix='group-0.', then this returns
'uns.group-0.ttest'. Note the dot '.' after 'group-0'.
"""
first_key, store_keys = key.split('.', maxsplit=1)
return f'{first_key}.{prefix}{store_keys}'

def update_prefix(self, suffix: str, **kwargs) -> None:
"""Appends suffix to prefix.
This is useful for example, with recursive `GroupProcessor`'s,
where each `GroupProcessor` needs to prepend some string to its
`save_key`'s.
Parameters
----------
suffix : str
The key to append (will be formatted if it has curly brackets).
"""
self.prefix += safe_format(suffix, **kwargs)

@staticmethod
@validate_call(config=dict(arbitrary_types_allowed=True))
def read(adata: AnnData, key: str, prefix: str = '', to_numpy: bool = False):
"""Get the data representation that key points to.
'X', 'obs_names', and 'var_names' are special keys and will be
mapped to adata.X, adata.obs_names, and adata.var_names,
respectively. All other keys should contain a dot '.' that splits
the AnnData column and the key. E.g., 'obs.leiden'. We only perform
dictionary style access here. So 'obs.leiden' will be mapped to
adata.obs['leiden']. Multiple dots '.' can be used (but these are
only supported for 'uns' keys). E.g., 'uns.group-0.ttest.leiden-0'
will be mapped to adata.uns['group-0']['ttest']['leiden-0'].
Parameters
----------
adata : AnnData
The annotated data structure.
key : str
The key to read from.
prefix : str
If non-empty, will prepend to the key after the column.
to_numpy : bool, default=False
If True, will call `to_numpy()` on the retrieved object.
Returns
-------
item: Any
The object read from adata columns. This is likely going to be
a view so care should be taken when updating it inplace.
"""
if key == 'X':
return adata.X
if key in ['obs_names', 'var_names']:
return np.asarray(getattr(adata, key)).astype(str)

key = StorageMixin.__insert_prefix_after_col(key, prefix)
read_class, *read_keys = key.split('.')

# Compose attribute access for all read_keys in case there was more
# than one split.
rec_itemgetter = compose(*(itemgetter(rk) for rk in read_keys))
# We only support dictionary style access for read_keys
item = rec_itemgetter(getattr(adata, read_class))
return item if not to_numpy else item.to_numpy()

@staticmethod
@validate_call(config=dict(arbitrary_types_allowed=True))
def write(adata: AnnData, key: str, value: Any) -> None:
"""Immediately save value under the key pointed to by key.
Keys should contain a dot '.' that splits the anndata column with
the key name. E.g., `obs.leiden` will write into
`adata.obs['leiden']`. If the column is not `uns` and key contains
more than one dot, all dots after the first will be replaced with
dashes. Multiple dots are allowed when storing into `uns`. Any
non-existing intermediate dictionaries will be created on the fly.
Parameters
----------
adata : AnnData
The annotated data structure.
key : str
The key to store `value` into.
value : Any
The object to store under `key`.
"""
save_class, *save_keys = key.split('.')

# Make sure save keys are not empty
for _ in filter(lambda x: len(x) == 0, save_keys):
raise ValueError("Found 'save_key' of zero-length.")

# Make sure we have no dots in non-uns columns.
if save_class != 'uns':
if len(save_keys) > 1:
logger.warning(
"Found non-'uns' save_class, but more than one save key."
"Replacing 'save_keys' dots with dashes."
)
# Only uns allows nested dict's
save_keys = [''.join(save_keys).replace('.', '-')]

klas = getattr(adata, save_class)
# Iterate over all save keys and initialize empty dictionaries if
# the keys are not found.
while len(save_keys) > 1:
save_key = save_keys.pop(0)
if save_key not in klas:
klas[save_key] = {}
klas = klas[save_key]

save_key = save_keys.pop(0) # Last key
# If save_key points to an existing dictionary, we merge the results.
if save_key in klas and len(klas[save_key]):
if isinstance(klas[save_key], dict) and isinstance(value, dict):
klas[save_key] |= value
return
klas[save_key] = value

@staticmethod
def lazy_writer(f: Callable, /):
"""A decorator for functions that wish to lazily store items.
Will reset `self.storage` on function call. To update this, the
caller should use the `store_item` and `store_items` methods.
"""
params = inspect.signature(f).parameters
if len(params) < 2: # require 'self' and 'adata'
raise ValueError("A 'setter_method' should take at least 2 arguments.")

_, adata_parameter = next(islice(params.items(), 1, 2))
if not issubclass(AnnData, adata_parameter.annotation):
raise ValueError(
"First argument to a 'setter_method' "
"should be explicitly typed 'AnnData'."
)

def _wrapper(self: StorageMixin, adata: AnnData,
*args, **kwargs) -> Dict[str, Any] | None:
"""Call f and either write or return storage."""
self.storage.clear() # Empty storage dict before function call
if f(self, adata, *args, **kwargs) is not None: # Also calls f
logger.warning(f"Function '{f.__name__}' returned a value.")

if kwargs.get('return_storage', False):
return self.storage
for key, item in self.storage.items():
self.write(adata, key, item)
return None

return _wrapper
34 changes: 22 additions & 12 deletions src/grinch/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,21 +98,30 @@ def __repr__(self):


class BaseConfigurable(_BaseConfigurable):
r"""BaseConfigurable class with some default parameters and methods.
Args:
seed (int, optional): Random state to use for the wrapped model. It
is the job of that configurable class to make sure the seed is
used.
logs_path (Path): The path to use for saving logs.
sanity_check (bool): If ``True`` will make a quick run to ensure
that everything runs without errors.
interactive (bool): If ``True`` will run in interactive mode. It is
the job of that configurable class to make use of this.
"""
"""BaseConfigurable class with an inner Config.
To be inherited by all configurable classes.
"""
class Config(_BaseConfigurable.Config):
r"""BaseConfigurable.Config
Parameters
----------
logs_path : Path, default=Path('./grinch_logs')
The path to use for saving logs.
interactive : bool, default=False
If ``True`` will run in interactive mode. It is the job of that
configurable class to make use of this.
seed : int, default=None
Random state to use for the wrapped model. It is the job of
that configurable class to make sure the seed is used.
sanity_check : bool, default=False
If ``True`` will make a quick run to ensure that everything
runs without errors.
"""
if TYPE_CHECKING:
create: Callable[..., 'BaseConfigurable']

Expand All @@ -130,4 +139,5 @@ def model_post_init(self, _) -> None:

@property
def logs_path(self) -> Path:
"""The path to logs."""
return self.cfg.logs_path
5 changes: 2 additions & 3 deletions src/grinch/custom_types.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
from typing import Annotated, Any, Dict, List, Tuple, TypeAlias
from typing import Annotated, Any, List, Tuple, TypeAlias

import numpy as np
import numpy.typing as npt
import scipy.sparse as sp
from pydantic import Field

REP_KEY: TypeAlias = str | List[str] | Dict[str, str] | None
REP: TypeAlias = Dict[str, Any] | List[Any] | Any
REP_KEY: TypeAlias = str | List[str] | None

NP1D_Any = np.ndarray[Tuple[Any], np.dtype[Any]]
NP2D_Any = np.ndarray[Tuple[Any, Any], np.dtype[Any]]
Expand Down
2 changes: 1 addition & 1 deletion src/grinch/normalizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ class Config(BaseNormalizer.Config):
cfg: Config

def _normalize(self, adata: AnnData) -> None:
batch: pd.Series = BaseProcessor._get_repr(adata, self.cfg.batch_key)
batch: pd.Series = BaseProcessor.read(adata, self.cfg.batch_key)
if not isinstance(batch, pd.Series):
raise ValueError("Batch should be a pandas series")

Expand Down
2 changes: 0 additions & 2 deletions src/grinch/processors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
ProcessorParam,
ReadKey,
WriteKey,
adata_modifier,
)
from .de import BimodalTest, KSTest, TTest
from .feature_selection import PhenotypeCover
Expand Down Expand Up @@ -31,7 +30,6 @@
from .transformers import MDS, PCA, UMAP, BaseTransformer, TruncatedSVD

__all__ = [
'adata_modifier',
'BaseProcessor',
'ReadKey',
'WriteKey',
Expand Down
Loading

0 comments on commit a2f3d58

Please sign in to comment.