Skip to content

Commit

Permalink
made InteractionValues.values array efficient (sum up to prediction)
Browse files Browse the repository at this point in the history
  • Loading branch information
mmschlk committed Jan 14, 2025
1 parent 649b871 commit 886bb6a
Show file tree
Hide file tree
Showing 7 changed files with 106 additions and 28 deletions.
31 changes: 27 additions & 4 deletions shapiq/explainer/_base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""The base Explainer classes for the shapiq package."""

from abc import abstractmethod
from typing import Optional
from warnings import warn

Expand Down Expand Up @@ -81,13 +82,35 @@ def _validate_data(self, data: np.ndarray, raise_error: bool = False) -> None:
else:
warn(message)

def explain(self, x: np.ndarray) -> InteractionValues:
"""Explain the model's prediction in terms of interaction values.
def explain(self, x: np.ndarray, *args, **kwargs) -> InteractionValues:
"""Explain a single prediction in terms of interaction values.
Args:
x: An instance/point/sample/observation to be explained.
x: A numpy array of a data point to be explained.
*args: Additional positional arguments passed to the explainer.
**kwargs: Additional keyword-only arguments passed to the explainer.
Returns:
The interaction values of the prediction.
"""
explanation = self.explain_function(x=x, *args, **kwargs)
if explanation.min_order == 0:
explanation[()] = explanation.baseline_value
return explanation

@abstractmethod
def explain_function(self, x: np.ndarray, *args, **kwargs) -> InteractionValues:
"""Explain a single prediction in terms of interaction values.
Args:
x: A numpy array of a data point to be explained.
*args: Additional positional arguments passed to the explainer.
**kwargs: Additional keyword-only arguments passed to the explainer.
Returns:
The interaction values of the prediction.
"""
return {}
raise NotImplementedError("The method `explain` must be implemented in a subclass.")

def explain_X(
self, X: np.ndarray, n_jobs=None, random_state=None, **kwargs
Expand Down
2 changes: 1 addition & 1 deletion shapiq/explainer/tabular.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def __init__(
self._max_order: int = max_order
self._approximator = self._init_approximator(approximator, self.index, self._max_order)

def explain(
def explain_function(
self, x: np.ndarray, budget: Optional[int] = None, random_state: Optional[int] = None
) -> InteractionValues:
"""Explains the model's predictions.
Expand Down
11 changes: 10 additions & 1 deletion shapiq/explainer/tree/explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,16 @@ def __init__(
]
self.baseline_value = self._compute_baseline_value()

def explain(self, x: np.ndarray) -> InteractionValues:
def explain_function(self, x: np.ndarray, **kwargs) -> InteractionValues:
"""Computes the Shapley Interaction values for a single instance.
Args:
x: The instance to explain as a 1-dimensional array.
**kwargs: Additional keyword arguments are ignored.
Returns:
The interaction values for the instance.
"""
if len(x.shape) != 1:
raise TypeError("explain expects a single instance, not a batch.")
# run treeshapiq for all trees
Expand Down
36 changes: 35 additions & 1 deletion shapiq/interaction_values.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,19 @@ def __post_init__(self) -> None:
)

if not isinstance(self.baseline_value, (int, float)):
raise TypeError("Baseline value must be provided as a number.")
raise TypeError(
f"Baseline value must be provided as a number. Got {self.baseline_value}."
)

# check if () is in the interaction_lookup if min_order is 0 -> add it to the end
if self.min_order == 0 and () not in self.interaction_lookup:
self.interaction_lookup[()] = len(self.interaction_lookup)
self.values = np.concatenate((self.values, np.array([self.baseline_value])))

# update the baseline value in the values vector if index is not SII
# # TODO: this might be a good idea check if this is okay to do
# if self.index != "SII" and self.baseline_value != self.values[self.interaction_lookup[()]]:
# self.values[self.interaction_lookup[()]] = self.baseline_value

@property
def dict_values(self) -> dict[tuple[int, ...], float]:
Expand Down Expand Up @@ -226,6 +238,28 @@ def __getitem__(self, item: Union[int, tuple[int, ...]]) -> float:
except KeyError:
return 0.0

def __setitem__(self, item: Union[int, tuple[int, ...]], value: float) -> None:
"""Sets the score for the given interaction.
Args:
item: The interaction as a tuple of integers for which to set the score. If ``item`` is an
integer it serves as the index to the values vector.
value: The value to set for the interaction.
Raises:
KeyError: If the interaction is not found in the InteractionValues object.
"""
try:
if isinstance(item, int):
self.values[item] = value
else:
item = tuple(sorted(item))
self.values[self.interaction_lookup[item]] = value
except Exception as e:
raise KeyError(
f"Interaction {item} not found in the InteractionValues. Unable to set a value."
) from e

def __eq__(self, other: object) -> bool:
"""Checks if two InteractionValues objects are equal.
Expand Down
12 changes: 12 additions & 0 deletions tests/test_base_interaction_values.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,18 @@ def test_initialization(index, n, min_order, max_order, estimation_budget, estim
assert interaction_values[0] == interaction_values.values[0]
assert interaction_values[-1] == interaction_values.values[-1]

# check setitem
interaction_values[(0,)] = 999_999
assert interaction_values[(0,)] == 999_999

# check setitem with integer as input
interaction_values[0] = 111_111
assert interaction_values[0] == 111_111

# check setitem raises error for invalid interaction
with pytest.raises(KeyError):
interaction_values[(100, 101)] = 0

# test __len__
assert len(interaction_values) == len(interaction_values.values)

Expand Down
38 changes: 19 additions & 19 deletions tests/tests_explainer/test_explainer_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def test_torch_reg(torch_reg_model, background_reg_data):
explainer = Explainer(model=torch_reg_model, data=background_reg_data)
values = explainer.explain(x_explain)
assert isinstance(values, InteractionValues)
sum_of_values = sum(values.values) + values.baseline_value
sum_of_values = sum(values.values)
assert prediction == pytest.approx(sum_of_values, rel=0.01)


Expand All @@ -32,13 +32,13 @@ def test_torch_clf(torch_clf_model, background_clf_data):
explainer = Explainer(model=torch_clf_model, data=background_clf_data, class_index=2)
values = explainer.explain(x_explain)
assert isinstance(values, InteractionValues)
sum_of_values = sum(values.values) + values.baseline_value
sum_of_values = sum(values.values)
assert prediction[2] == pytest.approx(sum_of_values, rel=0.001)

explainer = Explainer(model=torch_clf_model, data=background_clf_data, class_index=0)
values = explainer.explain(x_explain)
assert isinstance(values, InteractionValues)
sum_of_values = sum(values.values) + values.baseline_value
sum_of_values = sum(values.values)
assert prediction[0] == pytest.approx(sum_of_values, rel=0.001)


Expand All @@ -51,13 +51,13 @@ def test_sklearn_clf_tree(dt_clf_model, background_clf_data):
explainer = TabularExplainer(model=dt_clf_model, data=background_clf_data, class_index=2)
values = explainer.explain(x_explain)
assert isinstance(values, InteractionValues)
sum_of_values = sum(values.values) + values.baseline_value
sum_of_values = sum(values.values)
assert prediction[2] == pytest.approx(sum_of_values, abs=0.001)

explainer = TabularExplainer(model=dt_clf_model, data=background_clf_data, class_index=0)
values = explainer.explain(x_explain)
assert isinstance(values, InteractionValues)
sum_of_values = sum(values.values) + values.baseline_value
sum_of_values = sum(values.values)
assert prediction[0] == pytest.approx(sum_of_values, abs=0.001)

# do the same with the bare explainer (only for class_label=2)
Expand All @@ -78,7 +78,7 @@ def test_sklearn_reg_tree(dt_reg_model, background_reg_data):
explainer = TabularExplainer(model=dt_reg_model, data=background_reg_data)
values = explainer.explain(x_explain)
assert isinstance(values, InteractionValues)
sum_of_values = sum(values.values) + values.baseline_value
sum_of_values = sum(values.values)
assert prediction == pytest.approx(sum_of_values, abs=0.001)

# do the same with the bare explainer
Expand All @@ -99,13 +99,13 @@ def test_sklearn_clf_forest(rf_clf_model, background_clf_data):
explainer = TabularExplainer(model=rf_clf_model, data=background_clf_data, class_index=2)
values = explainer.explain(x_explain)
assert isinstance(values, InteractionValues)
sum_of_values = sum(values.values) + values.baseline_value
sum_of_values = sum(values.values)
assert prediction[2] == pytest.approx(sum_of_values, rel=0.001)

explainer = TabularExplainer(model=rf_clf_model, data=background_clf_data, class_index=0)
values = explainer.explain(x_explain)
assert isinstance(values, InteractionValues)
sum_of_values = sum(values.values) + values.baseline_value
sum_of_values = sum(values.values)
assert prediction[0] == pytest.approx(sum_of_values, rel=0.001)

# do the same with the bare explainer (only for class_label=2)
Expand All @@ -125,14 +125,14 @@ def test_sklearn_reg_forest(rf_reg_model, background_reg_data):
explainer = TabularExplainer(model=rf_reg_model, data=background_reg_data)
values = explainer.explain(x_explain)
assert isinstance(values, InteractionValues)
sum_of_values = sum(values.values) + values.baseline_value
sum_of_values = sum(values.values)
assert prediction == pytest.approx(sum_of_values)

# do the same with the bare explainer
explainer = Explainer(model=rf_reg_model, data=background_reg_data)
values = explainer.explain(x_explain)
assert isinstance(values, InteractionValues)
sum_of_values = sum(values.values) + values.baseline_value
sum_of_values = sum(values.values)
assert prediction == pytest.approx(sum_of_values, rel=0.01)


Expand All @@ -145,20 +145,20 @@ def test_sklearn_clf_logistic_regression(lr_clf_model, background_clf_data):
explainer = TabularExplainer(model=lr_clf_model, data=background_clf_data, class_index=2)
values = explainer.explain(x_explain)
assert isinstance(values, InteractionValues)
sum_of_values = sum(values.values) + values.baseline_value
sum_of_values = sum(values.values)
assert prediction[2] == pytest.approx(sum_of_values)

explainer = TabularExplainer(model=lr_clf_model, data=background_clf_data, class_index=0)
values = explainer.explain(x_explain)
assert isinstance(values, InteractionValues)
sum_of_values = sum(values.values) + values.baseline_value
sum_of_values = sum(values.values)
assert prediction[0] == pytest.approx(sum_of_values)

# do the same with the bare explainer (only for class_label=2)
explainer = Explainer(model=lr_clf_model, data=background_clf_data, class_index=2)
values = explainer.explain(x_explain)
assert isinstance(values, InteractionValues)
sum_of_values = sum(values.values) + values.baseline_value
sum_of_values = sum(values.values)
assert prediction[2] == pytest.approx(sum_of_values)


Expand All @@ -171,14 +171,14 @@ def test_sklearn_reg_linear_regression(lr_reg_model, background_reg_data):
explainer = TabularExplainer(model=lr_reg_model, data=background_reg_data)
values = explainer.explain(x_explain)
assert isinstance(values, InteractionValues)
sum_of_values = sum(values.values) + values.baseline_value
sum_of_values = sum(values.values)
assert prediction == pytest.approx(sum_of_values)

# do the same with the bare explainer
explainer = Explainer(model=lr_reg_model, data=background_reg_data)
values = explainer.explain(x_explain)
assert isinstance(values, InteractionValues)
sum_of_values = sum(values.values) + values.baseline_value
sum_of_values = sum(values.values)
assert prediction == pytest.approx(sum_of_values)


Expand All @@ -191,7 +191,7 @@ def test_lightgbm_reg(lightgbm_reg_model, background_reg_data):
explainer = TabularExplainer(model=lightgbm_reg_model, data=background_reg_data)
values = explainer.explain(x_explain)
assert isinstance(values, InteractionValues)
sum_of_values = sum(values.values) + values.baseline_value
sum_of_values = sum(values.values)
assert prediction == pytest.approx(sum_of_values)

# do the same with the bare explainer
Expand All @@ -212,13 +212,13 @@ def test_lightgbm_clf(lightgbm_clf_model, background_clf_data):
explainer = TabularExplainer(model=lightgbm_clf_model, data=background_clf_data, class_index=2)
values = explainer.explain(x_explain)
assert isinstance(values, InteractionValues)
sum_of_values = sum(values.values) + values.baseline_value
sum_of_values = sum(values.values)
assert prediction[2] == pytest.approx(sum_of_values, rel=0.001)

explainer = TabularExplainer(model=lightgbm_clf_model, data=background_clf_data, class_index=0)
values = explainer.explain(x_explain)
assert isinstance(values, InteractionValues)
sum_of_values = sum(values.values) + values.baseline_value
sum_of_values = sum(values.values)
assert prediction[0] == pytest.approx(sum_of_values, rel=0.001)

# do the same with the bare explainer (only for class_label=2)
Expand All @@ -241,7 +241,7 @@ def test_isoforest_clf(if_clf_model, if_clf_dataset):
explainer = TabularExplainer(model=if_clf_model, data=x_data, class_index=2)
values = explainer.explain(x_explain)
assert isinstance(values, InteractionValues)
sum_of_values = sum(values.values) + values.baseline_value
sum_of_values = sum(values.values)
assert pytest.approx(sum_of_values, abs=0.001) == prediction

# do the same with the bare explainer
Expand Down
4 changes: 2 additions & 2 deletions tests/tests_explainer/test_explainer_tabular.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,8 +169,8 @@ def test_explain(dt_model, data, index, budget, max_order, imputer):
# test for efficiency
if index in ("FSII", "k-SII"):
prediction = float(model_function(x)[0])
sum_of_values = float(np.sum(interaction_values.values) + interaction_values.baseline_value)
assert interaction_values[()] == 0.0
sum_of_values = float(np.sum(interaction_values.values))
assert pytest.approx(interaction_values[()]) == interaction_values.baseline_value
assert pytest.approx(sum_of_values, 0.01) == prediction


Expand Down

0 comments on commit 886bb6a

Please sign in to comment.