Coverage for mlprodict/onnx_conv/operator_converters/parse_lightgbm.py: 99%

97 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-04 02:28 +0100

1""" 

2@file 

3@brief Parsers for LightGBM booster. 

4""" 

5import numpy 

6from sklearn.base import ClassifierMixin 

7from skl2onnx._parse import _parse_sklearn_classifier, _parse_sklearn_simple_model 

8from skl2onnx.common._apply_operation import apply_concat, apply_cast 

9from skl2onnx.common.data_types import guess_proto_type 

10from skl2onnx.proto import onnx_proto 

11 

12 

13class WrappedLightGbmBooster: 

14 """ 

15 A booster can be a classifier, a regressor. 

16 Trick to wrap it in a minimal function. 

17 """ 

18 

19 def __init__(self, booster): 

20 self.booster_ = booster 

21 self.n_features_ = self.booster_.feature_name() 

22 self.objective_ = self.get_objective() 

23 if self.objective_.startswith('binary'): 

24 self.operator_name = 'LgbmClassifier' 

25 self.classes_ = self._generate_classes(booster) 

26 elif self.objective_.startswith('multiclass'): 

27 self.operator_name = 'LgbmClassifier' 

28 self.classes_ = self._generate_classes(booster) 

29 elif self.objective_.startswith('regression'): # pragma: no cover 

30 self.operator_name = 'LgbmRegressor' 

31 else: # pragma: no cover 

32 raise NotImplementedError( 

33 f'Unsupported LightGbm objective: {self.objective_!r}.') 

34 try: 

35 bt = self.booster_.attr('boosting_type') 

36 except KeyError: 

37 bt = None 

38 if bt is None: 

39 try: 

40 bt = self.booster_.params['boosting_type'] 

41 except AttributeError: 

42 bt = 'gbdt' 

43 self.boosting_type = bt 

44 # if average_output: 

45 # self.boosting_type = 'rf' 

46 # else: 

47 # self.boosting_type = 'gbdt' 

48 

49 @staticmethod 

50 def _generate_classes(booster): 

51 if isinstance(booster, dict): 

52 num_class = booster['num_class'] 

53 else: 

54 num_class = booster.attr('num_class') 

55 if num_class is None: 

56 dp = booster.dump_model(num_iteration=1) 

57 num_class = dp['num_class'] 

58 if num_class == 1: 

59 return numpy.asarray([0, 1]) 

60 return numpy.arange(num_class) 

61 

62 def get_objective(self): 

63 "Returns the objective." 

64 if hasattr(self, 'objective_') and self.objective_ is not None: 

65 return self.objective_ 

66 objective = self.booster_.attr('objective') 

67 if objective is not None: 

68 return objective 

69 dp = self.booster_.dump_model(num_iteration=1) 

70 return dp['objective'] 

71 

72 

73class WrappedLightGbmBoosterClassifier(ClassifierMixin): 

74 """ 

75 Trick to wrap a LGBMClassifier into a class. 

76 """ 

77 

78 def __init__(self, wrapped): # pylint: disable=W0231 

79 for k in {'boosting_type', '_model_dict', '_model_dict_info', 

80 'operator_name', 'classes_', 'booster_', 'n_features_', 

81 'objective_'}: 

82 if hasattr(wrapped, k): 

83 setattr(self, k, getattr(wrapped, k)) 

84 

85 

86class MockWrappedLightGbmBoosterClassifier(WrappedLightGbmBoosterClassifier): 

87 """ 

88 Mocked lightgbm. 

89 """ 

90 

91 def __init__(self, tree): # pylint: disable=W0231 

92 self.dumped_ = tree 

93 self.is_mock = True 

94 

95 def dump_model(self): 

96 "mock dump_model method" 

97 self.visited = True 

98 return self.dumped_ 

99 

100 def feature_name(self): 

101 "Returns binary features names." 

102 return [0, 1] 

103 

104 def attr(self, key): 

105 "Returns default values for common attributes." 

106 if key == 'objective': 

107 return "binary" 

108 if key == 'num_class': 

109 return 1 

110 if key == 'average_output': 

111 return None 

112 raise KeyError( # pragma: no cover 

113 f"No response for {key!r}.") 

114 

115 

116def lightgbm_parser(scope, model, inputs, custom_parsers=None): 

117 """ 

118 Agnostic parser for LightGBM Booster. 

119 """ 

120 if hasattr(model, "fit"): 

121 raise TypeError( # pragma: no cover 

122 f"This converter does not apply on type '{type(model)}'.") 

123 

124 if len(inputs) == 1: 

125 wrapped = WrappedLightGbmBooster(model) 

126 objective = wrapped.get_objective() 

127 if objective.startswith('binary'): 

128 wrapped = WrappedLightGbmBoosterClassifier(wrapped) 

129 return _parse_sklearn_classifier( 

130 scope, wrapped, inputs, custom_parsers=custom_parsers) 

131 if objective.startswith('multiclass'): 

132 wrapped = WrappedLightGbmBoosterClassifier(wrapped) 

133 return _parse_sklearn_classifier( 

134 scope, wrapped, inputs, custom_parsers=custom_parsers) 

135 if objective.startswith('regression'): # pragma: no cover 

136 return _parse_sklearn_simple_model( 

137 scope, wrapped, inputs, custom_parsers=custom_parsers) 

138 raise NotImplementedError( # pragma: no cover 

139 f"Objective '{objective}' is not implemented yet.") 

140 

141 # Multiple columns 

142 this_operator = scope.declare_local_operator('LightGBMConcat') 

143 this_operator.raw_operator = model 

144 this_operator.inputs = inputs 

145 var = scope.declare_local_variable( 

146 'Xlgbm', inputs[0].type.__class__([None, None])) 

147 this_operator.outputs.append(var) 

148 return lightgbm_parser(scope, model, this_operator.outputs, 

149 custom_parsers=custom_parsers) 

150 

151 

152def shape_calculator_lightgbm_concat(operator): 

153 """ 

154 Shape calculator for operator *LightGBMConcat*. 

155 """ 

156 pass 

157 

158 

159def converter_lightgbm_concat(scope, operator, container): 

160 """ 

161 Converter for operator *LightGBMConcat*. 

162 """ 

163 op = operator.raw_operator 

164 options = container.get_options(op, dict(cast=False)) 

165 proto_dtype = guess_proto_type(operator.inputs[0].type) 

166 if proto_dtype != onnx_proto.TensorProto.DOUBLE: # pylint: disable=E1101 

167 proto_dtype = onnx_proto.TensorProto.FLOAT # pylint: disable=E1101 

168 if options['cast']: 

169 concat_name = scope.get_unique_variable_name('cast_lgbm') 

170 apply_cast(scope, concat_name, operator.outputs[0].full_name, container, 

171 operator_name=scope.get_unique_operator_name('cast_lgmb'), 

172 to=proto_dtype) 

173 else: 

174 concat_name = operator.outputs[0].full_name 

175 

176 apply_concat(scope, [_.full_name for _ in operator.inputs], 

177 concat_name, container, axis=1)