Coverage for mlprodict/testing/onnx_backend.py: 95%

167 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-04 02:28 +0100

1""" 

2@file 

3@brief Tests with onnx backend. 

4""" 

5import os 

6import textwrap 

7import numpy 

8try: 

9 # new numpy 

10 from numpy import object_ as dtype_object 

11except ImportError: 

12 # old numpy 

13 from numpy import object as dtype_object 

14from numpy.testing import assert_almost_equal 

15import onnx 

16from onnx.numpy_helper import to_array, to_list 

17from onnx.backend.test import __file__ as backend_folder 

18 

19 

20def assert_almost_equal_string(expected, value): 

21 """ 

22 Compares two arrays knowing they contain strings. 

23 Raises an exception if the test fails. 

24 

25 :param expected: expected array 

26 :param value: value 

27 """ 

28 def is_float(x): 

29 try: 

30 return True 

31 except ValueError: # pragma: no cover 

32 return False 

33 

34 if all(map(is_float, expected.ravel())): 

35 expected_float = expected.astype(numpy.float32) 

36 value_float = value.astype(numpy.float32) 

37 assert_almost_equal(expected_float, value_float) 

38 else: 

39 assert_almost_equal(expected, value) 

40 

41 

42class OnnxBackendTest: 

43 """ 

44 Definition of a backend test. It starts with a folder, 

45 in this folder, one onnx file must be there, then a subfolder 

46 for each test to run with this model. 

47 

48 :param folder: test folder 

49 :param onnx_path: onnx file 

50 :param onnx_model: loaded onnx file 

51 :param tests: list of test 

52 """ 

53 @staticmethod 

54 def _sort(filenames): 

55 temp = [] 

56 for f in filenames: 

57 name = os.path.splitext(f)[0] 

58 i = name.split('_')[-1] 

59 temp.append((int(i), f)) 

60 temp.sort() 

61 return [_[1] for _ in temp] 

62 

63 @staticmethod 

64 def _read_proto_from_file(full): 

65 if not os.path.exists(full): 

66 raise FileNotFoundError( # pragma: no cover 

67 f"File not found: {full!r}.") 

68 with open(full, 'rb') as f: 

69 serialized = f.read() 

70 try: 

71 loaded = to_array(onnx.load_tensor_from_string(serialized)) 

72 except Exception as e: # pylint: disable=W0703 

73 seq = onnx.SequenceProto() # pylint: disable=E1101 

74 try: 

75 seq.ParseFromString(serialized) 

76 loaded = to_list(seq) 

77 except Exception: # pylint: disable=W0703 

78 try: 

79 loaded = onnx.load_model_from_string(serialized) 

80 except Exception: # pragma: no cover 

81 raise RuntimeError( 

82 "Unable to read %r, error is %s, content is %r." % ( 

83 full, e, serialized[:100])) from e 

84 return loaded 

85 

86 @staticmethod 

87 def _load(folder, names): 

88 res = [] 

89 for name in names: 

90 full = os.path.join(folder, name) 

91 new_tensor = OnnxBackendTest._read_proto_from_file(full) 

92 if isinstance(new_tensor, ( 

93 numpy.ndarray, onnx.ModelProto, list)): # pylint: disable=E1101 

94 t = new_tensor 

95 elif isinstance(new_tensor, onnx.TensorProto): # pylint: disable=E1101 

96 t = to_array(new_tensor) 

97 else: 

98 raise RuntimeError( # pragma: no cover 

99 f"Unexpected type {type(new_tensor)!r} for {full!r}.") 

100 res.append(t) 

101 return res 

102 

103 def __repr__(self): 

104 "usual" 

105 return f"{self.__class__.__name__}({self.folder!r})" 

106 

107 def __init__(self, folder): 

108 if not os.path.exists(folder): 

109 raise FileNotFoundError( # pragma: no cover 

110 f"Unable to find folder {folder!r}.") 

111 content = os.listdir(folder) 

112 onx = [c for c in content if os.path.splitext(c)[-1] in {'.onnx'}] 

113 if len(onx) != 1: 

114 raise ValueError( # pragma: no cover 

115 f"There is more than one onnx file in {folder!r} ({onx!r}).") 

116 self.folder = folder 

117 self.onnx_path = os.path.join(folder, onx[0]) 

118 self.onnx_model = onnx.load(self.onnx_path) 

119 

120 self.tests = [] 

121 for sub in content: 

122 full = os.path.join(folder, sub) 

123 if os.path.isdir(full): 

124 pb = [c for c in os.listdir(full) 

125 if os.path.splitext(c)[-1] in {'.pb'}] 

126 inputs = OnnxBackendTest._sort( 

127 c for c in pb if c.startswith('input_')) 

128 outputs = OnnxBackendTest._sort( 

129 c for c in pb if c.startswith('output_')) 

130 

131 t = dict( 

132 inputs=OnnxBackendTest._load(full, inputs), 

133 outputs=OnnxBackendTest._load(full, outputs)) 

134 self.tests.append(t) 

135 

136 @property 

137 def name(self): 

138 "Returns the test name." 

139 return os.path.split(self.folder)[-1] 

140 

141 def __len__(self): 

142 "Returns the number of tests." 

143 return len(self.tests) 

144 

145 def _compare_results(self, index, i, e, o, decimal=None): 

146 """ 

147 Compares the expected output and the output produced 

148 by the runtime. Raises an exception if not equal. 

149 

150 :param index: test index 

151 :param i: output index 

152 :param e: expected output 

153 :param o: output 

154 :param decimal: precision 

155 """ 

156 if isinstance(e, numpy.ndarray): 

157 if isinstance(o, numpy.ndarray): 

158 if decimal is None: 

159 if e.dtype == numpy.float32: 

160 deci = 6 

161 elif e.dtype == numpy.float64: 

162 deci = 12 

163 else: 

164 deci = 7 

165 else: 

166 deci = decimal 

167 if e.dtype == dtype_object: 

168 try: 

169 assert_almost_equal_string(e, o) 

170 except AssertionError as ex: 

171 raise AssertionError( # pragma: no cover 

172 "Output %d of test %d in folder %r failed." % ( 

173 i, index, self.folder)) from ex 

174 else: 

175 try: 

176 assert_almost_equal(e, o, decimal=deci) 

177 except AssertionError as ex: 

178 raise AssertionError( 

179 "Output %d of test %d in folder %r failed." % ( 

180 i, index, self.folder)) from ex 

181 elif hasattr(o, 'is_compatible'): 

182 # A shape 

183 if e.dtype != o.dtype: 

184 raise AssertionError( 

185 "Output %d of test %d in folder %r failed " 

186 "(e.dtype=%r, o=%r)." % ( 

187 i, index, self.folder, e.dtype, o)) 

188 if not o.is_compatible(e.shape): 

189 raise AssertionError( # pragma: no cover 

190 "Output %d of test %d in folder %r failed " 

191 "(e.shape=%r, o=%r)." % ( 

192 i, index, self.folder, e.shape, o)) 

193 else: 

194 raise NotImplementedError( 

195 f"Comparison not implemented for type {type(e)!r}.") 

196 

197 def is_random(self): 

198 "Tells if a test is random or not." 

199 if 'bernoulli' in self.folder: 

200 return True 

201 return False 

202 

203 def run(self, load_fct, run_fct, index=None, decimal=None): 

204 """ 

205 Executes a tests or all tests if index is None. 

206 The function crashes if the tests fails. 

207 

208 :param load_fct: loading function, takes a loaded onnx graph, 

209 and returns an object 

210 :param run_fct: running function, takes the result of previous 

211 function, the inputs, and returns the outputs 

212 :param index: index of the test to run or all. 

213 :param decimal: requested precision to compare results 

214 """ 

215 if index is None: 

216 for i in range(len(self)): 

217 self.run(load_fct, run_fct, index=i, decimal=decimal) 

218 return 

219 

220 obj = load_fct(self.onnx_model) 

221 

222 got = run_fct(obj, *self.tests[index]['inputs']) 

223 expected = self.tests[index]['outputs'] 

224 if len(got) != len(expected): 

225 raise AssertionError( # pragma: no cover 

226 "Unexpected number of output (test %d, folder %r), " 

227 "got %r, expected %r." % ( 

228 index, self.folder, len(got), len(expected))) 

229 for i, (e, o) in enumerate(zip(expected, got)): 

230 if self.is_random(): 

231 if e.dtype != o.dtype: 

232 raise AssertionError( 

233 "Output %d of test %d in folder %r failed " 

234 "(type mismatch %r != %r)." % ( 

235 i, index, self.folder, e.dtype, o.dtype)) 

236 if e.shape != o.shape: 

237 raise AssertionError( 

238 "Output %d of test %d in folder %r failed " 

239 "(shape mismatch %r != %r)." % ( 

240 i, index, self.folder, e.shape, o.shape)) 

241 else: 

242 self._compare_results(index, i, e, o, decimal=decimal) 

243 

244 def to_python(self): 

245 """ 

246 Returns a python code equivalent to the ONNX test. 

247 

248 :return: code 

249 """ 

250 from ..onnx_tools.onnx_export import export2onnx 

251 rows = [] 

252 code = export2onnx(self.onnx_model) 

253 lines = code.split('\n') 

254 lines = [line for line in lines 

255 if not line.strip().startswith('print') and 

256 not line.strip().startswith('# ')] 

257 rows.append(textwrap.dedent("\n".join(lines))) 

258 rows.append("oinf = OnnxInference(onnx_model)") 

259 for test in self.tests: 

260 rows.append("xs = [") 

261 for inp in test['inputs']: 

262 rows.append(textwrap.indent(repr(inp) + ',', ' ' * 2)) 

263 rows.append("]") 

264 rows.append("ys = [") 

265 for out in test['outputs']: 

266 rows.append(textwrap.indent(repr(out) + ',', ' ' * 2)) 

267 rows.append("]") 

268 rows.append("feeds = {n: x for n, x in zip(oinf.input_names, xs)}") 

269 rows.append("got = oinf.run(feeds)") 

270 rows.append("goty = [got[k] for k in oinf.output_names]") 

271 rows.append("for y, gy in zip(ys, goty):") 

272 rows.append(" self.assertEqualArray(y, gy)") 

273 rows.append("") 

274 code = "\n".join(rows) 

275 final = "\n".join([f"def {self.name}(self):", 

276 textwrap.indent(code, ' ')]) 

277 try: 

278 from pyquickhelper.pycode.code_helper import remove_extra_spaces_and_pep8 

279 except ImportError: # pragma: no cover 

280 return final 

281 return remove_extra_spaces_and_pep8(final, aggressive=True) 

282 

283 

284def enumerate_onnx_tests(series, fct_filter=None): 

285 """ 

286 Collects test from a sub folder of `onnx/backend/test`. 

287 Works as an enumerator to start processing them 

288 without waiting or storing too much of them. 

289 

290 :param series: which subfolder to load, possible values: 

291 (`'node'`, ...) 

292 :param fct_filter: function `lambda testname: boolean` 

293 to load or skip the test, None for all 

294 :return: list of @see cl OnnxBackendTest 

295 """ 

296 root = os.path.dirname(backend_folder) 

297 sub = os.path.join(root, 'data', series) 

298 if not os.path.exists(sub): 

299 raise FileNotFoundError( 

300 "Unable to find series of tests in %r, subfolders:\n%s" % ( 

301 root, "\n".join(os.listdir(root)))) 

302 tests = os.listdir(sub) 

303 for t in tests: 

304 if fct_filter is not None and not fct_filter(t): 

305 continue 

306 folder = os.path.join(sub, t) 

307 content = os.listdir(folder) 

308 onx = [c for c in content if os.path.splitext(c)[-1] in {'.onnx'}] 

309 if len(onx) == 1: 

310 yield OnnxBackendTest(folder)