diff --git a/examples/getting_started/plot_skore_getting_started.py b/examples/getting_started/plot_skore_getting_started.py index e0470e324..ce92bbfab 100644 --- a/examples/getting_started/plot_skore_getting_started.py +++ b/examples/getting_started/plot_skore_getting_started.py @@ -77,7 +77,7 @@ import matplotlib.pyplot as plt roc_plot = est_report.metrics.roc() -roc_plot +roc_plot.plot() plt.tight_layout() # %% @@ -117,7 +117,7 @@ # %% roc_plot_cv = cv_report.metrics.roc() -roc_plot_cv +roc_plot_cv.plot() plt.tight_layout() # %% diff --git a/examples/model_evaluation/plot_estimator_report.py b/examples/model_evaluation/plot_estimator_report.py index 0633d3fe6..340190caf 100644 --- a/examples/model_evaluation/plot_estimator_report.py +++ b/examples/model_evaluation/plot_estimator_report.py @@ -335,6 +335,7 @@ def operational_decision_cost(y_true, y_pred, amount): # # Let's start by plotting the ROC curve for our binary classification task. display = report.metrics.roc(pos_label=pos_label) +display.plot() plt.tight_layout() # %% @@ -362,7 +363,8 @@ def operational_decision_cost(y_true, y_pred, amount): # performance gain we can get. start = time.time() # we already trigger the computation of the predictions in a previous call -report.metrics.roc(pos_label=pos_label) +display = report.metrics.roc(pos_label=pos_label) +display.plot() plt.tight_layout() end = time.time() @@ -376,7 +378,8 @@ def operational_decision_cost(y_true, y_pred, amount): # %% start = time.time() -report.metrics.roc(pos_label=pos_label) +display = report.metrics.roc(pos_label=pos_label) +display.plot() plt.tight_layout() end = time.time() diff --git a/examples/use_cases/plot_employee_salaries.py b/examples/use_cases/plot_employee_salaries.py index 42a94ccb9..64489cb00 100644 --- a/examples/use_cases/plot_employee_salaries.py +++ b/examples/use_cases/plot_employee_salaries.py @@ -313,10 +313,12 @@ def periodic_spline_transformer(period, n_splines=None, degree=3): if estimator_report is None: ax.axis("off") continue - estimator_report.metrics.prediction_error(kind="actual_vs_predicted", ax=ax) + estimator_report.metrics.prediction_error().plot(kind="actual_vs_predicted", ax=ax) ax.set_title(f"Split #{split_idx + 1}") ax.legend(loc="lower right") plt.tight_layout() # sphinx_gallery_start_ignore temp_dir.cleanup() # sphinx_gallery_end_ignore + +# %% diff --git a/skore/src/skore/sklearn/_cross_validation/metrics_accessor.py b/skore/src/skore/sklearn/_cross_validation/metrics_accessor.py index a23e9c2c7..c7e168d29 100644 --- a/skore/src/skore/sklearn/_cross_validation/metrics_accessor.py +++ b/skore/src/skore/sklearn/_cross_validation/metrics_accessor.py @@ -812,7 +812,6 @@ def _get_display( response_method, display_class, display_kwargs, - display_plot_kwargs, ): """Get the display from the cache or compute it. @@ -833,9 +832,6 @@ def _get_display( display_kwargs : dict The display kwargs used by `display_class._from_predictions`. - display_plot_kwargs : dict - The display kwargs used by `display.plot`. - Returns ------- display : display_class @@ -852,7 +848,6 @@ def _get_display( if cache_key in self._parent._cache: display = self._parent._cache[cache_key] - display.plot(**display_plot_kwargs) else: y_true, y_pred = [], [] for report in self._parent.estimator_reports_: @@ -882,7 +877,6 @@ def _get_display( ml_task=self._parent._ml_task, data_source=data_source, **display_kwargs, - **display_plot_kwargs, ) self._parent._cache[cache_key] = display @@ -893,7 +887,7 @@ def _get_display( supported_ml_tasks=["binary-classification", "multiclass-classification"] ) ) - def roc(self, *, data_source="test", pos_label=None, ax=None): + def roc(self, *, data_source="test", pos_label=None): """Plot the ROC curve. Parameters @@ -907,9 +901,6 @@ def roc(self, *, data_source="test", pos_label=None, ax=None): pos_label : int, float, bool or str, default=None The positive class. - ax : matplotlib.axes.Axes, default=None - The axes to plot on. - Returns ------- RocCurveDisplay @@ -928,13 +919,11 @@ def roc(self, *, data_source="test", pos_label=None, ax=None): """ response_method = ("predict_proba", "decision_function") display_kwargs = {"pos_label": pos_label} - display_plot_kwargs = {"ax": ax, "plot_chance_level": True, "despine": True} return self._get_display( data_source=data_source, response_method=response_method, display_class=RocCurveDisplay, display_kwargs=display_kwargs, - display_plot_kwargs=display_plot_kwargs, ) @available_if( @@ -942,7 +931,7 @@ def roc(self, *, data_source="test", pos_label=None, ax=None): supported_ml_tasks=["binary-classification", "multiclass-classification"] ) ) - def precision_recall(self, *, data_source="test", pos_label=None, ax=None): + def precision_recall(self, *, data_source="test", pos_label=None): """Plot the precision-recall curve. Parameters @@ -956,9 +945,6 @@ def precision_recall(self, *, data_source="test", pos_label=None, ax=None): pos_label : int, float, bool or str, default=None The positive class. - ax : matplotlib.axes.Axes, default=None - The axes to plot on. - Returns ------- PrecisionRecallCurveDisplay @@ -977,13 +963,11 @@ def precision_recall(self, *, data_source="test", pos_label=None, ax=None): """ response_method = ("predict_proba", "decision_function") display_kwargs = {"pos_label": pos_label} - display_plot_kwargs = {"ax": ax, "despine": True} return self._get_display( data_source=data_source, response_method=response_method, display_class=PrecisionRecallCurveDisplay, display_kwargs=display_kwargs, - display_plot_kwargs=display_plot_kwargs, ) @available_if(_check_supported_ml_task(supported_ml_tasks=["regression"])) @@ -991,8 +975,6 @@ def prediction_error( self, *, data_source="test", - ax=None, - kind="residual_vs_predicted", subsample=1_000, random_state=None, ): @@ -1008,20 +990,6 @@ def prediction_error( - "test" : use the test set provided when creating the report. - "train" : use the train set provided when creating the report. - ax : matplotlib axes, default=None - Axes object to plot on. If `None`, a new figure and axes is - created. - - kind : {"actual_vs_predicted", "residual_vs_predicted"}, \ - default="residual_vs_predicted" - The type of plot to draw: - - - "actual_vs_predicted" draws the observed values (y-axis) vs. - the predicted values (x-axis). - - "residual_vs_predicted" draws the residuals, i.e. difference - between observed and predicted values, (y-axis) vs. the predicted - values (x-axis). - subsample : float, int or None, default=1_000 Sampling the samples to be shown on the scatter plot. If `float`, it should be between 0 and 1 and represents the proportion of the @@ -1045,17 +1013,13 @@ def prediction_error( >>> X, y = load_diabetes(return_X_y=True) >>> regressor = Ridge() >>> report = CrossValidationReport(regressor, X=X, y=y, cv_splitter=2) - >>> display = report.metrics.prediction_error( - ... kind="actual_vs_predicted" - ... ) - >>> display.plot(line_kwargs={"color": "tab:red"}) + >>> display = report.metrics.prediction_error() + >>> display.plot(kind="actual_vs_predicted", line_kwargs={"color": "tab:red"}) """ display_kwargs = {"subsample": subsample, "random_state": random_state} - display_plot_kwargs = {"ax": ax, "kind": kind} return self._get_display( data_source=data_source, response_method="predict", display_class=PredictionErrorDisplay, display_kwargs=display_kwargs, - display_plot_kwargs=display_plot_kwargs, ) diff --git a/skore/src/skore/sklearn/_estimator/metrics_accessor.py b/skore/src/skore/sklearn/_estimator/metrics_accessor.py index 55ed4ce8a..48d876af2 100644 --- a/skore/src/skore/sklearn/_estimator/metrics_accessor.py +++ b/skore/src/skore/sklearn/_estimator/metrics_accessor.py @@ -1396,7 +1396,6 @@ def _get_display( response_method, display_class, display_kwargs, - display_plot_kwargs, ): """Get the display from the cache or compute it. @@ -1424,9 +1423,6 @@ def _get_display( display_kwargs : dict The display kwargs used by `display_class._from_predictions`. - display_plot_kwargs : dict - The display kwargs used by `display.plot`. - Returns ------- display : display_class @@ -1442,7 +1438,6 @@ def _get_display( if cache_key in self._parent._cache: display = self._parent._cache[cache_key] - display.plot(**display_plot_kwargs) else: y_pred = _get_cached_response_values( cache=self._parent._cache, @@ -1463,7 +1458,6 @@ def _get_display( ml_task=self._parent._ml_task, data_source=data_source, **display_kwargs, - **display_plot_kwargs, ) self._parent._cache[cache_key] = display @@ -1474,7 +1468,7 @@ def _get_display( supported_ml_tasks=["binary-classification", "multiclass-classification"] ) ) - def roc(self, *, data_source="test", X=None, y=None, pos_label=None, ax=None): + def roc(self, *, data_source="test", X=None, y=None, pos_label=None): """Plot the ROC curve. Parameters @@ -1497,9 +1491,6 @@ def roc(self, *, data_source="test", X=None, y=None, pos_label=None, ax=None): pos_label : int, float, bool or str, default=None The positive class. - ax : matplotlib.axes.Axes, default=None - The axes to plot on. - Returns ------- RocCurveDisplay @@ -1527,7 +1518,6 @@ def roc(self, *, data_source="test", X=None, y=None, pos_label=None, ax=None): """ response_method = ("predict_proba", "decision_function") display_kwargs = {"pos_label": pos_label} - display_plot_kwargs = {"ax": ax, "plot_chance_level": True, "despine": True} return self._get_display( X=X, y=y, @@ -1535,7 +1525,6 @@ def roc(self, *, data_source="test", X=None, y=None, pos_label=None, ax=None): response_method=response_method, display_class=RocCurveDisplay, display_kwargs=display_kwargs, - display_plot_kwargs=display_plot_kwargs, ) @available_if( @@ -1550,7 +1539,6 @@ def precision_recall( X=None, y=None, pos_label=None, - ax=None, ): """Plot the precision-recall curve. @@ -1574,9 +1562,6 @@ def precision_recall( pos_label : int, float, bool or str, default=None The positive class. - ax : matplotlib.axes.Axes, default=None - The axes to plot on. - Returns ------- PrecisionRecallCurveDisplay @@ -1604,7 +1589,6 @@ def precision_recall( """ response_method = ("predict_proba", "decision_function") display_kwargs = {"pos_label": pos_label} - display_plot_kwargs = {"ax": ax, "despine": True} return self._get_display( X=X, y=y, @@ -1612,7 +1596,6 @@ def precision_recall( response_method=response_method, display_class=PrecisionRecallCurveDisplay, display_kwargs=display_kwargs, - display_plot_kwargs=display_plot_kwargs, ) @available_if(_check_supported_ml_task(supported_ml_tasks=["regression"])) @@ -1622,8 +1605,6 @@ def prediction_error( data_source="test", X=None, y=None, - ax=None, - kind="residual_vs_predicted", subsample=1_000, random_state=None, ): @@ -1648,20 +1629,6 @@ def prediction_error( New target on which to compute the metric. By default, we use the target provided when creating the report. - ax : matplotlib axes, default=None - Axes object to plot on. If `None`, a new figure and axes is - created. - - kind : {"actual_vs_predicted", "residual_vs_predicted"}, \ - default="residual_vs_predicted" - The type of plot to draw: - - - "actual_vs_predicted" draws the observed values (y-axis) vs. - the predicted values (x-axis). - - "residual_vs_predicted" draws the residuals, i.e. difference - between observed and predicted values, (y-axis) vs. the predicted - values (x-axis). - subsample : float, int or None, default=1_000 Sampling the samples to be shown on the scatter plot. If `float`, it should be between 0 and 1 and represents the proportion of the @@ -1694,13 +1661,10 @@ def prediction_error( ... X_test=X_test, ... y_test=y_test, ... ) - >>> display = report.metrics.prediction_error( - ... kind="actual_vs_predicted" - ... ) + >>> display = report.metrics.prediction_error() >>> display.plot(line_kwargs={"color": "tab:red"}) """ display_kwargs = {"subsample": subsample, "random_state": random_state} - display_plot_kwargs = {"ax": ax, "kind": kind} return self._get_display( X=X, y=y, @@ -1708,5 +1672,4 @@ def prediction_error( response_method="predict", display_class=PredictionErrorDisplay, display_kwargs=display_kwargs, - display_plot_kwargs=display_plot_kwargs, ) diff --git a/skore/src/skore/sklearn/_plot/precision_recall_curve.py b/skore/src/skore/sklearn/_plot/precision_recall_curve.py index a25f80fdc..11c36a199 100644 --- a/skore/src/skore/sklearn/_plot/precision_recall_curve.py +++ b/skore/src/skore/sklearn/_plot/precision_recall_curve.py @@ -379,9 +379,6 @@ def _from_predictions( data_source=None, pos_label=None, drop_intermediate=False, - ax=None, - pr_curve_kwargs=None, - despine=True, ): """Plot precision-recall curve given binary class predictions. @@ -416,19 +413,6 @@ def _from_predictions( on a plotted precision-recall curve. This is useful in order to create lighter precision-recall curves. - ax : matplotlib axes, default=None - Axes object to plot on. If `None`, a new figure and axes is created. - - pr_curve_kwargs : dict or list of dict, default=None - Keyword arguments to be passed to matplotlib's `plot` for rendering - the precision-recall curve(s). - - despine : bool, default=True - Whether to remove the top and right spines from the plot. - - **kwargs : dict - Keyword arguments to be passed to matplotlib's `plot`. - Returns ------- display : :class:`~sklearn.metrics.PrecisionRecallDisplay` @@ -475,7 +459,7 @@ def _from_predictions( recall[class_].append(recall_class_i) average_precision[class_].append(average_precision_class_i) - viz = cls( + return cls( precision=precision, recall=recall, average_precision=average_precision, @@ -483,12 +467,3 @@ def _from_predictions( pos_label=pos_label_validated, data_source=data_source, ) - - viz.plot( - ax=ax, - estimator_name=estimator_name, - pr_curve_kwargs=pr_curve_kwargs, - despine=despine, - ) - - return viz diff --git a/skore/src/skore/sklearn/_plot/prediction_error.py b/skore/src/skore/sklearn/_plot/prediction_error.py index 7084fb6e9..a859b9e35 100644 --- a/skore/src/skore/sklearn/_plot/prediction_error.py +++ b/skore/src/skore/sklearn/_plot/prediction_error.py @@ -291,13 +291,8 @@ def _from_predictions( estimator_name, ml_task, # FIXME: to be used when having single-output vs. multi-output data_source=None, - kind="residual_vs_predicted", subsample=1_000, random_state=None, - ax=None, - scatter_kwargs=None, - line_kwargs=None, - despine=True, ): """Plot the prediction error given the true and predicted targets. @@ -321,16 +316,6 @@ def _from_predictions( data_source : {"train", "test", "X_y"}, default=None The data source used to compute the ROC curve. - kind : {"actual_vs_predicted", "residual_vs_predicted"}, \ - default="residual_vs_predicted" - The type of plot to draw: - - - "actual_vs_predicted" draws the observed values (y-axis) vs. - the predicted values (x-axis). - - "residual_vs_predicted" draws the residuals, i.e. difference - between observed and predicted values, (y-axis) vs. the predicted - values (x-axis). - subsample : float, int or None, default=1_000 Sampling the samples to be shown on the scatter plot. If `float`, it should be between 0 and 1 and represents the proportion of the @@ -342,21 +327,6 @@ def _from_predictions( Controls the randomness when `subsample` is not `None`. See :term:`Glossary ` for details. - ax : matplotlib axes, default=None - Axes object to plot on. If `None`, a new figure and axes is - created. - - scatter_kwargs : dict, default=None - Dictionary with keywords passed to the `matplotlib.pyplot.scatter` - call. - - line_kwargs : dict, default=None - Dictionary with keyword passed to the `matplotlib.pyplot.plot` - call to draw the optimal line. - - despine : bool, default=True - Whether to remove the top and right spines from the plot. - Returns ------- display : PredictionErrorDisplay @@ -394,20 +364,9 @@ def _from_predictions( y_true_display.append(y_true_i) y_pred_display.append(y_pred_i) - viz = cls( + return cls( y_true=y_true_display, y_pred=y_pred_display, estimator_name=estimator_name, data_source=data_source, ) - - viz.plot( - ax=ax, - estimator_name=estimator_name, - kind=kind, - scatter_kwargs=scatter_kwargs, - line_kwargs=line_kwargs, - despine=despine, - ) - - return viz diff --git a/skore/src/skore/sklearn/_plot/roc_curve.py b/skore/src/skore/sklearn/_plot/roc_curve.py index eb5663473..e0c352aff 100644 --- a/skore/src/skore/sklearn/_plot/roc_curve.py +++ b/skore/src/skore/sklearn/_plot/roc_curve.py @@ -389,11 +389,6 @@ def _from_predictions( data_source=None, pos_label=None, drop_intermediate=True, - ax=None, - roc_curve_kwargs=None, - plot_chance_level=True, - chance_level_kwargs=None, - despine=True, ): """Private method to create a RocCurveDisplay from predictions. @@ -426,24 +421,6 @@ def _from_predictions( drop_intermediate : bool, default=True Whether to drop intermediate points with identical value. - ax : matplotlib axes, default=None - Axes object to plot on. If `None`, a new figure and axes is - created. - - roc_curve_kwargs : dict or list of dict, default=None - Keyword arguments to be passed to matplotlib's `plot` for rendering - the ROC curve(s). - - plot_chance_level : bool, default=True - Whether to plot the chance level. - - chance_level_kwargs : dict, default=None - Keyword arguments to be passed to matplotlib's `plot` for rendering - the chance level line. - - despine : bool, default=True - Whether to remove the top and right spines from the plot. - Returns ------- display : RocCurveDisplay @@ -485,7 +462,7 @@ def _from_predictions( tpr[class_].append(tpr_class_i) roc_auc[class_].append(roc_auc_class_i) - viz = cls( + return cls( fpr=fpr, tpr=tpr, roc_auc=roc_auc, @@ -493,14 +470,3 @@ def _from_predictions( pos_label=pos_label_validated, data_source=data_source, ) - - viz.plot( - ax=ax, - estimator_name=estimator_name, - roc_curve_kwargs=roc_curve_kwargs, - plot_chance_level=plot_chance_level, - chance_level_kwargs=chance_level_kwargs, - despine=despine, - ) - - return viz diff --git a/skore/src/skore/sklearn/_plot/utils.py b/skore/src/skore/sklearn/_plot/utils.py index 611aaa82a..2091bd560 100644 --- a/skore/src/skore/sklearn/_plot/utils.py +++ b/skore/src/skore/sklearn/_plot/utils.py @@ -43,11 +43,12 @@ def _create_help_tree(self): attr_branch = tree.add("[bold cyan] Attributes[/bold cyan]") # Ensure figure_ and ax_ are first sorted_attrs = sorted(attributes) - sorted_attrs.remove(".ax_") - sorted_attrs.remove(".figure_") - sorted_attrs = [".figure_", ".ax_"] + [ - attr for attr in sorted_attrs if attr not in [".figure_", ".ax_"] - ] + if ("figure_" in sorted_attrs) and ("ax_" in sorted_attrs): + sorted_attrs.remove(".ax_") + sorted_attrs.remove(".figure_") + sorted_attrs = [".figure_", ".ax_"] + [ + attr for attr in sorted_attrs if attr not in [".figure_", ".ax_"] + ] for attr in sorted_attrs: attr_branch.add(attr) diff --git a/skore/tests/unit/sklearn/plot/test_precision_recall_curve.py b/skore/tests/unit/sklearn/plot/test_precision_recall_curve.py index 92c9156be..15aa9ab77 100644 --- a/skore/tests/unit/sklearn/plot/test_precision_recall_curve.py +++ b/skore/tests/unit/sklearn/plot/test_precision_recall_curve.py @@ -59,6 +59,7 @@ def test_precision_recall_curve_display_binary_classification( assert isinstance(attr[estimator.classes_[1]], list) assert len(attr[estimator.classes_[1]]) == 1 + display.plot() assert isinstance(display.lines_, list) assert len(display.lines_) == 1 precision_recall_curve_mpl = display.lines_[0] @@ -104,6 +105,7 @@ def test_precision_recall_curve_cross_validation_display_binary_classification( assert isinstance(attr[pos_label], list) assert len(attr[pos_label]) == cv + display.plot() assert isinstance(display.lines_, list) assert len(display.lines_) == cv expected_colors = sample_mpl_colormap(pyplot.cm.tab10, 10) @@ -136,9 +138,11 @@ def test_precision_recall_curve_display_data_source(pyplot, binary_classificatio estimator, X_train=X_train, y_train=y_train, X_test=X_test, y_test=y_test ) display = report.metrics.precision_recall(data_source="train") + display.plot() assert display.lines_[0].get_label() == "Train set (AP = 1.00)" display = report.metrics.precision_recall(data_source="X_y", X=X_train, y=y_train) + display.plot() assert display.lines_[0].get_label() == "AP = 1.00" @@ -165,6 +169,7 @@ def test_precision_recall_curve_display_multiclass_classification( assert isinstance(attr[class_label], list) assert len(attr[class_label]) == 1 + display.plot() assert isinstance(display.lines_, list) assert len(display.lines_) == len(estimator.classes_) default_colors = sample_mpl_colormap(pyplot.cm.tab10, 10) @@ -211,6 +216,7 @@ def test_precision_recall_curve_cross_validation_display_multiclass_classificati assert isinstance(attr[class_label], list) assert len(attr[class_label]) == cv + display.plot() assert isinstance(display.lines_, list) assert len(display.lines_) == len(class_labels) * cv default_colors = sample_mpl_colormap(pyplot.cm.tab10, 10) @@ -329,7 +335,9 @@ def test_precision_recall_curve_display_data_source_binary_classification( estimator, X_train=X_train, y_train=y_train, X_test=X_test, y_test=y_test ) display = report.metrics.precision_recall(data_source="train") + display.plot() assert display.lines_[0].get_label() == "Train set (AP = 1.00)" display = report.metrics.precision_recall(data_source="X_y", X=X_train, y=y_train) + display.plot() assert display.lines_[0].get_label() == "AP = 1.00" diff --git a/skore/tests/unit/sklearn/plot/test_prediction_error.py b/skore/tests/unit/sklearn/plot/test_prediction_error.py index 31e2aa5d7..5f5288304 100644 --- a/skore/tests/unit/sklearn/plot/test_prediction_error.py +++ b/skore/tests/unit/sklearn/plot/test_prediction_error.py @@ -27,7 +27,6 @@ def regression_data_no_split(): ({"subsample": -1}, "When an integer, subsample=-1 should be"), ({"subsample": 20.0}, "When a floating-point, subsample=20.0 should be"), ({"subsample": -20.0}, "When a floating-point, subsample=-20.0 should be"), - ({"kind": "xxx"}, "`kind` must be one of"), ], ) def test_prediction_error_display_raise_error(pyplot, params, err_msg, regression_data): @@ -59,6 +58,7 @@ def test_prediction_error_display_regression(pyplot, regression_data, subsample) np.testing.assert_allclose(display.y_pred[0], estimator.predict(X_test)) assert display.data_source == "test" + display.plot() assert isinstance(display.line_, mpl.lines.Line2D) assert display.line_.get_label() == "Perfect predictions" assert display.line_.get_color() == "black" @@ -90,6 +90,7 @@ def test_prediction_error_cross_validation_display_regression( assert len(display.y_true) == len(display.y_pred) == cv assert display.data_source == "test" + display.plot() assert isinstance(display.line_, mpl.lines.Line2D) assert display.line_.get_label() == "Perfect predictions" assert display.line_.get_color() == "black" @@ -111,9 +112,10 @@ def test_prediction_error_display_regression_kind(pyplot, regression_data): report = EstimatorReport( estimator, X_train=X_train, y_train=y_train, X_test=X_test, y_test=y_test ) - display = report.metrics.prediction_error(kind="actual_vs_predicted") + display = report.metrics.prediction_error() assert isinstance(display, PredictionErrorDisplay) + display.plot(kind="actual_vs_predicted") assert isinstance(display.line_, mpl.lines.Line2D) assert display.line_.get_label() == "Perfect predictions" assert display.line_.get_color() == "black" @@ -138,7 +140,8 @@ def test_prediction_error_cross_validation_display_regression_kind( """Check the attributes when switching to the "actual_vs_predicted" kind.""" (estimator, X, y), cv = regression_data_no_split, 3 report = CrossValidationReport(estimator, X=X, y=y, cv_splitter=cv) - display = report.metrics.prediction_error(kind="actual_vs_predicted") + display = report.metrics.prediction_error() + display.plot(kind="actual_vs_predicted") assert isinstance(display, PredictionErrorDisplay) # check the structure of the attributes @@ -170,10 +173,12 @@ def test_prediction_error_display_data_source(pyplot, regression_data): estimator, X_train=X_train, y_train=y_train, X_test=X_test, y_test=y_test ) display = report.metrics.prediction_error(data_source="train") + display.plot() assert display.line_.get_label() == "Perfect predictions" assert display.scatter_.get_label() == "Train set" display = report.metrics.prediction_error(data_source="X_y", X=X_train, y=y_train) + display.plot() assert display.line_.get_label() == "Perfect predictions" assert display.scatter_.get_label() == "Data set" @@ -195,8 +200,10 @@ def test_prediction_error_display_kwargs(pyplot, regression_data): expected_subsample = 10 display = report.metrics.prediction_error(subsample=expected_subsample) + display.plot() assert len(display.scatter_.get_offsets()) == expected_subsample expected_subsample = int(X_test.shape[0] * 0.5) display = report.metrics.prediction_error(subsample=0.5) + display.plot() assert len(display.scatter_.get_offsets()) == expected_subsample diff --git a/skore/tests/unit/sklearn/plot/test_roc_curve.py b/skore/tests/unit/sklearn/plot/test_roc_curve.py index 33230cd02..f5d64f48e 100644 --- a/skore/tests/unit/sklearn/plot/test_roc_curve.py +++ b/skore/tests/unit/sklearn/plot/test_roc_curve.py @@ -56,6 +56,7 @@ def test_roc_curve_display_binary_classification(pyplot, binary_classification_d assert isinstance(attr[estimator.classes_[1]], list) assert len(attr[estimator.classes_[1]]) == 1 + display.plot() assert isinstance(display.lines_, list) assert len(display.lines_) == 1 roc_curve_mpl = display.lines_[0] @@ -104,6 +105,7 @@ def test_roc_curve_display_multiclass_classification( assert isinstance(attr[class_label], list) assert len(attr[class_label]) == 1 + display.plot() assert isinstance(display.lines_, list) assert len(display.lines_) == len(estimator.classes_) default_colors = sample_mpl_colormap(pyplot.cm.tab10, 10) @@ -141,9 +143,11 @@ def test_roc_curve_display_data_source_binary_classification( estimator, X_train=X_train, y_train=y_train, X_test=X_test, y_test=y_test ) display = report.metrics.roc(data_source="train") + display.plot() assert display.lines_[0].get_label() == "Train set (AUC = 1.00)" display = report.metrics.roc(data_source="X_y", X=X_train, y=y_train) + display.plot() assert display.lines_[0].get_label() == "AUC = 1.00" @@ -156,6 +160,7 @@ def test_roc_curve_display_data_source_multiclass_classification( estimator, X_train=X_train, y_train=y_train, X_test=X_test, y_test=y_test ) display = report.metrics.roc(data_source="train") + display.plot() for class_label in estimator.classes_: assert display.lines_[class_label].get_label() == ( f"{str(class_label).title()} - train set " @@ -163,6 +168,7 @@ def test_roc_curve_display_data_source_multiclass_classification( ) display = report.metrics.roc(data_source="X_y", X=X_train, y=y_train) + display.plot() for class_label in estimator.classes_: assert display.lines_[class_label].get_label() == ( f"{str(class_label).title()} - AUC = 1.00" @@ -267,6 +273,7 @@ def test_roc_curve_display_cross_validation_binary_classification( assert isinstance(attr[pos_label], list) assert len(attr[pos_label]) == cv + display.plot() assert isinstance(display.lines_, list) assert len(display.lines_) == cv expected_colors = sample_mpl_colormap(pyplot.cm.tab10, 10) @@ -315,6 +322,7 @@ def test_roc_curve_display_cross_validation_multiclass_classification( assert isinstance(attr[class_label], list) assert len(attr[class_label]) == cv + display.plot() assert isinstance(display.lines_, list) assert len(display.lines_) == len(class_labels) * cv default_colors = sample_mpl_colormap(pyplot.cm.tab10, 10)