Coverage for mlprodict/onnxrt/ops_cpu/op_unsqueeze.py: 89%
35 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
1# -*- encoding: utf-8 -*-
2# pylint: disable=E0203,E1101,C0111
3"""
4@file
5@brief Runtime operator.
6"""
7import numpy
8from onnx.defs import onnx_opset_version
9from ._op import OpRunUnaryNum, OpRun
12class Unsqueeze_1(OpRunUnaryNum):
14 atts = {'axes': [], 'keepdims': 1}
16 def __init__(self, onnx_node, desc=None, **options):
17 OpRunUnaryNum.__init__(self, onnx_node, desc=desc,
18 expected_attributes=Unsqueeze_1.atts,
19 **options)
20 if isinstance(self.axes, numpy.ndarray):
21 self.axes = tuple(self.axes)
22 elif self.axes in [[], tuple()]:
23 self.axes = None
24 elif isinstance(self.axes, list):
25 self.axes = tuple(self.axes)
27 def _run(self, data, attributes=None, verbose=0, fLOG=None): # pylint: disable=W0221
28 if isinstance(self.axes, (tuple, list)):
29 sq = data
30 for a in self.axes:
31 sq = numpy.expand_dims(sq, axis=a)
32 else:
33 raise RuntimeError( # pragma: no cover
34 "axes cannot be None for operator Unsqueeze (Unsqueeze_1).")
35 return (sq, )
38class Unsqueeze_11(Unsqueeze_1):
39 pass
42class Unsqueeze_13(OpRun):
44 atts = {'keepdims': 1}
46 def __init__(self, onnx_node, desc=None, **options):
47 OpRun.__init__(self, onnx_node, desc=desc,
48 expected_attributes=Unsqueeze_13.atts,
49 **options)
50 self.axes = None
52 def _run(self, data, axes=None, attributes=None, verbose=0, fLOG=None): # pylint: disable=W0221
53 if axes is not None:
54 if hasattr(axes, '__iter__') and len(axes.shape) > 0:
55 sq = numpy.expand_dims(data, axis=tuple(axes))
56 else:
57 sq = numpy.expand_dims(data, axis=axes)
58 else:
59 raise RuntimeError( # pragma: no cover
60 "axes cannot be None for operator Unsqueeze (Unsqueeze_13).")
61 return (sq, )
64if onnx_opset_version() >= 13:
65 Unsqueeze = Unsqueeze_13
66elif onnx_opset_version() >= 11: # pragma: no cover
67 Unsqueeze = Unsqueeze_11
68else: # pragma: no cover
69 Unsqueeze = Unsqueeze_1