Coverage for mlprodict/testing/onnx_backend.py: 95%
167 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
1"""
2@file
3@brief Tests with onnx backend.
4"""
5import os
6import textwrap
7import numpy
8try:
9 # new numpy
10 from numpy import object_ as dtype_object
11except ImportError:
12 # old numpy
13 from numpy import object as dtype_object
14from numpy.testing import assert_almost_equal
15import onnx
16from onnx.numpy_helper import to_array, to_list
17from onnx.backend.test import __file__ as backend_folder
20def assert_almost_equal_string(expected, value):
21 """
22 Compares two arrays knowing they contain strings.
23 Raises an exception if the test fails.
25 :param expected: expected array
26 :param value: value
27 """
28 def is_float(x):
29 try:
30 return True
31 except ValueError: # pragma: no cover
32 return False
34 if all(map(is_float, expected.ravel())):
35 expected_float = expected.astype(numpy.float32)
36 value_float = value.astype(numpy.float32)
37 assert_almost_equal(expected_float, value_float)
38 else:
39 assert_almost_equal(expected, value)
42class OnnxBackendTest:
43 """
44 Definition of a backend test. It starts with a folder,
45 in this folder, one onnx file must be there, then a subfolder
46 for each test to run with this model.
48 :param folder: test folder
49 :param onnx_path: onnx file
50 :param onnx_model: loaded onnx file
51 :param tests: list of test
52 """
53 @staticmethod
54 def _sort(filenames):
55 temp = []
56 for f in filenames:
57 name = os.path.splitext(f)[0]
58 i = name.split('_')[-1]
59 temp.append((int(i), f))
60 temp.sort()
61 return [_[1] for _ in temp]
63 @staticmethod
64 def _read_proto_from_file(full):
65 if not os.path.exists(full):
66 raise FileNotFoundError( # pragma: no cover
67 f"File not found: {full!r}.")
68 with open(full, 'rb') as f:
69 serialized = f.read()
70 try:
71 loaded = to_array(onnx.load_tensor_from_string(serialized))
72 except Exception as e: # pylint: disable=W0703
73 seq = onnx.SequenceProto() # pylint: disable=E1101
74 try:
75 seq.ParseFromString(serialized)
76 loaded = to_list(seq)
77 except Exception: # pylint: disable=W0703
78 try:
79 loaded = onnx.load_model_from_string(serialized)
80 except Exception: # pragma: no cover
81 raise RuntimeError(
82 "Unable to read %r, error is %s, content is %r." % (
83 full, e, serialized[:100])) from e
84 return loaded
86 @staticmethod
87 def _load(folder, names):
88 res = []
89 for name in names:
90 full = os.path.join(folder, name)
91 new_tensor = OnnxBackendTest._read_proto_from_file(full)
92 if isinstance(new_tensor, (
93 numpy.ndarray, onnx.ModelProto, list)): # pylint: disable=E1101
94 t = new_tensor
95 elif isinstance(new_tensor, onnx.TensorProto): # pylint: disable=E1101
96 t = to_array(new_tensor)
97 else:
98 raise RuntimeError( # pragma: no cover
99 f"Unexpected type {type(new_tensor)!r} for {full!r}.")
100 res.append(t)
101 return res
103 def __repr__(self):
104 "usual"
105 return f"{self.__class__.__name__}({self.folder!r})"
107 def __init__(self, folder):
108 if not os.path.exists(folder):
109 raise FileNotFoundError( # pragma: no cover
110 f"Unable to find folder {folder!r}.")
111 content = os.listdir(folder)
112 onx = [c for c in content if os.path.splitext(c)[-1] in {'.onnx'}]
113 if len(onx) != 1:
114 raise ValueError( # pragma: no cover
115 f"There is more than one onnx file in {folder!r} ({onx!r}).")
116 self.folder = folder
117 self.onnx_path = os.path.join(folder, onx[0])
118 self.onnx_model = onnx.load(self.onnx_path)
120 self.tests = []
121 for sub in content:
122 full = os.path.join(folder, sub)
123 if os.path.isdir(full):
124 pb = [c for c in os.listdir(full)
125 if os.path.splitext(c)[-1] in {'.pb'}]
126 inputs = OnnxBackendTest._sort(
127 c for c in pb if c.startswith('input_'))
128 outputs = OnnxBackendTest._sort(
129 c for c in pb if c.startswith('output_'))
131 t = dict(
132 inputs=OnnxBackendTest._load(full, inputs),
133 outputs=OnnxBackendTest._load(full, outputs))
134 self.tests.append(t)
136 @property
137 def name(self):
138 "Returns the test name."
139 return os.path.split(self.folder)[-1]
141 def __len__(self):
142 "Returns the number of tests."
143 return len(self.tests)
145 def _compare_results(self, index, i, e, o, decimal=None):
146 """
147 Compares the expected output and the output produced
148 by the runtime. Raises an exception if not equal.
150 :param index: test index
151 :param i: output index
152 :param e: expected output
153 :param o: output
154 :param decimal: precision
155 """
156 if isinstance(e, numpy.ndarray):
157 if isinstance(o, numpy.ndarray):
158 if decimal is None:
159 if e.dtype == numpy.float32:
160 deci = 6
161 elif e.dtype == numpy.float64:
162 deci = 12
163 else:
164 deci = 7
165 else:
166 deci = decimal
167 if e.dtype == dtype_object:
168 try:
169 assert_almost_equal_string(e, o)
170 except AssertionError as ex:
171 raise AssertionError( # pragma: no cover
172 "Output %d of test %d in folder %r failed." % (
173 i, index, self.folder)) from ex
174 else:
175 try:
176 assert_almost_equal(e, o, decimal=deci)
177 except AssertionError as ex:
178 raise AssertionError(
179 "Output %d of test %d in folder %r failed." % (
180 i, index, self.folder)) from ex
181 elif hasattr(o, 'is_compatible'):
182 # A shape
183 if e.dtype != o.dtype:
184 raise AssertionError(
185 "Output %d of test %d in folder %r failed "
186 "(e.dtype=%r, o=%r)." % (
187 i, index, self.folder, e.dtype, o))
188 if not o.is_compatible(e.shape):
189 raise AssertionError( # pragma: no cover
190 "Output %d of test %d in folder %r failed "
191 "(e.shape=%r, o=%r)." % (
192 i, index, self.folder, e.shape, o))
193 else:
194 raise NotImplementedError(
195 f"Comparison not implemented for type {type(e)!r}.")
197 def is_random(self):
198 "Tells if a test is random or not."
199 if 'bernoulli' in self.folder:
200 return True
201 return False
203 def run(self, load_fct, run_fct, index=None, decimal=None):
204 """
205 Executes a tests or all tests if index is None.
206 The function crashes if the tests fails.
208 :param load_fct: loading function, takes a loaded onnx graph,
209 and returns an object
210 :param run_fct: running function, takes the result of previous
211 function, the inputs, and returns the outputs
212 :param index: index of the test to run or all.
213 :param decimal: requested precision to compare results
214 """
215 if index is None:
216 for i in range(len(self)):
217 self.run(load_fct, run_fct, index=i, decimal=decimal)
218 return
220 obj = load_fct(self.onnx_model)
222 got = run_fct(obj, *self.tests[index]['inputs'])
223 expected = self.tests[index]['outputs']
224 if len(got) != len(expected):
225 raise AssertionError( # pragma: no cover
226 "Unexpected number of output (test %d, folder %r), "
227 "got %r, expected %r." % (
228 index, self.folder, len(got), len(expected)))
229 for i, (e, o) in enumerate(zip(expected, got)):
230 if self.is_random():
231 if e.dtype != o.dtype:
232 raise AssertionError(
233 "Output %d of test %d in folder %r failed "
234 "(type mismatch %r != %r)." % (
235 i, index, self.folder, e.dtype, o.dtype))
236 if e.shape != o.shape:
237 raise AssertionError(
238 "Output %d of test %d in folder %r failed "
239 "(shape mismatch %r != %r)." % (
240 i, index, self.folder, e.shape, o.shape))
241 else:
242 self._compare_results(index, i, e, o, decimal=decimal)
244 def to_python(self):
245 """
246 Returns a python code equivalent to the ONNX test.
248 :return: code
249 """
250 from ..onnx_tools.onnx_export import export2onnx
251 rows = []
252 code = export2onnx(self.onnx_model)
253 lines = code.split('\n')
254 lines = [line for line in lines
255 if not line.strip().startswith('print') and
256 not line.strip().startswith('# ')]
257 rows.append(textwrap.dedent("\n".join(lines)))
258 rows.append("oinf = OnnxInference(onnx_model)")
259 for test in self.tests:
260 rows.append("xs = [")
261 for inp in test['inputs']:
262 rows.append(textwrap.indent(repr(inp) + ',', ' ' * 2))
263 rows.append("]")
264 rows.append("ys = [")
265 for out in test['outputs']:
266 rows.append(textwrap.indent(repr(out) + ',', ' ' * 2))
267 rows.append("]")
268 rows.append("feeds = {n: x for n, x in zip(oinf.input_names, xs)}")
269 rows.append("got = oinf.run(feeds)")
270 rows.append("goty = [got[k] for k in oinf.output_names]")
271 rows.append("for y, gy in zip(ys, goty):")
272 rows.append(" self.assertEqualArray(y, gy)")
273 rows.append("")
274 code = "\n".join(rows)
275 final = "\n".join([f"def {self.name}(self):",
276 textwrap.indent(code, ' ')])
277 try:
278 from pyquickhelper.pycode.code_helper import remove_extra_spaces_and_pep8
279 except ImportError: # pragma: no cover
280 return final
281 return remove_extra_spaces_and_pep8(final, aggressive=True)
284def enumerate_onnx_tests(series, fct_filter=None):
285 """
286 Collects test from a sub folder of `onnx/backend/test`.
287 Works as an enumerator to start processing them
288 without waiting or storing too much of them.
290 :param series: which subfolder to load, possible values:
291 (`'node'`, ...)
292 :param fct_filter: function `lambda testname: boolean`
293 to load or skip the test, None for all
294 :return: list of @see cl OnnxBackendTest
295 """
296 root = os.path.dirname(backend_folder)
297 sub = os.path.join(root, 'data', series)
298 if not os.path.exists(sub):
299 raise FileNotFoundError(
300 "Unable to find series of tests in %r, subfolders:\n%s" % (
301 root, "\n".join(os.listdir(root))))
302 tests = os.listdir(sub)
303 for t in tests:
304 if fct_filter is not None and not fct_filter(t):
305 continue
306 folder = os.path.join(sub, t)
307 content = os.listdir(folder)
308 onx = [c for c in content if os.path.splitext(c)[-1] in {'.onnx'}]
309 if len(onx) == 1:
310 yield OnnxBackendTest(folder)