Skip to content

Commit

Permalink
Merge pull request #419 from ZJUEarthData/dev/HaibinWang
Browse files Browse the repository at this point in the history
refactor: separate classification func name in enum (save_without_id)
  • Loading branch information
SanyHe authored Jan 20, 2025
2 parents 4751992 + 92fda35 commit 8a14bad
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 17 deletions.
69 changes: 52 additions & 17 deletions geochemistrypi/data_mining/model/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,16 @@
score,
)
from .func.algo_classification._decision_tree import decision_tree_manual_hyper_parameters
from .func.algo_classification._enum import ClassificationCommonFunction
from .func.algo_classification._enum import (
ClassificationCommonFunction,
DecisionTreeSpecialFunction,
ExtraTreesSpecialFunction,
GradientBoostingSpecialFunction,
LogisticRegressionSpecialFunction,
MLPSpecialFunction,
RandomForestSpecialFunction,
XGBoostSpecialFunction,
)
from .func.algo_classification._extra_trees import extra_trees_manual_hyper_parameters
from .func.algo_classification._gradient_boosting import gradient_boosting_manual_hyper_parameters
from .func.algo_classification._knn import knn_manual_hyper_parameters
Expand Down Expand Up @@ -124,24 +133,24 @@ def manual_hyper_parameters(cls) -> Dict:
return dict()

@staticmethod
def _score(y_true: pd.DataFrame, y_predict: pd.DataFrame, algorithm_name: str, store_path: str) -> str:
def _score(y_true: pd.DataFrame, y_predict: pd.DataFrame, algorithm_name: str, store_path: str, func_name: str) -> str:
"""Print the classification score report of the model."""
print("-----* Model Score *-----")
print(f"-----* {func_name} *-----")
average, scores = score(y_true, y_predict)
scores_str = json.dumps(scores, indent=4)
save_text(scores_str, f"Model Score - {algorithm_name}", store_path)
save_text(scores_str, f"{func_name} - {algorithm_name}", store_path)
mlflow.log_metrics(scores)
return average

@staticmethod
def _classification_report(y_true: pd.DataFrame, y_predict: pd.DataFrame, algorithm_name: str, store_path: str) -> None:
def _classification_report(y_true: pd.DataFrame, y_predict: pd.DataFrame, algorithm_name: str, store_path: str, func_name: str) -> None:
"""Print the classification report of the model."""
print("-----* Classification Report *-----")
print(f"-----* {func_name} *-----")
print(classification_report(y_true, y_predict))
scores = classification_report(y_true, y_predict, output_dict=True)
scores_str = json.dumps(scores, indent=4)
save_text(scores_str, f"Classification Report - {algorithm_name}", store_path)
mlflow.log_artifact(os.path.join(store_path, f"Classification Report - {algorithm_name}.txt"))
save_text(scores_str, f"{func_name} - {algorithm_name}", store_path)
mlflow.log_artifact(os.path.join(store_path, f"{func_name} - {algorithm_name}.txt"))

@staticmethod
def _cross_validation(trained_model: object, X_train: pd.DataFrame, y_train: pd.DataFrame, graph_name: str, average: str, cv_num: int, algorithm_name: str, store_path: str) -> None:
Expand All @@ -157,7 +166,7 @@ def _plot_confusion_matrix(
y_test: pd.DataFrame, y_test_predict: pd.DataFrame, name_column: str, graph_name: str, trained_model: object, algorithm_name: str, local_path: str, mlflow_path: str
) -> None:
"""Plot the confusion matrix of the model."""
print("-----* {graph_name} *-----")
print(f"-----* {graph_name} *-----")
data = plot_confusion_matrix(y_test, y_test_predict, trained_model)
save_fig(f"{graph_name} - {algorithm_name}", local_path, mlflow_path)
index = [f"true_{i}" for i in range(int(y_test.nunique().values))]
Expand Down Expand Up @@ -275,12 +284,14 @@ def common_components(self) -> None:
average = self._score(
y_true=ClassificationWorkflowBase.y_test,
y_predict=ClassificationWorkflowBase.y_test_predict,
func_name=ClassificationCommonFunction.MODEL_SCORE.value,
algorithm_name=self.naming,
store_path=GEOPI_OUTPUT_METRICS_PATH,
)
self._classification_report(
y_true=ClassificationWorkflowBase.y_test,
y_predict=ClassificationWorkflowBase.y_test_predict,
func_name=ClassificationCommonFunction.CLASSIFICATION_REPORT.value,
algorithm_name=self.naming,
store_path=GEOPI_OUTPUT_METRICS_PATH,
)
Expand Down Expand Up @@ -368,12 +379,14 @@ def common_components(self, is_automl: bool) -> None:
y_true=ClassificationWorkflowBase.y_test,
y_predict=ClassificationWorkflowBase.y_test_predict,
algorithm_name=self.naming,
func_name=ClassificationCommonFunction.MODEL_SCORE.value,
store_path=GEOPI_OUTPUT_METRICS_PATH,
)
self._classification_report(
y_true=ClassificationWorkflowBase.y_test,
y_predict=ClassificationWorkflowBase.y_test_predict,
algorithm_name=self.naming,
func_name=ClassificationCommonFunction.CLASSIFICATION_REPORT.value,
store_path=GEOPI_OUTPUT_METRICS_PATH,
)
self._cross_validation(
Expand Down Expand Up @@ -936,13 +949,15 @@ def special_components(self, **kwargs) -> None:
trained_model=self.model,
image_config=self.image_config,
algorithm_name=self.naming,
func_name=DecisionTreeSpecialFunction.FEATURE_IMPORTANCE.value,
local_path=GEOPI_OUTPUT_ARTIFACTS_IMAGE_MODEL_OUTPUT_PATH,
mlflow_path=MLFLOW_ARTIFACT_IMAGE_MODEL_OUTPUT_PATH,
)
self._plot_tree(
trained_model=self.model,
image_config=self.image_config,
algorithm_name=self.naming,
func_name=DecisionTreeSpecialFunction.TREE_DIAGRAM.value,
local_path=GEOPI_OUTPUT_ARTIFACTS_IMAGE_MODEL_OUTPUT_PATH,
mlflow_path=MLFLOW_ARTIFACT_IMAGE_MODEL_OUTPUT_PATH,
)
Expand All @@ -957,13 +972,15 @@ def special_components(self, is_automl: bool, **kwargs) -> None:
trained_model=self.auto_model,
image_config=self.image_config,
algorithm_name=self.naming,
func_name=DecisionTreeSpecialFunction.FEATURE_IMPORTANCE.value,
local_path=GEOPI_OUTPUT_ARTIFACTS_IMAGE_MODEL_OUTPUT_PATH,
mlflow_path=MLFLOW_ARTIFACT_IMAGE_MODEL_OUTPUT_PATH,
)
self._plot_tree(
trained_model=self.auto_model,
image_config=self.image_config,
algorithm_name=self.naming,
func_name=DecisionTreeSpecialFunction.TREE_DIAGRAM.value,
local_path=GEOPI_OUTPUT_ARTIFACTS_IMAGE_MODEL_OUTPUT_PATH,
mlflow_path=MLFLOW_ARTIFACT_IMAGE_MODEL_OUTPUT_PATH,
)
Expand Down Expand Up @@ -1255,13 +1272,15 @@ def special_components(self, **kwargs) -> None:
trained_model=self.model,
image_config=self.image_config,
algorithm_name=self.naming,
func_name=RandomForestSpecialFunction.FEATURE_IMPORTANCE.value,
local_path=GEOPI_OUTPUT_ARTIFACTS_IMAGE_MODEL_OUTPUT_PATH,
mlflow_path=MLFLOW_ARTIFACT_IMAGE_MODEL_OUTPUT_PATH,
)
self._plot_tree(
trained_model=self.model.estimators_[0],
image_config=self.image_config,
algorithm_name=self.naming,
func_name=RandomForestSpecialFunction.TREE_DIAGRAM.value,
local_path=GEOPI_OUTPUT_ARTIFACTS_IMAGE_MODEL_OUTPUT_PATH,
mlflow_path=MLFLOW_ARTIFACT_IMAGE_MODEL_OUTPUT_PATH,
)
Expand All @@ -1276,13 +1295,15 @@ def special_components(self, is_automl: bool = False, **kwargs) -> None:
trained_model=self.auto_model,
image_config=self.image_config,
algorithm_name=self.naming,
func_name=RandomForestSpecialFunction.FEATURE_IMPORTANCE.value,
local_path=GEOPI_OUTPUT_ARTIFACTS_IMAGE_MODEL_OUTPUT_PATH,
mlflow_path=MLFLOW_ARTIFACT_IMAGE_MODEL_OUTPUT_PATH,
)
self._plot_tree(
trained_model=self.auto_model.estimators_[0],
image_config=self.image_config,
algorithm_name=self.naming,
func_name=RandomForestSpecialFunction.TREE_DIAGRAM.value,
local_path=GEOPI_OUTPUT_ARTIFACTS_IMAGE_MODEL_OUTPUT_PATH,
mlflow_path=MLFLOW_ARTIFACT_IMAGE_MODEL_OUTPUT_PATH,
)
Expand Down Expand Up @@ -1634,6 +1655,7 @@ def special_components(self, **kwargs) -> None:
trained_model=self.model,
image_config=self.image_config,
algorithm_name=self.naming,
func_name=XGBoostSpecialFunction.FEATURE_IMPORTANCE.value,
local_path=GEOPI_OUTPUT_ARTIFACTS_IMAGE_MODEL_OUTPUT_PATH,
mlflow_path=MLFLOW_ARTIFACT_IMAGE_MODEL_OUTPUT_PATH,
)
Expand All @@ -1648,6 +1670,7 @@ def special_components(self, is_automl: bool = False, **kwargs) -> None:
trained_model=self.auto_model,
image_config=self.image_config,
algorithm_name=self.naming,
func_name=XGBoostSpecialFunction.FEATURE_IMPORTANCE.value,
local_path=GEOPI_OUTPUT_ARTIFACTS_IMAGE_MODEL_OUTPUT_PATH,
mlflow_path=MLFLOW_ARTIFACT_IMAGE_MODEL_OUTPUT_PATH,
)
Expand Down Expand Up @@ -1889,12 +1912,12 @@ def manual_hyper_parameters(cls) -> Dict:
return hyper_parameters

@staticmethod
def _plot_feature_importance(columns_name: np.ndarray, name_column: str, trained_model: any, algorithm_name: str, local_path: str, mlflow_path: str) -> None:
def _plot_feature_importance(columns_name: np.ndarray, name_column: str, trained_model: any, algorithm_name: str, local_path: str, mlflow_path: str, func_name: str) -> None:
"""Print the feature coefficient value orderly."""
print("-----* Feature Importance *-----")
print(f"-----* {func_name} *-----")
data = plot_logistic_importance(columns_name, trained_model)
save_fig(f"Feature Importance - {algorithm_name}", local_path, mlflow_path)
save_data(data, name_column, f"Feature Importance - {algorithm_name}", local_path, mlflow_path)
save_fig(f"{func_name} - {algorithm_name}", local_path, mlflow_path)
save_data(data, name_column, f"{func_name} - {algorithm_name}", local_path, mlflow_path)

@dispatch()
def special_components(self, **kwargs) -> None:
Expand All @@ -1916,6 +1939,7 @@ def special_components(self, **kwargs) -> None:
name_column=LogisticRegressionClassification.name_all,
trained_model=self.model,
algorithm_name=self.naming,
func_name=LogisticRegressionSpecialFunction.FEATURE_IMPORTANCE.value,
local_path=GEOPI_OUTPUT_ARTIFACTS_IMAGE_MODEL_OUTPUT_PATH,
mlflow_path=MLFLOW_ARTIFACT_IMAGE_MODEL_OUTPUT_PATH,
)
Expand All @@ -1940,6 +1964,7 @@ def special_components(self, is_automl: bool = False, **kwargs) -> None:
name_column=LogisticRegressionClassification.name_all,
trained_model=self.auto_model,
algorithm_name=self.naming,
func_name=LogisticRegressionSpecialFunction.FEATURE_IMPORTANCE.value,
local_path=GEOPI_OUTPUT_ARTIFACTS_IMAGE_MODEL_OUTPUT_PATH,
mlflow_path=MLFLOW_ARTIFACT_IMAGE_MODEL_OUTPUT_PATH,
)
Expand Down Expand Up @@ -2258,13 +2283,13 @@ def manual_hyper_parameters(cls) -> Dict:
return hyper_parameters

@staticmethod
def _plot_loss_curve(trained_model: object, algorithm_name: str, local_path: str, mlflow_path: str) -> None:
def _plot_loss_curve(trained_model: object, algorithm_name: str, func_name: str, local_path: str, mlflow_path: str) -> None:
"""Plot the learning curve of the trained model."""
print("-----* Loss Curve Diagram *-----")
print(f"-----* {func_name} *-----")
data = pd.DataFrame(trained_model.loss_curve_, columns=["Loss"])
data.plot(title="Loss")
save_fig(f"Loss Curve Diagram - {algorithm_name}", local_path, mlflow_path)
save_data_without_data_identifier(data, f"Loss Curve Diagram - {algorithm_name}", local_path, mlflow_path)
save_fig(f"{func_name} - {algorithm_name}", local_path, mlflow_path)
save_data_without_data_identifier(data, f"{func_name} - {algorithm_name}", local_path, mlflow_path)

@dispatch()
def special_components(self, **kwargs) -> None:
Expand All @@ -2274,6 +2299,7 @@ def special_components(self, **kwargs) -> None:
self._plot_loss_curve(
trained_model=self.model,
algorithm_name=self.naming,
func_name=MLPSpecialFunction.LOSS_CURVE_DIAGRAM.value,
local_path=GEOPI_OUTPUT_ARTIFACTS_IMAGE_MODEL_OUTPUT_PATH,
mlflow_path=MLFLOW_ARTIFACT_IMAGE_MODEL_OUTPUT_PATH,
)
Expand All @@ -2286,6 +2312,7 @@ def special_components(self, is_automl: bool, **kwargs) -> None:
self._plot_loss_curve(
trained_model=self.auto_model,
algorithm_name=self.naming,
func_name=MLPSpecialFunction.LOSS_CURVE_DIAGRAM.value,
local_path=GEOPI_OUTPUT_ARTIFACTS_IMAGE_MODEL_OUTPUT_PATH,
mlflow_path=MLFLOW_ARTIFACT_IMAGE_MODEL_OUTPUT_PATH,
)
Expand Down Expand Up @@ -2555,13 +2582,15 @@ def special_components(self, **kwargs) -> None:
trained_model=self.model,
image_config=self.image_config,
algorithm_name=self.naming,
func_name=ExtraTreesSpecialFunction.FEATURE_IMPORTANCE.value,
local_path=GEOPI_OUTPUT_ARTIFACTS_IMAGE_MODEL_OUTPUT_PATH,
mlflow_path=MLFLOW_ARTIFACT_IMAGE_MODEL_OUTPUT_PATH,
)
self._plot_tree(
trained_model=self.model.estimators_[0],
image_config=self.image_config,
algorithm_name=self.naming,
func_name=ExtraTreesSpecialFunction.TREE_DIAGRAM.value,
local_path=GEOPI_OUTPUT_ARTIFACTS_IMAGE_MODEL_OUTPUT_PATH,
mlflow_path=MLFLOW_ARTIFACT_IMAGE_MODEL_OUTPUT_PATH,
)
Expand All @@ -2576,13 +2605,15 @@ def special_components(self, is_automl: bool, **kwargs) -> None:
trained_model=self.auto_model,
image_config=self.image_config,
algorithm_name=self.naming,
func_name=ExtraTreesSpecialFunction.FEATURE_IMPORTANCE.value,
local_path=GEOPI_OUTPUT_ARTIFACTS_IMAGE_MODEL_OUTPUT_PATH,
mlflow_path=MLFLOW_ARTIFACT_IMAGE_MODEL_OUTPUT_PATH,
)
self._plot_tree(
trained_model=self.auto_model.estimators_[0],
image_config=self.image_config,
algorithm_name=self.naming,
func_name=ExtraTreesSpecialFunction.TREE_DIAGRAM.value,
local_path=GEOPI_OUTPUT_ARTIFACTS_IMAGE_MODEL_OUTPUT_PATH,
mlflow_path=MLFLOW_ARTIFACT_IMAGE_MODEL_OUTPUT_PATH,
)
Expand Down Expand Up @@ -2920,13 +2951,15 @@ def special_components(self, **kwargs) -> None:
trained_model=self.model,
image_config=self.image_config,
algorithm_name=self.naming,
func_name=GradientBoostingSpecialFunction.FEATURE_IMPORTANCE.value,
local_path=GEOPI_OUTPUT_ARTIFACTS_IMAGE_MODEL_OUTPUT_PATH,
mlflow_path=MLFLOW_ARTIFACT_IMAGE_MODEL_OUTPUT_PATH,
)
self._plot_tree(
trained_model=self.model.estimators_[0][0],
image_config=self.image_config,
algorithm_name=self.naming,
func_name=GradientBoostingSpecialFunction.TREE_DIAGRAM.value,
local_path=GEOPI_OUTPUT_ARTIFACTS_IMAGE_MODEL_OUTPUT_PATH,
mlflow_path=MLFLOW_ARTIFACT_IMAGE_MODEL_OUTPUT_PATH,
)
Expand All @@ -2941,13 +2974,15 @@ def special_components(self, is_automl: bool, **kwargs) -> None:
trained_model=self.auto_model,
image_config=self.image_config,
algorithm_name=self.naming,
func_name=GradientBoostingSpecialFunction.FEATURE_IMPORTANCE.value,
local_path=GEOPI_OUTPUT_ARTIFACTS_IMAGE_MODEL_OUTPUT_PATH,
mlflow_path=MLFLOW_ARTIFACT_IMAGE_MODEL_OUTPUT_PATH,
)
self._plot_tree(
trained_model=self.auto_model.estimators_[0][0],
image_config=self.image_config,
algorithm_name=self.naming,
func_name=GradientBoostingSpecialFunction.TREE_DIAGRAM.value,
local_path=GEOPI_OUTPUT_ARTIFACTS_IMAGE_MODEL_OUTPUT_PATH,
mlflow_path=MLFLOW_ARTIFACT_IMAGE_MODEL_OUTPUT_PATH,
)
Expand Down
33 changes: 33 additions & 0 deletions geochemistrypi/data_mining/model/func/algo_classification/_enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

class ClassificationCommonFunction(Enum):
MODEL_SCORE = "Model Score"
CLASSIFICATION_REPORT = "Classification Report"
CONFUSION_MATRIX = "Confusion Matrix"
CROSS_VALIDATION = "Cross Validation"
MODEL_PREDICTION = "Model Prediction"
Expand All @@ -12,3 +13,35 @@ class ClassificationCommonFunction(Enum):
ROC_CURVE = "ROC Curve"
TWO_DIMENSIONAL_DECISION_BOUNDARY_DIAGRAM = "Two-dimensional Decision Boundary Diagram"
PERMUTATION_IMPORTANCE_DIAGRAM = "Permutation Importance Diagram"


class DecisionTreeSpecialFunction(Enum):
FEATURE_IMPORTANCE = "Feature Importance"
TREE_DIAGRAM = "Tree Diagram"


class RandomForestSpecialFunction(Enum):
FEATURE_IMPORTANCE = "Feature Importance"
TREE_DIAGRAM = "Tree Diagram"


class XGBoostSpecialFunction(Enum):
FEATURE_IMPORTANCE = "Feature Importance"


class LogisticRegressionSpecialFunction(Enum):
FEATURE_IMPORTANCE = "Feature Importance"


class MLPSpecialFunction(Enum):
LOSS_CURVE_DIAGRAM = "Loss Curve Diagram"


class ExtraTreesSpecialFunction(Enum):
FEATURE_IMPORTANCE = "Feature Importance"
TREE_DIAGRAM = "Tree Diagram"


class GradientBoostingSpecialFunction(Enum):
FEATURE_IMPORTANCE = "Feature Importance"
TREE_DIAGRAM = "Tree Diagram"

0 comments on commit 8a14bad

Please sign in to comment.