Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support LGBMRanker conversion #580

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions onnxmltools/convert/lightgbm/_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,15 @@
from ..common.data_types import (FloatTensorType,
SequenceType, DictionaryType, StringType, Int64Type)

from lightgbm import LGBMClassifier, LGBMRegressor
from lightgbm import LGBMClassifier, LGBMRegressor, LGBMRanker

lightgbm_classifier_list = [LGBMClassifier]

# Associate scikit-learn types with our operator names. If two scikit-learn models share a single name, it means their
# are equivalent in terms of conversion.
lightgbm_operator_name_map = {LGBMClassifier: 'LgbmClassifier',
LGBMRegressor: 'LgbmRegressor'}
LGBMRegressor: 'LgbmRegressor',
LGBMRanker: 'LgbmRanker'}


class WrappedBooster:
Expand All @@ -31,6 +32,8 @@ def __init__(self, booster):
self.classes_ = self._generate_classes(booster)
elif self.objective_.startswith('regression'):
self.operator_name = 'LgbmRegressor'
elif self.objective_.startswith('lambdarank'):
self.operator_name = 'LgbmRanker'
else:
raise NotImplementedError(
'Unsupported LightGbm objective: %r.' % self.objective_)
Expand Down
5 changes: 5 additions & 0 deletions onnxmltools/convert/lightgbm/operator_converters/LightGbm.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,10 @@ def convert_lightgbm(scope, operator, container):
# so we need to add an 'Exp' post transform node to the model
attrs['post_transform'] = 'NONE'
post_transform = "Exp"
elif gbm_text['objective'].startswith('lambdarank'):
n_classes = 1 # Ranker has only one output variable
attrs['post_transform'] = 'NONE'
attrs['n_targets'] = n_classes
else:
raise RuntimeError(
"LightGBM objective should be cleaned already not '{}'.".format(
Expand Down Expand Up @@ -818,3 +822,4 @@ def convert_lgbm_zipmap(scope, operator, container):
register_converter('LgbmClassifier', convert_lightgbm)
register_converter('LgbmRegressor', convert_lightgbm)
register_converter('LgbmZipMap', convert_lgbm_zipmap)
register_converter('LgbmRanker', convert_lightgbm)
6 changes: 6 additions & 0 deletions onnxmltools/convert/lightgbm/shape_calculators/Ranker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# SPDX-License-Identifier: Apache-2.0

from ...common._registration import register_shape_calculator
from ...common.shape_calculator import calculate_linear_regressor_output_shapes

register_shape_calculator('LgbmRanker', calculate_linear_regressor_output_shapes)
1 change: 1 addition & 0 deletions onnxmltools/convert/lightgbm/shape_calculators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@
# To register shape calculators for lightgbm operators, import associated modules here.
from . import Classifier
from . import Regressor
from . import Ranker