Skip to content

Commit

Permalink
[#167] Refactor _get_model_parameters()
Browse files Browse the repository at this point in the history
  • Loading branch information
riley-harper committed Nov 27, 2024
1 parent a476884 commit 8c72446
Showing 1 changed file with 7 additions and 5 deletions.
12 changes: 7 additions & 5 deletions hlink/linking/model_exploration/link_step_train_test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -687,9 +687,6 @@ def _custom_param_grid_builder(


def _get_model_parameters(training_config: dict[str, Any]) -> list[dict[str, Any]]:
model_parameters = training_config["model_parameters"]
model_parameter_search = training_config.get("model_parameter_search")

if "param_grid" in training_config:
print(
dedent(
Expand All @@ -705,6 +702,11 @@ def _get_model_parameters(training_config: dict[str, Any]) -> list[dict[str, Any
),
file=sys.stderr,
)

model_parameters = training_config["model_parameters"]
model_parameter_search = training_config.get("model_parameter_search")
use_param_grid = training_config.get("param_grid", False)

if model_parameter_search is not None:
strategy = model_parameter_search["strategy"]
if strategy == "explicit":
Expand All @@ -713,8 +715,8 @@ def _get_model_parameters(training_config: dict[str, Any]) -> list[dict[str, Any
return _custom_param_grid_builder(model_parameters)
else:
raise ValueError(f"Unknown model_parameter_search strategy '{strategy}'")
elif "param_grid" in training_config and training_config["param_grid"]:
model_parameters = _custom_param_grid_builder(model_parameters)
elif use_param_grid:
return _custom_param_grid_builder(model_parameters)
elif model_parameters == []:
raise ValueError(
"No model parameters found. In 'training' config, either supply 'model_parameters' or 'param_grid'."
Expand Down

0 comments on commit 8c72446

Please sign in to comment.