Skip to content

Commit

Permalink
[#167] Pull _get_model_parameters() out of the LinkStep class
Browse files Browse the repository at this point in the history
This will make this piece of code easier to understand and test.
  • Loading branch information
riley-harper committed Nov 26, 2024
1 parent 605369b commit 2204152
Showing 1 changed file with 14 additions and 13 deletions.
27 changes: 14 additions & 13 deletions hlink/linking/model_exploration/link_step_train_test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def _run(self) -> None:

splits = self._get_splits(prepped_data, id_a, n_training_iterations, seed)

model_parameters = self._get_model_parameters(config)
model_parameters = _get_model_parameters(training_conf, config)

logger.info(
f"There are {len(model_parameters)} sets of model parameters to explore; "
Expand Down Expand Up @@ -298,18 +298,6 @@ def _capture_results(
)
return pd.concat([results_df, new_results], ignore_index=True)

def _get_model_parameters(self, conf: dict[str, Any]) -> list[dict[str, Any]]:
training_conf = str(self.task.training_conf)

model_parameters = conf[training_conf]["model_parameters"]
if "param_grid" in conf[training_conf] and conf[training_conf]["param_grid"]:
model_parameters = _custom_param_grid_builder(model_parameters)
elif model_parameters == []:
raise ValueError(
"No model parameters found. In 'training' config, either supply 'model_parameters' or 'param_grid'."
)
return model_parameters

def _save_training_results(
self, desc_df: pd.DataFrame, spark: pyspark.sql.SparkSession
) -> None:
Expand Down Expand Up @@ -694,3 +682,16 @@ def _custom_param_grid_builder(

new_params.extend(params_exploded)
return new_params


def _get_model_parameters(
training_conf: str, conf: dict[str, Any]
) -> list[dict[str, Any]]:
model_parameters = conf[training_conf]["model_parameters"]
if "param_grid" in conf[training_conf] and conf[training_conf]["param_grid"]:
model_parameters = _custom_param_grid_builder(model_parameters)
elif model_parameters == []:
raise ValueError(
"No model parameters found. In 'training' config, either supply 'model_parameters' or 'param_grid'."
)
return model_parameters

0 comments on commit 2204152

Please sign in to comment.