Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(EstimatorReport): Display the feature permutation importance #1365

Open
wants to merge 30 commits into
base: 1320-featestimatorreport-display-the-feature-weights-for-linear-models
Choose a base branch
from

Conversation

auguste-probabl
Copy link
Contributor

@auguste-probabl auguste-probabl commented Feb 26, 2025

Closes #1319

Todo:

  • Add example
  • Coverage at 100%
  • Add API docs
  • Check for # TODO in code
  • Check what happens if scoring is a callable (is it cached?)
    • scoring was not included in the cache key at all! Fixed
  • Check what happens if random_state is a RandomState instance (is it cached?)
    • If random_state is a RandomState then the call should not be cached, because reusing a RandomState gives a different result.

@auguste-probabl auguste-probabl changed the base branch from main to 1320-featestimatorreport-display-the-feature-weights-for-linear-models February 26, 2025 15:38
Copy link
Contributor

github-actions bot commented Feb 26, 2025

Coverage

Coverage Report for backend
FileStmtsMissCoverMissing
venv/lib/python3.12/site-packages/skore
   __init__.py150100% 
   __main__.py880%3–19
   _config.py280100% 
   exceptions.py440%4–23
venv/lib/python3.12/site-packages/skore/persistence
   __init__.py00100% 
venv/lib/python3.12/site-packages/skore/persistence/item
   __init__.py56393%96–99
   altair_chart_item.py19191%14
   item.py22195%86
   matplotlib_figure_item.py36195%19
   media_item.py220100% 
   numpy_array_item.py27194%16
   pandas_dataframe_item.py29194%14
   pandas_series_item.py29194%14
   pickle_item.py220100% 
   pillow_image_item.py25193%15
   plotly_figure_item.py20192%14
   polars_dataframe_item.py27194%14
   polars_series_item.py22192%14
   primitive_item.py23291%13–15
   sklearn_base_estimator_item.py29194%15
   skrub_table_report_item.py10186%11
venv/lib/python3.12/site-packages/skore/persistence/repository
   __init__.py20100% 
   item_repository.py59591%15–16, 202–203, 226
venv/lib/python3.12/site-packages/skore/persistence/storage
   __init__.py40100% 
   abstract_storage.py220100% 
   disk_cache_storage.py33195%44
   in_memory_storage.py200100% 
venv/lib/python3.12/site-packages/skore/persistence/view
   __init__.py220%3–5
   view.py550%3–20
venv/lib/python3.12/site-packages/skore/project
   __init__.py30100% 
   _open.py50100% 
   project.py81199%284
venv/lib/python3.12/site-packages/skore/sklearn
   __init__.py60100% 
   _base.py1621392%43, 115, 118, 171–180, 192–>197, 212, 215–216
   find_ml_task.py61099%136–>144
   types.py13285%33, 61
venv/lib/python3.12/site-packages/skore/sklearn/_comparison
   __init__.py50100% 
   metrics_accessor.py164297%165, 166–>168, 1218
   precision_recall_curve_display.py73197%196–>199, 304
   prediction_error_display.py671078%97, 154–>exit, 209, 214–218, 227, 231, 236–238
   report.py64196%16, 251–>254
   roc_curve_display.py69196%204–>213, 213–>216, 308
venv/lib/python3.12/site-packages/skore/sklearn/_cross_validation
   __init__.py50100% 
   metrics_accessor.py170099%142–>144, 144–>146
   report.py105198%22
venv/lib/python3.12/site-packages/skore/sklearn/_estimator
   __init__.py70100% 
   feature_importance_accessor.py89099%271–>277
   metrics_accessor.py3251195%166–175, 203–>212, 211, 241, 252–>254, 282, 309–313, 328, 351, 363, 364–>366
   report.py127197%22, 229–>235, 237–>239
venv/lib/python3.12/site-packages/skore/sklearn/_plot
   __init__.py40100% 
   precision_recall_curve.py129198%240–>257, 329
   prediction_error.py102198%173, 189–>192
   roc_curve.py1430100% 
   style.py140100% 
   utils.py99594%31, 55–57, 61
venv/lib/python3.12/site-packages/skore/sklearn/train_test_split
   __init__.py00100% 
   train_test_split.py36294%16–17
venv/lib/python3.12/site-packages/skore/sklearn/train_test_split/warning
   __init__.py80100% 
   high_class_imbalance_too_few_examples_warning.py17190%79
   high_class_imbalance_warning.py180100% 
   random_state_unset_warning.py12188%15
   shuffle_true_warning.py10183%46
   stratify_is_set_warning.py12188%15
   time_based_column_warning.py23286%17, 73
   train_test_split_warning.py5180%21
venv/lib/python3.12/site-packages/skore/utils
   __init__.py60100% 
   _accessor.py170100% 
   _environment.py27270%1–51
   _index.py50100% 
   _logger.py22220%3–38
   _parallel.py38388%23–33, 124
   _patch.py13553%21–37
   _progress_bar.py340100% 
   _show_versions.py330100% 
TOTAL304815994% 

Tests Skipped Failures Errors Time
665 3 💤 0 ❌ 0 🔥 48.464s ⏱️

@auguste-probabl auguste-probabl force-pushed the 1319-featestimatorreport-display-the-feature-permutation-importance branch from e8d40ac to 475b6ef Compare February 27, 2025 17:27
Copy link
Contributor

github-actions bot commented Feb 27, 2025

Documentation preview @ d6f9fc3

@auguste-probabl auguste-probabl force-pushed the 1319-featestimatorreport-display-the-feature-permutation-importance branch from 475b6ef to 3ca3f74 Compare February 28, 2025 11:27
@auguste-probabl auguste-probabl marked this pull request as ready for review March 3, 2025 10:24
@auguste-probabl
Copy link
Contributor Author

The coverage report in the comment from the report I get locally (which is 100%)...

@glemaitre glemaitre self-requested a review March 3, 2025 15:35
Copy link
Member

@glemaitre glemaitre left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is a good start, I have the following remarks:

  • I think we can limit the API of scoring and only adopt what we do in the report.metrics.report_metrics().
  • Otherwise it is only little details.

I think that the caching is OKish but not the best. To make it better, we might need to revisit the implementation the scikit-learn implementation and cache at a different level.

If someone request the permutation importance (with a seed) and a single metric and then request another computation, we relaunch again the prediction computation of the model while it was computed the first time. When passing a list, the cache of the scikit-learn scorer will save us. But we don't deliver in a multiple call scenario.

To properly do it, we either need:

  • to reimplement a good bunch of the permutation importance such that cache the prediction of the model, or
  • there might be a dirty way to intervene with the cache of the scorer but it would not be easy (not sure about this one).

n_jobs: Optional[int] = None,
random_state: Optional[Union[int, RandomState]] = None,
sample_weight: Optional[ArrayLike] = None, # Sample weights used in scoring.
max_samples: float = 1.0,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's move this one under n_repeats since it is more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry what do you mean?

n_jobs=n_jobs,
random_state=random_state,
sample_weight=sample_weight,
max_samples=max_samples,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it makes sense to have an additional aggregate params to compute an aggregated score like the mean.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The aggregation can happen after reloading from the cache because it will not be costly.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We also need a flat_index to flatten the index if desired.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added aggregation. I don't get if flat_index is necessary (take a look at the doctest)

)
score = pd.DataFrame(data=data, index=index, columns=columns)

# Unless random_state is an int (i.e. the call is deterministic),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is a good point. I think we have a bug in the prediction_error plot when subsampling.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah... caching is hard

def case_several_scoring_numpy():
data = regression_data()

kwargs = {"scoring": ["r2", "neg_root_mean_squared_error"], "random_state": 42}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would not accept the neg_*** and be consistent with what we have already in the metrics report.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that opens a can of worms. how do I know what metrics to accept? using _SCORE_OR_LOSS_INFO maybe?

data = score
n_repeats = data.shape[1]
index = pd.Index(feature_names, name="Feature")
columns = pd.Index(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would still think that we need to have the score because we don't know what are the repeat related to.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the user passes scoring=None, how do I know what metric was used?

@auguste-probabl auguste-probabl force-pushed the 1319-featestimatorreport-display-the-feature-permutation-importance branch 2 times, most recently from 39c86d5 to 48bee8b Compare March 4, 2025 10:55
@auguste-probabl auguste-probabl force-pushed the 1319-featestimatorreport-display-the-feature-permutation-importance branch from d6f9fc3 to 9c8e914 Compare March 4, 2025 14:25
@auguste-probabl auguste-probabl force-pushed the 1320-featestimatorreport-display-the-feature-weights-for-linear-models branch from 3ebfa52 to 69eb4f4 Compare March 4, 2025 14:25
@auguste-probabl auguste-probabl force-pushed the 1319-featestimatorreport-display-the-feature-permutation-importance branch from 9c8e914 to 1c5166f Compare March 4, 2025 14:37
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Feat(EstimatorReport): Display the feature permutation importance
3 participants