Skip to content

Commit

Permalink
Fixed unit tests
Browse files Browse the repository at this point in the history
Signed-off-by: Adam Li <[email protected]>
  • Loading branch information
adam2392 committed Jul 20, 2023
1 parent 9a614f4 commit 8112691
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 25 deletions.
11 changes: 8 additions & 3 deletions sklearn/ensemble/_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -784,9 +784,14 @@ def predict_quantiles(self, X, quantiles=0.5, method="nearest"):
)

# get quantiles across all leaf node samples
y_hat[idx, ...] = np.quantile(
leaf_node_samples, quantiles, axis=0, method=method
)
try:
y_hat[idx, ...] = np.quantile(
leaf_node_samples, quantiles, axis=0, method=method
)
except TypeError:
y_hat[idx, ...] = np.quantile(
leaf_node_samples, quantiles, axis=0, interpolation=method
)

if is_classifier(self):
if self.n_outputs_ == 1:
Expand Down
27 changes: 5 additions & 22 deletions sklearn/tree/_tree.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -1008,16 +1008,8 @@ cdef class BaseTree:
cache_mgr = CategoryCacheMgr()
cache_mgr.populate(self.nodes, self.node_count, self.n_categories)
cdef vector[vector[UINT64_t]] cat_caches = cache_mgr.bits
# cdef vector[UINT64_t] cache = NULL

cdef const INT32_t[:] n_categories = self.n_categories

# apply Cache to speed up categorical "apply"
# cache_mgr = CategoryCacheMgr()
# cache_mgr.populate(self.nodes, self.node_count, self.n_categories)
# cdef UINT64_t** cat_caches = cache_mgr.bits
# cdef UINT64_t* cache = NULL

with nogil:
for i in range(n_samples):
node = self.nodes
Expand All @@ -1034,9 +1026,6 @@ cdef class BaseTree:
node = &self.nodes[node.right_child]
elif goes_left(
X_i_node_feature,
# node.split_value,
# node.threshold,
# self.n_categories[node.feature],
node,
n_categories,
cache
Expand Down Expand Up @@ -1082,7 +1071,6 @@ cdef class BaseTree:
cache_mgr = CategoryCacheMgr()
cache_mgr.populate(self.nodes, self.node_count, self.n_categories)
cdef vector[vector[UINT64_t]] cat_caches = cache_mgr.bits
# cdef vector[UINT64_t] cache = NULL

cdef const INT32_t[:] n_categories = self.n_categories
# feature_to_sample as a data structure records the last seen sample
Expand Down Expand Up @@ -1114,9 +1102,6 @@ cdef class BaseTree:

if goes_left(
feature_value,
# node.split_value,
# node.threshold,
# self.n_categories[node.feature],
node,
n_categories,
cache
Expand Down Expand Up @@ -1650,21 +1635,19 @@ cdef class Tree(BaseTree):
self.n_classes = NULL
safe_realloc(&self.n_classes, n_outputs)

self.n_categories = NULL
safe_realloc(&self.n_categories, n_features)
cdef SIZE_t k

# n-categories is a 1D array of size n_features
# self.n_categories = np.empty(n_features, dtype=np.int32)
# self.n_categories = n_categories
self.n_categories = NULL
safe_realloc(&self.n_categories, n_features)
for k in range(n_features):
self.n_categories[k] = n_categories[k]

self.max_n_classes = np.max(n_classes)
self.value_stride = n_outputs * self.max_n_classes

cdef SIZE_t k
for k in range(n_outputs):
self.n_classes[k] = n_classes[k]
for k in range(n_features):
self.n_categories[k] = n_categories[k]

# Inner structures
self.max_depth = 0
Expand Down

0 comments on commit 8112691

Please sign in to comment.