Coverage for mlprodict/onnxrt/ops_cpu/op_topk.py: 89%
92 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
1# -*- encoding: utf-8 -*-
2# pylint: disable=E0203,E1101,C0111
3"""
4@file
5@brief Runtime operator.
6"""
7import numpy
8from onnx.defs import onnx_opset_version
9from ._op import OpRun
10from ._op_onnx_numpy import ( # pylint: disable=E0611,E0401
11 topk_element_min_double, topk_element_max_double, topk_element_fetch_double,
12 topk_element_min_float, topk_element_max_float, topk_element_fetch_float,
13 topk_element_min_int64, topk_element_max_int64, topk_element_fetch_int64)
16def topk_sorted_implementation(X, k, axis, largest):
17 """
18 Retrieves the top-k elements.
20 @param X data
21 @param k k in top-k
22 @param axis axis chosen to select the top-k elements
23 @param largest largest (1) or smallest (0)
24 @return top-k values, top-k indices
26 See function `_kneighbors_reduce_func
27 <https://github.com/scikit-learn/scikit-learn/tree/master/
28 sklearn/neighbors/base.py#L304>`_.
29 """
30 if isinstance(k, numpy.ndarray):
31 if k.size != 1:
32 raise RuntimeError( # pragma: no cover
33 f"k must be an integer not {k!r}.")
34 k = k[0]
35 if len(X.shape) == 2 and axis == 1:
36 sample_range = numpy.arange(X.shape[0])[:, None]
37 if largest == 0:
38 sorted_indices = numpy.argpartition(X, axis=axis, kth=k - 1)
39 sorted_indices = sorted_indices[:, :k]
40 # argpartition doesn't guarantee sorted order, so we sort again
41 sorted_indices = sorted_indices[
42 sample_range, numpy.argsort(X[sample_range, sorted_indices])]
43 else:
44 sorted_indices = numpy.argpartition(-X, axis=axis, kth=k - 1)
45 sorted_indices = sorted_indices[:, :k]
46 # argpartition doesn't guarantee sorted order, so we sort again
47 sorted_indices = sorted_indices[
48 sample_range, numpy.argsort(-X[sample_range, sorted_indices])]
49 sorted_distances = X[sample_range, sorted_indices]
50 return sorted_distances, sorted_indices
52 sorted_indices = numpy.argsort(X, axis=axis)
53 sorted_values = numpy.sort(X, axis=axis)
54 if largest:
55 sorted_indices = numpy.flip(sorted_indices, axis=axis)
56 sorted_values = numpy.flip(sorted_values, axis=axis)
57 ark = numpy.arange(k)
58 topk_sorted_indices = numpy.take(sorted_indices, ark, axis=axis)
59 topk_sorted_values = numpy.take(sorted_values, ark, axis=axis)
60 return topk_sorted_values, topk_sorted_indices
63def topk_sorted_implementation_cpp(X, k, axis, largest, th_para=50):
64 """
65 Retrieves the top-k elements using a C++
66 implementation when the axis is the last dimension,
67 otherwise, it falls back to
68 @see fn topk_sorted_implementation.
70 @param X data
71 @param k k in top-k
72 @param axis axis chosen to select the top-k elements
73 @param largest largest (1) or smallest (0)
74 @param th_para threshold for parallelisation
75 @return top-k values, top-k indices
76 """
77 if isinstance(k, numpy.ndarray):
78 if k.size != 1:
79 raise RuntimeError( # pragma: no cover
80 f"k must be an integer not {k!r}.")
81 if axis != len(X.shape) - 1:
82 if k == 0:
83 return numpy.empty((0,), dtype=numpy.int64)
84 return topk_sorted_implementation(X, k, axis, largest)
85 if X.dtype == numpy.float64:
86 if k == 0:
87 return numpy.empty((0,), dtype=X.dtype), numpy.empty((0,), dtype=numpy.int64)
88 if largest:
89 topk_sorted_indices = topk_element_max_double(X, k, True, th_para)
90 else:
91 topk_sorted_indices = topk_element_min_double(X, k, True, th_para)
92 topk_sorted_values = topk_element_fetch_double(X, topk_sorted_indices)
93 elif X.dtype == numpy.float32:
94 if k == 0:
95 return numpy.empty((0,), dtype=X.dtype), numpy.empty((0,), dtype=numpy.int64)
96 if largest:
97 topk_sorted_indices = topk_element_max_float(X, k, True, th_para)
98 else:
99 topk_sorted_indices = topk_element_min_float(X, k, True, th_para)
100 topk_sorted_values = topk_element_fetch_float(X, topk_sorted_indices)
101 elif X.dtype == numpy.int64:
102 if k == 0:
103 return numpy.empty((0,), dtype=X.dtype), numpy.empty((0,), dtype=numpy.int64)
104 if largest:
105 topk_sorted_indices = topk_element_max_int64(X, k, True, th_para)
106 else:
107 topk_sorted_indices = topk_element_min_int64(X, k, True, th_para)
108 topk_sorted_values = topk_element_fetch_int64(X, topk_sorted_indices)
109 else:
110 if k == 0:
111 return numpy.empty((0,), dtype=numpy.int64)
112 return topk_sorted_implementation(X, k, axis, largest)
113 return topk_sorted_values, topk_sorted_indices
116class _CommonTopK(OpRun):
117 """
118 Ths class hides a parameter used as a threshold above
119 which the parallelisation is started: ``th_para``.
120 """
122 atts = {'axis': -1}
124 def __init__(self, *args, **options):
125 OpRun.__init__(self, *args, **options)
126 self.th_para = 50
128 def _common_run(self, data, ink, largest=1): # pylint: disable=W0221
129 """
130 Runtime for operator *TopK*.
131 The implementation is not the most efficient
132 as it sorts everything then extracts the top *k*
133 values.
135 .. warning::
136 ONNX specifications may be imprecise in case of negative value
137 for axis. The implementation follows what :epkg:`onnxruntime`
138 does in `top_k.cc
139 <https://github.com/Microsoft/onnxruntime/blob/master/onnxruntime/core/providers/cpu/math/top_k.cc#L63>`_.
140 """
141 k = ink[0]
142 axis = self.axis if self.axis >= 0 else (self.axis + len(data.shape))
143 sort, sorti = topk_sorted_implementation_cpp(
144 data, k, axis, largest, self.th_para)
145 return (sort, sorti.astype(numpy.int64))
148class TopK_1(_CommonTopK):
150 atts = {'axis': -1, 'k': None}
152 def __init__(self, onnx_node, desc=None, **options):
153 _CommonTopK.__init__(self, onnx_node, desc=desc,
154 expected_attributes=TopK_10.atts,
155 **options)
157 def _run(self, data, attributes=None, verbose=0, fLOG=None): # pylint: disable=W0221
158 """
159 Runtime for operator *TopK*.
160 The implementation is not the most efficient
161 as it sorts everything then extracts the top *k*
162 values.
164 .. warning::
165 ONNX specifications may be imprecise in case of negative value
166 for axis. The implementation follows what :epkg:`onnxruntime`
167 does in `top_k.cc
168 <https://github.com/Microsoft/onnxruntime/blob/master/onnxruntime/core/providers/cpu/math/top_k.cc#L63>`_.
169 """
170 return _CommonTopK._common_run(self, data, [self.k])
173class TopK_10(_CommonTopK):
175 atts = {'axis': -1}
177 def __init__(self, onnx_node, desc=None, **options):
178 _CommonTopK.__init__(self, onnx_node, desc=desc,
179 expected_attributes=TopK_10.atts,
180 **options)
182 def _run(self, data, ink, attributes=None, verbose=0, fLOG=None): # pylint: disable=W0221
183 """
184 Runtime for operator *TopK*.
185 The implementation is not the most efficient
186 as it sorts everything then extracts the top *k*
187 values.
189 .. warning::
190 ONNX specifications may be imprecise in case of negative value
191 for axis. The implementation follows what :epkg:`onnxruntime`
192 does in `top_k.cc
193 <https://github.com/Microsoft/onnxruntime/blob/master/onnxruntime/core/providers/cpu/math/top_k.cc#L63>`_.
194 """
195 return _CommonTopK._common_run(self, data, ink)
198class TopK_11(_CommonTopK):
200 atts = {'axis': -1, 'largest': 1, 'sorted': 1}
202 def __init__(self, onnx_node, desc=None, **options):
203 _CommonTopK.__init__(self, onnx_node, desc=desc,
204 expected_attributes=TopK_11.atts,
205 **options)
206 if self.sorted not in (True, 1):
207 raise RuntimeError( # pragma: no cover
208 "TopK does not implement anything for sorted=0.")
210 def _run(self, data, ink, attributes=None, verbose=0, fLOG=None): # pylint: disable=W0221
211 """
212 Runtime for operator *TopK*.
213 The implementation is not the most efficient
214 as it sorts everything then extracts the top *k*
215 values.
217 .. warning::
218 ONNX specifications may be imprecise in case of negative value
219 for axis. The implementation follows what :epkg:`onnxruntime`
220 does in `top_k.cc
221 <https://github.com/Microsoft/onnxruntime/blob/master/onnxruntime/core/providers/cpu/math/top_k.cc#L63>`_.
222 """
223 return _CommonTopK._common_run(self, data, ink, self.largest)
226if onnx_opset_version() >= 11:
227 TopK = TopK_11
228elif onnx_opset_version() >= 10: # pragma: no cover
229 TopK = TopK_10
230else: # pragma: no cover
231 TopK = TopK_1