Coverage for mlprodict/onnxrt/ops_cpu/op_topk.py: 89%

92 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-04 02:28 +0100

1# -*- encoding: utf-8 -*- 

2# pylint: disable=E0203,E1101,C0111 

3""" 

4@file 

5@brief Runtime operator. 

6""" 

7import numpy 

8from onnx.defs import onnx_opset_version 

9from ._op import OpRun 

10from ._op_onnx_numpy import ( # pylint: disable=E0611,E0401 

11 topk_element_min_double, topk_element_max_double, topk_element_fetch_double, 

12 topk_element_min_float, topk_element_max_float, topk_element_fetch_float, 

13 topk_element_min_int64, topk_element_max_int64, topk_element_fetch_int64) 

14 

15 

16def topk_sorted_implementation(X, k, axis, largest): 

17 """ 

18 Retrieves the top-k elements. 

19 

20 @param X data 

21 @param k k in top-k 

22 @param axis axis chosen to select the top-k elements 

23 @param largest largest (1) or smallest (0) 

24 @return top-k values, top-k indices 

25 

26 See function `_kneighbors_reduce_func 

27 <https://github.com/scikit-learn/scikit-learn/tree/master/ 

28 sklearn/neighbors/base.py#L304>`_. 

29 """ 

30 if isinstance(k, numpy.ndarray): 

31 if k.size != 1: 

32 raise RuntimeError( # pragma: no cover 

33 f"k must be an integer not {k!r}.") 

34 k = k[0] 

35 if len(X.shape) == 2 and axis == 1: 

36 sample_range = numpy.arange(X.shape[0])[:, None] 

37 if largest == 0: 

38 sorted_indices = numpy.argpartition(X, axis=axis, kth=k - 1) 

39 sorted_indices = sorted_indices[:, :k] 

40 # argpartition doesn't guarantee sorted order, so we sort again 

41 sorted_indices = sorted_indices[ 

42 sample_range, numpy.argsort(X[sample_range, sorted_indices])] 

43 else: 

44 sorted_indices = numpy.argpartition(-X, axis=axis, kth=k - 1) 

45 sorted_indices = sorted_indices[:, :k] 

46 # argpartition doesn't guarantee sorted order, so we sort again 

47 sorted_indices = sorted_indices[ 

48 sample_range, numpy.argsort(-X[sample_range, sorted_indices])] 

49 sorted_distances = X[sample_range, sorted_indices] 

50 return sorted_distances, sorted_indices 

51 

52 sorted_indices = numpy.argsort(X, axis=axis) 

53 sorted_values = numpy.sort(X, axis=axis) 

54 if largest: 

55 sorted_indices = numpy.flip(sorted_indices, axis=axis) 

56 sorted_values = numpy.flip(sorted_values, axis=axis) 

57 ark = numpy.arange(k) 

58 topk_sorted_indices = numpy.take(sorted_indices, ark, axis=axis) 

59 topk_sorted_values = numpy.take(sorted_values, ark, axis=axis) 

60 return topk_sorted_values, topk_sorted_indices 

61 

62 

63def topk_sorted_implementation_cpp(X, k, axis, largest, th_para=50): 

64 """ 

65 Retrieves the top-k elements using a C++ 

66 implementation when the axis is the last dimension, 

67 otherwise, it falls back to 

68 @see fn topk_sorted_implementation. 

69 

70 @param X data 

71 @param k k in top-k 

72 @param axis axis chosen to select the top-k elements 

73 @param largest largest (1) or smallest (0) 

74 @param th_para threshold for parallelisation 

75 @return top-k values, top-k indices 

76 """ 

77 if isinstance(k, numpy.ndarray): 

78 if k.size != 1: 

79 raise RuntimeError( # pragma: no cover 

80 f"k must be an integer not {k!r}.") 

81 if axis != len(X.shape) - 1: 

82 if k == 0: 

83 return numpy.empty((0,), dtype=numpy.int64) 

84 return topk_sorted_implementation(X, k, axis, largest) 

85 if X.dtype == numpy.float64: 

86 if k == 0: 

87 return numpy.empty((0,), dtype=X.dtype), numpy.empty((0,), dtype=numpy.int64) 

88 if largest: 

89 topk_sorted_indices = topk_element_max_double(X, k, True, th_para) 

90 else: 

91 topk_sorted_indices = topk_element_min_double(X, k, True, th_para) 

92 topk_sorted_values = topk_element_fetch_double(X, topk_sorted_indices) 

93 elif X.dtype == numpy.float32: 

94 if k == 0: 

95 return numpy.empty((0,), dtype=X.dtype), numpy.empty((0,), dtype=numpy.int64) 

96 if largest: 

97 topk_sorted_indices = topk_element_max_float(X, k, True, th_para) 

98 else: 

99 topk_sorted_indices = topk_element_min_float(X, k, True, th_para) 

100 topk_sorted_values = topk_element_fetch_float(X, topk_sorted_indices) 

101 elif X.dtype == numpy.int64: 

102 if k == 0: 

103 return numpy.empty((0,), dtype=X.dtype), numpy.empty((0,), dtype=numpy.int64) 

104 if largest: 

105 topk_sorted_indices = topk_element_max_int64(X, k, True, th_para) 

106 else: 

107 topk_sorted_indices = topk_element_min_int64(X, k, True, th_para) 

108 topk_sorted_values = topk_element_fetch_int64(X, topk_sorted_indices) 

109 else: 

110 if k == 0: 

111 return numpy.empty((0,), dtype=numpy.int64) 

112 return topk_sorted_implementation(X, k, axis, largest) 

113 return topk_sorted_values, topk_sorted_indices 

114 

115 

116class _CommonTopK(OpRun): 

117 """ 

118 Ths class hides a parameter used as a threshold above 

119 which the parallelisation is started: ``th_para``. 

120 """ 

121 

122 atts = {'axis': -1} 

123 

124 def __init__(self, *args, **options): 

125 OpRun.__init__(self, *args, **options) 

126 self.th_para = 50 

127 

128 def _common_run(self, data, ink, largest=1): # pylint: disable=W0221 

129 """ 

130 Runtime for operator *TopK*. 

131 The implementation is not the most efficient 

132 as it sorts everything then extracts the top *k* 

133 values. 

134 

135 .. warning:: 

136 ONNX specifications may be imprecise in case of negative value 

137 for axis. The implementation follows what :epkg:`onnxruntime` 

138 does in `top_k.cc 

139 <https://github.com/Microsoft/onnxruntime/blob/master/onnxruntime/core/providers/cpu/math/top_k.cc#L63>`_. 

140 """ 

141 k = ink[0] 

142 axis = self.axis if self.axis >= 0 else (self.axis + len(data.shape)) 

143 sort, sorti = topk_sorted_implementation_cpp( 

144 data, k, axis, largest, self.th_para) 

145 return (sort, sorti.astype(numpy.int64)) 

146 

147 

148class TopK_1(_CommonTopK): 

149 

150 atts = {'axis': -1, 'k': None} 

151 

152 def __init__(self, onnx_node, desc=None, **options): 

153 _CommonTopK.__init__(self, onnx_node, desc=desc, 

154 expected_attributes=TopK_10.atts, 

155 **options) 

156 

157 def _run(self, data, attributes=None, verbose=0, fLOG=None): # pylint: disable=W0221 

158 """ 

159 Runtime for operator *TopK*. 

160 The implementation is not the most efficient 

161 as it sorts everything then extracts the top *k* 

162 values. 

163 

164 .. warning:: 

165 ONNX specifications may be imprecise in case of negative value 

166 for axis. The implementation follows what :epkg:`onnxruntime` 

167 does in `top_k.cc 

168 <https://github.com/Microsoft/onnxruntime/blob/master/onnxruntime/core/providers/cpu/math/top_k.cc#L63>`_. 

169 """ 

170 return _CommonTopK._common_run(self, data, [self.k]) 

171 

172 

173class TopK_10(_CommonTopK): 

174 

175 atts = {'axis': -1} 

176 

177 def __init__(self, onnx_node, desc=None, **options): 

178 _CommonTopK.__init__(self, onnx_node, desc=desc, 

179 expected_attributes=TopK_10.atts, 

180 **options) 

181 

182 def _run(self, data, ink, attributes=None, verbose=0, fLOG=None): # pylint: disable=W0221 

183 """ 

184 Runtime for operator *TopK*. 

185 The implementation is not the most efficient 

186 as it sorts everything then extracts the top *k* 

187 values. 

188 

189 .. warning:: 

190 ONNX specifications may be imprecise in case of negative value 

191 for axis. The implementation follows what :epkg:`onnxruntime` 

192 does in `top_k.cc 

193 <https://github.com/Microsoft/onnxruntime/blob/master/onnxruntime/core/providers/cpu/math/top_k.cc#L63>`_. 

194 """ 

195 return _CommonTopK._common_run(self, data, ink) 

196 

197 

198class TopK_11(_CommonTopK): 

199 

200 atts = {'axis': -1, 'largest': 1, 'sorted': 1} 

201 

202 def __init__(self, onnx_node, desc=None, **options): 

203 _CommonTopK.__init__(self, onnx_node, desc=desc, 

204 expected_attributes=TopK_11.atts, 

205 **options) 

206 if self.sorted not in (True, 1): 

207 raise RuntimeError( # pragma: no cover 

208 "TopK does not implement anything for sorted=0.") 

209 

210 def _run(self, data, ink, attributes=None, verbose=0, fLOG=None): # pylint: disable=W0221 

211 """ 

212 Runtime for operator *TopK*. 

213 The implementation is not the most efficient 

214 as it sorts everything then extracts the top *k* 

215 values. 

216 

217 .. warning:: 

218 ONNX specifications may be imprecise in case of negative value 

219 for axis. The implementation follows what :epkg:`onnxruntime` 

220 does in `top_k.cc 

221 <https://github.com/Microsoft/onnxruntime/blob/master/onnxruntime/core/providers/cpu/math/top_k.cc#L63>`_. 

222 """ 

223 return _CommonTopK._common_run(self, data, ink, self.largest) 

224 

225 

226if onnx_opset_version() >= 11: 

227 TopK = TopK_11 

228elif onnx_opset_version() >= 10: # pragma: no cover 

229 TopK = TopK_10 

230else: # pragma: no cover 

231 TopK = TopK_1