Coverage for mlprodict/testing/test_utils/utils_backend_common_compare.py: 70%
84 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
1"""
2@file
3@brief Inspired from sklearn-onnx, handles two backends.
4"""
5import numpy
6import onnx
7import pandas
8from .utils_backend_common import (
9 load_data_and_model, extract_options,
10 ExpectedAssertionError, OnnxBackendAssertionError,
11 OnnxRuntimeMissingNewOnnxOperatorException,
12 _compare_expected, _create_column)
15def compare_runtime_session( # pylint: disable=R0912
16 cls_session, test, decimal=5, options=None,
17 verbose=False, context=None, comparable_outputs=None,
18 intermediate_steps=False, classes=None,
19 disable_optimisation=False):
20 """
21 The function compares the expected output (computed with
22 the model before being converted to ONNX) and the ONNX output
23 produced with module :epkg:`onnxruntime` or :epkg:`mlprodict`.
25 :param cls_session: inference session instance (like @see cl OnnxInference)
26 :param test: dictionary with the following keys:
27 - *onnx*: onnx model (filename or object)
28 - *expected*: expected output (filename pkl or object)
29 - *data*: input data (filename pkl or object)
30 :param decimal: precision of the comparison
31 :param options: comparison options
32 :param context: specifies custom operators
33 :param verbose: in case of error, the function may print
34 more information on the standard output
35 :param comparable_outputs: compare only these outputs
36 :param intermediate_steps: displays intermediate steps
37 in case of an error
38 :param classes: classes names (if option 'nocl' is used)
39 :param disable_optimisation: disable optimisation the runtime may do
40 :return: tuple (outut, lambda function to run the predictions)
42 The function does not return anything but raises an error
43 if the comparison failed.
44 """
45 lambda_onnx = None
46 if context is None:
47 context = {}
48 load = load_data_and_model(test, **context)
49 if verbose: # pragma no cover
50 print(f"[compare_runtime] test '{test['onnx']}' loaded")
52 onx = test['onnx']
54 if options is None:
55 if isinstance(onx, str):
56 options = extract_options(onx)
57 else:
58 options = {}
59 elif options is None:
60 options = {}
61 elif not isinstance(options, dict):
62 raise TypeError( # pragma no cover
63 "options must be a dictionary.")
65 if verbose: # pragma no cover
66 print(f"[compare_runtime] InferenceSession('{onx}')")
68 runtime_options = dict(disable_optimisation=disable_optimisation)
69 try:
70 sess = cls_session(onx, runtime_options=runtime_options)
71 except TypeError as et: # pragma: no cover
72 raise TypeError( # pylint: disable=W0707
73 f"Wrong signature for '{cls_session.__name__}' ({et}).")
74 except ExpectedAssertionError as expe: # pragma no cover
75 raise expe
76 except Exception as e: # pylint: disable=W0703
77 if "CannotLoad" in options: # pragma no cover
78 raise ExpectedAssertionError( # pylint: disable=W0707
79 f"Unable to load onnx '{onx}' due to\n{e}")
80 else: # pragma no cover
81 if verbose: # pragma no cover
82 model = onnx.load(onx)
83 smodel = "\nJSON ONNX\n" + str(model)
84 else:
85 smodel = ""
86 if ("NOT_IMPLEMENTED : Could not find an implementation "
87 "for the node" in str(e)):
88 # onnxruntime does not implement a specific node yet.
89 raise OnnxRuntimeMissingNewOnnxOperatorException( # pylint: disable=W0707
90 "{3} does not implement a new operator "
91 "'{0}'\n{1}\nONNX\n{2}".format(
92 onx, e, smodel, cls_session))
93 if "NOT_IMPLEMENTED : Failed to find kernel" in str(e):
94 # onnxruntime does not implement a specific node yet
95 # in the kernel included in onnxruntime.
96 raise OnnxBackendAssertionError( # pylint: disable=W0707
97 "{3} misses a kernel for operator "
98 "'{0}'\n{1}\nONNX\n{2}".format(
99 onx, e, smodel, cls_session))
100 raise OnnxBackendAssertionError( # pylint: disable=W0707
101 f"Unable to load onnx '{onx}'\nONNX\n{smodel}\n{e}")
103 input = load["data"]
104 DF = options.pop('DF', False)
105 if DF:
106 inputs = {c: input[c].values for c in input.columns}
107 for k in inputs:
108 if inputs[k].dtype == numpy.float64:
109 inputs[k] = inputs[k].astype(numpy.float32)
110 inputs[k] = inputs[k].reshape((inputs[k].shape[0], 1))
111 else:
112 if isinstance(input, dict):
113 inputs = input
114 elif isinstance(input, (list, numpy.ndarray, pandas.DataFrame)):
115 inp = sess.get_inputs()
116 outs = sess.get_outputs()
117 if len(outs) == 0:
118 raise OnnxBackendAssertionError( # pragma: no cover
119 "Wrong number of outputs, onnx='{2}'".format(onx))
120 if len(inp) == len(input):
121 inputs = {i.name: v for i, v in zip(inp, input)}
122 elif len(inp) == 1:
123 inputs = {inp[0].name: input}
124 elif isinstance(input, numpy.ndarray):
125 shape = sum(i.shape[1] if len(i.shape) == 2 else i.shape[0]
126 for i in inp)
127 if shape == input.shape[1]:
128 inputs = {n.name: input[:, i] for i, n in enumerate(inp)}
129 else:
130 raise OnnxBackendAssertionError( # pragma: no cover
131 "Wrong number of inputs onnx {0} != "
132 "original shape {1}, onnx='{2}'"
133 .format(len(inp), input.shape, onx))
134 elif isinstance(input, list):
135 try:
136 array_input = numpy.array(input)
137 except Exception: # pragma no cover
138 raise OnnxBackendAssertionError( # pylint: disable=W0707
139 "Wrong number of inputs onnx {0} != "
140 "original {1}, onnx='{2}'"
141 .format(len(inp), len(input), onx))
142 shape = sum(i.shape[1] for i in inp)
143 if shape == array_input.shape[1]:
144 inputs = {}
145 c = 0
146 for i, n in enumerate(inp):
147 d = c + n.shape[1]
148 inputs[n.name] = _create_column(
149 [row[c:d] for row in input], n.type)
150 c = d
151 else:
152 raise OnnxBackendAssertionError( # pragma no cover
153 "Wrong number of inputs onnx {0} != "
154 "original shape {1}, onnx='{2}'*"
155 .format(len(inp), array_input.shape, onx))
156 elif isinstance(input, pandas.DataFrame):
157 try:
158 array_input = numpy.array(input)
159 except Exception: # pragma no cover
160 raise OnnxBackendAssertionError( # pylint: disable=W0707
161 "Wrong number of inputs onnx {0} != "
162 "original {1}, onnx='{2}'"
163 .format(len(inp), len(input), onx))
164 shape = sum(i.shape[1] for i in inp)
165 if shape == array_input.shape[1]:
166 inputs = {}
167 c = 0
168 for i, n in enumerate(inp):
169 d = c + n.shape[1]
170 inputs[n.name] = _create_column(
171 input.iloc[:, c:d], n.type)
172 c = d
173 else:
174 raise OnnxBackendAssertionError( # pragma no cover
175 "Wrong number of inputs onnx {0}={1} columns != "
176 "original shape {2}, onnx='{3}'*"
177 .format(len(inp), shape, array_input.shape, onx))
178 else:
179 raise OnnxBackendAssertionError( # pragma no cover
180 f"Wrong type of inputs onnx {type(input)}, onnx='{onx}'")
181 else:
182 raise OnnxBackendAssertionError( # pragma no cover
183 f"Dict or list is expected, not {type(input)}")
185 for k in inputs:
186 if isinstance(inputs[k], list):
187 inputs[k] = numpy.array(inputs[k])
189 options.pop('SklCol', False) # unused here but in dump_data_and_model
191 if verbose: # pragma no cover
192 print("[compare_runtime] type(inputs)={} len={} names={}".format(
193 type(input), len(inputs), list(sorted(inputs))))
194 if verbose: # pragma no cover
195 if intermediate_steps:
196 run_options = {'verbose': 3, 'fLOG': print}
197 else:
198 run_options = {'verbose': 2, 'fLOG': print}
199 else:
200 run_options = {}
202 from onnxruntime.capi._pybind_state import ( # pylint: disable=E0611
203 InvalidArgument as OrtInvalidArgument)
205 try:
206 try:
207 output = sess.run(None, inputs, **run_options)
208 except TypeError: # pragma no cover
209 output = sess.run(None, inputs)
210 lambda_onnx = lambda: sess.run(None, inputs) # noqa
211 if verbose: # pragma no cover
212 import pprint
213 pprint.pprint(output)
214 except ExpectedAssertionError as expe: # pragma no cover
215 raise expe
216 except (RuntimeError, OrtInvalidArgument) as e: # pragma no cover
217 if intermediate_steps:
218 sess.run(None, inputs, verbose=3, fLOG=print)
219 if "-Fail" in onx:
220 raise ExpectedAssertionError( # pylint: disable=W0707
221 f"{cls_session} cannot compute the prediction for '{onx}'")
222 else:
223 if verbose: # pragma no cover
224 from ...plotting.text_plot import onnx_simple_text_plot
225 model = onnx.load(onx)
226 smodel = "\nJSON ONNX\n" + onnx_simple_text_plot(
227 model, recursive=True, raise_exc=False)
228 else:
229 smodel = ""
230 import pprint
231 raise OnnxBackendAssertionError( # pylint: disable=W0707
232 "{4} cannot compute the predictions"
233 " for '{0}' due to {1}{2}\n{3}"
234 .format(onx, e, smodel, pprint.pformat(inputs),
235 cls_session))
236 except Exception as e: # pragma no cover
237 raise OnnxBackendAssertionError( # pylint: disable=W0707
238 f"Unable to run onnx '{onx}' due to {e}")
239 if verbose: # pragma no cover
240 print(f"[compare_runtime] done type={type(output)}")
242 output0 = output.copy()
244 if comparable_outputs:
245 cmp_exp = [load["expected"][o] for o in comparable_outputs]
246 cmp_out = [output[o] for o in comparable_outputs]
247 else:
248 cmp_exp = load["expected"]
249 cmp_out = output
251 try:
252 _compare_expected(cmp_exp, cmp_out, sess, onx,
253 decimal=decimal, verbose=verbose,
254 classes=classes, **options)
255 except ExpectedAssertionError as expe: # pragma no cover
256 raise expe
257 except Exception as e: # pragma no cover
258 if verbose: # pragma no cover
259 model = onnx.load(onx)
260 smodel = "\nJSON ONNX\n" + str(model)
261 else:
262 smodel = ""
263 raise OnnxBackendAssertionError( # pylint: disable=W0707
264 "Model '{}' has discrepencies with cls='{}'.\n{}: {}{}".format(
265 onx, sess.__class__.__name__, type(e), e, smodel))
267 return output0, lambda_onnx