Skip to content

Commit

Permalink
perf: move name of classification common function output name to enum
Browse files Browse the repository at this point in the history
  • Loading branch information
Haibin committed Sep 15, 2024
1 parent 6c1f66b commit c4c4bca
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 7 deletions.
14 changes: 11 additions & 3 deletions geochemistrypi/data_mining/model/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,15 +148,15 @@ def _cross_validation(trained_model: object, X_train: pd.DataFrame, graph_name:
"""Perform cross validation on the model."""
print(f"-----* {graph_name} *-----")
print(f"K-Folds: {cv_num}")
scores = cross_validation(trained_model, X_train, y_train, average=average, cv_num=cv_num)
scores = cross_validation(trained_model, X_train, y_train, graph_name, average=average, cv_num=cv_num)
scores_str = json.dumps(scores, indent=4)
save_text(scores_str, f"{graph_name} - {algorithm_name}", store_path)

@staticmethod
def _plot_confusion_matrix(y_test: pd.DataFrame, y_test_predict: pd.DataFrame, graph_name: str, trained_model: object, algorithm_name: str, local_path: str, mlflow_path: str) -> None:
"""Plot the confusion matrix of the model."""
print(f"-----* {graph_name} *-----")
data = plot_confusion_matrix(y_test, y_test_predict, trained_model)
data = plot_confusion_matrix(y_test, y_test_predict, trained_model, graph_name)
save_fig(f"{graph_name} - {algorithm_name}", local_path, mlflow_path)
index = [f"true_{i}" for i in range(int(y_test.nunique().values))]
columns = [f"pred_{i}" for i in range(int(y_test.nunique().values))]
Expand Down Expand Up @@ -194,7 +194,7 @@ def _plot_precision_recall_threshold(
@staticmethod
def _plot_ROC(X_test: pd.DataFrame, y_test: pd.DataFrame, trained_model: object, graph_name: str, algorithm_name: str, local_path: str, mlflow_path: str) -> None:
print(f"-----* {graph_name} *-----")
y_probs, fpr, tpr, thresholds = plot_ROC(X_test, y_test, trained_model, algorithm_name)
y_probs, fpr, tpr, thresholds = plot_ROC(X_test, y_test, trained_model, graph_name, algorithm_name)
save_fig(f"{graph_name} - {algorithm_name}", local_path, mlflow_path)
y_probs = pd.DataFrame(y_probs, columns=["Probabilities"])
fpr = pd.DataFrame(fpr, columns=["False Positive Rate"])
Expand Down Expand Up @@ -284,6 +284,7 @@ def common_components(self) -> None:
trained_model=self.model,
X_train=ClassificationWorkflowBase.X_train,
y_train=ClassificationWorkflowBase.y_train,
graph_name=ClassificationCommonFunction.CROSS_VALIDATION.value,
average=average,
cv_num=10,
algorithm_name=self.naming,
Expand All @@ -294,6 +295,7 @@ def common_components(self) -> None:
y_test_predict=ClassificationWorkflowBase.y_test_predict,
name_column=ClassificationWorkflowBase.name_test,
trained_model=self.model,
graph_name=ClassificationCommonFunction.CONFUSION_MATRIX.value,
algorithm_name=self.naming,
local_path=GEOPI_OUTPUT_ARTIFACTS_IMAGE_MODEL_OUTPUT_PATH,
mlflow_path=MLFLOW_ARTIFACT_IMAGE_MODEL_OUTPUT_PATH,
Expand Down Expand Up @@ -324,6 +326,7 @@ def common_components(self) -> None:
y_test=ClassificationWorkflowBase.y_test,
name_column=ClassificationWorkflowBase.name_test,
trained_model=self.model,
graph_name=ClassificationCommonFunction.ROC_CURVE.value,
algorithm_name=self.naming,
local_path=GEOPI_OUTPUT_ARTIFACTS_IMAGE_MODEL_OUTPUT_PATH,
mlflow_path=MLFLOW_ARTIFACT_IMAGE_MODEL_OUTPUT_PATH,
Expand All @@ -346,6 +349,7 @@ def common_components(self) -> None:
name_column2=ClassificationWorkflowBase.name_test,
trained_model=self.model,
image_config=self.image_config,
graph_name=ClassificationCommonFunction.TWO_DIMENSIONAL_DECISION_BOUNDARY_DIAGRAM.value,
algorithm_name=self.naming,
local_path=GEOPI_OUTPUT_ARTIFACTS_IMAGE_MODEL_OUTPUT_PATH,
mlflow_path=MLFLOW_ARTIFACT_IMAGE_MODEL_OUTPUT_PATH,
Expand All @@ -372,6 +376,7 @@ def common_components(self, is_automl: bool) -> None:
trained_model=self.auto_model,
X_train=ClassificationWorkflowBase.X_train,
y_train=ClassificationWorkflowBase.y_train,
graph_name=ClassificationCommonFunction.CROSS_VALIDATION.value,
average=average,
cv_num=10,
algorithm_name=self.naming,
Expand All @@ -382,6 +387,7 @@ def common_components(self, is_automl: bool) -> None:
y_test_predict=ClassificationWorkflowBase.y_test_predict,
name_column=ClassificationWorkflowBase.name_test,
trained_model=self.auto_model,
graph_name=ClassificationCommonFunction.CONFUSION_MATRIX.value,
algorithm_name=self.naming,
local_path=GEOPI_OUTPUT_ARTIFACTS_IMAGE_MODEL_OUTPUT_PATH,
mlflow_path=MLFLOW_ARTIFACT_IMAGE_MODEL_OUTPUT_PATH,
Expand Down Expand Up @@ -412,6 +418,7 @@ def common_components(self, is_automl: bool) -> None:
y_test=ClassificationWorkflowBase.y_test,
name_column=ClassificationWorkflowBase.name_test,
trained_model=self.auto_model,
graph_name=ClassificationCommonFunction.ROC_CURVE.value,
algorithm_name=self.naming,
local_path=GEOPI_OUTPUT_ARTIFACTS_IMAGE_MODEL_OUTPUT_PATH,
mlflow_path=MLFLOW_ARTIFACT_IMAGE_MODEL_OUTPUT_PATH,
Expand All @@ -434,6 +441,7 @@ def common_components(self, is_automl: bool) -> None:
name_column2=ClassificationWorkflowBase.name_test,
trained_model=self.auto_model,
image_config=self.image_config,
graph_name=ClassificationCommonFunction.TWO_DIMENSIONAL_DECISION_BOUNDARY_DIAGRAM.value,
algorithm_name=self.naming,
local_path=GEOPI_OUTPUT_ARTIFACTS_IMAGE_MODEL_OUTPUT_PATH,
mlflow_path=MLFLOW_ARTIFACT_IMAGE_MODEL_OUTPUT_PATH,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def score(y_true: pd.DataFrame, y_predict: pd.DataFrame) -> tuple[str, Dict]:
return average, scores


def plot_confusion_matrix(y_test: pd.DataFrame, y_test_predict: pd.DataFrame, trained_model: object) -> np.ndarray:
def plot_confusion_matrix(y_test: pd.DataFrame, y_test_predict: pd.DataFrame, trained_model: object, graph_name: str) -> np.ndarray:
"""Plot the confusion matrix.
Parameters
Expand Down Expand Up @@ -124,7 +124,7 @@ def display_cross_validation_scores(scores: np.ndarray, score_name: str) -> Dict
return cv_scores


def cross_validation(trained_model: object, X_train: pd.DataFrame, y_train: pd.DataFrame, average: str, cv_num: int = 10) -> Dict:
def cross_validation(trained_model: object, X_train: pd.DataFrame, y_train: pd.DataFrame, graph_name: str, average: str, cv_num: int = 10) -> Dict:
"""Evaluate metric(s) by cross-validation and also record fit/score times.
Parameters
Expand Down Expand Up @@ -286,7 +286,7 @@ def plot_precision_recall_threshold(X_test: pd.DataFrame, y_test: pd.DataFrame,
return y_probs, precisions, recalls, thresholds


def plot_ROC(X_test: pd.DataFrame, y_test: pd.DataFrame, trained_model: object, algorithm_name: str) -> tuple:
def plot_ROC(X_test: pd.DataFrame, y_test: pd.DataFrame, trained_model: object, graph_name: str, algorithm_name: str) -> tuple:
"""Plot the ROC curve.
Parameters
Expand Down Expand Up @@ -324,7 +324,7 @@ def plot_ROC(X_test: pd.DataFrame, y_test: pd.DataFrame, trained_model: object,
plt.plot([0, 1], [0, 1], "r--")
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate (Recall)")
plt.title(f"ROC Curve - {algorithm_name}")
plt.title(f"{graph_name} - {algorithm_name}")
return y_probs, fpr, tpr, thresholds


Expand Down

0 comments on commit c4c4bca

Please sign in to comment.