Coverage for mlprodict/testing/test_utils/utils_backend_common.py: 86%

180 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-04 02:28 +0100

1""" 

2@file 

3@brief Inspired from :epkg:`sklearn-onnx`, handles two backends. 

4""" 

5import os 

6import pickle 

7import numpy 

8from numpy.testing import assert_array_almost_equal, assert_array_equal 

9from scipy.sparse.csr import csr_matrix 

10import pandas 

11from ...onnxrt.ops_cpu.op_zipmap import ArrayZipMapDictionary 

12 

13 

14class ExpectedAssertionError(Exception): 

15 """ 

16 Expected failure. 

17 """ 

18 pass 

19 

20 

21class OnnxBackendAssertionError(AssertionError): 

22 """ 

23 Expected failure. 

24 """ 

25 pass 

26 

27 

28class OnnxBackendMissingNewOnnxOperatorException(OnnxBackendAssertionError): 

29 """ 

30 Raised when :epkg:`onnxruntime` or :epkg:`mlprodict` 

31 does not implement a new operator 

32 defined in the latest onnx. 

33 """ 

34 pass 

35 

36 

37class OnnxRuntimeMissingNewOnnxOperatorException(OnnxBackendAssertionError): 

38 """ 

39 Raised when a new operator was added but cannot be found. 

40 """ 

41 pass 

42 

43 

44def evaluate_condition(backend, condition): 

45 """ 

46 Evaluates a condition such as 

47 ``StrictVersion(onnxruntime.__version__) <= StrictVersion('0.1.3')`` 

48 """ 

49 if backend == "onnxruntime": # pragma: no cover 

50 import onnxruntime # pylint: disable=W0611 

51 return eval(condition) # pylint: disable=W0123 

52 raise NotImplementedError( # pragma no cover 

53 f"Not implemented for backend '{backend}' and condition '{condition}'.") 

54 

55 

56def is_backend_enabled(backend): 

57 """ 

58 Tells if a backend is enabled. 

59 Raises an exception if backend != 'onnxruntime'. 

60 Unit tests only test models against this backend. 

61 """ 

62 if backend in ("onnxruntime", "onnxruntime1"): 

63 try: 

64 import onnxruntime # pylint: disable=W0611 

65 return True 

66 except ImportError: # pragma no cover 

67 return False 

68 if backend == "python": 

69 return True 

70 raise NotImplementedError( # pragma no cover 

71 f"Not implemented for backend '{backend}'") 

72 

73 

74def load_data_and_model(items_as_dict, **context): 

75 """ 

76 Loads every file in a dictionary {key: filename}. 

77 The extension is either *pkl* and *onnx* and determines 

78 how it it loaded. If the value is not a string, 

79 the function assumes it was already loaded. 

80 """ 

81 res = {} 

82 for k, v in items_as_dict.items(): 

83 if isinstance(v, str): 

84 if os.path.splitext(v)[-1] == ".pkl": 

85 with open(v, "rb") as f: # pragma: no cover 

86 try: 

87 bin = pickle.load(f) 

88 except ImportError as e: 

89 if '.model.' in v: 

90 continue 

91 raise ImportError( # pylint: disable=W0707 

92 f"Unable to load '{v}' due to {e}") 

93 res[k] = bin 

94 else: 

95 res[k] = v 

96 else: 

97 res[k] = v 

98 return res 

99 

100 

101def extract_options(name): 

102 """ 

103 Extracts comparison option from filename. 

104 As example, ``Binarizer-SkipDim1`` means 

105 options *SkipDim1* is enabled. 

106 ``(1, 2)`` and ``(2,)`` are considered equal. 

107 Available options: see :func:`dump_data_and_model`. 

108 """ 

109 opts = name.replace("\\", "/").split("/")[-1].split('.')[0].split('-') 

110 if len(opts) == 1: 

111 return {} 

112 res = {} 

113 for opt in opts[1:]: 

114 if opt in ("SkipDim1", "OneOff", "NoProb", "NoProbOpp", 

115 "Dec4", "Dec3", "Dec2", "Dec1", 'Svm', 

116 'Out0', 'Reshape', 'SklCol', 'DF', 'OneOffArray'): 

117 res[opt] = True 

118 else: 

119 # pragma no cover 

120 raise NameError(f"Unable to parse option '{opts[1:]}'") 

121 return res 

122 

123 

124def compare_outputs(expected, output, verbose=False, **kwargs): 

125 """ 

126 Compares expected values and output. 

127 Returns None if no error, an exception message otherwise. 

128 """ 

129 SkipDim1 = kwargs.pop("SkipDim1", False) 

130 NoProb = kwargs.pop("NoProb", False) 

131 NoProbOpp = kwargs.pop("NoProbOpp", False) 

132 Dec4 = kwargs.pop("Dec4", False) 

133 Dec3 = kwargs.pop("Dec3", False) 

134 Dec2 = kwargs.pop("Dec2", False) 

135 Dec1 = kwargs.pop("Dec1", False) 

136 Disc = kwargs.pop("Disc", False) 

137 Mism = kwargs.pop("Mism", False) 

138 

139 if Dec4: 

140 kwargs["decimal"] = min(kwargs["decimal"], 4) 

141 if Dec3: 

142 kwargs["decimal"] = min(kwargs["decimal"], 3) 

143 if Dec2: 

144 kwargs["decimal"] = min(kwargs["decimal"], 2) # pragma: no cover 

145 if Dec1: 

146 kwargs["decimal"] = min(kwargs["decimal"], 1) 

147 if isinstance(expected, numpy.ndarray) and isinstance( 

148 output, numpy.ndarray): 

149 if SkipDim1: 

150 # Arrays like (2, 1, 2, 3) becomes (2, 2, 3) 

151 # as one dimension is useless. 

152 expected = expected.reshape( 

153 tuple([d for d in expected.shape if d > 1])) 

154 output = output.reshape( 

155 tuple([d for d in expected.shape if d > 1])) 

156 if NoProb or NoProbOpp: 

157 # One vector is (N,) with scores, negative for class 0 

158 # positive for class 1 

159 # The other vector is (N, 2) score in two columns. 

160 if len(output.shape) == 2 and output.shape[1] == 2 and len( 

161 expected.shape) == 1: 

162 output = output[:, 1] 

163 if NoProbOpp: 

164 output = -output 

165 elif len(output.shape) == 1 and len(expected.shape) == 1: 

166 pass 

167 elif len(expected.shape) == 1 and len(output.shape) == 2 and \ 

168 expected.shape[0] == output.shape[0] and \ 

169 output.shape[1] == 1: 

170 output = output[:, 0] 

171 if NoProbOpp: 

172 output = -output 

173 elif expected.shape != output.shape: 

174 raise NotImplementedError("Shape mismatch: {0} != {1}".format( # pragma no cover 

175 expected.shape, output.shape)) 

176 if len(expected.shape) == 1 and len( 

177 output.shape) == 2 and output.shape[1] == 1: 

178 output = output.ravel() 

179 if len(output.shape) == 3 and output.shape[0] == 1 and len( 

180 expected.shape) == 2: 

181 output = output.reshape(output.shape[1:]) 

182 if expected.dtype in (numpy.str_, numpy.dtype("<U1"), 

183 numpy.dtype("<U3")): 

184 try: 

185 assert_array_equal(expected, output, verbose=verbose) 

186 except Exception as e: # pylint: disable=W0703 

187 if Disc: # pragma no cover 

188 # Bug to be fixed later. 

189 return ExpectedAssertionError(str(e)) 

190 else: # pragma no cover 

191 return OnnxBackendAssertionError(str(e)) 

192 else: 

193 if 'OneOff' in kwargs: 

194 kwargs = kwargs.copy() 

195 kwargs.pop('OneOff') 

196 if expected.shape != output.shape: 

197 raise NotImplementedError( 

198 f"Unable to deal with sort of shapes " 

199 f"{expected.shape!r} != {output.shape!r}.") 

200 try: 

201 assert_array_almost_equal(expected, 

202 output, 

203 verbose=verbose, 

204 **kwargs) 

205 except (RuntimeError, AssertionError, TypeError) as e: # pragma no cover 

206 longer = "\n--EXPECTED--\n{0}\n--OUTPUT--\n{1}".format( 

207 expected, output) if verbose else "" 

208 expected_ = numpy.asarray(expected).ravel() 

209 output_ = numpy.asarray(output).ravel() 

210 if len(expected_) == len(output_): 

211 if numpy.issubdtype(expected_.dtype, numpy.floating): 

212 diff = numpy.abs(expected_ - output_).max() 

213 else: 

214 diff = max((1 if ci != cj else 0) 

215 for ci, cj in zip(expected_, output_)) 

216 if diff == 0: 

217 return None 

218 elif Mism: 

219 return ExpectedAssertionError( 

220 "dimension mismatch={0}, {1}\n{2}{3}".format( 

221 expected.shape, output.shape, e, longer)) 

222 else: 

223 return OnnxBackendAssertionError( 

224 "dimension mismatch={0}, {1}\n{2}{3}".format( 

225 expected.shape, output.shape, e, longer)) 

226 if Disc: 

227 # Bug to be fixed later. 

228 return ExpectedAssertionError( 

229 f"max-diff={diff}\n--expected--output--\n{e}{longer}") 

230 return OnnxBackendAssertionError( 

231 f"max-diff={diff}\n--expected--output--\n{e}{longer}") 

232 else: 

233 return OnnxBackendAssertionError( # pragma: no cover 

234 f"Unexpected types {type(expected)} != {type(output)}") 

235 return None 

236 

237 

238def _post_process_output(res): 

239 """ 

240 Applies post processings before running the comparison 

241 such as changing type from list to arrays. 

242 """ 

243 if isinstance(res, list): 

244 if len(res) == 0: 

245 return res 

246 if len(res) == 1: 

247 return _post_process_output(res[0]) 

248 if isinstance(res[0], numpy.ndarray): 

249 return numpy.array(res) 

250 if isinstance(res[0], dict): 

251 return pandas.DataFrame(res).values 

252 ls = [len(r) for r in res] 

253 mi = min(ls) 

254 if mi != max(ls): 

255 raise NotImplementedError( # pragma no cover 

256 "Unable to postprocess various number of " 

257 "outputs in [{0}, {1}]" 

258 .format(min(ls), max(ls))) 

259 if mi > 1: 

260 output = [] 

261 for i in range(mi): 

262 output.append(_post_process_output([r[i] for r in res])) 

263 return output 

264 if isinstance(res[0], list): 

265 # list of lists 

266 if isinstance(res[0][0], list): 

267 return numpy.array(res) 

268 if len(res[0]) == 1 and isinstance(res[0][0], dict): 

269 return _post_process_output([r[0] for r in res]) 

270 if len(res) == 1: 

271 return res 

272 if len(res[0]) != 1: 

273 raise NotImplementedError( # pragma no cover 

274 f"Not conversion implemented for {res}") 

275 st = [r[0] for r in res] 

276 return numpy.vstack(st) 

277 return res 

278 return res 

279 

280 

281def _create_column(values, dtype): 

282 "Creates a column from values with dtype" 

283 if str(dtype) == "tensor(int64)": 

284 return numpy.array(values, dtype=numpy.int64) 

285 if str(dtype) == "tensor(float)": 

286 return numpy.array(values, dtype=numpy.float32) 

287 if str(dtype) in ("tensor(double)", "tensor(float64)"): 

288 return numpy.array(values, dtype=numpy.float64) 

289 if str(dtype) in ("tensor(string)", "tensor(str)"): 

290 return numpy.array(values, dtype=numpy.str_) 

291 raise OnnxBackendAssertionError( 

292 f"Unable to create one column from dtype '{dtype}'") 

293 

294 

295def _compare_expected(expected, output, sess, onnx_model, 

296 decimal=5, verbose=False, classes=None, 

297 **kwargs): 

298 """ 

299 Compares the expected output against the runtime outputs. 

300 This is specific to :epkg:`onnxruntime` or :epkg:`mlprodict`. 

301 """ 

302 tested = 0 

303 if isinstance(expected, list): 

304 if isinstance(output, list): 

305 if 'Out0' in kwargs: 

306 expected = expected[:1] 

307 output = output[:1] 

308 del kwargs['Out0'] 

309 if 'Reshape' in kwargs: 

310 del kwargs['Reshape'] 

311 output = numpy.hstack(output).ravel() 

312 output = output.reshape( 

313 (len(expected), len(output.ravel()) // len(expected))) 

314 if len(expected) != len(output): 

315 raise OnnxBackendAssertionError( # pragma no cover 

316 "Unexpected number of outputs '{0}', expected={1}, got={2}" 

317 .format(onnx_model, len(expected), len(output))) 

318 for exp, out in zip(expected, output): 

319 _compare_expected(exp, out, sess, onnx_model, decimal=5, verbose=verbose, 

320 classes=classes, **kwargs) 

321 tested += 1 

322 else: 

323 raise OnnxBackendAssertionError( # pragma no cover 

324 f"Type mismatch for '{onnx_model}', output type is {type(output)}") 

325 elif isinstance(expected, dict): 

326 if not isinstance(output, dict): 

327 raise OnnxBackendAssertionError( # pragma no cover 

328 f"Type mismatch for '{onnx_model}'") 

329 for k, v in output.items(): 

330 if k not in expected: 

331 continue 

332 msg = compare_outputs( 

333 expected[k], v, decimal=decimal, verbose=verbose, **kwargs) 

334 if msg: 

335 raise OnnxBackendAssertionError( # pragma no cover 

336 f"Unexpected output '{k}' in model '{onnx_model}'\n{msg}") 

337 tested += 1 

338 elif isinstance(expected, numpy.ndarray): 

339 if isinstance(output, list): 

340 if expected.shape[0] == len(output) and isinstance( 

341 output[0], dict): 

342 if isinstance(output, ArrayZipMapDictionary): 

343 output = pandas.DataFrame(list(output)) 

344 else: 

345 output = pandas.DataFrame(output) 

346 output = output[list(sorted(output.columns))] 

347 output = output.values 

348 if isinstance(output, (dict, list)): 

349 if len(output) != 1: # pragma: no cover 

350 ex = str(output) 

351 if len(ex) > 170: 

352 ex = ex[:170] + "..." 

353 raise OnnxBackendAssertionError( 

354 "More than one output when 1 is expected " 

355 "for onnx '{0}'\n{1}" 

356 .format(onnx_model, ex)) 

357 output = output[-1] 

358 if not isinstance(output, numpy.ndarray): 

359 raise OnnxBackendAssertionError( # pragma no cover 

360 f"output must be an array for onnx '{onnx_model}' not {type(output)}") 

361 if (classes is not None and ( 

362 expected.dtype == numpy.str_ or expected.dtype.char == 'U')): 

363 try: 

364 output = numpy.array([classes[cl] for cl in output]) 

365 except IndexError as e: # pragma no cover 

366 raise RuntimeError('Unable to handle\n{}\n{}\n{}'.format( 

367 expected, output, classes)) from e 

368 msg = compare_outputs( 

369 expected, output, decimal=decimal, verbose=verbose, **kwargs) 

370 if isinstance(msg, ExpectedAssertionError): 

371 raise msg # pylint: disable=E0702 

372 if msg: 

373 raise OnnxBackendAssertionError( # pragma no cover 

374 f"Unexpected output in model '{onnx_model}'\n{msg}") 

375 tested += 1 

376 else: 

377 if isinstance(expected, csr_matrix): 

378 # DictVectorizer 

379 one_array = numpy.array(output) 

380 dense = numpy.asarray(expected.todense()) 

381 msg = compare_outputs(dense, one_array, decimal=decimal, 

382 verbose=verbose, **kwargs) 

383 if msg: 

384 raise OnnxBackendAssertionError( # pragma no cover 

385 f"Unexpected output in model '{onnx_model}'\n{msg}") 

386 tested += 1 

387 else: 

388 raise OnnxBackendAssertionError( # pragma no cover 

389 "Unexpected type for expected output ({1}) and onnx '{0}'". 

390 format(onnx_model, type(expected))) 

391 if tested == 0: 

392 raise OnnxBackendAssertionError( # pragma no cover 

393 f"No test for onnx '{onnx_model}'")