From cf34dd22bf2a4c4c901dfc5104246b394e47f32b Mon Sep 17 00:00:00 2001 From: Pablo Mandiola Date: Tue, 24 Dec 2024 10:36:51 -0300 Subject: [PATCH] fix: replace json with pickle for storing lgbm params --- optuna_integration/_lightgbm_tuner/optimize.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/optuna_integration/_lightgbm_tuner/optimize.py b/optuna_integration/_lightgbm_tuner/optimize.py index a6f7bea9..0d98482b 100644 --- a/optuna_integration/_lightgbm_tuner/optimize.py +++ b/optuna_integration/_lightgbm_tuner/optimize.py @@ -1,6 +1,7 @@ from __future__ import annotations import abc +import base64 from collections.abc import Callable from collections.abc import Container from collections.abc import Generator @@ -8,7 +9,6 @@ from collections.abc import Iterator from collections.abc import Sequence import copy -import json import os import pickle import time @@ -267,8 +267,10 @@ def _postprocess( trial._trial_id, _AVERAGE_ITERATION_TIME_KEY, average_iteration_time ) trial.storage.set_trial_system_attr(trial._trial_id, _STEP_NAME_KEY, self.step_name) + + serialized_params = base64.b64encode(pickle.dumps(self.lgbm_params)).decode('utf-8') trial.storage.set_trial_system_attr( - trial._trial_id, _LGBM_PARAMS_KEY, json.dumps(self.lgbm_params) + trial._trial_id, _LGBM_PARAMS_KEY, serialized_params ) self.trial_count += 1 @@ -439,7 +441,8 @@ def best_score(self) -> float: def best_params(self) -> dict[str, Any]: """Return parameters of the best booster.""" try: - return json.loads(self.study.best_trial.system_attrs[_LGBM_PARAMS_KEY]) + serialized_params = self.study.best_trial.system_attrs[_LGBM_PARAMS_KEY] + return pickle.loads(base64.b64decode(serialized_params.encode('utf-8'))) except ValueError: # Return the default score because no trials have completed. params = copy.deepcopy(_DEFAULT_LIGHTGBM_PARAMETERS)