Coverage for mlprodict/onnx_conv/helpers/lgbm_helper.py: 93%

167 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-04 02:28 +0100

1""" 

2@file 

3@brief Helpers to speed up the conversion of Lightgbm models or transform it. 

4""" 

5from collections import deque 

6import ctypes 

7import json 

8import re 

9 

10 

11def restore_lgbm_info(tree): 

12 """ 

13 Restores speed up information to help 

14 modifying the structure of the tree. 

15 """ 

16 

17 def walk_through(t): 

18 if 'tree_info' in t: 

19 yield None 

20 elif 'tree_structure' in t: 

21 for w in walk_through(t['tree_structure']): 

22 yield w 

23 else: 

24 yield t 

25 if 'left_child' in t: 

26 for w in walk_through(t['left_child']): 

27 yield w 

28 if 'right_child' in t: 

29 for w in walk_through(t['right_child']): 

30 yield w 

31 

32 nodes = [] 

33 if 'tree_info' in tree: 

34 for node in walk_through(tree): 

35 if node is None: 

36 nodes.append([]) 

37 elif 'right_child' in node or 'left_child' in node: 

38 nodes[-1].append(node) 

39 else: 

40 for node in walk_through(tree): 

41 if 'right_child' in node or 'left_child' in node: 

42 nodes.append(node) 

43 return nodes 

44 

45 

46def dump_booster_model(self, num_iteration=None, start_iteration=0, 

47 importance_type='split', verbose=0): 

48 """ 

49 Dumps Booster to JSON format. 

50 

51 Parameters 

52 ---------- 

53 self: booster 

54 num_iteration : int or None, optional (default=None) 

55 Index of the iteration that should be dumped. 

56 If None, if the best iteration exists, it is dumped; otherwise, 

57 all iterations are dumped. 

58 If <= 0, all iterations are dumped. 

59 start_iteration : int, optional (default=0) 

60 Start index of the iteration that should be dumped. 

61 importance_type : string, optional (default="split") 

62 What type of feature importance should be dumped. 

63 If "split", result contains numbers of times the feature is used in a model. 

64 If "gain", result contains total gains of splits which use the feature. 

65 verbose: dispays progress (usefull for big trees) 

66 

67 Returns 

68 ------- 

69 json_repr : dict 

70 JSON format of Booster. 

71 

72 .. note:: 

73 This function is inspired from 

74 the :epkg:`lightgbm` (`dump_model 

75 <https://lightgbm.readthedocs.io/en/latest/pythonapi/ 

76 lightgbm.Booster.html#lightgbm.Booster.dump_model>`_. 

77 It creates intermediate structure to speed up the conversion 

78 into ONNX of such model. The function overwrites the 

79 `json.load` to fastly extract nodes. 

80 """ 

81 if getattr(self, 'is_mock', False): 

82 return self.dump_model(), None 

83 from lightgbm.basic import ( 

84 _LIB, FEATURE_IMPORTANCE_TYPE_MAPPER, _safe_call, 

85 json_default_with_numpy) 

86 if num_iteration is None: 

87 num_iteration = self.best_iteration 

88 importance_type_int = FEATURE_IMPORTANCE_TYPE_MAPPER[importance_type] 

89 buffer_len = 1 << 20 

90 tmp_out_len = ctypes.c_int64(0) 

91 string_buffer = ctypes.create_string_buffer(buffer_len) 

92 ptr_string_buffer = ctypes.c_char_p(*[ctypes.addressof(string_buffer)]) 

93 if verbose >= 2: 

94 print( # pragma: no cover 

95 "[dump_booster_model] call CAPI: LGBM_BoosterDumpModel") 

96 _safe_call(_LIB.LGBM_BoosterDumpModel( 

97 self.handle, 

98 ctypes.c_int(start_iteration), 

99 ctypes.c_int(num_iteration), 

100 ctypes.c_int(importance_type_int), 

101 ctypes.c_int64(buffer_len), 

102 ctypes.byref(tmp_out_len), 

103 ptr_string_buffer)) 

104 actual_len = tmp_out_len.value 

105 # if buffer length is not long enough, reallocate a buffer 

106 if actual_len > buffer_len: 

107 string_buffer = ctypes.create_string_buffer(actual_len) 

108 ptr_string_buffer = ctypes.c_char_p( 

109 *[ctypes.addressof(string_buffer)]) 

110 _safe_call(_LIB.LGBM_BoosterDumpModel( 

111 self.handle, 

112 ctypes.c_int(start_iteration), 

113 ctypes.c_int(num_iteration), 

114 ctypes.c_int(importance_type_int), 

115 ctypes.c_int64(actual_len), 

116 ctypes.byref(tmp_out_len), 

117 ptr_string_buffer)) 

118 

119 WHITESPACE = re.compile( 

120 r'[ \t\n\r]*', re.VERBOSE | re.MULTILINE | re.DOTALL) 

121 

122 class Hook(json.JSONDecoder): 

123 """ 

124 Keep track of the progress, stores a copy of all objects with 

125 a decision into a different container in order to walk through 

126 all nodes in a much faster way than going through the architecture. 

127 """ 

128 

129 def __init__(self, *args, info=None, n_trees=None, verbose=0, 

130 **kwargs): 

131 json.JSONDecoder.__init__( 

132 self, object_hook=self.hook, *args, **kwargs) 

133 self.nodes = [] 

134 self.buffer = [] 

135 self.info = info 

136 self.n_trees = n_trees 

137 self.verbose = verbose 

138 self.stored = 0 

139 if verbose >= 2 and n_trees is not None: 

140 from tqdm import tqdm # pragma: no cover 

141 self.loop = tqdm(total=n_trees) # pragma: no cover 

142 self.loop.set_description("dump_booster") # pragma: no cover 

143 else: 

144 self.loop = None 

145 

146 def decode(self, s, _w=WHITESPACE.match): 

147 return json.JSONDecoder.decode(self, s, _w=_w) 

148 

149 def raw_decode(self, s, idx=0): 

150 return json.JSONDecoder.raw_decode(self, s, idx=idx) 

151 

152 def hook(self, obj): 

153 """ 

154 Hook called everytime a JSON object is created. 

155 Keep track of the progress, stores a copy of all objects with 

156 a decision into a different container. 

157 """ 

158 # Every obj goes through this function from the leaves to the root. 

159 if 'tree_info' in obj: 

160 self.info['decision_nodes'] = self.nodes 

161 if self.n_trees is not None and len(self.nodes) != self.n_trees: 

162 raise RuntimeError( # pragma: no cover 

163 "Unexpected number of trees %d (expecting %d)." % ( 

164 len(self.nodes), self.n_trees)) 

165 self.nodes = [] 

166 if self.loop is not None: 

167 self.loop.close() 

168 if 'tree_structure' in obj: 

169 self.nodes.append(self.buffer) 

170 if self.loop is not None: 

171 self.loop.update(len(self.nodes)) 

172 if len(self.nodes) % 10 == 0: 

173 self.loop.set_description( 

174 "dump_booster: %d/%d trees, %d nodes" % ( 

175 len(self.nodes), self.n_trees, self.stored)) 

176 self.buffer = [] 

177 if "decision_type" in obj: 

178 self.buffer.append(obj) 

179 self.stored += 1 

180 return obj 

181 

182 if verbose >= 2: 

183 print("[dump_booster_model] to_json") # pragma: no cover 

184 info = {} 

185 ret = json.loads(string_buffer.value.decode('utf-8'), cls=Hook, 

186 info=info, n_trees=self.num_trees(), verbose=verbose) 

187 ret['pandas_categorical'] = json.loads( 

188 json.dumps(self.pandas_categorical, 

189 default=json_default_with_numpy)) 

190 if verbose >= 2: 

191 print("[dump_booster_model] end.") # pragma: no cover 

192 return ret, info 

193 

194 

195def dump_lgbm_booster(booster, verbose=0): 

196 """ 

197 Dumps a Lightgbm booster into JSON. 

198 

199 :param booster: Lightgbm booster 

200 :param verbose: verbosity 

201 :return: json, dictionary with more information 

202 """ 

203 js, info = dump_booster_model(booster, verbose=verbose) 

204 return js, info 

205 

206 

207def modify_tree_for_rule_in_set(gbm, use_float=False, verbose=0, count=0, # pylint: disable=R1710 

208 info=None): 

209 """ 

210 LightGBM produces sometimes a tree with a node set 

211 to use rule ``==`` to a set of values (= in set), 

212 the values are separated by ``||``. 

213 This function unfold theses nodes. 

214 

215 :param gbm: a tree coming from lightgbm dump 

216 :param use_float: use float otherwise int first 

217 then float if it does not work 

218 :param verbose: verbosity, use :epkg:`tqdm` to show progress 

219 :param count: number of nodes already changed (origin) before this call 

220 :param info: addition information to speed up this search 

221 :return: number of changed nodes (include *count*) 

222 

223 A child looks like the following: 

224 

225 .. runpython:: 

226 :showcode: 

227 :warningout: DeprecationWarning 

228 

229 import pprint 

230 from mlprodict.onnx_conv.operator_converters.conv_lightgbm import modify_tree_for_rule_in_set 

231 

232 tree = {'decision_type': '==', 

233 'default_left': True, 

234 'internal_count': 6805, 

235 'internal_value': 0.117558, 

236 'left_child': {'leaf_count': 4293, 

237 'leaf_index': 18, 

238 'leaf_value': 0.003519117642745049}, 

239 'missing_type': 'None', 

240 'right_child': {'leaf_count': 2512, 

241 'leaf_index': 25, 

242 'leaf_value': 0.012305307958365394}, 

243 'split_feature': 24, 

244 'split_gain': 12.233599662780762, 

245 'split_index': 24, 

246 'threshold': '10||12||13'} 

247 

248 modify_tree_for_rule_in_set(tree) 

249 

250 pprint.pprint(tree) 

251 """ 

252 if 'tree_info' in gbm: 

253 if info is not None: 

254 dec_nodes = info['decision_nodes'] 

255 else: 

256 dec_nodes = None 

257 if verbose >= 2: # pragma: no cover 

258 from tqdm import tqdm 

259 loop = tqdm(gbm['tree_info']) 

260 for i, tree in enumerate(loop): 

261 loop.set_description("rules tree %d c=%d" % (i, count)) 

262 count = modify_tree_for_rule_in_set( 

263 tree, use_float=use_float, count=count, 

264 info=None if dec_nodes is None else dec_nodes[i]) 

265 else: 

266 for i, tree in enumerate(gbm['tree_info']): 

267 count = modify_tree_for_rule_in_set( 

268 tree, use_float=use_float, count=count, 

269 info=None if dec_nodes is None else dec_nodes[i]) 

270 return count 

271 

272 if 'tree_structure' in gbm: 

273 return modify_tree_for_rule_in_set( 

274 gbm['tree_structure'], use_float=use_float, count=count, 

275 info=info) 

276 

277 if 'decision_type' not in gbm: 

278 return count 

279 

280 def str2number(val): 

281 if use_float: 

282 return float(val) 

283 else: 

284 try: 

285 return int(val) 

286 except ValueError: # pragma: no cover 

287 return float(val) 

288 

289 if info is None: 

290 

291 def recursive_call(this, c): 

292 if 'left_child' in this: 

293 c = process_node(this['left_child'], count=c) 

294 if 'right_child' in this: 

295 c = process_node(this['right_child'], count=c) 

296 return c 

297 

298 def process_node(node, count): 

299 if 'decision_type' not in node: 

300 return count 

301 if node['decision_type'] != '==': 

302 return recursive_call(node, count) 

303 th = node['threshold'] 

304 if not isinstance(th, str): 

305 return recursive_call(node, count) 

306 pos = th.find('||') 

307 if pos == -1: 

308 return recursive_call(node, count) 

309 th1 = str2number(th[:pos]) 

310 

311 def doit(): 

312 rest = th[pos + 2:] 

313 if '||' not in rest: 

314 rest = str2number(rest) 

315 

316 node['threshold'] = th1 

317 new_node = node.copy() 

318 node['right_child'] = new_node 

319 new_node['threshold'] = rest 

320 

321 doit() 

322 return recursive_call(node, count + 1) 

323 

324 return process_node(gbm, count) 

325 

326 # when info is used 

327 

328 def split_node(node, th, pos): 

329 th1 = str2number(th[:pos]) 

330 

331 rest = th[pos + 2:] 

332 if '||' not in rest: 

333 rest = str2number(rest) 

334 app = False 

335 else: 

336 app = True 

337 

338 node['threshold'] = th1 

339 new_node = node.copy() 

340 node['right_child'] = new_node 

341 new_node['threshold'] = rest 

342 return new_node, app 

343 

344 stack = deque(info) 

345 while len(stack) > 0: 

346 node = stack.pop() 

347 

348 if 'decision_type' not in node: 

349 continue # leave 

350 

351 if node['decision_type'] != '==': 

352 continue 

353 

354 th = node['threshold'] 

355 if not isinstance(th, str): 

356 continue 

357 

358 pos = th.find('||') 

359 if pos == -1: 

360 continue 

361 

362 new_node, app = split_node(node, th, pos) 

363 count += 1 

364 if app: 

365 stack.append(new_node) 

366 

367 return count