Coverage for mlprodict/onnxrt/ops_cpu/op_scatternd.py: 89%

19 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-04 02:28 +0100

1# -*- encoding: utf-8 -*- 

2# pylint: disable=E0203,E1101,C0111 

3""" 

4@file 

5@brief Runtime operator. 

6""" 

7import numpy 

8from ._op import OpRun 

9 

10 

11def _scatter_nd_impl(data, indices, updates, reduction=b'none'): 

12 output = numpy.copy(data) 

13 for i in numpy.ndindex(indices.shape[:-1]): 

14 if reduction == 'add': 

15 output[indices[i]] += updates[i] 

16 elif reduction == 'mul': 

17 output[indices[i]] *= updates[i] 

18 else: 

19 output[indices[i]] = updates[i] 

20 return output 

21 

22 

23class ScatterND(OpRun): 

24 

25 atts = {'reduction': b'none'} 

26 

27 def __init__(self, onnx_node, desc=None, **options): 

28 OpRun.__init__(self, onnx_node, desc=desc, 

29 expected_attributes=ScatterND.atts, 

30 **options) 

31 

32 def _run(self, data, indices, updates, attributes=None, verbose=0, fLOG=None): # pylint: disable=W0221 

33 y = _scatter_nd_impl(data, indices, updates, reduction=self.reduction) 

34 return (y, )