Coverage for mlprodict/onnxrt/ops_cpu/op_transpose.py: 100%
16 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
1# -*- encoding: utf-8 -*-
2# pylint: disable=E0203,E1101,C0111
3"""
4@file
5@brief Runtime operator.
6"""
7import numpy
8from ._op import OpRunUnaryNum
11class Transpose(OpRunUnaryNum):
13 atts = {'perm': []}
15 def __init__(self, onnx_node, desc=None, **options):
16 OpRunUnaryNum.__init__(self, onnx_node, desc=desc,
17 expected_attributes=Transpose.atts,
18 **options)
19 self.perm_ = None if len(self.perm) == 0 else self.perm
21 def _run(self, data, attributes=None, verbose=0, fLOG=None): # pylint: disable=W0221
22 if self.perm_ is None:
23 return (numpy.transpose(data), )
24 if len(self.perm_) != len(data.shape):
25 raise RuntimeError( # pragma: no cover
26 f"Inconsistent permutation {self.perm_!r} with shape {data.shape!r}.")
27 return (numpy.transpose(data, axes=self.perm_), )
29 def to_python(self, inputs):
30 """
31 Returns a python code equivalent to this operator.
33 @param inputs inputs name
34 @return imports, python code, both as strings
35 """
36 lines = [
37 "if perm is None:",
38 f" return numpy.transpose({inputs[0]})",
39 f"return numpy.transpose({inputs[0]}, axes=perm)"
40 ]
41 return "import numpy", "\n".join(lines)