Coverage for mlprodict/onnx_conv/operator_converters/parse_lightgbm.py: 99%
97 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
1"""
2@file
3@brief Parsers for LightGBM booster.
4"""
5import numpy
6from sklearn.base import ClassifierMixin
7from skl2onnx._parse import _parse_sklearn_classifier, _parse_sklearn_simple_model
8from skl2onnx.common._apply_operation import apply_concat, apply_cast
9from skl2onnx.common.data_types import guess_proto_type
10from skl2onnx.proto import onnx_proto
13class WrappedLightGbmBooster:
14 """
15 A booster can be a classifier, a regressor.
16 Trick to wrap it in a minimal function.
17 """
19 def __init__(self, booster):
20 self.booster_ = booster
21 self.n_features_ = self.booster_.feature_name()
22 self.objective_ = self.get_objective()
23 if self.objective_.startswith('binary'):
24 self.operator_name = 'LgbmClassifier'
25 self.classes_ = self._generate_classes(booster)
26 elif self.objective_.startswith('multiclass'):
27 self.operator_name = 'LgbmClassifier'
28 self.classes_ = self._generate_classes(booster)
29 elif self.objective_.startswith('regression'): # pragma: no cover
30 self.operator_name = 'LgbmRegressor'
31 else: # pragma: no cover
32 raise NotImplementedError(
33 f'Unsupported LightGbm objective: {self.objective_!r}.')
34 try:
35 bt = self.booster_.attr('boosting_type')
36 except KeyError:
37 bt = None
38 if bt is None:
39 try:
40 bt = self.booster_.params['boosting_type']
41 except AttributeError:
42 bt = 'gbdt'
43 self.boosting_type = bt
44 # if average_output:
45 # self.boosting_type = 'rf'
46 # else:
47 # self.boosting_type = 'gbdt'
49 @staticmethod
50 def _generate_classes(booster):
51 if isinstance(booster, dict):
52 num_class = booster['num_class']
53 else:
54 num_class = booster.attr('num_class')
55 if num_class is None:
56 dp = booster.dump_model(num_iteration=1)
57 num_class = dp['num_class']
58 if num_class == 1:
59 return numpy.asarray([0, 1])
60 return numpy.arange(num_class)
62 def get_objective(self):
63 "Returns the objective."
64 if hasattr(self, 'objective_') and self.objective_ is not None:
65 return self.objective_
66 objective = self.booster_.attr('objective')
67 if objective is not None:
68 return objective
69 dp = self.booster_.dump_model(num_iteration=1)
70 return dp['objective']
73class WrappedLightGbmBoosterClassifier(ClassifierMixin):
74 """
75 Trick to wrap a LGBMClassifier into a class.
76 """
78 def __init__(self, wrapped): # pylint: disable=W0231
79 for k in {'boosting_type', '_model_dict', '_model_dict_info',
80 'operator_name', 'classes_', 'booster_', 'n_features_',
81 'objective_'}:
82 if hasattr(wrapped, k):
83 setattr(self, k, getattr(wrapped, k))
86class MockWrappedLightGbmBoosterClassifier(WrappedLightGbmBoosterClassifier):
87 """
88 Mocked lightgbm.
89 """
91 def __init__(self, tree): # pylint: disable=W0231
92 self.dumped_ = tree
93 self.is_mock = True
95 def dump_model(self):
96 "mock dump_model method"
97 self.visited = True
98 return self.dumped_
100 def feature_name(self):
101 "Returns binary features names."
102 return [0, 1]
104 def attr(self, key):
105 "Returns default values for common attributes."
106 if key == 'objective':
107 return "binary"
108 if key == 'num_class':
109 return 1
110 if key == 'average_output':
111 return None
112 raise KeyError( # pragma: no cover
113 f"No response for {key!r}.")
116def lightgbm_parser(scope, model, inputs, custom_parsers=None):
117 """
118 Agnostic parser for LightGBM Booster.
119 """
120 if hasattr(model, "fit"):
121 raise TypeError( # pragma: no cover
122 f"This converter does not apply on type '{type(model)}'.")
124 if len(inputs) == 1:
125 wrapped = WrappedLightGbmBooster(model)
126 objective = wrapped.get_objective()
127 if objective.startswith('binary'):
128 wrapped = WrappedLightGbmBoosterClassifier(wrapped)
129 return _parse_sklearn_classifier(
130 scope, wrapped, inputs, custom_parsers=custom_parsers)
131 if objective.startswith('multiclass'):
132 wrapped = WrappedLightGbmBoosterClassifier(wrapped)
133 return _parse_sklearn_classifier(
134 scope, wrapped, inputs, custom_parsers=custom_parsers)
135 if objective.startswith('regression'): # pragma: no cover
136 return _parse_sklearn_simple_model(
137 scope, wrapped, inputs, custom_parsers=custom_parsers)
138 raise NotImplementedError( # pragma: no cover
139 f"Objective '{objective}' is not implemented yet.")
141 # Multiple columns
142 this_operator = scope.declare_local_operator('LightGBMConcat')
143 this_operator.raw_operator = model
144 this_operator.inputs = inputs
145 var = scope.declare_local_variable(
146 'Xlgbm', inputs[0].type.__class__([None, None]))
147 this_operator.outputs.append(var)
148 return lightgbm_parser(scope, model, this_operator.outputs,
149 custom_parsers=custom_parsers)
152def shape_calculator_lightgbm_concat(operator):
153 """
154 Shape calculator for operator *LightGBMConcat*.
155 """
156 pass
159def converter_lightgbm_concat(scope, operator, container):
160 """
161 Converter for operator *LightGBMConcat*.
162 """
163 op = operator.raw_operator
164 options = container.get_options(op, dict(cast=False))
165 proto_dtype = guess_proto_type(operator.inputs[0].type)
166 if proto_dtype != onnx_proto.TensorProto.DOUBLE: # pylint: disable=E1101
167 proto_dtype = onnx_proto.TensorProto.FLOAT # pylint: disable=E1101
168 if options['cast']:
169 concat_name = scope.get_unique_variable_name('cast_lgbm')
170 apply_cast(scope, concat_name, operator.outputs[0].full_name, container,
171 operator_name=scope.get_unique_operator_name('cast_lgmb'),
172 to=proto_dtype)
173 else:
174 concat_name = operator.outputs[0].full_name
176 apply_concat(scope, [_.full_name for _ in operator.inputs],
177 concat_name, container, axis=1)