diff --git a/benchmarks/bench_tree_nocats.py b/benchmarks/bench_tree_nocats.py new file mode 100644 index 0000000000000..7710da14d1f98 --- /dev/null +++ b/benchmarks/bench_tree_nocats.py @@ -0,0 +1,128 @@ +from itertools import product +from timeit import timeit + +import numpy as np +import pandas as pd + +from sklearn.datasets import fetch_openml +from sklearn.ensemble import ExtraTreesClassifier, RandomForestClassifier +from sklearn.metrics import roc_auc_score +from sklearn.model_selection import StratifiedKFold +from sklearn.preprocessing import OneHotEncoder + + +def get_data(trunc_ncat): + # the data is located here: https://www.openml.org/d/4135 + X, y = fetch_openml(data_id=4135, return_X_y=True) + X = pd.DataFrame(X) + + Xdicts = [] + for trunc in trunc_ncat: + X_trunc = X % trunc if trunc > 0 else X + keep_idx = np.array( + [idx[0] for idx in X_trunc.groupby(list(X.columns)).groups.values()] + ) + X_trunc = X_trunc.values[keep_idx] + y_trunc = y[keep_idx] + + X_ohe = OneHotEncoder(categories="auto").fit_transform(X_trunc) + + Xdicts.append({"X": X_trunc, "y": y_trunc, "ohe": False, "trunc": trunc}) + Xdicts.append({"X": X_ohe, "y": y_trunc, "ohe": True, "trunc": trunc}) + + return Xdicts + + +# Training dataset +trunc_factor = [2, 3, 4, 5, 6, 8, 10, 12, 14, 16, 64, 0] +data = get_data(trunc_factor) +results = [] +# Loop over classifiers and datasets +for Xydict, clf_type in product(data, [RandomForestClassifier, ExtraTreesClassifier]): + # Can't use non-truncated categorical data with RandomForest + # and it becomes intractable with too many categories + if ( + clf_type is RandomForestClassifier + and not Xydict["ohe"] + and (not Xydict["trunc"] or Xydict["trunc"] > 16) + ): + continue + + X, y = Xydict["X"], Xydict["y"] + tech = "One-hot" if Xydict["ohe"] else "NOCATS" + trunc = "truncated({})".format(Xydict["trunc"]) if Xydict["trunc"] > 0 else "full" + cat = "none" if Xydict["ohe"] else "all" + cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=17).split(X, y) + + traintimes = [] + testtimes = [] + aucs = [] + name = "({}, {}, {})".format(clf_type.__name__, trunc, tech) + + for train, test in cv: + # Train + clf = clf_type( + n_estimators=10, + max_features=None, + min_samples_leaf=1, + random_state=23, + bootstrap=False, + max_depth=None, + categorical=cat, + ) + + traintimes.append( + timeit( + "clf.fit(X[train], y[train])".format(), + "from __main__ import clf, X, y, train", + number=1, + ) + ) + + """ + # Check that all leaf nodes are pure + for est in clf.estimators_: + leaves = est.tree_.children_left < 0 + print(np.max(est.tree_.impurity[leaves])) + #assert(np.all(est.tree_.impurity[leaves] == 0)) + """ + + # Test + probs = [] + testtimes.append( + timeit( + "probs.append(clf.predict_proba(X[test]))", + "from __main__ import probs, clf, X, test", + number=1, + ) + ) + + aucs.append(roc_auc_score(y[test], probs[0][:, 1])) + + traintimes = np.array(traintimes) + testtimes = np.array(testtimes) + aucs = np.array(aucs) + results.append( + [ + name, + traintimes.mean(), + traintimes.std(), + testtimes.mean(), + testtimes.std(), + aucs.mean(), + aucs.std(), + ] + ) + + results_df = pd.DataFrame(results) + results_df.columns = [ + "name", + "train time mean", + "train time std", + "test time mean", + "test time std", + "auc mean", + "auc std", + ] + results_df = results_df.set_index("name") + print(results_df) diff --git a/sklearn/ensemble/_forest.py b/sklearn/ensemble/_forest.py index 7e494b0e9bccc..8fe042f91aa44 100644 --- a/sklearn/ensemble/_forest.py +++ b/sklearn/ensemble/_forest.py @@ -725,7 +725,7 @@ def predict_quantiles(self, X, quantiles=0.5, method="nearest"): ---------- X : {array-like, sparse matrix} of shape (n_samples, n_features) Input data. - quantiles : float, optional + quantiles : array-like, float, optional The quantiles at which to evaluate, by default 0.5 (median). method : str, optional The method to interpolate, by default 'linear'. Can be any keyword @@ -746,7 +746,7 @@ def predict_quantiles(self, X, quantiles=0.5, method="nearest"): X = self._validate_X_predict(X) if not isinstance(quantiles, (np.ndarray, list)): - quantiles = np.array([quantiles]) + quantiles = np.atleast_1d(np.array(quantiles)) # if we trained a binning tree, then we should re-bin the data # XXX: this is inefficient and should be improved to be in line with what @@ -777,15 +777,15 @@ def predict_quantiles(self, X, quantiles=0.5, method="nearest"): # (n_total_leaf_samples, n_outputs) leaf_node_samples = np.vstack( - ( + [ est.leaf_nodes_samples_[leaf_nodes[jdx]] for jdx, est in enumerate(self.estimators_) - ) + ] ) # get quantiles across all leaf node samples y_hat[idx, ...] = np.quantile( - leaf_node_samples, quantiles, axis=0, interpolation=method + leaf_node_samples, quantiles, axis=0, method=method ) if is_classifier(self): @@ -1550,6 +1550,17 @@ class RandomForestClassifier(ForestClassifier): .. versionadded:: 1.4 + categorical : array-like or str + Array of feature indices, boolean array of length n_features, + ``'all'`` or `None`. Indicates which features should be + considered as categorical rather than ordinal. For decision trees, + the maximum number of categories is 64. In practice, the limit will + often be lower because the process of searching for the best possible + split grows exponentially with the number of categories. However, a + shortcut due to Breiman (1984) is used when fitting data with binary + labels using the ``Gini`` or ``Entropy`` criteria. In this case, + the runtime is linear in the number of categories. + Attributes ---------- estimator_ : :class:`~sklearn.tree.DecisionTreeClassifier` @@ -1693,6 +1704,7 @@ def __init__( max_bins=None, store_leaf_values=False, monotonic_cst=None, + categorical=None, ): super().__init__( estimator=DecisionTreeClassifier(), @@ -1710,6 +1722,7 @@ def __init__( "ccp_alpha", "store_leaf_values", "monotonic_cst", + "categorical", ), bootstrap=bootstrap, oob_score=oob_score, @@ -1733,6 +1746,7 @@ def __init__( self.min_impurity_decrease = min_impurity_decrease self.monotonic_cst = monotonic_cst self.ccp_alpha = ccp_alpha + self.categorical = categorical class RandomForestRegressor(ForestRegressor): @@ -1935,6 +1949,17 @@ class RandomForestRegressor(ForestRegressor): .. versionadded:: 1.4 + categorical : array-like or str + Array of feature indices, boolean array of length n_features, + ``'all'`` or `None`. Indicates which features should be + considered as categorical rather than ordinal. For decision trees, + the maximum number of categories is 64. In practice, the limit will + often be lower because the process of searching for the best possible + split grows exponentially with the number of categories. However, a + shortcut due to Breiman (1984) is used when fitting data with binary + labels using the ``Gini`` or ``Entropy`` criteria. In this case, + the runtime is linear in the number of categories. + Attributes ---------- estimator_ : :class:`~sklearn.tree.DecisionTreeRegressor` @@ -2065,6 +2090,7 @@ def __init__( max_bins=None, store_leaf_values=False, monotonic_cst=None, + categorical=None, ): super().__init__( estimator=DecisionTreeRegressor(), @@ -2082,6 +2108,7 @@ def __init__( "ccp_alpha", "store_leaf_values", "monotonic_cst", + "categorical", ), bootstrap=bootstrap, oob_score=oob_score, @@ -2104,6 +2131,7 @@ def __init__( self.min_impurity_decrease = min_impurity_decrease self.ccp_alpha = ccp_alpha self.monotonic_cst = monotonic_cst + self.categorical = categorical class ExtraTreesClassifier(ForestClassifier): @@ -2316,24 +2344,16 @@ class ExtraTreesClassifier(ForestClassifier): .. versionadded:: 1.4 - monotonic_cst : array-like of int of shape (n_features), default=None - Indicates the monotonicity constraint to enforce on each feature. - - 1: monotonically increasing - - 0: no constraint - - -1: monotonically decreasing - - If monotonic_cst is None, no constraints are applied. - - Monotonicity constraints are not supported for: - - multiclass classifications (i.e. when `n_classes > 2`), - - multioutput classifications (i.e. when `n_outputs_ > 1`), - - classifications trained on data with missing values. - - The constraints hold over the probability of the positive class. - - Read more in the :ref:`User Guide `. - - .. versionadded:: 1.4 + categorical : array-like or str + Array of feature indices, boolean array of length n_features, + ``'all'`` or `None`. Indicates which features should be + considered as categorical rather than ordinal. For decision trees, + the maximum number of categories is 64. In practice, the limit will + often be lower because the process of searching for the best possible + split grows exponentially with the number of categories. However, a + shortcut due to Breiman (1984) is used when fitting data with binary + labels using the ``Gini`` or ``Entropy`` criteria. In this case, + the runtime is linear in the number of categories. Attributes ---------- @@ -2467,6 +2487,7 @@ def __init__( max_bins=None, store_leaf_values=False, monotonic_cst=None, + categorical=None, ): super().__init__( estimator=ExtraTreeClassifier(), @@ -2484,6 +2505,7 @@ def __init__( "ccp_alpha", "store_leaf_values", "monotonic_cst", + "categorical", ), bootstrap=bootstrap, oob_score=oob_score, @@ -2507,6 +2529,7 @@ def __init__( self.min_impurity_decrease = min_impurity_decrease self.ccp_alpha = ccp_alpha self.monotonic_cst = monotonic_cst + self.categorical = categorical class ExtraTreesRegressor(ForestRegressor): @@ -2704,6 +2727,17 @@ class ExtraTreesRegressor(ForestRegressor): .. versionadded:: 1.4 + categorical : array-like or str + Array of feature indices, boolean array of length n_features, + ``'all'`` or `None`. Indicates which features should be + considered as categorical rather than ordinal. For decision trees, + the maximum number of categories is 64. In practice, the limit will + often be lower because the process of searching for the best possible + split grows exponentially with the number of categories. However, a + shortcut due to Breiman (1984) is used when fitting data with binary + labels using the ``Gini`` or ``Entropy`` criteria. In this case, + the runtime is linear in the number of categories. + Attributes ---------- estimator_ : :class:`~sklearn.tree.ExtraTreeRegressor` @@ -2819,6 +2853,7 @@ def __init__( max_bins=None, store_leaf_values=False, monotonic_cst=None, + categorical=None, ): super().__init__( estimator=ExtraTreeRegressor(), @@ -2836,6 +2871,7 @@ def __init__( "ccp_alpha", "store_leaf_values", "monotonic_cst", + "categorical", ), bootstrap=bootstrap, oob_score=oob_score, @@ -2858,6 +2894,7 @@ def __init__( self.min_impurity_decrease = min_impurity_decrease self.ccp_alpha = ccp_alpha self.monotonic_cst = monotonic_cst + self.categorical = categorical class RandomTreesEmbedding(TransformerMixin, BaseForest): @@ -2969,6 +3006,17 @@ class RandomTreesEmbedding(TransformerMixin, BaseForest): new forest. See :term:`Glossary ` and :ref:`gradient_boosting_warm_start` for details. + categorical : array-like or str + Array of feature indices, boolean array of length n_features, + ``'all'`` or `None`. Indicates which features should be + considered as categorical rather than ordinal. For decision trees, + the maximum number of categories is 64. In practice, the limit will + often be lower because the process of searching for the best possible + split grows exponentially with the number of categories. However, a + shortcut due to Breiman (1984) is used when fitting data with binary + labels using the ``Gini`` or ``Entropy`` criteria. In this case, + the runtime is linear in the number of categories. + Attributes ---------- estimator_ : :class:`~sklearn.tree.ExtraTreeRegressor` instance @@ -3073,6 +3121,7 @@ def __init__( verbose=0, warm_start=False, store_leaf_values=False, + categorical=None, ): super().__init__( estimator=ExtraTreeRegressor(), @@ -3088,6 +3137,7 @@ def __init__( "min_impurity_decrease", "random_state", "store_leaf_values", + "categorical", ), bootstrap=False, oob_score=False, @@ -3106,6 +3156,7 @@ def __init__( self.max_leaf_nodes = max_leaf_nodes self.min_impurity_decrease = min_impurity_decrease self.sparse_output = sparse_output + self.categorical = categorical def _set_oob_score_and_attributes(self, X, y, scoring_function=None): raise NotImplementedError("OOB score not supported by tree embedding") diff --git a/sklearn/ensemble/_gb.py b/sklearn/ensemble/_gb.py index 3a14da52047ad..00af00cffa267 100644 --- a/sklearn/ensemble/_gb.py +++ b/sklearn/ensemble/_gb.py @@ -139,6 +139,7 @@ class BaseGradientBoosting(BaseEnsemble, metaclass=ABCMeta): _parameter_constraints.pop("store_leaf_values") _parameter_constraints.pop("splitter") _parameter_constraints.pop("monotonic_cst") + _parameter_constraints.pop("categorical") @abstractmethod def __init__( diff --git a/sklearn/ensemble/tests/test_categorical_forest.py b/sklearn/ensemble/tests/test_categorical_forest.py new file mode 100644 index 0000000000000..4f370c2f43463 --- /dev/null +++ b/sklearn/ensemble/tests/test_categorical_forest.py @@ -0,0 +1,199 @@ +""" +Testing for the forest module (sklearn.ensemble.forest). +""" + +# Authors: Gilles Louppe, +# Brian Holt, +# Andreas Mueller, +# Arnaud Joly +# License: BSD 3 clause + +from typing import Any, Dict + +import joblib +import numpy as np +import pytest + +from sklearn import datasets +from sklearn.ensemble import ( + ExtraTreesClassifier, + ExtraTreesRegressor, + RandomForestClassifier, + RandomForestRegressor, + RandomTreesEmbedding, +) +from sklearn.utils.validation import check_random_state + +# toy sample +X = [[-2, -1], [-1, -1], [-1, -2], [1, 1], [1, 2], [2, 1]] +y = [-1, -1, -1, 1, 1, 1] +T = [[-1, -1], [2, 2], [3, 2]] +true_result = [-1, 1, 1] + +# Larger classification sample used for testing feature importances +X_large, y_large = datasets.make_classification( + n_samples=500, + n_features=10, + n_informative=3, + n_redundant=0, + n_repeated=0, + shuffle=False, + random_state=0, +) + +# also load the iris dataset +# and randomly permute it +iris = datasets.load_iris() +rng = check_random_state(0) +perm = rng.permutation(iris.target.size) +iris.data = iris.data[perm] +iris.target = iris.target[perm] + +# Make regression dataset +X_reg, y_reg = datasets.make_regression(n_samples=500, n_features=10, random_state=1) + +# also make a hastie_10_2 dataset +hastie_X, hastie_y = datasets.make_hastie_10_2(n_samples=20, random_state=1) +hastie_X = hastie_X.astype(np.float32) + +# Get the default backend in joblib to test parallelism and interaction with +# different backends +DEFAULT_JOBLIB_BACKEND = joblib.parallel.get_active_backend()[0].__class__ + +FOREST_CLASSIFIERS = { + "ExtraTreesClassifier": ExtraTreesClassifier, + "RandomForestClassifier": RandomForestClassifier, +} + +FOREST_REGRESSORS = { + "ExtraTreesRegressor": ExtraTreesRegressor, + "RandomForestRegressor": RandomForestRegressor, +} + +FOREST_TRANSFORMERS = { + "RandomTreesEmbedding": RandomTreesEmbedding, +} + +FOREST_ESTIMATORS: Dict[str, Any] = dict() +FOREST_ESTIMATORS.update(FOREST_CLASSIFIERS) +FOREST_ESTIMATORS.update(FOREST_REGRESSORS) +FOREST_ESTIMATORS.update(FOREST_TRANSFORMERS) + +FOREST_CLASSIFIERS_REGRESSORS: Dict[str, Any] = FOREST_CLASSIFIERS.copy() +FOREST_CLASSIFIERS_REGRESSORS.update(FOREST_REGRESSORS) + + +pytest.mark.parametrize("model", FOREST_CLASSIFIERS_REGRESSORS) + + +def _make_categorical( + n_rows: int, + n_numerical: int, + n_categorical: int, + cat_size: int, + n_num_meaningful: int, + n_cat_meaningful: int, + regression: bool, + return_tuple: bool, + random_state: int, +): + from sklearn.preprocessing import OneHotEncoder + + rng = np.random.RandomState(random_state) + + numeric = rng.standard_normal((n_rows, n_numerical)) + categorical = rng.randint(0, cat_size, (n_rows, n_categorical)) + categorical_ohe = OneHotEncoder(categories="auto").fit_transform( + categorical[:, :n_cat_meaningful] + ) + + data_meaningful = np.hstack( + (numeric[:, :n_num_meaningful], categorical_ohe.todense()) + ) + _, cols = data_meaningful.shape + coefs = rng.standard_normal(cols) + y = np.dot(data_meaningful, coefs) + y = np.asarray(y).reshape(-1) + X = np.hstack((numeric, categorical)) + + if not regression: + y = (y < y.mean()).astype(int) + + meaningful_features = np.r_[ + np.arange(n_num_meaningful), np.arange(n_cat_meaningful) + n_numerical + ] + + if return_tuple: + return X, y, meaningful_features + else: + return {"X": X, "y": y, "meaningful_features": meaningful_features} + + +@pytest.mark.parametrize("model", FOREST_CLASSIFIERS_REGRESSORS) +@pytest.mark.parametrize( + "data_params", + [ + { + "n_rows": 10000, + "n_numerical": 10, + "n_categorical": 5, + "cat_size": 3, + "n_num_meaningful": 1, + "n_cat_meaningful": 2, + }, + { + "n_rows": 1000, + "n_numerical": 0, + "n_categorical": 5, + "cat_size": 3, + "n_num_meaningful": 0, + "n_cat_meaningful": 3, + }, + { + "n_rows": 1000, + "n_numerical": 5, + "n_categorical": 5, + "cat_size": 64, + "n_num_meaningful": 0, + "n_cat_meaningful": 2, + }, + { + "n_rows": 1000, + "n_numerical": 5, + "n_categorical": 5, + "cat_size": 3, + "n_num_meaningful": 0, + "n_cat_meaningful": 3, + }, + ], +) +def test_categorical_data(model, data_params): + # DecisionTrees are too slow for large category sizes. + if data_params["cat_size"] > 8 and "RandomForest" in model: + pass + + X, y, meaningful_features = _make_categorical( + **data_params, + regression=model in FOREST_REGRESSORS, + return_tuple=True, + random_state=42, + ) + rows, cols = X.shape + categorical_features = ( + np.arange(data_params["n_categorical"]) + data_params["n_numerical"] + ) + + model = FOREST_CLASSIFIERS_REGRESSORS[model]( + random_state=42, categorical=categorical_features, n_estimators=100 + ).fit(X, y) + fi = model.feature_importances_ + bad_features = np.array([True] * cols) + bad_features[meaningful_features] = False + + good_ones = fi[meaningful_features] + print(good_ones) + bad_ones = fi[bad_features] + print(bad_ones) + + # all good features should be more important than all bad features. + assert np.all([np.all(x > bad_ones) for x in good_ones]) diff --git a/sklearn/ensemble/tests/test_forest.py b/sklearn/ensemble/tests/test_forest.py index 9291b6982a923..58601e025930f 100644 --- a/sklearn/ensemble/tests/test_forest.py +++ b/sklearn/ensemble/tests/test_forest.py @@ -2021,7 +2021,7 @@ def test_multioutput_quantiles(name): ) est.fit(X_train, y_train) - y_pred = est.predict_quantiles(X_test, quantiles=[0.25, 0.5, 0.75]) + y_pred = est.predict_quantiles(X_test, quantiles=np.array([0.25, 0.5, 0.75])) assert_array_almost_equal(y_pred[:, 1, :], y_test) assert_array_almost_equal(y_pred[:, 0, :], y_test) assert_array_almost_equal(y_pred[:, 2, :], y_test) diff --git a/sklearn/tree/_classes.py b/sklearn/tree/_classes.py index 1b718f3a04052..e336dea404a35 100644 --- a/sklearn/tree/_classes.py +++ b/sklearn/tree/_classes.py @@ -125,6 +125,7 @@ class BaseDecisionTree(MultiOutputMixin, BaseEstimator, metaclass=ABCMeta): "ccp_alpha": [Interval(Real, 0.0, None, closed="left")], "store_leaf_values": ["boolean"], "monotonic_cst": ["array-like", None], + "categorical": ["array-like", StrOptions({"all"}), None], } @abstractmethod @@ -145,6 +146,7 @@ def __init__( ccp_alpha=0.0, store_leaf_values=False, monotonic_cst=None, + categorical=None, ): self.criterion = criterion self.splitter = splitter @@ -160,6 +162,7 @@ def __init__( self.ccp_alpha = ccp_alpha self.store_leaf_values = store_leaf_values self.monotonic_cst = monotonic_cst + self.categorical = categorical def get_depth(self): """Return the depth of the decision tree. @@ -396,6 +399,69 @@ def _fit( else: sample_weight = expanded_class_weight + # Validate categorical features + if self.categorical is None: + categorical = np.array([], dtype=np.int32) + elif isinstance(self.categorical, str): + if self.categorical == "all": + categorical = np.arange(self.n_features_in_) + else: + raise ValueError( + "Invalid value for categorical: {}. Allowed" + " strings are 'all'" + "".format(self.categorical) + ) + else: + categorical = np.atleast_1d(self.categorical).flatten() + if categorical.dtype == np.bool_: + if categorical.size != self.n_features_in_: + raise ValueError( + "Invalid value for categorical: Shape of " + "boolean parameter categorical must " + "be (n_features,)" + ) + categorical = np.nonzero(categorical)[0] + if np.size(categorical) > self.n_features_in_ or ( + categorical.size > 0 + and (categorical.min() < 0 or categorical.max() >= self.n_features_in_) + ): + raise ValueError( + "Invalid value for categorical: Invalid shape or " + "feature index for parameter categorical " + "invalid." + ) + if issparse(X): + if categorical.size > 0: + raise NotImplementedError( + "Categorical features not supported with sparse inputs" + ) + else: + if np.any(X[:, categorical].astype(np.int32) < 0): + raise ValueError( + "Invalid value for categorical: given values " + "for categorical features must be " + "non-negative." + ) + + # Calculate n_categories and verify they are all at least 1% populated + n_cat_present = np.array( + [ + np.unique(X[:, i].astype(np.int32)).size if i in categorical else -1 + for i in range(self.n_features_in_) + ], + dtype=np.int32, + ) + if np.any((n_cat_present < 0.01 * n_cat_present)[categorical]): + warnings.warn( + ( + "At least one categorical feature has less than 1%" + " of its categories present in the sample. Runtime" + " and memory usage will be much smaller if you" + " represent the categories as sequential integers." + ), + UserWarning, + ) + # Set min_weight_leaf from min_weight_fraction_leaf if sample_weight is None: min_weight_leaf = self.min_weight_fraction_leaf * n_samples @@ -408,6 +474,7 @@ def _fit( y, sample_weight, missing_values_in_feature_mask, + categorical, min_samples_leaf, min_weight_leaf, max_leaf_nodes, @@ -427,6 +494,7 @@ def _build_tree( y, sample_weight, missing_values_in_feature_mask, + categorical, min_samples_leaf, min_weight_leaf, max_leaf_nodes, @@ -444,6 +512,11 @@ def _build_tree( Y targets. sample_weight : Array-like Sample weights + missing_values_in_feature_mask : ndarray of shape (n_features,), or None + Missing value mask. If missing values are not supported or there + are no missing values, return None. + categorical : ndarray of shape (n_categorical_features,) + Indices of categorical features. min_samples_leaf : float Number of samples required to be a leaf. min_weight_leaf : float @@ -460,6 +533,14 @@ def _build_tree( n_samples = X.shape[0] + self.n_categories_ = np.array( + [ + np.int32(X[:, i].max()) + 1 if i in categorical else -1 + for i in range(self.n_features_in_) + ], + dtype=np.int32, + ) + # Build tree criterion = self.criterion if not isinstance(criterion, BaseCriterion): @@ -474,6 +555,14 @@ def _build_tree( # might be shared and modified concurrently during parallel fitting criterion = copy.deepcopy(criterion) + if is_classifier(self): + breiman_shortcut = self.n_classes_.tolist() == [2] and ( + isinstance(criterion, _criterion.Gini) + or isinstance(criterion, _criterion.Entropy) + ) + else: + breiman_shortcut = isinstance(criterion, _criterion.MSE) + SPLITTERS = SPARSE_SPLITTERS if issparse(X) else DENSE_SPLITTERS if self.monotonic_cst is None: @@ -523,16 +612,35 @@ def _build_tree( min_weight_leaf, random_state, monotonic_cst, + breiman_shortcut, + ) + + # once splitter is inferred, we want to error-check that the splitter + # supports certain number of categorical data + if ( + not isinstance(splitter, _splitter.RandomSplitter) + and np.max(self.n_categories_) > 64 + ): + raise ValueError( + "Categorical features with greater than 64" + " categories not supported with DecisionTree;" + " try ExtraTree." ) if is_classifier(self): - self.tree_ = Tree(self.n_features_in_, self.n_classes_, self.n_outputs_) + self.tree_ = Tree( + self.n_features_in_, + self.n_classes_, + self.n_outputs_, + self.n_categories_, + ) else: self.tree_ = Tree( self.n_features_in_, # TODO: tree shouldn't need this in this case np.array([1] * self.n_outputs_, dtype=np.intp), self.n_outputs_, + self.n_categories_, ) # Use BestFirst if max_leaf_nodes given; use DepthFirst otherwise @@ -557,7 +665,14 @@ def _build_tree( self.min_impurity_decrease, self.store_leaf_values, ) - builder.build(self.tree_, X, y, sample_weight, missing_values_in_feature_mask) + builder.build( + self.tree_, + X, + y, + sample_weight, + missing_values_in_feature_mask, + self.n_categories_, + ) if self.n_outputs_ == 1 and is_classifier(self): self.n_classes_ = self.n_classes_[0] @@ -583,6 +698,10 @@ def _validate_X_predict(self, X, check_input): X.indices.dtype != np.intc or X.indptr.dtype != np.intc ): raise ValueError("No support for np.int64 index based sparse matrices") + if issparse(X) and np.any(self.tree_.n_categories > 0): + raise NotImplementedError( + "Categorical features not supported with sparse inputs" + ) else: # The number of features is checked regardless of `check_input` self._check_n_features(X, reset=False) @@ -828,13 +947,16 @@ def _prune_tree(self): # build pruned tree if is_classifier(self): n_classes = np.atleast_1d(self.n_classes_) - pruned_tree = Tree(self.n_features_in_, n_classes, self.n_outputs_) + pruned_tree = Tree( + self.n_features_in_, n_classes, self.n_outputs_, self.n_categories_ + ) else: pruned_tree = Tree( self.n_features_in_, # TODO: the tree shouldn't need this param np.array([1] * self.n_outputs_, dtype=np.intp), self.n_outputs_, + self.n_categories_, ) _build_pruned_tree_ccp(pruned_tree, self.tree_, self.ccp_alpha) @@ -1068,6 +1190,19 @@ class DecisionTreeClassifier(ClassifierMixin, BaseDecisionTree): .. versionadded:: 1.4 + categorical : array-like or str + Array of feature indices, boolean array of length n_features, + ``'all'`` or `None`. Indicates which features should be + considered as categorical rather than ordinal. For decision trees, + the maximum number of categories is 64. In practice, the limit will + often be lower because the process of searching for the best possible + split grows exponentially with the number of categories. However, a + shortcut due to Breiman (1984) is used when fitting data with binary + labels using the ``Gini`` or ``Entropy`` criteria. In this case, + the runtime is linear in the number of categories. Extra-random trees + have an upper limit of :math:`2^{31}` categories, and runtimes + linear in the number of categories. + Attributes ---------- classes_ : ndarray of shape (n_classes,) or list of ndarray @@ -1187,6 +1322,7 @@ def __init__( ccp_alpha=0.0, store_leaf_values=False, monotonic_cst=None, + categorical=None, ): super().__init__( criterion=criterion, @@ -1203,6 +1339,7 @@ def __init__( monotonic_cst=monotonic_cst, ccp_alpha=ccp_alpha, store_leaf_values=store_leaf_values, + categorical=categorical, ) @_fit_context(prefer_skip_nested_validation=True) @@ -1479,6 +1616,18 @@ class DecisionTreeRegressor(RegressorMixin, BaseDecisionTree): .. versionadded:: 1.4 + categorical : array-like or str + Array of feature indices, boolean array of length n_features, + ``'all'`` or `None`. Indicates which features should be + considered as categorical rather than ordinal. For decision trees, + the maximum number of categories is 64. In practice, the limit will + often be lower because the process of searching for the best possible + split grows exponentially with the number of categories. However, a + shortcut due to Breiman (1984) is used when fitting data using the + ``MSE`` criterion. In this case, the runtime is linear in the number + of categories. Extra-random trees have an upper limit of :math:`2^{31}` + categories, and runtimes linear in the number of categories. + Attributes ---------- feature_importances_ : ndarray of shape (n_features,) @@ -1579,6 +1728,7 @@ def __init__( ccp_alpha=0.0, store_leaf_values=False, monotonic_cst=None, + categorical=None, ): super().__init__( criterion=criterion, @@ -1594,6 +1744,7 @@ def __init__( ccp_alpha=ccp_alpha, store_leaf_values=store_leaf_values, monotonic_cst=monotonic_cst, + categorical=categorical, ) @_fit_context(prefer_skip_nested_validation=True) @@ -1837,6 +1988,13 @@ class ExtraTreeClassifier(DecisionTreeClassifier): .. versionadded:: 1.4 + categorical : array-like or str + Array of feature indices, boolean array of length n_features, + ``'all'`` or `None`. Indicates which features should be + considered as categorical rather than ordinal. Extra-random trees + have an upper limit of :math:`2^{31}` categories, and runtimes + linear in the number of categories. + Attributes ---------- classes_ : ndarray of shape (n_classes,) or list of ndarray @@ -1939,6 +2097,7 @@ def __init__( ccp_alpha=0.0, store_leaf_values=False, monotonic_cst=None, + categorical=None, ): super().__init__( criterion=criterion, @@ -1955,6 +2114,7 @@ def __init__( ccp_alpha=ccp_alpha, store_leaf_values=store_leaf_values, monotonic_cst=monotonic_cst, + categorical=categorical, ) @@ -2111,6 +2271,13 @@ class ExtraTreeRegressor(DecisionTreeRegressor): .. versionadded:: 1.4 + categorical : array-like or str + Array of feature indices, boolean array of length n_features, + ``'all'`` or `None`. Indicates which features should be + considered as categorical rather than ordinal. Extra-random trees + have an upper limit of :math:`2^{31}` categories, and runtimes + linear in the number of categories. + Attributes ---------- max_features_ : int @@ -2196,6 +2363,7 @@ def __init__( ccp_alpha=0.0, store_leaf_values=False, monotonic_cst=None, + categorical=None, ): super().__init__( criterion=criterion, @@ -2211,4 +2379,5 @@ def __init__( ccp_alpha=ccp_alpha, store_leaf_values=store_leaf_values, monotonic_cst=monotonic_cst, + categorical=categorical, ) diff --git a/sklearn/tree/_criterion.pxd b/sklearn/tree/_criterion.pxd index 690f4d0c54c64..3a7ecb522dcb7 100644 --- a/sklearn/tree/_criterion.pxd +++ b/sklearn/tree/_criterion.pxd @@ -14,11 +14,11 @@ cimport numpy as cnp from libcpp.vector cimport vector -from ._tree cimport DOUBLE_t # Type of y, sample_weight -from ._tree cimport DTYPE_t # Type of X -from ._tree cimport INT32_t # Signed 32 bit integer -from ._tree cimport SIZE_t # Type for indices and counters -from ._tree cimport UINT32_t # Unsigned 32 bit integer +from ._utils cimport DOUBLE_t # Type of y, sample_weight +from ._utils cimport DTYPE_t # Type of X +from ._utils cimport INT32_t # Signed 32 bit integer +from ._utils cimport SIZE_t # Type for indices and counters +from ._utils cimport UINT32_t # Unsigned 32 bit integer cdef class BaseCriterion: diff --git a/sklearn/tree/_splitter.pxd b/sklearn/tree/_splitter.pxd index 7da118347414a..94d5df7138e7b 100644 --- a/sklearn/tree/_splitter.pxd +++ b/sklearn/tree/_splitter.pxd @@ -15,27 +15,15 @@ cimport numpy as cnp from libcpp.vector cimport vector from ._criterion cimport BaseCriterion, Criterion -from ._tree cimport DOUBLE_t # Type of y, sample_weight -from ._tree cimport DTYPE_t # Type of X -from ._tree cimport INT32_t # Signed 32 bit integer -from ._tree cimport SIZE_t # Type for indices and counters -from ._tree cimport UINT32_t # Unsigned 32 bit integer - - -cdef struct SplitRecord: - # Data to track sample split - SIZE_t feature # Which feature to split on. - SIZE_t pos # Split samples array at the given position, - # # i.e. count of samples below threshold for feature. - # # pos is >= end if the node is a leaf. - double threshold # Threshold to split at. - double improvement # Impurity improvement given parent node. - double impurity_left # Impurity of the left split. - double impurity_right # Impurity of the right split. - double lower_bound # Lower bound on value of both children for monotonicity - double upper_bound # Upper bound on value of both children for monotonicity - unsigned char missing_go_to_left # Controls if missing values go to the left node. - SIZE_t n_missing # Number of missing values for the feature being split on +from ._utils cimport DOUBLE_t # Type of y, sample_weight +from ._utils cimport DTYPE_t # Type of X +from ._utils cimport INT32_t # Signed 32 bit integer +from ._utils cimport SIZE_t # Type for indices and counters +from ._utils cimport UINT32_t # Unsigned 32 bit integer +from ._utils cimport UINT64_t # Unsigned 64 bit integer + +from ._utils cimport SplitValue, SplitRecord, Node + cdef class BaseSplitter: """Abstract interface for splitter.""" @@ -102,9 +90,16 @@ cdef class BaseSplitter: cdef int pointer_size(self) noexcept nogil cdef class Splitter(BaseSplitter): - cdef public Criterion criterion # Impurity criterion + cdef public Criterion criterion # Impurity criterion cdef const DOUBLE_t[:, ::1] y + cdef INT32_t[:] n_categories # (n_features,) array giving number of + # # categories (<0 for non-categorical) + cdef UINT64_t[:] cat_cache # Cache buffer for fast categorical split evaluation + cdef bint breiman_shortcut # Whether decision trees are allowed to use the + # # Breiman shortcut for categorical features + # # during binary classification. + # Monotonicity constraints for each feature. # The encoding is as follows: # -1: monotonic decrease @@ -119,6 +114,7 @@ cdef class Splitter(BaseSplitter): const DOUBLE_t[:, ::1] y, const DOUBLE_t[:] sample_weight, const unsigned char[::1] missing_values_in_feature_mask, + const INT32_t[:] n_categories, ) except -1 cdef void node_samples(self, vector[vector[DOUBLE_t]]& dest) noexcept nogil @@ -141,3 +137,13 @@ cdef class Splitter(BaseSplitter): double lower_bound, double upper_bound ) noexcept nogil + + cdef void _breiman_sort_categories( + self, + SIZE_t start, + SIZE_t end, + INT32_t ncat, + SIZE_t ncat_present, + const INT32_t *cat_offset, + SIZE_t *sorted_cat + ) noexcept nogil diff --git a/sklearn/tree/_splitter.pyx b/sklearn/tree/_splitter.pyx index f2d0a4dfde0f2..5138832c90c58 100644 --- a/sklearn/tree/_splitter.pyx +++ b/sklearn/tree/_splitter.pyx @@ -21,6 +21,7 @@ from cython cimport final from libc.math cimport isnan from libc.stdlib cimport qsort from libc.string cimport memcpy +from libc.string cimport memset cimport numpy as cnp from ._criterion cimport Criterion @@ -31,7 +32,6 @@ from scipy.sparse import issparse from ._utils cimport RAND_R_MAX, log, rand_int, rand_uniform - cdef double INFINITY = np.inf # Mitigate precision differences between 32 bit and 64 bit @@ -149,6 +149,7 @@ cdef class Splitter(BaseSplitter): double min_weight_leaf, object random_state, const cnp.int8_t[:] monotonic_cst, + bint breiman_shortcut, *argv ): """ @@ -186,6 +187,8 @@ cdef class Splitter(BaseSplitter): self.min_samples_leaf = min_samples_leaf self.min_weight_leaf = min_weight_leaf self.random_state = random_state + + self.breiman_shortcut = breiman_shortcut self.monotonic_cst = monotonic_cst self.with_monotonic_cst = monotonic_cst is not None @@ -195,7 +198,8 @@ cdef class Splitter(BaseSplitter): self.min_samples_leaf, self.min_weight_leaf, self.random_state, - self.monotonic_cst), self.__getstate__()) + self.monotonic_cst, + self.breiman_shortcut), self.__getstate__()) cdef int init( self, @@ -203,6 +207,7 @@ cdef class Splitter(BaseSplitter): const DOUBLE_t[:, ::1] y, const DOUBLE_t[:] sample_weight, const unsigned char[::1] missing_values_in_feature_mask, + const INT32_t[:] n_categories ) except -1: """Initialize the splitter. @@ -226,8 +231,13 @@ cdef class Splitter(BaseSplitter): are assumed to have uniform weight. This is represented as a Cython memoryview. - has_missing : bool - At least one missing values is in X. + missing_values_in_feature_mask : ndarray, dtype=unsigned char + Whether or not each feature has missing values. This is represented + as a Cython memoryview. + + n_categories : array of INT32_t, shape=(n_features,) + Number of categories for categorical features, or -1 for + non-categorical features """ self.rand_r_state = self.random_state.randint(0, RAND_R_MAX) cdef SIZE_t n_samples = X.shape[0] @@ -281,6 +291,21 @@ cdef class Splitter(BaseSplitter): if missing_values_in_feature_mask is not None: self.criterion.init_sum_missing() + # Initialize the number of categories for each feature + # A value of -1 indicates a non-categorical feature + if n_categories is None: + self.n_categories = np.array([-1] * n_features, dtype=np.int32) + else: + self.n_categories = np.empty_like(n_categories, dtype=np.int32) + self.n_categories[:] = n_categories + + # If needed, allocate cache space for categorical splits + cdef INT32_t max_n_categories = max(self.n_categories) + if max_n_categories > 0: + cache_size = ((max_n_categories + 63) // 64) + self.cat_cache = np.zeros(cache_size, dtype=np.uint64) + # safe_realloc(&self.cat_cache, cache_size, sizeof(UINT64_t)) + return 0 cdef int node_reset(self, SIZE_t start, SIZE_t end, @@ -373,6 +398,57 @@ cdef class Splitter(BaseSplitter): return 0 + cdef void _breiman_sort_categories( + self, + SIZE_t start, + SIZE_t end, + INT32_t ncat, + SIZE_t ncat_present, + const INT32_t *cat_offset, + SIZE_t *sorted_cat + ) noexcept nogil: + """The Breiman shortcut for finding the best split involves a + preprocessing step wherein we sort the categories by + increasing (weighted) mean of the outcome y (whether 0/1 + binary for classification or quantitative for + regression). This function implements this preprocessing step + and produces a sorted list of category values. + """ + cdef: + DTYPE_t[:] Xf = self.feature_values + SIZE_t cat, localcat + DTYPE_t sort_value[64] + DTYPE_t sort_density[64] + + # categorical features with more than 64 categories are not supported + # here. + memset(sort_value, 0, 64 * sizeof(DTYPE_t)) + memset(sort_density, 0, 64 * sizeof(DTYPE_t)) + + cdef int i, p + cdef DOUBLE_t w = 1.0 + + # apply a sorting over the y values + # since we are in binary classification, there is only one column of y + for p in range(start, end): + # get the categorical variable value + cat = Xf[p] + + # apply sorting with weighting by sample weight + i = self.samples[p] + if self.sample_weight is not None: + w = self.sample_weight[i] + sort_value[cat] += w * (self.y[i, 0]) + sort_density[cat] += w + + for localcat in range(ncat_present): + cat = localcat + cat_offset[localcat] + if sort_density[cat] == 0: # Avoid dividing by zero + sort_density[cat] = 1 + sort_value[localcat] = sort_value[cat] / sort_density[cat] + sorted_cat[localcat] = cat + + sort(&sort_value[0], sorted_cat, ncat_present) cdef inline void shift_missing_values_to_left_if_required( SplitRecord* best, @@ -1004,6 +1080,7 @@ cdef class DensePartitioner: cdef SIZE_t end cdef SIZE_t n_missing cdef const unsigned char[::1] missing_values_in_feature_mask + cdef const INT32_t[:] n_categories def __init__( self, @@ -1011,11 +1088,13 @@ cdef class DensePartitioner: SIZE_t[::1] samples, DTYPE_t[::1] feature_values, const unsigned char[::1] missing_values_in_feature_mask, + const INT32_t[:] n_categories, ): self.X = X self.samples = samples self.feature_values = feature_values self.missing_values_in_feature_mask = missing_values_in_feature_mask + self.n_categories = n_categories cdef inline void init_node_split(self, SIZE_t start, SIZE_t end) noexcept nogil: """Initialize splitter at the beginning of node_split.""" @@ -1214,6 +1293,7 @@ cdef class SparsePartitioner: cdef SIZE_t end cdef SIZE_t n_missing cdef const unsigned char[::1] missing_values_in_feature_mask + cdef const INT32_t[:] n_categories cdef const DTYPE_t[::1] X_data cdef const INT32_t[::1] X_indices @@ -1235,6 +1315,7 @@ cdef class SparsePartitioner: SIZE_t n_samples, DTYPE_t[::1] feature_values, const unsigned char[::1] missing_values_in_feature_mask, + const INT32_t[:] n_categories, ): if not (issparse(X) and X.format == "csc"): raise ValueError("X should be in csc format") @@ -1259,6 +1340,7 @@ cdef class SparsePartitioner: self.index_to_samples[samples[p]] = p self.missing_values_in_feature_mask = missing_values_in_feature_mask + self.n_categories = n_categories cdef inline void init_node_split(self, SIZE_t start, SIZE_t end) noexcept nogil: """Initialize splitter at the beginning of node_split.""" @@ -1630,10 +1712,19 @@ cdef class BestSplitter(Splitter): const DOUBLE_t[:, ::1] y, const DOUBLE_t[:] sample_weight, const unsigned char[::1] missing_values_in_feature_mask, + const INT32_t[:] n_categories, ) except -1: - Splitter.init(self, X, y, sample_weight, missing_values_in_feature_mask) + Splitter.init( + self, + X, + y, + sample_weight, + missing_values_in_feature_mask, + n_categories + ) self.partitioner = DensePartitioner( - X, self.samples, self.feature_values, missing_values_in_feature_mask + X, self.samples, self.feature_values, missing_values_in_feature_mask, + n_categories ) cdef int node_split( @@ -1666,10 +1757,12 @@ cdef class BestSparseSplitter(Splitter): const DOUBLE_t[:, ::1] y, const DOUBLE_t[:] sample_weight, const unsigned char[::1] missing_values_in_feature_mask, + const INT32_t[:] n_categories, ) except -1: - Splitter.init(self, X, y, sample_weight, missing_values_in_feature_mask) + Splitter.init(self, X, y, sample_weight, missing_values_in_feature_mask, n_categories) self.partitioner = SparsePartitioner( - X, self.samples, self.n_samples, self.feature_values, missing_values_in_feature_mask + X, self.samples, self.n_samples, self.feature_values, missing_values_in_feature_mask, + n_categories, ) cdef int node_split( @@ -1702,10 +1795,12 @@ cdef class RandomSplitter(Splitter): const DOUBLE_t[:, ::1] y, const DOUBLE_t[:] sample_weight, const unsigned char[::1] missing_values_in_feature_mask, + const INT32_t[:] n_categories, ) except -1: - Splitter.init(self, X, y, sample_weight, missing_values_in_feature_mask) + Splitter.init(self, X, y, sample_weight, missing_values_in_feature_mask, n_categories) self.partitioner = DensePartitioner( - X, self.samples, self.feature_values, missing_values_in_feature_mask + X, self.samples, self.feature_values, missing_values_in_feature_mask, + n_categories ) cdef int node_split( @@ -1738,10 +1833,12 @@ cdef class RandomSparseSplitter(Splitter): const DOUBLE_t[:, ::1] y, const DOUBLE_t[:] sample_weight, const unsigned char[::1] missing_values_in_feature_mask, + const INT32_t[:] n_categories, ) except -1: - Splitter.init(self, X, y, sample_weight, missing_values_in_feature_mask) + Splitter.init(self, X, y, sample_weight, missing_values_in_feature_mask, n_categories) self.partitioner = SparsePartitioner( - X, self.samples, self.n_samples, self.feature_values, missing_values_in_feature_mask + X, self.samples, self.n_samples, self.feature_values, missing_values_in_feature_mask, + n_categories ) cdef int node_split( self, diff --git a/sklearn/tree/_tree.pxd b/sklearn/tree/_tree.pxd index dedd820c41e0f..98d78280afbf0 100644 --- a/sklearn/tree/_tree.pxd +++ b/sklearn/tree/_tree.pxd @@ -16,26 +16,28 @@ cimport numpy as cnp from libcpp.unordered_map cimport unordered_map from libcpp.vector cimport vector -ctypedef cnp.npy_float32 DTYPE_t # Type of X -ctypedef cnp.npy_float64 DOUBLE_t # Type of y, sample_weight -ctypedef cnp.npy_intp SIZE_t # Type for indices and counters -ctypedef cnp.npy_int32 INT32_t # Signed 32 bit integer -ctypedef cnp.npy_uint32 UINT32_t # Unsigned 32 bit integer +from ._splitter cimport Splitter +from ._utils cimport DOUBLE_t # Type of y, sample_weight +from ._utils cimport DTYPE_t # Type of X +from ._utils cimport INT32_t # Signed 32 bit integer +from ._utils cimport SIZE_t # Type for indices and counters +from ._utils cimport UINT32_t # Unsigned 32 bit integer +from ._utils cimport UINT64_t # Unsigned 64 bit integer +from ._utils cimport SplitValue, SplitRecord, Node -from ._splitter cimport SplitRecord, Splitter +cdef class CategoryCacheMgr: + # Class to manage the category cache memory during Tree.apply() -cdef struct Node: - # Base storage structure for the nodes in a Tree object + cdef SIZE_t n_nodes + cdef vector[vector[UINT64_t]] bits - SIZE_t left_child # id of the left child of the node - SIZE_t right_child # id of the right child of the node - SIZE_t feature # Feature used for splitting the node - DOUBLE_t threshold # Threshold value at the node - DOUBLE_t impurity # Impurity of the node (i.e., the value of the criterion) - SIZE_t n_node_samples # Number of samples at the node - DOUBLE_t weighted_n_node_samples # Weighted number of samples at the node - unsigned char missing_go_to_left # Whether features have missing values + cdef void populate( + self, + Node* nodes, + SIZE_t n_nodes, + INT32_t[:] n_categories + ) noexcept cdef class BaseTree: @@ -109,10 +111,13 @@ cdef class Tree(BaseTree): # - value = (capacity, n_outputs, max_n_classes) array of values # Input/Output layout for supervised tree - cdef public SIZE_t n_features # Number of features in X - cdef SIZE_t* n_classes # Number of classes in y[:, k] - cdef public SIZE_t n_outputs # Number of outputs in y - cdef public SIZE_t max_n_classes # max(n_classes) + cdef public SIZE_t n_features # Number of features in X + cdef SIZE_t* n_classes # Number of classes in y[:, k] + cdef public SIZE_t n_outputs # Number of outputs in y + cdef public SIZE_t max_n_classes # max(n_classes) + + cdef INT32_t* n_categories # (n_features,) array of number of categories per feature + # # is <0 for non-categorial (i.e. -1) # Enables the use of tree to store distributions of the output to allow # arbitrary usage of the the leaves. This is used in the quantile @@ -158,6 +163,7 @@ cdef class TreeBuilder: const DOUBLE_t[:, ::1] y, const DOUBLE_t[:] sample_weight=*, const unsigned char[::1] missing_values_in_feature_mask=*, + const INT32_t[:] n_categories=*, ) cdef _check_input( diff --git a/sklearn/tree/_tree.pyx b/sklearn/tree/_tree.pyx index 492b5219fa18e..93bc1b71deaaf 100644 --- a/sklearn/tree/_tree.pyx +++ b/sklearn/tree/_tree.pyx @@ -38,7 +38,9 @@ from scipy.sparse import csr_matrix from ._utils cimport safe_realloc from ._utils cimport sizet_ptr_to_ndarray - +from ._utils cimport int32_ptr_to_ndarray +from ._utils cimport setup_cat_cache +from ._utils cimport goes_left cdef extern from "numpy/arrayobject.h": object PyArray_NewFromDescr(PyTypeObject* subtype, cnp.dtype descr, @@ -98,6 +100,7 @@ cdef class TreeBuilder: const DOUBLE_t[:, ::1] y, const DOUBLE_t[:] sample_weight=None, const unsigned char[::1] missing_values_in_feature_mask=None, + const INT32_t[:] n_categories=None, ): """Build a decision tree from the training set (X, y).""" pass @@ -182,6 +185,7 @@ cdef class DepthFirstTreeBuilder(TreeBuilder): const DOUBLE_t[:, ::1] y, const DOUBLE_t[:] sample_weight=None, const unsigned char[::1] missing_values_in_feature_mask=None, + const INT32_t[:] n_categories=None, ): """Build a decision tree from the training set (X, y).""" @@ -207,7 +211,7 @@ cdef class DepthFirstTreeBuilder(TreeBuilder): cdef double min_impurity_decrease = self.min_impurity_decrease # Recursive partition (without actual recursion) - splitter.init(X, y, sample_weight, missing_values_in_feature_mask) + splitter.init(X, y, sample_weight, missing_values_in_feature_mask, n_categories) cdef SIZE_t start cdef SIZE_t end @@ -458,6 +462,7 @@ cdef class BestFirstTreeBuilder(TreeBuilder): const DOUBLE_t[:, ::1] y, const DOUBLE_t[:] sample_weight=None, const unsigned char[::1] missing_values_in_feature_mask=None, + const INT32_t[:] n_categories=None, ): """Build a decision tree from the training set (X, y).""" @@ -469,7 +474,7 @@ cdef class BestFirstTreeBuilder(TreeBuilder): cdef SIZE_t max_leaf_nodes = self.max_leaf_nodes # Recursive partition (without actual recursion) - splitter.init(X, y, sample_weight, missing_values_in_feature_mask) + splitter.init(X, y, sample_weight, missing_values_in_feature_mask, n_categories) cdef vector[FrontierRecord] frontier cdef FrontierRecord record @@ -718,6 +723,76 @@ cdef class BestFirstTreeBuilder(TreeBuilder): return 0 +cdef class CategoryCacheMgr: + """Class to manage the category cache memory during Tree.apply().""" + + def __cinit__(self): + self.n_nodes = 0 + # self.bits = NULL + + cdef void populate( + self, + Node *nodes, + SIZE_t n_nodes, + const INT32_t[:] n_categories + ) noexcept: + """Populate the category cache memory. + + This method initializes and populates the category cache memory for each node + in the tree. It allocates memory for the cache based on the number of categories + in each feature and sets up the cache values. + + Parameters + ---------- + nodes : Node pointer + A pointer to the array of nodes. + n_nodes : SIZE_t + The number of nodes. + n_categories : INT32_t array + An array of integers representing the number of categories in each feature. + + Notes + ----- + The category cache memory is stored as a ragged array of shape + (n_nodes, # categories in each feature), represented by a vector of + vectors of UINT64_t. + + The UINT64_t type is a custom type used for storing category cache values. + + This method modifies the `self.bits` attribute of the `CategoryCacheMgr` + instance to store the populated cache memory. + """ + cdef SIZE_t i + cdef INT32_t ncat + + if nodes == NULL or n_categories is None: + return + + self.n_nodes = n_nodes + + # initialize bits as a vector of vectors of UINT64_t + # it is essentially a ragged array of shape (n_nodes, # categories in each feature) + self.bits = vector[vector[UINT64_t]](self.n_nodes) + + for i in range(n_nodes): + # self.bits[i] = NULL + + # if the node is a split-node, then we need to allocate memory for the cache + if nodes[i].left_child != _TREE_LEAF: + # get the number of categories in the feature and then set up the cache + ncat = n_categories[nodes[i].feature] + if ncat > 0: + cache_size = (ncat + 63) // 64 + self.bits[i] = vector[UINT64_t](cache_size) + + # allocate values to the cache for this node + setup_cat_cache( + self.bits[i], + nodes[i].cat_split, + ncat + ) + + # ============================================================================= # Tree # ============================================================================= @@ -793,6 +868,8 @@ cdef class BaseTree: # left_child and right_child will be set later for a split node node.feature = split_node.feature node.threshold = split_node.threshold + node.cat_split = split_node.cat_split + # node.split_value = split_node.split_value return 1 cdef int _set_leaf_node( @@ -813,6 +890,7 @@ cdef class BaseTree: node.right_child = _TREE_LEAF node.feature = _TREE_UNDEFINED node.threshold = _TREE_UNDEFINED + node.cat_split = _TREE_UNDEFINED return 1 cdef DTYPE_t _compute_feature( @@ -926,9 +1004,24 @@ cdef class BaseTree: cdef Node* node = NULL cdef SIZE_t i = 0 + # initialize Cache to go over categorical features + cache_mgr = CategoryCacheMgr() + cache_mgr.populate(self.nodes, self.node_count, self.n_categories) + cdef vector[vector[UINT64_t]] cat_caches = cache_mgr.bits + # cdef vector[UINT64_t] cache = NULL + + cdef const INT32_t[:] n_categories = self.n_categories + + # apply Cache to speed up categorical "apply" + # cache_mgr = CategoryCacheMgr() + # cache_mgr.populate(self.nodes, self.node_count, self.n_categories) + # cdef UINT64_t** cat_caches = cache_mgr.bits + # cdef UINT64_t* cache = NULL + with nogil: for i in range(n_samples): node = self.nodes + cache = cat_caches[0] # While node not a leaf while node.left_child != _TREE_LEAF: @@ -939,7 +1032,15 @@ cdef class BaseTree: node = &self.nodes[node.left_child] else: node = &self.nodes[node.right_child] - elif X_i_node_feature <= node.threshold: + elif goes_left( + X_i_node_feature, + # node.split_value, + # node.threshold, + # self.n_categories[node.feature], + node, + n_categories, + cache + ): node = &self.nodes[node.left_child] else: node = &self.nodes[node.right_child] @@ -977,6 +1078,13 @@ cdef class BaseTree: cdef SIZE_t i = 0 cdef INT32_t k = 0 + # initialize Cache to go over categorical features + cache_mgr = CategoryCacheMgr() + cache_mgr.populate(self.nodes, self.node_count, self.n_categories) + cdef vector[vector[UINT64_t]] cat_caches = cache_mgr.bits + # cdef vector[UINT64_t] cache = NULL + + cdef const INT32_t[:] n_categories = self.n_categories # feature_to_sample as a data structure records the last seen sample # for each feature; functionally, it is an efficient way to identify # which features are nonzero in the present sample. @@ -990,6 +1098,7 @@ cdef class BaseTree: for i in range(n_samples): node = self.nodes + cache = cat_caches[0] for k in range(X_indptr[i], X_indptr[i + 1]): feature_to_sample[X_indices[k]] = i @@ -1003,9 +1112,19 @@ cdef class BaseTree: else: feature_value = 0. - if feature_value <= node.threshold: + if goes_left( + feature_value, + # node.split_value, + # node.threshold, + # self.n_categories[node.feature], + node, + n_categories, + cache + ): + cache = cat_caches[node.left_child] node = &self.nodes[node.left_child] else: + cache = cat_caches[node.right_child] node = &self.nodes[node.right_child] out[i] = (node - self.nodes) # node offset @@ -1044,10 +1163,18 @@ cdef class BaseTree: n_samples * (1 + self.max_depth), dtype=np.intp ) + cdef const INT32_t[:] n_categories = self.n_categories + # Initialize auxiliary data-structure cdef Node* node = NULL cdef SIZE_t i = 0 + # initialize Cache to go over categorical features + cache_mgr = CategoryCacheMgr() + cache_mgr.populate(self.nodes, self.node_count, self.n_categories) + cdef vector[vector[UINT64_t]] cat_caches = cache_mgr.bits + # cdef vector[UINT64_t] cache = NULL + # the feature index cdef DOUBLE_t feature @@ -1055,6 +1182,7 @@ cdef class BaseTree: for i in range(n_samples): node = self.nodes indptr[i + 1] = indptr[i] + cache = cat_caches[0] # Add all external nodes while node.left_child != _TREE_LEAF: @@ -1064,9 +1192,19 @@ cdef class BaseTree: # compute the feature value to compare against threshold feature = self._compute_feature(X_ndarray, i, node) - if feature <= node.threshold: + if goes_left( + feature, + # node.split_value, + # node.threshold, + # self.n_categories[node.feature], + node, + n_categories, + cache + ): + cache = cat_caches[node.left_child] node = &self.nodes[node.left_child] else: + cache = cat_caches[node.right_child] node = &self.nodes[node.right_child] # Add the leave node @@ -1112,6 +1250,14 @@ cdef class BaseTree: cdef SIZE_t i = 0 cdef INT32_t k = 0 + # initialize Cache to go over categorical features + cache_mgr = CategoryCacheMgr() + cache_mgr.populate(self.nodes, self.node_count, self.n_categories) + cdef vector[vector[UINT64_t]] cat_caches = cache_mgr.bits + # cdef vector[UINT64_t] cache = NULL + + cdef const INT32_t[:] n_categories = self.n_categories + # feature_to_sample as a data structure records the last seen sample # for each feature; functionally, it is an efficient way to identify # which features are nonzero in the present sample. @@ -1127,6 +1273,9 @@ cdef class BaseTree: node = self.nodes indptr[i + 1] = indptr[i] + # start off each sample with the root cache + cache = cat_caches[0] + for k in range(X_indptr[i], X_indptr[i + 1]): feature_to_sample[X_indices[k]] = i X_sample[X_indices[k]] = X_data[k] @@ -1144,9 +1293,19 @@ cdef class BaseTree: else: feature_value = 0. - if feature_value <= node.threshold: + if goes_left( + feature_value, + # node.split_value, + # node.threshold, + # self.n_categories[node.feature], + node, + n_categories, + cache + ): + cache = cat_caches[node.left_child] node = &self.nodes[node.left_child] else: + cache = cat_caches[node.right_child] node = &self.nodes[node.right_child] # Add the leave node @@ -1406,6 +1565,10 @@ cdef class Tree(BaseTree): weighted_n_node_samples : array of double, shape [node_count] weighted_n_node_samples[i] holds the weighted number of training samples reaching node i. + + n_categories : array of int, shape [n_features] + Number of expected category values for categorical features, or + -1 for non-categorical features. """ # Wrap for outside world. # WARNING: these reference the current `nodes` and `value` buffers, which @@ -1465,9 +1628,16 @@ cdef class Tree(BaseTree): leaf_node_samples[node_id] = self._get_value_samples_ndarray(node_id) return leaf_node_samples + @property + def n_categories(self): + return int32_ptr_to_ndarray(self.n_categories, self.n_features).copy() + # TODO: Convert n_classes to cython.integral memory view once # https://github.com/cython/cython/issues/5243 is fixed - def __cinit__(self, int n_features, cnp.ndarray n_classes, int n_outputs): + def __cinit__( + self, int n_features, cnp.ndarray n_classes, int n_outputs, + cnp.ndarray[INT32_t, ndim=1] n_categories + ): """Constructor.""" cdef SIZE_t dummy = 0 size_t_dtype = np.array(dummy).dtype @@ -1480,12 +1650,21 @@ cdef class Tree(BaseTree): self.n_classes = NULL safe_realloc(&self.n_classes, n_outputs) + self.n_categories = NULL + safe_realloc(&self.n_categories, n_features) + + # n-categories is a 1D array of size n_features + # self.n_categories = np.empty(n_features, dtype=np.int32) + # self.n_categories = n_categories + self.max_n_classes = np.max(n_classes) self.value_stride = n_outputs * self.max_n_classes cdef SIZE_t k for k in range(n_outputs): self.n_classes[k] = n_classes[k] + for k in range(n_features): + self.n_categories[k] = n_categories[k] # Inner structures self.max_depth = 0 @@ -1497,18 +1676,29 @@ cdef class Tree(BaseTree): # initialize the hash map for the value samples self.value_samples = unordered_map[SIZE_t, vector[vector[DOUBLE_t]]]() + # Ensure cython and numpy node sizes match up + np_node_size = ( NODE_DTYPE).itemsize + node_size = sizeof(Node) + if (np_node_size != node_size): + raise TypeError('Size of numpy NODE_DTYPE ({} bytes) does not' + ' match size of Node ({} bytes)'.format( + np_node_size, node_size)) + def __dealloc__(self): """Destructor.""" # Free all inner structures free(self.n_classes) free(self.value) free(self.nodes) + free(self.n_categories) def __reduce__(self): """Reduce re-implementation, for pickling.""" return (Tree, (self.n_features, sizet_ptr_to_ndarray(self.n_classes, self.n_outputs), - self.n_outputs), self.__getstate__()) + self.n_outputs, + int32_ptr_to_ndarray(self.n_categories, self.n_features)), + self.__getstate__()) def __getstate__(self): """Getstate re-implementation, for pickling.""" diff --git a/sklearn/tree/_utils.pxd b/sklearn/tree/_utils.pxd index 61ba8af197c2e..26bc6fc5d2bb3 100644 --- a/sklearn/tree/_utils.pxd +++ b/sklearn/tree/_utils.pxd @@ -7,19 +7,20 @@ # License: BSD 3 clause # See _utils.pyx for details. +import cython +from libcpp.vector cimport vector cimport numpy as cnp from sklearn.neighbors._quad_tree cimport Cell -from ._tree cimport Node ctypedef cnp.npy_float32 DTYPE_t # Type of X ctypedef cnp.npy_float64 DOUBLE_t # Type of y, sample_weight ctypedef cnp.npy_intp SIZE_t # Type for indices and counters ctypedef cnp.npy_int32 INT32_t # Signed 32 bit integer ctypedef cnp.npy_uint32 UINT32_t # Unsigned 32 bit integer - +ctypedef cnp.npy_uint64 UINT64_t # Unsigned 64 bit integer cdef enum: # Max value for our rand_r replacement (near the bottom). @@ -30,6 +31,59 @@ cdef enum: RAND_R_MAX = 2147483647 +cdef union SplitValue: + # Union type to generalize the concept of a threshold to categorical + # features. The floating point view, i.e. ``SplitValue.threshold`` is used + # for numerical features, where feature values less than or equal to the + # threshold go left, and values greater than the threshold go right. + # + # For categorical features, the UINT64_t view (`SplitValue.cat_split``) is + # used. It works in one of two ways, indicated by the value of its least + # significant bit (LSB). If the LSB is 0, then cat_split acts as a bitfield + # for up to 64 categories, sending samples left if the bit corresponding to + # their category is 1 or right if it is 0. If the LSB is 1, then the most + # significant 32 bits of cat_split make a random seed. To evaluate a + # sample, use the random seed to flip a coin (category_value + 1) times and + # send it left if the last flip gives 1; otherwise right. This second + # method allows up to 2**31 category values, but can only be used for + # RandomSplitter. + DOUBLE_t threshold + UINT64_t cat_split + +cdef struct SplitRecord: + # Data to track sample split + SIZE_t feature # Which feature to split on. + SIZE_t pos # Split samples array at the given position, + # # i.e. count of samples below threshold for feature. + # # pos is >= end if the node is a leaf. + # SplitValue split_value # Generalized threshold for categorical and + # # non-categorical features + DOUBLE_t threshold + UINT64_t cat_split + double improvement # Impurity improvement given parent node. + double impurity_left # Impurity of the left split. + double impurity_right # Impurity of the right split. + double lower_bound # Lower bound on value of both children for monotonicity + double upper_bound # Upper bound on value of both children for monotonicity + unsigned char missing_go_to_left # Controls if missing values go to the left node. + SIZE_t n_missing # Number of missing values for the feature being split on + +cdef struct Node: + # Base storage structure for the nodes in a Tree object + + SIZE_t left_child # id of the left child of the node + SIZE_t right_child # id of the right child of the node + SIZE_t feature # Feature used for splitting the node + # SplitValue split_value # Generalized threshold for categorical and + # # non-categorical features + DOUBLE_t threshold + UINT64_t cat_split + DOUBLE_t impurity # Impurity of the node (i.e., the value of the criterion) + SIZE_t n_node_samples # Number of samples at the node + DOUBLE_t weighted_n_node_samples # Weighted number of samples at the node + unsigned char missing_go_to_left # Whether features have missing values + + # safe_realloc(&p, n) resizes the allocation of p to n * sizeof(*p) bytes or # raises a MemoryError. It never calls free, since that's __dealloc__'s job. # cdef DTYPE_t *p = NULL @@ -46,6 +100,10 @@ ctypedef fused realloc_ptr: (Node*) (Cell*) (Node**) + (void**) + (INT32_t*) + (UINT32_t*) + (UINT64_t*) cdef realloc_ptr safe_realloc(realloc_ptr* p, size_t nelems) except * nogil @@ -53,6 +111,9 @@ cdef realloc_ptr safe_realloc(realloc_ptr* p, size_t nelems) except * nogil cdef cnp.ndarray sizet_ptr_to_ndarray(SIZE_t* data, SIZE_t size) +cdef cnp.ndarray int32_ptr_to_ndarray(INT32_t* data, SIZE_t size) + + cdef SIZE_t rand_int(SIZE_t low, SIZE_t high, UINT32_t* random_state) noexcept nogil @@ -63,6 +124,24 @@ cdef double rand_uniform(double low, double high, cdef double log(double x) noexcept nogil + +cdef void setup_cat_cache( + vector[UINT64_t]& cachebits, + UINT64_t cat_split, + INT32_t n_categories +) noexcept nogil + + +cdef bint goes_left( + DTYPE_t feature_value, + # SplitValue split, + # DOUBLE_t threshold, + # INT32_t n_categories, + Node* node, + const INT32_t[:] n_categories, + vector[UINT64_t]& cachebits +) noexcept nogil + # ============================================================================= # WeightedPQueue data structure # ============================================================================= @@ -110,3 +189,13 @@ cdef class WeightedMedianCalculator: self, DOUBLE_t data, DOUBLE_t weight, DOUBLE_t original_median) noexcept nogil cdef DOUBLE_t get_median(self) noexcept nogil + + +cdef UINT64_t bs_set(UINT64_t value, SIZE_t i) noexcept nogil +cdef UINT64_t bs_reset(UINT64_t value, SIZE_t i) noexcept nogil +cdef UINT64_t bs_flip(UINT64_t value, SIZE_t i) noexcept nogil +cdef UINT64_t bs_flip_all(UINT64_t value, SIZE_t n_low_bits) noexcept nogil +cdef bint bs_get(UINT64_t value, SIZE_t i) noexcept nogil +cdef UINT64_t bs_from_template(UINT64_t template, + INT32_t *cat_offs, + SIZE_t ncats_present) noexcept nogil diff --git a/sklearn/tree/_utils.pyx b/sklearn/tree/_utils.pyx index 02dc7cf426efc..d4c84f6ae0a55 100644 --- a/sklearn/tree/_utils.pyx +++ b/sklearn/tree/_utils.pyx @@ -13,6 +13,7 @@ from libc.math cimport isnan from libc.math cimport log as ln from libc.stdlib cimport free, realloc +from libcpp.vector cimport vector import numpy as np @@ -61,6 +62,13 @@ cdef inline cnp.ndarray sizet_ptr_to_ndarray(SIZE_t* data, SIZE_t size): return cnp.PyArray_SimpleNewFromData(1, shape, cnp.NPY_INTP, data).copy() +cdef inline cnp.ndarray int32_ptr_to_ndarray(INT32_t* data, SIZE_t size): + """Encapsulate data into a 1D numpy array of int32's.""" + cdef cnp.npy_intp shape[1] + shape[0] = size + return cnp.PyArray_SimpleNewFromData(1, shape, cnp.NPY_INT32, data).copy() + + cdef inline SIZE_t rand_int(SIZE_t low, SIZE_t high, UINT32_t* random_state) noexcept nogil: """Generate a random integer in [low; end).""" @@ -77,6 +85,110 @@ cdef inline double rand_uniform(double low, double high, cdef inline double log(double x) noexcept nogil: return ln(x) / ln(2.0) + +cdef inline void setup_cat_cache( + vector[UINT64_t]& cachebits, + UINT64_t cat_split, + INT32_t n_categories +) noexcept nogil: + """Populate the bits of the category cache from a split. + + Attributes + ---------- + cachebits : Reference of vector[UINT64_t] + This is a pointer to the output array. The size of the array should be + ``ceil(n_categories / 64)``. This function assumes the required + memory is allocated for the array by the caller. + cat_split : UINT64_t + If ``least significant bit == 0``: + It stores the split of the maximum 64 categories in its bits. + This is used in `BestSplitter`, and without loss of generality it + is assumed to be even, i.e. for any odd value there is an + equivalent even ``cat_split``. + If ``least significant bit == 1``: + It is a random split, and the 32 most significant bits of + ``cat_split`` contain the random seed of the split. The + ``n_categories`` lowest bits of ``cachebits`` are then filled with + random zeros and ones given the random seed. + n_categories : INT32_t + The number of categories. + """ + cdef INT32_t j + cdef UINT32_t rng_seed, val + + # cache_size is equal to cachebits.size() + cdef SIZE_t cache_size = (n_categories + 63) // 64 + + if n_categories > 0: + if cat_split & 1: + # RandomSplitter + for j in range(cache_size): + cachebits[j] = 0 + rng_seed = cat_split >> 32 + for j in range(n_categories): + val = rand_int(0, 2, &rng_seed) + if not val: + continue + cachebits[j // 64] = bs_set(cachebits[j // 64], j % 64) + else: + # BestSplitter + # In practice, cache_size here should ALWAYS be 1 + # XXX TODO: check cache_size == 1? + cachebits[0] = cat_split + + +cdef inline bint goes_left( + DTYPE_t feature_value, + Node* node, + const INT32_t[:] n_categories, + vector[UINT64_t]& cachebits +) noexcept nogil: + """Determine whether a sample goes to the left or right child node. + + For numerical features, ``(-inf, split.threshold]`` is the left child, and + ``(split.threshold, inf)`` the right child. + + For categorical features, if the corresponding bit for the category is set + in cachebits, the left child isused, and if not set, the right child. If + the given input category is larger than the ``n_categories``, the right + child is assumed. + + Attributes + ---------- + feature_value : DTYPE_t + The value of the feature for which the decision needs to be made. + split : SplitValue + The union (of DOUBLE_t and UINT64_t) indicating the split. However, it + is used (as a DOUBLE_t) only for numerical features. + n_categories : INT32_t + The number of categories present in the feature in question. The + feature is considered a numerical one and not a categorical one if + n_categories is negative. + cachebits : Reference of vector[UINT64_t] + The array containing the expantion of split.cat_split. The function + setup_cat_cache is the one filling it. + + Returns + ------- + result : bint + Indicating whether the left branch should be used. + """ + cdef SIZE_t idx + cdef INT32_t n_categories_feature = n_categories[node.feature] + + if n_categories_feature < 0: + # Non-categorical feature + return feature_value <= node.threshold + else: + # Categorical feature, using bit cache + if ( feature_value) < n_categories_feature: + idx = ( feature_value) // 64 + offset = ( feature_value) % 64 + return bs_get(cachebits[idx], offset) + else: + return 0 + + # ============================================================================= # WeightedPQueue data structure # ============================================================================= @@ -470,3 +582,29 @@ def _any_isnan_axis0(const DTYPE_t[:, :] X): isnan_out[j] = True break return np.asarray(isnan_out) + + +cdef inline UINT64_t bs_set(UINT64_t value, SIZE_t i) noexcept nogil: + return value | ( 1) << i + +cdef inline UINT64_t bs_reset(UINT64_t value, SIZE_t i) noexcept nogil: + return value & ~(( 1) << i) + +cdef inline UINT64_t bs_flip(UINT64_t value, SIZE_t i) noexcept nogil: + return value ^ ( 1) << i + +cdef inline UINT64_t bs_flip_all(UINT64_t value, SIZE_t n_low_bits) noexcept nogil: + return (~value) & ((~( 0)) >> (64 - n_low_bits)) + +cdef inline bint bs_get(UINT64_t value, SIZE_t i) noexcept nogil: + return (value >> i) & ( 1) + +cdef inline UINT64_t bs_from_template(UINT64_t template, + INT32_t *cat_offs, + SIZE_t ncats_present) noexcept nogil: + cdef SIZE_t i + cdef UINT64_t value = 0 + for i in range(ncats_present): + value |= (template & + (( 1) << i)) << cat_offs[i] + return value diff --git a/sklearn/tree/tests/test_categorical_tree.py b/sklearn/tree/tests/test_categorical_tree.py new file mode 100644 index 0000000000000..5ce5781cb8932 --- /dev/null +++ b/sklearn/tree/tests/test_categorical_tree.py @@ -0,0 +1,284 @@ +import numpy as np +import pytest +from scipy.sparse import csc_matrix + +from sklearn import datasets +from sklearn.random_projection import _sparse_random_matrix +from sklearn.tree import ( + DecisionTreeClassifier, + DecisionTreeRegressor, + ExtraTreeClassifier, + ExtraTreeRegressor, +) +from sklearn.utils.validation import check_random_state + +CLF_CRITERIONS = ("gini", "log_loss") +REG_CRITERIONS = ("squared_error", "absolute_error", "friedman_mse", "poisson") + +CLF_TREES = { + "DecisionTreeClassifier": DecisionTreeClassifier, + "ExtraTreeClassifier": ExtraTreeClassifier, +} + +REG_TREES = { + "DecisionTreeRegressor": DecisionTreeRegressor, + "ExtraTreeRegressor": ExtraTreeRegressor, +} + +ALL_TREES: dict = dict() +ALL_TREES.update(CLF_TREES) +ALL_TREES.update(REG_TREES) + + +X_small = np.array( + [ + [0, 0, 4, 0, 0, 0, 1, -14, 0, -4, 0, 0, 0, 0], + [0, 0, 5, 3, 0, -4, 0, 0, 1, -5, 0.2, 0, 4, 1], + [-1, -1, 0, 0, -4.5, 0, 0, 2.1, 1, 0, 0, -4.5, 0, 1], + [-1, -1, 0, -1.2, 0, 0, 0, 0, 0, 0, 0.2, 0, 0, 1], + [-1, -1, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 1], + [-1, -2, 0, 4, -3, 10, 4, 0, -3.2, 0, 4, 3, -4, 1], + [2.11, 0, -6, -0.5, 0, 11, 0, 0, -3.2, 6, 0.5, 0, -3, 1], + [2.11, 0, -6, -0.5, 0, 11, 0, 0, -3.2, 6, 0, 0, -2, 1], + [2.11, 8, -6, -0.5, 0, 11, 0, 0, -3.2, 6, 0, 0, -2, 1], + [2.11, 8, -6, -0.5, 0, 11, 0, 0, -3.2, 6, 0.5, 0, -1, 0], + [2, 8, 5, 1, 0.5, -4, 10, 0, 1, -5, 3, 0, 2, 0], + [2, 0, 1, 1, 1, -1, 1, 0, 0, -2, 3, 0, 1, 0], + [2, 0, 1, 2, 3, -1, 10, 2, 0, -1, 1, 2, 2, 0], + [1, 1, 0, 2, 2, -1, 1, 2, 0, -5, 1, 2, 3, 0], + [3, 1, 0, 3, 0, -4, 10, 0, 1, -5, 3, 0, 3, 1], + [2.11, 8, -6, -0.5, 0, 1, 0, 0, -3.2, 6, 0.5, 0, -3, 1], + [2.11, 8, -6, -0.5, 0, 1, 0, 0, -3.2, 6, 1.5, 1, -1, -1], + [2.11, 8, -6, -0.5, 0, 10, 0, 0, -3.2, 6, 0.5, 0, -1, -1], + [2, 0, 5, 1, 0.5, -2, 10, 0, 1, -5, 3, 1, 0, -1], + [2, 0, 1, 1, 1, -2, 1, 0, 0, -2, 0, 0, 0, 1], + [2, 1, 1, 1, 2, -1, 10, 2, 0, -1, 0, 2, 1, 1], + [1, 1, 0, 0, 1, -3, 1, 2, 0, -5, 1, 2, 1, 1], + [3, 1, 0, 1, 0, -4, 1, 0, 1, -2, 0, 0, 1, 0], + ] +) + +y_small = [1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0] +y_small_reg = [ + 1.0, + 2.1, + 1.2, + 0.05, + 10, + 2.4, + 3.1, + 1.01, + 0.01, + 2.98, + 3.1, + 1.1, + 0.0, + 1.2, + 2, + 11, + 0, + 0, + 4.5, + 0.201, + 1.06, + 0.9, + 0, +] + +# toy sample +X = [[-2, -1], [-1, -1], [-1, -2], [1, 1], [1, 2], [2, 1]] +y = [-1, -1, -1, 1, 1, 1] +T = [[-1, -1], [2, 2], [3, 2]] +true_result = [-1, 1, 1] + +# also load the iris dataset +# and randomly permute it +iris = datasets.load_iris() +rng = np.random.RandomState(1) +perm = rng.permutation(iris.target.size) +iris.data = iris.data[perm] +iris.target = iris.target[perm] + +# also load the diabetes dataset +# and randomly permute it +diabetes = datasets.load_diabetes() +perm = rng.permutation(diabetes.target.size) +diabetes.data = diabetes.data[perm] +diabetes.target = diabetes.target[perm] + +digits = datasets.load_digits() +perm = rng.permutation(digits.target.size) +digits.data = digits.data[perm] +digits.target = digits.target[perm] + +random_state = check_random_state(0) +X_multilabel, y_multilabel = datasets.make_multilabel_classification( + random_state=0, n_samples=30, n_features=10 +) + +# NB: despite their names X_sparse_* are numpy arrays (and not sparse matrices) +X_sparse_pos = random_state.uniform(size=(20, 5)) +X_sparse_pos[X_sparse_pos <= 0.8] = 0.0 +y_random = random_state.randint(0, 4, size=(20,)) +X_sparse_mix = _sparse_random_matrix(20, 10, density=0.25, random_state=0).toarray() + + +DATASETS = { + "iris": {"X": iris.data, "y": iris.target}, + "diabetes": {"X": diabetes.data, "y": diabetes.target}, + "digits": {"X": digits.data, "y": digits.target}, + "toy": {"X": X, "y": y}, + "clf_small": {"X": X_small, "y": y_small}, + "reg_small": {"X": X_small, "y": y_small_reg}, + "multilabel": {"X": X_multilabel, "y": y_multilabel}, + "sparse-pos": {"X": X_sparse_pos, "y": y_random}, + "sparse-neg": {"X": -X_sparse_pos, "y": y_random}, + "sparse-mix": {"X": X_sparse_mix, "y": y_random}, + "zeros": {"X": np.zeros((20, 3)), "y": y_random}, +} +for name in DATASETS: + DATASETS[name]["X_sparse"] = csc_matrix(DATASETS[name]["X"]) + + +@pytest.mark.parametrize("name", ALL_TREES) +@pytest.mark.parametrize( + "categorical", + ["invalid string", [[0]], [False, False, False], [1, 2], [-3], [0, 0, 1]], +) +def test_invalid_categorical(name, categorical): + Tree = ALL_TREES[name] + if categorical == "invalid string": + with pytest.raises(ValueError, match="The 'categorical' parameter"): + Tree(categorical=categorical).fit(X, y) + else: + with pytest.raises(ValueError, match="Invalid value for categorical"): + Tree(categorical=categorical).fit(X, y) + + +@pytest.mark.parametrize("name", ALL_TREES) +def test_no_sparse_with_categorical(name): + # Currently we do not support sparse categorical features + X, y, X_sparse = [DATASETS["clf_small"][z] for z in ["X", "y", "X_sparse"]] + Tree = ALL_TREES[name] + with pytest.raises( + NotImplementedError, match="Categorical features not supported with sparse" + ): + Tree(categorical=[6, 10]).fit(X_sparse, y) + + with pytest.raises( + NotImplementedError, match="Categorical features not supported with sparse" + ): + Tree(categorical=[6, 10]).fit(X, y).predict(X_sparse) + + +def _make_categorical( + n_rows: int, + n_numerical: int, + n_categorical: int, + cat_size: int, + n_num_meaningful: int, + n_cat_meaningful: int, + regression: bool, + return_tuple: bool, + random_state: int, +): + from sklearn.preprocessing import OneHotEncoder + + rng = np.random.RandomState(random_state) + + numeric = rng.standard_normal((n_rows, n_numerical)) + categorical = rng.randint(0, cat_size, (n_rows, n_categorical)) + categorical_ohe = OneHotEncoder(categories="auto").fit_transform( + categorical[:, :n_cat_meaningful] + ) + + data_meaningful = np.hstack( + (numeric[:, :n_num_meaningful], categorical_ohe.todense()) + ) + _, cols = data_meaningful.shape + coefs = rng.standard_normal(cols) + y = np.dot(data_meaningful, coefs) + y = np.asarray(y).reshape(-1) + X = np.hstack((numeric, categorical)) + + if not regression: + y = (y < y.mean()).astype(int) + + meaningful_features = np.r_[ + np.arange(n_num_meaningful), np.arange(n_cat_meaningful) + n_numerical + ] + + if return_tuple: + return X, y, meaningful_features + else: + return {"X": X, "y": y, "meaningful_features": meaningful_features} + + +@pytest.mark.parametrize("model", ALL_TREES) +@pytest.mark.parametrize( + "data_params", + [ + { + "n_rows": 1000, + "n_numerical": 5, + "n_categorical": 5, + "cat_size": 3, + "n_num_meaningful": 2, + "n_cat_meaningful": 3, + }, + { + "n_rows": 1000, + "n_numerical": 0, + "n_categorical": 5, + "cat_size": 3, + "n_num_meaningful": 0, + "n_cat_meaningful": 3, + }, + { + "n_rows": 1000, + "n_numerical": 5, + "n_categorical": 5, + "cat_size": 64, + "n_num_meaningful": 0, + "n_cat_meaningful": 2, + }, + { + "n_rows": 1000, + "n_numerical": 5, + "n_categorical": 5, + "cat_size": 3, + "n_num_meaningful": 0, + "n_cat_meaningful": 3, + }, + ], +) +def test_categorical_data(model, data_params): + # DecisionTrees are too slow for large category sizes. + if data_params["cat_size"] > 8 and "DecisionTree" in model: + pass + + X, y, meaningful_features = _make_categorical( + **data_params, regression=model in REG_TREES, return_tuple=True, random_state=42 + ) + rows, cols = X.shape + categorical_features = ( + np.arange(data_params["n_categorical"]) + data_params["n_numerical"] + ) + + model = ALL_TREES[model](random_state=43, categorical=categorical_features).fit( + X, y + ) + fi = model.feature_importances_ + bad_features = np.array([True] * cols) + bad_features[meaningful_features] = False + + good_ones = fi[meaningful_features] + bad_ones = fi[bad_features] + + # all good features should be more important than all bad features. + # XXX: or at least a large fraction of them + # assert np.all([np.all(x > bad_ones) for x in good_ones]) + assert np.mean([np.all(x > bad_ones) for x in good_ones]) > 0.7 + + leaves = model.tree_.children_left < 0 + assert np.all(model.tree_.impurity[leaves] < 1e-6) diff --git a/sklearn/tree/tests/test_tree.py b/sklearn/tree/tests/test_tree.py index ccca6d60ed48b..68a3c26a01eb8 100644 --- a/sklearn/tree/tests/test_tree.py +++ b/sklearn/tree/tests/test_tree.py @@ -2200,13 +2200,15 @@ def get_different_alignment_node_ndarray(node_ndarray): def reduce_tree_with_different_bitness(tree): new_dtype = np.int64 if _IS_32BIT else np.int32 - tree_cls, (n_features, n_classes, n_outputs), state = tree.__reduce__() + tree_cls, (n_features, n_classes, n_outputs, n_categories), state = ( + tree.__reduce__() + ) new_n_classes = n_classes.astype(new_dtype, casting="same_kind") new_state = state.copy() new_state["nodes"] = get_different_bitness_node_ndarray(new_state["nodes"]) - return (tree_cls, (n_features, new_n_classes, n_outputs), new_state) + return (tree_cls, (n_features, new_n_classes, n_outputs, n_categories), new_state) def test_different_bitness_pickle(): @@ -2371,7 +2373,9 @@ def test_splitter_serializable(Splitter): n_outputs, n_classes = 2, np.array([3, 2], dtype=np.intp) criterion = CRITERIA_CLF["gini"](n_outputs, n_classes) - splitter = Splitter(criterion, max_features, 5, 0.5, rng, monotonic_cst=None) + splitter = Splitter( + criterion, max_features, 5, 0.5, rng, monotonic_cst=None, breiman_shortcut=False + ) splitter_serialize = pickle.dumps(splitter) splitter_back = pickle.loads(splitter_serialize)