Skip to content

Commit

Permalink
feat: Add dedicated Display classes to ComparisonReport (#1309)
Browse files Browse the repository at this point in the history
Duplicate the plot classes for `ComparisonReport` use-cases:
- `RocCurveDisplay`,
- `PredictionRecallCurveDisplay`,
- `PredictionErrorDisplay`.

I've chosen to simplify a few things from the original classes, in
particular by removing the user's ability to customize the plot.
The plot tests strictly compare pixel by pixel the matplotlib result
with an expected image.

The next iteration should be to factorize these duplicated classes with
the one in `skore.sklearn._plot`.

---

**Roc curves**

![binary-classification](https://github.com/user-attachments/assets/b3e4fd8a-e201-4f3a-83be-01165c8a9c9d)

![multiclass-classification](https://github.com/user-attachments/assets/6ec2f490-7c55-483d-ada2-5af67a0ff1bd)

**PR curves**

![pr-binary-classification](https://github.com/user-attachments/assets/79a07266-5ddb-42e4-99cb-e6f3fefdfc6d)

![pr-multiclass-classification](https://github.com/user-attachments/assets/80b200cd-20af-4a51-8fe6-401f56e9b534)

**Prediction error curves**

![Figure_1](https://github.com/user-attachments/assets/d64de1de-84de-4f25-9fa9-7ee7e3e17ad9)

![Figure_2](https://github.com/user-attachments/assets/b2d8889e-6573-4dd7-8170-c5974fd767e0)
  • Loading branch information
thomass-dev authored Feb 20, 2025
1 parent 314dde5 commit 01dc294
Show file tree
Hide file tree
Showing 13 changed files with 1,844 additions and 26 deletions.
15 changes: 12 additions & 3 deletions examples/getting_started/plot_skore_getting_started.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,25 +162,34 @@
# %%
from skore import ComparisonReport

comparator = ComparisonReport(reports=[log_reg_report, rf_report])
comparison_report = ComparisonReport(reports=[log_reg_report, rf_report])

# %%
# As for the :class:`~skore.EstimatorReport` and the
# :class:`~skore.CrossValidationReport`, we have a helper:

# %%
comparator.help()
comparison_report.help()

# %%
# Let us display the result of our benchmark:

# %%
benchmark_metrics = comparator.metrics.report_metrics()
benchmark_metrics = comparison_report.metrics.report_metrics()
benchmark_metrics

# %%
# We have the result of our benchmark.

# %%
# We display the ROC curve for the two estimator reports we want to compare, by
# superimposing them on the same figure:

# %%
comparison_report.metrics.roc().plot()
plt.tight_layout()


# %%
# Train-test split with skore
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^
Expand Down
6 changes: 3 additions & 3 deletions skore/src/skore/sklearn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@
from skore.sklearn.train_test_split.train_test_split import train_test_split

__all__ = [
"train_test_split",
"ComparisonReport",
"CrossValidationReport",
"EstimatorReport",
"ComparisonReport",
"RocCurveDisplay",
"PrecisionRecallCurveDisplay",
"PredictionErrorDisplay",
"RocCurveDisplay",
"train_test_split",
]
332 changes: 331 additions & 1 deletion skore/src/skore/sklearn/_comparison/metrics_accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,12 @@
from sklearn.utils.metaestimators import available_if

from skore.externals._pandas_accessors import DirNamesMixin
from skore.sklearn._base import _BaseAccessor
from skore.sklearn._base import _BaseAccessor, _get_cached_response_values
from skore.sklearn._comparison.precision_recall_curve_display import (
PrecisionRecallCurveDisplay,
)
from skore.sklearn._comparison.prediction_error_display import PredictionErrorDisplay
from skore.sklearn._comparison.roc_curve_display import RocCurveDisplay
from skore.utils._accessor import _check_supported_ml_task
from skore.utils._index import flatten_multi_index
from skore.utils._progress_bar import progress_decorator
Expand Down Expand Up @@ -1108,3 +1113,328 @@ def __repr__(self):
class_name="skore.ComparisonReport.metrics",
help_method_name="report.metrics.help()",
)

@progress_decorator(description="Computing predictions for display")
def _get_display(
self,
*,
X,
y,
data_source,
response_method,
display_class,
display_kwargs,
):
"""Get the display from the cache or compute it.
Parameters
----------
X : array-like of shape (n_samples, n_features)
The data.
y : array-like of shape (n_samples,)
The target.
data_source : {"test", "train", "X_y"}, default="test"
The data source to use.
- "test" : use the test set provided when creating the report.
- "train" : use the train set provided when creating the report.
- "X_y" : use the provided `X` and `y` to compute the metric.
response_method : str
The response method.
display_class : class
The display class.
display_kwargs : dict
The display kwargs used by `display_class._from_predictions`.
Returns
-------
display : display_class
The display.
"""
cache_key = (self._parent._hash, display_class.__name__)
cache_key += tuple(display_kwargs.values())
cache_key += (data_source,)

progress = self._progress_info["current_progress"]
main_task = self._progress_info["current_task"]
total_estimators = len(self._parent.estimator_reports_)
progress.update(main_task, total=total_estimators)

if cache_key in self._parent._cache:
display = self._parent._cache[cache_key]
else:
y_true, y_pred = [], []

for report in self._parent.estimator_reports_:
report_X, report_y, _ = report.metrics._get_X_y_and_data_source_hash(
data_source=data_source,
X=X,
y=y,
)

y_true.append(report_y)
y_pred.append(
_get_cached_response_values(
cache=report._cache,
estimator_hash=report._hash,
estimator=report._estimator,
X=report_X,
response_method=response_method,
data_source=data_source,
data_source_hash=None,
pos_label=display_kwargs.get("pos_label", None),
)
)
progress.update(main_task, advance=1, refresh=True)

display = display_class._from_predictions(
y_true,
y_pred,
estimators=[r.estimator_ for r in self._parent.estimator_reports_],
estimator_names=self._parent.report_names_,
ml_task=self._parent._ml_task,
data_source=data_source,
**display_kwargs,
)
self._parent._cache[cache_key] = display

return display

@available_if(
_check_supported_ml_task(
supported_ml_tasks=["binary-classification", "multiclass-classification"]
)
)
def roc(self, *, data_source="test", X=None, y=None, pos_label=None, ax=None):
"""Plot the ROC curve.
Parameters
----------
data_source : {"test", "train", "X_y"}, default="test"
The data source to use.
- "test" : use the test set provided when creating the report.
- "train" : use the train set provided when creating the report.
- "X_y" : use the provided `X` and `y` to compute the metric.
X : array-like of shape (n_samples, n_features), default=None
New data on which to compute the metric. By default, we use the validation
set provided when creating the report.
y : array-like of shape (n_samples,), default=None
New target on which to compute the metric. By default, we use the target
provided when creating the report.
pos_label : int, float, bool or str, default=None
The positive class.
Returns
-------
RocCurveDisplay
The ROC curve display.
Examples
--------
>>> from sklearn.datasets import load_breast_cancer
>>> from sklearn.linear_model import LogisticRegression
>>> from sklearn.model_selection import train_test_split
>>> from skore import ComparisonReport, EstimatorReport
>>> X, y = load_breast_cancer(return_X_y=True)
>>> X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)
>>> estimator_1 = LogisticRegression(max_iter=10000, random_state=42)
>>> estimator_report_1 = EstimatorReport(
... estimator_1,
... X_train=X_train,
... y_train=y_train,
... X_test=X_test,
... y_test=y_test,
... )
>>> estimator_2 = LogisticRegression(max_iter=10000, random_state=43)
>>> estimator_report_2 = EstimatorReport(
... estimator_2,
... X_train=X_train,
... y_train=y_train,
... X_test=X_test,
... y_test=y_test,
... )
>>> comparison_report = ComparisonReport(
... [estimator_report_1, estimator_report_2]
... )
>>> display = comparison_report.metrics.roc()
>>> display.plot()
"""
response_method = ("predict_proba", "decision_function")
display_kwargs = {"pos_label": pos_label}
return self._get_display(
X=X,
y=y,
data_source=data_source,
response_method=response_method,
display_class=RocCurveDisplay,
display_kwargs=display_kwargs,
)

@available_if(
_check_supported_ml_task(
supported_ml_tasks=["binary-classification", "multiclass-classification"]
)
)
def precision_recall(self, *, data_source="test", X=None, y=None, pos_label=None):
"""Plot the precision-recall curve.
Parameters
----------
data_source : {"test", "train", "X_y"}, default="test"
The data source to use.
- "test" : use the test set provided when creating the report.
- "train" : use the train set provided when creating the report.
- "X_y" : use the provided `X` and `y` to compute the metric.
X : array-like of shape (n_samples, n_features), default=None
New data on which to compute the metric. By default, we use the validation
set provided when creating the report.
y : array-like of shape (n_samples,), default=None
New target on which to compute the metric. By default, we use the target
provided when creating the report.
pos_label : int, float, bool or str, default=None
The positive class.
Returns
-------
PrecisionRecallCurveDisplay
The precision-recall curve display.
Examples
--------
>>> from sklearn.datasets import load_breast_cancer
>>> from sklearn.linear_model import LogisticRegression
>>> from sklearn.model_selection import train_test_split
>>> from skore import ComparisonReport, EstimatorReport
>>> X, y = load_breast_cancer(return_X_y=True)
>>> X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)
>>> estimator_1 = LogisticRegression(max_iter=10000, random_state=42)
>>> estimator_report_1 = EstimatorReport(
... estimator_1,
... X_train=X_train,
... y_train=y_train,
... X_test=X_test,
... y_test=y_test,
... )
>>> estimator_2 = LogisticRegression(max_iter=10000, random_state=43)
>>> estimator_report_2 = EstimatorReport(
... estimator_2,
... X_train=X_train,
... y_train=y_train,
... X_test=X_test,
... y_test=y_test,
... )
>>> comparison_report = ComparisonReport(
... [estimator_report_1, estimator_report_2]
... )
>>> display = comparison_report.metrics.precision_recall()
>>> display.plot()
"""
response_method = ("predict_proba", "decision_function")
display_kwargs = {"pos_label": pos_label}
return self._get_display(
X=X,
y=y,
data_source=data_source,
response_method=response_method,
display_class=PrecisionRecallCurveDisplay,
display_kwargs=display_kwargs,
)

@available_if(_check_supported_ml_task(supported_ml_tasks=["regression"]))
def prediction_error(
self,
*,
data_source="test",
X=None,
y=None,
subsample=1_000,
random_state=None,
):
"""Plot the prediction error of a regression model.
Extra keyword arguments will be passed to matplotlib's `plot`.
Parameters
----------
data_source : {"test", "train", "X_y"}, default="test"
The data source to use.
- "test" : use the test set provided when creating the report.
- "train" : use the train set provided when creating the report.
- "X_y" : use the provided `X` and `y` to compute the metric.
X : array-like of shape (n_samples, n_features), default=None
New data on which to compute the metric. By default, we use the validation
set provided when creating the report.
y : array-like of shape (n_samples,), default=None
New target on which to compute the metric. By default, we use the target
provided when creating the report.
subsample : float, int or None, default=1_000
Sampling the samples to be shown on the scatter plot. If `float`,
it should be between 0 and 1 and represents the proportion of the
original dataset. If `int`, it represents the number of samples
display on the scatter plot. If `None`, no subsampling will be
applied. by default, 1,000 samples or less will be displayed.
random_state : int, default=None
The random state to use for the subsampling.
Returns
-------
PredictionErrorDisplay
The prediction error display.
Examples
--------
>>> from sklearn.datasets import load_diabetes
>>> from sklearn.linear_model import Ridge
>>> from sklearn.model_selection import train_test_split
>>> from skore import ComparisonReport, EstimatorReport
>>> X, y = load_diabetes(return_X_y=True)
>>> X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)
>>> estimator_1 = Ridge(random_state=42)
>>> estimator_report_1 = EstimatorReport(
... estimator_1,
... X_train=X_train,
... y_train=y_train,
... X_test=X_test,
... y_test=y_test,
... )
>>> estimator_2 = Ridge(random_state=43)
>>> estimator_report_2 = EstimatorReport(
... estimator_2,
... X_train=X_train,
... y_train=y_train,
... X_test=X_test,
... y_test=y_test,
... )
>>> comparison_report = ComparisonReport(
... [estimator_report_1, estimator_report_2]
... )
>>> display = comparison_report.metrics.prediction_error()
>>> display.plot(kind="actual_vs_predicted")
"""
display_kwargs = {"subsample": subsample, "random_state": random_state}
return self._get_display(
X=X,
y=y,
data_source=data_source,
response_method="predict",
display_class=PredictionErrorDisplay,
display_kwargs=display_kwargs,
)
Loading

0 comments on commit 01dc294

Please sign in to comment.