Coverage for mlprodict/onnxrt/ops_cpu/op_squeeze.py: 84%
37 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
1# -*- encoding: utf-8 -*-
2# pylint: disable=E0203,E1101,C0111
3"""
4@file
5@brief Runtime operator.
6"""
7import numpy
8from onnx.defs import onnx_opset_version
9from ._op import OpRunUnaryNum, OpRun
12class Squeeze_1(OpRunUnaryNum):
14 atts = {'axes': [], 'keepdims': 1}
16 def __init__(self, onnx_node, desc=None, **options):
17 OpRunUnaryNum.__init__(self, onnx_node, desc=desc,
18 expected_attributes=Squeeze_1.atts,
19 **options)
20 if isinstance(self.axes, numpy.ndarray):
21 self.axes = tuple(self.axes)
22 elif self.axes in [[], tuple()]:
23 self.axes = None
24 elif isinstance(self.axes, list):
25 self.axes = tuple(self.axes)
27 def _run(self, data, attributes=None, verbose=0, fLOG=None): # pylint: disable=W0221
28 if isinstance(self.axes, (tuple, list)):
29 sq = data
30 for a in reversed(self.axes):
31 sq = numpy.squeeze(sq, axis=a)
32 else:
33 sq = numpy.squeeze(data, axis=self.axes)
34 return (sq, )
37class Squeeze_11(Squeeze_1):
38 pass
41class Squeeze_13(OpRun):
43 atts = {'keepdims': 1}
45 def __init__(self, onnx_node, desc=None, **options):
46 OpRun.__init__(self, onnx_node, desc=desc,
47 expected_attributes=Squeeze_13.atts,
48 **options)
49 self.axes = None
51 def _run(self, data, axes=None, attributes=None, verbose=0, fLOG=None): # pylint: disable=W0221
52 if axes is not None:
53 if hasattr(axes, '__iter__'):
54 sq = numpy.squeeze(data, axis=tuple(axes))
55 else:
56 sq = numpy.squeeze(data, axis=axes)
57 else:
58 sq = numpy.squeeze(data)
59 return (sq, )
62if onnx_opset_version() >= 13:
63 Squeeze = Squeeze_13
64elif onnx_opset_version() >= 11: # pragma: no cover
65 Squeeze = Squeeze_11
66else: # pragma: no cover
67 Squeeze = Squeeze_1