Skip to content

Commit

Permalink
Merge pull request #284 from ZJUEarthData/dev/Mengqi
Browse files Browse the repository at this point in the history
perf: improve cross validation function and the selection branch of two-class and multi-class classification.
  • Loading branch information
SanyHe authored Nov 26, 2023
2 parents b776957 + 9e3fdea commit 2d512b7
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 69 deletions.
81 changes: 43 additions & 38 deletions geochemistrypi/data_mining/model/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,13 +121,14 @@ def manual_hyper_parameters(cls) -> Dict:
return dict()

@staticmethod
def _score(y_true: pd.DataFrame, y_predict: pd.DataFrame, algorithm_name: str, store_path: str) -> None:
def _score(y_true: pd.DataFrame, y_predict: pd.DataFrame, algorithm_name: str, store_path: str) -> str:
"""Print the classification score report of the model."""
print("-----* Model Score *-----")
scores = score(y_true, y_predict)
average, scores = score(y_true, y_predict)
scores_str = json.dumps(scores, indent=4)
save_text(scores_str, f"Model Score - {algorithm_name}", store_path)
mlflow.log_metrics(scores)
return average

@staticmethod
def _classification_report(y_true: pd.DataFrame, y_predict: pd.DataFrame, algorithm_name: str, store_path: str) -> None:
Expand All @@ -140,11 +141,11 @@ def _classification_report(y_true: pd.DataFrame, y_predict: pd.DataFrame, algori
mlflow.log_artifact(os.path.join(store_path, f"Classification Report - {algorithm_name}.txt"))

@staticmethod
def _cross_validation(trained_model: object, X_train: pd.DataFrame, y_train: pd.DataFrame, cv_num: int, algorithm_name: str, store_path: str) -> None:
def _cross_validation(trained_model: object, X_train: pd.DataFrame, y_train: pd.DataFrame, average: str, cv_num: int, algorithm_name: str, store_path: str) -> None:
"""Perform cross validation on the model."""
print("-----* Cross Validation *-----")
print(f"K-Folds: {cv_num}")
scores = cross_validation(trained_model, X_train, y_train, cv_num=cv_num)
scores = cross_validation(trained_model, X_train, y_train, average=average, cv_num=cv_num)
scores_str = json.dumps(scores, indent=4)
save_text(scores_str, f"Cross Validation - {algorithm_name}", store_path)

Expand Down Expand Up @@ -248,7 +249,7 @@ def common_components(self) -> None:
"""Invoke all common application functions for classification algorithms by Scikit-learn framework."""
GEOPI_OUTPUT_METRICS_PATH = os.getenv("GEOPI_OUTPUT_METRICS_PATH")
GEOPI_OUTPUT_ARTIFACTS_IMAGE_MODEL_OUTPUT_PATH = os.getenv("GEOPI_OUTPUT_ARTIFACTS_IMAGE_MODEL_OUTPUT_PATH")
self._score(
average = self._score(
y_true=ClassificationWorkflowBase.y_test,
y_predict=ClassificationWorkflowBase.y_test_predict,
algorithm_name=self.naming,
Expand All @@ -264,6 +265,7 @@ def common_components(self) -> None:
trained_model=self.model,
X_train=ClassificationWorkflowBase.X_train,
y_train=ClassificationWorkflowBase.y_train,
average=average,
cv_num=10,
algorithm_name=self.naming,
store_path=GEOPI_OUTPUT_METRICS_PATH,
Expand All @@ -276,22 +278,23 @@ def common_components(self) -> None:
local_path=GEOPI_OUTPUT_ARTIFACTS_IMAGE_MODEL_OUTPUT_PATH,
mlflow_path=MLFLOW_ARTIFACT_IMAGE_MODEL_OUTPUT_PATH,
)
self._plot_precision_recall(
X_test=ClassificationWorkflowBase.X_test,
y_test=ClassificationWorkflowBase.y_test,
trained_model=self.model,
algorithm_name=self.naming,
local_path=GEOPI_OUTPUT_ARTIFACTS_IMAGE_MODEL_OUTPUT_PATH,
mlflow_path=MLFLOW_ARTIFACT_IMAGE_MODEL_OUTPUT_PATH,
)
self._plot_ROC(
X_test=ClassificationWorkflowBase.X_test,
y_test=ClassificationWorkflowBase.y_test,
trained_model=self.model,
algorithm_name=self.naming,
local_path=GEOPI_OUTPUT_ARTIFACTS_IMAGE_MODEL_OUTPUT_PATH,
mlflow_path=MLFLOW_ARTIFACT_IMAGE_MODEL_OUTPUT_PATH,
)
if int(ClassificationWorkflowBase.y_test.nunique().values) == 2:
self._plot_precision_recall(
X_test=ClassificationWorkflowBase.X_test,
y_test=ClassificationWorkflowBase.y_test,
trained_model=self.model,
algorithm_name=self.naming,
local_path=GEOPI_OUTPUT_ARTIFACTS_IMAGE_MODEL_OUTPUT_PATH,
mlflow_path=MLFLOW_ARTIFACT_IMAGE_MODEL_OUTPUT_PATH,
)
self._plot_ROC(
X_test=ClassificationWorkflowBase.X_test,
y_test=ClassificationWorkflowBase.y_test,
trained_model=self.model,
algorithm_name=self.naming,
local_path=GEOPI_OUTPUT_ARTIFACTS_IMAGE_MODEL_OUTPUT_PATH,
mlflow_path=MLFLOW_ARTIFACT_IMAGE_MODEL_OUTPUT_PATH,
)
self._plot_permutation_importance(
X_test=ClassificationWorkflowBase.X_test,
y_test=ClassificationWorkflowBase.y_test,
Expand All @@ -317,7 +320,7 @@ def common_components(self, is_automl: bool) -> None:
"""Invoke all common application functions for classification algorithms by FLAML framework."""
GEOPI_OUTPUT_METRICS_PATH = os.getenv("GEOPI_OUTPUT_METRICS_PATH")
GEOPI_OUTPUT_ARTIFACTS_IMAGE_MODEL_OUTPUT_PATH = os.getenv("GEOPI_OUTPUT_ARTIFACTS_IMAGE_MODEL_OUTPUT_PATH")
self._score(
average = self._score(
y_true=ClassificationWorkflowBase.y_test,
y_predict=ClassificationWorkflowBase.y_test_predict,
algorithm_name=self.naming,
Expand All @@ -333,6 +336,7 @@ def common_components(self, is_automl: bool) -> None:
trained_model=self.auto_model,
X_train=ClassificationWorkflowBase.X_train,
y_train=ClassificationWorkflowBase.y_train,
average=average,
cv_num=10,
algorithm_name=self.naming,
store_path=GEOPI_OUTPUT_METRICS_PATH,
Expand All @@ -345,22 +349,23 @@ def common_components(self, is_automl: bool) -> None:
local_path=GEOPI_OUTPUT_ARTIFACTS_IMAGE_MODEL_OUTPUT_PATH,
mlflow_path=MLFLOW_ARTIFACT_IMAGE_MODEL_OUTPUT_PATH,
)
self._plot_precision_recall(
X_test=ClassificationWorkflowBase.X_test,
y_test=ClassificationWorkflowBase.y_test,
trained_model=self.auto_model,
algorithm_name=self.naming,
local_path=GEOPI_OUTPUT_ARTIFACTS_IMAGE_MODEL_OUTPUT_PATH,
mlflow_path=MLFLOW_ARTIFACT_IMAGE_MODEL_OUTPUT_PATH,
)
self._plot_ROC(
X_test=ClassificationWorkflowBase.X_test,
y_test=ClassificationWorkflowBase.y_test,
trained_model=self.auto_model,
algorithm_name=self.naming,
local_path=GEOPI_OUTPUT_ARTIFACTS_IMAGE_MODEL_OUTPUT_PATH,
mlflow_path=MLFLOW_ARTIFACT_IMAGE_MODEL_OUTPUT_PATH,
)
if int(ClassificationWorkflowBase.y_test.nunique().values) == 2:
self._plot_precision_recall(
X_test=ClassificationWorkflowBase.X_test,
y_test=ClassificationWorkflowBase.y_test,
trained_model=self.auto_model,
algorithm_name=self.naming,
local_path=GEOPI_OUTPUT_ARTIFACTS_IMAGE_MODEL_OUTPUT_PATH,
mlflow_path=MLFLOW_ARTIFACT_IMAGE_MODEL_OUTPUT_PATH,
)
self._plot_ROC(
X_test=ClassificationWorkflowBase.X_test,
y_test=ClassificationWorkflowBase.y_test,
trained_model=self.auto_model,
algorithm_name=self.naming,
local_path=GEOPI_OUTPUT_ARTIFACTS_IMAGE_MODEL_OUTPUT_PATH,
mlflow_path=MLFLOW_ARTIFACT_IMAGE_MODEL_OUTPUT_PATH,
)
self._plot_permutation_importance(
X_test=ClassificationWorkflowBase.X_test,
y_test=ClassificationWorkflowBase.y_test,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@
from imblearn.pipeline import Pipeline
from imblearn.under_sampling import RandomUnderSampler
from rich import print
from sklearn.metrics import ConfusionMatrixDisplay, accuracy_score, confusion_matrix, f1_score, precision_recall_curve, precision_score, recall_score, roc_curve
from sklearn.metrics import ConfusionMatrixDisplay, accuracy_score, confusion_matrix, f1_score, make_scorer, precision_recall_curve, precision_score, recall_score, roc_curve
from sklearn.model_selection import cross_validate
from sklearn.preprocessing import LabelEncoder


def score(y_true: pd.DataFrame, y_predict: pd.DataFrame) -> Dict:
def score(y_true: pd.DataFrame, y_predict: pd.DataFrame) -> tuple[str, Dict]:
"""Calculate the scores of the classification model.
Parameters
Expand All @@ -30,6 +30,9 @@ def score(y_true: pd.DataFrame, y_predict: pd.DataFrame) -> Dict:
Returns
-------
average : str
Metric parameters.
scores : dict
The scores of the classification model.
"""
Expand Down Expand Up @@ -61,7 +64,7 @@ def score(y_true: pd.DataFrame, y_predict: pd.DataFrame) -> Dict:
"Recall": recall,
"F1 Score": f1,
}
return scores
return average, scores


def plot_confusion_matrix(y_test: pd.DataFrame, y_test_predict: pd.DataFrame, trained_model: object) -> np.ndarray:
Expand Down Expand Up @@ -120,7 +123,7 @@ def display_cross_validation_scores(scores: np.ndarray, score_name: str) -> Dict
return cv_scores


def cross_validation(trained_model: object, X_train: pd.DataFrame, y_train: pd.DataFrame, cv_num: int = 10) -> Dict:
def cross_validation(trained_model: object, X_train: pd.DataFrame, y_train: pd.DataFrame, average: str, cv_num: int = 10) -> Dict:
"""Evaluate metric(s) by cross-validation and also record fit/score times.
Parameters
Expand All @@ -134,6 +137,9 @@ def cross_validation(trained_model: object, X_train: pd.DataFrame, y_train: pd.D
y_train : pd.DataFrame (n_samples, n_components)
The training target values.
average : str
Metric parameters.
cv_num : int
Determines the cross-validation splitting strategy.
Expand All @@ -142,8 +148,35 @@ def cross_validation(trained_model: object, X_train: pd.DataFrame, y_train: pd.D
scores_result : dict
The scores of cross-validation.
"""

scores = cross_validate(trained_model, X_train, y_train, scoring=("accuracy", "precision", "recall", "f1"), cv=cv_num)
if average == "binary":
scoring = {
"accuracy": make_scorer(accuracy_score),
"precision": make_scorer(precision_score, average="binary"),
"recall": make_scorer(recall_score, average="binary"),
"f1": make_scorer(f1_score, average="binary"),
}
elif average == "micro":
scoring = {
"accuracy": make_scorer(accuracy_score),
"precision": make_scorer(precision_score, average="micro"),
"recall": make_scorer(recall_score, average="micro"),
"f1": make_scorer(f1_score, average="micro"),
}
elif average == "macro":
scoring = {
"accuracy": make_scorer(accuracy_score),
"precision": make_scorer(precision_score, average="macro"),
"recall": make_scorer(recall_score, average="macro"),
"f1": make_scorer(f1_score, average="macro"),
}
elif average == "weighted":
scoring = {
"accuracy": make_scorer(accuracy_score),
"precision": make_scorer(precision_score, average="weighted"),
"recall": make_scorer(recall_score, average="weighted"),
"f1": make_scorer(f1_score, average="weighted"),
}
scores = cross_validate(trained_model, X_train, y_train, scoring=scoring, cv=cv_num)
del scores["fit_time"]
del scores["score_time"]
# the keys follow the returns of cross_validate in scikit-learn
Expand Down Expand Up @@ -193,19 +226,15 @@ def plot_precision_recall(X_test, y_test, trained_model: object, algorithm_name:
thresholds : np.ndarray
The thresholds of the model.
"""
if int(y_test.nunique().values) == 2:
# Predict probabilities for the positive class
y_probs = trained_model.predict_proba(X_test)[:, 1]
precisions, recalls, thresholds = precision_recall_curve(y_test, y_probs)

plt.figure()
plt.plot(thresholds, precisions[:-1], "b--", label="Precision")
plt.plot(thresholds, recalls[:-1], "g-", label="Recall")
plt.legend(labels=["Precision", "Recall"], loc="best")
plt.title(f"Precision Recall Curve - {algorithm_name}")
return y_probs, precisions, recalls, thresholds
else:
return None, None, None, None
# Predict probabilities for the positive class
y_probs = trained_model.predict_proba(X_test)[:, 1]
precisions, recalls, thresholds = precision_recall_curve(y_test, y_probs)
plt.figure()
plt.plot(thresholds, precisions[:-1], "b--", label="Precision")
plt.plot(thresholds, recalls[:-1], "g-", label="Recall")
plt.legend(labels=["Precision", "Recall"], loc="best")
plt.title(f"Precision Recall Curve - {algorithm_name}")
return y_probs, precisions, recalls, thresholds


def plot_ROC(X_test: pd.DataFrame, y_test: pd.DataFrame, trained_model: object, algorithm_name: str) -> tuple:
Expand Down Expand Up @@ -239,18 +268,15 @@ def plot_ROC(X_test: pd.DataFrame, y_test: pd.DataFrame, trained_model: object,
thresholds : np.ndarray
The thresholds of the model.
"""
if int(y_test.nunique().values) == 2:
y_probs = trained_model.predict_proba(X_test)[:, 1]
fpr, tpr, thresholds = roc_curve(y_test, y_probs)
plt.figure()
plt.plot(fpr, tpr, linewidth=2)
plt.plot([0, 1], [0, 1], "r--")
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate (Recall)")
plt.title(f"ROC Curve - {algorithm_name}")
return y_probs, fpr, tpr, thresholds
else:
return None, None, None, None
y_probs = trained_model.predict_proba(X_test)[:, 1]
fpr, tpr, thresholds = roc_curve(y_test, y_probs)
plt.figure()
plt.plot(fpr, tpr, linewidth=2)
plt.plot([0, 1], [0, 1], "r--")
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate (Recall)")
plt.title(f"ROC Curve - {algorithm_name}")
return y_probs, fpr, tpr, thresholds


def plot_2d_decision_boundary(X: pd.DataFrame, X_test: pd.DataFrame, trained_model: object, image_config: Dict) -> None:
Expand Down

0 comments on commit 2d512b7

Please sign in to comment.