# Coverage for mlinsights/mltree/tree_digitize.py: 98%

## 85 statements

, created at 2022-08-09 08:45 +0200

1"""

2@file

3@brief Helpers to investigate a tree structure.

6"""

7import numpy

8from sklearn.tree._tree import Tree # pylint: disable=E0611

9from sklearn.tree import DecisionTreeRegressor

10from ._tree_digitize import tree_add_node # pylint: disable=E0611

13def digitize2tree(bins, right=False):

14 """

15 Builds a decision tree which returns the same result as

16 `lambda x: numpy.digitize(x, bins, right=right)`

17 (see :epkg:`numpy:digitize`).

19 :param bins: array of bins. It has to be 1-dimensional and monotonic.

20 :param right: Indicating whether the intervals include the right

21 or the left bin edge. Default behavior is (right==False)

22 indicating that the interval does not include the right edge.

23 The left bin end is open in this case, i.e.,

24 `bins[i-1] <= x < bins[i]` is the default behavior for

25 monotonically increasing bins.

26 :return: decision tree

28 .. note::

29 The implementation of decision trees in :epkg:`scikit-learn`

30 only allows one type of decision (`<=`). That's why the

31 function throws an exception when `right=False`. However,

32 this could be overcome by using :epkg:`ONNX` where all

33 kind of decision rules are implemented. Default value for

34 right is still *False* to follow *numpy* API even though

35 this value raises an exception in *digitize2tree*.

37 The following example shows what the tree looks like.

39 .. runpython::

40 :showcode:

42 import numpy

43 from sklearn.tree import export_text

44 from mlinsights.mltree import digitize2tree

46 x = numpy.array([0.2, 6.4, 3.0, 1.6])

47 bins = numpy.array([0.0, 1.0, 2.5, 4.0, 7.0])

48 expected = numpy.digitize(x, bins, right=True)

49 tree = digitize2tree(bins, right=True)

50 pred = tree.predict(x.reshape((-1, 1)))

51 print("Comparison with numpy:")

52 print(expected, pred)

53 print("Tree:")

54 print(export_text(tree, feature_names=['x']))

59 """

60 if not right:

61 raise RuntimeError(

62 f"right must be True not right={right!r}")

63 ascending = len(bins) <= 1 or bins[0] < bins[1]

65 if not ascending:

66 bins2 = bins[::-1]

67 cl = digitize2tree(bins2, right=right)

68 n = len(bins)

69 for i in range(cl.tree_.value.shape[0]):

70 cl.tree_.value[i, 0, 0] = n - cl.tree_.value[i, 0, 0]

71 return cl

73 tree = Tree(1, numpy.array([1], dtype=numpy.intp), 1)

74 values = []

75 UNUSED = numpy.nan

76 n_nodes = []

79 if index < 0 or index >= len(bins):

80 raise IndexError( # pragma: no cover

81 "Unexpected index %d / len(bins)=%d." % (

82 index, len(bins)))

83 parent = -1

84 is_left = False

85 is_leaf = False

86 threshold = bins[index]

88 tree, parent, is_left, is_leaf, 0, threshold, 0, 1, 1.)

89 values.append(UNUSED)

90 n_nodes.append(n)

91 return n

93 def add_nodes(parent, i, j, is_left):

94 # add for bins[i:j] (j excluded)

95 if is_left:

96 # it means j is the parent split

97 if i == j:

98 # leaf

99 n = tree_add_node(tree, parent, is_left, True, 0, 0, 0, 1, 1.)

100 n_nodes.append(n)

101 values.append(i)

102 return n

103 if i + 1 == j:

104 # split

105 values.append(UNUSED)

106 th = bins[i]

107 n = tree_add_node(tree, parent, is_left,

108 False, 0, th, 0, 1, 1.)

109 n_nodes.append(n)

112 return n

113 if i + 1 < j:

114 # split

115 values.append(UNUSED)

116 index = (i + j) // 2

117 th = bins[index]

118 n = tree_add_node(tree, parent, is_left,

119 False, 0, th, 0, 1, 1.)

120 n_nodes.append(n)

123 return n

124 else:

125 # it means i is the parent split

126 if i + 1 == j:

127 # leaf

128 values.append(j)

129 n = tree_add_node(tree, parent, is_left, True, 0, 0, 0, 1, 1.)

130 n_nodes.append(n)

131 return n

132 if i + 1 < j:

133 # split

134 values.append(UNUSED)

135 index = (i + j) // 2

136 th = bins[index]

137 n = tree_add_node(tree, parent, is_left,

138 False, 0, th, 0, 1, 1.)

139 n_nodes.append(n)

142 return n

143 raise NotImplementedError( # pragma: no cover

144 f"Unexpected case where i={i!r}, j={j!r}, is_left={is_left!r}.")

146 index = len(bins) // 2

151 cl = DecisionTreeRegressor()

152 cl.tree_ = tree

153 cl.tree_.value[:, 0, 0] = numpy.array( # pylint: disable=E1137

154 values, dtype=numpy.float64)

155 cl.n_outputs = 1

156 cl.n_outputs_ = 1

157 try:

158 # scikit-learn >= 0.24

159 cl.n_features_in_ = 1

160 except AttributeError:

161 # scikit-learn < 0.24

162 cl.n_features_ = 1

163 try:

164 # for scikit-learn<=0.23.2

165 cl.n_features_ = 1

166 except AttributeError:

167 pass

168 return cl