Skip to content

Commit

Permalink
Fixed all not None
Browse files Browse the repository at this point in the history
  • Loading branch information
euxhenh committed Aug 10, 2023
1 parent 2159d15 commit 3186f17
Show file tree
Hide file tree
Showing 5 changed files with 146 additions and 168 deletions.
178 changes: 111 additions & 67 deletions src/grinch/cond_filter.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,22 @@
import logging
from typing import Any, Literal, overload
from typing import Any, Generic, Literal, TypeVar, overload

import numpy as np
from pydantic import BaseModel, model_validator, validate_call
from pydantic import BaseModel, NonNegativeInt, model_validator, validate_call
from sklearn.utils import column_or_1d

from .custom_types import NP1D_bool, NP1D_float, NP1D_int
from .utils.validation import only_one_not_None
from .custom_types import NP1D_bool, NP1D_int, PercentFraction
from .utils.validation import any_not_None

logger = logging.getLogger(__name__)

T = TypeVar("T", int, float, bool, str)

class Filter(BaseModel):
"""Takes any object and looks for 'key' in its members. It then selects

class Filter(BaseModel, Generic[T]):
"""Selects and returns item indices based on criteria.
Takes any object and looks for 'key' in its members. It then selects
indices from 'key' based on the conditions defined in this class. If
cutoff is not None, will take all values greater than or less than
'cutoff'. If top_k is not None, will take the top k greatest (smallest)
Expand All @@ -25,87 +29,122 @@ class Filter(BaseModel):
If key is None, will assume the passed object to call is the array to
filter itself.
P
To take a mask of True or False, simply use gt=False or lt=True.
Parameters
----------
key: str
If not None, will search in obj for a member named as `key`.
ge, le, gt, lt: T
Greater than or less than in either strict or non-strict mode.
top_k, bot_k: int
Top or bottom k items to pick.
top_ratio, bot_ratio: float
A percent fraction betwen 0 and 1. Will round up to the nearest
item.
Examples
--------
>>> f1 = Filter(gt=3)
>>> f1([1, 2, 3, 4, 5, 6], as_mask=True)
array([False, False, False, True, True, True])
>>> f1([5, 4, 6, 3, 2], as_mask=False)
array([0, 1, 2])
>>> f2 = Filter(top_k=2)
>>> f2([7, 1, 2, 5, 6, 8], as_mask=False)
array([5, 0])
>>> f3 = Filter(bot_ratio=0.4)
>>> f3([1, 7, 5, 3, 4], as_mask=False)
array([0, 3])
>>> f = f1 & f2 # Take greater than 3, but no more than 2 elements
>>> f([2, 4, 3, 5, 6, 0, 1, 7], as_mask=False)
array([4, 7])
"""
model_config = {
'validate_assignment': True,
'validate_default': True,
'extra': 'forbid',
}

key: str | None = None
ge: float | None = None # greater than or equal
le: float | None = None # less than or equal
gt: float | None = None # greater than
lt: float | None = None # less than
top_k: int | None = None # top k items after sorting
bot_k: int | None = None # bottom k items after sorting
top_ratio: float | None = None # top fraction of items
bot_ratio: float | None = None # bottom fraction of items
key: str | None = None # Set to None if passing a container

ordered: bool = False
dtype: Literal['float', 'bool', 'int', 'str'] = 'float'
ge: T | None = None # greater than or equal
le: T | None = None # less than or equal
gt: T | None = None # greater than
lt: T | None = None # less than

top_k: NonNegativeInt | None = None # top k items after sorting
bot_k: NonNegativeInt | None = None # bottom k items after sorting
# These will be rounded up to the nearest item
top_ratio: PercentFraction | None = None # top fraction of items
bot_ratio: PercentFraction | None = None # bottom fraction of items

@model_validator(mode='before')
def only_one_not_None(cls, data):
def at_most_one_not_None(cls, data):
to_check = ['ge', 'le', 'gt', 'lt', 'top_k', 'bot_k', 'top_ratio', 'bot_ratio']
if not only_one_not_None(*(data[key] for key in to_check)):
if sum(data.get(key, None) is not None for key in to_check) > 1:
raise ValueError(

Check warning on line 89 in src/grinch/cond_filter.py

View check run for this annotation

Codecov / codecov/patch

src/grinch/cond_filter.py#L89

Added line #L89 was not covered by tests
"Only one filter key should not be None. If more than "
"At most one filter key should not be None. If more than "
"one key is desired, then stack multiple filters together."
)
return data

def __and__(self, other) -> 'StackedFilter':
return StackedFilter(self, other)

def _take_top_k(self, arr: NP1D_float, as_mask: bool = True):
"""Takes the top k elements from arr and returns a mask or index
array. If these elements need to be sorted, pass ordered=True.
"""
if self.top_k is None:
raise ValueError("Expected integer but 'top_k' is None.")
if self.top_k > len(arr):
raise ValueError(f"Requested {self.top_k} items but array has size {len(arr)}.")

if self.greater_is_True:
arr = -arr
# argpartition is faster if we don't care about the order
idx = np.argsort(arr) if self.ordered else np.argpartition(arr, self.top_k)
idx = idx[:self.top_k]
@staticmethod
def _take_k_functional(arr, k: NonNegativeInt, as_mask: bool, top: bool):
if k > (n := len(arr)):
logger.warning(f"Requested {k} items but array has size {n}.")

Check warning on line 101 in src/grinch/cond_filter.py

View check run for this annotation

Codecov / codecov/patch

src/grinch/cond_filter.py#L101

Added line #L101 was not covered by tests
argidx = np.argsort(arr)
# Flip so that we start with greatest
idx = np.flip(argidx[-k:]) if top else argidx[:k]

if not as_mask:
if self.ordered:
logger.warning("'ordered=True' will be ignored when returning mask.")
return idx

mask = np.full_like(arr, False, dtype=bool)
mask[idx] = True
return mask

def _take_cutoff(self, arr: NP1D_float, as_mask: bool = True):
"""Takes the elements which are greater than or less than cutoff
depending on the value of greater_is_True.
def _take_k(self, arr, as_mask: bool = True):
"""Take top or bot k items.
"""
top = self.top_k is not None
k = self.top_k if top else self.bot_k
assert k is not None
return self._take_k_functional(arr, k, as_mask, top)

def _take_ratio(self, arr, as_mask: bool = True):
"""Take top or bot fraction of items.
"""
top = self.top_ratio is not None
ratio = self.top_ratio if top else self.bot_ratio
assert ratio is not None
k = int(np.ceil(ratio * len(arr))) # round up
return self._take_k_functional(arr, k, as_mask, top)

def _take_cutoff(self, arr, as_mask: bool = True):
"""Takes the elements which are greater than or less than cutoff.
"""
if self.cutoff is None:
raise ValueError("Expected float but 'cutoff' is None.")

mask = arr >= self.cutoff if self.greater_is_True else arr <= self.cutoff
if as_mask:
if self.ordered:
logger.warning("'ordered=True' will be ignored when returning mask.")
return mask

idx = np.argwhere(mask).ravel()
if self.ordered:
idx = idx[np.argsort(arr[idx])] # Sort idx based on arr
return np.flip(idx) if self.greater_is_True else idx

def _take_mask(self, arr: NP1D_float | NP1D_bool, as_mask: bool = True):
"""Assumes arr is a mask."""
if not arr.dtype == bool:
logger.warning("Array type is not boolean. Converting to bool...")
arr = arr.astype(bool)
return arr if as_mask else np.argwhere(arr).ravel() # type: ignore
assert any_not_None(self.gt, self.ge, self.lt, self.le)
top = any_not_None(self.gt, self.ge)
strict = any_not_None(self.gt, self.lt)

match top, strict:
case True, True:
mask = arr > self.gt
case True, False:
mask = arr >= self.ge
case False, True:
mask = arr < self.lt

Check warning on line 143 in src/grinch/cond_filter.py

View check run for this annotation

Codecov / codecov/patch

src/grinch/cond_filter.py#L143

Added line #L143 was not covered by tests
case False, False:
mask = arr <= self.le

return mask if as_mask else np.argwhere(mask).ravel()

@staticmethod
def _get_repr(obj: Any, key: str) -> Any:
Expand Down Expand Up @@ -133,13 +172,18 @@ def __call__(self, obj, as_mask=True):
"""
if self.key is not None:
obj = self._get_repr(obj, self.key)
arr: NP1D_float | NP1D_bool = column_or_1d(obj).astype(self.dtype)
if self.cutoff is not None:
return self._take_cutoff(arr.astype(float), as_mask=as_mask)
elif self.top_k is not None:
return self._take_top_k(arr.astype(float), as_mask=as_mask)
# default to a mask
return self._take_mask(arr, as_mask=as_mask)

arr: np.ndarray[T, Any] = column_or_1d(obj)

if any_not_None(self.ge, self.gt, self.le, self.lt):
return self._take_cutoff(arr, as_mask)
if any_not_None(self.top_k, self.bot_k):
return self._take_k(arr, as_mask)
if any_not_None(self.top_ratio, self.bot_ratio):
return self._take_ratio(arr, as_mask)

# Default to taking True
return arr.astype(bool) if as_mask else np.argwhere(arr).ravel()


class StackedFilter:
Expand Down
5 changes: 4 additions & 1 deletion src/grinch/custom_types.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import inspect
from operator import attrgetter
from typing import Any, Dict, List, Tuple, TypeAlias
from typing import Annotated, Any, Dict, List, Tuple, TypeAlias

import numpy as np
import numpy.typing as npt
from pydantic import Field

REP_KEY: TypeAlias = str | List[str] | Dict[str, str] | None
REP: TypeAlias = Dict[str, Any] | List[Any] | Any
Expand All @@ -21,6 +22,8 @@
NP_int = npt.NDArray[np.int_]
NP_float = npt.NDArray[np.float_]

PercentFraction = Annotated[float, Field(ge=0, le=1)]


def optional_staticmethod(klas: str, special_args: Dict[str, str]):
"""Marks a method as optionally static. If the method is called from an
Expand Down
108 changes: 12 additions & 96 deletions src/grinch/shortcuts.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,103 +4,19 @@

from .cond_filter import Filter


pVal_Filter_01 = partial(
Filter,
key='pvals',
cutoff=0.01,
greater_is_True=False,
dtype='float',
)

pVal_Filter_05 = partial(
Filter,
key='pvals',
cutoff=0.05,
greater_is_True=False,
dtype='float',
)

qVal_Filter_01 = partial(
Filter,
key='qvals',
cutoff=0.01,
greater_is_True=False,
dtype='float',
)

qVal_Filter_05 = partial(
Filter,
key='qvals',
cutoff=0.05,
greater_is_True=False,
dtype='float',
)

log2fc_Filter_1 = partial(
Filter,
key='log2fc',
cutoff=1,
greater_is_True=True,
dtype='float',
)

log2fc_Filter_m1 = partial(
Filter,
key='log2fc',
cutoff=-1,
greater_is_True=False,
dtype='float',
)

log2fc_Filter_2 = partial(
Filter,
key='log2fc',
cutoff=2,
greater_is_True=True,
dtype='float',
)

log2fc_Filter_m2 = partial(
Filter,
key='log2fc',
cutoff=-2,
greater_is_True=False,
dtype='float',
)

abs_log2fc_Filter_1 = partial(
Filter,
key='abs_log2fc',
cutoff=1,
greater_is_True=True,
dtype='float',
)

abs_log2fc_Filter_2 = partial(
Filter,
key='abs_log2fc',
cutoff=2,
greater_is_True=True,
dtype='float',
)

pVal_Filter_01 = partial(Filter, key='pvals', le=0.01)
pVal_Filter_05 = partial(Filter, key='pvals', le=0.05)
qVal_Filter_01 = partial(Filter, key='qvals', le=0.01)
qVal_Filter_05 = partial(Filter, key='qvals', le=0.05)
log2fc_Filter_1 = partial(Filter, key='log2fc', ge=1)
log2fc_Filter_m1 = partial(Filter, key='log2fc', le=-1)
log2fc_Filter_2 = partial(Filter, key='log2fc', ge=2)
log2fc_Filter_m2 = partial(Filter, key='log2fc', le=-2)
abs_log2fc_Filter_1 = partial(Filter, key='abs_log2fc', ge=1)
abs_log2fc_Filter_2 = partial(Filter, key='abs_log2fc', ge=2)
# For lead gene discovery in a GSEA prerank test
FDRqVal_Filter_05 = partial(
Filter,
key='FDR q-val',
cutoff=0.05,
greater_is_True=False,
dtype='float',
)

FWERpVal_Filter_05 = partial(
Filter,
key='FWER p-val',
cutoff=0.05,
greater_is_True=False,
dtype='float',
)
FDRqVal_Filter_05 = partial(Filter, key='FDR q-val', le=0.05)
FWERpVal_Filter_05 = partial(Filter, key='FWER p-val', le=0.05)

__all__ = [
'pVal_Filter_01',
Expand Down
20 changes: 18 additions & 2 deletions src/grinch/utils/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,28 @@ def all_None(*args):


def all_not_None(*args):
"""Returns True if all items are not None."""
"""Returns True if all items are not None.
Examples
--------
>>> all_not_None(None, 1, 2)
False
>>> all_not_None(5, "bar")
True
"""
return sum(arg is None for arg in args) == 0


def only_one_not_None(*args):
"""Returns True if there is exactly one item that is not None."""
"""Returns True if there is exactly one item that is not None.
Examples
--------
>>> only_one_not_None(None, 1, 'bar')
False
>>> only_one_not_None(None, 'foo', None)
True
"""
return sum(arg is not None for arg in args) == 1


Expand Down
Loading

0 comments on commit 3186f17

Please sign in to comment.