Coverage for mlprodict/onnxrt/ops_cpu/op_scatternd.py: 89%
19 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
1# -*- encoding: utf-8 -*-
2# pylint: disable=E0203,E1101,C0111
3"""
4@file
5@brief Runtime operator.
6"""
7import numpy
8from ._op import OpRun
11def _scatter_nd_impl(data, indices, updates, reduction=b'none'):
12 output = numpy.copy(data)
13 for i in numpy.ndindex(indices.shape[:-1]):
14 if reduction == 'add':
15 output[indices[i]] += updates[i]
16 elif reduction == 'mul':
17 output[indices[i]] *= updates[i]
18 else:
19 output[indices[i]] = updates[i]
20 return output
23class ScatterND(OpRun):
25 atts = {'reduction': b'none'}
27 def __init__(self, onnx_node, desc=None, **options):
28 OpRun.__init__(self, onnx_node, desc=desc,
29 expected_attributes=ScatterND.atts,
30 **options)
32 def _run(self, data, indices, updates, attributes=None, verbose=0, fLOG=None): # pylint: disable=W0221
33 y = _scatter_nd_impl(data, indices, updates, reduction=self.reduction)
34 return (y, )