Coverage for mlprodict/onnxrt/ops_cpu/op_quantize_linear.py: 88%

49 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-04 02:28 +0100

1# -*- encoding: utf-8 -*- 

2# pylint: disable=E0203,E1101,C0111 

3""" 

4@file 

5@brief Runtime operator. 

6""" 

7import numpy 

8from ._op import OpRun 

9 

10 

11class _CommonQuantizeLinear(OpRun): 

12 

13 def __init__(self, onnx_node, desc=None, 

14 expected_attributes=None, **options): 

15 OpRun.__init__(self, onnx_node, desc=desc, 

16 expected_attributes=expected_attributes, 

17 **options) 

18 

19 def common_run(self, x, y_scale, zero_point=None, axis=1): # pylint: disable=W0221 

20 if len(y_scale.shape) > 1: 

21 raise RuntimeError( # pragma: no cover 

22 "Input 2 must be a vector or a number.") 

23 if len(y_scale.shape) > 0 and y_scale.size == 1: 

24 y_scale = y_scale[0] 

25 if len(y_scale.shape) > 0: 

26 new_shape = [1 for s in x.shape] 

27 new_shape[axis] = len(y_scale) 

28 x = x / y_scale.reshape(new_shape) 

29 else: 

30 x = x / y_scale 

31 if zero_point is not None: 

32 dtype = zero_point.dtype 

33 if len(y_scale.shape) > 0: 

34 x += zero_point.reshape(new_shape) 

35 else: 

36 x += zero_point 

37 numpy.around(x, 1, out=x) 

38 if dtype == numpy.uint8: 

39 numpy.clip(x, 0, 255, out=x) 

40 elif dtype == numpy.int8: 

41 numpy.clip(x, -128, 127, out=x) 

42 else: 

43 raise RuntimeError( # pragma no cover 

44 f"Unexpected dtype for input 2 {dtype}.") 

45 return (x.astype(dtype), ) 

46 

47 dtype = numpy.uint8 

48 numpy.around(x, 1, out=x) 

49 numpy.clip(x, 0, 255, out=x) 

50 return (x.astype(dtype), ) 

51 

52 

53class QuantizeLinear(_CommonQuantizeLinear): 

54 

55 atts = {'axis': 1} 

56 python_inputs = ['*inputs'] 

57 

58 def __init__(self, onnx_node, desc=None, **options): 

59 _CommonQuantizeLinear.__init__( 

60 self, onnx_node, desc=desc, 

61 expected_attributes=QuantizeLinear.atts, 

62 **options) 

63 

64 def _run(self, *args, attributes=None, verbose=0, fLOG=None): # pylint: disable=W0221 

65 # args: x, y_scale, zero_point 

66 return self.common_run(*args, axis=self.axis) 

67 

68 

69class DynamicQuantizeLinear(OpRun): 

70 

71 def __init__(self, onnx_node, desc=None, **options): 

72 OpRun.__init__(self, onnx_node, desc=desc, 

73 **options) 

74 self.dtype = numpy.uint8 

75 

76 def _run(self, x, attributes=None, verbose=0, fLOG=None): # pylint: disable=W0221 

77 # args: x, y_scale, zero_point 

78 qmin, qmax = 0, 255 

79 minx = numpy.min(x) 

80 y_scale = (numpy.max(x) - minx) / (qmax - qmin) 

81 intermediate_zero_point = qmin - minx / y_scale 

82 y_zero_point = numpy.round( 

83 numpy.clip(intermediate_zero_point, qmin, qmax)).astype(self.dtype) 

84 y = numpy.clip(numpy.round(x / y_scale) + y_zero_point, qmin, qmax) 

85 return (y.astype(self.dtype), 

86 y_scale.astype(x.dtype), 

87 y_zero_point.astype(self.dtype))