Coverage for mlprodict/tools/zoo.py: 92%

109 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-04 02:28 +0100

1""" 

2@file 

3@brief Tools to test models from the :epkg:`ONNX Zoo`. 

4 

5.. versionadded:: 0.6 

6""" 

7import os 

8import urllib.request 

9from collections import OrderedDict 

10import numpy 

11from onnx import TensorProto, numpy_helper, load 

12from onnx.reference import ReferenceEvaluator 

13try: 

14 from .ort_wrapper import InferenceSession 

15except ImportError: 

16 from mlprodict.tools.ort_wrapper import InferenceSession 

17 

18 

19def short_list_zoo_models(): 

20 """ 

21 Returns a short list from :epkg:`ONNX Zoo`. 

22 

23 :return: list of dictionaries. 

24 

25 .. runpython:: 

26 :showcode: 

27 :warningout: DeprecationWarning 

28 

29 import pprint 

30 from mlprodict.tools.zoo import short_list_zoo_models 

31 pprint.pprint(short_list_zoo_models()) 

32 """ 

33 return [ 

34 dict(name="mobilenet", 

35 model="https://github.com/onnx/models/raw/main/vision/" 

36 "classification/mobilenet/model/mobilenetv2-7.tar.gz"), 

37 dict(name="resnet18", 

38 model="https://github.com/onnx/models/raw/main/vision/" 

39 "classification/resnet/model/resnet18-v1-7.tar.gz"), 

40 dict(name="squeezenet", 

41 model="https://github.com/onnx/models/raw/main/vision/" 

42 "classification/squeezenet/model/squeezenet1.0-9.tar.gz", 

43 folder="squeezenet"), 

44 dict(name="densenet121", 

45 model="https://github.com/onnx/models/raw/main/vision/" 

46 "classification/densenet-121/model/densenet-9.tar.gz", 

47 folder="densenet121"), 

48 dict(name="inception2", 

49 model="https://github.com/onnx/models/raw/main/vision/" 

50 "classification/inception_and_googlenet/inception_v2/" 

51 "model/inception-v2-9.tar.gz"), 

52 dict(name="shufflenet", 

53 model="https://github.com/onnx/models/raw/main/vision/" 

54 "classification/shufflenet/model/shufflenet-9.tar.gz"), 

55 dict(name="efficientnet-lite4", 

56 model="https://github.com/onnx/models/raw/main/vision/" 

57 "classification/efficientnet-lite4/model/" 

58 "efficientnet-lite4-11.tar.gz"), 

59 ] 

60 

61 

62def _download_url(url, output_path, name, verbose=False): 

63 if verbose: # pragma: no cover 

64 from tqdm import tqdm 

65 

66 class DownloadProgressBar(tqdm): 

67 "progress bar hook" 

68 

69 def update_to(self, b=1, bsize=1, tsize=None): 

70 "progress bar hook" 

71 if tsize is not None: 

72 self.total = tsize 

73 self.update(b * bsize - self.n) 

74 

75 with DownloadProgressBar(unit='B', unit_scale=True, 

76 miniters=1, desc=name) as t: 

77 urllib.request.urlretrieve( 

78 url, filename=output_path, reporthook=t.update_to) 

79 else: 

80 urllib.request.urlretrieve(url, filename=output_path) 

81 

82 

83def load_data(folder): 

84 """ 

85 Restores protobuf data stored in a folder. 

86 

87 :param folder: folder 

88 :return: dictionary 

89 """ 

90 res = OrderedDict() 

91 res['in'] = OrderedDict() 

92 res['out'] = OrderedDict() 

93 files = os.listdir(folder) 

94 for name in files: 

95 noext, ext = os.path.splitext(name) 

96 if ext == '.pb': 

97 data = TensorProto() 

98 with open(os.path.join(folder, name), 'rb') as f: 

99 data.ParseFromString(f.read()) 

100 if noext.startswith('input'): 

101 res['in'][noext] = numpy_helper.to_array(data) 

102 elif noext.startswith('output'): 

103 res['out'][noext] = numpy_helper.to_array(data) 

104 else: 

105 raise ValueError( # pragma: no cover 

106 f"Unable to guess anything about {noext!r}.") 

107 

108 return res 

109 

110 

111def download_model_data(name, model=None, cache=None, verbose=False): 

112 """ 

113 Downloads a model and returns a link to the local 

114 :epkg:`ONNX` file and data which can be used as inputs. 

115 

116 :param name: model name (see @see fn short_list_zoo_models) 

117 :param model: url or empty to get the default value 

118 returned by @see fn short_list_zoo_models) 

119 :param cache: folder to cache the downloaded data 

120 :param verbose: display a progress bar 

121 :return: local onnx file, input data 

122 """ 

123 suggested_folder = None 

124 if model is None: 

125 model_list = short_list_zoo_models() 

126 for mod in model_list: 

127 if mod['name'] == name: 

128 model = mod['model'] 

129 if 'folder' in mod: # pylint: disable=R1715 

130 suggested_folder = mod['folder'] 

131 break 

132 if model is None: 

133 raise ValueError( 

134 f"Unable to find a default value for name={name!r}.") 

135 

136 # downloads 

137 last_name = model.split('/')[-1] 

138 if cache is None: 

139 cache = os.path.abspath('.') # pragma: no cover 

140 dest = os.path.join(cache, last_name) 

141 if not os.path.exists(dest): 

142 _download_url(model, dest, name, verbose=verbose) 

143 size = os.stat(dest).st_size 

144 if size < 2 ** 20: # pragma: no cover 

145 os.remove(dest) 

146 raise ConnectionError( 

147 f"Unable to download model from {model!r}.") 

148 

149 outtar = os.path.splitext(dest)[0] 

150 if not os.path.exists(outtar): 

151 from pyquickhelper.filehelper.compression_helper import ( 

152 ungzip_files) 

153 ungzip_files(dest, unzip=False, where_to=cache, remove_space=False) 

154 

155 onnx_file = os.path.splitext(outtar)[0] 

156 if not os.path.exists(onnx_file): 

157 from pyquickhelper.filehelper.compression_helper import ( 

158 untar_files) 

159 foldtar = [f for f in untar_files(outtar, where_to=cache) 

160 if os.path.isdir(f) and "test_data_" not in f] 

161 else: 

162 foldtar = [] 

163 

164 if suggested_folder is not None: 

165 fold_onnx = [suggested_folder] + foldtar 

166 else: 

167 fold_onnx = foldtar + [onnx_file, onnx_file.split('-')[0], 

168 '-'.join(onnx_file.split('-')[:-1]), 

169 '-'.join(onnx_file.split('-')[:-1]).replace('-', '_')] 

170 fold_onnx_ok = set( 

171 _ for _ in fold_onnx if os.path.exists(_) and os.path.isdir(_)) 

172 if len(fold_onnx_ok) != 1: 

173 raise FileNotFoundError( # pragma: no cover 

174 f"Unable to find an existing folder among {fold_onnx!r}.") 

175 onnx_file = list(fold_onnx_ok)[0] 

176 

177 onnx_files = [_ for _ in os.listdir(onnx_file) if _.endswith(".onnx")] 

178 if len(onnx_files) != 1: 

179 raise FileNotFoundError( # pragma: no cover 

180 f"Unable to find any onnx file in {onnx_files!r}.") 

181 final_onnx = os.path.join(onnx_file, onnx_files[0]) 

182 

183 # data 

184 data = [_ for _ in os.listdir(onnx_file) 

185 if os.path.isdir(os.path.join(onnx_file, _))] 

186 examples = OrderedDict() 

187 for f in data: 

188 examples[f] = load_data(os.path.join(onnx_file, f)) 

189 

190 return final_onnx, examples 

191 

192 

193def verify_model(onnx_file, examples, runtime=None, abs_tol=5e-4, 

194 verbose=0, fLOG=None): 

195 """ 

196 Verifies a model. 

197 

198 :param onnx_file: ONNX file 

199 :param examples: list of examples to verify 

200 :param runtime: a runtime to use 

201 :param abs_tol: error tolerance when checking the output 

202 :param verbose: verbosity level for for runtime other than 

203 `'onnxruntime'` 

204 :param fLOG: logging function when `verbose > 0` 

205 :return: errors for every sample 

206 """ 

207 if runtime in ('onnxruntime', 'onnxruntime-cuda'): 

208 sess = InferenceSession(onnx_file, runtime=runtime) 

209 meth = lambda data, s=sess: s.run(None, data) 

210 names = [p.name for p in sess.get_inputs()] 

211 onames = list(range(len(sess.get_outputs()))) 

212 elif runtime in ('onnx'): 

213 with open(onnx_file, "rb") as f: 

214 onx = load(f) 

215 inits = set(i.name for i in onx.graph.initializer) 

216 sess = ReferenceEvaluator(onnx_file, verbose=10) 

217 meth = lambda data, s=sess: s.run(None, data) 

218 names = [n for n in sess.input_names if n not in inits] 

219 onames = list(range(len(sess.output_names))) 

220 else: 

221 def _lin_(sess, data, names): 

222 r = sess.run(data, verbose=verbose, fLOG=fLOG) 

223 return [r[n] for n in names] 

224 

225 from ..onnxrt import OnnxInference 

226 sess = OnnxInference( 

227 onnx_file, runtime=runtime, 

228 runtime_options=dict(log_severity_level=3)) 

229 names = sess.input_names 

230 onames = sess.output_names 

231 meth = lambda data, s=sess, ns=onames: _lin_(s, data, ns) 

232 

233 rows = [] 

234 for index, (name, data_inout) in enumerate(examples.items()): 

235 data = data_inout["in"] 

236 if len(data) != len(names): 

237 raise RuntimeError( # pragma: no cover 

238 "Mismathed number of inputs %d != %d\ninputs: %r\nmodel: %r." 

239 "" % (len(data), len(names), list(sorted(data)), names)) 

240 inputs = {n: data[v] for n, v in zip(names, data)} 

241 outputs = meth(inputs) 

242 expected = data_inout['out'] 

243 if len(outputs) != len(onames): 

244 raise RuntimeError( # pragma: no cover 

245 "Number of outputs %d is != expected outputs %d." % ( 

246 len(outputs), len(onames))) 

247 for i, (output, expect) in enumerate(zip(outputs, expected.items())): 

248 if output.shape != expect[1].shape: 

249 raise ValueError( # pragma: no cover 

250 "Shape mismatch got %r != expected %r." % ( 

251 output.shape, expect[1].shape)) 

252 diff = numpy.abs(output - expect[1]).ravel() 

253 absolute = diff.max() 

254 relative = absolute / numpy.median(diff) if absolute > 0 else 0. 

255 if absolute > abs_tol: 

256 raise ValueError( # pragma: no cover 

257 "Example %d, inferred and expected results are different " 

258 "for output %d: abs=%r rel=%r (runtime=%r)." 

259 "" % (index, i, absolute, relative, runtime)) 

260 rows.append(dict(name=name, i=i, abs=absolute, rel=relative)) 

261 return rows