Skip to content

Commit

Permalink
feat(CrossValidationReporter): Catch exceptions during cross-validati…
Browse files Browse the repository at this point in the history
…on (#1287)
  • Loading branch information
auguste-probabl authored Feb 7, 2025
1 parent ceb3a80 commit 83ea1b1
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 4 deletions.
40 changes: 37 additions & 3 deletions skore/src/skore/sklearn/_cross_validation/report.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import joblib
import numpy as np
from rich.panel import Panel
from sklearn.base import clone, is_classifier
from sklearn.model_selection import check_cv
from sklearn.pipeline import Pipeline
Expand All @@ -28,6 +29,13 @@ def _generate_estimator_report(estimator, X, y, train_indices, test_indices):
class CrossValidationReport(_BaseReport, DirNamesMixin):
"""Report for cross-validation results.
Upon initialization, `CrossValidationReport` will clone ``estimator`` according to
``cv_splitter`` and fit the generated estimators. The fitting is done in parallel,
and can be interrupted: the estimators that have been fitted can be accessed even if
the full cross-validation process did not complete. In particular,
`KeyboardInterrupt` exceptions are swallowed and will only interrupt the
cross-validation process, rather than the entire program.
Parameters
----------
estimator : estimator object
Expand Down Expand Up @@ -163,9 +171,35 @@ def _fit_estimator_reports(self):
)

estimator_reports = []
for report in generator:
estimator_reports.append(report)
progress.update(task, advance=1, refresh=True)
try:
for report in generator:
estimator_reports.append(report)
progress.update(task, advance=1, refresh=True)
except (Exception, KeyboardInterrupt) as e:
from skore import console # avoid circular import

if isinstance(e, KeyboardInterrupt):
message = (
"Cross-validation process was interrupted manually before all "
"estimators could be fitted; CrossValidationReport object "
"might not contain all the expected results."
)
else:
message = (
"Cross-validation process was interrupted by an error before "
"all estimators could be fitted; CrossValidationReport object "
"might not contain all the expected results. "
f"Traceback: \n{e}"
)

console.print(
Panel(
title="Cross-validation interrupted",
renderable=message,
style="orange1",
border_style="cyan",
)
)

return estimator_reports

Expand Down
52 changes: 51 additions & 1 deletion skore/tests/unit/sklearn/test_cross_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import numpy as np
import pandas as pd
import pytest
from sklearn.base import clone
from sklearn.base import BaseEstimator, clone
from sklearn.datasets import make_classification, make_regression
from sklearn.ensemble import RandomForestClassifier
from sklearn.exceptions import NotFittedError
Expand Down Expand Up @@ -737,3 +737,53 @@ def test_cross_validation_report_custom_metric(binary_classification_data):
)
assert result.shape == (2, 1)
assert result.columns == ["accuracy_score"]


@pytest.mark.parametrize(
"error,error_message",
[
(ValueError("No more fitting"), "Cross-validation interrupted by an error"),
(KeyboardInterrupt(), "Cross-validation interrupted manually"),
],
)
def test_cross_validation_report_interrupted(
binary_classification_data, capsys, error, error_message
):
"""Check that we can interrupt cross-validation without losing all
data."""

class MockEstimator(BaseEstimator):
def __init__(self, n_call=0, fail_after_n_clone=3):
self.n_call = n_call
self.fail_after_n_clone = fail_after_n_clone

def fit(self, X, y):
if self.n_call > self.fail_after_n_clone:
raise error
return self

def __sklearn_clone__(self):
"""Do not clone the estimator
Instead, we increment a counter each time that
`sklearn.clone` is called.
"""
self.n_call += 1
return self

def predict(self, X):
return np.ones(X.shape[0])

_, X, y = binary_classification_data

report = CrossValidationReport(MockEstimator(), X, y, cv_splitter=10)

captured = capsys.readouterr()
assert all(word in captured.out for word in error_message.split(" "))

result = report.metrics.custom_metric(
metric_function=accuracy_score,
response_method="predict",
)
assert result.shape == (1, 1)
assert result.columns == ["accuracy_score"]

0 comments on commit 83ea1b1

Please sign in to comment.