Skip to content

Commit

Permalink
Fixes early stopping with XGBoost 2.0 (#597)
Browse files Browse the repository at this point in the history
* add one unit test to investigate a bug

Signed-off-by: xadupre <[email protected]>

* add one more test

Signed-off-by: xadupre <[email protected]>

* remove unnecessary print

Signed-off-by: xadupre <[email protected]>

* update CI

Signed-off-by: Xavier Dupre <[email protected]>

* ci

Signed-off-by: Xavier Dupre <[email protected]>

* remove removed files

Signed-off-by: Xavier Dupre <[email protected]>

* update test

Signed-off-by: Xavier Dupre <[email protected]>

* fix early stopping

Signed-off-by: Xavier Dupre <[email protected]>

* fix rf models

Signed-off-by: Xavier Dupre <[email protected]>

* remaining merge issue

Signed-off-by: Xavier Dupre <[email protected]>

---------

Signed-off-by: xadupre <[email protected]>
Signed-off-by: Xavier Dupre <[email protected]>
  • Loading branch information
xadupre authored Dec 16, 2023
1 parent 7858f9f commit 180e733
Show file tree
Hide file tree
Showing 8 changed files with 517 additions and 17 deletions.
14 changes: 7 additions & 7 deletions .azure-pipelines/linux-conda-CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,13 @@ jobs:
numpy.version: ''
scipy.version: ''

Python311-1150-RT1160-xgb175-lgbm40:
Python311-1150-RT1163-xgb175-lgbm40:
python.version: '3.11'
ONNX_PATH: 'onnx==1.15.0'
ONNXRT_PATH: 'onnxruntime==1.16.2'
ONNXRT_PATH: 'onnxruntime==1.16.3'
COREML_PATH: NONE
lightgbm.version: '>=4.0'
xgboost.version: '==1.7.5'
xgboost.version: '>=1.7.5,<2'
numpy.version: ''
scipy.version: ''

Expand All @@ -41,7 +41,7 @@ jobs:
ONNXRT_PATH: 'onnxruntime==1.16.2'
COREML_PATH: NONE
lightgbm.version: '>=4.0'
xgboost.version: '==1.7.5'
xgboost.version: '>=1.7.5,<2'
numpy.version: ''
scipy.version: ''

Expand All @@ -51,7 +51,7 @@ jobs:
ONNXRT_PATH: 'onnxruntime==1.15.1'
COREML_PATH: NONE
lightgbm.version: '<4.0'
xgboost.version: '==1.7.5'
xgboost.version: '>=1.7.5,<2'
numpy.version: ''
scipy.version: ''

Expand All @@ -61,7 +61,7 @@ jobs:
ONNXRT_PATH: 'onnxruntime==1.14.0'
COREML_PATH: NONE
lightgbm.version: '<4.0'
xgboost.version: '==1.7.5'
xgboost.version: '>=1.7.5,<2'
numpy.version: ''
scipy.version: ''

Expand All @@ -71,7 +71,7 @@ jobs:
ONNXRT_PATH: 'onnxruntime==1.15.1'
COREML_PATH: NONE
lightgbm.version: '>=4.0'
xgboost.version: '==1.7.5'
xgboost.version: '>=1.7.5,<2'
numpy.version: ''
scipy.version: '==1.8.0'

Expand Down
8 changes: 8 additions & 0 deletions .azure-pipelines/win32-conda-CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,14 @@ jobs:
strategy:
matrix:

Python311-1150-RT1163:
python.version: '3.11'
ONNX_PATH: 'onnx==1.15.0'
ONNXRT_PATH: 'onnxruntime==1.16.3'
COREML_PATH: NONE
numpy.version: ''
xgboost.version: '2.0.2'

Python311-1150-RT1162:
python.version: '3.11'
ONNX_PATH: 'onnx==1.15.0'
Expand Down
2 changes: 2 additions & 0 deletions CHANGELOGS.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

## 1.12.0

* Fix early stopping for XGBClassifier and xgboost > 2
[#597](https://github.com/onnx/onnxmltools/pull/597)
* Fix discrepancies with XGBRegressor and xgboost > 2
[#670](https://github.com/onnx/onnxmltools/pull/670)
* Support count:poisson for XGBRegressor
Expand Down
8 changes: 8 additions & 0 deletions onnxmltools/convert/xgboost/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,14 @@ def get_xgb_params(xgb_node):
bs = float(config["learner"]["learner_model_param"]["base_score"])
# xgboost >= 2.0
params["base_score"] = bs

bst = xgb_node.get_booster()
if hasattr(bst, "best_ntree_limit"):
params["best_ntree_limit"] = bst.best_ntree_limit
if "gradient_booster" in config["learner"]:
gbp = config["learner"]["gradient_booster"]["gbtree_model_param"]
if "num_trees" in gbp:
params["best_ntree_limit"] = int(gbp["num_trees"])
return params


Expand Down
25 changes: 17 additions & 8 deletions onnxmltools/convert/xgboost/operator_converters/XGBoost.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,20 @@ def common_members(xgb_node, inputs):
params = XGBConverter.get_xgb_params(xgb_node)
objective = params["objective"]
base_score = params["base_score"]
if hasattr(xgb_node, "best_ntree_limit"):
best_ntree_limit = xgb_node.best_ntree_limit
elif hasattr(xgb_node, "best_iteration"):
best_ntree_limit = xgb_node.best_iteration + 1
else:
best_ntree_limit = params.get("best_ntree_limit", None)
if base_score is None:
base_score = 0.5
booster = xgb_node.get_booster()
# The json format was available in October 2017.
# XGBoost 0.7 was the first version released with it.
js_tree_list = booster.get_dump(with_stats=True, dump_format="json")
js_trees = [json.loads(s) for s in js_tree_list]
return objective, base_score, js_trees
return objective, base_score, js_trees, best_ntree_limit

@staticmethod
def _get_default_tree_attribute_pairs(is_classifier):
Expand Down Expand Up @@ -231,17 +237,17 @@ def _get_default_tree_attribute_pairs():
def convert(scope, operator, container):
xgb_node = operator.raw_operator
inputs = operator.inputs
objective, base_score, js_trees = XGBConverter.common_members(xgb_node, inputs)
objective, base_score, js_trees, best_ntree_limit = XGBConverter.common_members(
xgb_node, inputs
)

if objective in ["reg:gamma", "reg:tweedie"]:
raise RuntimeError("Objective '{}' not supported.".format(objective))

attr_pairs = XGBRegressorConverter._get_default_tree_attribute_pairs()
attr_pairs["base_values"] = [base_score]

bst = xgb_node.get_booster()
best_ntree_limit = getattr(bst, "best_ntree_limit", len(js_trees))
if best_ntree_limit < len(js_trees):
if best_ntree_limit and best_ntree_limit < len(js_trees):
js_trees = js_trees[:best_ntree_limit]

XGBConverter.fill_tree_attributes(
Expand Down Expand Up @@ -289,7 +295,9 @@ def convert(scope, operator, container):
xgb_node = operator.raw_operator
inputs = operator.inputs

objective, base_score, js_trees = XGBConverter.common_members(xgb_node, inputs)
objective, base_score, js_trees, best_ntree_limit = XGBConverter.common_members(
xgb_node, inputs
)

params = XGBConverter.get_xgb_params(xgb_node)
n_estimators = get_n_estimators_classifier(xgb_node, params, js_trees)
Expand All @@ -305,8 +313,9 @@ def convert(scope, operator, container):
else:
ncl = (max(attr_pairs["class_treeids"]) + 1) // n_estimators

bst = xgb_node.get_booster()
best_ntree_limit = getattr(bst, "best_ntree_limit", len(js_trees)) * ncl
best_ntree_limit = best_ntree_limit or len(js_trees)
if ncl > 0:
best_ntree_limit *= ncl
if 0 < best_ntree_limit < len(js_trees):
js_trees = js_trees[:best_ntree_limit]
attr_pairs = XGBClassifierConverter._get_default_tree_attribute_pairs()
Expand Down
Loading

0 comments on commit 180e733

Please sign in to comment.