From 24e23bc95638c17f3669cab0ee597630a0e0e986 Mon Sep 17 00:00:00 2001 From: CharlesCousyn Date: Mon, 3 Feb 2025 14:28:15 -0500 Subject: [PATCH] Add minimal test --- .../test_tree_treeshapiq.py | 27 ++++++++++++++++++- 1 file changed, 26 insertions(+), 1 deletion(-) diff --git a/tests/tests_explainer/tests_tree_explainer/test_tree_treeshapiq.py b/tests/tests_explainer/tests_tree_explainer/test_tree_treeshapiq.py index de724728..85aad268 100644 --- a/tests/tests_explainer/tests_tree_explainer/test_tree_treeshapiq.py +++ b/tests/tests_explainer/tests_tree_explainer/test_tree_treeshapiq.py @@ -3,7 +3,7 @@ import numpy as np import pytest -from shapiq.explainer.tree import TreeModel, TreeSHAPIQ +from shapiq.explainer.tree import TreeModel, TreeSHAPIQ, TreeExplainer def test_init(dt_clf_model, background_clf_data): @@ -132,3 +132,28 @@ def test_edge_case_params(): # test with max_order = 0 with pytest.raises(ValueError): _ = TreeSHAPIQ(model=tree_model, max_order=0) + +def test_no_bug_with_one_feature_tree(): + # create the dataset + X = np.array( + [ + [1, 1, 1, 1], + [1, 1, 1, 2], + [2, 1, 1, 1], + [3, 2, 1, 1], + ] + ) + + # Define simple one feature tree + tree = { + "children_left": np.array([1, -1, -1]), + "children_right": np.array([2, -1, -1]), + "features": np.array([0, -2, -2]), + "thresholds": np.array([2.5, -2, -2]), + "values": np.array([0.5, 0.0, 1]), + "node_sample_weight": np.array([14, 5, 9]), + } + tree = TreeModel(**tree) + explainer = TreeExplainer(model=tree, index="SV", max_order=1) + shapley_values = explainer.explain(X[2]) + print(shapley_values)