Coverage for mlprodict/onnxrt/ops_cpu/op_cum_sum.py: 90%

39 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-04 02:28 +0100

1# -*- encoding: utf-8 -*- 

2# pylint: disable=E0203,E1101,C0111 

3""" 

4@file 

5@brief Runtime operator. 

6""" 

7import numpy 

8from ._op import OpRun 

9 

10 

11class CumSum(OpRun): 

12 

13 atts = {'exclusive': 0, 'reverse': 0} 

14 python_inputs = ['x', 'axis=None'] 

15 

16 def __init__(self, onnx_node, desc=None, **options): 

17 OpRun.__init__(self, onnx_node, desc=desc, 

18 expected_attributes=CumSum.atts, 

19 **options) 

20 

21 def _run(self, x, *axis, attributes=None, verbose=0, fLOG=None): # pylint: disable=W0221 

22 axis = None if len(axis) == 0 else axis[0] 

23 if axis is None: 

24 if self.reverse or self.exclusive: 

25 raise NotImplementedError( # pragma no cover 

26 'reverse=1 or exclusive=1 not implemented') 

27 if self.inplaces.get(0, False) and x.flags['WRITEABLE']: 

28 return (numpy.cumsum(x, out=x), ) 

29 return (numpy.cumsum(x), ) 

30 if not isinstance(axis, (numpy.int32, numpy.int64)): 

31 if (len(axis.shape) > 1 or 

32 (len(axis.shape) > 0 and axis.shape[0] != 1)): 

33 raise RuntimeError( # pragma no cover 

34 "axis must be an array of one number not {} " 

35 "(shape {})".format(axis, axis.shape)) 

36 if len(axis.shape) > 0: 

37 axis = axis[0] # pylint: disable=E1136 

38 if self.reverse: 

39 rev_indices = [slice(0, s) for s in x.shape] 

40 rev_indices[axis] = slice(None, None, -1) 

41 x = x[tuple(rev_indices)] 

42 if self.exclusive: 

43 indices_c = [slice(0, s) for s in x.shape] 

44 indices_d = [slice(0, s) for s in x.shape] 

45 indices_c[axis] = slice(0, -1) 

46 indices_d[axis] = slice(1, x.shape[axis]) 

47 res = numpy.zeros(x.shape, dtype=x.dtype) 

48 numpy.cumsum(x[tuple(indices_c)], axis=axis, 

49 out=res[tuple(indices_d)]) 

50 else: 

51 if self.inplaces.get(0, False) and x.flags['WRITEABLE']: 

52 res = numpy.cumsum(x, axis=axis, out=x) 

53 else: 

54 res = numpy.cumsum(x, axis=axis) 

55 if self.reverse: 

56 res = res[tuple(rev_indices)] 

57 return (res, ) 

58 

59 def to_python(self, inputs): 

60 lines = ['if exclusive or reverse:', 

61 ' raise NotImplementedError("reverse=1 or exclusive=1 not implemente")', 

62 'if axis is None:', 

63 ' return numpy.cumsum(x)', 

64 'return numpy.cumsum(x, axis=axis[0])'] 

65 return 'import numpy', "\n".join(lines)