Coverage for mlprodict/testing/test_utils/quantized_tensor.py: 94%

70 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-04 02:28 +0100

1""" 

2@file 

3@brief Initializes a quantized tensor from float values. 

4""" 

5import numpy 

6from ...npy.xop import loadop 

7from ...onnxrt import OnnxInference 

8 

9 

10class QuantizedTensor: 

11 """ 

12 Instantiates a quantized tensor (uint8) from a float tensor. 

13 

14 :param data: array 

15 :param scale: scale if data.dtype is float32, None otherwise 

16 :param zero_point: zero_point if data.dtype is float32, None otherwise 

17 """ 

18 

19 def __init__(self, data, scale=None, zero_point=None): 

20 "constructor" 

21 if data.dtype == numpy.float32: 

22 if scale is not None or zero_point is not None: 

23 raise ValueError( # pragma: no cover 

24 "scale and zero_point are ignored.") 

25 self._init(data) 

26 elif data.dtype == numpy.uint8: 

27 if scale is None or zero_point is None: 

28 raise ValueError( # pragma: no cover 

29 "scale and zero_point must be specified.") 

30 self.quantized_ = data 

31 self.scale_ = numpy.float32(scale) 

32 self.zero_point_ = numpy.uint8(zero_point) 

33 

34 def _init(self, data): 

35 "Initialization when dtype is float32." 

36 rav = data.flatten().astype(numpy.float32) 

37 mini = min(rav.min(), numpy.float32(0)) 

38 maxi = max(rav.max(), numpy.float32(0)) 

39 

40 info = numpy.iinfo(numpy.uint8) 

41 qmin = numpy.float32(info.min) 

42 qmax = numpy.float32(info.max) 

43 

44 self.scale_ = (maxi - mini) / (qmax - qmin) 

45 initial_zero_point = qmin - mini / self.scale_ 

46 self.zero_point_ = numpy.uint8(numpy.round( 

47 max(qmin, min(qmax, initial_zero_point)))) 

48 

49 self.quantized_ = numpy.empty(data.size, dtype=numpy.uint8) 

50 for i in range(0, data.size): 

51 clamped_val = numpy.float32( 

52 max(qmin, min(qmax, numpy.round(data[i] / self.scale_) + self.zero_point_))) 

53 self.quantized_[i] = numpy.uint8(clamped_val) 

54 

55 if self.quantized_.dtype != numpy.uint8: 

56 raise TypeError( # pragma: no cover 

57 f"dtype={self.quantized_.dtype} not uint8") 

58 

59 

60class QuantizedBiasTensor: 

61 """ 

62 Instantiates a quantized tensor (uint8) with bias 

63 from a float tensor. 

64 

65 :param data: array 

66 :param X_or_scale: a @see cl QuantizedTensor or a float 

67 :param zero_point: a @see cl QuantizedTensor or or None 

68 """ 

69 

70 def __init__(self, data, X_or_scale, W: QuantizedTensor = None): 

71 if W is None: 

72 self.quantized_ = data 

73 self.scale_ = numpy.float32(X_or_scale) 

74 else: 

75 self.scale_ = X_or_scale.scale_ * W.scale_ 

76 

77 self.quantized_ = numpy.empty(data.size(), dtype=numpy.int32) 

78 for i in range(0, data.size()): 

79 self.quantized_[i] = numpy.int32( 

80 numpy.floor(data[i] / (X_or_scale.scale_ * W.scale_))) 

81 if self.quantized_.dtype != numpy.int32: 

82 raise TypeError( # pragma: no cover 

83 f"dtype={self.quantized_.dtype} not int32") 

84 

85 

86def test_qlinear_conv(x: QuantizedTensor, x_shape, 

87 w: QuantizedTensor, w_shape, 

88 b: QuantizedBiasTensor, 

89 y: QuantizedTensor, y_shape, 

90 opset=None, runtime='python', 

91 pads=None, strides=None, group=None): 

92 """ 

93 Checks a runtime for operator `QLinearConv`. 

94 

95 :param x: @see cl QuantizedTensor 

96 :param x_shape: shape of X 

97 :param w: @see cl QuantizedTensor 

98 :param w_shape: shape of W 

99 :param b: @see cl QuantizedBiasTensor or None 

100 :param y: expected output, @see cl QuantizedTensor or None 

101 :param y_shape: shape of Y 

102 :param opset: desired onnx opset 

103 :param runtime: runtime for @see cl OnnxInference 

104 :param pads: optional parameter for operator `QLinearConv` 

105 :param strides: optional parameter for operator `QLinearConv` 

106 :param group: optional paramerer for operator `QLinearConv` 

107 """ 

108 OnnxQLinearConv = loadop(('', 'QLinearConv')) 

109 

110 if opset is None: 

111 from ... import __max_supported_opset__ 

112 opset = __max_supported_opset__ 

113 

114 kwargs = {} 

115 if pads is not None: 

116 kwargs['pads'] = pads 

117 if strides is not None: 

118 kwargs['strides'] = strides 

119 if group is not None: 

120 kwargs['group'] = group 

121 

122 if b is None: 

123 inputs_list = [ 

124 'x', 'x_scale', 'x_zero_point', 'w', 'w_scale', 'w_zero_point', 

125 'y_scale', 'y_zero_point'] 

126 inputs = {'x': x.quantized_.reshape(x_shape), 

127 'x_scale': x.scale_, 'x_zero_point': x.zero_point_, 

128 'w': w.quantized_.reshape(w_shape), 

129 'w_scale': w.scale_, 'w_zero_point': w.zero_point_, 

130 'y_scale': y.scale_, 'y_zero_point': y.zero_point_} 

131 else: 

132 inputs_list = [ 

133 'x', 'x_scale', 'x_zero_point', 'w', 'w_scale', 'w_zero_point', 

134 'y_scale', 'y_zero_point', 'b'] 

135 inputs = {'x': x.quantized_.reshape(x_shape), 

136 'x_scale': x.scale_, 'x_zero_point': x.zero_point_, 

137 'w': w.quantized_.reshape(w_shape), 

138 'w_scale': w.scale_, 'w_zero_point': w.zero_point_, 

139 'y_scale': y.scale_, 'y_zero_point': y.zero_point_, 

140 'b': b.quantized_} 

141 

142 updated = {} 

143 for k in inputs: # pylint: disable=C0206 

144 v = inputs[k] 

145 if len(v.shape) == 0: 

146 updated[k] = numpy.array([v], dtype=v.dtype) 

147 inputs.update(updated) 

148 

149 node = OnnxQLinearConv(*inputs_list, output_names=['y'], 

150 op_version=opset, **kwargs) 

151 model_def = node.to_onnx(inputs, target_opset=opset) 

152 

153 oinf = OnnxInference( 

154 model_def, runtime=runtime, runtime_options=dict( 

155 log_severity_level=3)) 

156 got = oinf.run(inputs)['y'] 

157 expected = y.quantized_.reshape(y_shape) 

158 if got.dtype != expected.dtype: 

159 raise TypeError( # pragma: no cover 

160 f"Unexpected output dtype:\nEXPECTED\n{expected}\nGOT\n{got}") 

161 diff = numpy.abs(got.ravel().astype(numpy.float32) - 

162 expected.ravel().astype(numpy.float32)) 

163 mdiff = diff.max() 

164 if mdiff > 1e-5: 

165 raise ValueError( # pragma: no cover 

166 "Unexpected output maximum difference={}:\nEXPECTED\n{}\nGOT\n{}" 

167 "".format(mdiff, expected, got))