Coverage for mlprodict/onnxrt/ops_cpu/op_conv_transpose.py: 94%
18 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
1# -*- encoding: utf-8 -*-
2# pylint: disable=E0203,E1101,C0111
3"""
4@file
5@brief Runtime operator.
6"""
7import numpy
8from ._op import OpRun
9from .op_conv_transpose_ import ( # pylint: disable=E0611,E0401
10 ConvTransposeFloat, ConvTransposeDouble)
13class ConvTranspose(OpRun):
15 atts = {'auto_pad': 'NOTSET', 'group': 1,
16 'dilations': [],
17 'kernel_shape': [],
18 'pads': [],
19 'strides': [],
20 'output_padding': [],
21 'output_shape': []}
23 def __init__(self, onnx_node, desc=None, **options):
24 OpRun.__init__(self, onnx_node, desc=desc,
25 expected_attributes=ConvTranspose.atts,
26 **options)
27 self._init()
29 def _init(self):
30 self.rt32_ = ConvTransposeFloat()
31 self.rt64_ = ConvTransposeDouble()
32 for rt in [self.rt32_, self.rt64_]:
33 rt.init(self.auto_pad,
34 numpy.array(self.dilations, dtype=numpy.int64),
35 self.group,
36 numpy.array(self.kernel_shape, dtype=numpy.int64),
37 numpy.array(self.pads, dtype=numpy.int64),
38 numpy.array(self.strides, dtype=numpy.int64),
39 numpy.array(self.output_padding, dtype=numpy.int64),
40 numpy.array(self.output_shape, dtype=numpy.int64))
42 def _run(self, X, W, B=None, attributes=None, verbose=0, fLOG=None): # pylint: disable=W0221
43 if X.dtype == numpy.float32:
44 return (self.rt32_.compute(X, W, B), )
45 return (self.rt64_.compute(X, W, B), )