Skip to content

Commit

Permalink
Merge pull request #6 from wakamezake/features/document
Browse files Browse the repository at this point in the history
Features/document
  • Loading branch information
wakame1367 authored Mar 26, 2020
2 parents 55ce706 + 9f8f3c3 commit b11507b
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 72 deletions.
137 changes: 65 additions & 72 deletions optcat/core.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,22 @@
import logging
from typing import Dict, Any, Optional, Union, Iterable

from typing import Dict, Any, Optional, Union
import numpy as np
import catboost as cb
from optuna import distributions
from optuna import samplers
from optuna import study as study_module
from optuna import trial as trial_module
from sklearn.base import RegressorMixin, ClassifierMixin
from sklearn.model_selection import BaseCrossValidator, check_cv
from sklearn.model_selection import check_cv

CVType = Union[int, Iterable, BaseCrossValidator]
from .typing import (
CVType,
MultipleDataType,
TargetDataType,
PairsType,
SampleWeightType,
FeatureType,
)


# https://catboost.ai/docs/references/eval-metric__supported-metrics.html
Expand Down Expand Up @@ -123,7 +130,7 @@ def _get_params(self, trial: trial_module.Trial) -> Dict[str, Any]:
class CatBoostBase(cb.CatBoost):
def __init__(
self,
params,
params: Dict[str, Any],
refit: bool = False,
cv: CVType = 5,
n_trials: int = 20,
Expand All @@ -141,31 +148,31 @@ def __init__(

def fit(
self,
X,
y=None,
cat_features=None,
text_features=None,
pairs=None,
sample_weight=None,
group_id=None,
group_weight=None,
subgroup_id=None,
pairs_weight=None,
baseline=None,
use_best_model=None,
eval_set=None,
verbose=None,
logging_level=None,
plot=False,
column_description=None,
verbose_eval=None,
metric_period=None,
silent=None,
early_stopping_rounds=None,
save_snapshot=None,
snapshot_file=None,
snapshot_interval=None,
init_model=None,
X: MultipleDataType,
y: Optional[TargetDataType] = None,
cat_features: Optional[FeatureType] = None,
text_features: Optional[FeatureType] = None,
pairs: Optional[PairsType] = None,
sample_weight: SampleWeightType = None,
group_id: Optional[FeatureType] = None,
group_weight: Optional[FeatureType] = None,
subgroup_id: Optional[FeatureType] = None,
pairs_weight: Optional[FeatureType] = None,
baseline: Optional[FeatureType] = None,
use_best_model: Optional[bool] = None,
eval_set: Optional[cb.Pool] = None,
verbose: Optional[Union[bool, int]] = None,
logging_level: Optional[str] = None,
plot: bool = False,
column_description: Optional[str] = None,
verbose_eval: Optional[Union[bool, int]] = None,
metric_period: Optional[int] = None,
silent: Optional[bool] = None,
early_stopping_rounds: Optional[int] = None,
save_snapshot: Optional[bool] = None,
snapshot_file: Optional[str] = None,
snapshot_interval: Optional[int] = None,
init_model: Optional[str] = None,
):
logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -199,7 +206,7 @@ def fit(
init_model,
)

n_samples, _ = X.shape
n_samples = len(X)
# get_params
params = train_params["params"]
eval_name = params.get("loss_function")
Expand Down Expand Up @@ -241,11 +248,30 @@ def fit(
def _refit(self):
pass

def predict(
self,
data: MultipleDataType,
prediction_type: str = "RawFormulaVal",
ntree_start: int = 0,
ntree_end: int = 0,
thread_count: int = -1,
verbose: Optional[bool] = None,
) -> np.ndarray:
return self._predict(
data,
"RawFormulaVal",
ntree_start,
ntree_end,
thread_count,
verbose,
"predict",
)


class CatBoostClassifier(CatBoostBase, ClassifierMixin):
def __init__(
self,
params,
params: Dict[str, Any],
refit: bool = False,
cv: CVType = 5,
n_trials: int = 20,
Expand All @@ -263,28 +289,14 @@ def __init__(
timeout=timeout,
)

def predict(
def predict_proba(
self,
data,
prediction_type="RawFormulaVal",
ntree_start=0,
ntree_end=0,
thread_count=-1,
verbose=None,
):
return self._predict(
data,
prediction_type,
ntree_start,
ntree_end,
thread_count,
verbose,
"predict",
)

def predict_proba(
self, data, ntree_start=0, ntree_end=0, thread_count=-1, verbose=None
):
ntree_start: int = 0,
ntree_end: int = 0,
thread_count: int = -1,
verbose: Optional[bool] = None,
) -> np.ndarray:
return self._predict(
data,
"Probability",
Expand All @@ -299,7 +311,7 @@ def predict_proba(
class CatBoostRegressor(CatBoostBase, RegressorMixin):
def __init__(
self,
params,
params: Dict[str, Any],
refit: bool = False,
cv: CVType = 5,
n_trials: int = 20,
Expand All @@ -316,22 +328,3 @@ def __init__(
study=study,
timeout=timeout,
)

def predict(
self,
data,
prediction_type="RawFormulaVal",
ntree_start=0,
ntree_end=0,
thread_count=-1,
verbose=None,
):
return self._predict(
data,
"RawFormulaVal",
ntree_start,
ntree_end,
thread_count,
verbose,
"predict",
)
16 changes: 16 additions & 0 deletions optcat/typing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from typing import Union, Iterable, List

import catboost as cb
import numpy as np
import pandas as pd
from scipy.sparse import spmatrix
from sklearn.model_selection import BaseCrossValidator

CVType = Union[int, Iterable, BaseCrossValidator]
TargetDataType = Union[cb.Pool, np.ndarray, pd.DataFrame, pd.Series]
TwoDimFeatureType = Union[List, pd.DataFrame, pd.Series]
TwoDimSparseType = Union[pd.SparseDataFrame, spmatrix]
MultipleDataType = Union[cb.Pool, TwoDimFeatureType, TwoDimSparseType]
FeatureType = Union[List, np.ndarray]
PairsType = Union[FeatureType, pd.DataFrame]
SampleWeightType = Union[np.ndarray, TwoDimFeatureType]

0 comments on commit b11507b

Please sign in to comment.