Coverage for mlinsights/mltree/tree_digitize.py: 98%
85 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-28 08:46 +0100
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-28 08:46 +0100
1"""
2@file
3@brief Helpers to investigate a tree structure.
5.. versionadded:: 0.4
6"""
7import numpy
8from sklearn.tree._tree import Tree # pylint: disable=E0611
9from sklearn.tree import DecisionTreeRegressor
10from ._tree_digitize import tree_add_node # pylint: disable=E0611
13def digitize2tree(bins, right=False):
14 """
15 Builds a decision tree which returns the same result as
16 `lambda x: numpy.digitize(x, bins, right=right)`
17 (see :epkg:`numpy:digitize`).
19 :param bins: array of bins. It has to be 1-dimensional and monotonic.
20 :param right: Indicating whether the intervals include the right
21 or the left bin edge. Default behavior is (right==False)
22 indicating that the interval does not include the right edge.
23 The left bin end is open in this case, i.e.,
24 `bins[i-1] <= x < bins[i]` is the default behavior for
25 monotonically increasing bins.
26 :return: decision tree
28 .. note::
29 The implementation of decision trees in :epkg:`scikit-learn`
30 only allows one type of decision (`<=`). That's why the
31 function throws an exception when `right=False`. However,
32 this could be overcome by using :epkg:`ONNX` where all
33 kind of decision rules are implemented. Default value for
34 right is still *False* to follow *numpy* API even though
35 this value raises an exception in *digitize2tree*.
37 The following example shows what the tree looks like.
39 .. runpython::
40 :showcode:
42 import numpy
43 from sklearn.tree import export_text
44 from mlinsights.mltree import digitize2tree
46 x = numpy.array([0.2, 6.4, 3.0, 1.6])
47 bins = numpy.array([0.0, 1.0, 2.5, 4.0, 7.0])
48 expected = numpy.digitize(x, bins, right=True)
49 tree = digitize2tree(bins, right=True)
50 pred = tree.predict(x.reshape((-1, 1)))
51 print("Comparison with numpy:")
52 print(expected, pred)
53 print("Tree:")
54 print(export_text(tree, feature_names=['x']))
56 See also example :ref:`l-example-digitize`.
58 .. versionadded:: 0.4
59 """
60 if not right:
61 raise RuntimeError(
62 f"right must be True not right={right!r}")
63 ascending = len(bins) <= 1 or bins[0] < bins[1]
65 if not ascending:
66 bins2 = bins[::-1]
67 cl = digitize2tree(bins2, right=right)
68 n = len(bins)
69 for i in range(cl.tree_.value.shape[0]):
70 cl.tree_.value[i, 0, 0] = n - cl.tree_.value[i, 0, 0]
71 return cl
73 tree = Tree(1, numpy.array([1], dtype=numpy.intp), 1)
74 values = []
75 UNUSED = numpy.nan
76 n_nodes = []
78 def add_root(index):
79 if index < 0 or index >= len(bins):
80 raise IndexError( # pragma: no cover
81 "Unexpected index %d / len(bins)=%d." % (
82 index, len(bins)))
83 parent = -1
84 is_left = False
85 is_leaf = False
86 threshold = bins[index]
87 n = tree_add_node(
88 tree, parent, is_left, is_leaf, 0, threshold, 0, 1, 1.)
89 values.append(UNUSED)
90 n_nodes.append(n)
91 return n
93 def add_nodes(parent, i, j, is_left):
94 # add for bins[i:j] (j excluded)
95 if is_left:
96 # it means j is the parent split
97 if i == j:
98 # leaf
99 n = tree_add_node(tree, parent, is_left, True, 0, 0, 0, 1, 1.)
100 n_nodes.append(n)
101 values.append(i)
102 return n
103 if i + 1 == j:
104 # split
105 values.append(UNUSED)
106 th = bins[i]
107 n = tree_add_node(tree, parent, is_left,
108 False, 0, th, 0, 1, 1.)
109 n_nodes.append(n)
110 add_nodes(n, i, i, True)
111 add_nodes(n, i, j, False)
112 return n
113 if i + 1 < j:
114 # split
115 values.append(UNUSED)
116 index = (i + j) // 2
117 th = bins[index]
118 n = tree_add_node(tree, parent, is_left,
119 False, 0, th, 0, 1, 1.)
120 n_nodes.append(n)
121 add_nodes(n, i, index, True)
122 add_nodes(n, index, j, False)
123 return n
124 else:
125 # it means i is the parent split
126 if i + 1 == j:
127 # leaf
128 values.append(j)
129 n = tree_add_node(tree, parent, is_left, True, 0, 0, 0, 1, 1.)
130 n_nodes.append(n)
131 return n
132 if i + 1 < j:
133 # split
134 values.append(UNUSED)
135 index = (i + j) // 2
136 th = bins[index]
137 n = tree_add_node(tree, parent, is_left,
138 False, 0, th, 0, 1, 1.)
139 n_nodes.append(n)
140 add_nodes(n, i, index, True)
141 add_nodes(n, index, j, False)
142 return n
143 raise NotImplementedError( # pragma: no cover
144 f"Unexpected case where i={i!r}, j={j!r}, is_left={is_left!r}.")
146 index = len(bins) // 2
147 add_root(index)
148 add_nodes(0, 0, index, True)
149 add_nodes(0, index, len(bins), False)
151 cl = DecisionTreeRegressor()
152 cl.tree_ = tree
153 cl.tree_.value[:, 0, 0] = numpy.array( # pylint: disable=E1137
154 values, dtype=numpy.float64)
155 cl.n_outputs = 1
156 cl.n_outputs_ = 1
157 try:
158 # scikit-learn >= 0.24
159 cl.n_features_in_ = 1
160 except AttributeError:
161 # scikit-learn < 0.24
162 cl.n_features_ = 1
163 try:
164 # for scikit-learn<=0.23.2
165 cl.n_features_ = 1
166 except AttributeError:
167 pass
168 return cl