From 077c9c8cf057544d3fc67f4ffc6fc0d8d21bcda5 Mon Sep 17 00:00:00 2001 From: "dante.l" Date: Thu, 1 Sep 2022 09:38:32 +0900 Subject: [PATCH] support LGBMRanker conversion --- onnxmltools/convert/lightgbm/_parse.py | 7 +++++-- .../convert/lightgbm/operator_converters/LightGbm.py | 5 +++++ onnxmltools/convert/lightgbm/shape_calculators/Ranker.py | 6 ++++++ onnxmltools/convert/lightgbm/shape_calculators/__init__.py | 1 + 4 files changed, 17 insertions(+), 2 deletions(-) create mode 100644 onnxmltools/convert/lightgbm/shape_calculators/Ranker.py diff --git a/onnxmltools/convert/lightgbm/_parse.py b/onnxmltools/convert/lightgbm/_parse.py index 011f6d812..fca4b5586 100644 --- a/onnxmltools/convert/lightgbm/_parse.py +++ b/onnxmltools/convert/lightgbm/_parse.py @@ -7,14 +7,15 @@ from ..common.data_types import (FloatTensorType, SequenceType, DictionaryType, StringType, Int64Type) -from lightgbm import LGBMClassifier, LGBMRegressor +from lightgbm import LGBMClassifier, LGBMRegressor, LGBMRanker lightgbm_classifier_list = [LGBMClassifier] # Associate scikit-learn types with our operator names. If two scikit-learn models share a single name, it means their # are equivalent in terms of conversion. lightgbm_operator_name_map = {LGBMClassifier: 'LgbmClassifier', - LGBMRegressor: 'LgbmRegressor'} + LGBMRegressor: 'LgbmRegressor', + LGBMRanker: 'LgbmRanker'} class WrappedBooster: @@ -31,6 +32,8 @@ def __init__(self, booster): self.classes_ = self._generate_classes(booster) elif self.objective_.startswith('regression'): self.operator_name = 'LgbmRegressor' + elif self.objective_.startswith('lambdarank'): + self.operator_name = 'LgbmRanker' else: raise NotImplementedError( 'Unsupported LightGbm objective: %r.' % self.objective_) diff --git a/onnxmltools/convert/lightgbm/operator_converters/LightGbm.py b/onnxmltools/convert/lightgbm/operator_converters/LightGbm.py index 415b63a37..5da4ca5a1 100644 --- a/onnxmltools/convert/lightgbm/operator_converters/LightGbm.py +++ b/onnxmltools/convert/lightgbm/operator_converters/LightGbm.py @@ -455,6 +455,10 @@ def convert_lightgbm(scope, operator, container): # so we need to add an 'Exp' post transform node to the model attrs['post_transform'] = 'NONE' post_transform = "Exp" + elif gbm_text['objective'].startswith('lambdarank'): + n_classes = 1 # Ranker has only one output variable + attrs['post_transform'] = 'NONE' + attrs['n_targets'] = n_classes else: raise RuntimeError( "LightGBM objective should be cleaned already not '{}'.".format( @@ -818,3 +822,4 @@ def convert_lgbm_zipmap(scope, operator, container): register_converter('LgbmClassifier', convert_lightgbm) register_converter('LgbmRegressor', convert_lightgbm) register_converter('LgbmZipMap', convert_lgbm_zipmap) +register_converter('LgbmRanker', convert_lightgbm) \ No newline at end of file diff --git a/onnxmltools/convert/lightgbm/shape_calculators/Ranker.py b/onnxmltools/convert/lightgbm/shape_calculators/Ranker.py new file mode 100644 index 000000000..48e0e85cb --- /dev/null +++ b/onnxmltools/convert/lightgbm/shape_calculators/Ranker.py @@ -0,0 +1,6 @@ +# SPDX-License-Identifier: Apache-2.0 + +from ...common._registration import register_shape_calculator +from ...common.shape_calculator import calculate_linear_regressor_output_shapes + +register_shape_calculator('LgbmRanker', calculate_linear_regressor_output_shapes) diff --git a/onnxmltools/convert/lightgbm/shape_calculators/__init__.py b/onnxmltools/convert/lightgbm/shape_calculators/__init__.py index e7a2c3d9b..7fd1224e6 100644 --- a/onnxmltools/convert/lightgbm/shape_calculators/__init__.py +++ b/onnxmltools/convert/lightgbm/shape_calculators/__init__.py @@ -3,3 +3,4 @@ # To register shape calculators for lightgbm operators, import associated modules here. from . import Classifier from . import Regressor +from . import Ranker \ No newline at end of file