Coverage for mlprodict/onnxrt/ops_cpu/op_quantize_linear.py: 88%
49 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
1# -*- encoding: utf-8 -*-
2# pylint: disable=E0203,E1101,C0111
3"""
4@file
5@brief Runtime operator.
6"""
7import numpy
8from ._op import OpRun
11class _CommonQuantizeLinear(OpRun):
13 def __init__(self, onnx_node, desc=None,
14 expected_attributes=None, **options):
15 OpRun.__init__(self, onnx_node, desc=desc,
16 expected_attributes=expected_attributes,
17 **options)
19 def common_run(self, x, y_scale, zero_point=None, axis=1): # pylint: disable=W0221
20 if len(y_scale.shape) > 1:
21 raise RuntimeError( # pragma: no cover
22 "Input 2 must be a vector or a number.")
23 if len(y_scale.shape) > 0 and y_scale.size == 1:
24 y_scale = y_scale[0]
25 if len(y_scale.shape) > 0:
26 new_shape = [1 for s in x.shape]
27 new_shape[axis] = len(y_scale)
28 x = x / y_scale.reshape(new_shape)
29 else:
30 x = x / y_scale
31 if zero_point is not None:
32 dtype = zero_point.dtype
33 if len(y_scale.shape) > 0:
34 x += zero_point.reshape(new_shape)
35 else:
36 x += zero_point
37 numpy.around(x, 1, out=x)
38 if dtype == numpy.uint8:
39 numpy.clip(x, 0, 255, out=x)
40 elif dtype == numpy.int8:
41 numpy.clip(x, -128, 127, out=x)
42 else:
43 raise RuntimeError( # pragma no cover
44 f"Unexpected dtype for input 2 {dtype}.")
45 return (x.astype(dtype), )
47 dtype = numpy.uint8
48 numpy.around(x, 1, out=x)
49 numpy.clip(x, 0, 255, out=x)
50 return (x.astype(dtype), )
53class QuantizeLinear(_CommonQuantizeLinear):
55 atts = {'axis': 1}
56 python_inputs = ['*inputs']
58 def __init__(self, onnx_node, desc=None, **options):
59 _CommonQuantizeLinear.__init__(
60 self, onnx_node, desc=desc,
61 expected_attributes=QuantizeLinear.atts,
62 **options)
64 def _run(self, *args, attributes=None, verbose=0, fLOG=None): # pylint: disable=W0221
65 # args: x, y_scale, zero_point
66 return self.common_run(*args, axis=self.axis)
69class DynamicQuantizeLinear(OpRun):
71 def __init__(self, onnx_node, desc=None, **options):
72 OpRun.__init__(self, onnx_node, desc=desc,
73 **options)
74 self.dtype = numpy.uint8
76 def _run(self, x, attributes=None, verbose=0, fLOG=None): # pylint: disable=W0221
77 # args: x, y_scale, zero_point
78 qmin, qmax = 0, 255
79 minx = numpy.min(x)
80 y_scale = (numpy.max(x) - minx) / (qmax - qmin)
81 intermediate_zero_point = qmin - minx / y_scale
82 y_zero_point = numpy.round(
83 numpy.clip(intermediate_zero_point, qmin, qmax)).astype(self.dtype)
84 y = numpy.clip(numpy.round(x / y_scale) + y_zero_point, qmin, qmax)
85 return (y.astype(self.dtype),
86 y_scale.astype(x.dtype),
87 y_zero_point.astype(self.dtype))