Coverage for mlinsights/mltree/tree_structure.py: 97%
154 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-28 08:46 +0100
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-28 08:46 +0100
1"""
2@file
3@brief Helpers to investigate a tree structure.
4"""
5import numpy
6from sklearn.tree._tree import TREE_LEAF # pylint: disable=E0611
9def _get_tree(obj):
10 """
11 Returns the tree object.
12 """
13 if hasattr(obj, "children_left"):
14 return obj
15 if hasattr(obj, "tree_"):
16 return obj.tree_
17 raise AttributeError( # pragma: no cover
18 f"obj is no tree: {type(obj)}")
21def tree_leave_index(model):
22 """
23 Returns the indices of every leave in a tree.
25 @param model something which has a member ``tree_``
26 @return leave indices
27 """
28 tree = _get_tree(model)
29 res = []
30 for i in range(tree.node_count):
31 if tree.children_left[i] == TREE_LEAF:
32 res.append(i)
33 return res
36def tree_find_path_to_root(tree, i, parents=None):
37 """
38 Lists nodes involved into the path to find node *i*.
40 @param tree tree
41 @param i node index (``tree.nodes[i]``)
42 @param parents precomputed parents (None -> calls @see fn tree_node_range)
43 @return one array of size *(D, 2)* where *D* is the number of dimensions
44 """
45 tree = _get_tree(tree)
46 path_i = [i]
47 current_i = i
48 while current_i in parents:
49 current_i = parents[current_i]
50 if current_i < 0:
51 current_i = - current_i
52 path_i.append(current_i)
53 return list(reversed(path_i))
56def tree_find_common_node(tree, i, j, parents=None):
57 """
58 Finds the common node to nodes *i* and *j*.
60 @param tree tree
61 @param i node index (``tree.nodes[i]``)
62 @param j node index (``tree.nodes[j]``)
63 @param parents precomputed parents (None -> calls @see fn tree_node_range)
64 @return common root, remaining path to *i*, remaining path to *j*
65 """
66 tree = _get_tree(tree)
67 if parents is None:
68 parents = tree_node_parents(tree)
69 path_i = tree_find_path_to_root(tree, i, parents)
70 path_j = tree_find_path_to_root(tree, j, parents)
71 for pos, (a, b) in enumerate(zip(path_i, path_j)):
72 if a != b:
73 return a, path_i[pos:], path_j[pos:]
74 pi = parents.get(i, None)
75 pj = parents.get(j, None)
76 pos = min(len(path_i), len(path_j))
77 if pi is not None and pi == j:
78 return j, path_i[pos:], path_j[pos:]
79 if pj is not None and pj == i:
80 return i, path_i[pos:], path_j[pos:]
81 raise RuntimeError( # pragma: no cover
82 f"Paths are equal, i={i} and j={j} must be differet.")
85def tree_node_parents(tree):
86 """
87 Returns a dictionary ``{node_id: parent_id}``.
89 @param tree tree
90 @return parents
91 """
92 tree = _get_tree(tree)
93 parents = {}
94 for i in range(tree.node_count):
95 if tree.children_left[i] == TREE_LEAF:
96 continue
97 parents[tree.children_left[i]] = i
98 parents[tree.children_right[i]] = -i
99 return parents
102def tree_node_range(tree, i, parents=None):
103 """
104 Determines the ranges for a node all dimensions.
105 ``nan`` means infinity.
107 @param tree tree
108 @param i node index (``tree.nodes[i]``)
109 @param parents precomputed parents (None -> calls @see fn tree_node_range)
110 @return one array of size *(D, 2)* where *D* is the number of dimensions
112 The following example shows what the function returns
113 in case of simple grid in two dimensions.
115 .. runpython::
116 :showcode:
118 import numpy
119 from sklearn.tree import DecisionTreeClassifier
120 from mlinsights.mltree import tree_leave_index, tree_node_range
122 X = numpy.array([[0, 0], [0, 1], [0, 2],
123 [1, 0], [1, 1], [1, 2],
124 [2, 0], [2, 1], [2, 2]])
125 y = list(range(X.shape[0]))
126 clr = DecisionTreeClassifier(max_depth=4)
127 clr.fit(X, y)
129 leaves = tree_leave_index(clr)
130 ra = tree_node_range(clr, leaves[0])
132 print(ra)
133 """
134 tree = _get_tree(tree)
135 if parents is None:
136 parents = tree_node_parents(tree)
137 path = tree_find_path_to_root(tree, i, parents)
138 mx = max([tree.feature[p] for p in path])
139 res = numpy.full((mx + 1, 2), numpy.nan)
140 for ind, p in enumerate(path):
141 if p == i:
142 break
143 fn = tree.feature[p]
144 lr = tree.children_left[p] == path[ind + 1]
145 th = tree.threshold[p]
146 if lr:
147 res[fn, 1] = min(res[fn, 1], th) if not numpy.isnan(
148 res[fn, 1]) else th
149 else:
150 res[fn, 0] = max(res[fn, 0], th) if not numpy.isnan(
151 res[fn, 0]) else th
152 return res
155def predict_leaves(model, X):
156 """
157 Returns the leave every observations of *X*
158 falls into.
160 @param model a decision tree
161 @param X observations
162 @return array of leaves
163 """
164 if hasattr(model, 'get_leaves_index'):
165 leaves_index = model.get_leaves_index()
166 else:
167 leaves_index = [i for i in range(len(model.tree_.children_left))
168 if model.tree_.children_left[i] == TREE_LEAF]
169 leaves = model.decision_path(X)
170 leaves = leaves[:, leaves_index]
171 mat = numpy.argmax(leaves, 1)
172 res = numpy.asarray(mat).ravel()
173 res = numpy.array([leaves_index[r] for r in res])
174 return res
177def tree_leave_neighbors(model):
178 """
179 The function determines which leaves are neighbors.
180 The method uses some memory as it creates creates a
181 grid of the feature spaces, each split multiplies the
182 number of cells by two.
184 @param model a :epkg:`sklearn:tree:DecisionTreeRegressor`,
185 a :epkg:`sklearn:tree:DecisionTreeClassifier`,
186 a model which has a member ``tree_``
187 @return a dictionary ``{(i, j): (dimension, x1, x2)}``,
188 *i, j* are node indices, if :math:`X_d * sign < th * sign`,
189 the observations goes to node *i*, *j* otherwise,
190 *i < j*. The border is somewhere in the segment ``[x1, x2]``.
192 The following example shows what the function returns
193 in case of simple grid in two dimensions.
195 .. runpython::
196 :showcode:
198 import numpy
199 from sklearn.tree import DecisionTreeClassifier
200 from mlinsights.mltree import tree_leave_neighbors
202 X = numpy.array([[0, 0], [0, 1], [0, 2],
203 [1, 0], [1, 1], [1, 2],
204 [2, 0], [2, 1], [2, 2]])
205 y = list(range(X.shape[0]))
206 clr = DecisionTreeClassifier(max_depth=4)
207 clr.fit(X, y)
209 nei = tree_leave_neighbors(clr)
211 import pprint
212 pprint.pprint(nei)
213 """
214 tree = _get_tree(model)
216 # creates the coordinates of the grid
218 features = {}
219 for i in range(tree.node_count):
220 fe = tree.feature[i]
221 if fe < 0:
222 # leave
223 continue
224 th = tree.threshold[i]
225 if fe not in features:
226 features[fe] = []
227 features[fe].append(th)
228 features_keys = features.keys()
229 for fe in features_keys:
230 features[fe] = list(sorted(set(features[fe])))
231 for fe, v in features.items():
232 if len(v) == 1:
233 d = abs(v[0]) / 10
234 if d == v[0]:
235 d = 1
236 v.insert(0, v[0] - d)
237 v.append(v[-1] + d)
238 else:
239 diff = [v[i + 1] - v[i] for i in range(len(v) - 1)]
240 mdiff = min(diff)
241 v.append(v[-1] + mdiff)
242 v.insert(0, v[0] - mdiff)
244 # predictions
246 keys = list(sorted(features))
247 pos = [0 for k in keys]
248 shape = [len(features[k]) - 1 for k in keys]
249 cells = numpy.full(shape, 0, numpy.int32)
250 while pos[0] < len(features[keys[0]]) - 1:
251 # evaluate
252 xy = numpy.zeros((1, model.n_features_))
253 for p, k in zip(pos, keys):
254 xy[0, k] = (features[k][p] + features[k][p + 1]) / 2
255 leave = predict_leaves(model, xy)
256 cells[tuple(pos)] = leave[0]
258 # next
259 ind = len(pos) - 1
260 pos[ind] += 1
261 while ind > 0 and pos[ind] >= len(features[keys[ind]]) - 1:
262 pos[ind] = 0
263 ind -= 1
264 pos[ind] += 1
266 # neighbors
268 neighbors = {}
269 pos = [0 for k in keys]
270 while pos[0] <= len(features[keys[0]]) - 1:
271 # neighbors
272 try:
273 cl = cells[tuple(pos)]
274 except IndexError:
275 # outside the cube
276 cl = None
277 if cl is not None:
278 for k in range(len(pos)): # pylint: disable=C0200
279 pos[k] += 1
280 try:
281 cl2 = cells[tuple(pos)]
282 except IndexError:
283 # outside the cube
284 pos[k] -= 1
285 continue
286 if cl != cl2:
287 edge = (cl, cl2) if cl < cl2 else (cl2, cl)
288 if edge not in neighbors:
289 neighbors[edge] = []
290 xy = numpy.zeros((model.n_features_))
291 for p, f in zip(pos, keys):
292 xy[f] = (features[f][p] + features[f][p + 1]) / 2
293 x2 = tuple(xy)
294 pos[k] -= 1
295 p = pos[k]
296 key = keys[k]
297 xy[key] = (features[key][p] + features[key][p + 1]) / 2
298 x1 = tuple(xy)
299 neighbors[edge].append((key, x1, x2))
300 else:
301 pos[k] -= 1
303 # next
305 ind = len(pos) - 1
306 pos[ind] += 1
307 while ind > 0 and pos[ind] >= len(features[keys[ind]]) - 1:
308 pos[ind] = 0
309 ind -= 1
310 pos[ind] += 1
312 return neighbors