Coverage for mlprodict/onnxrt/ops_cpu/op_reshape.py: 97%
29 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
1# -*- encoding: utf-8 -*-
2# pylint: disable=E0203,E1101,C0111
3"""
4@file
5@brief Runtime operator.
6"""
7import numpy
8from onnx.defs import onnx_opset_version
9from ._op import OpRun
12def reshape_reference_implementation(data, shape):
13 new_shape = numpy.copy(shape)
14 zeros_index = numpy.where(shape == 0)
15 if len(data.shape) == 1 and data.shape[0] == 0:
16 reshaped = numpy.reshape(data, shape)
17 else:
18 try:
19 new_shape[zeros_index] = numpy.array(data.shape)[zeros_index]
20 except IndexError as e: # pragma: no cover
21 raise RuntimeError(
22 "Unable to reshape from shape %r to shape %r (or %r)."
23 "" % (data.shape, shape, new_shape)) from e
24 reshaped = numpy.reshape(data, new_shape)
25 return reshaped
28class CommonReshape(OpRun):
30 def __init__(self, onnx_node, desc=None, expected_attributes=None, **options):
31 OpRun.__init__(
32 self, onnx_node, desc=desc,
33 expected_attributes=expected_attributes, **options)
35 def _run(self, data, shape, attributes=None, verbose=0, fLOG=None): # pylint: disable=W0221
36 return (reshape_reference_implementation(data, shape), )
39class Reshape_5(CommonReshape):
41 def __init__(self, onnx_node, desc=None, expected_attributes=None, **options):
42 CommonReshape.__init__(self, onnx_node, desc=desc, **options)
45class Reshape_13(Reshape_5):
46 pass
49class Reshape_14(CommonReshape):
51 atts = {'allowzero': 0}
53 def __init__(self, onnx_node, desc=None, **options):
54 CommonReshape.__init__(
55 self, onnx_node, desc=desc,
56 expected_attributes=Reshape_14.atts, **options)
59if onnx_opset_version() >= 14:
60 Reshape = Reshape_14
61else:
62 Reshape = Reshape_5 # pragma: no cover