Skip to content

Commit

Permalink
Fix TreeExplainer instanciation when a tree in LightGBM Classifier mo…
Browse files Browse the repository at this point in the history
…del has n_features_in_tree = 1 (#310)

* Fix bug at instanciation when a tree has only one feature

* Modifty comments

* Correction using n_interpolation_size instead of _n_features_in_tree

* Add minimal test

* Fix styling

* ran pre-commit

---------

Co-authored-by: mmschlk <[email protected]>
  • Loading branch information
CharlesCousyn and mmschlk authored Feb 5, 2025
1 parent 0e27619 commit 4b04a2f
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 2 deletions.
8 changes: 7 additions & 1 deletion shapiq/explainer/tree/treeshapiq.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,13 @@ def _init_summary_polynomials(self):
interaction_order=order, n_features=self._n_features_in_tree
)
self.subset_ancestors_store[order] = subset_ancestors
self.D_store[order] = np.polynomial.chebyshev.chebpts2(self.n_interpolation_size)

# If the tree has only one feature, we assign a default value of 0
if self.n_interpolation_size == 1:
self.D_store[order] = np.array([0])
else:
self.D_store[order] = np.polynomial.chebyshev.chebpts2(self.n_interpolation_size)

self.D_powers_store[order] = self._cache(self.D_store[order])
if self._index in ("SV", "SII", "k-SII"):
self.Ns_store[order] = self._get_N(self.D_store[order])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np
import pytest

from shapiq.explainer.tree import TreeModel, TreeSHAPIQ
from shapiq.explainer.tree import TreeExplainer, TreeModel, TreeSHAPIQ


def test_init(dt_clf_model, background_clf_data):
Expand Down Expand Up @@ -132,3 +132,29 @@ def test_edge_case_params():
# test with max_order = 0
with pytest.raises(ValueError):
_ = TreeSHAPIQ(model=tree_model, max_order=0)


def test_no_bug_with_one_feature_tree():
# create the dataset
X = np.array(
[
[1, 1, 1, 1],
[1, 1, 1, 2],
[2, 1, 1, 1],
[3, 2, 1, 1],
]
)

# Define simple one feature tree
tree = {
"children_left": np.array([1, -1, -1]),
"children_right": np.array([2, -1, -1]),
"features": np.array([0, -2, -2]),
"thresholds": np.array([2.5, -2, -2]),
"values": np.array([0.5, 0.0, 1]),
"node_sample_weight": np.array([14, 5, 9]),
}
tree = TreeModel(**tree)
explainer = TreeExplainer(model=tree, index="SV", max_order=1)
shapley_values = explainer.explain(X[2])
print(shapley_values)

0 comments on commit 4b04a2f

Please sign in to comment.