Coverage for mlprodict/onnxrt/ops_cpu/op_shape.py: 97%
32 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
1# -*- encoding: utf-8 -*-
2# pylint: disable=E0203,E1101,C0111
3"""
4@file
5@brief Runtime operator.
6"""
7import numpy
8from onnx.defs import onnx_opset_version
9from ._op import OpRun
12class Shape_1(OpRun):
14 def __init__(self, onnx_node, desc=None, **options):
15 OpRun.__init__(self, onnx_node, desc=desc, **options)
17 def _run(self, data, attributes=None, verbose=0, fLOG=None): # pylint: disable=W0221
18 return (numpy.array(data.shape, dtype=numpy.int64), )
21class Shape_15(Shape_1):
23 atts = {'start': 0, 'end': numpy.nan}
25 def __init__(self, onnx_node, desc=None, **options):
26 Shape_1.__init__(self, onnx_node, desc=desc,
27 expected_attributes=Shape_15.atts, **options)
29 def _interval(self, n):
30 if self.start == 0:
31 if numpy.isnan(self.end):
32 return None
33 elif self.end < 0:
34 return (0, n + self.end)
35 return (0, self.end)
36 if numpy.isnan(self.end):
37 return (self.start, n)
38 elif self.end < 0:
39 return (self.start, n + self.end)
40 return (self.start, self.end)
42 def _run(self, data, attributes=None, verbose=0, fLOG=None): # pylint: disable=W0221
43 ab = self._interval(len(data.shape))
44 if ab is None:
45 return (numpy.array(data.shape, dtype=numpy.int64), )
46 return (numpy.array(data.shape[ab[0]: ab[1]], dtype=numpy.int64), )
49if onnx_opset_version() >= 15:
50 Shape = Shape_15
51else: # pragma: no cover
52 Shape = Shape_1