Skip to content

Commit

Permalink
Address PR comments
Browse files Browse the repository at this point in the history
  • Loading branch information
ccdavis committed Dec 3, 2024
1 parent 40f344e commit 1e55384
Showing 1 changed file with 73 additions and 67 deletions.
140 changes: 73 additions & 67 deletions hlink/linking/model_exploration/link_step_train_test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ class ModelEval:
score: float
hyperparams: dict[str, Any]
threshold: float | list[float]
threshold_ratio: float | list[float] | bool
threshold_ratio: float | list[float] | None

def print(self):
return f"{self.model_type} {self.score} params: {self.hyperparams}"
Expand Down Expand Up @@ -180,7 +180,7 @@ def _train_model(

test_pred = predictions_tmp.toPandas()
precision, recall, thresholds_raw = precision_recall_curve(
test_pred[f"{dep_var}"],
test_pred[dep_var],
test_pred["probability"].round(2),
pos_label=1,
)
Expand Down Expand Up @@ -241,8 +241,7 @@ def _evaluate_hyperparam_combinations(
dep_var: str,
id_a: str,
id_b: str,
config,
training_conf,
training_settings,
) -> list[ModelEval]:
info = f"Begin evaluating all {len(all_model_parameter_combos)} selected hyperparameter combinations."
print(info)
Expand All @@ -263,7 +262,7 @@ def _evaluate_hyperparam_combinations(
# we need to use model_type, params, score and thresholds to
# do the next step using thresholds.
threshold, threshold_ratio = self._get_thresholds(
hyperparams, config, training_conf
hyperparams, training_settings
)
# thresholds and model_type are mixed in with the model hyper-parameters
# in the config; this removes them before passing to the model training.
Expand All @@ -290,24 +289,17 @@ def _evaluate_hyperparam_combinations(

# Grabs the threshold settings from a single model parameter combination row (after all combinations
# are exploded.) Does not alter the params structure.)
def _get_thresholds(
self, model_parameters, config, training_conf
) -> tuple[Any, Any]:
def _get_thresholds(self, model_parameters, training_settings) -> tuple[Any, Any]:
alpha_threshold = model_parameters.get(
"threshold", config[training_conf].get("threshold", 0.8)
"threshold", training_settings.get("threshold", 0.8)
)
if (
config[training_conf].get("decision", False)
== "drop_duplicate_with_threshold_ratio"
):
if training_settings.get("decision") == "drop_duplicate_with_threshold_ratio":
threshold_ratio = model_parameters.get(
"threshold_ratio",
threshold_core.get_threshold_ratio(
config[training_conf], model_parameters
),
threshold_core.get_threshold_ratio(training_settings, model_parameters),
)
else:
threshold_ratio = False
threshold_ratio = None

return alpha_threshold, threshold_ratio

Expand Down Expand Up @@ -340,9 +332,12 @@ def _evaluate_threshold_combinations(
id_a: str,
id_b: str,
) -> tuple[pd.DataFrame, Any]:
training_conf = str(self.task.training_conf)
training_config_name = str(self.task.training_conf)
config = self.task.link_run.config

id_column = config["id_column"]
training_settings = config[training_config_name]

thresholded_metrics_df = _create_thresholded_metrics_df()

thresholding_training_data = split.get("training")
Expand Down Expand Up @@ -417,15 +412,15 @@ def _evaluate_threshold_combinations(
thresholding_predictions,
this_alpha_threshold,
this_threshold_ratio,
config[training_conf],
config["id_column"],
training_settings,
id_column,
)
predict_train = threshold_core.predict_using_thresholds(
thresholding_predict_train,
this_alpha_threshold,
this_threshold_ratio,
config[training_conf],
config["id_column"],
training_settings,
id_column,
)

end_predict_time = perf_counter()
Expand Down Expand Up @@ -460,13 +455,14 @@ def _evaluate_threshold_combinations(
return thresholded_metrics_df, suspicious_data

def _run(self) -> None:
training_conf = str(self.task.training_conf)
training_section_name = str(self.task.training_conf)
table_prefix = self.task.table_prefix
config = self.task.link_run.config
training_settings = config[training_section_name]

self.task.spark.sql("set spark.sql.shuffle.partitions=1")

dep_var = config[training_conf]["dependent_var"]
dep_var = training_settings["dependent_var"]
id_a = config["id_column"] + "_a"
id_b = config["id_column"] + "_b"

Expand All @@ -478,15 +474,15 @@ def _run(self) -> None:
)

# Stores suspicious data
otd_data = self._create_otd_data(id_a, id_b)
suspicious_data = self._create_suspicious_data(id_a, id_b)

outer_fold_count = config[training_conf].get("n_training_iterations", 10)
outer_fold_count = training_settings.get("n_training_iterations", 10)
inner_fold_count = 3

if outer_fold_count < 3:
raise RuntimeError("You must use at least two training iterations.")
raise RuntimeError("You must use at least three outer folds.")

seed = config[training_conf].get("seed", 2133)
seed = training_settings.get("seed", 2133)

outer_folds = self._get_outer_folds(prepped_data, id_a, outer_fold_count, seed)

Expand Down Expand Up @@ -515,8 +511,7 @@ def _run(self) -> None:
dep_var,
id_a,
id_b,
config,
training_conf,
training_settings,
)

print(
Expand All @@ -526,7 +521,7 @@ def _run(self) -> None:
thresholded_metrics_df, suspicious_data = (
self._evaluate_threshold_combinations(
hyperparam_evaluation_results,
otd_data,
suspicious_data,
{"test": outer_test_data, "training": outer_training_data},
dep_var,
id_a,
Expand All @@ -545,7 +540,7 @@ def _run(self) -> None:
print("*** Final thresholded metrics ***")

self._save_training_results(thresholded_metrics_df, self.task.spark)
self._save_otd_data(suspicious_data, self.task.spark)
self._save_suspicious_data(suspicious_data, self.task.spark)
self.task.spark.sql("set spark.sql.shuffle.partitions=200")

def _split_into_folds(
Expand Down Expand Up @@ -673,9 +668,9 @@ def _capture_results(
dep_var: str,
model: Model,
results_df: pd.DataFrame,
otd_data: dict[str, Any] | None,
suspicious_data: dict[str, Any] | None,
alpha_threshold: float,
threshold_ratio: float,
threshold_ratio: float | None,
pr_auc: float,
) -> pd.DataFrame:
table_prefix = self.task.table_prefix
Expand All @@ -695,7 +690,7 @@ def _capture_results(
test_FP_count,
test_FN_count,
test_TN_count,
) = _get_confusion_matrix(predictions, dep_var, otd_data)
) = _get_confusion_matrix(predictions, dep_var, suspicious_data)
test_precision, test_recall, test_mcc = _get_aggregate_metrics(
test_TP_count, test_FP_count, test_FN_count, test_TN_count
)
Expand All @@ -705,7 +700,7 @@ def _capture_results(
train_FP_count,
train_FN_count,
train_TN_count,
) = _get_confusion_matrix(predict_train, dep_var, otd_data)
) = _get_confusion_matrix(predict_train, dep_var, suspicious_data)
train_precision, train_recall, train_mcc = _get_aggregate_metrics(
train_TP_count, train_FP_count, train_FN_count, train_TN_count
)
Expand Down Expand Up @@ -754,7 +749,7 @@ def _save_training_results(
# f"Training results saved to Spark table '{table_prefix}training_results'."
# )

def _prepare_otd_table(
def _prepare_suspicious_table(
self, spark: pyspark.sql.SparkSession, df: pd.DataFrame, id_a: str, id_b: str
) -> pyspark.sql.DataFrame:
spark_df = spark.createDataFrame(df)
Expand All @@ -769,21 +764,21 @@ def _prepare_otd_table(
)
return counted

def _save_otd_data(
self, otd_data: dict[str, Any] | None, spark: pyspark.sql.SparkSession
def _save_suspicious_data(
self, suspicious_data: dict[str, Any] | None, spark: pyspark.sql.SparkSession
) -> None:
table_prefix = self.task.table_prefix

if otd_data is None:
if suspicious_data is None:
print("OTD suspicious data is None, not saving.")
return
id_a = otd_data["id_a"]
id_b = otd_data["id_b"]
id_a = suspicious_data["id_a"]
id_b = suspicious_data["id_b"]

if not otd_data["FP_data"].empty:
if not suspicious_data["FP_data"].empty:
table_name = f"{table_prefix}repeat_fps"
counted_FPs = self._prepare_otd_table(
spark, otd_data["FP_data"], id_a, id_b
counted_FPs = self._prepare_suspicious_table(
spark, suspicious_data["FP_data"], id_a, id_b
)
counted_FPs.write.mode("overwrite").saveAsTable(table_name)
print(
Expand All @@ -792,10 +787,10 @@ def _save_otd_data(
else:
print("There were no false positives recorded.")

if not otd_data["FN_data"].empty:
if not suspicious_data["FN_data"].empty:
table_name = f"{table_prefix}repeat_fns"
counted_FNs = self._prepare_otd_table(
spark, otd_data["FN_data"], id_a, id_b
counted_FNs = self._prepare_suspicious_table(
spark, suspicious_data["FN_data"], id_a, id_b
)
counted_FNs.write.mode("overwrite").saveAsTable(table_name)
print(
Expand All @@ -804,10 +799,10 @@ def _save_otd_data(
else:
print("There were no false negatives recorded.")

if not otd_data["TP_data"].empty:
if not suspicious_data["TP_data"].empty:
table_name = f"{table_prefix}repeat_tps"
counted_TPs = self._prepare_otd_table(
spark, otd_data["TP_data"], id_a, id_b
counted_TPs = self._prepare_suspicious_table(
spark, suspicious_data["TP_data"], id_a, id_b
)
counted_TPs.write.mode("overwrite").saveAsTable(table_name)
print(
Expand All @@ -816,10 +811,10 @@ def _save_otd_data(
else:
print("There were no true positives recorded.")

if not otd_data["TN_data"].empty:
if not suspicious_data["TN_data"].empty:
table_name = f"{table_prefix}repeat_tns"
counted_TNs = self._prepare_otd_table(
spark, otd_data["TN_data"], id_a, id_b
counted_TNs = self._prepare_suspicious_table(
spark, suspicious_data["TN_data"], id_a, id_b
)
counted_TNs.write.mode("overwrite").saveAsTable(table_name)
print(
Expand All @@ -828,14 +823,15 @@ def _save_otd_data(
else:
print("There were no true negatives recorded.")

def _create_otd_data(self, id_a: str, id_b: str) -> dict[str, Any] | None:
def _create_suspicious_data(self, id_a: str, id_b: str) -> dict[str, Any] | None:
"""Output Suspicious Data (OTD): used to check config to see if you should find sketchy training data that the models routinely mis-classify"""
training_conf = str(self.task.training_conf)
training_section_name = str(self.task.training_conf)
config = self.task.link_run.config
training_settings = config[training_section_name]

if (
"output_suspicious_TD" in config[training_conf]
and config[training_conf]["output_suspicious_TD"]
"output_suspicious_TD" in training_settings
and training_settings["output_suspicious_TD"]
):
return {
"FP_data": pd.DataFrame(),
Expand Down Expand Up @@ -865,7 +861,7 @@ def _calc_mcc(TP: int, TN: int, FP: int, FN: int) -> float:


def _calc_threshold_matrix(
alpha_threshold: float | list[float], threshold_ratio: float | list[float]
alpha_threshold: float | list[float], threshold_ratio: float | list[float] | None
) -> list[list[float]]:
if alpha_threshold and type(alpha_threshold) != list:
alpha_threshold = [alpha_threshold]
Expand Down Expand Up @@ -908,7 +904,9 @@ def _get_probability_and_select_pred_columns(


def _get_confusion_matrix(
predictions: pyspark.sql.DataFrame, dep_var: str, otd_data: dict[str, Any] | None
predictions: pyspark.sql.DataFrame,
dep_var: str,
suspicious_data: dict[str, Any] | None,
) -> tuple[int, int, int, int]:

TP = predictions.filter((predictions[dep_var] == 1) & (predictions.prediction == 1))
Expand All @@ -931,29 +929,37 @@ def _get_confusion_matrix(
# f"Confusion matrix -- true negatives and false negatives: FN {FN_count} TN {TN_count}"
# )

if otd_data:
id_a = otd_data["id_a"]
id_b = otd_data["id_b"]
if suspicious_data:
id_a = suspicious_data["id_a"]
id_b = suspicious_data["id_b"]

new_FP_data = FP.select(
id_a, id_b, dep_var, "prediction", "probability"
).toPandas()
otd_data["FP_data"] = pd.concat([otd_data["FP_data"], new_FP_data])
suspicious_data["FP_data"] = pd.concat(
[suspicious_data["FP_data"], new_FP_data]
)

new_FN_data = FN.select(
id_a, id_b, dep_var, "prediction", "probability"
).toPandas()
otd_data["FN_data"] = pd.concat([otd_data["FN_data"], new_FN_data])
suspicious_data["FN_data"] = pd.concat(
[suspicious_data["FN_data"], new_FN_data]
)

new_TP_data = TP.select(
id_a, id_b, dep_var, "prediction", "probability"
).toPandas()
otd_data["TP_data"] = pd.concat([otd_data["TP_data"], new_TP_data])
suspicious_data["TP_data"] = pd.concat(
[suspicious_data["TP_data"], new_TP_data]
)

new_TN_data = TN.select(
id_a, id_b, dep_var, "prediction", "probability"
).toPandas()
otd_data["TN_data"] = pd.concat([otd_data["TN_data"], new_TN_data])
suspicious_data["TN_data"] = pd.concat(
[suspicious_data["TN_data"], new_TN_data]
)

return TP_count, FP_count, FN_count, TN_count

Expand Down

0 comments on commit 1e55384

Please sign in to comment.