diff --git a/hlink/linking/model_exploration/link_step_train_test_models.py b/hlink/linking/model_exploration/link_step_train_test_models.py index e02b7f7..26a5581 100644 --- a/hlink/linking/model_exploration/link_step_train_test_models.py +++ b/hlink/linking/model_exploration/link_step_train_test_models.py @@ -124,7 +124,7 @@ class ModelEval: score: float hyperparams: dict[str, Any] threshold: float | list[float] - threshold_ratio: float | list[float] | bool + threshold_ratio: float | list[float] | None def print(self): return f"{self.model_type} {self.score} params: {self.hyperparams}" @@ -180,7 +180,7 @@ def _train_model( test_pred = predictions_tmp.toPandas() precision, recall, thresholds_raw = precision_recall_curve( - test_pred[f"{dep_var}"], + test_pred[dep_var], test_pred["probability"].round(2), pos_label=1, ) @@ -241,8 +241,7 @@ def _evaluate_hyperparam_combinations( dep_var: str, id_a: str, id_b: str, - config, - training_conf, + training_settings, ) -> list[ModelEval]: info = f"Begin evaluating all {len(all_model_parameter_combos)} selected hyperparameter combinations." print(info) @@ -263,7 +262,7 @@ def _evaluate_hyperparam_combinations( # we need to use model_type, params, score and thresholds to # do the next step using thresholds. threshold, threshold_ratio = self._get_thresholds( - hyperparams, config, training_conf + hyperparams, training_settings ) # thresholds and model_type are mixed in with the model hyper-parameters # in the config; this removes them before passing to the model training. @@ -290,24 +289,17 @@ def _evaluate_hyperparam_combinations( # Grabs the threshold settings from a single model parameter combination row (after all combinations # are exploded.) Does not alter the params structure.) - def _get_thresholds( - self, model_parameters, config, training_conf - ) -> tuple[Any, Any]: + def _get_thresholds(self, model_parameters, training_settings) -> tuple[Any, Any]: alpha_threshold = model_parameters.get( - "threshold", config[training_conf].get("threshold", 0.8) + "threshold", training_settings.get("threshold", 0.8) ) - if ( - config[training_conf].get("decision", False) - == "drop_duplicate_with_threshold_ratio" - ): + if training_settings.get("decision") == "drop_duplicate_with_threshold_ratio": threshold_ratio = model_parameters.get( "threshold_ratio", - threshold_core.get_threshold_ratio( - config[training_conf], model_parameters - ), + threshold_core.get_threshold_ratio(training_settings, model_parameters), ) else: - threshold_ratio = False + threshold_ratio = None return alpha_threshold, threshold_ratio @@ -340,9 +332,12 @@ def _evaluate_threshold_combinations( id_a: str, id_b: str, ) -> tuple[pd.DataFrame, Any]: - training_conf = str(self.task.training_conf) + training_config_name = str(self.task.training_conf) config = self.task.link_run.config + id_column = config["id_column"] + training_settings = config[training_config_name] + thresholded_metrics_df = _create_thresholded_metrics_df() thresholding_training_data = split.get("training") @@ -417,15 +412,15 @@ def _evaluate_threshold_combinations( thresholding_predictions, this_alpha_threshold, this_threshold_ratio, - config[training_conf], - config["id_column"], + training_settings, + id_column, ) predict_train = threshold_core.predict_using_thresholds( thresholding_predict_train, this_alpha_threshold, this_threshold_ratio, - config[training_conf], - config["id_column"], + training_settings, + id_column, ) end_predict_time = perf_counter() @@ -460,13 +455,14 @@ def _evaluate_threshold_combinations( return thresholded_metrics_df, suspicious_data def _run(self) -> None: - training_conf = str(self.task.training_conf) + training_section_name = str(self.task.training_conf) table_prefix = self.task.table_prefix config = self.task.link_run.config + training_settings = config[training_section_name] self.task.spark.sql("set spark.sql.shuffle.partitions=1") - dep_var = config[training_conf]["dependent_var"] + dep_var = training_settings["dependent_var"] id_a = config["id_column"] + "_a" id_b = config["id_column"] + "_b" @@ -478,15 +474,15 @@ def _run(self) -> None: ) # Stores suspicious data - otd_data = self._create_otd_data(id_a, id_b) + suspicious_data = self._create_suspicious_data(id_a, id_b) - outer_fold_count = config[training_conf].get("n_training_iterations", 10) + outer_fold_count = training_settings.get("n_training_iterations", 10) inner_fold_count = 3 if outer_fold_count < 3: - raise RuntimeError("You must use at least two training iterations.") + raise RuntimeError("You must use at least three outer folds.") - seed = config[training_conf].get("seed", 2133) + seed = training_settings.get("seed", 2133) outer_folds = self._get_outer_folds(prepped_data, id_a, outer_fold_count, seed) @@ -515,8 +511,7 @@ def _run(self) -> None: dep_var, id_a, id_b, - config, - training_conf, + training_settings, ) print( @@ -526,7 +521,7 @@ def _run(self) -> None: thresholded_metrics_df, suspicious_data = ( self._evaluate_threshold_combinations( hyperparam_evaluation_results, - otd_data, + suspicious_data, {"test": outer_test_data, "training": outer_training_data}, dep_var, id_a, @@ -545,7 +540,7 @@ def _run(self) -> None: print("*** Final thresholded metrics ***") self._save_training_results(thresholded_metrics_df, self.task.spark) - self._save_otd_data(suspicious_data, self.task.spark) + self._save_suspicious_data(suspicious_data, self.task.spark) self.task.spark.sql("set spark.sql.shuffle.partitions=200") def _split_into_folds( @@ -673,9 +668,9 @@ def _capture_results( dep_var: str, model: Model, results_df: pd.DataFrame, - otd_data: dict[str, Any] | None, + suspicious_data: dict[str, Any] | None, alpha_threshold: float, - threshold_ratio: float, + threshold_ratio: float | None, pr_auc: float, ) -> pd.DataFrame: table_prefix = self.task.table_prefix @@ -695,7 +690,7 @@ def _capture_results( test_FP_count, test_FN_count, test_TN_count, - ) = _get_confusion_matrix(predictions, dep_var, otd_data) + ) = _get_confusion_matrix(predictions, dep_var, suspicious_data) test_precision, test_recall, test_mcc = _get_aggregate_metrics( test_TP_count, test_FP_count, test_FN_count, test_TN_count ) @@ -705,7 +700,7 @@ def _capture_results( train_FP_count, train_FN_count, train_TN_count, - ) = _get_confusion_matrix(predict_train, dep_var, otd_data) + ) = _get_confusion_matrix(predict_train, dep_var, suspicious_data) train_precision, train_recall, train_mcc = _get_aggregate_metrics( train_TP_count, train_FP_count, train_FN_count, train_TN_count ) @@ -754,7 +749,7 @@ def _save_training_results( # f"Training results saved to Spark table '{table_prefix}training_results'." # ) - def _prepare_otd_table( + def _prepare_suspicious_table( self, spark: pyspark.sql.SparkSession, df: pd.DataFrame, id_a: str, id_b: str ) -> pyspark.sql.DataFrame: spark_df = spark.createDataFrame(df) @@ -769,21 +764,21 @@ def _prepare_otd_table( ) return counted - def _save_otd_data( - self, otd_data: dict[str, Any] | None, spark: pyspark.sql.SparkSession + def _save_suspicious_data( + self, suspicious_data: dict[str, Any] | None, spark: pyspark.sql.SparkSession ) -> None: table_prefix = self.task.table_prefix - if otd_data is None: + if suspicious_data is None: print("OTD suspicious data is None, not saving.") return - id_a = otd_data["id_a"] - id_b = otd_data["id_b"] + id_a = suspicious_data["id_a"] + id_b = suspicious_data["id_b"] - if not otd_data["FP_data"].empty: + if not suspicious_data["FP_data"].empty: table_name = f"{table_prefix}repeat_fps" - counted_FPs = self._prepare_otd_table( - spark, otd_data["FP_data"], id_a, id_b + counted_FPs = self._prepare_suspicious_table( + spark, suspicious_data["FP_data"], id_a, id_b ) counted_FPs.write.mode("overwrite").saveAsTable(table_name) print( @@ -792,10 +787,10 @@ def _save_otd_data( else: print("There were no false positives recorded.") - if not otd_data["FN_data"].empty: + if not suspicious_data["FN_data"].empty: table_name = f"{table_prefix}repeat_fns" - counted_FNs = self._prepare_otd_table( - spark, otd_data["FN_data"], id_a, id_b + counted_FNs = self._prepare_suspicious_table( + spark, suspicious_data["FN_data"], id_a, id_b ) counted_FNs.write.mode("overwrite").saveAsTable(table_name) print( @@ -804,10 +799,10 @@ def _save_otd_data( else: print("There were no false negatives recorded.") - if not otd_data["TP_data"].empty: + if not suspicious_data["TP_data"].empty: table_name = f"{table_prefix}repeat_tps" - counted_TPs = self._prepare_otd_table( - spark, otd_data["TP_data"], id_a, id_b + counted_TPs = self._prepare_suspicious_table( + spark, suspicious_data["TP_data"], id_a, id_b ) counted_TPs.write.mode("overwrite").saveAsTable(table_name) print( @@ -816,10 +811,10 @@ def _save_otd_data( else: print("There were no true positives recorded.") - if not otd_data["TN_data"].empty: + if not suspicious_data["TN_data"].empty: table_name = f"{table_prefix}repeat_tns" - counted_TNs = self._prepare_otd_table( - spark, otd_data["TN_data"], id_a, id_b + counted_TNs = self._prepare_suspicious_table( + spark, suspicious_data["TN_data"], id_a, id_b ) counted_TNs.write.mode("overwrite").saveAsTable(table_name) print( @@ -828,14 +823,15 @@ def _save_otd_data( else: print("There were no true negatives recorded.") - def _create_otd_data(self, id_a: str, id_b: str) -> dict[str, Any] | None: + def _create_suspicious_data(self, id_a: str, id_b: str) -> dict[str, Any] | None: """Output Suspicious Data (OTD): used to check config to see if you should find sketchy training data that the models routinely mis-classify""" - training_conf = str(self.task.training_conf) + training_section_name = str(self.task.training_conf) config = self.task.link_run.config + training_settings = config[training_section_name] if ( - "output_suspicious_TD" in config[training_conf] - and config[training_conf]["output_suspicious_TD"] + "output_suspicious_TD" in training_settings + and training_settings["output_suspicious_TD"] ): return { "FP_data": pd.DataFrame(), @@ -865,7 +861,7 @@ def _calc_mcc(TP: int, TN: int, FP: int, FN: int) -> float: def _calc_threshold_matrix( - alpha_threshold: float | list[float], threshold_ratio: float | list[float] + alpha_threshold: float | list[float], threshold_ratio: float | list[float] | None ) -> list[list[float]]: if alpha_threshold and type(alpha_threshold) != list: alpha_threshold = [alpha_threshold] @@ -908,7 +904,9 @@ def _get_probability_and_select_pred_columns( def _get_confusion_matrix( - predictions: pyspark.sql.DataFrame, dep_var: str, otd_data: dict[str, Any] | None + predictions: pyspark.sql.DataFrame, + dep_var: str, + suspicious_data: dict[str, Any] | None, ) -> tuple[int, int, int, int]: TP = predictions.filter((predictions[dep_var] == 1) & (predictions.prediction == 1)) @@ -931,29 +929,37 @@ def _get_confusion_matrix( # f"Confusion matrix -- true negatives and false negatives: FN {FN_count} TN {TN_count}" # ) - if otd_data: - id_a = otd_data["id_a"] - id_b = otd_data["id_b"] + if suspicious_data: + id_a = suspicious_data["id_a"] + id_b = suspicious_data["id_b"] new_FP_data = FP.select( id_a, id_b, dep_var, "prediction", "probability" ).toPandas() - otd_data["FP_data"] = pd.concat([otd_data["FP_data"], new_FP_data]) + suspicious_data["FP_data"] = pd.concat( + [suspicious_data["FP_data"], new_FP_data] + ) new_FN_data = FN.select( id_a, id_b, dep_var, "prediction", "probability" ).toPandas() - otd_data["FN_data"] = pd.concat([otd_data["FN_data"], new_FN_data]) + suspicious_data["FN_data"] = pd.concat( + [suspicious_data["FN_data"], new_FN_data] + ) new_TP_data = TP.select( id_a, id_b, dep_var, "prediction", "probability" ).toPandas() - otd_data["TP_data"] = pd.concat([otd_data["TP_data"], new_TP_data]) + suspicious_data["TP_data"] = pd.concat( + [suspicious_data["TP_data"], new_TP_data] + ) new_TN_data = TN.select( id_a, id_b, dep_var, "prediction", "probability" ).toPandas() - otd_data["TN_data"] = pd.concat([otd_data["TN_data"], new_TN_data]) + suspicious_data["TN_data"] = pd.concat( + [suspicious_data["TN_data"], new_TN_data] + ) return TP_count, FP_count, FN_count, TN_count