Coverage for mlprodict/onnxrt/ops_cpu/op_split.py: 100%

38 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-04 02:28 +0100

1# -*- encoding: utf-8 -*- 

2# pylint: disable=E0203,E1101,C0111 

3""" 

4@file 

5@brief Runtime operator. 

6""" 

7from onnx.defs import onnx_opset_version 

8from ._op import OpRun 

9 

10 

11class CommonSplit(OpRun): 

12 """ 

13 Runtime for operator *Split*. 

14 """ 

15 

16 def __init__(self, onnx_node, desc=None, 

17 expected_attributes=None, **options): 

18 if 'split' not in options: 

19 options['split'] = None 

20 OpRun.__init__(self, onnx_node, desc=desc, 

21 expected_attributes=expected_attributes, 

22 **options) 

23 self.nb_outputs = len(onnx_node.output) 

24 

25 def common_run(self, mat, split): # pylint: disable=W0221 

26 if split is None: 

27 div = mat.shape[self.axis] // self.nb_outputs 

28 split = [div] * self.nb_outputs 

29 split[-1] += mat.shape[self.axis] - sum(split) 

30 sli = [slice(0, s) for s in mat.shape] 

31 res = [] 

32 pos = 0 

33 for spl in split: 

34 sli[self.axis] = slice(pos, pos + spl) 

35 pos += spl 

36 res.append(mat[tuple(sli)]) 

37 return tuple(res) 

38 

39 

40class Split_2(CommonSplit): 

41 """ 

42 Runtime for operator *Split*. 

43 """ 

44 

45 atts = {'axis': 0, 'split': None} 

46 

47 def __init__(self, onnx_node, desc=None, **options): 

48 CommonSplit.__init__(self, onnx_node, desc=desc, 

49 expected_attributes=Split_2.atts, **options) 

50 

51 def _run(self, mat, attributes=None, verbose=0, fLOG=None): # pylint: disable=W0221 

52 return self.common_run(mat, self.split) 

53 

54 

55class Split_11(Split_2): 

56 """ 

57 Runtime for operator *Split*. 

58 """ 

59 pass 

60 

61 

62class Split_13(CommonSplit): 

63 """ 

64 Runtime for operator *Split*. 

65 """ 

66 

67 atts = {'axis': 0} 

68 

69 def __init__(self, onnx_node, desc=None, **options): 

70 CommonSplit.__init__(self, onnx_node, desc=desc, 

71 expected_attributes=Split_13.atts, **options) 

72 

73 def _run(self, mat, split=None, attributes=None, verbose=0, fLOG=None): # pylint: disable=W0221 

74 return self.common_run(mat, split) 

75 

76 

77if onnx_opset_version() >= 13: 

78 Split = Split_13 

79elif onnx_opset_version() >= 11: # pragma: no cover 

80 Split = Split_11 

81else: # pragma: no cover 

82 Split = Split_2