From d6c5c2e8b8e272cbd3c575be2f8c79154decd8aa Mon Sep 17 00:00:00 2001 From: hanhui <193691140@qq.com> Date: Tue, 17 Oct 2023 10:53:32 +0800 Subject: [PATCH] Fix bug at https://github.com/microsoft/FLAML/issues/1244 --- flaml/tune/searcher/search_thread.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/flaml/tune/searcher/search_thread.py b/flaml/tune/searcher/search_thread.py index f0488c8181..a7c43502d5 100644 --- a/flaml/tune/searcher/search_thread.py +++ b/flaml/tune/searcher/search_thread.py @@ -23,6 +23,19 @@ logger = logging.getLogger(__name__) +def recursive_update(d:dict, u:dict): + """ + Args: + d (dict): The target dictionary to be updated. + u (dict): A dictionary containing values to be merged into `d`. + """ + for k, v in u.items(): + if isinstance(v, dict) and k in d and isinstance(d[k], dict): + recursive_update(d[k], v) + else: + d[k] = v + + class SearchThread: """Class of global or local search thread.""" @@ -63,7 +76,7 @@ def suggest(self, trial_id: str) -> Optional[Dict]: try: config = self._search_alg.suggest(trial_id) if isinstance(self._search_alg._space, dict): - config.update(self._const) + recursive_update(config, self._const) else: # define by run config, self.space = unflatten_hierarchical(config, self._space)