diff --git a/hlink/linking/model_exploration/link_step_train_test_models.py b/hlink/linking/model_exploration/link_step_train_test_models.py index a7c79ec..a740853 100644 --- a/hlink/linking/model_exploration/link_step_train_test_models.py +++ b/hlink/linking/model_exploration/link_step_train_test_models.py @@ -6,9 +6,12 @@ import itertools import math import re +from typing import Any import numpy as np import pandas as pd from sklearn.metrics import precision_recall_curve, auc +from pyspark.ml import Model, Transformer +import pyspark.sql from pyspark.sql.functions import count, mean import hlink.linking.core.threshold as threshold_core @@ -18,7 +21,7 @@ class LinkStepTrainTestModels(LinkStep): - def __init__(self, task): + def __init__(self, task) -> None: super().__init__( task, "train test models", @@ -35,7 +38,7 @@ def __init__(self, task): ], ) - def _run(self): + def _run(self) -> None: training_conf = str(self.task.training_conf) table_prefix = self.task.table_prefix config = self.task.link_run.config @@ -80,7 +83,7 @@ def _run(self): threshold_ratio = False threshold_matrix = _calc_threshold_matrix(alpha_threshold, threshold_ratio) - results_dfs = {} + results_dfs: dict[int, pd.DataFrame] = {} for i in range(len(threshold_matrix)): results_dfs[i] = _create_results_df() @@ -175,7 +178,19 @@ def _run(self): self._save_otd_data(otd_data, self.task.spark) self.task.spark.sql("set spark.sql.shuffle.partitions=200") - def _get_splits(self, prepped_data, id_a, n_training_iterations, seed): + def _get_splits( + self, + prepped_data: pyspark.sql.DataFrame, + id_a: str, + n_training_iterations: int, + seed: int, + ) -> list[list[pyspark.sql.DataFrame]]: + """ + Get a list of random splits of the prepped_data into two DataFrames. + There are n_training_iterations elements in the list. Each element is + itself a list of two DataFrames which are the splits of prepped_data. + The split DataFrames are roughly equal in size. + """ if self.task.link_run.config[f"{self.task.training_conf}"].get( "split_by_id_a", False ): @@ -200,7 +215,7 @@ def _get_splits(self, prepped_data, id_a, n_training_iterations, seed): return splits - def _custom_param_grid_builder(self, conf): + def _custom_param_grid_builder(self, conf: dict[str, Any]) -> list[dict[str, Any]]: print("Building param grid for models") given_parameters = conf[f"{self.task.training_conf}"]["model_parameters"] new_params = [] @@ -231,16 +246,16 @@ def _custom_param_grid_builder(self, conf): def _capture_results( self, - predictions, - predict_train, - dep_var, - model, - results_df, - otd_data, - at, - tr, - pr_auc, - ): + predictions: pyspark.sql.DataFrame, + predict_train: pyspark.sql.DataFrame, + dep_var: str, + model: Model, + results_df: pd.DataFrame, + otd_data: dict[str, Any] | None, + at: float, + tr: float, + pr_auc: float, + ) -> pd.DataFrame: table_prefix = self.task.table_prefix print("Evaluating model performance...") @@ -284,7 +299,7 @@ def _capture_results( ) return pd.concat([results_df, new_results], ignore_index=True) - def _get_model_parameters(self, conf): + def _get_model_parameters(self, conf: dict[str, Any]) -> list[dict[str, Any]]: training_conf = str(self.task.training_conf) model_parameters = conf[training_conf]["model_parameters"] @@ -296,7 +311,9 @@ def _get_model_parameters(self, conf): ) return model_parameters - def _save_training_results(self, desc_df, spark): + def _save_training_results( + self, desc_df: pd.DataFrame, spark: pyspark.sql.SparkSession + ) -> None: table_prefix = self.task.table_prefix if desc_df.empty: @@ -310,7 +327,9 @@ def _save_training_results(self, desc_df, spark): f"Training results saved to Spark table '{table_prefix}training_results'." ) - def _prepare_otd_table(self, spark, df, id_a, id_b): + def _prepare_otd_table( + self, spark: pyspark.sql.SparkSession, df: pd.DataFrame, id_a: str, id_b: str + ) -> pyspark.sql.DataFrame: spark_df = spark.createDataFrame(df) counted = ( spark_df.groupby(id_a, id_b) @@ -323,7 +342,9 @@ def _prepare_otd_table(self, spark, df, id_a, id_b): ) return counted - def _save_otd_data(self, otd_data, spark): + def _save_otd_data( + self, otd_data: dict[str, Any] | None, spark: pyspark.sql.SparkSession + ) -> None: table_prefix = self.task.table_prefix if otd_data is None: @@ -379,7 +400,7 @@ def _save_otd_data(self, otd_data, spark): else: print("There were no true negatives recorded.") - def _create_otd_data(self, id_a, id_b): + def _create_otd_data(self, id_a: str, id_b: str) -> dict[str, Any] | None: """Output Suspicous Data (OTD): used to check config to see if you should find sketchy training data that the models routinely mis-classify""" training_conf = str(self.task.training_conf) config = self.task.link_run.config @@ -400,7 +421,12 @@ def _create_otd_data(self, id_a, id_b): return None -def _calc_mcc(TP, TN, FP, FN): +def _calc_mcc(TP: int, TN: int, FP: int, FN: int) -> float: + """ + Given the counts of true positives (TP), true negatives (TN), false + positives (FP), and false negatives (FN) for a model run, compute the + Matthews Correlation Coefficient (MCC). + """ if (math.sqrt((TP + FP) * (TP + FN) * (TN + FP) * (TN + FN))) != 0: mcc = ((TP * TN) - (FP * FN)) / ( math.sqrt((TP + FP) * (TP + FN) * (TN + FP) * (TN + FN)) @@ -410,7 +436,9 @@ def _calc_mcc(TP, TN, FP, FN): return mcc -def _calc_threshold_matrix(alpha_threshold, threshold_ratio): +def _calc_threshold_matrix( + alpha_threshold: float | list[float], threshold_ratio: float | list[float] +) -> list[list[float]]: if alpha_threshold and type(alpha_threshold) != list: alpha_threshold = [alpha_threshold] @@ -426,8 +454,13 @@ def _calc_threshold_matrix(alpha_threshold, threshold_ratio): def _get_probability_and_select_pred_columns( - pred_df, model, post_transformer, id_a, id_b, dep_var -): + pred_df: pyspark.sql.DataFrame, + model: Model, + post_transformer: Transformer, + id_a: str, + id_b: str, + dep_var: str, +) -> pyspark.sql.DataFrame: all_prediction_cols = set( [ f"{id_a}", @@ -446,7 +479,9 @@ def _get_probability_and_select_pred_columns( return required_col_df -def _get_confusion_matrix(predictions, dep_var, otd_data): +def _get_confusion_matrix( + predictions: pyspark.sql.DataFrame, dep_var: str, otd_data: dict[str, Any] | None +) -> tuple[int, int, int, int]: TP = predictions.filter((predictions[dep_var] == 1) & (predictions.prediction == 1)) TP_count = TP.count() @@ -486,7 +521,16 @@ def _get_confusion_matrix(predictions, dep_var, otd_data): return TP_count, FP_count, FN_count, TN_count -def _get_aggregate_metrics(TP_count, FP_count, FN_count, TN_count): +def _get_aggregate_metrics( + TP_count: int, FP_count: int, FN_count: int, TN_count: int +) -> tuple[float, float, float]: + """ + Given the counts of true positives, false positivies, false negatives, and + true negatives for a model run, compute several metrics to evaluate the + model's quality. + + Return a tuple of (precision, recall, Matthews Correlation Coefficient). + """ if (TP_count + FP_count) == 0: precision = np.nan else: @@ -499,7 +543,7 @@ def _get_aggregate_metrics(TP_count, FP_count, FN_count, TN_count): return precision, recall, mcc -def _create_results_df(): +def _create_results_df() -> pd.DataFrame: return pd.DataFrame( columns=[ "precision_test", @@ -516,7 +560,12 @@ def _create_results_df(): ) -def _append_results(desc_df, results_df, model_type, params): +def _append_results( + desc_df: pd.DataFrame, + results_df: pd.DataFrame, + model_type: str, + params: dict[str, Any], +) -> pd.DataFrame: # run.pop("type") print(results_df) @@ -548,7 +597,7 @@ def _append_results(desc_df, results_df, model_type, params): return desc_df -def _print_desc_df(desc_df): +def _print_desc_df(desc_df: pd.DataFrame) -> None: pd.set_option("display.max_colwidth", None) print( desc_df.drop( @@ -564,7 +613,7 @@ def _print_desc_df(desc_df): print("\n") -def _load_desc_df_params(desc_df): +def _load_desc_df_params(desc_df: pd.DataFrame) -> pd.DataFrame: params = [ "maxDepth", "numTrees", @@ -591,7 +640,7 @@ def _load_desc_df_params(desc_df): return desc_df -def _create_desc_df(): +def _create_desc_df() -> pd.DataFrame: return pd.DataFrame( columns=[ "model",