Coverage for mlprodict/onnxrt/ops_cpu/op_gathernd.py: 100%
24 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
1"""
2@file
3@brief Runtime operator.
4"""
5import numpy
6from ._op import OpRun
9def _gather_nd_impl(data, indices, batch_dims):
10 """
11 Modified version of `softmaxcrossentropy.py
12 <https://github.com/onnx/onnx/blob/main/onnx/backend/
13 test/case/node/gathernd.py>`_.
14 """
15 # Note the data rank - will be reused multiple times later
16 data_rank = len(data.shape)
18 # The list of data/indice shape of batch_dims.
19 batch_dims_shape = []
21 # The number of elements in the batch_dims for data/indice array.
22 batch_dims_size = 1
24 # Check the shape of indice and data are identicial for batch dims.
25 for i in range(batch_dims):
26 batch_dims_shape.append(indices.shape[i])
27 batch_dims_size *= indices.shape[i]
29 # Compute output of the op as below.
30 # Compute shape of output array.
31 output_shape = (
32 batch_dims_shape + list(indices.shape)[batch_dims:-1]
33 if (indices.shape[-1] == data_rank - batch_dims)
34 else batch_dims_shape + list(indices.shape)[batch_dims:-1] +
35 list(data.shape)[batch_dims + indices.shape[-1]:])
37 # Placeholder for output data.
38 output_data_buffer = []
40 # Flatten 'indices' to 2D array.
41 reshaped_indices = indices.reshape(batch_dims_size, -1, indices.shape[-1])
43 # Flatten 'data' to array of shape
44 # (batch_dim_size, data.shape[batch_dimes:]).
45 reshaped_data = data.reshape((batch_dims_size, ) + data.shape[batch_dims:])
47 # Gather each scalar value from 'data'.
48 for batch_dim in range(reshaped_indices.shape[0]):
49 for outer_dim in range(reshaped_indices.shape[1]):
50 gather_index = tuple(reshaped_indices[batch_dim][outer_dim])
51 output_data_buffer.append(
52 reshaped_data[(batch_dim,) + gather_index])
53 return (numpy.asarray(output_data_buffer,
54 dtype=data.dtype).reshape(output_shape), )
57class GatherND(OpRun):
58 """
59 Python runtime for function *SoftmaxCrossEntropyLoss*.
60 """
62 atts = {'batch_dims': 0}
64 def __init__(self, onnx_node, desc=None, **options):
65 OpRun.__init__(self, onnx_node, desc=desc,
66 expected_attributes=GatherND.atts,
67 **options)
69 def _run(self, data, indices, attributes=None, verbose=0, fLOG=None): # pylint: disable=W0221
70 return _gather_nd_impl(data, indices, self.batch_dims) # pylint: disable=E1101