Coverage for mlprodict/onnxrt/ops_cpu/op_gather_elements.py: 100%

33 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-04 02:28 +0100

1# -*- encoding: utf-8 -*- 

2# pylint: disable=E0203,E1101,C0111 

3""" 

4@file 

5@brief Runtime operator. 

6""" 

7import numpy 

8from ._op import OpRun 

9 

10 

11def gather_numpy_2(self, dim, index): 

12 res = [] 

13 for a, b in zip(self, index): 

14 res.append(a[b[0]]) 

15 res = numpy.array( 

16 res, dtype=self.dtype).reshape(index.shape) 

17 return res 

18 

19 

20def gather_numpy(self, dim, index): 

21 """ 

22 Gathers values along an axis specified by dim. 

23 For a 3-D tensor the output is specified by: 

24 

25 :: 

26 

27 out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0 

28 out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1 

29 out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2 

30 

31 :param dim: The axis along which to index 

32 :param index: A tensor of indices of elements to gather 

33 :return: tensor of gathered values 

34 

35 See `How to do scatter and gather operations in numpy? 

36 <https://stackoverflow.com/questions/46065873/ 

37 how-to-do-scatter-and-gather-operations-in-numpy/46204790#46204790>`_ 

38 """ 

39 idx_xsection_shape = index.shape[:dim] + index.shape[dim + 1:] 

40 self_xsection_shape = self.shape[:dim] + self.shape[dim + 1:] 

41 if idx_xsection_shape != self_xsection_shape: 

42 raise ValueError( # pragma: no cover 

43 "Except for dimension {}, all dimensions of " 

44 "index and self should be the same size".format(dim)) 

45 data_swaped = numpy.swapaxes(self, 0, dim) 

46 index_swaped = numpy.swapaxes(index, 0, dim) 

47 

48 try: 

49 gathered = numpy.choose(index_swaped, data_swaped, mode='wrap') 

50 except ValueError as e: 

51 if len(index_swaped.shape) == 2 and len(data_swaped.shape) == 2: 

52 return gather_numpy_2(self, dim, index) 

53 raise e # pragma: no cover 

54 

55 return numpy.swapaxes(gathered, 0, dim) 

56 

57 

58class GatherElements(OpRun): 

59 

60 atts = {'axis': 0} 

61 

62 def __init__(self, onnx_node, desc=None, **options): 

63 OpRun.__init__(self, onnx_node, desc=desc, 

64 expected_attributes=GatherElements.atts, 

65 **options) 

66 

67 def _run(self, data, indices, attributes=None, verbose=0, fLOG=None): # pylint: disable=W0221 

68 if indices.size == 0: 

69 return (numpy.empty((0, ), dtype=data.dtype), ) 

70 y = gather_numpy(data, self.axis, indices) 

71 return (y, ) 

72 

73 def to_python(self, inputs): 

74 lines = [f'data_swaped = numpy.swapaxes({inputs[0]}, 0, axis)', 

75 f'index_swaped = numpy.swapaxes({inputs[1]}, 0, axis)', 

76 "gathered = numpy.choose(index_swaped, data_swaped, mode='wrap')", 

77 'return numpy.swapaxes(gathered, 0, axis)'] 

78 return "import numpy", "\n".join(lines)