Coverage for mlprodict/onnxrt/ops_cpu/op_slice.py: 95%

44 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-04 02:28 +0100

1# -*- encoding: utf-8 -*- 

2# pylint: disable=E0203,E1101,C0111 

3""" 

4@file 

5@brief Runtime operator. 

6""" 

7import numpy 

8from onnx.defs import onnx_opset_version 

9from ._op import OpRun 

10 

11 

12def _slice(data, starts, ends, axes=None, steps=None): 

13 if len(starts.shape) == 0: 

14 starts = numpy.array([starts]) 

15 if len(ends.shape) == 0: 

16 ends = numpy.array([ends]) 

17 if axes is None: 

18 if steps is None: 

19 slices = [slice(s, e) for s, e in zip(starts, ends)] 

20 else: 

21 slices = [slice(s, e, d) 

22 for s, e, d in zip(starts, ends, steps)] 

23 else: 

24 if steps is None: 

25 slices = [slice(0, a) for a in data.shape] 

26 for s, e, a in zip(starts, ends, axes): 

27 slices[a] = slice(s, e) 

28 else: 

29 slices = [slice(0, a) for a in data.shape] 

30 for s, e, a, d in zip(starts, ends, axes, steps): 

31 slices[a] = slice(s, e, d) 

32 try: 

33 return data[tuple(slices)] 

34 except TypeError as e: # pragma: no cover 

35 raise TypeError( 

36 f"Unable to extract slice {slices!r} for shape {data.shape!r}.") from e 

37 

38 

39class SliceCommon(OpRun): 

40 

41 def __init__(self, onnx_node, desc=None, **options): 

42 OpRun.__init__(self, onnx_node, desc=desc, 

43 **options) 

44 

45 def _run(self, data, starts, ends, axes=None, steps=None, attributes=None, verbose=0, fLOG=None): # pylint: disable=W0221 

46 res = _slice(data, starts, ends, axes, steps) 

47 return (res, ) 

48 

49 

50class Slice_10(SliceCommon): 

51 def __init__(self, onnx_node, desc=None, **options): 

52 SliceCommon.__init__(self, onnx_node, desc=desc, 

53 **options) 

54 

55 

56class Slice_1(SliceCommon): 

57 

58 atts = {'starts': [], 'ends': [], 'axes': []} 

59 

60 def __init__(self, onnx_node, desc=None, **options): 

61 SliceCommon.__init__(self, onnx_node, desc=desc, 

62 expected_attributes=Slice_1.atts, 

63 **options) 

64 for f in ['starts', 'ends', 'steps', 'axes']: 

65 if not hasattr(self, f): 

66 continue 

67 if getattr(self, f) is not None and len(getattr(self, f)) == 0: 

68 setattr(self, f, None) 

69 

70 def _run(self, data, attributes=None, verbose=0, fLOG=None): # pylint: disable=W0221 

71 return SliceCommon._run( 

72 self, data, self.starts, self.ends, self.axes) 

73 

74 

75if onnx_opset_version() >= 10: 

76 Slice = Slice_10 

77else: 

78 Slice = Slice_1 # pragma: no cover