Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1""" 

2@file 

3@brief Helpers to investigate a tree structure. 

4""" 

5import numpy 

6from sklearn.tree._tree import TREE_LEAF # pylint: disable=E0611 

7 

8 

9def _get_tree(obj): 

10 """ 

11 Returns the tree object. 

12 """ 

13 if hasattr(obj, "children_left"): 

14 return obj 

15 if hasattr(obj, "tree_"): 

16 return obj.tree_ 

17 raise AttributeError( # pragma: no cover 

18 "obj is no tree: {}".format(type(obj))) 

19 

20 

21def tree_leave_index(model): 

22 """ 

23 Returns the indices of every leave in a tree. 

24 

25 @param model something which has a member ``tree_`` 

26 @return leave indices 

27 """ 

28 tree = _get_tree(model) 

29 res = [] 

30 for i in range(tree.node_count): 

31 if tree.children_left[i] == TREE_LEAF: 

32 res.append(i) 

33 return res 

34 

35 

36def tree_find_path_to_root(tree, i, parents=None): 

37 """ 

38 Lists nodes involved into the path to find node *i*. 

39 

40 @param tree tree 

41 @param i node index (``tree.nodes[i]``) 

42 @param parents precomputed parents (None -> calls @see fn tree_node_range) 

43 @return one array of size *(D, 2)* where *D* is the number of dimensions 

44 """ 

45 tree = _get_tree(tree) 

46 path_i = [i] 

47 current_i = i 

48 while current_i in parents: 

49 current_i = parents[current_i] 

50 if current_i < 0: 

51 current_i = - current_i 

52 path_i.append(current_i) 

53 return list(reversed(path_i)) 

54 

55 

56def tree_find_common_node(tree, i, j, parents=None): 

57 """ 

58 Finds the common node to nodes *i* and *j*. 

59 

60 @param tree tree 

61 @param i node index (``tree.nodes[i]``) 

62 @param j node index (``tree.nodes[j]``) 

63 @param parents precomputed parents (None -> calls @see fn tree_node_range) 

64 @return common root, remaining path to *i*, remaining path to *j* 

65 """ 

66 tree = _get_tree(tree) 

67 if parents is None: 

68 parents = tree_node_parents(tree) 

69 path_i = tree_find_path_to_root(tree, i, parents) 

70 path_j = tree_find_path_to_root(tree, j, parents) 

71 for pos, (a, b) in enumerate(zip(path_i, path_j)): 

72 if a != b: 

73 return a, path_i[pos:], path_j[pos:] 

74 pi = parents.get(i, None) 

75 pj = parents.get(j, None) 

76 pos = min(len(path_i), len(path_j)) 

77 if pi is not None and pi == j: 

78 return j, path_i[pos:], path_j[pos:] 

79 if pj is not None and pj == i: 

80 return i, path_i[pos:], path_j[pos:] 

81 raise RuntimeError( # pragma: no cover 

82 "Paths are equal, i={} and j={} must be differet.".format(i, j)) 

83 

84 

85def tree_node_parents(tree): 

86 """ 

87 Returns a dictionary ``{node_id: parent_id}``. 

88 

89 @param tree tree 

90 @return parents 

91 """ 

92 tree = _get_tree(tree) 

93 parents = {} 

94 for i in range(tree.node_count): 

95 if tree.children_left[i] == TREE_LEAF: 

96 continue 

97 parents[tree.children_left[i]] = i 

98 parents[tree.children_right[i]] = -i 

99 return parents 

100 

101 

102def tree_node_range(tree, i, parents=None): 

103 """ 

104 Determines the ranges for a node all dimensions. 

105 ``nan`` means infinity. 

106 

107 @param tree tree 

108 @param i node index (``tree.nodes[i]``) 

109 @param parents precomputed parents (None -> calls @see fn tree_node_range) 

110 @return one array of size *(D, 2)* where *D* is the number of dimensions 

111 

112 The following example shows what the function returns 

113 in case of simple grid in two dimensions. 

114 

115 .. runpython:: 

116 :showcode: 

117 

118 import numpy 

119 from sklearn.tree import DecisionTreeClassifier 

120 from mlinsights.mltree import tree_leave_index, tree_node_range 

121 

122 X = numpy.array([[0, 0], [0, 1], [0, 2], 

123 [1, 0], [1, 1], [1, 2], 

124 [2, 0], [2, 1], [2, 2]]) 

125 y = list(range(X.shape[0])) 

126 clr = DecisionTreeClassifier(max_depth=4) 

127 clr.fit(X, y) 

128 

129 leaves = tree_leave_index(clr) 

130 ra = tree_node_range(clr, leaves[0]) 

131 

132 print(ra) 

133 """ 

134 tree = _get_tree(tree) 

135 if parents is None: 

136 parents = tree_node_parents(tree) 

137 path = tree_find_path_to_root(tree, i, parents) 

138 mx = max([tree.feature[p] for p in path]) 

139 res = numpy.full((mx + 1, 2), numpy.nan) 

140 for ind, p in enumerate(path): 

141 if p == i: 

142 break 

143 fn = tree.feature[p] 

144 lr = tree.children_left[p] == path[ind + 1] 

145 th = tree.threshold[p] 

146 if lr: 

147 res[fn, 1] = min(res[fn, 1], th) if not numpy.isnan( 

148 res[fn, 1]) else th 

149 else: 

150 res[fn, 0] = max(res[fn, 0], th) if not numpy.isnan( 

151 res[fn, 0]) else th 

152 return res 

153 

154 

155def predict_leaves(model, X): 

156 """ 

157 Returns the leave every observations of *X* 

158 falls into. 

159 

160 @param model a decision tree 

161 @param X observations 

162 @return array of leaves 

163 """ 

164 if hasattr(model, 'get_leaves_index'): 

165 leaves_index = model.get_leaves_index() 

166 else: 

167 leaves_index = [i for i in range(len(model.tree_.children_left)) 

168 if model.tree_.children_left[i] == TREE_LEAF] 

169 leaves = model.decision_path(X) 

170 leaves = leaves[:, leaves_index] 

171 mat = numpy.argmax(leaves, 1) 

172 res = numpy.asarray(mat).ravel() 

173 res = numpy.array([leaves_index[r] for r in res]) 

174 return res 

175 

176 

177def tree_leave_neighbors(model): 

178 """ 

179 The function determines which leaves are neighbors. 

180 The method uses some memory as it creates creates a 

181 grid of the feature spaces, each split multiplies the 

182 number of cells by two. 

183 

184 @param model a :epkg:`sklearn:tree:DecisionTreeRegressor`, 

185 a :epkg:`sklearn:tree:DecisionTreeClassifier`, 

186 a model which has a member ``tree_`` 

187 @return a dictionary ``{(i, j): (dimension, x1, x2)}``, 

188 *i, j* are node indices, if :math:`X_d * sign < th * sign`, 

189 the observations goes to node *i*, *j* otherwise, 

190 *i < j*. The border is somewhere in the segment ``[x1, x2]``. 

191 

192 The following example shows what the function returns 

193 in case of simple grid in two dimensions. 

194 

195 .. runpython:: 

196 :showcode: 

197 

198 import numpy 

199 from sklearn.tree import DecisionTreeClassifier 

200 from mlinsights.mltree import tree_leave_neighbors 

201 

202 X = numpy.array([[0, 0], [0, 1], [0, 2], 

203 [1, 0], [1, 1], [1, 2], 

204 [2, 0], [2, 1], [2, 2]]) 

205 y = list(range(X.shape[0])) 

206 clr = DecisionTreeClassifier(max_depth=4) 

207 clr.fit(X, y) 

208 

209 nei = tree_leave_neighbors(clr) 

210 

211 import pprint 

212 pprint.pprint(nei) 

213 """ 

214 tree = _get_tree(model) 

215 

216 # creates the coordinates of the grid 

217 

218 features = {} 

219 for i in range(tree.node_count): 

220 fe = tree.feature[i] 

221 if fe < 0: 

222 # leave 

223 continue 

224 th = tree.threshold[i] 

225 if fe not in features: 

226 features[fe] = [] 

227 features[fe].append(th) 

228 for fe in features: 

229 features[fe] = list(sorted(set(features[fe]))) 

230 for fe, v in features.items(): 

231 if len(v) == 1: 

232 d = abs(v[0]) / 10 

233 if d == v[0]: 

234 d = 1 

235 v.insert(0, v[0] - d) 

236 v.append(v[-1] + d) 

237 else: 

238 diff = [v[i + 1] - v[i] for i in range(len(v) - 1)] 

239 mdiff = min(diff) 

240 v.append(v[-1] + mdiff) 

241 v.insert(0, v[0] - mdiff) 

242 

243 # predictions 

244 

245 keys = list(sorted(features)) 

246 pos = [0 for k in keys] 

247 shape = [len(features[k]) - 1 for k in keys] 

248 cells = numpy.full(shape, 0, numpy.int32) 

249 while pos[0] < len(features[keys[0]]) - 1: 

250 # evaluate 

251 xy = numpy.zeros((1, model.n_features_)) 

252 for p, k in zip(pos, keys): 

253 xy[0, k] = (features[k][p] + features[k][p + 1]) / 2 

254 leave = predict_leaves(model, xy) 

255 cells[tuple(pos)] = leave[0] 

256 

257 # next 

258 ind = len(pos) - 1 

259 pos[ind] += 1 

260 while ind > 0 and pos[ind] >= len(features[keys[ind]]) - 1: 

261 pos[ind] = 0 

262 ind -= 1 

263 pos[ind] += 1 

264 

265 # neighbors 

266 

267 neighbors = {} 

268 pos = [0 for k in keys] 

269 while pos[0] <= len(features[keys[0]]) - 1: 

270 # neighbors 

271 try: 

272 cl = cells[tuple(pos)] 

273 except IndexError: 

274 # outside the cube 

275 cl = None 

276 if cl is not None: 

277 for k in range(len(pos)): # pylint: disable=C0200 

278 pos[k] += 1 

279 try: 

280 cl2 = cells[tuple(pos)] 

281 except IndexError: 

282 # outside the cube 

283 pos[k] -= 1 

284 continue 

285 if cl != cl2: 

286 edge = (cl, cl2) if cl < cl2 else (cl2, cl) 

287 if edge not in neighbors: 

288 neighbors[edge] = [] 

289 xy = numpy.zeros((model.n_features_)) 

290 for p, f in zip(pos, keys): 

291 xy[f] = (features[f][p] + features[f][p + 1]) / 2 

292 x2 = tuple(xy) 

293 pos[k] -= 1 

294 p = pos[k] 

295 key = keys[k] 

296 xy[key] = (features[key][p] + features[key][p + 1]) / 2 

297 x1 = tuple(xy) 

298 neighbors[edge].append((key, x1, x2)) 

299 else: 

300 pos[k] -= 1 

301 

302 # next 

303 

304 ind = len(pos) - 1 

305 pos[ind] += 1 

306 while ind > 0 and pos[ind] >= len(features[keys[ind]]) - 1: 

307 pos[ind] = 0 

308 ind -= 1 

309 pos[ind] += 1 

310 

311 return neighbors