Coverage for mlprodict/onnxrt/ops_cpu/op_dequantize_linear.py: 82%
28 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
1# -*- encoding: utf-8 -*-
2# pylint: disable=E0203,E1101,C0111
3"""
4@file
5@brief Runtime operator.
6"""
7import numpy
8from ._op import OpRun
11class DequantizeLinear(OpRun):
13 atts = {'axis': 1}
14 python_inputs = ['*inputs']
16 def __init__(self, onnx_node, desc=None, **options):
17 OpRun.__init__(self, onnx_node, desc=desc,
18 expected_attributes=DequantizeLinear.atts,
19 **options)
21 def _run(self, *args, attributes=None, verbose=0, fLOG=None): # pylint: disable=W0221
22 if len(args[1].shape) > 1:
23 raise RuntimeError( # pragma: no cover
24 "Input 2 must be a vector or a number.")
26 x_scale = args[2]
27 if len(x_scale.shape) > 0 and x_scale.size == 1:
28 x_scale = x_scale[0]
29 if len(args) > 2:
30 if x_scale.dtype != args[0].dtype:
31 raise RuntimeError( # pragma no cover
32 "Type mismatch {} != {} in DequantizeLinear.".format(
33 args[0].dtype, x_scale.dtype))
35 if len(x_scale.shape) > 0:
36 new_shape = [1 for s in args[0].shape]
37 new_shape[self.axis] = len(x_scale)
38 x = args[0].astype(numpy.float32) - x_scale.reshape(new_shape)
39 y = x * args[1].reshape(new_shape)
40 else:
41 x = args[0].astype(numpy.float32) - x_scale
42 y = x * args[1]
43 elif len(args[1].shape) > 0:
44 new_shape = [1 for s in args[0].shape]
45 new_shape[self.axis] = len(x_scale)
46 y = args[0].astype(numpy.float32) * x_scale.reshape(new_shape)
47 else:
48 y = args[0].astype(numpy.float32) * x_scale
49 return (y.astype(numpy.float32), )