Coverage for mlprodict/onnxrt/ops_cpu/op_cum_sum.py: 90%
39 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
1# -*- encoding: utf-8 -*-
2# pylint: disable=E0203,E1101,C0111
3"""
4@file
5@brief Runtime operator.
6"""
7import numpy
8from ._op import OpRun
11class CumSum(OpRun):
13 atts = {'exclusive': 0, 'reverse': 0}
14 python_inputs = ['x', 'axis=None']
16 def __init__(self, onnx_node, desc=None, **options):
17 OpRun.__init__(self, onnx_node, desc=desc,
18 expected_attributes=CumSum.atts,
19 **options)
21 def _run(self, x, *axis, attributes=None, verbose=0, fLOG=None): # pylint: disable=W0221
22 axis = None if len(axis) == 0 else axis[0]
23 if axis is None:
24 if self.reverse or self.exclusive:
25 raise NotImplementedError( # pragma no cover
26 'reverse=1 or exclusive=1 not implemented')
27 if self.inplaces.get(0, False) and x.flags['WRITEABLE']:
28 return (numpy.cumsum(x, out=x), )
29 return (numpy.cumsum(x), )
30 if not isinstance(axis, (numpy.int32, numpy.int64)):
31 if (len(axis.shape) > 1 or
32 (len(axis.shape) > 0 and axis.shape[0] != 1)):
33 raise RuntimeError( # pragma no cover
34 "axis must be an array of one number not {} "
35 "(shape {})".format(axis, axis.shape))
36 if len(axis.shape) > 0:
37 axis = axis[0] # pylint: disable=E1136
38 if self.reverse:
39 rev_indices = [slice(0, s) for s in x.shape]
40 rev_indices[axis] = slice(None, None, -1)
41 x = x[tuple(rev_indices)]
42 if self.exclusive:
43 indices_c = [slice(0, s) for s in x.shape]
44 indices_d = [slice(0, s) for s in x.shape]
45 indices_c[axis] = slice(0, -1)
46 indices_d[axis] = slice(1, x.shape[axis])
47 res = numpy.zeros(x.shape, dtype=x.dtype)
48 numpy.cumsum(x[tuple(indices_c)], axis=axis,
49 out=res[tuple(indices_d)])
50 else:
51 if self.inplaces.get(0, False) and x.flags['WRITEABLE']:
52 res = numpy.cumsum(x, axis=axis, out=x)
53 else:
54 res = numpy.cumsum(x, axis=axis)
55 if self.reverse:
56 res = res[tuple(rev_indices)]
57 return (res, )
59 def to_python(self, inputs):
60 lines = ['if exclusive or reverse:',
61 ' raise NotImplementedError("reverse=1 or exclusive=1 not implemente")',
62 'if axis is None:',
63 ' return numpy.cumsum(x)',
64 'return numpy.cumsum(x, axis=axis[0])']
65 return 'import numpy', "\n".join(lines)