Coverage for mlprodict/testing/test_utils/utils_backend_common_compare.py: 70%

84 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-04 02:28 +0100

1""" 

2@file 

3@brief Inspired from sklearn-onnx, handles two backends. 

4""" 

5import numpy 

6import onnx 

7import pandas 

8from .utils_backend_common import ( 

9 load_data_and_model, extract_options, 

10 ExpectedAssertionError, OnnxBackendAssertionError, 

11 OnnxRuntimeMissingNewOnnxOperatorException, 

12 _compare_expected, _create_column) 

13 

14 

15def compare_runtime_session( # pylint: disable=R0912 

16 cls_session, test, decimal=5, options=None, 

17 verbose=False, context=None, comparable_outputs=None, 

18 intermediate_steps=False, classes=None, 

19 disable_optimisation=False): 

20 """ 

21 The function compares the expected output (computed with 

22 the model before being converted to ONNX) and the ONNX output 

23 produced with module :epkg:`onnxruntime` or :epkg:`mlprodict`. 

24 

25 :param cls_session: inference session instance (like @see cl OnnxInference) 

26 :param test: dictionary with the following keys: 

27 - *onnx*: onnx model (filename or object) 

28 - *expected*: expected output (filename pkl or object) 

29 - *data*: input data (filename pkl or object) 

30 :param decimal: precision of the comparison 

31 :param options: comparison options 

32 :param context: specifies custom operators 

33 :param verbose: in case of error, the function may print 

34 more information on the standard output 

35 :param comparable_outputs: compare only these outputs 

36 :param intermediate_steps: displays intermediate steps 

37 in case of an error 

38 :param classes: classes names (if option 'nocl' is used) 

39 :param disable_optimisation: disable optimisation the runtime may do 

40 :return: tuple (outut, lambda function to run the predictions) 

41 

42 The function does not return anything but raises an error 

43 if the comparison failed. 

44 """ 

45 lambda_onnx = None 

46 if context is None: 

47 context = {} 

48 load = load_data_and_model(test, **context) 

49 if verbose: # pragma no cover 

50 print(f"[compare_runtime] test '{test['onnx']}' loaded") 

51 

52 onx = test['onnx'] 

53 

54 if options is None: 

55 if isinstance(onx, str): 

56 options = extract_options(onx) 

57 else: 

58 options = {} 

59 elif options is None: 

60 options = {} 

61 elif not isinstance(options, dict): 

62 raise TypeError( # pragma no cover 

63 "options must be a dictionary.") 

64 

65 if verbose: # pragma no cover 

66 print(f"[compare_runtime] InferenceSession('{onx}')") 

67 

68 runtime_options = dict(disable_optimisation=disable_optimisation) 

69 try: 

70 sess = cls_session(onx, runtime_options=runtime_options) 

71 except TypeError as et: # pragma: no cover 

72 raise TypeError( # pylint: disable=W0707 

73 f"Wrong signature for '{cls_session.__name__}' ({et}).") 

74 except ExpectedAssertionError as expe: # pragma no cover 

75 raise expe 

76 except Exception as e: # pylint: disable=W0703 

77 if "CannotLoad" in options: # pragma no cover 

78 raise ExpectedAssertionError( # pylint: disable=W0707 

79 f"Unable to load onnx '{onx}' due to\n{e}") 

80 else: # pragma no cover 

81 if verbose: # pragma no cover 

82 model = onnx.load(onx) 

83 smodel = "\nJSON ONNX\n" + str(model) 

84 else: 

85 smodel = "" 

86 if ("NOT_IMPLEMENTED : Could not find an implementation " 

87 "for the node" in str(e)): 

88 # onnxruntime does not implement a specific node yet. 

89 raise OnnxRuntimeMissingNewOnnxOperatorException( # pylint: disable=W0707 

90 "{3} does not implement a new operator " 

91 "'{0}'\n{1}\nONNX\n{2}".format( 

92 onx, e, smodel, cls_session)) 

93 if "NOT_IMPLEMENTED : Failed to find kernel" in str(e): 

94 # onnxruntime does not implement a specific node yet 

95 # in the kernel included in onnxruntime. 

96 raise OnnxBackendAssertionError( # pylint: disable=W0707 

97 "{3} misses a kernel for operator " 

98 "'{0}'\n{1}\nONNX\n{2}".format( 

99 onx, e, smodel, cls_session)) 

100 raise OnnxBackendAssertionError( # pylint: disable=W0707 

101 f"Unable to load onnx '{onx}'\nONNX\n{smodel}\n{e}") 

102 

103 input = load["data"] 

104 DF = options.pop('DF', False) 

105 if DF: 

106 inputs = {c: input[c].values for c in input.columns} 

107 for k in inputs: 

108 if inputs[k].dtype == numpy.float64: 

109 inputs[k] = inputs[k].astype(numpy.float32) 

110 inputs[k] = inputs[k].reshape((inputs[k].shape[0], 1)) 

111 else: 

112 if isinstance(input, dict): 

113 inputs = input 

114 elif isinstance(input, (list, numpy.ndarray, pandas.DataFrame)): 

115 inp = sess.get_inputs() 

116 outs = sess.get_outputs() 

117 if len(outs) == 0: 

118 raise OnnxBackendAssertionError( # pragma: no cover 

119 "Wrong number of outputs, onnx='{2}'".format(onx)) 

120 if len(inp) == len(input): 

121 inputs = {i.name: v for i, v in zip(inp, input)} 

122 elif len(inp) == 1: 

123 inputs = {inp[0].name: input} 

124 elif isinstance(input, numpy.ndarray): 

125 shape = sum(i.shape[1] if len(i.shape) == 2 else i.shape[0] 

126 for i in inp) 

127 if shape == input.shape[1]: 

128 inputs = {n.name: input[:, i] for i, n in enumerate(inp)} 

129 else: 

130 raise OnnxBackendAssertionError( # pragma: no cover 

131 "Wrong number of inputs onnx {0} != " 

132 "original shape {1}, onnx='{2}'" 

133 .format(len(inp), input.shape, onx)) 

134 elif isinstance(input, list): 

135 try: 

136 array_input = numpy.array(input) 

137 except Exception: # pragma no cover 

138 raise OnnxBackendAssertionError( # pylint: disable=W0707 

139 "Wrong number of inputs onnx {0} != " 

140 "original {1}, onnx='{2}'" 

141 .format(len(inp), len(input), onx)) 

142 shape = sum(i.shape[1] for i in inp) 

143 if shape == array_input.shape[1]: 

144 inputs = {} 

145 c = 0 

146 for i, n in enumerate(inp): 

147 d = c + n.shape[1] 

148 inputs[n.name] = _create_column( 

149 [row[c:d] for row in input], n.type) 

150 c = d 

151 else: 

152 raise OnnxBackendAssertionError( # pragma no cover 

153 "Wrong number of inputs onnx {0} != " 

154 "original shape {1}, onnx='{2}'*" 

155 .format(len(inp), array_input.shape, onx)) 

156 elif isinstance(input, pandas.DataFrame): 

157 try: 

158 array_input = numpy.array(input) 

159 except Exception: # pragma no cover 

160 raise OnnxBackendAssertionError( # pylint: disable=W0707 

161 "Wrong number of inputs onnx {0} != " 

162 "original {1}, onnx='{2}'" 

163 .format(len(inp), len(input), onx)) 

164 shape = sum(i.shape[1] for i in inp) 

165 if shape == array_input.shape[1]: 

166 inputs = {} 

167 c = 0 

168 for i, n in enumerate(inp): 

169 d = c + n.shape[1] 

170 inputs[n.name] = _create_column( 

171 input.iloc[:, c:d], n.type) 

172 c = d 

173 else: 

174 raise OnnxBackendAssertionError( # pragma no cover 

175 "Wrong number of inputs onnx {0}={1} columns != " 

176 "original shape {2}, onnx='{3}'*" 

177 .format(len(inp), shape, array_input.shape, onx)) 

178 else: 

179 raise OnnxBackendAssertionError( # pragma no cover 

180 f"Wrong type of inputs onnx {type(input)}, onnx='{onx}'") 

181 else: 

182 raise OnnxBackendAssertionError( # pragma no cover 

183 f"Dict or list is expected, not {type(input)}") 

184 

185 for k in inputs: 

186 if isinstance(inputs[k], list): 

187 inputs[k] = numpy.array(inputs[k]) 

188 

189 options.pop('SklCol', False) # unused here but in dump_data_and_model 

190 

191 if verbose: # pragma no cover 

192 print("[compare_runtime] type(inputs)={} len={} names={}".format( 

193 type(input), len(inputs), list(sorted(inputs)))) 

194 if verbose: # pragma no cover 

195 if intermediate_steps: 

196 run_options = {'verbose': 3, 'fLOG': print} 

197 else: 

198 run_options = {'verbose': 2, 'fLOG': print} 

199 else: 

200 run_options = {} 

201 

202 from onnxruntime.capi._pybind_state import ( # pylint: disable=E0611 

203 InvalidArgument as OrtInvalidArgument) 

204 

205 try: 

206 try: 

207 output = sess.run(None, inputs, **run_options) 

208 except TypeError: # pragma no cover 

209 output = sess.run(None, inputs) 

210 lambda_onnx = lambda: sess.run(None, inputs) # noqa 

211 if verbose: # pragma no cover 

212 import pprint 

213 pprint.pprint(output) 

214 except ExpectedAssertionError as expe: # pragma no cover 

215 raise expe 

216 except (RuntimeError, OrtInvalidArgument) as e: # pragma no cover 

217 if intermediate_steps: 

218 sess.run(None, inputs, verbose=3, fLOG=print) 

219 if "-Fail" in onx: 

220 raise ExpectedAssertionError( # pylint: disable=W0707 

221 f"{cls_session} cannot compute the prediction for '{onx}'") 

222 else: 

223 if verbose: # pragma no cover 

224 from ...plotting.text_plot import onnx_simple_text_plot 

225 model = onnx.load(onx) 

226 smodel = "\nJSON ONNX\n" + onnx_simple_text_plot( 

227 model, recursive=True, raise_exc=False) 

228 else: 

229 smodel = "" 

230 import pprint 

231 raise OnnxBackendAssertionError( # pylint: disable=W0707 

232 "{4} cannot compute the predictions" 

233 " for '{0}' due to {1}{2}\n{3}" 

234 .format(onx, e, smodel, pprint.pformat(inputs), 

235 cls_session)) 

236 except Exception as e: # pragma no cover 

237 raise OnnxBackendAssertionError( # pylint: disable=W0707 

238 f"Unable to run onnx '{onx}' due to {e}") 

239 if verbose: # pragma no cover 

240 print(f"[compare_runtime] done type={type(output)}") 

241 

242 output0 = output.copy() 

243 

244 if comparable_outputs: 

245 cmp_exp = [load["expected"][o] for o in comparable_outputs] 

246 cmp_out = [output[o] for o in comparable_outputs] 

247 else: 

248 cmp_exp = load["expected"] 

249 cmp_out = output 

250 

251 try: 

252 _compare_expected(cmp_exp, cmp_out, sess, onx, 

253 decimal=decimal, verbose=verbose, 

254 classes=classes, **options) 

255 except ExpectedAssertionError as expe: # pragma no cover 

256 raise expe 

257 except Exception as e: # pragma no cover 

258 if verbose: # pragma no cover 

259 model = onnx.load(onx) 

260 smodel = "\nJSON ONNX\n" + str(model) 

261 else: 

262 smodel = "" 

263 raise OnnxBackendAssertionError( # pylint: disable=W0707 

264 "Model '{}' has discrepencies with cls='{}'.\n{}: {}{}".format( 

265 onx, sess.__class__.__name__, type(e), e, smodel)) 

266 

267 return output0, lambda_onnx