Coverage for mlprodict/onnxrt/ops_cpu/op_split.py: 100%
38 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
1# -*- encoding: utf-8 -*-
2# pylint: disable=E0203,E1101,C0111
3"""
4@file
5@brief Runtime operator.
6"""
7from onnx.defs import onnx_opset_version
8from ._op import OpRun
11class CommonSplit(OpRun):
12 """
13 Runtime for operator *Split*.
14 """
16 def __init__(self, onnx_node, desc=None,
17 expected_attributes=None, **options):
18 if 'split' not in options:
19 options['split'] = None
20 OpRun.__init__(self, onnx_node, desc=desc,
21 expected_attributes=expected_attributes,
22 **options)
23 self.nb_outputs = len(onnx_node.output)
25 def common_run(self, mat, split): # pylint: disable=W0221
26 if split is None:
27 div = mat.shape[self.axis] // self.nb_outputs
28 split = [div] * self.nb_outputs
29 split[-1] += mat.shape[self.axis] - sum(split)
30 sli = [slice(0, s) for s in mat.shape]
31 res = []
32 pos = 0
33 for spl in split:
34 sli[self.axis] = slice(pos, pos + spl)
35 pos += spl
36 res.append(mat[tuple(sli)])
37 return tuple(res)
40class Split_2(CommonSplit):
41 """
42 Runtime for operator *Split*.
43 """
45 atts = {'axis': 0, 'split': None}
47 def __init__(self, onnx_node, desc=None, **options):
48 CommonSplit.__init__(self, onnx_node, desc=desc,
49 expected_attributes=Split_2.atts, **options)
51 def _run(self, mat, attributes=None, verbose=0, fLOG=None): # pylint: disable=W0221
52 return self.common_run(mat, self.split)
55class Split_11(Split_2):
56 """
57 Runtime for operator *Split*.
58 """
59 pass
62class Split_13(CommonSplit):
63 """
64 Runtime for operator *Split*.
65 """
67 atts = {'axis': 0}
69 def __init__(self, onnx_node, desc=None, **options):
70 CommonSplit.__init__(self, onnx_node, desc=desc,
71 expected_attributes=Split_13.atts, **options)
73 def _run(self, mat, split=None, attributes=None, verbose=0, fLOG=None): # pylint: disable=W0221
74 return self.common_run(mat, split)
77if onnx_opset_version() >= 13:
78 Split = Split_13
79elif onnx_opset_version() >= 11: # pragma: no cover
80 Split = Split_11
81else: # pragma: no cover
82 Split = Split_2