-
Notifications
You must be signed in to change notification settings - Fork 22
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat(EstimatorReport): Display the feature permutation importance #1365
base: 1320-featestimatorreport-display-the-feature-weights-for-linear-models
Are you sure you want to change the base?
Conversation
e8d40ac
to
475b6ef
Compare
475b6ef
to
3ca3f74
Compare
skore/src/skore/sklearn/_estimator/feature_importance_accessor.py
Outdated
Show resolved
Hide resolved
The coverage report in the comment from the report I get locally (which is 100%)... |
skore/src/skore/sklearn/_estimator/feature_importance_accessor.py
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is a good start, I have the following remarks:
- I think we can limit the API of
scoring
and only adopt what we do in thereport.metrics.report_metrics()
. - Otherwise it is only little details.
I think that the caching is OKish but not the best. To make it better, we might need to revisit the implementation the scikit-learn implementation and cache at a different level.
If someone request the permutation importance (with a seed) and a single metric and then request another computation, we relaunch again the prediction computation of the model while it was computed the first time. When passing a list, the cache of the scikit-learn scorer will save us. But we don't deliver in a multiple call scenario.
To properly do it, we either need:
- to reimplement a good bunch of the permutation importance such that cache the prediction of the model, or
- there might be a dirty way to intervene with the cache of the scorer but it would not be easy (not sure about this one).
skore/src/skore/sklearn/_estimator/feature_importance_accessor.py
Outdated
Show resolved
Hide resolved
n_jobs: Optional[int] = None, | ||
random_state: Optional[Union[int, RandomState]] = None, | ||
sample_weight: Optional[ArrayLike] = None, # Sample weights used in scoring. | ||
max_samples: float = 1.0, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's move this one under n_repeats
since it is more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sorry what do you mean?
skore/src/skore/sklearn/_estimator/feature_importance_accessor.py
Outdated
Show resolved
Hide resolved
skore/src/skore/sklearn/_estimator/feature_importance_accessor.py
Outdated
Show resolved
Hide resolved
skore/src/skore/sklearn/_estimator/feature_importance_accessor.py
Outdated
Show resolved
Hide resolved
n_jobs=n_jobs, | ||
random_state=random_state, | ||
sample_weight=sample_weight, | ||
max_samples=max_samples, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it makes sense to have an additional aggregate
params to compute an aggregated score like the mean.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The aggregation can happen after reloading from the cache because it will not be costly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We also need a flat_index
to flatten the index if desired.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added aggregation. I don't get if flat_index is necessary (take a look at the doctest)
) | ||
score = pd.DataFrame(data=data, index=index, columns=columns) | ||
|
||
# Unless random_state is an int (i.e. the call is deterministic), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is a good point. I think we have a bug in the prediction_error
plot when subsampling.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah... caching is hard
def case_several_scoring_numpy(): | ||
data = regression_data() | ||
|
||
kwargs = {"scoring": ["r2", "neg_root_mean_squared_error"], "random_state": 42} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would not accept the neg_***
and be consistent with what we have already in the metrics report.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
that opens a can of worms. how do I know what metrics to accept? using _SCORE_OR_LOSS_INFO
maybe?
data = score | ||
n_repeats = data.shape[1] | ||
index = pd.Index(feature_names, name="Feature") | ||
columns = pd.Index( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would still think that we need to have the score because we don't know what are the repeat related to.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the user passes scoring=None
, how do I know what metric was used?
39c86d5
to
48bee8b
Compare
d6f9fc3
to
9c8e914
Compare
3ebfa52
to
69eb4f4
Compare
9c8e914
to
1c5166f
Compare
Closes #1319
Todo:
# TODO
in codescoring
is a callable (is it cached?)random_state
is aRandomState
instance (is it cached?)random_state
is a RandomState then the call should not be cached, because reusing aRandomState
gives a different result.