Coverage for mlprodict/onnxrt/ops_cpu/op_gathernd.py: 100%

24 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-04 02:28 +0100

1""" 

2@file 

3@brief Runtime operator. 

4""" 

5import numpy 

6from ._op import OpRun 

7 

8 

9def _gather_nd_impl(data, indices, batch_dims): 

10 """ 

11 Modified version of `softmaxcrossentropy.py 

12 <https://github.com/onnx/onnx/blob/main/onnx/backend/ 

13 test/case/node/gathernd.py>`_. 

14 """ 

15 # Note the data rank - will be reused multiple times later 

16 data_rank = len(data.shape) 

17 

18 # The list of data/indice shape of batch_dims. 

19 batch_dims_shape = [] 

20 

21 # The number of elements in the batch_dims for data/indice array. 

22 batch_dims_size = 1 

23 

24 # Check the shape of indice and data are identicial for batch dims. 

25 for i in range(batch_dims): 

26 batch_dims_shape.append(indices.shape[i]) 

27 batch_dims_size *= indices.shape[i] 

28 

29 # Compute output of the op as below. 

30 # Compute shape of output array. 

31 output_shape = ( 

32 batch_dims_shape + list(indices.shape)[batch_dims:-1] 

33 if (indices.shape[-1] == data_rank - batch_dims) 

34 else batch_dims_shape + list(indices.shape)[batch_dims:-1] + 

35 list(data.shape)[batch_dims + indices.shape[-1]:]) 

36 

37 # Placeholder for output data. 

38 output_data_buffer = [] 

39 

40 # Flatten 'indices' to 2D array. 

41 reshaped_indices = indices.reshape(batch_dims_size, -1, indices.shape[-1]) 

42 

43 # Flatten 'data' to array of shape 

44 # (batch_dim_size, data.shape[batch_dimes:]). 

45 reshaped_data = data.reshape((batch_dims_size, ) + data.shape[batch_dims:]) 

46 

47 # Gather each scalar value from 'data'. 

48 for batch_dim in range(reshaped_indices.shape[0]): 

49 for outer_dim in range(reshaped_indices.shape[1]): 

50 gather_index = tuple(reshaped_indices[batch_dim][outer_dim]) 

51 output_data_buffer.append( 

52 reshaped_data[(batch_dim,) + gather_index]) 

53 return (numpy.asarray(output_data_buffer, 

54 dtype=data.dtype).reshape(output_shape), ) 

55 

56 

57class GatherND(OpRun): 

58 """ 

59 Python runtime for function *SoftmaxCrossEntropyLoss*. 

60 """ 

61 

62 atts = {'batch_dims': 0} 

63 

64 def __init__(self, onnx_node, desc=None, **options): 

65 OpRun.__init__(self, onnx_node, desc=desc, 

66 expected_attributes=GatherND.atts, 

67 **options) 

68 

69 def _run(self, data, indices, attributes=None, verbose=0, fLOG=None): # pylint: disable=W0221 

70 return _gather_nd_impl(data, indices, self.batch_dims) # pylint: disable=E1101