Coverage for mlprodict/onnxrt/ops_cpu/op_scatter_elements.py: 97%

37 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-04 02:28 +0100

1# -*- encoding: utf-8 -*- 

2# pylint: disable=E0203,E1101,C0111 

3""" 

4@file 

5@brief Runtime operator. 

6""" 

7import numpy 

8from ._op import OpRun 

9 

10 

11def scatter_elements(data, indices, updates, axis=0): 

12 """ 

13 :: 

14 // for 3-dim and axis=0 

15 // output[indices[i][j][k]][j][k] = updates[i][j][k] 

16 // for axis 1 

17 // output[i][indices[i][j][k]][k] = updates[i][j][k] 

18 // and so on 

19 """ 

20 if len(data.shape) == 1 and axis == 0: 

21 scattered = numpy.copy(data) 

22 for pos, up in zip(indices, updates): 

23 scattered[pos] = up 

24 return scattered 

25 

26 if axis < 0: 

27 axis = data.ndim + axis 

28 

29 idx_xsection_shape = indices.shape[:axis] + indices.shape[axis + 1:] 

30 

31 def make_slice(arr, axis, i): 

32 slc = [slice(None)] * arr.ndim 

33 slc[axis] = i 

34 return slc 

35 

36 def unpack(packed): 

37 unpacked = packed[0] 

38 for i in range(1, len(packed)): 

39 unpacked = unpacked, packed[i] 

40 return unpacked 

41 

42 # We use indices and axis parameters to create idx 

43 # idx is in a form that can be used as a NumPy advanced 

44 # indices for scattering of updates param. in data 

45 idx = [[unpack(numpy.indices(idx_xsection_shape).reshape(indices.ndim - 1, -1)), 

46 indices[tuple(make_slice(indices, axis, i))].reshape(1, -1)[0]] 

47 for i in range(indices.shape[axis])] 

48 idx = list(numpy.concatenate(idx, axis=1)) 

49 idx.insert(axis, idx.pop()) 

50 

51 # updates_idx is a NumPy advanced indices for indexing 

52 # of elements in the updates 

53 updates_idx = list(idx) 

54 updates_idx.pop(axis) 

55 updates_idx.insert(axis, numpy.repeat(numpy.arange(indices.shape[axis]), 

56 numpy.prod(idx_xsection_shape))) 

57 

58 scattered = numpy.copy(data) 

59 scattered[tuple(idx)] = updates[tuple(updates_idx)] 

60 return scattered 

61 

62 

63class ScatterElements(OpRun): 

64 

65 atts = {'axis': 0} 

66 

67 def __init__(self, onnx_node, desc=None, **options): 

68 OpRun.__init__(self, onnx_node, desc=desc, 

69 **options) 

70 

71 def _run(self, data, indices, updates, attributes=None, verbose=0, fLOG=None): # pylint: disable=W0221 

72 res = scatter_elements(data, indices, updates, axis=self.axis) 

73 return (res, )