Coverage for mlprodict/onnxrt/ops_cpu/op_gather_elements.py: 100%
33 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
1# -*- encoding: utf-8 -*-
2# pylint: disable=E0203,E1101,C0111
3"""
4@file
5@brief Runtime operator.
6"""
7import numpy
8from ._op import OpRun
11def gather_numpy_2(self, dim, index):
12 res = []
13 for a, b in zip(self, index):
14 res.append(a[b[0]])
15 res = numpy.array(
16 res, dtype=self.dtype).reshape(index.shape)
17 return res
20def gather_numpy(self, dim, index):
21 """
22 Gathers values along an axis specified by dim.
23 For a 3-D tensor the output is specified by:
25 ::
27 out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0
28 out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1
29 out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2
31 :param dim: The axis along which to index
32 :param index: A tensor of indices of elements to gather
33 :return: tensor of gathered values
35 See `How to do scatter and gather operations in numpy?
36 <https://stackoverflow.com/questions/46065873/
37 how-to-do-scatter-and-gather-operations-in-numpy/46204790#46204790>`_
38 """
39 idx_xsection_shape = index.shape[:dim] + index.shape[dim + 1:]
40 self_xsection_shape = self.shape[:dim] + self.shape[dim + 1:]
41 if idx_xsection_shape != self_xsection_shape:
42 raise ValueError( # pragma: no cover
43 "Except for dimension {}, all dimensions of "
44 "index and self should be the same size".format(dim))
45 data_swaped = numpy.swapaxes(self, 0, dim)
46 index_swaped = numpy.swapaxes(index, 0, dim)
48 try:
49 gathered = numpy.choose(index_swaped, data_swaped, mode='wrap')
50 except ValueError as e:
51 if len(index_swaped.shape) == 2 and len(data_swaped.shape) == 2:
52 return gather_numpy_2(self, dim, index)
53 raise e # pragma: no cover
55 return numpy.swapaxes(gathered, 0, dim)
58class GatherElements(OpRun):
60 atts = {'axis': 0}
62 def __init__(self, onnx_node, desc=None, **options):
63 OpRun.__init__(self, onnx_node, desc=desc,
64 expected_attributes=GatherElements.atts,
65 **options)
67 def _run(self, data, indices, attributes=None, verbose=0, fLOG=None): # pylint: disable=W0221
68 if indices.size == 0:
69 return (numpy.empty((0, ), dtype=data.dtype), )
70 y = gather_numpy(data, self.axis, indices)
71 return (y, )
73 def to_python(self, inputs):
74 lines = [f'data_swaped = numpy.swapaxes({inputs[0]}, 0, axis)',
75 f'index_swaped = numpy.swapaxes({inputs[1]}, 0, axis)',
76 "gathered = numpy.choose(index_swaped, data_swaped, mode='wrap')",
77 'return numpy.swapaxes(gathered, 0, axis)']
78 return "import numpy", "\n".join(lines)