From 8c724467738dd128124cabfd51bd377c2245ade1 Mon Sep 17 00:00:00 2001 From: rileyh Date: Wed, 27 Nov 2024 10:38:20 -0600 Subject: [PATCH] [#167] Refactor _get_model_parameters() --- .../model_exploration/link_step_train_test_models.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/hlink/linking/model_exploration/link_step_train_test_models.py b/hlink/linking/model_exploration/link_step_train_test_models.py index d6dce8f..99a929c 100644 --- a/hlink/linking/model_exploration/link_step_train_test_models.py +++ b/hlink/linking/model_exploration/link_step_train_test_models.py @@ -687,9 +687,6 @@ def _custom_param_grid_builder( def _get_model_parameters(training_config: dict[str, Any]) -> list[dict[str, Any]]: - model_parameters = training_config["model_parameters"] - model_parameter_search = training_config.get("model_parameter_search") - if "param_grid" in training_config: print( dedent( @@ -705,6 +702,11 @@ def _get_model_parameters(training_config: dict[str, Any]) -> list[dict[str, Any ), file=sys.stderr, ) + + model_parameters = training_config["model_parameters"] + model_parameter_search = training_config.get("model_parameter_search") + use_param_grid = training_config.get("param_grid", False) + if model_parameter_search is not None: strategy = model_parameter_search["strategy"] if strategy == "explicit": @@ -713,8 +715,8 @@ def _get_model_parameters(training_config: dict[str, Any]) -> list[dict[str, Any return _custom_param_grid_builder(model_parameters) else: raise ValueError(f"Unknown model_parameter_search strategy '{strategy}'") - elif "param_grid" in training_config and training_config["param_grid"]: - model_parameters = _custom_param_grid_builder(model_parameters) + elif use_param_grid: + return _custom_param_grid_builder(model_parameters) elif model_parameters == []: raise ValueError( "No model parameters found. In 'training' config, either supply 'model_parameters' or 'param_grid'."