Skip to content

Commit

Permalink
Adding xgboost (#112)
Browse files Browse the repository at this point in the history
* Adding xgboost

* adding xgboost to reqs

* Remove warnings

* Remove umap warning

* fix test

* fix test

* add params for xgboost

* Switch to 2 comps

* remove xgboost

* remove arg

* Adding xgboost test

* fix seed
  • Loading branch information
euxhenh authored Dec 9, 2023
1 parent 4307d75 commit dae6b41
Show file tree
Hide file tree
Showing 7 changed files with 125 additions and 32 deletions.
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,5 @@ leidenalg
pyensembl
seaborn
harmonypy
scanpy
scanpy
xgboost
3 changes: 3 additions & 0 deletions src/grinch/aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ class OBS:
GAUSSIAN_MIXTURE_SCORE = auto()
LEIDEN = auto()
LOG_REG = auto()
XGB_CLASSIFIER = auto()

class VAR:
N_COUNTS = auto()
Expand All @@ -45,6 +46,7 @@ class OBSM:
X_UMAP = auto()
GAUSSIAN_MIXTURE_PROBA = auto()
X_HARMONY = auto()
XGB_CLASSIFIER_PROBA = auto()

class VARM:
LOG_REG_COEF = auto()
Expand All @@ -59,6 +61,7 @@ class UNS:
N_GENE_ID_TO_NAME_FAILED = auto()
ALL_CUSTOM_LEAD_GENES = auto()
PIPELINE = auto()
XGB_CLASSIFIER_SCORE = auto()

class OBSP:
KNN_CONNECTIVITY = auto()
Expand Down
2 changes: 2 additions & 0 deletions src/grinch/processors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
KMeans,
Leiden,
LogisticRegression,
XGBClassifier,
)
from .repeat import RepeatProcessor
from .splitter import DataSplitter, Splitter
Expand Down Expand Up @@ -50,6 +51,7 @@
'FuzzySimplicialSetGraph',
'Leiden',
'LogisticRegression',
'XGBClassifier',
'DataSplitter',
'RepeatProcessor',
'Splitter',
Expand Down
4 changes: 2 additions & 2 deletions src/grinch/processors/base_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
)

from anndata import AnnData
from pydantic import field_validator, validate_call
from pydantic import Field, field_validator, validate_call

from ..base import StorageMixin
from ..conf import BaseConfigurable
Expand Down Expand Up @@ -95,7 +95,7 @@ class Config(BaseConfigurable.Config):
create: Callable[..., 'BaseProcessor']

attrs_key: WriteKey | None = None
kwargs: Dict[str, ProcessorParam] = {} # Processor kwargs
kwargs: Dict[str, ProcessorParam] = Field(default_factory=dict) # Processor kwargs

# Kwargs used by the processor, but are not ProcessorParam's
__extra_processor_params__: List[str] = []
Expand Down
54 changes: 51 additions & 3 deletions src/grinch/processors/predictors.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,21 @@
import numpy as np
import pandas as pd
from anndata import AnnData
from pydantic import Field, PositiveFloat, PositiveInt, validate_call
from pydantic import (
Field,
NonNegativeInt,
PositiveFloat,
PositiveInt,
validate_call,
)
from sklearn.cluster import KMeans as _KMeans
from sklearn.linear_model import LogisticRegression as _LogisticRegression
from sklearn.mixture import BayesianGaussianMixture as _BayesianGaussianMixture
from sklearn.mixture import GaussianMixture as _GaussianMixture
from sklearn.utils import indexable
from xgboost import XGBClassifier as _XGBClassifier

from ..aliases import OBS, OBSM, OBSP
from ..aliases import OBS, OBSM, OBSP, UNS
from ..base import StorageMixin
from ..custom_types import NP1D_Any, NP1D_float
from ..utils.ops import group_indices
Expand Down Expand Up @@ -208,7 +215,7 @@ def _post_process(self, adata: AnnData) -> None:


class BaseSupervisedPredictor(BasePredictor, abc.ABC):
"""A base class for unsupervised predictors, e.g., clustering."""
"""A base class for supervised predictors, e.g., logistic regression."""
__processor_reqs__ = ['fit']

class Config(BasePredictor.Config):
Expand Down Expand Up @@ -263,3 +270,44 @@ def __init__(self, cfg: Config, /):
random_state=self.cfg.seed,
**self.cfg.kwargs,
)


class XGBClassifier(BaseSupervisedPredictor):
"""XGBoostClassifier"""
__processor_attrs__ = ['feature_importances_', 'n_features_in_']

class Config(BaseSupervisedPredictor.Config):

if TYPE_CHECKING:
create: Callable[..., 'XGBClassifier']

labels_key: WriteKey = f"obs.{OBS.XGB_CLASSIFIER}"
proba_key: WriteKey = f"obsm.{OBSM.XGB_CLASSIFIER_PROBA}"
score_key: WriteKey = f"uns.{UNS.XGB_CLASSIFIER_SCORE}"
# XGBoost kwargs
n_estimators: ProcessorParam[PositiveInt | None] = 2
max_depth: ProcessorParam[PositiveInt | None] = 1
max_leaves: ProcessorParam[NonNegativeInt] = 0 # 0 == no limit
learning_rate: ProcessorParam[PositiveFloat | None] = 1.0

cfg: Config

def __init__(self, cfg: Config, /):
super().__init__(cfg)

self.processor: _XGBClassifier = _XGBClassifier(
n_estimators=self.cfg.n_estimators,
max_depth=self.cfg.max_depth,
max_leaves=self.cfg.max_leaves,
learning_rate=self.cfg.learning_rate,
random_state=self.cfg.seed,
**self.cfg.kwargs,
)

def _post_process(self, adata: AnnData) -> None:
x = self.read(adata, self.cfg.x_key)
y = self.read(adata, self.cfg.y_key)
proba = self.processor.predict_proba(x)
score = self.processor.score(x, y)
self.store_item(self.cfg.proba_key, proba)
self.store_item(self.cfg.score_key, score)
89 changes: 63 additions & 26 deletions tests/test_predictors.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,18 @@
X = np.array([
[6, 8, 0, 0, 0],
[5, 7, 0, 0, 0],
[0, 1, 5, 6, 5],
[6, 8, 1, 0, 0],
[4, 7, 0, 0, 0],
[0, 1, 5, 6, 8],
[2, 1, 7, 9, 8],
[0, 1, 5, 6, 7],
[0, 1, 8, 6, 7],
[0, 1, 8, 6, 5],
[2, 1, 7, 8, 8],
[0, 1, 9, 6, 7],
], dtype=np.float32)

K_plus = 4

X_test = np.array([
[0, -1, 5, 6, 5],
[5, 6, 0, 1, 0],
Expand All @@ -42,9 +49,9 @@ def test_kmeans_x(X):
kmeans = cfg.create()
adata = AnnData(X)
kmeans(adata)
outp = adata.obs[OBS.KMEANS]
assert np.unique(outp[:2]).size == 1
assert np.unique(outp[2:]).size == 1
outp = adata.obs[OBS.KMEANS].to_numpy()
assert np.unique(outp[:K_plus]).size == 1
assert np.unique(outp[K_plus:]).size == 1
assert outp[0] != outp[-1]


Expand Down Expand Up @@ -74,15 +81,15 @@ def test_kmeans_x_pca(X):
cfg = instantiate(cfg)
kmeans = cfg.create()
kmeans(adata)
outp = adata.obs[OBS.KMEANS]
assert np.unique(outp[:2]).size == 1
assert np.unique(outp[2:]).size == 1
outp = adata.obs[OBS.KMEANS].to_numpy()
assert np.unique(outp[:K_plus]).size == 1
assert np.unique(outp[K_plus:]).size == 1
assert outp[0] != outp[-1]

adata_test = AnnData(X_test)
pca.transform(adata_test)
kmeans.predict(adata_test)
outp = adata_test.obs[OBS.KMEANS]
outp = adata_test.obs[OBS.KMEANS].to_numpy()
assert outp[0] == outp[2]
assert outp[0] != outp[1]

Expand All @@ -101,17 +108,20 @@ def test_gmix_x(X):
kmeans = cfg.create()
adata = AnnData(X)
kmeans(adata)
outp = adata.obs[OBS.GAUSSIAN_MIXTURE]
assert np.unique(outp[:2]).size == 1
assert np.unique(outp[2:]).size == 1
outp = adata.obs[OBS.GAUSSIAN_MIXTURE].to_numpy()
assert np.unique(outp[:K_plus]).size == 1
assert np.unique(outp[K_plus:]).size == 1
assert outp[0] != outp[-1]
proba = adata.obsm[OBSM.GAUSSIAN_MIXTURE_PROBA]
assert (proba[:2, 0] > proba[:2, 1]).all()
assert (proba[2:, 0] < proba[2:, 1]).all()
assert (proba[:K_plus, 0] > proba[:K_plus, 1]).all()
assert (proba[K_plus:, 0] < proba[K_plus:, 1]).all()


@pytest.mark.parametrize("X", X_mods_no_sparse)
def test_log_reg_x(X):
@pytest.mark.parametrize(
"classifier, key", [("LogisticRegression", OBS.LOG_REG)]
)
def test_classifiers_x(X, classifier, key):
adata = AnnData(X)
cfg_pca = OmegaConf.create(
{
Expand Down Expand Up @@ -139,38 +149,63 @@ def test_log_reg_x(X):

cfg = OmegaConf.create(
{
"_target_": "src.grinch.LogisticRegression.Config",
"_target_": f"src.grinch.{classifier}.Config",
"x_key": f"obsm.{OBSM.X_PCA}",
"y_key": f"obs.{OBS.KMEANS}",
"seed": 42,
"labels_key": f"obs.{OBS.LOG_REG}",
"labels_key": f"obs.{key}",
}
)
# Need to start using convert all for lists and dicts
cfg = instantiate(cfg, _convert_='all')
lr = cfg.create()
lr(adata)
outp = adata.obs[OBS.LOG_REG]
assert np.unique(outp[:2]).size == 1
assert np.unique(outp[2:]).size == 1
outp = adata.obs[key].to_numpy()
assert np.unique(outp[:K_plus]).size == 1
assert np.unique(outp[K_plus:]).size == 1
assert outp[0] != outp[-1]

adata_test = AnnData(X_test)
pca.transform(adata_test)
lr.predict(adata_test)
outp = adata_test.obs[OBS.LOG_REG]
outp = adata_test.obs[key].to_numpy()
assert outp[0] == outp[2]
assert outp[0] != outp[1]


def test_xgboost():
from sklearn.datasets import make_classification
X, y = make_classification(
n_samples=100, n_features=2, n_informative=2, n_redundant=0,
random_state=42, n_clusters_per_class=1, flip_y=False, class_sep=2.0)
adata = AnnData(X)
adata.obs['y'] = y

cfg = OmegaConf.create(
{
"_target_": "src.grinch.XGBClassifier.Config",
"seed": 42,
"x_key": "X",
"y_key": "obs.y",
}
)

cfg = instantiate(cfg)
obj = cfg.create()
obj(adata)
outp = adata.obs[OBS.XGB_CLASSIFIER].to_numpy()
# < 5% error
assert (y != outp).mean() < 0.05 or (y != 1 - outp).mean() < 0.05


@pytest.mark.parametrize("X", X_mods)
def test_leiden(X):
adata = AnnData(X)
cfg_knn = OmegaConf.create(
{
"_target_": "src.grinch.KNNGraph.Config",
"x_key": "X",
"n_neighbors": 1,
"n_neighbors": 3,
}
)
cfg_knn = instantiate(cfg_knn)
Expand All @@ -182,20 +217,22 @@ def test_leiden(X):
"_target_": "src.grinch.Leiden.Config",
"x_key": f"obsp.{OBSP.KNN_DISTANCE}",
"seed": 42,
"resolution": 0.5,
}
)
cfg = instantiate(cfg)
leiden = cfg.create()
leiden(adata)
pred = adata.obs[OBS.LEIDEN]
true = np.array([0, 0, 1, 1, 1])
pred = adata.obs[OBS.LEIDEN].to_numpy()
true = np.ones(X.shape[0])
true[:K_plus] = 0
if pred[0] == 1:
true = 1 - true
assert_allclose(pred, true)

centroids = {
pred[0]: np.ravel(X[:2].mean(axis=0)),
1 - pred[0]: np.ravel(X[2:].mean(axis=0)),
pred[0]: np.ravel(X[:K_plus].mean(axis=0)),
1 - pred[0]: np.ravel(X[K_plus:].mean(axis=0)),
}
pred_centroid = adata.uns['leiden_']["cluster_centers_"]
assert_allclose(centroids[0], pred_centroid['0'])
Expand Down
2 changes: 2 additions & 0 deletions tests/test_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ def test_umap(X):
# things happening with spectral initialization and reproducibility
'kwargs': {
'init': 'random',
'n_jobs': 1, # since using random seed
}
}
)
Expand All @@ -98,6 +99,7 @@ def test_umap(X):
random_state=SEED,
transform_seed=SEED,
init='random',
n_jobs=1,
)
adata = AnnData(X)
up = cfg.create()
Expand Down

0 comments on commit dae6b41

Please sign in to comment.