Skip to content

Commit

Permalink
add aggregate argument
Browse files Browse the repository at this point in the history
  • Loading branch information
auguste-probabl committed Mar 4, 2025
1 parent af7a97a commit 1c5166f
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 0 deletions.
33 changes: 33 additions & 0 deletions skore/src/skore/sklearn/_estimator/feature_importance_accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
# - a dictionary with metric names as keys and callables a values.
Scoring = Union[str, Callable, Iterable[str], dict[str, Callable]]

Aggregation = Literal["mean", "std"]


class _FeatureImportanceAccessor(_BaseAccessor["EstimatorReport"], DirNamesMixin):
"""Accessor for feature importance related operations.
Expand Down Expand Up @@ -111,6 +113,7 @@ def feature_permutation(
data_source: DataSource = "test",
X: Optional[ArrayLike] = None,
y: Optional[ArrayLike] = None,
aggregate: Optional[Union[Aggregation, list[Aggregation]]] = None,
scoring: Optional[Scoring] = None,
n_repeats: int = 5,
n_jobs: Optional[int] = None,
Expand Down Expand Up @@ -147,6 +150,9 @@ def feature_permutation(
New target on which to compute the metric. By default, we use the test
target provided when creating the report.
aggregate : {"mean", "std"} or list of such str, default=None
Function to aggregate the scores across the repeats.
scoring : str, callable, list, tuple, or dict, default=None
The scorer to pass to :func:`~sklearn.inspection.permutation_importance`.
Expand Down Expand Up @@ -211,12 +217,30 @@ def feature_permutation(
Feature #7 0.023... 0.017...
Feature #8 0.077... 0.077...
Feature #9 0.011... 0.023...
>>> report.feature_importance.feature_permutation(
... n_repeats=2,
... aggregate=["mean", "std"],
... random_state=0,
... )
mean std
Feature
Feature #0 0.001... 0.002...
Feature #1 0.009... 0.007...
Feature #2 0.128... 0.019...
Feature #3 0.074... 0.004...
Feature #4 0.000... 0.000...
Feature #5 -0.000... 0.002...
Feature #6 0.031... 0.002...
Feature #7 0.020... 0.004...
Feature #8 0.077... 0.000...
Feature #9 0.017... 0.008...
"""
return self._feature_permutation(
data_source=data_source,
data_source_hash=None,
X=X,
y=y,
aggregate=aggregate,
scoring=scoring,
n_repeats=n_repeats,
n_jobs=n_jobs,
Expand All @@ -231,6 +255,7 @@ def _feature_permutation(
data_source_hash: Optional[int] = None,
X: Optional[ArrayLike] = None,
y: Optional[ArrayLike] = None,
aggregate: Optional[Union[Aggregation, list[Aggregation]]] = None,
scoring: Optional[Scoring] = None,
n_repeats: int = 5,
n_jobs: Optional[int] = None,
Expand Down Expand Up @@ -269,6 +294,9 @@ def _feature_permutation(
else:
cache_key_parts.append(scoring)

# aggregate is not included in the cache
# in order to trade off computation for storage

# order arguments by key to ensure cache works
# n_jobs variable should not be in the cache
kwargs = {
Expand Down Expand Up @@ -334,6 +362,11 @@ def _feature_permutation(
if cache_key is not None:
self._parent._cache[cache_key] = score

if aggregate:
if isinstance(aggregate, str):
aggregate = [aggregate]
score = score.aggregate(func=aggregate, axis=1)

return score

####################################################################################
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,19 @@ def case_X_y():
return data, kwargs, expected


def case_aggregate():
data = regression_data()

kwargs = {"data_source": "train", "aggregate": "mean", "random_state": 42}

expected = pd.DataFrame(
data=np.zeros((3, 1)),
index=pd.Index((f"Feature #{i}" for i in range(3)), name="Feature"),
columns=pd.Index(["mean"]),
)
return data, kwargs, expected


def case_default_args_dataframe():
data = regression_data_dataframe()

Expand Down Expand Up @@ -179,6 +192,7 @@ def case_several_scoring_dataframe():
case_r2_numpy,
case_train_numpy,
case_several_scoring_numpy,
case_aggregate,
case_default_args_dataframe,
case_r2_dataframe,
case_train_dataframe,
Expand Down

0 comments on commit 1c5166f

Please sign in to comment.