Skip to content

Commit

Permalink
fix(api): Do not call .plot() in display factory (#1299)
Browse files Browse the repository at this point in the history
closes #1298 

This PR makes sure that the display factory does not call `.plot()`
directly.
It allows:

- to pass plotting parameters to be passed only to `display.plot()`
- to have a consistent api as for pandas object like dataframe:

```python
report.metrics.report_metrics().plot(kind="barh")
report.metrics.roc().plot()
```
  • Loading branch information
glemaitre authored Feb 10, 2025
1 parent 3724391 commit e2cc5e6
Show file tree
Hide file tree
Showing 12 changed files with 51 additions and 195 deletions.
4 changes: 2 additions & 2 deletions examples/getting_started/plot_skore_getting_started.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@
import matplotlib.pyplot as plt

roc_plot = est_report.metrics.roc()
roc_plot
roc_plot.plot()
plt.tight_layout()

# %%
Expand Down Expand Up @@ -117,7 +117,7 @@

# %%
roc_plot_cv = cv_report.metrics.roc()
roc_plot_cv
roc_plot_cv.plot()
plt.tight_layout()

# %%
Expand Down
7 changes: 5 additions & 2 deletions examples/model_evaluation/plot_estimator_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,7 @@ def operational_decision_cost(y_true, y_pred, amount):
#
# Let's start by plotting the ROC curve for our binary classification task.
display = report.metrics.roc(pos_label=pos_label)
display.plot()
plt.tight_layout()

# %%
Expand Down Expand Up @@ -362,7 +363,8 @@ def operational_decision_cost(y_true, y_pred, amount):
# performance gain we can get.
start = time.time()
# we already trigger the computation of the predictions in a previous call
report.metrics.roc(pos_label=pos_label)
display = report.metrics.roc(pos_label=pos_label)
display.plot()
plt.tight_layout()
end = time.time()

Expand All @@ -376,7 +378,8 @@ def operational_decision_cost(y_true, y_pred, amount):

# %%
start = time.time()
report.metrics.roc(pos_label=pos_label)
display = report.metrics.roc(pos_label=pos_label)
display.plot()
plt.tight_layout()
end = time.time()

Expand Down
4 changes: 3 additions & 1 deletion examples/use_cases/plot_employee_salaries.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,10 +313,12 @@ def periodic_spline_transformer(period, n_splines=None, degree=3):
if estimator_report is None:
ax.axis("off")
continue
estimator_report.metrics.prediction_error(kind="actual_vs_predicted", ax=ax)
estimator_report.metrics.prediction_error().plot(kind="actual_vs_predicted", ax=ax)
ax.set_title(f"Split #{split_idx + 1}")
ax.legend(loc="lower right")
plt.tight_layout()
# sphinx_gallery_start_ignore
temp_dir.cleanup()
# sphinx_gallery_end_ignore

# %%
44 changes: 4 additions & 40 deletions skore/src/skore/sklearn/_cross_validation/metrics_accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -812,7 +812,6 @@ def _get_display(
response_method,
display_class,
display_kwargs,
display_plot_kwargs,
):
"""Get the display from the cache or compute it.
Expand All @@ -833,9 +832,6 @@ def _get_display(
display_kwargs : dict
The display kwargs used by `display_class._from_predictions`.
display_plot_kwargs : dict
The display kwargs used by `display.plot`.
Returns
-------
display : display_class
Expand All @@ -852,7 +848,6 @@ def _get_display(

if cache_key in self._parent._cache:
display = self._parent._cache[cache_key]
display.plot(**display_plot_kwargs)
else:
y_true, y_pred = [], []
for report in self._parent.estimator_reports_:
Expand Down Expand Up @@ -882,7 +877,6 @@ def _get_display(
ml_task=self._parent._ml_task,
data_source=data_source,
**display_kwargs,
**display_plot_kwargs,
)
self._parent._cache[cache_key] = display

Expand All @@ -893,7 +887,7 @@ def _get_display(
supported_ml_tasks=["binary-classification", "multiclass-classification"]
)
)
def roc(self, *, data_source="test", pos_label=None, ax=None):
def roc(self, *, data_source="test", pos_label=None):
"""Plot the ROC curve.
Parameters
Expand All @@ -907,9 +901,6 @@ def roc(self, *, data_source="test", pos_label=None, ax=None):
pos_label : int, float, bool or str, default=None
The positive class.
ax : matplotlib.axes.Axes, default=None
The axes to plot on.
Returns
-------
RocCurveDisplay
Expand All @@ -928,21 +919,19 @@ def roc(self, *, data_source="test", pos_label=None, ax=None):
"""
response_method = ("predict_proba", "decision_function")
display_kwargs = {"pos_label": pos_label}
display_plot_kwargs = {"ax": ax, "plot_chance_level": True, "despine": True}
return self._get_display(
data_source=data_source,
response_method=response_method,
display_class=RocCurveDisplay,
display_kwargs=display_kwargs,
display_plot_kwargs=display_plot_kwargs,
)

@available_if(
_check_supported_ml_task(
supported_ml_tasks=["binary-classification", "multiclass-classification"]
)
)
def precision_recall(self, *, data_source="test", pos_label=None, ax=None):
def precision_recall(self, *, data_source="test", pos_label=None):
"""Plot the precision-recall curve.
Parameters
Expand All @@ -956,9 +945,6 @@ def precision_recall(self, *, data_source="test", pos_label=None, ax=None):
pos_label : int, float, bool or str, default=None
The positive class.
ax : matplotlib.axes.Axes, default=None
The axes to plot on.
Returns
-------
PrecisionRecallCurveDisplay
Expand All @@ -977,22 +963,18 @@ def precision_recall(self, *, data_source="test", pos_label=None, ax=None):
"""
response_method = ("predict_proba", "decision_function")
display_kwargs = {"pos_label": pos_label}
display_plot_kwargs = {"ax": ax, "despine": True}
return self._get_display(
data_source=data_source,
response_method=response_method,
display_class=PrecisionRecallCurveDisplay,
display_kwargs=display_kwargs,
display_plot_kwargs=display_plot_kwargs,
)

@available_if(_check_supported_ml_task(supported_ml_tasks=["regression"]))
def prediction_error(
self,
*,
data_source="test",
ax=None,
kind="residual_vs_predicted",
subsample=1_000,
random_state=None,
):
Expand All @@ -1008,20 +990,6 @@ def prediction_error(
- "test" : use the test set provided when creating the report.
- "train" : use the train set provided when creating the report.
ax : matplotlib axes, default=None
Axes object to plot on. If `None`, a new figure and axes is
created.
kind : {"actual_vs_predicted", "residual_vs_predicted"}, \
default="residual_vs_predicted"
The type of plot to draw:
- "actual_vs_predicted" draws the observed values (y-axis) vs.
the predicted values (x-axis).
- "residual_vs_predicted" draws the residuals, i.e. difference
between observed and predicted values, (y-axis) vs. the predicted
values (x-axis).
subsample : float, int or None, default=1_000
Sampling the samples to be shown on the scatter plot. If `float`,
it should be between 0 and 1 and represents the proportion of the
Expand All @@ -1045,17 +1013,13 @@ def prediction_error(
>>> X, y = load_diabetes(return_X_y=True)
>>> regressor = Ridge()
>>> report = CrossValidationReport(regressor, X=X, y=y, cv_splitter=2)
>>> display = report.metrics.prediction_error(
... kind="actual_vs_predicted"
... )
>>> display.plot(line_kwargs={"color": "tab:red"})
>>> display = report.metrics.prediction_error()
>>> display.plot(kind="actual_vs_predicted", line_kwargs={"color": "tab:red"})
"""
display_kwargs = {"subsample": subsample, "random_state": random_state}
display_plot_kwargs = {"ax": ax, "kind": kind}
return self._get_display(
data_source=data_source,
response_method="predict",
display_class=PredictionErrorDisplay,
display_kwargs=display_kwargs,
display_plot_kwargs=display_plot_kwargs,
)
41 changes: 2 additions & 39 deletions skore/src/skore/sklearn/_estimator/metrics_accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1396,7 +1396,6 @@ def _get_display(
response_method,
display_class,
display_kwargs,
display_plot_kwargs,
):
"""Get the display from the cache or compute it.
Expand Down Expand Up @@ -1424,9 +1423,6 @@ def _get_display(
display_kwargs : dict
The display kwargs used by `display_class._from_predictions`.
display_plot_kwargs : dict
The display kwargs used by `display.plot`.
Returns
-------
display : display_class
Expand All @@ -1442,7 +1438,6 @@ def _get_display(

if cache_key in self._parent._cache:
display = self._parent._cache[cache_key]
display.plot(**display_plot_kwargs)
else:
y_pred = _get_cached_response_values(
cache=self._parent._cache,
Expand All @@ -1463,7 +1458,6 @@ def _get_display(
ml_task=self._parent._ml_task,
data_source=data_source,
**display_kwargs,
**display_plot_kwargs,
)
self._parent._cache[cache_key] = display

Expand All @@ -1474,7 +1468,7 @@ def _get_display(
supported_ml_tasks=["binary-classification", "multiclass-classification"]
)
)
def roc(self, *, data_source="test", X=None, y=None, pos_label=None, ax=None):
def roc(self, *, data_source="test", X=None, y=None, pos_label=None):
"""Plot the ROC curve.
Parameters
Expand All @@ -1497,9 +1491,6 @@ def roc(self, *, data_source="test", X=None, y=None, pos_label=None, ax=None):
pos_label : int, float, bool or str, default=None
The positive class.
ax : matplotlib.axes.Axes, default=None
The axes to plot on.
Returns
-------
RocCurveDisplay
Expand Down Expand Up @@ -1527,15 +1518,13 @@ def roc(self, *, data_source="test", X=None, y=None, pos_label=None, ax=None):
"""
response_method = ("predict_proba", "decision_function")
display_kwargs = {"pos_label": pos_label}
display_plot_kwargs = {"ax": ax, "plot_chance_level": True, "despine": True}
return self._get_display(
X=X,
y=y,
data_source=data_source,
response_method=response_method,
display_class=RocCurveDisplay,
display_kwargs=display_kwargs,
display_plot_kwargs=display_plot_kwargs,
)

@available_if(
Expand All @@ -1550,7 +1539,6 @@ def precision_recall(
X=None,
y=None,
pos_label=None,
ax=None,
):
"""Plot the precision-recall curve.
Expand All @@ -1574,9 +1562,6 @@ def precision_recall(
pos_label : int, float, bool or str, default=None
The positive class.
ax : matplotlib.axes.Axes, default=None
The axes to plot on.
Returns
-------
PrecisionRecallCurveDisplay
Expand Down Expand Up @@ -1604,15 +1589,13 @@ def precision_recall(
"""
response_method = ("predict_proba", "decision_function")
display_kwargs = {"pos_label": pos_label}
display_plot_kwargs = {"ax": ax, "despine": True}
return self._get_display(
X=X,
y=y,
data_source=data_source,
response_method=response_method,
display_class=PrecisionRecallCurveDisplay,
display_kwargs=display_kwargs,
display_plot_kwargs=display_plot_kwargs,
)

@available_if(_check_supported_ml_task(supported_ml_tasks=["regression"]))
Expand All @@ -1622,8 +1605,6 @@ def prediction_error(
data_source="test",
X=None,
y=None,
ax=None,
kind="residual_vs_predicted",
subsample=1_000,
random_state=None,
):
Expand All @@ -1648,20 +1629,6 @@ def prediction_error(
New target on which to compute the metric. By default, we use the target
provided when creating the report.
ax : matplotlib axes, default=None
Axes object to plot on. If `None`, a new figure and axes is
created.
kind : {"actual_vs_predicted", "residual_vs_predicted"}, \
default="residual_vs_predicted"
The type of plot to draw:
- "actual_vs_predicted" draws the observed values (y-axis) vs.
the predicted values (x-axis).
- "residual_vs_predicted" draws the residuals, i.e. difference
between observed and predicted values, (y-axis) vs. the predicted
values (x-axis).
subsample : float, int or None, default=1_000
Sampling the samples to be shown on the scatter plot. If `float`,
it should be between 0 and 1 and represents the proportion of the
Expand Down Expand Up @@ -1694,19 +1661,15 @@ def prediction_error(
... X_test=X_test,
... y_test=y_test,
... )
>>> display = report.metrics.prediction_error(
... kind="actual_vs_predicted"
... )
>>> display = report.metrics.prediction_error()
>>> display.plot(line_kwargs={"color": "tab:red"})
"""
display_kwargs = {"subsample": subsample, "random_state": random_state}
display_plot_kwargs = {"ax": ax, "kind": kind}
return self._get_display(
X=X,
y=y,
data_source=data_source,
response_method="predict",
display_class=PredictionErrorDisplay,
display_kwargs=display_kwargs,
display_plot_kwargs=display_plot_kwargs,
)
Loading

0 comments on commit e2cc5e6

Please sign in to comment.