Coverage for mlprodict/onnx_tools/model_checker.py: 94%

80 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-04 02:28 +0100

1""" 

2@file 

3@brief Investigate issues happening with float32. 

4""" 

5from io import BytesIO 

6import numpy 

7from numpy.random import randint 

8from onnx import ModelProto, FunctionProto, GraphProto, load 

9from onnx.checker import check_model 

10 

11 

12class MissingInputError(RuntimeError): 

13 "Raised when an input is missing." 

14 pass 

15 

16 

17def astype_range(arr, dtype=numpy.float32, force=1): 

18 """ 

19 Computes ranges for every number in an array 

20 once converted into *float32*. The function returns 

21 two matrices which produces two numbers 

22 *a* et *b*, the number rounded to float32 

23 is in interval :math:`[a, b]`. 

24 

25 @param arr array 

26 @param dtype type to convert to 

27 @param force does something like *[i] +/- force |i - [i]|* 

28 @return minimum, maximum 

29 """ 

30 conv = arr.astype(dtype) 

31 delta = numpy.abs(arr - conv) 

32 delta = numpy.maximum(numpy.abs(arr) * 1e-7, delta) 

33 maxa = (conv + delta * force).astype(dtype) 

34 mina = (conv - delta * force).astype(dtype) 

35 return mina, maxa 

36 

37 

38def enumerate_random_inputs(inputs, n=100, dtype=numpy.float32, force=1): 

39 """ 

40 Enumerates random matrices. 

41 

42 @param inputs inputs (dictionary) 

43 @param n number of iterations 

44 @param dtype type to convert to 

45 @param force does something like *[i] +/- force |i - [i]|* 

46 """ 

47 keys = list(inputs) 

48 ranges = {k: astype_range(v, dtype=dtype, force=force) 

49 for k, v in inputs.items()} 

50 for _ in range(n): 

51 new_inputs = {} 

52 for k in keys: 

53 rnd = randint(0, 2, inputs[k].size).reshape( # pylint: disable=E1101 

54 inputs[k].shape) # pylint: disable=E1101 

55 if rnd.min() == rnd.max() or rnd.max() != 1: 

56 raise RuntimeError( # pragma: no cover 

57 "Minimum and maximum are equal or maximum is not 1. " 

58 "Randomness failed.") 

59 rnd = rnd.astype(dtype) 

60 ma1 = ranges[k][0] * rnd 

61 ma2 = ranges[k][1] * (-(rnd - 1)) 

62 inp = (ma1 + ma2) 

63 new_inputs[k] = inp 

64 yield new_inputs 

65 

66 

67def onnx_shaker(oinf, inputs, output_fct, n=100, dtype=numpy.float32, force=1): 

68 """ 

69 Shakes a model :epkg:`ONNX`. 

70 Explores the ranges for every prediction. 

71 Uses @see fn astype_range 

72 

73 @param oinf object of type @see cl OnnxInference 

74 @param inputs inputs 

75 @param output_fct output function which extracts 

76 a single array from the output 

77 @param dtype type to convert to 

78 @param force does something like *[i] +/- force |i - [i]|* 

79 @return ranges for each predictions 

80 

81 See notebook :ref:`onnxshakerrst` for an example of use. 

82 """ 

83 results = None 

84 for i, new_inputs in enumerate(enumerate_random_inputs( 

85 inputs, n=n, dtype=dtype, force=force)): 

86 res_ = oinf.run(new_inputs) 

87 res = output_fct(res_) 

88 sq = numpy.squeeze(res) 

89 if len(sq.shape) != 1: 

90 raise ValueError( # pragma: no cover 

91 f"The function only works with shape={sq.shape}") 

92 if results is None: 

93 results = numpy.empty((sq.shape[0], n), dtype=sq.dtype) 

94 results[:, i] = sq 

95 

96 results.sort(axis=1) 

97 return results 

98 

99 

100def check_onnx(model, use_onnx=False, known_results=None, 

101 path=None): 

102 """ 

103 Checks consistency of the model. 

104 

105 :param model: onnx graph 

106 :param use_onnx: calls `onnx.checker.check_model` 

107 :param known_results: known results 

108 :param path: path to a node (through subgraphs) 

109 """ 

110 if isinstance(model, bytes): 

111 model = load(BytesIO(model)) 

112 

113 def raise_missing(name, node, p, kn): 

114 raise MissingInputError( 

115 "Missing input %r in node type=%r and name=%r " 

116 "path=%r, known=\n%s\n--ONNX--\n%s" % ( 

117 name, node.op_type, node.name, 

118 [n.name for n in p], "\n".join(sorted(kn)), 

119 str(model))) 

120 

121 if isinstance(model, ModelProto): 

122 try: 

123 check_onnx(model.graph, known_results=known_results) 

124 except MissingInputError as e: 

125 raise MissingInputError( 

126 f"Wrong ONNX model\n--ONNX\n{str(model)}") from e 

127 for f in model.functions: 

128 check_onnx(f) 

129 return 

130 if known_results is None: 

131 known_results = {} 

132 else: 

133 known_results = known_results.copy() 

134 if isinstance(model, FunctionProto): 

135 for i in model.input: 

136 known_results[i] = i 

137 elif isinstance(model, GraphProto): 

138 for i in model.input: 

139 known_results[i.name] = i 

140 for i in model.initializer: 

141 known_results[i.name] = i 

142 else: 

143 raise TypeError( # pragma: no cover 

144 f"Unexpected type {type(model)!r}.") 

145 

146 if path is None: 

147 path = [] 

148 else: 

149 path = path.copy() 

150 

151 for node in model.node: 

152 for i in node.input: 

153 if i == '': 

154 # optional input 

155 continue 

156 if i not in known_results: 

157 raise_missing(i, node, path + [node], known_results) 

158 for att in node.attribute: 

159 if hasattr(att, 'g') and att.g is not None: 

160 check_onnx(att.g, use_onnx=use_onnx, 

161 known_results=known_results, 

162 path=path + [att, node]) 

163 for o in node.output: 

164 known_results[o] = node 

165 

166 if use_onnx: 

167 check_model(model)