Coverage for mlprodict/testing/test_utils/utils_backend_common.py: 86%
180 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
1"""
2@file
3@brief Inspired from :epkg:`sklearn-onnx`, handles two backends.
4"""
5import os
6import pickle
7import numpy
8from numpy.testing import assert_array_almost_equal, assert_array_equal
9from scipy.sparse.csr import csr_matrix
10import pandas
11from ...onnxrt.ops_cpu.op_zipmap import ArrayZipMapDictionary
14class ExpectedAssertionError(Exception):
15 """
16 Expected failure.
17 """
18 pass
21class OnnxBackendAssertionError(AssertionError):
22 """
23 Expected failure.
24 """
25 pass
28class OnnxBackendMissingNewOnnxOperatorException(OnnxBackendAssertionError):
29 """
30 Raised when :epkg:`onnxruntime` or :epkg:`mlprodict`
31 does not implement a new operator
32 defined in the latest onnx.
33 """
34 pass
37class OnnxRuntimeMissingNewOnnxOperatorException(OnnxBackendAssertionError):
38 """
39 Raised when a new operator was added but cannot be found.
40 """
41 pass
44def evaluate_condition(backend, condition):
45 """
46 Evaluates a condition such as
47 ``StrictVersion(onnxruntime.__version__) <= StrictVersion('0.1.3')``
48 """
49 if backend == "onnxruntime": # pragma: no cover
50 import onnxruntime # pylint: disable=W0611
51 return eval(condition) # pylint: disable=W0123
52 raise NotImplementedError( # pragma no cover
53 f"Not implemented for backend '{backend}' and condition '{condition}'.")
56def is_backend_enabled(backend):
57 """
58 Tells if a backend is enabled.
59 Raises an exception if backend != 'onnxruntime'.
60 Unit tests only test models against this backend.
61 """
62 if backend in ("onnxruntime", "onnxruntime1"):
63 try:
64 import onnxruntime # pylint: disable=W0611
65 return True
66 except ImportError: # pragma no cover
67 return False
68 if backend == "python":
69 return True
70 raise NotImplementedError( # pragma no cover
71 f"Not implemented for backend '{backend}'")
74def load_data_and_model(items_as_dict, **context):
75 """
76 Loads every file in a dictionary {key: filename}.
77 The extension is either *pkl* and *onnx* and determines
78 how it it loaded. If the value is not a string,
79 the function assumes it was already loaded.
80 """
81 res = {}
82 for k, v in items_as_dict.items():
83 if isinstance(v, str):
84 if os.path.splitext(v)[-1] == ".pkl":
85 with open(v, "rb") as f: # pragma: no cover
86 try:
87 bin = pickle.load(f)
88 except ImportError as e:
89 if '.model.' in v:
90 continue
91 raise ImportError( # pylint: disable=W0707
92 f"Unable to load '{v}' due to {e}")
93 res[k] = bin
94 else:
95 res[k] = v
96 else:
97 res[k] = v
98 return res
101def extract_options(name):
102 """
103 Extracts comparison option from filename.
104 As example, ``Binarizer-SkipDim1`` means
105 options *SkipDim1* is enabled.
106 ``(1, 2)`` and ``(2,)`` are considered equal.
107 Available options: see :func:`dump_data_and_model`.
108 """
109 opts = name.replace("\\", "/").split("/")[-1].split('.')[0].split('-')
110 if len(opts) == 1:
111 return {}
112 res = {}
113 for opt in opts[1:]:
114 if opt in ("SkipDim1", "OneOff", "NoProb", "NoProbOpp",
115 "Dec4", "Dec3", "Dec2", "Dec1", 'Svm',
116 'Out0', 'Reshape', 'SklCol', 'DF', 'OneOffArray'):
117 res[opt] = True
118 else:
119 # pragma no cover
120 raise NameError(f"Unable to parse option '{opts[1:]}'")
121 return res
124def compare_outputs(expected, output, verbose=False, **kwargs):
125 """
126 Compares expected values and output.
127 Returns None if no error, an exception message otherwise.
128 """
129 SkipDim1 = kwargs.pop("SkipDim1", False)
130 NoProb = kwargs.pop("NoProb", False)
131 NoProbOpp = kwargs.pop("NoProbOpp", False)
132 Dec4 = kwargs.pop("Dec4", False)
133 Dec3 = kwargs.pop("Dec3", False)
134 Dec2 = kwargs.pop("Dec2", False)
135 Dec1 = kwargs.pop("Dec1", False)
136 Disc = kwargs.pop("Disc", False)
137 Mism = kwargs.pop("Mism", False)
139 if Dec4:
140 kwargs["decimal"] = min(kwargs["decimal"], 4)
141 if Dec3:
142 kwargs["decimal"] = min(kwargs["decimal"], 3)
143 if Dec2:
144 kwargs["decimal"] = min(kwargs["decimal"], 2) # pragma: no cover
145 if Dec1:
146 kwargs["decimal"] = min(kwargs["decimal"], 1)
147 if isinstance(expected, numpy.ndarray) and isinstance(
148 output, numpy.ndarray):
149 if SkipDim1:
150 # Arrays like (2, 1, 2, 3) becomes (2, 2, 3)
151 # as one dimension is useless.
152 expected = expected.reshape(
153 tuple([d for d in expected.shape if d > 1]))
154 output = output.reshape(
155 tuple([d for d in expected.shape if d > 1]))
156 if NoProb or NoProbOpp:
157 # One vector is (N,) with scores, negative for class 0
158 # positive for class 1
159 # The other vector is (N, 2) score in two columns.
160 if len(output.shape) == 2 and output.shape[1] == 2 and len(
161 expected.shape) == 1:
162 output = output[:, 1]
163 if NoProbOpp:
164 output = -output
165 elif len(output.shape) == 1 and len(expected.shape) == 1:
166 pass
167 elif len(expected.shape) == 1 and len(output.shape) == 2 and \
168 expected.shape[0] == output.shape[0] and \
169 output.shape[1] == 1:
170 output = output[:, 0]
171 if NoProbOpp:
172 output = -output
173 elif expected.shape != output.shape:
174 raise NotImplementedError("Shape mismatch: {0} != {1}".format( # pragma no cover
175 expected.shape, output.shape))
176 if len(expected.shape) == 1 and len(
177 output.shape) == 2 and output.shape[1] == 1:
178 output = output.ravel()
179 if len(output.shape) == 3 and output.shape[0] == 1 and len(
180 expected.shape) == 2:
181 output = output.reshape(output.shape[1:])
182 if expected.dtype in (numpy.str_, numpy.dtype("<U1"),
183 numpy.dtype("<U3")):
184 try:
185 assert_array_equal(expected, output, verbose=verbose)
186 except Exception as e: # pylint: disable=W0703
187 if Disc: # pragma no cover
188 # Bug to be fixed later.
189 return ExpectedAssertionError(str(e))
190 else: # pragma no cover
191 return OnnxBackendAssertionError(str(e))
192 else:
193 if 'OneOff' in kwargs:
194 kwargs = kwargs.copy()
195 kwargs.pop('OneOff')
196 if expected.shape != output.shape:
197 raise NotImplementedError(
198 f"Unable to deal with sort of shapes "
199 f"{expected.shape!r} != {output.shape!r}.")
200 try:
201 assert_array_almost_equal(expected,
202 output,
203 verbose=verbose,
204 **kwargs)
205 except (RuntimeError, AssertionError, TypeError) as e: # pragma no cover
206 longer = "\n--EXPECTED--\n{0}\n--OUTPUT--\n{1}".format(
207 expected, output) if verbose else ""
208 expected_ = numpy.asarray(expected).ravel()
209 output_ = numpy.asarray(output).ravel()
210 if len(expected_) == len(output_):
211 if numpy.issubdtype(expected_.dtype, numpy.floating):
212 diff = numpy.abs(expected_ - output_).max()
213 else:
214 diff = max((1 if ci != cj else 0)
215 for ci, cj in zip(expected_, output_))
216 if diff == 0:
217 return None
218 elif Mism:
219 return ExpectedAssertionError(
220 "dimension mismatch={0}, {1}\n{2}{3}".format(
221 expected.shape, output.shape, e, longer))
222 else:
223 return OnnxBackendAssertionError(
224 "dimension mismatch={0}, {1}\n{2}{3}".format(
225 expected.shape, output.shape, e, longer))
226 if Disc:
227 # Bug to be fixed later.
228 return ExpectedAssertionError(
229 f"max-diff={diff}\n--expected--output--\n{e}{longer}")
230 return OnnxBackendAssertionError(
231 f"max-diff={diff}\n--expected--output--\n{e}{longer}")
232 else:
233 return OnnxBackendAssertionError( # pragma: no cover
234 f"Unexpected types {type(expected)} != {type(output)}")
235 return None
238def _post_process_output(res):
239 """
240 Applies post processings before running the comparison
241 such as changing type from list to arrays.
242 """
243 if isinstance(res, list):
244 if len(res) == 0:
245 return res
246 if len(res) == 1:
247 return _post_process_output(res[0])
248 if isinstance(res[0], numpy.ndarray):
249 return numpy.array(res)
250 if isinstance(res[0], dict):
251 return pandas.DataFrame(res).values
252 ls = [len(r) for r in res]
253 mi = min(ls)
254 if mi != max(ls):
255 raise NotImplementedError( # pragma no cover
256 "Unable to postprocess various number of "
257 "outputs in [{0}, {1}]"
258 .format(min(ls), max(ls)))
259 if mi > 1:
260 output = []
261 for i in range(mi):
262 output.append(_post_process_output([r[i] for r in res]))
263 return output
264 if isinstance(res[0], list):
265 # list of lists
266 if isinstance(res[0][0], list):
267 return numpy.array(res)
268 if len(res[0]) == 1 and isinstance(res[0][0], dict):
269 return _post_process_output([r[0] for r in res])
270 if len(res) == 1:
271 return res
272 if len(res[0]) != 1:
273 raise NotImplementedError( # pragma no cover
274 f"Not conversion implemented for {res}")
275 st = [r[0] for r in res]
276 return numpy.vstack(st)
277 return res
278 return res
281def _create_column(values, dtype):
282 "Creates a column from values with dtype"
283 if str(dtype) == "tensor(int64)":
284 return numpy.array(values, dtype=numpy.int64)
285 if str(dtype) == "tensor(float)":
286 return numpy.array(values, dtype=numpy.float32)
287 if str(dtype) in ("tensor(double)", "tensor(float64)"):
288 return numpy.array(values, dtype=numpy.float64)
289 if str(dtype) in ("tensor(string)", "tensor(str)"):
290 return numpy.array(values, dtype=numpy.str_)
291 raise OnnxBackendAssertionError(
292 f"Unable to create one column from dtype '{dtype}'")
295def _compare_expected(expected, output, sess, onnx_model,
296 decimal=5, verbose=False, classes=None,
297 **kwargs):
298 """
299 Compares the expected output against the runtime outputs.
300 This is specific to :epkg:`onnxruntime` or :epkg:`mlprodict`.
301 """
302 tested = 0
303 if isinstance(expected, list):
304 if isinstance(output, list):
305 if 'Out0' in kwargs:
306 expected = expected[:1]
307 output = output[:1]
308 del kwargs['Out0']
309 if 'Reshape' in kwargs:
310 del kwargs['Reshape']
311 output = numpy.hstack(output).ravel()
312 output = output.reshape(
313 (len(expected), len(output.ravel()) // len(expected)))
314 if len(expected) != len(output):
315 raise OnnxBackendAssertionError( # pragma no cover
316 "Unexpected number of outputs '{0}', expected={1}, got={2}"
317 .format(onnx_model, len(expected), len(output)))
318 for exp, out in zip(expected, output):
319 _compare_expected(exp, out, sess, onnx_model, decimal=5, verbose=verbose,
320 classes=classes, **kwargs)
321 tested += 1
322 else:
323 raise OnnxBackendAssertionError( # pragma no cover
324 f"Type mismatch for '{onnx_model}', output type is {type(output)}")
325 elif isinstance(expected, dict):
326 if not isinstance(output, dict):
327 raise OnnxBackendAssertionError( # pragma no cover
328 f"Type mismatch for '{onnx_model}'")
329 for k, v in output.items():
330 if k not in expected:
331 continue
332 msg = compare_outputs(
333 expected[k], v, decimal=decimal, verbose=verbose, **kwargs)
334 if msg:
335 raise OnnxBackendAssertionError( # pragma no cover
336 f"Unexpected output '{k}' in model '{onnx_model}'\n{msg}")
337 tested += 1
338 elif isinstance(expected, numpy.ndarray):
339 if isinstance(output, list):
340 if expected.shape[0] == len(output) and isinstance(
341 output[0], dict):
342 if isinstance(output, ArrayZipMapDictionary):
343 output = pandas.DataFrame(list(output))
344 else:
345 output = pandas.DataFrame(output)
346 output = output[list(sorted(output.columns))]
347 output = output.values
348 if isinstance(output, (dict, list)):
349 if len(output) != 1: # pragma: no cover
350 ex = str(output)
351 if len(ex) > 170:
352 ex = ex[:170] + "..."
353 raise OnnxBackendAssertionError(
354 "More than one output when 1 is expected "
355 "for onnx '{0}'\n{1}"
356 .format(onnx_model, ex))
357 output = output[-1]
358 if not isinstance(output, numpy.ndarray):
359 raise OnnxBackendAssertionError( # pragma no cover
360 f"output must be an array for onnx '{onnx_model}' not {type(output)}")
361 if (classes is not None and (
362 expected.dtype == numpy.str_ or expected.dtype.char == 'U')):
363 try:
364 output = numpy.array([classes[cl] for cl in output])
365 except IndexError as e: # pragma no cover
366 raise RuntimeError('Unable to handle\n{}\n{}\n{}'.format(
367 expected, output, classes)) from e
368 msg = compare_outputs(
369 expected, output, decimal=decimal, verbose=verbose, **kwargs)
370 if isinstance(msg, ExpectedAssertionError):
371 raise msg # pylint: disable=E0702
372 if msg:
373 raise OnnxBackendAssertionError( # pragma no cover
374 f"Unexpected output in model '{onnx_model}'\n{msg}")
375 tested += 1
376 else:
377 if isinstance(expected, csr_matrix):
378 # DictVectorizer
379 one_array = numpy.array(output)
380 dense = numpy.asarray(expected.todense())
381 msg = compare_outputs(dense, one_array, decimal=decimal,
382 verbose=verbose, **kwargs)
383 if msg:
384 raise OnnxBackendAssertionError( # pragma no cover
385 f"Unexpected output in model '{onnx_model}'\n{msg}")
386 tested += 1
387 else:
388 raise OnnxBackendAssertionError( # pragma no cover
389 "Unexpected type for expected output ({1}) and onnx '{0}'".
390 format(onnx_model, type(expected)))
391 if tested == 0:
392 raise OnnxBackendAssertionError( # pragma no cover
393 f"No test for onnx '{onnx_model}'")