Coverage for mlprodict/tools/zoo.py: 92%
109 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
1"""
2@file
3@brief Tools to test models from the :epkg:`ONNX Zoo`.
5.. versionadded:: 0.6
6"""
7import os
8import urllib.request
9from collections import OrderedDict
10import numpy
11from onnx import TensorProto, numpy_helper, load
12from onnx.reference import ReferenceEvaluator
13try:
14 from .ort_wrapper import InferenceSession
15except ImportError:
16 from mlprodict.tools.ort_wrapper import InferenceSession
19def short_list_zoo_models():
20 """
21 Returns a short list from :epkg:`ONNX Zoo`.
23 :return: list of dictionaries.
25 .. runpython::
26 :showcode:
27 :warningout: DeprecationWarning
29 import pprint
30 from mlprodict.tools.zoo import short_list_zoo_models
31 pprint.pprint(short_list_zoo_models())
32 """
33 return [
34 dict(name="mobilenet",
35 model="https://github.com/onnx/models/raw/main/vision/"
36 "classification/mobilenet/model/mobilenetv2-7.tar.gz"),
37 dict(name="resnet18",
38 model="https://github.com/onnx/models/raw/main/vision/"
39 "classification/resnet/model/resnet18-v1-7.tar.gz"),
40 dict(name="squeezenet",
41 model="https://github.com/onnx/models/raw/main/vision/"
42 "classification/squeezenet/model/squeezenet1.0-9.tar.gz",
43 folder="squeezenet"),
44 dict(name="densenet121",
45 model="https://github.com/onnx/models/raw/main/vision/"
46 "classification/densenet-121/model/densenet-9.tar.gz",
47 folder="densenet121"),
48 dict(name="inception2",
49 model="https://github.com/onnx/models/raw/main/vision/"
50 "classification/inception_and_googlenet/inception_v2/"
51 "model/inception-v2-9.tar.gz"),
52 dict(name="shufflenet",
53 model="https://github.com/onnx/models/raw/main/vision/"
54 "classification/shufflenet/model/shufflenet-9.tar.gz"),
55 dict(name="efficientnet-lite4",
56 model="https://github.com/onnx/models/raw/main/vision/"
57 "classification/efficientnet-lite4/model/"
58 "efficientnet-lite4-11.tar.gz"),
59 ]
62def _download_url(url, output_path, name, verbose=False):
63 if verbose: # pragma: no cover
64 from tqdm import tqdm
66 class DownloadProgressBar(tqdm):
67 "progress bar hook"
69 def update_to(self, b=1, bsize=1, tsize=None):
70 "progress bar hook"
71 if tsize is not None:
72 self.total = tsize
73 self.update(b * bsize - self.n)
75 with DownloadProgressBar(unit='B', unit_scale=True,
76 miniters=1, desc=name) as t:
77 urllib.request.urlretrieve(
78 url, filename=output_path, reporthook=t.update_to)
79 else:
80 urllib.request.urlretrieve(url, filename=output_path)
83def load_data(folder):
84 """
85 Restores protobuf data stored in a folder.
87 :param folder: folder
88 :return: dictionary
89 """
90 res = OrderedDict()
91 res['in'] = OrderedDict()
92 res['out'] = OrderedDict()
93 files = os.listdir(folder)
94 for name in files:
95 noext, ext = os.path.splitext(name)
96 if ext == '.pb':
97 data = TensorProto()
98 with open(os.path.join(folder, name), 'rb') as f:
99 data.ParseFromString(f.read())
100 if noext.startswith('input'):
101 res['in'][noext] = numpy_helper.to_array(data)
102 elif noext.startswith('output'):
103 res['out'][noext] = numpy_helper.to_array(data)
104 else:
105 raise ValueError( # pragma: no cover
106 f"Unable to guess anything about {noext!r}.")
108 return res
111def download_model_data(name, model=None, cache=None, verbose=False):
112 """
113 Downloads a model and returns a link to the local
114 :epkg:`ONNX` file and data which can be used as inputs.
116 :param name: model name (see @see fn short_list_zoo_models)
117 :param model: url or empty to get the default value
118 returned by @see fn short_list_zoo_models)
119 :param cache: folder to cache the downloaded data
120 :param verbose: display a progress bar
121 :return: local onnx file, input data
122 """
123 suggested_folder = None
124 if model is None:
125 model_list = short_list_zoo_models()
126 for mod in model_list:
127 if mod['name'] == name:
128 model = mod['model']
129 if 'folder' in mod: # pylint: disable=R1715
130 suggested_folder = mod['folder']
131 break
132 if model is None:
133 raise ValueError(
134 f"Unable to find a default value for name={name!r}.")
136 # downloads
137 last_name = model.split('/')[-1]
138 if cache is None:
139 cache = os.path.abspath('.') # pragma: no cover
140 dest = os.path.join(cache, last_name)
141 if not os.path.exists(dest):
142 _download_url(model, dest, name, verbose=verbose)
143 size = os.stat(dest).st_size
144 if size < 2 ** 20: # pragma: no cover
145 os.remove(dest)
146 raise ConnectionError(
147 f"Unable to download model from {model!r}.")
149 outtar = os.path.splitext(dest)[0]
150 if not os.path.exists(outtar):
151 from pyquickhelper.filehelper.compression_helper import (
152 ungzip_files)
153 ungzip_files(dest, unzip=False, where_to=cache, remove_space=False)
155 onnx_file = os.path.splitext(outtar)[0]
156 if not os.path.exists(onnx_file):
157 from pyquickhelper.filehelper.compression_helper import (
158 untar_files)
159 foldtar = [f for f in untar_files(outtar, where_to=cache)
160 if os.path.isdir(f) and "test_data_" not in f]
161 else:
162 foldtar = []
164 if suggested_folder is not None:
165 fold_onnx = [suggested_folder] + foldtar
166 else:
167 fold_onnx = foldtar + [onnx_file, onnx_file.split('-')[0],
168 '-'.join(onnx_file.split('-')[:-1]),
169 '-'.join(onnx_file.split('-')[:-1]).replace('-', '_')]
170 fold_onnx_ok = set(
171 _ for _ in fold_onnx if os.path.exists(_) and os.path.isdir(_))
172 if len(fold_onnx_ok) != 1:
173 raise FileNotFoundError( # pragma: no cover
174 f"Unable to find an existing folder among {fold_onnx!r}.")
175 onnx_file = list(fold_onnx_ok)[0]
177 onnx_files = [_ for _ in os.listdir(onnx_file) if _.endswith(".onnx")]
178 if len(onnx_files) != 1:
179 raise FileNotFoundError( # pragma: no cover
180 f"Unable to find any onnx file in {onnx_files!r}.")
181 final_onnx = os.path.join(onnx_file, onnx_files[0])
183 # data
184 data = [_ for _ in os.listdir(onnx_file)
185 if os.path.isdir(os.path.join(onnx_file, _))]
186 examples = OrderedDict()
187 for f in data:
188 examples[f] = load_data(os.path.join(onnx_file, f))
190 return final_onnx, examples
193def verify_model(onnx_file, examples, runtime=None, abs_tol=5e-4,
194 verbose=0, fLOG=None):
195 """
196 Verifies a model.
198 :param onnx_file: ONNX file
199 :param examples: list of examples to verify
200 :param runtime: a runtime to use
201 :param abs_tol: error tolerance when checking the output
202 :param verbose: verbosity level for for runtime other than
203 `'onnxruntime'`
204 :param fLOG: logging function when `verbose > 0`
205 :return: errors for every sample
206 """
207 if runtime in ('onnxruntime', 'onnxruntime-cuda'):
208 sess = InferenceSession(onnx_file, runtime=runtime)
209 meth = lambda data, s=sess: s.run(None, data)
210 names = [p.name for p in sess.get_inputs()]
211 onames = list(range(len(sess.get_outputs())))
212 elif runtime in ('onnx'):
213 with open(onnx_file, "rb") as f:
214 onx = load(f)
215 inits = set(i.name for i in onx.graph.initializer)
216 sess = ReferenceEvaluator(onnx_file, verbose=10)
217 meth = lambda data, s=sess: s.run(None, data)
218 names = [n for n in sess.input_names if n not in inits]
219 onames = list(range(len(sess.output_names)))
220 else:
221 def _lin_(sess, data, names):
222 r = sess.run(data, verbose=verbose, fLOG=fLOG)
223 return [r[n] for n in names]
225 from ..onnxrt import OnnxInference
226 sess = OnnxInference(
227 onnx_file, runtime=runtime,
228 runtime_options=dict(log_severity_level=3))
229 names = sess.input_names
230 onames = sess.output_names
231 meth = lambda data, s=sess, ns=onames: _lin_(s, data, ns)
233 rows = []
234 for index, (name, data_inout) in enumerate(examples.items()):
235 data = data_inout["in"]
236 if len(data) != len(names):
237 raise RuntimeError( # pragma: no cover
238 "Mismathed number of inputs %d != %d\ninputs: %r\nmodel: %r."
239 "" % (len(data), len(names), list(sorted(data)), names))
240 inputs = {n: data[v] for n, v in zip(names, data)}
241 outputs = meth(inputs)
242 expected = data_inout['out']
243 if len(outputs) != len(onames):
244 raise RuntimeError( # pragma: no cover
245 "Number of outputs %d is != expected outputs %d." % (
246 len(outputs), len(onames)))
247 for i, (output, expect) in enumerate(zip(outputs, expected.items())):
248 if output.shape != expect[1].shape:
249 raise ValueError( # pragma: no cover
250 "Shape mismatch got %r != expected %r." % (
251 output.shape, expect[1].shape))
252 diff = numpy.abs(output - expect[1]).ravel()
253 absolute = diff.max()
254 relative = absolute / numpy.median(diff) if absolute > 0 else 0.
255 if absolute > abs_tol:
256 raise ValueError( # pragma: no cover
257 "Example %d, inferred and expected results are different "
258 "for output %d: abs=%r rel=%r (runtime=%r)."
259 "" % (index, i, absolute, relative, runtime))
260 rows.append(dict(name=name, i=i, abs=absolute, rel=relative))
261 return rows