Coverage for mlinsights/mltree/tree_digitize.py: 98%

85 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-28 08:46 +0100

1""" 

2@file 

3@brief Helpers to investigate a tree structure. 

4 

5.. versionadded:: 0.4 

6""" 

7import numpy 

8from sklearn.tree._tree import Tree # pylint: disable=E0611 

9from sklearn.tree import DecisionTreeRegressor 

10from ._tree_digitize import tree_add_node # pylint: disable=E0611 

11 

12 

13def digitize2tree(bins, right=False): 

14 """ 

15 Builds a decision tree which returns the same result as 

16 `lambda x: numpy.digitize(x, bins, right=right)` 

17 (see :epkg:`numpy:digitize`). 

18 

19 :param bins: array of bins. It has to be 1-dimensional and monotonic. 

20 :param right: Indicating whether the intervals include the right 

21 or the left bin edge. Default behavior is (right==False) 

22 indicating that the interval does not include the right edge. 

23 The left bin end is open in this case, i.e., 

24 `bins[i-1] <= x < bins[i]` is the default behavior for 

25 monotonically increasing bins. 

26 :return: decision tree 

27 

28 .. note:: 

29 The implementation of decision trees in :epkg:`scikit-learn` 

30 only allows one type of decision (`<=`). That's why the 

31 function throws an exception when `right=False`. However, 

32 this could be overcome by using :epkg:`ONNX` where all 

33 kind of decision rules are implemented. Default value for 

34 right is still *False* to follow *numpy* API even though 

35 this value raises an exception in *digitize2tree*. 

36 

37 The following example shows what the tree looks like. 

38 

39 .. runpython:: 

40 :showcode: 

41 

42 import numpy 

43 from sklearn.tree import export_text 

44 from mlinsights.mltree import digitize2tree 

45 

46 x = numpy.array([0.2, 6.4, 3.0, 1.6]) 

47 bins = numpy.array([0.0, 1.0, 2.5, 4.0, 7.0]) 

48 expected = numpy.digitize(x, bins, right=True) 

49 tree = digitize2tree(bins, right=True) 

50 pred = tree.predict(x.reshape((-1, 1))) 

51 print("Comparison with numpy:") 

52 print(expected, pred) 

53 print("Tree:") 

54 print(export_text(tree, feature_names=['x'])) 

55 

56 See also example :ref:`l-example-digitize`. 

57 

58 .. versionadded:: 0.4 

59 """ 

60 if not right: 

61 raise RuntimeError( 

62 f"right must be True not right={right!r}") 

63 ascending = len(bins) <= 1 or bins[0] < bins[1] 

64 

65 if not ascending: 

66 bins2 = bins[::-1] 

67 cl = digitize2tree(bins2, right=right) 

68 n = len(bins) 

69 for i in range(cl.tree_.value.shape[0]): 

70 cl.tree_.value[i, 0, 0] = n - cl.tree_.value[i, 0, 0] 

71 return cl 

72 

73 tree = Tree(1, numpy.array([1], dtype=numpy.intp), 1) 

74 values = [] 

75 UNUSED = numpy.nan 

76 n_nodes = [] 

77 

78 def add_root(index): 

79 if index < 0 or index >= len(bins): 

80 raise IndexError( # pragma: no cover 

81 "Unexpected index %d / len(bins)=%d." % ( 

82 index, len(bins))) 

83 parent = -1 

84 is_left = False 

85 is_leaf = False 

86 threshold = bins[index] 

87 n = tree_add_node( 

88 tree, parent, is_left, is_leaf, 0, threshold, 0, 1, 1.) 

89 values.append(UNUSED) 

90 n_nodes.append(n) 

91 return n 

92 

93 def add_nodes(parent, i, j, is_left): 

94 # add for bins[i:j] (j excluded) 

95 if is_left: 

96 # it means j is the parent split 

97 if i == j: 

98 # leaf 

99 n = tree_add_node(tree, parent, is_left, True, 0, 0, 0, 1, 1.) 

100 n_nodes.append(n) 

101 values.append(i) 

102 return n 

103 if i + 1 == j: 

104 # split 

105 values.append(UNUSED) 

106 th = bins[i] 

107 n = tree_add_node(tree, parent, is_left, 

108 False, 0, th, 0, 1, 1.) 

109 n_nodes.append(n) 

110 add_nodes(n, i, i, True) 

111 add_nodes(n, i, j, False) 

112 return n 

113 if i + 1 < j: 

114 # split 

115 values.append(UNUSED) 

116 index = (i + j) // 2 

117 th = bins[index] 

118 n = tree_add_node(tree, parent, is_left, 

119 False, 0, th, 0, 1, 1.) 

120 n_nodes.append(n) 

121 add_nodes(n, i, index, True) 

122 add_nodes(n, index, j, False) 

123 return n 

124 else: 

125 # it means i is the parent split 

126 if i + 1 == j: 

127 # leaf 

128 values.append(j) 

129 n = tree_add_node(tree, parent, is_left, True, 0, 0, 0, 1, 1.) 

130 n_nodes.append(n) 

131 return n 

132 if i + 1 < j: 

133 # split 

134 values.append(UNUSED) 

135 index = (i + j) // 2 

136 th = bins[index] 

137 n = tree_add_node(tree, parent, is_left, 

138 False, 0, th, 0, 1, 1.) 

139 n_nodes.append(n) 

140 add_nodes(n, i, index, True) 

141 add_nodes(n, index, j, False) 

142 return n 

143 raise NotImplementedError( # pragma: no cover 

144 f"Unexpected case where i={i!r}, j={j!r}, is_left={is_left!r}.") 

145 

146 index = len(bins) // 2 

147 add_root(index) 

148 add_nodes(0, 0, index, True) 

149 add_nodes(0, index, len(bins), False) 

150 

151 cl = DecisionTreeRegressor() 

152 cl.tree_ = tree 

153 cl.tree_.value[:, 0, 0] = numpy.array( # pylint: disable=E1137 

154 values, dtype=numpy.float64) 

155 cl.n_outputs = 1 

156 cl.n_outputs_ = 1 

157 try: 

158 # scikit-learn >= 0.24 

159 cl.n_features_in_ = 1 

160 except AttributeError: 

161 # scikit-learn < 0.24 

162 cl.n_features_ = 1 

163 try: 

164 # for scikit-learn<=0.23.2 

165 cl.n_features_ = 1 

166 except AttributeError: 

167 pass 

168 return cl