Skip to content

Commit

Permalink
Add minimal test
Browse files Browse the repository at this point in the history
  • Loading branch information
CharlesCousyn committed Feb 3, 2025
1 parent 99d943d commit 24e23bc
Showing 1 changed file with 26 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np
import pytest

from shapiq.explainer.tree import TreeModel, TreeSHAPIQ
from shapiq.explainer.tree import TreeModel, TreeSHAPIQ, TreeExplainer


def test_init(dt_clf_model, background_clf_data):
Expand Down Expand Up @@ -132,3 +132,28 @@ def test_edge_case_params():
# test with max_order = 0
with pytest.raises(ValueError):
_ = TreeSHAPIQ(model=tree_model, max_order=0)

def test_no_bug_with_one_feature_tree():
# create the dataset
X = np.array(
[
[1, 1, 1, 1],
[1, 1, 1, 2],
[2, 1, 1, 1],
[3, 2, 1, 1],
]
)

# Define simple one feature tree
tree = {
"children_left": np.array([1, -1, -1]),
"children_right": np.array([2, -1, -1]),
"features": np.array([0, -2, -2]),
"thresholds": np.array([2.5, -2, -2]),
"values": np.array([0.5, 0.0, 1]),
"node_sample_weight": np.array([14, 5, 9]),
}
tree = TreeModel(**tree)
explainer = TreeExplainer(model=tree, index="SV", max_order=1)
shapley_values = explainer.explain(X[2])
print(shapley_values)

0 comments on commit 24e23bc

Please sign in to comment.