Coverage for mlprodict/onnxrt/ops_cpu/op_slice.py: 95%
44 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
1# -*- encoding: utf-8 -*-
2# pylint: disable=E0203,E1101,C0111
3"""
4@file
5@brief Runtime operator.
6"""
7import numpy
8from onnx.defs import onnx_opset_version
9from ._op import OpRun
12def _slice(data, starts, ends, axes=None, steps=None):
13 if len(starts.shape) == 0:
14 starts = numpy.array([starts])
15 if len(ends.shape) == 0:
16 ends = numpy.array([ends])
17 if axes is None:
18 if steps is None:
19 slices = [slice(s, e) for s, e in zip(starts, ends)]
20 else:
21 slices = [slice(s, e, d)
22 for s, e, d in zip(starts, ends, steps)]
23 else:
24 if steps is None:
25 slices = [slice(0, a) for a in data.shape]
26 for s, e, a in zip(starts, ends, axes):
27 slices[a] = slice(s, e)
28 else:
29 slices = [slice(0, a) for a in data.shape]
30 for s, e, a, d in zip(starts, ends, axes, steps):
31 slices[a] = slice(s, e, d)
32 try:
33 return data[tuple(slices)]
34 except TypeError as e: # pragma: no cover
35 raise TypeError(
36 f"Unable to extract slice {slices!r} for shape {data.shape!r}.") from e
39class SliceCommon(OpRun):
41 def __init__(self, onnx_node, desc=None, **options):
42 OpRun.__init__(self, onnx_node, desc=desc,
43 **options)
45 def _run(self, data, starts, ends, axes=None, steps=None, attributes=None, verbose=0, fLOG=None): # pylint: disable=W0221
46 res = _slice(data, starts, ends, axes, steps)
47 return (res, )
50class Slice_10(SliceCommon):
51 def __init__(self, onnx_node, desc=None, **options):
52 SliceCommon.__init__(self, onnx_node, desc=desc,
53 **options)
56class Slice_1(SliceCommon):
58 atts = {'starts': [], 'ends': [], 'axes': []}
60 def __init__(self, onnx_node, desc=None, **options):
61 SliceCommon.__init__(self, onnx_node, desc=desc,
62 expected_attributes=Slice_1.atts,
63 **options)
64 for f in ['starts', 'ends', 'steps', 'axes']:
65 if not hasattr(self, f):
66 continue
67 if getattr(self, f) is not None and len(getattr(self, f)) == 0:
68 setattr(self, f, None)
70 def _run(self, data, attributes=None, verbose=0, fLOG=None): # pylint: disable=W0221
71 return SliceCommon._run(
72 self, data, self.starts, self.ends, self.axes)
75if onnx_opset_version() >= 10:
76 Slice = Slice_10
77else:
78 Slice = Slice_1 # pragma: no cover