Coverage for mlprodict/onnxrt/ops_cpu/op_scatter_elements.py: 97%
37 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
1# -*- encoding: utf-8 -*-
2# pylint: disable=E0203,E1101,C0111
3"""
4@file
5@brief Runtime operator.
6"""
7import numpy
8from ._op import OpRun
11def scatter_elements(data, indices, updates, axis=0):
12 """
13 ::
14 // for 3-dim and axis=0
15 // output[indices[i][j][k]][j][k] = updates[i][j][k]
16 // for axis 1
17 // output[i][indices[i][j][k]][k] = updates[i][j][k]
18 // and so on
19 """
20 if len(data.shape) == 1 and axis == 0:
21 scattered = numpy.copy(data)
22 for pos, up in zip(indices, updates):
23 scattered[pos] = up
24 return scattered
26 if axis < 0:
27 axis = data.ndim + axis
29 idx_xsection_shape = indices.shape[:axis] + indices.shape[axis + 1:]
31 def make_slice(arr, axis, i):
32 slc = [slice(None)] * arr.ndim
33 slc[axis] = i
34 return slc
36 def unpack(packed):
37 unpacked = packed[0]
38 for i in range(1, len(packed)):
39 unpacked = unpacked, packed[i]
40 return unpacked
42 # We use indices and axis parameters to create idx
43 # idx is in a form that can be used as a NumPy advanced
44 # indices for scattering of updates param. in data
45 idx = [[unpack(numpy.indices(idx_xsection_shape).reshape(indices.ndim - 1, -1)),
46 indices[tuple(make_slice(indices, axis, i))].reshape(1, -1)[0]]
47 for i in range(indices.shape[axis])]
48 idx = list(numpy.concatenate(idx, axis=1))
49 idx.insert(axis, idx.pop())
51 # updates_idx is a NumPy advanced indices for indexing
52 # of elements in the updates
53 updates_idx = list(idx)
54 updates_idx.pop(axis)
55 updates_idx.insert(axis, numpy.repeat(numpy.arange(indices.shape[axis]),
56 numpy.prod(idx_xsection_shape)))
58 scattered = numpy.copy(data)
59 scattered[tuple(idx)] = updates[tuple(updates_idx)]
60 return scattered
63class ScatterElements(OpRun):
65 atts = {'axis': 0}
67 def __init__(self, onnx_node, desc=None, **options):
68 OpRun.__init__(self, onnx_node, desc=desc,
69 **options)
71 def _run(self, data, indices, updates, attributes=None, verbose=0, fLOG=None): # pylint: disable=W0221
72 res = scatter_elements(data, indices, updates, axis=self.axis)
73 return (res, )